Repository: smallstep/certificates Branch: master Commit: 6e8ec6140523 Files: 543 Total size: 4.5 MB Directory structure: gitextract_pt61h5ml/ ├── .VERSION ├── .dockerignore ├── .gitattributes ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug-report.yml │ │ ├── config.yml │ │ ├── documentation-request.md │ │ └── enhancement.md │ ├── PULL_REQUEST_TEMPLATE │ ├── dependabot.yml │ ├── workflows/ │ │ ├── actionci.yml │ │ ├── ci.yml │ │ ├── code-scan-cron.yml │ │ ├── dependabot-auto-merge.yml │ │ ├── publish-packages.yml │ │ ├── release.yml │ │ └── triage.yml │ └── zizmor.yml ├── .gitignore ├── .goreleaser.yml ├── .version.sh ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── acme/ │ ├── account.go │ ├── account_test.go │ ├── api/ │ │ ├── account.go │ │ ├── account_test.go │ │ ├── eab.go │ │ ├── eab_test.go │ │ ├── handler.go │ │ ├── handler_test.go │ │ ├── middleware.go │ │ ├── middleware_test.go │ │ ├── order.go │ │ ├── order_test.go │ │ ├── revoke.go │ │ ├── revoke_test.go │ │ └── wire_integration_test.go │ ├── authorization.go │ ├── authorization_test.go │ ├── certificate.go │ ├── challenge.go │ ├── challenge_test.go │ ├── challenge_tpmsimulator_test.go │ ├── challenge_wire_test.go │ ├── client.go │ ├── common.go │ ├── db/ │ │ └── nosql/ │ │ ├── account.go │ │ ├── account_test.go │ │ ├── authz.go │ │ ├── authz_test.go │ │ ├── certificate.go │ │ ├── certificate_test.go │ │ ├── challenge.go │ │ ├── challenge_test.go │ │ ├── eab.go │ │ ├── eab_test.go │ │ ├── nonce.go │ │ ├── nonce_test.go │ │ ├── nosql.go │ │ ├── nosql_test.go │ │ ├── order.go │ │ ├── order_test.go │ │ ├── wire.go │ │ └── wire_test.go │ ├── db.go │ ├── db_test.go │ ├── errors.go │ ├── errors_test.go │ ├── linker.go │ ├── linker_test.go │ ├── nonce.go │ ├── order.go │ ├── order_test.go │ ├── status.go │ └── wire/ │ ├── id.go │ └── id_test.go ├── api/ │ ├── api.go │ ├── api_test.go │ ├── crl.go │ ├── crl_test.go │ ├── log/ │ │ ├── log.go │ │ └── log_test.go │ ├── models/ │ │ └── scep.go │ ├── read/ │ │ ├── read.go │ │ └── read_test.go │ ├── rekey.go │ ├── render/ │ │ ├── render.go │ │ └── render_test.go │ ├── renew.go │ ├── revoke.go │ ├── revoke_test.go │ ├── sign.go │ ├── ssh.go │ ├── sshRekey.go │ ├── sshRenew.go │ ├── sshRevoke.go │ └── ssh_test.go ├── authority/ │ ├── admin/ │ │ ├── api/ │ │ │ ├── acme.go │ │ │ ├── acme_test.go │ │ │ ├── admin.go │ │ │ ├── admin_test.go │ │ │ ├── handler.go │ │ │ ├── middleware.go │ │ │ ├── middleware_test.go │ │ │ ├── policy.go │ │ │ ├── policy_test.go │ │ │ ├── provisioner.go │ │ │ ├── provisioner_test.go │ │ │ ├── webhook.go │ │ │ └── webhook_test.go │ │ ├── db/ │ │ │ └── nosql/ │ │ │ ├── admin.go │ │ │ ├── admin_test.go │ │ │ ├── nosql.go │ │ │ ├── policy.go │ │ │ ├── policy_test.go │ │ │ ├── provisioner.go │ │ │ └── provisioner_test.go │ │ ├── db.go │ │ └── errors.go │ ├── administrator/ │ │ └── collection.go │ ├── admins.go │ ├── authority.go │ ├── authority_test.go │ ├── authorize.go │ ├── authorize_test.go │ ├── config/ │ │ ├── config.go │ │ ├── config_test.go │ │ ├── ssh.go │ │ ├── ssh_test.go │ │ ├── tls_options.go │ │ ├── tls_options_test.go │ │ ├── types.go │ │ └── types_test.go │ ├── config.go │ ├── export.go │ ├── http_client.go │ ├── http_client_test.go │ ├── internal/ │ │ └── constraints/ │ │ ├── constraints.go │ │ ├── constraints_test.go │ │ └── verify.go │ ├── linkedca.go │ ├── meter.go │ ├── options.go │ ├── policy/ │ │ ├── engine.go │ │ ├── options.go │ │ ├── options_test.go │ │ ├── policy.go │ │ └── policy_test.go │ ├── policy.go │ ├── policy_test.go │ ├── poolhttp/ │ │ ├── poolhttp.go │ │ └── poolhttp_test.go │ ├── provisioner/ │ │ ├── acme.go │ │ ├── acme_118_test.go │ │ ├── acme_119_test.go │ │ ├── acme_test.go │ │ ├── aws.go │ │ ├── aws_certificates.pem │ │ ├── aws_test.go │ │ ├── azure.go │ │ ├── azure_test.go │ │ ├── claims.go │ │ ├── claims_test.go │ │ ├── collection.go │ │ ├── collection_test.go │ │ ├── controller.go │ │ ├── controller_test.go │ │ ├── duration.go │ │ ├── duration_test.go │ │ ├── extension.go │ │ ├── extension_test.go │ │ ├── gcp/ │ │ │ ├── projectvalidator.go │ │ │ └── projectvalidator_test.go │ │ ├── gcp.go │ │ ├── gcp_test.go │ │ ├── jwk.go │ │ ├── jwk_test.go │ │ ├── k8sSA.go │ │ ├── k8sSA_test.go │ │ ├── keystore.go │ │ ├── keystore_test.go │ │ ├── method.go │ │ ├── nebula.go │ │ ├── nebula_test.go │ │ ├── noop.go │ │ ├── noop_test.go │ │ ├── oidc.go │ │ ├── oidc_test.go │ │ ├── options.go │ │ ├── options_test.go │ │ ├── policy.go │ │ ├── provisioner.go │ │ ├── provisioner_test.go │ │ ├── scep.go │ │ ├── scep_test.go │ │ ├── sign_options.go │ │ ├── sign_options_test.go │ │ ├── sign_ssh_options.go │ │ ├── sign_ssh_options_test.go │ │ ├── ssh_options.go │ │ ├── ssh_options_test.go │ │ ├── ssh_test.go │ │ ├── sshpop.go │ │ ├── sshpop_test.go │ │ ├── testdata/ │ │ │ ├── certs/ │ │ │ │ ├── apple-att-ca.crt │ │ │ │ ├── aws-test.crt │ │ │ │ ├── aws.crt │ │ │ │ ├── bad-extension.crt │ │ │ │ ├── bar.pub │ │ │ │ ├── ecdsa.csr │ │ │ │ ├── ed25519.csr │ │ │ │ ├── foo.crt │ │ │ │ ├── foo.pub │ │ │ │ ├── good-extension.crt │ │ │ │ ├── root_ca.crt │ │ │ │ ├── rsa.csr │ │ │ │ ├── short-rsa.csr │ │ │ │ ├── ssh_host_ca_key.pub │ │ │ │ ├── ssh_user_ca_key.pub │ │ │ │ ├── x5c-leaf.crt │ │ │ │ └── yubico-piv-ca.crt │ │ │ ├── secrets/ │ │ │ │ ├── bar.priv │ │ │ │ ├── bar_host_ssh_key │ │ │ │ ├── ecdsa.key │ │ │ │ ├── ed25519.key │ │ │ │ ├── foo.key │ │ │ │ ├── foo.priv │ │ │ │ ├── foo_user_ssh_key │ │ │ │ ├── rsa.key │ │ │ │ ├── ssh_host_ca_key │ │ │ │ ├── ssh_user_ca_key │ │ │ │ └── x5c-leaf.key │ │ │ └── templates/ │ │ │ └── cr.tpl │ │ ├── timeduration.go │ │ ├── timeduration_test.go │ │ ├── utils_test.go │ │ ├── webhook.go │ │ ├── webhook_test.go │ │ ├── wire/ │ │ │ ├── dpop_options.go │ │ │ ├── dpop_options_test.go │ │ │ ├── oidc_options.go │ │ │ ├── oidc_options_test.go │ │ │ ├── wire_options.go │ │ │ └── wire_options_test.go │ │ ├── x5c.go │ │ └── x5c_test.go │ ├── provisioners.go │ ├── provisioners_test.go │ ├── root.go │ ├── root_test.go │ ├── ssh.go │ ├── ssh_test.go │ ├── testdata/ │ │ ├── certs/ │ │ │ ├── badsig.csr │ │ │ ├── foo.crt │ │ │ ├── foo.csr │ │ │ ├── intermediate_ca.crt │ │ │ ├── provisioner-not-found.crt │ │ │ ├── renew-disabled.crt │ │ │ ├── root_ca.crt │ │ │ ├── ssh_host_ca_key.pub │ │ │ └── ssh_user_ca_key.pub │ │ ├── scep/ │ │ │ ├── intermediate.crt │ │ │ ├── intermediate.key │ │ │ ├── root.crt │ │ │ └── root.key │ │ ├── secrets/ │ │ │ ├── foo.key │ │ │ ├── intermediate_ca_key │ │ │ ├── max_priv.jwk │ │ │ ├── max_pub.jwk │ │ │ ├── provisioner-not-found.key │ │ │ ├── renew-disabled.key │ │ │ ├── ssh_host_ca_key │ │ │ ├── ssh_user_ca_key │ │ │ ├── step_cli_key │ │ │ ├── step_cli_key.public │ │ │ ├── step_cli_key_priv.jwk │ │ │ └── step_cli_key_pub.jwk │ │ └── templates/ │ │ ├── badjsonsyntax.tpl │ │ ├── badjsonvalue.tpl │ │ ├── ca.tpl │ │ ├── config.tpl │ │ ├── error.tpl │ │ ├── fail.tpl │ │ ├── include.tpl │ │ ├── known_hosts.tpl │ │ ├── sshd_config.tpl │ │ └── step_includes.tpl │ ├── tls.go │ ├── tls_test.go │ ├── version.go │ ├── webhook.go │ └── webhook_test.go ├── autocert/ │ └── README.md ├── ca/ │ ├── acmeClient.go │ ├── acmeClient_test.go │ ├── adminClient.go │ ├── bootstrap.go │ ├── bootstrap_test.go │ ├── ca.go │ ├── ca_test.go │ ├── client/ │ │ └── requestid.go │ ├── client.go │ ├── client_test.go │ ├── identity/ │ │ ├── client.go │ │ ├── client_test.go │ │ ├── identity.go │ │ ├── identity_test.go │ │ └── testdata/ │ │ ├── certs/ │ │ │ ├── intermediate_ca.crt │ │ │ ├── root_ca.crt │ │ │ └── server.crt │ │ ├── config/ │ │ │ ├── badIdentity.json │ │ │ ├── badca.json │ │ │ ├── badroot.json │ │ │ ├── ca.json │ │ │ ├── defaults.json │ │ │ ├── fail.json │ │ │ ├── identity.json │ │ │ └── tunnel.json │ │ ├── identity/ │ │ │ ├── expired.crt │ │ │ ├── identity.crt │ │ │ ├── identity_key │ │ │ └── not_before.crt │ │ └── secrets/ │ │ ├── intermediate_ca_key │ │ ├── root_ca_key │ │ └── server_key │ ├── mutable_tls_config.go │ ├── provisioner.go │ ├── provisioner_test.go │ ├── renew.go │ ├── signal.go │ ├── testdata/ │ │ ├── ca.json │ │ ├── federated-ca.json │ │ ├── rotate-ca-0.json │ │ ├── rotate-ca-1.json │ │ ├── rotate-ca-2.json │ │ ├── rotate-ca-3.json │ │ ├── rotated/ │ │ │ ├── intermediate_ca.crt │ │ │ ├── intermediate_ca_key │ │ │ ├── root_ca.crt │ │ │ └── root_ca_key │ │ ├── rsaca.json │ │ └── secrets/ │ │ ├── federated_ca.crt │ │ ├── intermediate_ca.crt │ │ ├── intermediate_ca_key │ │ ├── ott_key │ │ ├── ott_key.public │ │ ├── ott_mariano_priv.jwk │ │ ├── ott_mariano_pub.jwk │ │ ├── root_ca.crt │ │ ├── root_ca_key │ │ ├── rsa_intermediate_ca.crt │ │ ├── rsa_intermediate_ca_key │ │ ├── rsa_root_ca.crt │ │ ├── rsa_root_ca_key │ │ ├── step_cli_key │ │ ├── step_cli_key.public │ │ ├── step_cli_key_priv.jwk │ │ └── step_cli_key_pub.jwk │ ├── tls.go │ ├── tls_options.go │ ├── tls_options_test.go │ └── tls_test.go ├── cas/ │ ├── apiv1/ │ │ ├── extension.go │ │ ├── extension_test.go │ │ ├── options.go │ │ ├── options_test.go │ │ ├── registry.go │ │ ├── registry_test.go │ │ ├── requests.go │ │ ├── services.go │ │ └── services_test.go │ ├── cas.go │ ├── cas_test.go │ ├── cloudcas/ │ │ ├── certificate.go │ │ ├── certificate_test.go │ │ ├── cloudcas.go │ │ ├── cloudcas_test.go │ │ ├── mock_client_test.go │ │ └── mock_operation_server_test.go │ ├── softcas/ │ │ ├── softcas.go │ │ └── softcas_test.go │ ├── stepcas/ │ │ ├── issuer.go │ │ ├── issuer_test.go │ │ ├── jwk_issuer.go │ │ ├── jwk_issuer_test.go │ │ ├── stepcas.go │ │ ├── stepcas_test.go │ │ ├── x5c_issuer.go │ │ └── x5c_issuer_test.go │ └── vaultcas/ │ ├── auth/ │ │ ├── approle/ │ │ │ ├── approle.go │ │ │ └── approle_test.go │ │ ├── aws/ │ │ │ ├── aws.go │ │ │ └── aws_test.go │ │ └── kubernetes/ │ │ ├── kubernetes.go │ │ ├── kubernetes_test.go │ │ └── token │ ├── vaultcas.go │ └── vaultcas_test.go ├── cmd/ │ └── step-ca/ │ └── main.go ├── commands/ │ ├── app.go │ ├── export.go │ └── onboard.go ├── cosign.pub ├── db/ │ ├── db.go │ ├── db_test.go │ ├── simple.go │ └── simple_test.go ├── debian/ │ └── copyright ├── docker/ │ ├── Dockerfile │ ├── Dockerfile.hsm │ └── entrypoint.sh ├── errs/ │ ├── error.go │ └── errors_test.go ├── examples/ │ ├── README.md │ ├── ansible/ │ │ ├── smallstep-certs/ │ │ │ ├── defaults/ │ │ │ │ └── main.yml │ │ │ └── tasks/ │ │ │ └── main.yml │ │ ├── smallstep-install/ │ │ │ ├── defaults/ │ │ │ │ └── main.yml │ │ │ └── tasks/ │ │ │ └── main.yml │ │ └── smallstep-ssh/ │ │ ├── defaults/ │ │ │ └── main.yml │ │ └── tasks/ │ │ └── main.yml │ ├── basic-client/ │ │ └── client.go │ ├── basic-federation/ │ │ ├── client/ │ │ │ └── main.go │ │ ├── pki/ │ │ │ ├── cloud/ │ │ │ │ ├── certs/ │ │ │ │ │ ├── intermediate_ca.crt │ │ │ │ │ ├── kubernetes_root_ca.crt │ │ │ │ │ └── root_ca.crt │ │ │ │ ├── config/ │ │ │ │ │ ├── ca.federated.json │ │ │ │ │ └── ca.json │ │ │ │ └── secrets/ │ │ │ │ ├── intermediate_ca_key │ │ │ │ └── root_ca_key │ │ │ └── kubernetes/ │ │ │ ├── certs/ │ │ │ │ ├── cloud_root_ca.crt │ │ │ │ ├── intermediate_ca.crt │ │ │ │ └── root_ca.crt │ │ │ ├── config/ │ │ │ │ ├── ca.federated.json │ │ │ │ └── ca.json │ │ │ └── secrets/ │ │ │ ├── intermediate_ca_key │ │ │ └── root_ca_key │ │ └── server/ │ │ └── main.go │ ├── bootstrap-client/ │ │ └── client.go │ ├── bootstrap-mtls-server/ │ │ └── server.go │ ├── bootstrap-tls-server/ │ │ └── server.go │ ├── docker/ │ │ ├── Makefile │ │ ├── ca/ │ │ │ ├── Dockerfile │ │ │ └── pki/ │ │ │ ├── config/ │ │ │ │ └── ca.json │ │ │ └── secrets/ │ │ │ ├── intermediate_ca.crt │ │ │ ├── intermediate_ca_key │ │ │ ├── root_ca.crt │ │ │ └── root_ca_key │ │ ├── docker-compose.yml │ │ ├── nginx/ │ │ │ ├── Dockerfile │ │ │ ├── certwatch.sh │ │ │ ├── entrypoint.sh │ │ │ └── site.conf │ │ ├── password.txt │ │ └── renewer/ │ │ ├── Dockerfile │ │ ├── crontab │ │ └── entrypoint.sh │ ├── pki/ │ │ ├── config/ │ │ │ └── ca.json │ │ └── secrets/ │ │ ├── intermediate_ca.crt │ │ ├── intermediate_ca_key │ │ ├── root_ca.crt │ │ └── root_ca_key │ └── puppet/ │ ├── ca.json.erb │ ├── defaults.json.erb │ ├── step.pp │ ├── step_ca.pp │ └── tls_server.pp ├── go.mod ├── go.sum ├── internal/ │ ├── cast/ │ │ ├── cast.go │ │ └── cast_test.go │ ├── httptransport/ │ │ └── httptransport.go │ ├── metrix/ │ │ └── meter.go │ └── userid/ │ └── userid.go ├── logging/ │ ├── clf.go │ ├── handler.go │ ├── handler_test.go │ ├── logger.go │ └── responselogger.go ├── middleware/ │ └── requestid/ │ ├── requestid.go │ └── requestid_test.go ├── monitoring/ │ └── monitoring.go ├── pki/ │ ├── helm.go │ ├── helm_test.go │ ├── pki.go │ ├── pki_test.go │ ├── templates.go │ └── testdata/ │ └── helm/ │ ├── simple.yml │ ├── with-acme-and-duplicate-provisioner-name.yml │ ├── with-acme.yml │ ├── with-admin.yml │ ├── with-provisioner.yml │ ├── with-ssh-and-acme.yml │ ├── with-ssh-and-duplicate-provisioner-name.yml │ └── with-ssh.yml ├── policy/ │ ├── engine.go │ ├── engine_test.go │ ├── options.go │ ├── options_test.go │ ├── ssh.go │ ├── validate.go │ └── x509.go ├── scep/ │ ├── api/ │ │ ├── api.go │ │ └── api_test.go │ ├── authority.go │ ├── authority_test.go │ ├── options.go │ ├── provisioner.go │ └── scep.go ├── scripts/ │ ├── README.md │ ├── badger-migration/ │ │ └── main.go │ ├── install-step-ra.sh │ ├── package-repo-import.sh │ └── package-upload.sh ├── server/ │ └── server.go ├── systemd/ │ ├── README.md │ └── step-ca.service ├── templates/ │ ├── templates.go │ ├── templates_test.go │ ├── values.go │ └── values_test.go ├── test/ │ └── integration/ │ ├── requestid_test.go │ └── scep/ │ ├── common_test.go │ ├── decrypter_cas_test.go │ ├── decrypter_test.go │ ├── internal/ │ │ └── x509/ │ │ ├── debug.go │ │ ├── doc.go │ │ ├── oid.go │ │ ├── parser.go │ │ ├── pkcs1.go │ │ ├── verify.go │ │ └── x509.go │ ├── regular_cas_test.go │ ├── regular_test.go │ ├── windows_go1.23_test.go │ └── windows_test.go ├── tools.go └── webhook/ ├── options.go ├── options_test.go └── types.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .VERSION ================================================ $Format:%d$ ================================================ FILE: .dockerignore ================================================ bin coverage.txt *.test *.out .travis-releases ================================================ FILE: .gitattributes ================================================ .VERSION export-subst ================================================ FILE: .github/ISSUE_TEMPLATE/bug-report.yml ================================================ name: Bug Report description: File a bug report title: "[Bug]: " labels: ["bug", "needs triage"] body: - type: markdown attributes: value: | Thanks for taking the time to fill out this bug report! - type: textarea id: steps attributes: label: Steps to Reproduce description: Tell us how to reproduce this issue. placeholder: These are the steps! validations: required: true - type: textarea id: your-env attributes: label: Your Environment value: |- * OS - * `step-ca` Version - validations: required: true - type: textarea id: expected-behavior attributes: label: Expected Behavior description: What did you expect to happen? validations: required: true - type: textarea id: actual-behavior attributes: label: Actual Behavior description: What happens instead? validations: required: true - type: textarea id: context attributes: label: Additional Context description: Add any other context about the problem here. validations: required: false - type: textarea id: contributing attributes: label: Contributing value: | Vote on this issue by adding a 👍 reaction. To contribute a fix for this issue, leave a comment (and link to your pull request, if you've opened one already). validations: required: false ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: true contact_links: - name: Ask on Discord url: https://discord.gg/7xgjhVAg6g about: You can ask for help here! - name: Want to contribute to step certificates? url: https://github.com/smallstep/certificates/blob/master/CONTRIBUTING.md about: Be sure to read contributing guidelines! ================================================ FILE: .github/ISSUE_TEMPLATE/documentation-request.md ================================================ --- name: Documentation Request about: Request documentation for a feature title: '[Docs]:' labels: docs, needs triage assignees: '' --- ## Hello! - Vote on this issue by adding a 👍 reaction - If you want to document this feature, comment to let us know (we'll work with you on design, scheduling, etc.) ## Affected area/feature ================================================ FILE: .github/ISSUE_TEMPLATE/enhancement.md ================================================ --- name: Enhancement about: Suggest an enhancement to step-ca title: '' labels: enhancement, needs triage assignees: '' --- ## Hello! - Vote on this issue by adding a 👍 reaction - If you want to implement this feature, comment to let us know (we'll work with you on design, scheduling, etc.) ## Issue details ## Why is this needed? ================================================ FILE: .github/PULL_REQUEST_TEMPLATE ================================================ #### Name of feature: #### Pain or issue this feature alleviates: #### Why is this important to the project (if not answered above): #### Is there documentation on how to use this feature? If so, where? #### In what environments or workflows is this feature supported? #### In what environments or workflows is this feature explicitly NOT supported (if any)? #### Supporting links/other PRs/issues: 💔Thank you! ================================================ FILE: .github/dependabot.yml ================================================ # To get started with Dependabot version updates, you'll need to specify which # package ecosystems to update and where the package manifests are located. # Please see the documentation for all configuration options: # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates version: 2 updates: - package-ecosystem: "gomod" # See documentation for possible values directory: "/" # Location of package manifests schedule: interval: "weekly" - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" ================================================ FILE: .github/workflows/actionci.yml ================================================ name: Action CI on: push: tags-ignore: - 'v*' branches: - "master" pull_request: workflow_call: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true jobs: actionci: permissions: contents: read security-events: write uses: smallstep/workflows/.github/workflows/actionci.yml@main with: zizmor-advanced-security: true secrets: inherit ================================================ FILE: .github/workflows/ci.yml ================================================ name: CI on: push: tags-ignore: - 'v*' branches: - "master" pull_request: workflow_call: secrets: CODECOV_TOKEN: required: true concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true permissions: contents: read jobs: ci: permissions: actions: read contents: read security-events: write uses: smallstep/workflows/.github/workflows/goCI.yml@main with: only-latest-golang: false os-dependencies: 'libpcsclite-dev' run-codeql: true test-command: 'V=1 make test' secrets: inherit ================================================ FILE: .github/workflows/code-scan-cron.yml ================================================ on: schedule: - cron: '0 0 * * *' permissions: actions: read contents: read security-events: write jobs: code-scan: uses: smallstep/workflows/.github/workflows/code-scan.yml@main ================================================ FILE: .github/workflows/dependabot-auto-merge.yml ================================================ name: Dependabot auto-merge on: pull_request permissions: contents: write pull-requests: write jobs: dependabot-auto-merge: uses: smallstep/workflows/.github/workflows/dependabot-auto-merge.yml@main secrets: inherit ================================================ FILE: .github/workflows/publish-packages.yml ================================================ name: Publish to packages.smallstep.com # Independently publish packages to Red Hat (RPM) and Debian (DEB) repositories # without running a full release. Downloads packages from GitHub releases, # uploads to GCS, and imports to Artifact Registry. # # Usage (CLI): # gh workflow run publish-packages.yml -f tag=v0.28.0 on: workflow_dispatch: inputs: tag: description: 'Git tag to publish (e.g., v0.28.0)' required: true type: string jobs: publish: runs-on: ubuntu-latest permissions: id-token: write contents: read steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ inputs.tag }} fetch-depth: 0 - name: Extract version id: version run: echo "version=${TAG#v}" >> "$GITHUB_OUTPUT" env: TAG: ${{ inputs.tag }} - name: Is Pre-release id: is_prerelease run: | if [[ "$TAG" == *"-rc"* ]]; then echo "is_prerelease=true" >> "$GITHUB_OUTPUT" else echo "is_prerelease=false" >> "$GITHUB_OUTPUT" fi env: TAG: ${{ inputs.tag }} - name: Authenticate to Google Cloud uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0 with: workload_identity_provider: ${{ secrets.GOOGLE_CLOUD_WORKLOAD_IDENTITY_PROVIDER }} service_account: ${{ secrets.GOOGLE_CLOUD_GITHUB_SERVICE_ACCOUNT }} - name: Set up Cloud SDK uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1 with: project_id: ${{ secrets.GOOGLE_CLOUD_PACKAGES_PROJECT_ID }} - name: Download packages from GitHub release run: | mkdir -p dist gh release download "$TAG" --pattern "*${VERSION}*.deb" --pattern "*${VERSION}*.rpm" --dir dist env: TAG: ${{ inputs.tag }} VERSION: ${{ steps.version.outputs.version }} GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Upload packages to GCS run: | for pkg in dist/*.deb dist/*.rpm; do ./scripts/package-upload.sh "$pkg" step-ca ${{ steps.version.outputs.version }} done - name: Import packages to Artifact Registry run: ./scripts/package-repo-import.sh step-ca ${{ steps.version.outputs.version }} env: IS_PRERELEASE: ${{ steps.is_prerelease.outputs.is_prerelease }} ================================================ FILE: .github/workflows/release.yml ================================================ name: Create Release & Upload Assets on: push: # Sequence of patterns matched against refs/tags tags: - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 permissions: contents: write jobs: ci: permissions: contents: read actions: read security-events: write uses: smallstep/certificates/.github/workflows/ci.yml@master secrets: inherit create_release: name: Create Release permissions: contents: write needs: ci runs-on: ubuntu-latest env: DOCKER_IMAGE: smallstep/step-ca outputs: version: ${{ steps.extract-tag.outputs.VERSION }} is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} docker_tags: ${{ env.DOCKER_TAGS }} docker_tags_hsm: ${{ env.DOCKER_TAGS_HSM }} steps: - name: Is Pre-release id: is_prerelease env: REF: ${{ github.ref }} run: | set +e echo "${REF}" | grep "\-rc.*" OUT=$? if [ $OUT -eq 0 ]; then IS_PRERELEASE=true; else IS_PRERELEASE=false; fi echo "IS_PRERELEASE=${IS_PRERELEASE}" >> "${GITHUB_OUTPUT}" - name: Extract Tag Names id: extract-tag run: | VERSION=${GITHUB_REF#refs/tags/v} echo "VERSION=${VERSION}" >> "${GITHUB_OUTPUT}" echo "DOCKER_TAGS=${{ env.DOCKER_IMAGE }}:${VERSION}" >> "${GITHUB_ENV}" echo "DOCKER_TAGS_HSM=${{ env.DOCKER_IMAGE }}:${VERSION}-hsm" >> "${GITHUB_ENV}" - name: Add Latest Tag if: steps.is_prerelease.outputs.IS_PRERELEASE == 'false' run: | echo "DOCKER_TAGS=${{ env.DOCKER_TAGS }},${{ env.DOCKER_IMAGE }}:latest" >> "${GITHUB_ENV}" echo "DOCKER_TAGS_HSM=${{ env.DOCKER_TAGS_HSM }},${{ env.DOCKER_IMAGE }}:hsm" >> "${GITHUB_ENV}" - name: Create Release id: create_release uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2.6.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: tag_name: ${{ github.ref_name }} name: Release ${{ github.ref_name }} draft: false prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} goreleaser: needs: create_release permissions: id-token: write contents: write packages: write uses: smallstep/workflows/.github/workflows/goreleaser.yml@main with: enable-packages-upload: true is-prerelease: ${{ needs.create_release.outputs.is_prerelease == 'true' }} secrets: inherit build_upload_docker: name: Build & Upload Docker Images needs: create_release permissions: id-token: write contents: write uses: smallstep/workflows/.github/workflows/docker-buildx-push.yml@main with: platforms: linux/amd64,linux/386,linux/arm,linux/arm64 tags: ${{ needs.create_release.outputs.docker_tags }} docker_image: smallstep/step-ca docker_file: docker/Dockerfile secrets: inherit build_upload_docker_hsm: name: Build & Upload HSM Enabled Docker Images needs: create_release permissions: id-token: write contents: write uses: smallstep/workflows/.github/workflows/docker-buildx-push.yml@main with: platforms: linux/amd64,linux/386,linux/arm,linux/arm64 tags: ${{ needs.create_release.outputs.docker_tags_hsm }} docker_image: smallstep/step-ca docker_file: docker/Dockerfile.hsm secrets: inherit ================================================ FILE: .github/workflows/triage.yml ================================================ name: Add Issues and PRs to Triage on: issues: types: - opened - reopened pull_request_target: types: - opened - reopened permissions: pull-requests: write issues: write jobs: triage: uses: smallstep/workflows/.github/workflows/triage.yml@main secrets: inherit ================================================ FILE: .github/zizmor.yml ================================================ rules: unpinned-uses: config: policies: "smallstep/*": ref-pin secrets-inherit: disable: true ref-confusion: disable: true dangerous-triggers: ignore: - triage.yml ================================================ FILE: .gitignore ================================================ # Binaries for programs and plugins /bin *.exe *.exe~ *.dll *.so *.dylib # Go Workspaces go.work go.work.sum # Test binary, build with `go test -c` *.test # Output of the go coverage tool, specifically when used with LiteIDE *.out # Others *.swp .releases coverage.txt output vendor dist/ .idea .envrc # Packages files 0x889B19391F774443-Certify.key gha-creds-*.json ================================================ FILE: .goreleaser.yml ================================================ # Documentation: https://goreleaser.com/customization/ # yaml-language-server: $schema=https://goreleaser.com/static/schema-pro.json project_name: step-ca version: 2 variables: packageName: step-ca packageRelease: 1 # Manually update release: in the nfpm section to match this value if you change this before: hooks: # You may remove this if you don't use go modules. - go mod download after: hooks: # This script depends on IS_PRERELEASE env being set. This is set by CI in the Is Pre-release step. - cmd: bash scripts/package-repo-import.sh {{ .Var.packageName }} {{ .Version }} output: true builds: - id: step-ca env: - CGO_ENABLED=0 targets: - darwin_amd64 - darwin_arm64 - freebsd_amd64 - linux_386 - linux_amd64 - linux_arm64 - linux_arm_5 - linux_arm_6 - linux_arm_7 - windows_amd64 flags: - -trimpath main: ./cmd/step-ca/main.go binary: step-ca ldflags: - -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}} archives: - &ARCHIVE # Can be used to change the archive formats for specific GOOSs. # Most common use case is to archive as zip on Windows. # Default is empty. name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Version }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}" format_overrides: - goos: windows format: zip files: - README.md - LICENSE allow_different_binary_count: true - << : *ARCHIVE id: unversioned name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}" wrap_in_directory: "{{ .ProjectName }}_{{ .Os }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}" nfpms: # Configure nFPM for .deb and .rpm releases # # See https://nfpm.goreleaser.com/configuration/ # and https://goreleaser.com/customization/nfpm/ # # Useful tools for debugging .debs: # List file contents: dpkg -c dist/step_...deb # Package metadata: dpkg --info dist/step_....deb # - &NFPM id: packages builds: - step-ca package_name: "{{ .Var.packageName }}" release: "1" file_name_template: >- {{- trimsuffix .ConventionalFileName .ConventionalExtension -}} {{- if and (eq .Arm "6") (eq .ConventionalExtension ".deb") }}6{{ end -}} {{- if not (eq .Amd64 "v1")}}{{ .Amd64 }}{{ end -}} {{- .ConventionalExtension -}} vendor: Smallstep Labs homepage: https://github.com/smallstep/certificates maintainer: Smallstep description: > step-ca is an online certificate authority for secure, automated certificate management. license: Apache 2.0 section: utils formats: - deb - rpm priority: optional bindir: /usr/bin contents: - src: debian/copyright dst: /usr/share/doc/step-ca/copyright rpm: signature: key_file: "{{ .Env.GPG_PRIVATE_KEY_FILE }}" deb: signature: key_file: "{{ .Env.GPG_PRIVATE_KEY_FILE }}" type: origin - << : *NFPM id: unversioned file_name_template: "{{ .PackageName }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}" source: enabled: true name_template: '{{ .ProjectName }}_{{ .Version }}' checksum: name_template: 'checksums.txt' extra_files: - glob: ./.releases/* signs: - cmd: cosign signature: "${artifact}.sigstore.json" args: - "sign-blob" - "--bundle=${signature}" - "${artifact}" - "--yes" artifacts: all publishers: - name: Google Cloud Artifact Registry ids: - packages cmd: ./scripts/package-upload.sh {{ abs .ArtifactPath }} {{ .Var.packageName }} {{ .Version }} {{ .Var.packageRelease }} snapshot: name_template: "{{ .Tag }}-next" release: # Repo in which the release will be created. # Default is extracted from the origin remote URL or empty if its private hosted. # Note: it can only be one: either github, gitlab or gitea github: owner: smallstep name: certificates # IDs of the archives to use. # Defaults to all. #ids: # - foo # - bar # If set to true, will not auto-publish the release. # Default is false. draft: false # If set to auto, will mark the release as not ready for production # in case there is an indicator for this in the tag e.g. v1.0.0-rc1 # If set to true, will mark the release as not ready for production. # Default is false. prerelease: auto # You can change the name of the release. # Default is `{{.Tag}}` name_template: "Step CA {{ .Tag }} ({{ .Env.RELEASE_DATE }})" # Header template for the release body. # Defaults to empty. header: | ## Official Release Artifacts #### Linux - 📦 [step-ca_linux_{{ .Version }}_amd64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_linux_{{ .Version }}_amd64.tar.gz) - 📦 [step-ca_{{ replace .Version "-" "." }}-{{ .Var.packageRelease }}_amd64.deb](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ replace .Version "-" "." }}-{{ .Var.packageRelease }}_amd64.deb) - 📦 [step-ca-{{ replace .Version "-" "." }}-{{ .Var.packageRelease }}.x86_64.rpm](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca-{{ replace .Version "-" "." }}-{{ .Var.packageRelease }}.x86_64.rpm) - 📦 [step-ca_{{ replace .Version "-" "." }}-{{ .Var.packageRelease }}_arm64.deb](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ replace .Version "-" "." }}-{{ .Var.packageRelease }}_arm64.deb) - 📦 [step-ca-{{ replace .Version "-" "." }}-{{ .Var.packageRelease }}.aarch64.rpm](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca-{{ replace .Version "-" "." }}-{{ .Var.packageRelease }}.aarch64.rpm) #### OSX Darwin - 📦 [step-ca_darwin_{{ .Version }}_amd64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_amd64.tar.gz) - 📦 [step-ca_darwin_{{ .Version }}_arm64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_arm64.tar.gz) #### Windows - 📦 [step-ca_windows_{{ .Version }}_amd64.zip](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_windows_{{ .Version }}_amd64.zip) For more builds across platforms and architectures, see the `Assets` section below. And for packaged versions (Docker, k8s, Homebrew), see our [installation docs](https://smallstep.com/docs/step-ca/installation). Don't see the artifact you need? Open an issue [here](https://github.com/smallstep/certificates/issues/new/choose). ## Signatures and Checksums `step-ca` uses [sigstore/cosign](https://github.com/sigstore/cosign) for signing and verifying release artifacts. Below is an example using `cosign` to verify a release artifact: ``` cosign verify-blob \ --bundle step-ca_darwin_{{ .Version }}_amd64.tar.gz.sigstore.json \ --certificate-identity-regexp "https://github\.com/smallstep/workflows/.*" \ --certificate-oidc-issuer https://token.actions.githubusercontent.com \ step-ca_darwin_{{ .Version }}_amd64.tar.gz ``` The `checksums.txt` file (in the `Assets` section below) contains a checksum for every artifact in the release. # Footer template for the release body. # Defaults to empty. footer: | ## Thanks! Those were the changes on {{ .Tag }}! Come join us on [Discord](https://discord.gg/X2RKGwEbV9) to ask questions, chat about PKI, or get a sneak peek at the freshest PKI memes. # You can disable this pipe in order to not upload any artifacts. # Defaults to false. #disable: true # You can add extra pre-existing files to the release. # The filename on the release will be the last part of the path (base). If # another file with the same name exists, the latest one found will be used. # Defaults to empty. extra_files: - glob: ./.releases/* #extra_files: # - glob: ./path/to/file.txt # - glob: ./glob/**/to/**/file/**/* # - glob: ./glob/foo/to/bar/file/foobar/override_from_previous winget: - # IDs of the archives to use. # Empty means all IDs. ids: [ default ] # # Default: ProjectName # Templates: allowed name: step-ca # Publisher name. # # Templates: allowed # Required. publisher: Smallstep # Your app's description. # # Templates: allowed # Required. short_description: "A private certificate authority (X.509 & SSH) & ACME server for secure automated certificate management." # Package identifier. # # Default: Publisher.ProjectName # Templates: allowed package_identifier: Smallstep.step-ca # License name. # # Templates: allowed # Required. license: "Apache-2.0" # Publisher URL. # # Templates: allowed publisher_url: "https://smallstep.com" # Publisher support URL. # # Templates: allowed publisher_support_url: "https://github.com/smallstep/certificates/discussions" # Privacy URL. # # Templates: allowed privacy_url: "https://smallstep.com/privacy-policy/" # URL which is determined by the given Token (github, gitlab or gitea). # # Default depends on the client. # Templates: allowed url_template: "https://github.com/smallstep/certificates/releases/download/{{ .Tag }}/{{ .ArtifactName }}" # Git author used to commit to the repository. commit_author: name: goreleaserbot email: goreleaser@smallstep.com # The project name and current git tag are used in the format string. # # Templates: allowed commit_msg_template: "{{ .PackageIdentifier }}: {{ .Tag }}" # Your app's homepage. homepage: "https://github.com/smallstep/certificates" # Your app's long description. # # Templates: allowed description: "step-ca is an online certificate authority for secure, automated certificate management. It issues X.509 and SSH certificates using protocols like ACME, OIDC, and SCEP." # License URL. # # Templates: allowed license_url: "https://github.com/smallstep/certificates/blob/master/LICENSE" # Release notes. # # Templates: allowed release_notes: "{{.Changelog}}" # Release notes URL. # # Templates: allowed release_notes_url: "https://github.com/smallstep/certificates/releases/tag/{{ .Tag }}" # Installation notes. # # Templates: allowed installation_notes: "After installation, run 'step-ca --help' to get started. Documentation: https://smallstep.com/docs/step-ca" # Create the PR - for testing skip_upload: auto # Tags. tags: - certificates - smallstep - tls # Repository to push the generated files to. repository: owner: smallstep name: winget-pkgs branch: "step-ca-{{.Version}}" # Optionally a token can be provided, if it differs from the token # provided to GoReleaser # Templates: allowed #token: "{{ .Env.GITHUB_PERSONAL_AUTH_TOKEN }}" # Sets up pull request creation instead of just pushing to the given branch. # Make sure the 'branch' property is different from base before enabling # it. # # Since: v1.17 pull_request: # Whether to enable it or not. enabled: true check_boxes: true # Whether to open the PR as a draft or not. # # Default: false # Since: v1.19 # draft: true # Base can also be another repository, in which case the owner and name # above will be used as HEAD, allowing cross-repository pull requests. # # Since: v1.19 base: owner: microsoft name: winget-pkgs branch: master scoops: - ids: [ default ] # Template for the url which is determined by the given Token (github or gitlab) # Default for github is "https://github.com///releases/download/{{ .Tag }}/{{ .ArtifactName }}" # Default for gitlab is "https://gitlab.com///uploads/{{ .ArtifactUploadHash }}/{{ .ArtifactName }}" # Default for gitea is "https://gitea.com///releases/download/{{ .Tag }}/{{ .ArtifactName }}" url_template: "http://github.com/smallstep/certificates/releases/download/{{ .Tag }}/{{ .ArtifactName }}" # Repository to push the app manifest to. repository: owner: smallstep name: scoop-bucket branch: main # Git author used to commit to the repository. # Defaults are shown. commit_author: name: goreleaserbot email: goreleaser@smallstep.com # The project name and current git tag are used in the format string. commit_msg_template: "Scoop update for {{ .ProjectName }} version {{ .Tag }}" # Your app's homepage. # Default is empty. homepage: "https://smallstep.com/docs/step-ca" # Skip uploads for prerelease. skip_upload: auto # Your app's description. # Default is empty. description: "A private certificate authority (X.509 & SSH) & ACME server for secure automated certificate management, so you can use TLS everywhere & SSO for SSH." # Your app's license # Default is empty. license: "Apache-2.0" ================================================ FILE: .version.sh ================================================ #!/usr/bin/env sh read -r firstline < .VERSION last_half="${firstline##*tag: }" case "$last_half" in v*) version_string="${last_half%%[,)]*}" ;; esac echo "${version_string:-v0.0.0}" ================================================ FILE: CHANGELOG.md ================================================ # Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## TEMPLATE -- do not alter or remove --- ## [x.y.z] - aaaa-bb-cc ### Added ### Changed ### Deprecated ### Removed ### Fixed ### Security --- ### [0.30.1] - 2026-03-18 - Fix release issue ### [0.30.0] - 2026-03-18 ### Added - Warn when ACME provisioner is configured without a database (smallstep/certificates#2526) - Validate webhooks configured on the ca.json (smallstep/certificates#2570) - Add HTTP transport decorator (smallstep/certificates#2533) ### Changed - Upgrade HSM-enabled Docker images from Debian Bookworm (12) to Debian Trixie (13) (smallstep/certificates#2493) - Use JSON array format for Dockerfile's `CMD` instruction. This prevents shell interpolation of environment variables like `CONFIGPATH` and `PWDPATH`, ensuring consistent command execution. Commands can still be overridden via Kubernetes or Docker configuration when needed (smallstep/certificates#2493) ### Fixed - Fix CRL IssuingDistributionPoint marshaling to correctly unset `OnlyContainsUserCerts` and `OnlyContainsCACerts` flags (smallstep/certificates#2511) - Fix CRL DER download content-disposition filename extension from `.der` to `.crl` (smallstep/certificates#2537) - Fix SSH agent KMS when CA is configured with Prometheus instrumented signer (smallstep/certificates#2379) - Return helpful error message when root certificate is not found (smallstep/certificates#1893) - Fix missing version number when building step-ca from source archive (smallstep/certificates#2513) - Fix potential panic if a certificate had an empty tcg-kp-AIKCertificate extended key usage (smallstep/certificates#2569) - Fix CA startup when configured with SCEP and Google Cloud CAS (smallstep/certificates#2517) - Close idle connections on client certificate renew (smallstep/certificates#2515) ## [0.29.0] - 2025-12-03 ### Added - Add support for YubiKeys 5.7.4+ (smallstep/certificates#2370) - Support managed device ID OID for step attestation format (smallstep/certificates#2382) - Add support for remote configuration of GCP Organization-Id (smallstep/certificates#2408) - Add additional DOCKER_STEPCA_INIT_* envs for docker/entrypoint.sh (smallstep/certificates#2461) - Add sd_notify support (smallstep/certificates#2463) ### Changed - Use errgroup to shutdown services concurrently (smallstep/certificates#2343) ### Fixed - Fix process hanging after SIGTERM (smallstep/certificates#2338) - Disable execute permission on a few policy/engine source files (smallstep/certificates#2435) - Fix backdate support for ACME provisioner (smallstep/certificates#2444) ### Security - Authorization Bypass in ACME and SCEP Provisioners (smallstep/certificates#2491) - Improper Authorization Check for SSH Certificate Revocation (smallstep/certificates#2491) ## [0.28.4] - 2025-07-13 ### Added - Add support for using key usage, extended key usage, and basic constraints from certificate requests in certificate templates (smallstep/crypto#767) - Allow to specify audience when generating JWK provisioner tokens (smallstep/certificates#2326) - Add SSH certificate type to exposed metrics (smallstep/certificates#2290) - Enable dynamic validation of project ownership within a GCP organization when using the GCP Cloud Instance Identity provisioner (smallstep/certificates#2133) ### Changed - Introduce poolhttp package for improved memory performance of Authority httpClients (smallstep/certificates#2325) ## [0.28.3] - 2025-03-17 - dependabot updates ## [0.28.2] - 2025-02-20 ### Added - Added support for imported keys on YubiKey (smallstep/certificates#2113) - Enable storing ACME attestation payload (smallstep/certificates#2114) - Add ACME attestation format field to ACME challenge (smallstep/certificates#2124) ### Changed - Added internal httptransport package to replace cloning of http.DefaultTransport (smallstep/certificates#2098, smallstep/certificates#2103, smallstep/certificates#2104) - For example, replacing http.DefaultTransport clone in provisioner webhook business logic. ## [0.28.1] - 2024-11-19 ### Added - Support for using template data from SCEPCHALLENGE webhooks (smallstep/certificates#2065) - New field to Webhook response that allows for propagation of human readable errors to the client (smallstep/certificates#2066, smallstep/certificates#2069) - CICD for pushing DEB and RPM packages to packages.smallstep.com on releases (smallstep/certificates#2076) - PKCS11 utilities in HSM container image (smallstep/certificates#2077) ### Changed - Artifact names for RPM and DEB packages in conformance with standards (smallstep/certificates#2076) ## [0.28.0] - 2024-10-29 ### Added - Add options to GCP IID provisioner to enable or disable signing of SSH user and host certificates (smallstep/certificates#2045) ### Changed - For IID provisioners with disableCustomSANs set to true, validate that the requested DNS names are a subset of the allowed DNS names (based on the IID token), rather than requiring an exact match to the entire list of allowed DNS names. (smallstep/certificates#2044) ## [0.27.5] - 2024-10-17 ### Added - Option to log real IP (x-forwarded-for) in logging middleware (smallstep/certificates#2002) ### Fixed - Pulled in updates to smallstep/pkcs7 to fix failing Windows SCEP enrollment certificates (smallstep/certificates#1994) ## [0.27.4] - 2024-09-13 ### Fixed - Release worfklow ## [0.27.3] - 2024-09-13 ### Added - AWS auth method for Vault RA mode (smallstep/certificates#1976) - API endpoints for retrieving Intermediate certificates (smallstep/certificates#1962) - Enable use of OIDC provisioner with private identity providers and a certificate from step-ca (smallstep/certificates#1940) - Support for verifying `cnf` and `x5rt#S256` claim when provided in token (smallstep/certificates#1660) - Add Wire integration to ACME provisioner (smallstep/certificates#1666) ### Changed - Clarified SSH certificate policy errors (smallstep/certificates#1951) ### Fixed - Nebula ECDSA P-256 support (smallstep/certificates#1662) ## [0.27.2] - 2024-07-18 ### Added - `--console` option to default step-ssh config (smallstep/certificates#1931) ## [0.27.1] - 2024-07-12 ### Changed - Enable use of strict FQDN with a flag (smallstep/certificates#1926) - This reverses a change in 0.27.0 that required the use of strict FQDNs (smallstep/certificate#1910) ## [0.27.0] - 2024-07-11 ### Added - Support for validity windows in templates (smallstep/certificates#1903) - Create identity certificate with host URI when using any provisioner (smallstep/certificates#1922) ### Changed - Do strict DNS lookup on ACME (smallstep/certificates#1910) ### Fixed - Handle bad attestation object in deviceAttest01 validation (smallstep/certificates#1913) ## [0.26.2] - 2024-06-13 ### Added - Add provisionerID to ACME accounts (smallstep/certificates#1830) - Enable verifying ACME provisioner using provisionerID if available (smallstep/certificates#1844) - Add methods to Authority to get intermediate certificates (smallstep/certificates#1848) - Add GetX509Signer method (smallstep/certificates#1850) ### Changed - Make ISErrNotFound more flexible (smallstep/certificates#1819) - Log errors using slog.Logger (smallstep/certificates#1849) - Update hardcoded AWS certificates (smallstep/certificates#1881) ## [0.26.1] - 2024-04-22 ### Added - Allow configuration of a custom SCEP key manager (smallstep/certificates#1797) ### Fixed - id-scep-failInfoText OID (smallstep/certificates#1794) - CA startup with Vault RA configuration (smallstep/certificates#1803) ## [0.26.0] - 2024-03-28 ### Added - [TPM KMS](https://github.com/smallstep/crypto/tree/master/kms/tpmkms) support for CA keys (smallstep/certificates#1772) - Propagation of HTTP request identifier using X-Request-Id header (smallstep/certificates#1743, smallstep/certificates#1542) - Expires header in CRL response (smallstep/certificates#1708) - Support for providing TLS configuration programmatically (smallstep/certificates#1685) - Support for providing external CAS implementation (smallstep/certificates#1684) - AWS `ca-west-1` identity document root certificate (smallstep/certificates#1715) - [COSE RS1](https://www.rfc-editor.org/rfc/rfc8812.html#section-2) as a supported algorithm with ACME `device-attest-01` challenge (smallstep/certificates#1663) ### Changed - In an RA setup, let the CA decide the RA certificate lifetime (smallstep/certificates#1764) - Use Debian Bookworm in Docker containers (smallstep/certificates#1615) - Error message for CSR validation (smallstep/certificates#1665) - Updated dependencies ### Fixed - Stop CA when any of the required servers fails to start (smallstep/certificates#1751). Before the fix, the CA would continue running and only log the server failure when stopped. - Configuration loading errors when not using context were not returned. Fixed in [cli-utils/109](https://github.com/smallstep/cli-utils/pull/109). - HTTP_PROXY and HTTPS_PROXY support for ACME validation client (smallstep/certificates#1658). ### Security - Upgrade to using cosign v2 for signing artifacts ## [0.25.1] - 2023-11-28 ### Added - Provisioner name in SCEP webhook request body in (smallstep/certificates#1617) - Support for ASN1 boolean encoding in (smallstep/certificates#1590) ### Changed - Generation of first provisioner name on `step ca init` in (smallstep/certificates#1566) - Processing of SCEP Get PKIOperation requests in (smallstep/certificates#1570) - Support for signing identity certificate during SSH sign by skipping URI validation in (smallstep/certificates#1572) - Dependency on `micromdm/scep` and `go.mozilla.org/pkcs7` to use Smallstep forks in (smallstep/certificates#1600) - Make the Common Name validator for JWK provisioners accept values from SANs too in (smallstep/certificates#1609) ### Fixed - Registration Authority token creation relied on values from CSR. Fixed to rely on template in (smallstep/certificates#1608) - Use same glibc version for running the CA when built using CGo in (smallstep/certificates#1616) ## [0.25.0] - 2023-09-26 ### Added - Added support for configuring SCEP decrypters in the provisioner (smallstep/certificates#1414) - Added support for TPM KMS (smallstep/crypto#253) - Added support for disableSmallstepExtensions provisioner claim (smallstep/certificates#1484) - Added script to migrate a badger DB to MySQL or PostgreSQL (smallstep/certificates#1477) - Added AWS public certificates for me-central-1 and ap-southeast-3 (smallstep/certificates#1404) - Added namespace field to VaultCAS JSON config (smallstep/certificates#1424) - Added AWS public certificates for me-central-1 and ap-southeast-3 (smallstep/certificates#1404) - Added unversioned filenames to Github release assets (smallstep/certificates#1435) - Send X5C leaf certificate to webhooks (smallstep/certificates#1485) - Added support for disableSmallstepExtensions claim (smallstep/certificates#1484) - Added all AWS Identity Document Certificates (smallstep/certificates#1404, smallstep/certificates#1510) - Added Winget release automation (smallstep/certificates#1519) - Added CSR to SCEPCHALLENGE webhook request body (smallstep/certificates#1523) - Added SCEP issuance notification webhook (smallstep/certificates#1544) - Added ability to disable color in the log text formatter (smallstep/certificates(#1559) ### Changed - Changed the Makefile to produce cgo-enabled builds running `make build GO_ENVS="CGO_ENABLED=1"` (smallstep/certificates#1446) - Return more detailed errors to ACME clients using device-attest-01 (smallstep/certificates#1495) - Change SCEP password type to string (smallstep/certificates#1555) ### Removed - Removed OIDC user regexp check (smallstep/certificates#1481) - Removed automatic initialization of $STEPPATH (smallstep/certificates#1493) - Removed db datasource from error msg to prevent leaking of secrets to logs (smallstep/certificates#1528) ### Fixed - Improved authentication for ACME requests using kid and provisioner name (smallstep/certificates#1386). - Fixed indentation of KMS configuration in helm charts (smallstep/certificates#1405) - Fixed simultaneous sign or decrypt operation on a YubiKey (smallstep/certificates#1476, smallstep/crypto#288) - Fixed adding certificate templates with ASN.1 functions (smallstep/certificates#1500, smallstep/crypto#302) - Fixed a problem when the ca.json is truncated if the encoding of the configuration fails (e.g., new provisioner with bad template data) (smallstep/cli#994, smallstep/certificates#1501) - Fixed provisionerOptionsToLinkedCA missing template and templateData (smallstep/certificates#1520) - Fix calculation of webhook signature (smallstep/certificates#1546) ## [v0.24.2] - 2023-05-11 ### Added - Log SSH certificates (smallstep/certificates#1374) - CRL endpoints on the HTTP server (smallstep/certificates#1372) - Dynamic SCEP challenge validation using webhooks (smallstep/certificates#1366) - For Docker deployments, added DOCKER_STEPCA_INIT_PASSWORD_FILE. Useful for pointing to a Docker Secret in the container (smallstep/certificates#1384) ### Changed - Depend on [smallstep/go-attestation](https://github.com/smallstep/go-attestation) instead of [google/go-attestation](https://github.com/google/go-attestation) - Render CRLs into http.ResponseWriter instead of memory (smallstep/certificates#1373) - Redaction of SCEP static challenge when listing provisioners (smallstep/certificates#1204) ### Fixed - VaultCAS certificate lifetime (smallstep/certificates#1376) ## [v0.24.1] - 2023-04-14 ### Fixed - Docker image name for HSM support (smallstep/certificates#1348) ## [v0.24.0] - 2023-04-12 ### Added - Add ACME `device-attest-01` support with TPM 2.0 (smallstep/certificates#1063). - Add support for new Azure SDK, sovereign clouds, and HSM keys on Azure KMS (smallstep/crypto#192, smallstep/crypto#197, smallstep/crypto#198, smallstep/certificates#1323, smallstep/certificates#1309). - Add support for ASN.1 functions on certificate templates (smallstep/crypto#208, smallstep/certificates#1345) - Add `DOCKER_STEPCA_INIT_ADDRESS` to configure the address to use in a docker container (smallstep/certificates#1262). - Make sure that the CSR used matches the attested key when using AME `device-attest-01` challenge (smallstep/certificates#1265). - Add support for compacting the Badger DB (smallstep/certificates#1298). - Build and release cleanups (smallstep/certificates#1322, smallstep/certificates#1329, smallstep/certificates#1340). ### Fixed - Fix support for PKCS #7 RSA-OAEP decryption through [smallstep/pkcs7#4](https://github.com/smallstep/pkcs7/pull/4), as used in SCEP. - Fix RA installation using `scripts/install-step-ra.sh` (smallstep/certificates#1255). - Clarify error messages on policy errors (smallstep/certificates#1287, smallstep/certificates#1278). - Clarify error message on OIDC email validation (smallstep/certificates#1290). - Mark the IDP critical in the generated CRL data (smallstep/certificates#1293). - Disable database if CA is initialized with the `--no-db` flag (smallstep/certificates#1294). ## [v0.23.2] - 2023-02-02 ### Added - Added [`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) to docker images, and a new image, `smallstep/step-ca-hsm`, compiled with cgo (smallstep/certificates#1243). - Added [`scoop`](https://scoop.sh) packages back to the release (smallstep/certificates#1250). - Added optional flag `--pidfile` which allows passing a filename where step-ca will write its process id (smallstep/certificates#1251). - Added helpful message on CA startup when config can't be opened (smallstep/certificates#1252). - Improved validation and error messages on `device-attest-01` orders (smallstep/certificates#1235). ### Removed - The deprecated CLI utils `step-awskms-init`, `step-cloudkms-init`, `step-pkcs11-init`, `step-yubikey-init` have been removed. [`step`](https://github.com/smallstep/cli) and [`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) should be used instead (smallstep/certificates#1240). ### Fixed - Fixed remote management flags in docker images (smallstep/certificates#1228). ## [v0.23.1] - 2023-01-10 ### Added - Added configuration property `.crl.idpURL` to be able to set a custom Issuing Distribution Point in the CRL (smallstep/certificates#1178). - Added WithContext methods to the CA client (smallstep/certificates#1211). - Docker: Added environment variables for enabling Remote Management and ACME provisioner (smallstep/certificates#1201). - Docker: The entrypoint script now generates and displays an initial JWK provisioner password by default when the CA is being initialized (smallstep/certificates#1223). ### Changed - Ignore SSH principals validation when using an OIDC provisioner. The provisioner will ignore the principals passed and set the defaults or the ones including using WebHooks or templates (smallstep/certificates#1206). ## [v0.23.0] - 2022-11-11 ### Added - Added support for ACME device-attest-01 challenge on iOS, iPadOS, tvOS and YubiKey. - Ability to disable ACME challenges and attestation formats. - Added flags to change ACME challenge ports for testing purposes. - Added name constraints evaluation and enforcement when issuing or renewing X.509 certificates. - Added provisioner webhooks for augmenting template data and authorizing certificate requests before signing. - Added automatic migration of provisioners when enabling remote management. - Added experimental support for CRLs. - Add certificate renewal support on RA mode. The `step ca renew` command must use the flag `--mtls=false` to use the token renewal flow. - Added support for initializing remote management using `step ca init`. - Added support for renewing X.509 certificates on RAs. - Added support for using SCEP with keys in a KMS. - Added client support to set the dialer's local address with the environment variable `STEP_CLIENT_ADDR`. ### Changed - Remove the email requirement for issuing SSH certificates with an OIDC provisioner. - Root files can contain more than one certificate. ### Fixed - Fixed MySQL DSN parsing issues with an upgrade to [smallstep/nosql@v0.5.0](https://github.com/smallstep/nosql/releases/tag/v0.5.0). - Fixed renewal of certificates with missing subject attributes. - Fixed ACME support with [ejabberd](https://github.com/processone/ejabberd). ### Deprecated - The CLIs `step-awskms-init`, `step-cloudkms-init`, `step-pkcs11-init`, `step-yubikey-init` are deprecated. Now you can use [`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) in combination with `step certificates create` to initialize your PKI. ## [0.22.1] - 2022-08-31 ### Fixed - Fixed signature algorithm on EC (root) + RSA (intermediate) PKIs. ## [0.22.0] - 2022-08-26 ### Added - Added automatic configuration of Linked RAs. - Send provisioner configuration on Linked RAs. ### Changed - Certificates signed by an issuer using an RSA key will be signed using the same algorithm used to sign the issuer certificate. The signature will no longer default to PKCS #1. For example, if the issuer certificate was signed using RSA-PSS with SHA-256, a new certificate will also be signed using RSA-PSS with SHA-256. - Support two latest versions of Go (1.18, 1.19). - Validate revocation serial number (either base 10 or prefixed with an appropriate base). - Sanitize TLS options. ## [0.20.0] - 2022-05-26 ### Added - Added Kubernetes auth method for Vault RAs. - Added support for reporting provisioners to linkedca. - Added support for certificate policies on authority level. - Added a Dockerfile with a step-ca build with HSM support. - A few new WithXX methods for instantiating authorities ### Changed - Context usage in HTTP APIs. - Changed authentication for Vault RAs. - Error message returned to client when authenticating with expired certificate. - Strip padding from ACME CSRs. ### Deprecated - HTTP API handler types. ### Fixed - Fixed SSH revocation. - CA client dial context for js/wasm target. - Incomplete `extraNames` support in templates. - SCEP GET request support. - Large SCEP request handling. ## [0.19.0] - 2022-04-19 ### Added - Added support for certificate renewals after expiry using the claim `allowRenewalAfterExpiry`. - Added support for `extraNames` in X.509 templates. - Added `armv5` builds. - Added RA support using a Vault instance as the CA. - Added `WithX509SignerFunc` authority option. - Added a new `/roots.pem` endpoint to download the CA roots in PEM format. - Added support for Azure `Managed Identity` tokens. - Added support for automatic configuration of linked RAs. - Added support for the `--context` flag. It's now possible to start the CA with `step-ca --context=abc` to use the configuration from context `abc`. When a context has been configured and no configuration file is provided on startup, the configuration for the current context is used. - Added startup info logging and option to skip it (`--quiet`). - Added support for renaming the CA (Common Name). ### Changed - Made SCEP CA URL paths dynamic. - Support two latest versions of Go (1.17, 1.18). - Upgrade go.step.sm/crypto to v0.16.1. - Upgrade go.step.sm/linkedca to v0.15.0. ### Deprecated - Go 1.16 support. ### Removed ### Fixed - Fixed admin credentials on RAs. - Fixed ACME HTTP-01 challenges for IPv6 identifiers. - Various improvements under the hood. ### Security ## [0.18.2] - 2022-03-01 ### Added - Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner. - [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering out database drivers using Go tags. For example, using the Go flag `--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx driver for PostgreSQL. ### Changed - IPv6 addresses are normalized as IP addresses instead of hostnames. - More descriptive JWK decryption error message. - Make the X5C leaf certificate available to the templates using `{{ .AuthorizationCrt }}`. ### Fixed - During provisioner add - validate provisioner configuration before storing to DB. ## [0.18.1] - 2022-02-03 ### Added - Support for ACME revocation. - Replace hash function with an RSA SSH CA to "rsa-sha2-256". - Support Nebula provisioners. - Example Ansible configurations. - Support PKCS#11 as a decrypter, as used by SCEP. ### Changed - Automatically create database directory on `step ca init`. - Slightly improve errors reported when a template has invalid content. - Error reporting in logs and to clients. ### Fixed - SCEP renewal using HTTPS on macOS. ## [0.18.0] - 2021-11-17 ### Added - Support for multiple certificate authority contexts. - Support for generating extractable keys and certificates on a pkcs#11 module. ### Changed - Support two latest versions of Go (1.16, 1.17) ### Deprecated - go 1.15 support ## [0.17.6] - 2021-10-20 ### Notes - 0.17.5 failed in CI/CD ## [0.17.5] - 2021-10-20 ### Added - Support for Azure Key Vault as a KMS. - Adapt `pki` package to support key managers. - gocritic linter ### Fixed - gocritic warnings ## [0.17.4] - 2021-09-28 ### Fixed - Support host-only or user-only SSH CA. ## [0.17.3] - 2021-09-24 ### Added - go 1.17 to github action test matrix - Support for CloudKMS RSA-PSS signers without using templates. - Add flags to support individual passwords for the intermediate and SSH keys. - Global support for group admins in the OIDC provisioner. ### Changed - Using go 1.17 for binaries ### Fixed - Upgrade go-jose.v2 to fix a bug in the JWK fingerprint of Ed25519 keys. ### Security - Use cosign to sign and upload signatures for multi-arch Docker container. - Add debian checksum ## [0.17.2] - 2021-08-30 ### Added - Additional way to distinguish Azure IID and Azure OIDC tokens. ### Security - Sign over all goreleaser github artifacts using cosign ## [0.17.1] - 2021-08-26 ## [0.17.0] - 2021-08-25 ### Added - Add support for Linked CAs using protocol buffers and gRPC - `step-ca init` adds support for - configuring a StepCAS RA - configuring a Linked CA - congifuring a `step-ca` using Helm ### Changed - Update badger driver to use v2 by default - Update TLS cipher suites to include 1.3 ### Security - Fix key version when SHA512WithRSA is used. There was a typo creating RSA keys with SHA256 digests instead of SHA512. ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to `step certificates` We welcome contributions to `step certificates` of any kind including documentation, themes, organization, tutorials, blog posts, bug reports, issues, feature requests, feature implementations, pull requests, helping to manage issues, etc. ## Table of Contents - [Contributing to `step certificates`](#contributing-to-step-certificates) - [Table of Contents](#table-of-contents) - [Building From Source](#building-from-source) - [Build a standard `step-ca`](#build-a-standard-step-ca) - [Build `step-ca` using CGO](#build-step-ca-using-cgo) - [The CGO build enables PKCS #11 and YubiKey PIV support](#the-cgo-build-enables-pkcs-11-and-yubikey-piv-support) - [1. Install PCSC support](#1-install-pcsc-support) - [2. Build `step-ca`](#2-build-step-ca) - [Asking Support Questions](#asking-support-questions) - [Reporting Issues](#reporting-issues) - [Code Contribution](#code-contribution) - [Submitting Patches](#submitting-patches) - [Code Contribution Guidelines](#code-contribution-guidelines) - [Git Commit Message Guidelines](#git-commit-message-guidelines) ## Building From Source Clone this repository to get a bleeding-edge build, or download the source archive for [the latest stable release](https://github.com/smallstep/certificates/releases/latest). ### Build a standard `step-ca` The only prerequisites are [`go`](https://golang.org/) and make. To build from source: make bootstrap && make Find your binaries in `bin/`. ### Build `step-ca` using CGO #### The CGO build enables PKCS #11 and YubiKey PIV support To build the CGO version of `step-ca`, you will need [`go`](https://golang.org/), make, and a C compiler. You'll also need PCSC support on your operating system, as required by the `go-piv` module. On Linux, the [`libpcsclite-dev`](https://pcsclite.apdu.fr/) package provides PCSC support. On macOS and Windows, PCSC support is built into the OS. #### 1. Install PCSC support On Debian-based distributions, run: ```shell sudo apt-get install libpcsclite-dev ``` On Fedora: ```shell sudo yum install pcsc-lite-devel ``` On CentOS: ``` sudo yum install 'dnf-command(config-manager)' sudo yum config-manager --set-enabled PowerTools sudo yum install pcsc-lite-devel ``` #### 2. Build `step-ca` To build `step-ca`, clone this repository and run the following: ```shell make bootstrap && make build GO_ENVS="CGO_ENABLED=1" ``` When the build is complete, you will find binaries in `bin/`. ## Asking Support Questions Feel free to post a question on our [GitHub Discussions](https://github.com/smallstep/certificates/discussions) page, or find us on [Discord](https://bit.ly/step-discord). ## Reporting Issues If you believe you have found a defect in `step certificates` or its documentation, use the GitHub [issue tracker](https://github.com/smallstep/certificates/issues) to report the problem. When reporting the issue, please provide the version of `step certificates` in use (`step-ca version`) and your operating system. ## Code Contribution `step certificates` aims to become a fully featured online Certificate Authority. We encourage all contributions that meet the following criteria: * fit naturally into a Certificate Authority. * strive not to break existing functionality. * close or update an open [`step certificates` issue](https://github.com/smallstep/certificates/issues) **Bug fixes are, of course, always welcome.** ## Submitting Patches `step certificates` welcomes all contributors and contributions. If you are interested in helping with the project, please reach out to us or, better yet, submit a PR :). ### Code Contribution Guidelines Because we want to create the best possible product for our users and the best contribution experience for our developers, we have a set of guidelines which ensure that all contributions are acceptable. The guidelines are not intended as a filter or barrier to participation. If you are unfamiliar with the contribution process, the Smallstep team will guide you in order to get your contribution in accordance with the guidelines. To make the contribution process as seamless as possible, we ask for the following: * Go ahead and fork the project and make your changes. We encourage pull requests to allow for review and discussion of code changes. * When you’re ready to create a pull request, be sure to: * Sign the [CLA](https://cla-assistant.io/smallstep/certificates). * Have test cases for the new code. If you have questions about how to do this, please ask in your pull request. * Run `go fmt`. * Add documentation if you are adding new features or changing functionality. * Squash your commits into a single commit. `git rebase -i`. It’s okay to force update your pull request with `git push -f`. * Follow the **Git Commit Message Guidelines** below. ### Git Commit Message Guidelines This [blog article](http://chris.beams.io/posts/git-commit/) is a good resource for learning how to write good commit messages, the most important part being that each commit message should have a title/subject in imperative mood starting with a capital letter and no trailing period: *"Return error on wrong use of the Paginator"*, **NOT** *"returning some error."* Also, if your commit references one or more GitHub issues, always end your commit message body with *See #1234* or *Fixes #1234*. Replace *1234* with the GitHub issue ID. The last example will close the issue when the commit is merged into *master*. Please use a short and descriptive branch name, e.g. **NOT** "patch-1". It's very common but creates a naming conflict each time when a submission is pulled for a review. An example: ```text Add step certificate install Add a command line utility for installing (and uninstalling) certificates to the local system truststores. This should help developers with local development flows. Fixes #75 ``` ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2020 Smallstep Labs, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: Makefile ================================================ PKG?=github.com/smallstep/certificates/cmd/step-ca BINNAME?=step-ca # Set V to 1 for verbose output from the Makefile Q=$(if $V,,@) PREFIX?= SRC=$(shell find . -type f -name '*.go' -not -path "./vendor/*") GOOS_OVERRIDE ?= all: lint test build ci: testcgo build .PHONY: all ci ######################################### # Bootstrapping ######################################### bootstra%: $Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $$(go env GOPATH)/bin latest $Q go install golang.org/x/vuln/cmd/govulncheck@latest $Q go install gotest.tools/gotestsum@latest $Q go install github.com/goreleaser/goreleaser/v2@latest $Q go install github.com/sigstore/cosign/v2/cmd/cosign@latest .PHONY: bootstra% ################################################# # Determine the type of `push` and `version` ################################################# # GITHUB Actions ifdef GITHUB_REF VERSION ?= $(shell echo $(GITHUB_REF) | sed 's/^refs\/tags\///') NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc) ifeq ($(NOT_RC),) PUSHTYPE := release-candidate else PUSHTYPE := release endif else VERSION ?= $(shell [ -d .git ] && git describe --tags --always --dirty="-dev") # If we are not in an active git dir then try reading the version from .VERSION. # .VERSION contains a slug populated by `git archive`. VERSION := $(or $(VERSION),$(shell ./.version.sh .VERSION)) PUSHTYPE := branch endif VERSION := $(shell echo $(VERSION) | sed 's/^v//') ifdef V $(info GITHUB_REF is $(GITHUB_REF)) $(info VERSION is $(VERSION)) $(info PUSHTYPE is $(PUSHTYPE)) endif ######################################### # Build ######################################### DATE := $(shell date -u '+%Y-%m-%d %H:%M UTC') LDFLAGS := -ldflags='-w -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"' # Always explicitly enable or disable cgo, # so that go doesn't silently fall back on # non-cgo when gcc is not found. ifeq (,$(findstring CGO_ENABLED,$(GO_ENVS))) ifneq ($(origin GOFLAGS),undefined) # This section is for backward compatibility with # # $ make build GOFLAGS="" # # which is how we recommended building step-ca with cgo support # until June 2023. GO_ENVS := $(GO_ENVS) CGO_ENABLED=1 else GO_ENVS := $(GO_ENVS) CGO_ENABLED=0 endif endif download: $Q go mod download build: $(PREFIX)bin/$(BINNAME) @echo "Build Complete!" $(PREFIX)bin/$(BINNAME): download $(call rwildcard,*.go) $Q mkdir -p $(@D) $Q $(GOOS_OVERRIDE) GOFLAGS="$(GOFLAGS)" $(GO_ENVS) go build -v -o $(PREFIX)bin/$(BINNAME) $(LDFLAGS) $(PKG) # Target to force a build of step-ca without running tests simple: build .PHONY: download build simple ######################################### # Go generate ######################################### generate: $Q go generate ./... .PHONY: generate ######################################### # Test ######################################### test: testdefault testtpmsimulator combinecoverage testdefault: $Q $(GO_ENVS) gotestsum -- -coverprofile=defaultcoverage.out -short -covermode=atomic ./... testtpmsimulator: $Q CGO_ENABLED=1 gotestsum -- -coverprofile=tpmsimulatorcoverage.out -short -covermode=atomic -tags tpmsimulator ./acme testcgo: $Q gotestsum -- -coverprofile=coverage.out -short -covermode=atomic ./... combinecoverage: cat defaultcoverage.out tpmsimulatorcoverage.out > coverage.out .PHONY: test testdefault testtpmsimulator testcgo combinecoverage integrate: integration integration: bin/$(BINNAME) $Q $(GO_ENVS) gotestsum -- -tags=integration ./integration/... .PHONY: integrate integration ######################################### # Linting ######################################### fmt: $Q goimports -l -w $(SRC) lint: SHELL:=/bin/bash lint: $Q LOG_LEVEL=error golangci-lint run --config <(curl -s https://raw.githubusercontent.com/smallstep/workflows/master/.golangci.yml) --timeout=30m $Q govulncheck ./... .PHONY: fmt lint ######################################### # Install ######################################### INSTALL_PREFIX?=/usr/local/ install: $(PREFIX)bin/$(BINNAME) $Q install -D $(PREFIX)bin/$(BINNAME) $(DESTDIR)$(INSTALL_PREFIX)bin/$(BINNAME) uninstall: $Q rm -f $(DESTDIR)$(INSTALL_PREFIX)/bin/$(BINNAME) .PHONY: install uninstall ######################################### # Clean ######################################### clean: ifneq ($(BINNAME),"") $Q rm -f bin/$(BINNAME) endif .PHONY: clean ######################################### # Dev ######################################### run: $Q go run cmd/step-ca/main.go $(shell step path)/config/ca.json .PHONY: run ================================================ FILE: README.md ================================================ # step-ca [![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest) [![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates) [![Build Status](https://github.com/smallstep/certificates/actions/workflows/test.yml/badge.svg)](https://github.com/smallstep/certificates) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![CLA assistant](https://cla-assistant.io/readme/badge/smallstep/certificates)](https://cla-assistant.io/smallstep/certificates) `step-ca` is an online certificate authority for secure, automated certificate management for DevOps. It's the server counterpart to the [`step` CLI tool](https://github.com/smallstep/cli) for working with certificates and keys. Both projects are maintained by [Smallstep Labs](https://smallstep.com). You can use `step-ca` to: - Issue HTTPS server and client certificates that [work in browsers](https://smallstep.com/blog/step-v0-8-6-valid-HTTPS-certificates-for-dev-pre-prod.html) ([RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliance) - Issue TLS certificates for DevOps: VMs, containers, APIs, database connections, Kubernetes pods... - Issue SSH certificates: - For people, in exchange for single sign-on identity tokens - For hosts, in exchange for cloud instance identity documents - Easily automate certificate management: - It's an [ACME server](https://smallstep.com/docs/step-ca/acme-basics/) that supports all [popular ACME challenge types](https://smallstep.com/docs/step-ca/acme-basics/#acme-challenge-types) - It comes with a [Go wrapper](./examples#user-content-basic-client-usage) - ... and there's a [command-line client](https://github.com/smallstep/cli) you can use in scripts! --- ### Comparison with Smallstep's commercial product `step-ca` is optimized for a two-tier PKI serving common DevOps use cases. As you design your PKI, if you need any of the following, [consider our commerical CA](http://smallstep.com): - Multiple certificate authorities - Active revocation (CRL, OCSP) - Turnkey high-volume, high availability CA - An API for seamless IaC management of your PKI - Integrated support for SCEP & NDES, for migrating from legacy Active Directory Certificate Services deployments - Device identity — cross-platform device inventory and attestation using Secure Enclave & TPM 2.0 - Highly automated PKI — managed certificate renewal, monitoring, TPM-based attested enrollment - Seamless client deployments of EAP-TLS Wi-Fi, VPN, SSH, and browser certificates - Jamf, Intune, or other MDM for root distribution and client enrollment - Web Admin UI — history, issuance, and metrics - ACME External Account Binding (EAB) - Deep integration with an identity provider - Fine-grained, role-based access control - FIPS-compliant software - HSM-bound private keys See our [full feature comparison](https://smallstep.com/step-ca-vs-smallstep-certificate-manager/) for more. You can [start a free trial](https://smallstep.com/signup) or [set up a call with us](https://go.smallstep.com/request-demo) to learn more. --- **Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions) or [Join our Discord](https://u.step.sm/discord).** [Website](https://smallstep.com/certificates) | [Documentation](https://smallstep.com/docs/step-ca) | [Installation](https://smallstep.com/docs/step-ca/installation) | [Contributor's Guide](./CONTRIBUTING.md) ## Features ### 🦾 A fast, stable, flexible private CA Setting up a *public key infrastructure* (PKI) is out of reach for many small teams. `step-ca` makes it easier. - Choose key types (RSA, ECDSA, EdDSA) and lifetimes to suit your needs - [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation - Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca) - [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases) ### ⚙️ Many ways to automate There are several ways to authorize a request with the CA and establish a chain of trust that suits your flow. You can issue certificates in exchange for: - [ACME challenge responses](#your-own-private-acme-server) from any ACMEv2 client - [OAuth OIDC single sign-on tokens](https://smallstep.com/blog/easily-curl-services-secured-by-https-tls.html), eg: - ID tokens from Okta, GSuite, Azure AD, Auth0. - ID tokens from an OAuth OIDC service that you host, like [Keycloak](https://www.keycloak.org/) or [Dex](https://github.com/dexidp/dex) - [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure - [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc. - A trusted X.509 certificate (X5C provisioner) - A host certificate from your Nebula network - A SCEP challenge (SCEP provisioner) - An SSH host certificates needing renewal (the SSHPOP provisioner) - Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners) ### 🏔 Your own private ACME server ACME is the protocol used by Let's Encrypt to automate the issuance of HTTPS certificates. It's _super easy_ to issue certificates to any ACMEv2 ([RFC8555](https://tools.ietf.org/html/rfc8555)) client. - [Use ACME in development & pre-production](https://smallstep.com/blog/private-acme-server/#local-development--pre-production) - Supports the most popular [ACME challenge types](https://letsencrypt.org/docs/challenge-types/): - For `http-01`, place a token at a well-known URL to prove that you control the web server - For `dns-01`, add a `TXT` record to prove that you control the DNS record set - For `tls-alpn-01`, respond to the challenge at the TLS layer ([as Caddy does](https://caddy.community/t/caddy-supports-the-acme-tls-alpn-challenge/4860)) to prove that you control the web server - Works with any ACME client. We've written examples for: - [certbot](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#certbot) - [acme.sh](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#acmesh) - [win-acme](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#win-acme) - [Caddy](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#caddy-v2) - [Traefik](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#traefik) - [Apache](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#apache) - [nginx](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#nginx) - Get certificates programmatically using ACME, using these libraries: - [`lego`](https://github.com/go-acme/lego) for Golang ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#golang)) - certbot's [`acme` module](https://github.com/certbot/certbot/tree/master/acme) for Python ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#python)) - [`acme-client`](https://github.com/publishlab/node-acme-client) for Node.js ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#node)) - Our own [`step` CLI tool](https://github.com/smallstep/cli) is also an ACME client! - See our [ACME tutorial](https://smallstep.com/docs/tutorials/acme-challenge) for more ### 👩🏽‍💻 An online SSH Certificate Authority - Delegate SSH authentication to `step-ca` by using [SSH certificates](https://smallstep.com/blog/use-ssh-certificates/) instead of public keys and `authorized_keys` files - For user certificates, [connect SSH to your single sign-on provider](https://smallstep.com/blog/diy-single-sign-on-for-ssh/), to improve security with short-lived certificates and MFA (or other security policies) via any OAuth OIDC provider. - For host certificates, improve security, [eliminate TOFU warnings](https://smallstep.com/blog/use-ssh-certificates/), and set up automated host certificate renewal. ### 🤓 A general purpose PKI tool, via [`step` CLI](https://github.com/smallstep/cli) [integration](https://smallstep.com/docs/step-cli/reference/ca/) - Generate key pairs where they're needed so private keys are never transmitted across the network - [Authenticate and obtain a certificate](https://smallstep.com/docs/step-cli/reference/ca/certificate/) using any provisioner supported by `step-ca` - Securely [distribute root certificates](https://smallstep.com/docs/step-cli/reference/ca/root/) and [bootstrap](https://smallstep.com/docs/step-cli/reference/ca/bootstrap/) PKI relying parties - [Renew](https://smallstep.com/docs/step-cli/reference/ca/renew/) and [revoke](https://smallstep.com/docs/step-cli/reference/ca/revoke/) certificates issued by `step-ca` - [Install root certificates](https://smallstep.com/docs/step-cli/reference/certificate/install/) on your machine and browsers, so your CA is trusted - [Inspect](https://smallstep.com/docs/step-cli/reference/certificate/inspect/) and [lint](https://smallstep.com/docs/step-cli/reference/certificate/lint/) certificates ## Installation See our installation docs [here](https://smallstep.com/docs/step-ca/installation). ## Documentation * [Official documentation](https://smallstep.com/docs/step-ca) is on smallstep.com * The `step` command reference is available via `step help`, [on smallstep.com](https://smallstep.com/docs/step-cli/reference/), or by running `step help --http=:8080` from the command line and visiting http://localhost:8080. ## Feedback? * Tell us what you like and don't like about managing your PKI - we're eager to help solve problems in this space. [Join our Discord](https://u.step.sm/discord) or [GitHub Discussions](https://github.com/smallstep/certificates/discussions) * Tell us about a feature you'd like to see! [Request a Feature](https://github.com/smallstep/certificates/issues/new?assignees=&labels=enhancement%2C+needs+triage&template=enhancement.md&title=) ================================================ FILE: SECURITY.md ================================================ We appreciate any effort to discover and disclose security vulnerabilities responsibly. If you would like to report a vulnerability in one of our projects, or have security concerns regarding Smallstep software, please email security@smallstep.com. In order for us to best respond to your report, please include any of the following: * Steps to reproduce or proof-of-concept * Any relevant tools, including versions used * Tool output ================================================ FILE: acme/account.go ================================================ package acme import ( "crypto" "encoding/base64" "encoding/json" "time" "go.step.sm/crypto/jose" "github.com/smallstep/certificates/authority/policy" ) // Account is a subset of the internal account type containing only those // attributes required for responses in the ACME protocol. type Account struct { ID string `json:"-"` Key *jose.JSONWebKey `json:"-"` Contact []string `json:"contact,omitempty"` Status Status `json:"status"` OrdersURL string `json:"orders"` ExternalAccountBinding interface{} `json:"externalAccountBinding,omitempty"` LocationPrefix string `json:"-"` ProvisionerID string `json:"-"` ProvisionerName string `json:"-"` } // GetLocation returns the URL location of the given account. func (a *Account) GetLocation() string { if a.LocationPrefix == "" { return "" } return a.LocationPrefix + a.ID } // ToLog enables response logging. func (a *Account) ToLog() (interface{}, error) { b, err := json.Marshal(a) if err != nil { return nil, WrapErrorISE(err, "error marshaling account for logging") } return string(b), nil } // IsValid returns true if the Account is valid. func (a *Account) IsValid() bool { return a.Status == StatusValid } // KeyToID converts a JWK to a thumbprint. func KeyToID(jwk *jose.JSONWebKey) (string, error) { kid, err := jwk.Thumbprint(crypto.SHA256) if err != nil { return "", WrapErrorISE(err, "error generating jwk thumbprint") } return base64.RawURLEncoding.EncodeToString(kid), nil } // PolicyNames contains ACME account level policy names type PolicyNames struct { DNSNames []string `json:"dns"` IPRanges []string `json:"ips"` } // X509Policy contains ACME account level X.509 policy type X509Policy struct { Allowed PolicyNames `json:"allow"` Denied PolicyNames `json:"deny"` AllowWildcardNames bool `json:"allowWildcardNames"` } // Policy is an ACME Account level policy type Policy struct { X509 X509Policy `json:"x509"` } func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions { if p == nil { return nil } return &policy.X509NameOptions{ DNSDomains: p.X509.Allowed.DNSNames, IPRanges: p.X509.Allowed.IPRanges, } } func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions { if p == nil { return nil } return &policy.X509NameOptions{ DNSDomains: p.X509.Denied.DNSNames, IPRanges: p.X509.Denied.IPRanges, } } // AreWildcardNamesAllowed returns if wildcard names // like *.example.com are allowed to be signed. // Defaults to false. func (p *Policy) AreWildcardNamesAllowed() bool { if p == nil { return false } return p.X509.AllowWildcardNames } // ExternalAccountKey is an ACME External Account Binding key. type ExternalAccountKey struct { ID string `json:"id"` ProvisionerID string `json:"provisionerID"` Reference string `json:"reference"` AccountID string `json:"-"` HmacKey []byte `json:"-"` CreatedAt time.Time `json:"createdAt"` BoundAt time.Time `json:"boundAt,omitempty"` Policy *Policy `json:"policy,omitempty"` } // AlreadyBound returns whether this EAK is already bound to // an ACME Account or not. func (eak *ExternalAccountKey) AlreadyBound() bool { return !eak.BoundAt.IsZero() } // BindTo binds the EAK to an Account. // It returns an error if it's already bound. func (eak *ExternalAccountKey) BindTo(account *Account) error { if eak.AlreadyBound() { return NewError(ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", eak.ID, eak.AccountID, eak.BoundAt) } eak.AccountID = account.ID eak.BoundAt = time.Now() eak.HmacKey = []byte{} // clearing the key bytes; can only be used once return nil } ================================================ FILE: acme/account_test.go ================================================ package acme import ( "crypto" "encoding/base64" "testing" "time" "github.com/pkg/errors" "go.step.sm/crypto/jose" "github.com/smallstep/assert" ) func TestKeyToID(t *testing.T) { type test struct { jwk *jose.JSONWebKey exp string err *Error } tests := map[string]func(t *testing.T) test{ "fail/error-generating-thumbprint": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) jwk.Key = "foo" return test{ jwk: jwk, err: NewErrorISE("error generating jwk thumbprint: go-jose/go-jose: unknown key type 'string'"), } }, "ok": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) kid, err := jwk.Thumbprint(crypto.SHA256) assert.FatalError(t, err) return test{ jwk: jwk, exp: base64.RawURLEncoding.EncodeToString(kid), } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if id, err := KeyToID(tc.jwk); err != nil { if assert.NotNil(t, tc.err) { var k *Error if errors.As(err, &k) { assert.Equals(t, k.Type, tc.err.Type) assert.Equals(t, k.Detail, tc.err.Detail) assert.Equals(t, k.Status, tc.err.Status) assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) assert.Equals(t, k.Detail, tc.err.Detail) } else { assert.FatalError(t, errors.New("unexpected error type")) } } } else { if assert.Nil(t, tc.err) { assert.Equals(t, id, tc.exp) } } }) } } func TestAccount_GetLocation(t *testing.T) { locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/" type test struct { acc *Account exp string } tests := map[string]test{ "empty": {acc: &Account{LocationPrefix: ""}, exp: ""}, "not-empty": {acc: &Account{ID: "bar", LocationPrefix: locationPrefix}, exp: locationPrefix + "bar"}, } for name, tc := range tests { t.Run(name, func(t *testing.T) { assert.Equals(t, tc.acc.GetLocation(), tc.exp) }) } } func TestAccount_IsValid(t *testing.T) { type test struct { acc *Account exp bool } tests := map[string]test{ "valid": {acc: &Account{Status: StatusValid}, exp: true}, "invalid": {acc: &Account{Status: StatusInvalid}, exp: false}, } for name, tc := range tests { t.Run(name, func(t *testing.T) { assert.Equals(t, tc.acc.IsValid(), tc.exp) }) } } func TestExternalAccountKey_BindTo(t *testing.T) { boundAt := time.Now() tests := []struct { name string eak *ExternalAccountKey acct *Account err *Error }{ { name: "ok", eak: &ExternalAccountKey{ ID: "eakID", ProvisionerID: "provID", Reference: "ref", HmacKey: []byte{1, 3, 3, 7}, }, acct: &Account{ ID: "accountID", }, err: nil, }, { name: "fail/already-bound", eak: &ExternalAccountKey{ ID: "eakID", ProvisionerID: "provID", Reference: "ref", HmacKey: []byte{1, 3, 3, 7}, AccountID: "someAccountID", BoundAt: boundAt, }, acct: &Account{ ID: "accountID", }, err: NewError(ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", "eakID", "someAccountID", boundAt), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { eak := tt.eak acct := tt.acct err := eak.BindTo(acct) wantErr := tt.err != nil gotErr := err != nil if wantErr != gotErr { t.Errorf("ExternalAccountKey.BindTo() error = %v, wantErr %v", err, tt.err) } if wantErr { assert.NotNil(t, err) var ae *Error if assert.True(t, errors.As(err, &ae)) { assert.Equals(t, ae.Type, tt.err.Type) assert.Equals(t, ae.Detail, tt.err.Detail) assert.Equals(t, ae.Subproblems, tt.err.Subproblems) } } else { assert.Equals(t, eak.AccountID, acct.ID) assert.Equals(t, eak.HmacKey, []byte{}) assert.NotNil(t, eak.BoundAt) } }) } } ================================================ FILE: acme/api/account.go ================================================ package api import ( "context" "encoding/json" "errors" "net/http" "github.com/go-chi/chi/v5" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/logging" ) // NewAccountRequest represents the payload for a new account request. type NewAccountRequest struct { Contact []string `json:"contact"` OnlyReturnExisting bool `json:"onlyReturnExisting"` TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"` ExternalAccountBinding *ExternalAccountBinding `json:"externalAccountBinding,omitempty"` } func validateContacts(cs []string) error { for _, c := range cs { if c == "" { return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string") } } return nil } // Validate validates a new-account request body. func (n *NewAccountRequest) Validate() error { if n.OnlyReturnExisting && len(n.Contact) > 0 { return acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone") } return validateContacts(n.Contact) } // UpdateAccountRequest represents an update-account request. type UpdateAccountRequest struct { Contact []string `json:"contact"` Status acme.Status `json:"status"` } // Validate validates a update-account request body. func (u *UpdateAccountRequest) Validate() error { switch { case len(u.Status) > 0 && len(u.Contact) > 0: return acme.NewError(acme.ErrorMalformedType, "incompatible input; contact and "+ "status updates are mutually exclusive") case len(u.Contact) > 0: if err := validateContacts(u.Contact); err != nil { return err } return nil case len(u.Status) > 0: if u.Status != acme.StatusDeactivated { return acme.NewError(acme.ErrorMalformedType, "cannot update account "+ "status to %s, only deactivated", u.Status) } return nil default: // According to the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2) // accountUpdate should ignore any fields not recognized by the server. return nil } } // getAccountLocationPath returns the current account URL location. // Returned location will be of the form: https:///acme//account/ func getAccountLocationPath(ctx context.Context, linker acme.Linker, accID string) string { return linker.GetLink(ctx, acme.AccountLinkType, accID) } // NewAccount is the handler resource for creating new ACME accounts. func NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, r, err) return } var nar NewAccountRequest if err := json.Unmarshal(payload.value, &nar); err != nil { render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := nar.Validate(); err != nil { render.Error(w, r, err) return } prov, err := acmeProvisionerFromContext(ctx) if err != nil { render.Error(w, r, err) return } httpStatus := http.StatusCreated acc, err := accountFromContext(ctx) if err != nil { var acmeErr *acme.Error if !errors.As(err, &acmeErr) || acmeErr.Status != http.StatusBadRequest { // Something went wrong ... render.Error(w, r, err) return } // Account does not exist // if nar.OnlyReturnExisting { render.Error(w, r, acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist")) return } jwk, err := jwkFromContext(ctx) if err != nil { render.Error(w, r, err) return } eak, err := validateExternalAccountBinding(ctx, &nar) if err != nil { render.Error(w, r, err) return } acc = &acme.Account{ Key: jwk, Contact: nar.Contact, Status: acme.StatusValid, LocationPrefix: getAccountLocationPath(ctx, linker, ""), ProvisionerID: prov.ID, ProvisionerName: prov.Name, } if err := db.CreateAccount(ctx, acc); err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error creating account")) return } if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response if err := eak.BindTo(acc); err != nil { render.Error(w, r, err) return } if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error updating external account binding key")) return } acc.ExternalAccountBinding = nar.ExternalAccountBinding } } else { // Account exists httpStatus = http.StatusOK } linker.LinkAccount(ctx, acc) w.Header().Set("Location", getAccountLocationPath(ctx, linker, acc.ID)) render.JSONStatus(w, r, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, r, err) return } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, r, err) return } // If PostAsGet just respond with the account, otherwise process like a // normal Post request. if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := uar.Validate(); err != nil { render.Error(w, r, err) return } if len(uar.Status) > 0 || len(uar.Contact) > 0 { if len(uar.Status) > 0 { acc.Status = uar.Status } else if len(uar.Contact) > 0 { acc.Contact = uar.Contact } if err := db.UpdateAccount(ctx, acc); err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error updating account")) return } } } linker.LinkAccount(ctx, acc) w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID)) render.JSON(w, r, acc) } func logOrdersByAccount(w http.ResponseWriter, oids []string) { if rl, ok := w.(logging.ResponseLogger); ok { m := map[string]interface{}{ "orders": oids, } rl.WithFields(m) } } // GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, r, err) return } accID := chi.URLParam(r, "accID") if acc.ID != accID { render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } orders, err := db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { render.Error(w, r, err) return } linker.LinkOrdersByAccountID(ctx, orders) render.JSON(w, r, orders) logOrdersByAccount(w, orders) } ================================================ FILE: acme/api/account_test.go ================================================ package api import ( "bytes" "context" "crypto/x509" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "net/url" "testing" "time" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/pkg/errors" "go.step.sm/crypto/jose" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner" ) var ( defaultDisableRenewal = false defaultDisableSmallstepExtensions = false globalProvisionerClaims = provisioner.Claims{ MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, DisableRenewal: &defaultDisableRenewal, DisableSmallstepExtensions: &defaultDisableSmallstepExtensions, } ) type fakeProvisioner struct{} func (*fakeProvisioner) AuthorizeOrderIdentifier(context.Context, provisioner.ACMEIdentifier) error { return nil } func (*fakeProvisioner) AuthorizeSign(context.Context, string) ([]provisioner.SignOption, error) { return nil, nil } func (*fakeProvisioner) IsChallengeEnabled(context.Context, provisioner.ACMEChallenge) bool { return true } func (*fakeProvisioner) IsAttestationFormatEnabled(context.Context, provisioner.ACMEAttestationFormat) bool { return true } func (*fakeProvisioner) GetAttestationRoots() (*x509.CertPool, bool) { return nil, false } func (*fakeProvisioner) AuthorizeRevoke(context.Context, string) error { return nil } func (*fakeProvisioner) GetID() string { return "" } func (*fakeProvisioner) GetName() string { return "" } func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 } func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil } func newProv() acme.Provisioner { // Initialize provisioners p := &provisioner.ACME{ Type: "ACME", Name: "test@acme-provisioner.com", } if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { fmt.Printf("%v", err) } return p } func newProvWithID() acme.Provisioner { // Initialize provisioners p := &provisioner.ACME{ ID: uuid.NewString(), Type: "ACME", Name: "test@acme-provisioner.com", } if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { fmt.Printf("%v", err) } return p } func newProvWithOptions(options *provisioner.Options) acme.Provisioner { // Initialize provisioners p := &provisioner.ACME{ Type: "ACME", Name: "test@acme-provisioner.com", Options: options, } if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { fmt.Printf("%v", err) } return p } func newACMEProv(t *testing.T) *provisioner.ACME { p := newProv() a, ok := p.(*provisioner.ACME) if !ok { t.Fatal("not a valid ACME provisioner") } return a } func newACMEProvWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME { p := newProvWithOptions(options) a, ok := p.(*provisioner.ACME) if !ok { t.Fatal("not a valid ACME provisioner") } return a } func createEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) (*jose.JSONWebSignature, error) { signer, err := jose.NewSigner( jose.SigningKey{ Algorithm: jose.SignatureAlgorithm("HS256"), Key: hmacKey, }, &jose.SignerOptions{ ExtraHeaders: map[jose.HeaderKey]interface{}{ "kid": keyID, "url": u, }, EmbedJWK: false, }, ) if err != nil { return nil, err } jwkJSONBytes, err := jwk.Public().MarshalJSON() if err != nil { return nil, err } jws, err := signer.Sign(jwkJSONBytes) if err != nil { return nil, err } raw, err := jws.CompactSerialize() if err != nil { return nil, err } parsedJWS, err := jose.ParseJWS(raw) if err != nil { return nil, err } return parsedJWS, nil } func createRawEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) ([]byte, error) { jws, err := createEABJWS(jwk, hmacKey, keyID, u) if err != nil { return nil, err } rawJWS := jws.FullSerialize() return []byte(rawJWS), nil } func TestNewAccountRequest_Validate(t *testing.T) { type test struct { nar *NewAccountRequest err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/incompatible-input": func(t *testing.T) test { return test{ nar: &NewAccountRequest{ OnlyReturnExisting: true, Contact: []string{"foo", "bar"}, }, err: acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone"), } }, "fail/bad-contact": func(t *testing.T) test { return test{ nar: &NewAccountRequest{ Contact: []string{"foo", ""}, }, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "ok": func(t *testing.T) test { return test{ nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, }, } }, "ok/onlyReturnExisting": func(t *testing.T) test { return test{ nar: &NewAccountRequest{ OnlyReturnExisting: true, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { if err := tc.nar.Validate(); err != nil { if assert.NotNil(t, err) { var ae *acme.Error if assert.True(t, errors.As(err, &ae)) { assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } } } else { assert.Nil(t, tc.err) } }) } } func TestUpdateAccountRequest_Validate(t *testing.T) { type test struct { uar *UpdateAccountRequest err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/incompatible-input": func(t *testing.T) test { return test{ uar: &UpdateAccountRequest{ Contact: []string{"foo", "bar"}, Status: "foo", }, err: acme.NewError(acme.ErrorMalformedType, "incompatible input; "+ "contact and status updates are mutually exclusive"), } }, "fail/bad-contact": func(t *testing.T) test { return test{ uar: &UpdateAccountRequest{ Contact: []string{"foo", ""}, }, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "fail/bad-status": func(t *testing.T) test { return test{ uar: &UpdateAccountRequest{ Status: "foo", }, err: acme.NewError(acme.ErrorMalformedType, "cannot update account "+ "status to foo, only deactivated"), } }, "ok/contact": func(t *testing.T) test { return test{ uar: &UpdateAccountRequest{ Contact: []string{"foo", "bar"}, }, } }, "ok/status": func(t *testing.T) test { return test{ uar: &UpdateAccountRequest{ Status: "deactivated", }, } }, "ok/accept-empty": func(t *testing.T) test { return test{ uar: &UpdateAccountRequest{}, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { if err := tc.uar.Validate(); err != nil { if assert.NotNil(t, err) { var ae *acme.Error if assert.True(t, errors.As(err, &ae)) { assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } } } else { assert.Nil(t, tc.err) } }) } } func TestHandler_GetOrdersByAccountID(t *testing.T) { accID := "account-id" // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("accID", accID) prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} u := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID) oids := []string{"foo", "bar"} oidURLs := []string{ fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName), fmt.Sprintf("%s/acme/%s/order/bar", baseURL.String(), provName), } type test struct { db acme.DB ctx context.Context statusCode int err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), accContextKey, http.NoBody), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "foo"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 401, err: acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param"), } }, "fail/db.GetOrdersByAccountID-error": func(t *testing.T) test { acc := &acme.Account{ID: accID} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: accID} ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ db: &acme.MockDB{ MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { assert.Equals(t, id, acc.ID) return oids, nil }, }, ctx: ctx, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() GetOrdersByAccountID(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(oidURLs) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } func TestHandler_NewAccount(t *testing.T) { prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} provID := prov.GetID() type test struct { db acme.DB acc *acme.Account ctx context.Context statusCode int err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-payload": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), payloadContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to "+ "unmarshal new-account request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { nar := &NewAccountRequest{ Contact: []string{"foo", ""}, } b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "fail/no-existing-account": func(t *testing.T) test { nar := &NewAccountRequest{ OnlyReturnExisting: true, } b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-jwk": func(t *testing.T) test { nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, } b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, } b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/new-account-no-eab-provided": func(t *testing.T) test { nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: nil, } b, err := json.Marshal(nar) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"), } }, "fail/db.CreateAccount-error": func(t *testing.T) test { nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, } b, err := json.Marshal(nar) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwkContextKey, jwk) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { assert.Equals(t, acc.Contact, nar.Contact) assert.Equals(t, acc.Key, jwk) return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "fail/acmeProvisionerFromContext": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"), } }, "fail/db.UpdateExternalAccountKey-error": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) eak := &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), } return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { acc.ID = "accountID" assert.Equals(t, acc.Contact, nar.Contact) assert.Equals(t, acc.Key, jwk) return nil }, MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return eak, nil }, MockUpdateExternalAccountKey: func(ctx context.Context, provisionerName string, eak *acme.ExternalAccountKey) error { return errors.New("force") }, }, acc: &acme.Account{ ID: "accountID", Key: jwk, Status: acme.StatusValid, Contact: []string{"foo", "bar"}, OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName), ExternalAccountBinding: eab, }, ctx: ctx, statusCode: 500, err: acme.NewError(acme.ErrorServerInternalType, "error updating external account binding key"), } }, "ok/new-account": func(t *testing.T) test { nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, } b, err := json.Marshal(nar) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { acc.ID = "accountID" assert.Equals(t, acc.Contact, nar.Contact) assert.Equals(t, acc.Key, jwk) return nil }, }, acc: &acme.Account{ ID: "accountID", Key: jwk, Status: acme.StatusValid, Contact: []string{"foo", "bar"}, OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName), }, ctx: ctx, statusCode: 201, } }, "ok/return-existing": func(t *testing.T) test { nar := &NewAccountRequest{ OnlyReturnExisting: true, } b, err := json.Marshal(nar) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ ID: "accountID", Key: jwk, Status: acme.StatusValid, Contact: []string{"foo", "bar"}, } ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, accContextKey, acc) return test{ db: &acme.MockDB{}, ctx: ctx, acc: acc, statusCode: 200, } }, "ok/new-account-no-eab-required": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } b, err := json.Marshal(nar) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = false ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { acc.ID = "accountID" assert.Equals(t, acc.Contact, nar.Contact) assert.Equals(t, acc.Key, jwk) return nil }, }, acc: &acme.Account{ ID: "accountID", Key: jwk, Status: acme.StatusValid, Contact: []string{"foo", "bar"}, OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName), }, ctx: ctx, statusCode: 201, } }, "ok/new-account-with-eab": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { acc.ID = "accountID" assert.Equals(t, acc.Contact, nar.Contact) assert.Equals(t, acc.Key, jwk) return nil }, MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, MockUpdateExternalAccountKey: func(ctx context.Context, provisionerName string, eak *acme.ExternalAccountKey) error { return nil }, }, acc: &acme.Account{ ID: "accountID", Key: jwk, Status: acme.StatusValid, Contact: []string{"foo", "bar"}, OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName), ExternalAccountBinding: eab, }, ctx: ctx, statusCode: 201, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() NewAccount(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(tc.acc) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), escProvName, "accountID")}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } func TestHandler_GetOrUpdateAccount(t *testing.T) { accID := "accountID" acc := acme.Account{ ID: accID, Status: "valid", OrdersURL: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), } prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { db acme.DB ctx context.Context statusCode int err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, &acc) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { uar := &UpdateAccountRequest{ Contact: []string{"foo", ""}, } b, err := json.Marshal(uar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "fail/db.UpdateAccount-error": func(t *testing.T) test { uar := &UpdateAccountRequest{ Status: "deactivated", } b, err := json.Marshal(uar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { assert.Equals(t, upd.Status, acme.StatusDeactivated) assert.Equals(t, upd.ID, acc.ID) return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "ok/deactivate": func(t *testing.T) test { uar := &UpdateAccountRequest{ Status: "deactivated", } b, err := json.Marshal(uar) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { assert.Equals(t, upd.Status, acme.StatusDeactivated) assert.Equals(t, upd.ID, acc.ID) return nil }, }, ctx: ctx, statusCode: 200, } }, "ok/update-empty": func(t *testing.T) test { uar := &UpdateAccountRequest{} b, err := json.Marshal(uar) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 200, } }, "ok/update-contacts": func(t *testing.T) test { uar := &UpdateAccountRequest{ Contact: []string{"foo", "bar"}, } b, err := json.Marshal(uar) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { assert.Equals(t, upd.Contact, uar.Contact) assert.Equals(t, upd.ID, acc.ID) return nil }, }, ctx: ctx, statusCode: 200, } }, "ok/post-as-get": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() GetOrUpdateAccount(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(acc) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), escProvName, accID)}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } ================================================ FILE: acme/api/eab.go ================================================ package api import ( "context" "encoding/json" "errors" "go.step.sm/crypto/jose" "github.com/smallstep/certificates/acme" ) // ExternalAccountBinding represents the ACME externalAccountBinding JWS type ExternalAccountBinding struct { Protected string `json:"protected"` Payload string `json:"payload"` Sig string `json:"signature"` } // validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account. func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context") } if !acmeProv.RequireEAB { //nolint:nilnil // legacy return nil, nil } if nar.ExternalAccountBinding == nil { return nil, acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided") } eabJSONBytes, err := json.Marshal(nar.ExternalAccountBinding) if err != nil { return nil, acme.WrapErrorISE(err, "error marshaling externalAccountBinding into bytes") } eabJWS, err := jose.ParseJWS(string(eabJSONBytes)) if err != nil { return nil, acme.WrapErrorISE(err, "error parsing externalAccountBinding jws") } // TODO(hs): implement strategy pattern to allow for different ways of verification (i.e. webhook call) based on configuration? keyID, acmeErr := validateEABJWS(ctx, eabJWS) if acmeErr != nil { return nil, acmeErr } db := acme.MustDatabaseFromContext(ctx) externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) if err != nil { var ae *acme.Error if errors.As(err, &ae) { return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key") } return nil, acme.WrapErrorISE(err, "error retrieving external account key") } if externalAccountKey == nil { return nil, acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key") } if len(externalAccountKey.HmacKey) == 0 { return nil, acme.NewError(acme.ErrorServerInternalType, "external account binding key with id '%s' does not have secret bytes", keyID) } if externalAccountKey.AlreadyBound() { return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt) } payload, err := eabJWS.Verify(externalAccountKey.HmacKey) if err != nil { return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature") } jwk, err := jwkFromContext(ctx) if err != nil { return nil, err } var payloadJWK *jose.JSONWebKey if err = json.Unmarshal(payload, &payloadJWK); err != nil { return nil, acme.WrapError(acme.ErrorMalformedType, err, "error unmarshaling payload into jwk") } if !keysAreEqual(jwk, payloadJWK) { return nil, acme.NewError(acme.ErrorUnauthorizedType, "keys in jws and eab payload do not match") } return externalAccountKey, nil } // keysAreEqual performs an equality check on two JWKs by comparing // the (base64 encoding) of the Key IDs. func keysAreEqual(x, y *jose.JSONWebKey) bool { if x == nil || y == nil { return false } digestX, errX := acme.KeyToID(x) digestY, errY := acme.KeyToID(y) if errX != nil || errY != nil { return false } return digestX == digestY } // validateEABJWS verifies the contents of the External Account Binding JWS. // The protected header of the JWS MUST meet the following criteria: // // - The "alg" field MUST indicate a MAC-based algorithm // - The "kid" field MUST contain the key identifier provided by the CA // - The "nonce" field MUST NOT be present // - The "url" field MUST be set to the same value as the outer JWS func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) { if jws == nil { return "", acme.NewErrorISE("no JWS provided") } if len(jws.Signatures) != 1 { return "", acme.NewError(acme.ErrorMalformedType, "JWS must have one signature") } header := jws.Signatures[0].Protected algorithm := header.Algorithm keyID := header.KeyID nonce := header.Nonce if algorithm != jose.HS256 && algorithm != jose.HS384 && algorithm != jose.HS512 { return "", acme.NewError(acme.ErrorMalformedType, "'alg' field set to invalid algorithm '%s'", algorithm) } if keyID == "" { return "", acme.NewError(acme.ErrorMalformedType, "'kid' field is required") } if nonce != "" { return "", acme.NewError(acme.ErrorMalformedType, "'nonce' must not be present") } jwsURL, ok := header.ExtraHeaders["url"] if !ok { return "", acme.NewError(acme.ErrorMalformedType, "'url' field is required") } outerJWS, err := jwsFromContext(ctx) if err != nil { return "", acme.WrapErrorISE(err, "could not retrieve outer JWS from context") } if len(outerJWS.Signatures) != 1 { return "", acme.NewError(acme.ErrorMalformedType, "outer JWS must have one signature") } outerJWSURL, ok := outerJWS.Signatures[0].Protected.ExtraHeaders["url"] if !ok { return "", acme.NewError(acme.ErrorMalformedType, "'url' field must be set in outer JWS") } if jwsURL != outerJWSURL { return "", acme.NewError(acme.ErrorMalformedType, "'url' field is not the same value as the outer JWS") } return keyID, nil } ================================================ FILE: acme/api/eab_test.go ================================================ package api import ( "context" "encoding/json" "fmt" "net/url" "testing" "time" "github.com/pkg/errors" "go.step.sm/crypto/jose" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" ) func Test_keysAreEqual(t *testing.T) { jwkX, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) jwkY, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) wrongJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) wrongJWK.Key = struct{}{} type args struct { x *jose.JSONWebKey y *jose.JSONWebKey } tests := []struct { name string args args want bool }{ { name: "ok/nil", args: args{ x: jwkX, y: nil, }, want: false, }, { name: "ok/equal", args: args{ x: jwkX, y: jwkX, }, want: true, }, { name: "ok/not-equal", args: args{ x: jwkX, y: jwkY, }, want: false, }, { name: "ok/wrong-key-type", args: args{ x: wrongJWK, y: jwkY, }, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := keysAreEqual(tt.args.x, tt.args.y); got != tt.want { t.Errorf("keysAreEqual() = %v, want %v", got, tt.want) } }) } } func TestHandler_validateExternalAccountBinding(t *testing.T) { acmeProv := newACMEProv(t) escProvName := url.PathEscape(acmeProv.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} provID := acmeProv.GetID() type test struct { db acme.DB ctx context.Context nar *NewAccountRequest eak *acme.ExternalAccountKey err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok/no-eab-required-but-provided": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) prov := newACMEProv(t) ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: nil, } }, "ok/eab": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: createdAt, }, nil }, }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: createdAt, }, err: nil, } }, "fail/acmeProvisionerFromContext": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{}) return test{ ctx: ctx, err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"), } }, "fail/parse-eab-jose": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) eab.Payload += "{}" prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewErrorISE("error parsing externalAccountBinding jws"), } }, "fail/validate-eab-jws-no-signatures": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) parsedJWS.Signatures = []jose.Signature{} prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{}, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewError(acme.ErrorMalformedType, "outer JWS must have one signature"), } }, "fail/retrieve-eab-key-db-failure": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ MockError: errors.New("db failure"), }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewErrorISE("error retrieving external account key"), } }, "fail/db.GetExternalAccountKey-not-found": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return nil, acme.ErrNotFound }, }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewErrorISE("error retrieving external account key"), } }, "fail/db.GetExternalAccountKey-error": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return nil, errors.New("force") }, }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewErrorISE("error retrieving external account key"), } }, "fail/db.GetExternalAccountKey-nil": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return nil, nil }, }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key"), } }, "fail/db.GetExternalAccountKey-no-keybytes": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", CreatedAt: createdAt, AccountID: "some-account-id", HmacKey: []byte{}, }, nil }, }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewError(acme.ErrorServerInternalType, "external account binding key with id 'eakID' does not have secret bytes"), } }, "fail/db.GetExternalAccountKey-wrong-provisioner": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ MockError: acme.NewError(acme.ErrorUnauthorizedType, "name of provisioner does not match provisioner for which the EAB key was created"), }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key: name of provisioner does not match provisioner for which the EAB key was created"), } }, "fail/eab-already-bound": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() boundAt := time.Now().Add(1 * time.Second) return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", CreatedAt: createdAt, AccountID: "some-account-id", HmacKey: []byte{1, 3, 3, 7}, BoundAt: boundAt, }, nil }, }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", "eakID", "some-account-id", boundAt), } }, "fail/eab-verify": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", HmacKey: []byte{1, 2, 3, 4}, CreatedAt: time.Now(), }, nil }, }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewErrorISE("error verifying externalAccountBinding signature"), } }, "fail/eab-non-matching-keys": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) differentJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(differentJWK, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewError(acme.ErrorUnauthorizedType, "keys in jws and eab payload do not match"), } }, "fail/no-jwk": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewError(acme.ErrorServerInternalType, "jwk expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal(rawEABJWS, &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, nil) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { return &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: provID, Reference: "testeak", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Now(), }, nil }, }, ctx: ctx, nar: &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, eak: nil, err: acme.NewError(acme.ErrorServerInternalType, "jwk expected in request context"), } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := acme.NewDatabaseContext(tc.ctx, tc.db) got, err := validateExternalAccountBinding(ctx, tc.nar) wantErr := tc.err != nil gotErr := err != nil if wantErr != gotErr { t.Errorf("Handler.validateExternalAccountBinding() error = %v, want %v", err, tc.err) } if wantErr { assert.NotNil(t, err) assert.Type(t, &acme.Error{}, err) var ae *acme.Error if assert.True(t, errors.As(err, &ae)) { assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Status, tc.err.Status) assert.HasPrefix(t, ae.Err.Error(), tc.err.Err.Error()) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) } } else { if got == nil { assert.Nil(t, tc.eak) } else { assert.NotNil(t, tc.eak) assert.Equals(t, got.ID, tc.eak.ID) assert.Equals(t, got.HmacKey, tc.eak.HmacKey) assert.Equals(t, got.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, got.Reference, tc.eak.Reference) assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt) assert.Equals(t, got.AccountID, tc.eak.AccountID) assert.Equals(t, got.BoundAt, tc.eak.BoundAt) } } }) } } func Test_validateEABJWS(t *testing.T) { acmeProv := newACMEProv(t) escProvName := url.PathEscape(acmeProv.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { ctx context.Context jws *jose.JSONWebSignature keyID string err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/nil-jws": func(t *testing.T) test { return test{ jws: nil, err: acme.NewErrorISE("no JWS provided"), } }, "fail/invalid-number-of-signatures": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eabJWS.Signatures = append(eabJWS.Signatures, jose.Signature{}) return test{ jws: eabJWS, err: acme.NewError(acme.ErrorMalformedType, "JWS must have one signature"), } }, "fail/invalid-algorithm": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eabJWS.Signatures[0].Protected.Algorithm = "HS42" return test{ jws: eabJWS, err: acme.NewError(acme.ErrorMalformedType, "'alg' field set to invalid algorithm 'HS42'"), } }, "fail/kid-not-set": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eabJWS.Signatures[0].Protected.KeyID = "" return test{ jws: eabJWS, err: acme.NewError(acme.ErrorMalformedType, "'kid' field is required"), } }, "fail/nonce-not-empty": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) eabJWS.Signatures[0].Protected.Nonce = "some-bogus-nonce" return test{ jws: eabJWS, err: acme.NewError(acme.ErrorMalformedType, "'nonce' must not be present"), } }, "fail/url-not-set": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) delete(eabJWS.Signatures[0].Protected.ExtraHeaders, "url") return test{ jws: eabJWS, err: acme.NewError(acme.ErrorMalformedType, "'url' field is required"), } }, "fail/no-outer-jws": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) ctx := context.WithValue(context.TODO(), jwsContextKey, nil) return test{ ctx: ctx, jws: eabJWS, err: acme.NewErrorISE("could not retrieve outer JWS from context"), } }, "fail/outer-jws-multiple-signatures": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) rawEABJWS := eabJWS.FullSerialize() assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal([]byte(rawEABJWS), &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) outerJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) outerJWS.Signatures = append(outerJWS.Signatures, jose.Signature{}) ctx := context.WithValue(context.TODO(), jwsContextKey, outerJWS) return test{ ctx: ctx, jws: eabJWS, err: acme.NewError(acme.ErrorMalformedType, "outer JWS must have one signature"), } }, "fail/outer-jws-no-url": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) rawEABJWS := eabJWS.FullSerialize() assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal([]byte(rawEABJWS), &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) outerJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) ctx := context.WithValue(context.TODO(), jwsContextKey, outerJWS) return test{ ctx: ctx, jws: eabJWS, err: acme.NewError(acme.ErrorMalformedType, "'url' field must be set in outer JWS"), } }, "fail/outer-jws-with-different-url": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) rawEABJWS := eabJWS.FullSerialize() assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal([]byte(rawEABJWS), &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", "this-is-not-the-same-url-as-in-the-eab-jws") signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) outerJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) ctx := context.WithValue(context.TODO(), jwsContextKey, outerJWS) return test{ ctx: ctx, jws: eabJWS, err: acme.NewError(acme.ErrorMalformedType, "'url' field is not the same value as the outer JWS"), } }, "ok": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) assert.FatalError(t, err) rawEABJWS := eabJWS.FullSerialize() assert.FatalError(t, err) eab := &ExternalAccountBinding{} err = json.Unmarshal([]byte(rawEABJWS), &eab) assert.FatalError(t, err) nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, } payloadBytes, err := json.Marshal(nar) assert.FatalError(t, err) so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) so.WithHeader("url", url) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign(payloadBytes) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) outerJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) ctx := context.WithValue(context.TODO(), jwsContextKey, outerJWS) return test{ ctx: ctx, jws: eabJWS, keyID: "eakID", err: nil, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { keyID, err := validateEABJWS(tc.ctx, tc.jws) wantErr := tc.err != nil gotErr := err != nil if wantErr != gotErr { t.Errorf("validateEABJWS() error = %v, want %v", err, tc.err) } if wantErr { assert.NotNil(t, err) assert.Equals(t, tc.err.Type, err.Type) assert.Equals(t, tc.err.Status, err.Status) assert.HasPrefix(t, err.Err.Error(), tc.err.Err.Error()) assert.Equals(t, tc.err.Detail, err.Detail) assert.Equals(t, tc.err.Subproblems, err.Subproblems) } else { assert.Nil(t, err) assert.Equals(t, tc.keyID, keyID) } }) } } ================================================ FILE: acme/api/handler.go ================================================ package api import ( "context" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "net/http" "time" "github.com/go-chi/chi/v5" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" ) func link(url, typ string) string { return fmt.Sprintf("<%s>;rel=%q", url, typ) } // Clock that returns time in UTC rounded to seconds. type Clock struct{} // Now returns the UTC time rounded to seconds. func (c *Clock) Now() time.Time { return time.Now().UTC().Truncate(time.Second) } var clock Clock type payloadInfo struct { value []byte isPostAsGet bool isEmptyJSON bool } // HandlerOptions required to create a new ACME API request handler. type HandlerOptions struct { // DB storage backend that implements the acme.DB interface. // // Deprecated: use acme.NewContex(context.Context, acme.DB) DB acme.DB // CA is the certificate authority interface. // // Deprecated: use authority.NewContext(context.Context, *authority.Authority) CA acme.CertificateAuthority // Backdate is the duration that the CA will subtract from the current time // to set the NotBefore in the certificate. Backdate provisioner.Duration // DNS the host used to generate accurate ACME links. By default the authority // will use the Host from the request, so this value will only be used if // request.Host is empty. DNS string // Prefix is a URL path prefix under which the ACME api is served. This // prefix is required to generate accurate ACME links. // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- // "acme" is the prefix from which the ACME api is accessed. Prefix string // PrerequisitesChecker checks if all prerequisites for serving ACME are // met by the CA configuration. PrerequisitesChecker func(ctx context.Context) (bool, error) } var mustAuthority = func(ctx context.Context) acme.CertificateAuthority { return authority.MustFromContext(ctx) } // handler is the ACME API request handler. type handler struct { opts *HandlerOptions } // Route traffic and implement the Router interface. For backward compatibility // this route adds will add a new middleware that will set the ACME components // on the context. // // Note: this method is deprecated in step-ca, other applications can still use // this to support ACME, but the recommendation is to use use // api.Route(api.Router) and acme.NewContext() instead. func (h *handler) Route(r api.Router) { client := acme.NewClient() linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix) route(r, func(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil { ctx = authority.NewContext(ctx, ca) } ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker) next(w, r.WithContext(ctx)) } }) } // NewHandler returns a new ACME API handler. // // Note: this method is deprecated in step-ca, other applications can still use // this to support ACME, but the recommendation is to use use // api.Route(api.Router) and acme.NewContext() instead. func NewHandler(opts HandlerOptions) api.RouterHandler { return &handler{ opts: &opts, } } // Route traffic and implement the Router interface. This method requires that // all the acme components, authority, db, client, linker, and prerequisite // checker to be present in the context. func Route(r api.Router) { route(r, nil) } func route(r api.Router, middleware func(next nextHTTP) nextHTTP) { commonMiddleware := func(next nextHTTP) nextHTTP { handler := func(w http.ResponseWriter, r *http.Request) { // Linker middleware gets the provisioner and current url from the // request and sets them in the context. linker := acme.MustLinkerFromContext(r.Context()) linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r) } if middleware != nil { handler = middleware(handler) } return handler } validatingMiddleware := func(next nextHTTP) nextHTTP { return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next)))))) } extractPayloadByJWK := func(next nextHTTP) nextHTTP { return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next))) } extractPayloadByKid := func(next nextHTTP) nextHTTP { return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next))) } extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next))) } getPath := acme.GetUnescapedPathSuffix // Standard ACME API r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"), commonMiddleware(addNonce(addDirLink(GetNonce)))) r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"), commonMiddleware(addNonce(addDirLink(GetNonce)))) r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"), commonMiddleware(GetDirectory)) r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"), commonMiddleware(GetDirectory)) r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(NewAccount)) r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(GetOrUpdateAccount)) r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(NotImplemented)) r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(NewOrder)) r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(isPostAsGet(GetOrder))) r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(isPostAsGet(GetOrdersByAccountID))) r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(FinalizeOrder)) r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(isPostAsGet(GetAuthorization))) r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(GetChallenge)) r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(isPostAsGet(GetCertificate))) r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(RevokeCert)) } // GetNonce just sets the right header since a Nonce is added to each response // by middleware by default. func GetNonce(w http.ResponseWriter, r *http.Request) { if r.Method == "HEAD" { w.WriteHeader(http.StatusOK) } else { w.WriteHeader(http.StatusNoContent) } } type Meta struct { TermsOfService string `json:"termsOfService,omitempty"` Website string `json:"website,omitempty"` CaaIdentities []string `json:"caaIdentities,omitempty"` ExternalAccountRequired bool `json:"externalAccountRequired,omitempty"` } // Directory represents an ACME directory for configuring clients. type Directory struct { NewNonce string `json:"newNonce"` NewAccount string `json:"newAccount"` NewOrder string `json:"newOrder"` RevokeCert string `json:"revokeCert"` KeyChange string `json:"keyChange"` Meta *Meta `json:"meta,omitempty"` } // ToLog enables response logging for the Directory type. func (d *Directory) ToLog() (interface{}, error) { b, err := json.Marshal(d) if err != nil { return nil, acme.WrapErrorISE(err, "error marshaling directory for logging") } return string(b), nil } // GetDirectory is the ACME resource for returning a directory configuration // for client configuration. func GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { render.Error(w, r, err) return } linker := acme.MustLinkerFromContext(ctx) render.JSON(w, r, &Directory{ NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType), NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType), NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType), RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType), KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType), Meta: createMetaObject(acmeProv), }) } // createMetaObject creates a Meta object if the ACME provisioner // has one or more properties that are written in the ACME directory output. // It returns nil if none of the properties are set. func createMetaObject(p *provisioner.ACME) *Meta { if shouldAddMetaObject(p) { return &Meta{ TermsOfService: p.TermsOfService, Website: p.Website, CaaIdentities: p.CaaIdentities, ExternalAccountRequired: p.RequireEAB, } } return nil } // shouldAddMetaObject returns whether or not the ACME provisioner // has properties configured that must be added to the ACME directory object. func shouldAddMetaObject(p *provisioner.ACME) bool { switch { case p.TermsOfService != "": return true case p.Website != "": return true case len(p.CaaIdentities) > 0: return true case p.RequireEAB: return true default: return false } } // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. func NotImplemented(w http.ResponseWriter, r *http.Request) { render.Error(w, r, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthorization ACME api for retrieving an Authz. func GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, r, err) return } az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error retrieving authorization")) return } if acc.ID != az.AccountID { render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } if err = az.UpdateStatus(ctx, db); err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error updating authorization status")) return } linker.LinkAuthorization(ctx, az) w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID)) render.JSON(w, r, az) } // GetChallenge ACME api for retrieving a Challenge. func GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, r, err) return } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, r, err) return } // NOTE: We should be checking that the request is either a POST-as-GET, or // that for all challenges except for device-attest-01, the payload is an // empty JSON block ({}). However, older ACME clients still send a vestigial // body (rather than an empty JSON block) and strict enforcement would // render these clients broken. azID := chi.URLParam(r, "authzID") ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error retrieving challenge")) return } ch.AuthorizationID = azID if acc.ID != ch.AccountID { render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) return } jwk, err := jwkFromContext(ctx) if err != nil { render.Error(w, r, err) return } if err = ch.Validate(ctx, db, jwk, payload.value); err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error validating challenge")) return } linker.LinkChallenge(ctx, ch, azID) w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up")) w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID)) render.JSON(w, r, ch) } // GetCertificate ACME api for retrieving a Certificate. func GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, r, err) return } certID := chi.URLParam(r, "certID") cert, err := db.GetCertificate(ctx, certID) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error retrieving certificate")) return } if cert.AccountID != acc.ID { render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own certificate '%s'", acc.ID, certID)) return } var certBytes []byte for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) { certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: c.Raw, })...) } api.LogCertificate(w, cert.Leaf) w.Header().Set("Content-Type", "application/pem-certificate-chain") w.Write(certBytes) } ================================================ FILE: acme/api/handler_test.go ================================================ package api import ( "bytes" "context" "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "io" "net/http" "net/http/httptest" "net/url" "testing" "time" "github.com/go-chi/chi/v5" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner" ) type mockClient struct { get func(url string) (*http.Response, error) lookupTxt func(name string) ([]string, error) tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) } func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) } func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { return m.tlsDial(network, addr, config) } func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) { t.Helper() fn := mustAuthority t.Cleanup(func() { mustAuthority = fn }) mustAuthority = func(ctx context.Context) acme.CertificateAuthority { return a } } func TestHandler_GetNonce(t *testing.T) { tests := []struct { name string statusCode int }{ {"GET", 204}, {"HEAD", 200}, } // Request with chi context req := httptest.NewRequest("GET", "http://ca.smallstep.com/nonce", http.NoBody) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // h := &Handler{} w := httptest.NewRecorder() req.Method = tt.name GetNonce(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("Handler.GetNonce StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } }) } } func TestHandler_GetDirectory(t *testing.T) { linker := acme.NewLinker("ca.smallstep.com", "acme") _ = linker type test struct { ctx context.Context statusCode int dir Directory err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-provisioner": func(t *testing.T) test { return test{ ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("provisioner is not in context"), } }, "fail/different-provisioner": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{}) return test{ ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), } }, "ok": func(t *testing.T) test { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), } return test{ ctx: ctx, dir: expDir, statusCode: 200, } }, "ok/eab-required": func(t *testing.T) test { prov := newACMEProv(t) prov.RequireEAB = true provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), Meta: &Meta{ ExternalAccountRequired: true, }, } return test{ ctx: ctx, dir: expDir, statusCode: 200, } }, "ok/full-meta": func(t *testing.T) test { prov := newACMEProv(t) prov.TermsOfService = "https://terms.ca.local/" prov.Website = "https://ca.local/" prov.CaaIdentities = []string{"ca.local"} prov.RequireEAB = true provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), Meta: &Meta{ TermsOfService: "https://terms.ca.local/", Website: "https://ca.local/", CaaIdentities: []string{"ca.local"}, ExternalAccountRequired: true, }, } return test{ ctx: ctx, dir: expDir, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", "/foo/bar", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() GetDirectory(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { var dir Directory json.Unmarshal(bytes.TrimSpace(body), &dir) if !cmp.Equal(tc.dir, dir) { t.Errorf("GetDirectory() diff =\n%s", cmp.Diff(tc.dir, dir)) } assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } func TestHandler_GetAuthorization(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) az := acme.Authorization{ ID: "authzID", AccountID: "accID", Identifier: acme.Identifier{ Type: "dns", Value: "example.com", }, Status: "pending", ExpiresAt: expiry, Wildcard: false, Challenges: []*acme.Challenge{ { Type: "http-01", Status: "pending", Token: "tok2", URL: "https://ca.smallstep.com/acme/challenge/chHTTPID", ID: "chHTTP01ID", }, { Type: "dns-01", Status: "pending", Token: "tok2", URL: "https://ca.smallstep.com/acme/challenge/chDNSID", ID: "chDNSID", }, }, } prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("authzID", az.ID) u := fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, az.ID) type test struct { db acme.DB ctx context.Context statusCode int err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/db.GetAuthorization-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { assert.Equals(t, id, az.ID) return &acme.Authorization{ AccountID: "foo", }, nil }, }, ctx: ctx, statusCode: 401, err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), } }, "fail/db.UpdateAuthorization-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { assert.Equals(t, id, az.ID) return &acme.Authorization{ AccountID: "accID", Status: acme.StatusPending, ExpiresAt: time.Now().Add(-1 * time.Hour), }, nil }, MockUpdateAuthorization: func(ctx context.Context, az *acme.Authorization) error { assert.Equals(t, az.Status, acme.StatusInvalid) return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { assert.Equals(t, id, az.ID) return &az, nil }, }, ctx: ctx, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() GetAuthorization(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { //var gotAz acme.Authz //assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &gotAz)) expB, err := json.Marshal(az) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{u}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } func TestHandler_GetCertificate(t *testing.T) { leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt") assert.FatalError(t, err) inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt") assert.FatalError(t, err) root, err := pemutil.ReadCertificate("../../authority/testdata/certs/root_ca.crt") assert.FatalError(t, err) certBytes := append(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: leaf.Raw, }), pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: inter.Raw, })...) certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: root.Raw, })...) certID := "certID" prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("certID", certID) u := fmt.Sprintf("%s/acme/%s/certificate/%s", baseURL.String(), provName, certID) type test struct { db acme.DB ctx context.Context statusCode int err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/db.GetCertificate-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) return &acme.Certificate{AccountID: "foo"}, nil }, }, ctx: ctx, statusCode: 401, err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) return &acme.Certificate{ AccountID: "accID", OrderID: "ordID", Leaf: leaf, Intermediates: []*x509.Certificate{inter, root}, ID: id, }, nil }, }, ctx: ctx, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := acme.NewDatabaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() GetCertificate(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.HasPrefix(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes)) assert.Equals(t, res.Header["Content-Type"], []string{"application/pem-certificate-chain"}) } }) } } func TestHandler_GetChallenge(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("chID", "chID") chiCtx.URLParams.Add("authzID", "authzID") prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} u := fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL.String(), provName, "authzID", "chID") type test struct { db acme.DB vc acme.Client ctx context.Context statusCode int ch *acme.Challenge err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), } }, "fail/db.GetChallenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { assert.Equals(t, chID, "chID") assert.Equals(t, azID, "authzID") return nil, acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { assert.Equals(t, chID, "chID") assert.Equals(t, azID, "authzID") return &acme.Challenge{AccountID: "foo"}, nil }, }, ctx: ctx, statusCode: 401, err: acme.NewError(acme.ErrorUnauthorizedType, "accout id mismatch"), } }, "fail/no-jwk": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { assert.Equals(t, chID, "chID") assert.Equals(t, azID, "authzID") return &acme.Challenge{AccountID: "accID"}, nil }, }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("missing jwk"), } }, "fail/nil-jwk": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, jwkContextKey, nil) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { assert.Equals(t, chID, "chID") assert.Equals(t, azID, "authzID") return &acme.Challenge{AccountID: "accID"}, nil }, }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("nil jwk"), } }, "fail/validate-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() ctx = context.WithValue(ctx, jwkContextKey, &_pub) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { assert.Equals(t, chID, "chID") assert.Equals(t, azID, "authzID") return &acme.Challenge{ Status: acme.StatusPending, Type: acme.HTTP01, AccountID: "accID", }, nil }, MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error { assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Type, acme.HTTP01) assert.Equals(t, ch.AccountID, "accID") assert.Equals(t, ch.AuthorizationID, "authzID") assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String()) return acme.NewErrorISE("force") }, }, vc: &mockClient{ get: func(string) (*http.Response, error) { return nil, errors.New("force") }, }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() ctx = context.WithValue(ctx, jwkContextKey, &_pub) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { assert.Equals(t, chID, "chID") assert.Equals(t, azID, "authzID") return &acme.Challenge{ ID: "chID", Status: acme.StatusPending, Type: acme.HTTP01, AccountID: "accID", }, nil }, MockUpdateChallenge: func(ctx context.Context, ch *acme.Challenge) error { assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Type, acme.HTTP01) assert.Equals(t, ch.AccountID, "accID") assert.Equals(t, ch.AuthorizationID, "authzID") assert.HasSuffix(t, ch.Error.Type, acme.ErrorConnectionType.String()) return nil }, }, ch: &acme.Challenge{ ID: "chID", Status: acme.StatusPending, AuthorizationID: "authzID", Type: acme.HTTP01, AccountID: "accID", URL: u, Error: acme.NewError(acme.ErrorConnectionType, "force"), }, vc: &mockClient{ get: func(string) (*http.Response, error) { return nil, errors.New("force") }, }, ctx: ctx, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() GetChallenge(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(tc.ch) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")}) assert.Equals(t, res.Header["Location"], []string{u}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } func Test_createMetaObject(t *testing.T) { tests := []struct { name string p *provisioner.ACME want *Meta }{ { name: "no-meta", p: &provisioner.ACME{ Type: "ACME", Name: "acme", }, want: nil, }, { name: "terms-of-service", p: &provisioner.ACME{ Type: "ACME", Name: "acme", TermsOfService: "https://terms.ca.local", }, want: &Meta{ TermsOfService: "https://terms.ca.local", }, }, { name: "website", p: &provisioner.ACME{ Type: "ACME", Name: "acme", Website: "https://ca.local", }, want: &Meta{ Website: "https://ca.local", }, }, { name: "caa", p: &provisioner.ACME{ Type: "ACME", Name: "acme", CaaIdentities: []string{"ca.local", "ca.remote"}, }, want: &Meta{ CaaIdentities: []string{"ca.local", "ca.remote"}, }, }, { name: "require-eab", p: &provisioner.ACME{ Type: "ACME", Name: "acme", RequireEAB: true, }, want: &Meta{ ExternalAccountRequired: true, }, }, { name: "full-meta", p: &provisioner.ACME{ Type: "ACME", Name: "acme", TermsOfService: "https://terms.ca.local", Website: "https://ca.local", CaaIdentities: []string{"ca.local", "ca.remote"}, RequireEAB: true, }, want: &Meta{ TermsOfService: "https://terms.ca.local", Website: "https://ca.local", CaaIdentities: []string{"ca.local", "ca.remote"}, ExternalAccountRequired: true, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := createMetaObject(tt.p) if !cmp.Equal(tt.want, got) { t.Errorf("createMetaObject() diff =\n%s", cmp.Diff(tt.want, got)) } }) } } ================================================ FILE: acme/api/middleware.go ================================================ package api import ( "context" "crypto/rsa" "errors" "io" "net/http" "net/url" "path" "strings" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" ) type nextHTTP = func(http.ResponseWriter, *http.Request) func logNonce(w http.ResponseWriter, nonce string) { if rl, ok := w.(logging.ResponseLogger); ok { m := map[string]interface{}{ "nonce": nonce, } rl.WithFields(m) } } // addNonce is a middleware that adds a nonce to the response header. func addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { db := acme.MustDatabaseFromContext(r.Context()) nonce, err := db.CreateNonce(r.Context()) if err != nil { render.Error(w, r, err) return } w.Header().Set("Replay-Nonce", string(nonce)) w.Header().Set("Cache-Control", "no-store") logNonce(w, string(nonce)) next(w, r) } } // addDirLink is a middleware that adds a 'Link' response reader with the // directory index url. func addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() linker := acme.MustLinkerFromContext(ctx) w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index")) next(w, r) } } // verifyContentType is a middleware that verifies that content type is // application/jose+json. func verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { p, err := provisionerFromContext(r.Context()) if err != nil { render.Error(w, r, err) return } u := &url.URL{ Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""), } var expected []string if strings.Contains(r.URL.String(), u.EscapedPath()) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} } else { // By default every request should have content-type applictaion/jose+json. expected = []string{"application/jose+json"} } ct := r.Header.Get("Content-Type") for _, e := range expected { if ct == e { next(w, r) return } } render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "expected content-type to be in %s, but got %s", expected, ct)) } } // parseJWS is a middleware that parses a request body into a JSONWebSignature struct. func parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "failed to read request body")) return } jws, err := jose.ParseJWS(string(body)) if err != nil { render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) return } ctx := context.WithValue(r.Context(), jwsContextKey, jws) next(w, r.WithContext(ctx)) } } // validateJWS checks the request body for to verify that it meets ACME // requirements for a JWS. // // The JWS MUST NOT have multiple signatures // The JWS Unencoded Payload Option [RFC7797] MUST NOT be used // The JWS Unprotected Header [RFC7515] MUST NOT be used // The JWS Payload MUST NOT be detached // The JWS Protected Header MUST include the following fields: // - “alg” (Algorithm). // This field MUST NOT contain “none” or a Message Authentication Code // (MAC) algorithm (e.g. one in which the algorithm registry description // mentions MAC/HMAC). // - “nonce” (defined in Section 6.5) // - “url” (defined in Section 6.4) // - Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below func validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, r, err) return } if len(jws.Signatures) == 0 { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) return } if len(jws.Signatures) > 1 { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) return } sig := jws.Signatures[0] uh := sig.Unprotected if uh.KeyID != "" || uh.JSONWebKey != nil || uh.Algorithm != "" || uh.Nonce != "" || len(uh.ExtraHeaders) > 0 { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) return } hdr := sig.Protected switch hdr.Algorithm { case jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512: if hdr.JSONWebKey != nil { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least %d bits (%d bytes) in size", 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) return } default: render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match")) return } } case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: render.Error(w, r, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) return } // Check the validity/freshness of the Nonce. if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { render.Error(w, r, err) return } // Check that the JWS url matches the requested url. jwsURL, ok := hdr.ExtraHeaders["url"].(string) if !ok { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) return } reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} if jwsURL != reqURL.String() { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) return } if hdr.JSONWebKey != nil && hdr.KeyID != "" { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } if hdr.JSONWebKey == nil && hdr.KeyID == "" { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } next(w, r) } } // extractJWK is a middleware that extracts the JWK from the JWS and saves it // in the context. Make sure to parse and validate the JWS before running this // middleware. func extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, r, err) return } jwk := jws.Signatures[0].Protected.JSONWebKey if jwk == nil { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) return } if !jwk.Valid() { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) return } // Overwrite KeyID with the JWK thumbprint. jwk.KeyID, err = acme.KeyToID(jwk) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } // Store the JWK in the context. ctx = context.WithValue(ctx, jwkContextKey, jwk) // Get Account OR continue to generate a new one OR continue Revoke with certificate private key acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID) switch { case acme.IsErrNotFound(err): // For NewAccount and Revoke requests ... break case err != nil: render.Error(w, r, err) return default: if !acc.IsValid() { render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) } next(w, r.WithContext(ctx)) } } // checkPrerequisites checks if all prerequisites for serving ACME // are met by the CA configuration. func checkPrerequisites(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // If the function is not set assume that all prerequisites are met. checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx) if ok { ok, err := checkFunc(ctx) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) return } if !ok { render.Error(w, r, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) return } } next(w, r) } } // lookupJWK loads the JWK associated with the acme account referenced by the // kid parameter of the signed payload. // Make sure to parse and validate the JWS before running this middleware. func lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, r, err) return } kid := jws.Signatures[0].Protected.KeyID if kid == "" { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'")) return } accID := path.Base(kid) acc, err := db.GetAccount(ctx, accID) switch { case acme.IsErrNotFound(err): render.Error(w, r, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: render.Error(w, r, err) return default: if !acc.IsValid() { render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } if storedLocation := acc.GetLocation(); storedLocation != "" { if kid != storedLocation { // ACME accounts should have a stored location equivalent to the // kid in the ACME request. render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "kid does not match stored account location; expected %s, but got %s", storedLocation, kid)) return } // Verify that the provisioner with which the account was created // matches the provisioner in the request URL. reqProv := acme.MustProvisionerFromContext(ctx) switch { case acc.ProvisionerID == "" && acc.ProvisionerName != reqProv.GetName(): render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s", acc.ProvisionerName, reqProv.GetName())) return case acc.ProvisionerID != "" && acc.ProvisionerID != reqProv.GetID(): render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s", acc.ProvisionerID, reqProv.GetID())) return } } else { // This code will only execute for old ACME accounts that do // not have a cached location. The following validation was // the original implementation of the `kid` check which has // since been deprecated. However, the code will remain to // ensure consistent behavior for old ACME accounts. linker := acme.MustLinkerFromContext(ctx) kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "") if !strings.HasPrefix(kid, kidPrefix) { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got %s", kidPrefix, kid)) return } } ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, jwkContextKey, acc.Key) next(w, r.WithContext(ctx)) return } } } // extractOrLookupJWK forwards handling to either extractJWK or // lookupJWK based on the presence of a JWK or a KID, respectively. func extractOrLookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, r, err) return } // at this point the JWS has already been verified (if correctly configured in middleware), // and it can be used to check if a JWK exists. This flow is used when the ACME client // signed the payload with a certificate private key. if canExtractJWKFrom(jws) { extractJWK(next)(w, r) return } // default to looking up the JWK based on KeyID. This flow is used when the ACME client // signed the payload with an account private key. lookupJWK(next)(w, r) } } // canExtractJWKFrom checks if the JWS has a JWK that can be extracted func canExtractJWKFrom(jws *jose.JSONWebSignature) bool { if jws == nil { return false } if len(jws.Signatures) == 0 { return false } return jws.Signatures[0].Protected.JSONWebKey != nil } // verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. // Make sure to parse and validate the JWS before running this middleware. func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, r, err) return } jwk, err := jwkFromContext(ctx) if err != nil { render.Error(w, r, err) return } if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } payload, err := jws.Verify(jwk) switch { case errors.Is(err, jose.ErrCryptoFailure): payload, err = retryVerificationWithPatchedSignatures(jws, jwk) if err != nil { render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws with patched signature(s)")) return } case err != nil: render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) return } ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ value: payload, isPostAsGet: len(payload) == 0, isEmptyJSON: string(payload) == "{}", }) next(w, r.WithContext(ctx)) } } // retryVerificationWithPatchedSignatures retries verification of the JWS using // the JWK by patching the JWS signatures if they're determined to be too short. // // Generally this shouldn't happen, but we've observed this to be the case with // the macOS ACME client, which seems to omit (at least one) leading null // byte(s). The error returned is `go-jose/go-jose: error in cryptographic // primitive`, which is a sentinel error that hides the details of the actual // underlying error, which is as follows: `go-jose/go-jose: invalid signature // size, have 63 bytes, wanted 64`, for ES256. func retryVerificationWithPatchedSignatures(jws *jose.JSONWebSignature, jwk *jose.JSONWebKey) (data []byte, err error) { originalSignatureValues := make([][]byte, len(jws.Signatures)) patched := false defer func() { if patched && err != nil { for i, sig := range jws.Signatures { sig.Signature = originalSignatureValues[i] jws.Signatures[i] = sig } } }() for i, sig := range jws.Signatures { var expectedSize int alg := strings.ToUpper(sig.Header.Algorithm) switch alg { case jose.ES256: expectedSize = 64 case jose.ES384: expectedSize = 96 case jose.ES512: expectedSize = 132 default: // other cases are (currently) ignored continue } switch diff := expectedSize - len(sig.Signature); diff { case 0: // expected length; nothing to do; will result in just doing the // same verification (as done before calling this function) again, // and thus an error will be returned. continue case 1: patched = true original := make([]byte, expectedSize-diff) copy(original, sig.Signature) originalSignatureValues[i] = original patchedR := make([]byte, expectedSize) copy(patchedR[1:], original) // [0x00, R.0:31, S.0:32], for expectedSize 64 sig.Signature = patchedR jws.Signatures[i] = sig // verify it with a patched R; return early if successful; continue // with patching S if not. data, err = jws.Verify(jwk) if err == nil { return } patchedS := make([]byte, expectedSize) halfSize := expectedSize / 2 copy(patchedS, original[:halfSize]) // [R.0:32], for expectedSize 64 copy(patchedS[halfSize+1:], original[halfSize:]) // [R.0:32, 0x00, S.0:31] sig.Signature = patchedS jws.Signatures[i] = sig case 2: // assumption is currently the Apple case, in which only the // first null byte of R and/or S are removed, and thus not a case in // which two first bytes of either R or S are removed. patched = true original := make([]byte, expectedSize-diff) copy(original, sig.Signature) originalSignatureValues[i] = original patchedRS := make([]byte, expectedSize) halfSize := expectedSize / 2 copy(patchedRS[1:], original[:halfSize-1]) // [0x00, R.0:31], for expectedSize 64 copy(patchedRS[halfSize+1:], original[halfSize-1:]) // [0x00, R.0:31, 0x00, S.0:31] sig.Signature = patchedRS jws.Signatures[i] = sig default: // Technically, there can be multiple null bytes in either R or S, // so when the difference is larger than 2, there is more than one // option to pick. Apple's ACME client seems to only cut off the // first null byte of either R or S, so we don't do anything in this // case. Will result in just doing the same verification (as done // before calling this function) again, and thus an error will be // returned. // TODO(hs): log this specific case? It might mean some other ACME // client is doing weird things. continue } } data, err = jws.Verify(jwk) return } // isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). func isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { render.Error(w, r, err) return } if !payload.isPostAsGet { render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) return } next(w, r) } } // ContextKey is the key type for storing and searching for ACME request // essentials in the context of a request. type ContextKey string const ( // accContextKey account key accContextKey = ContextKey("acc") // jwsContextKey jws key jwsContextKey = ContextKey("jws") // jwkContextKey jwk key jwkContextKey = ContextKey("jwk") // payloadContextKey payload key payloadContextKey = ContextKey("payload") ) // accountFromContext searches the context for an ACME account. Returns the // account or an error. func accountFromContext(ctx context.Context) (*acme.Account, error) { val, ok := ctx.Value(accContextKey).(*acme.Account) if !ok || val == nil { return nil, acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context") } return val, nil } // jwkFromContext searches the context for a JWK. Returns the JWK or an error. func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) if !ok || val == nil { return nil, acme.NewErrorISE("jwk expected in request context") } return val, nil } // jwsFromContext searches the context for a JWS. Returns the JWS or an error. func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { val, ok := ctx.Value(jwsContextKey).(*jose.JSONWebSignature) if !ok || val == nil { return nil, acme.NewErrorISE("jws expected in request context") } return val, nil } // provisionerFromContext searches the context for a provisioner. Returns the // provisioner or an error. func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { p, ok := acme.ProvisionerFromContext(ctx) if !ok || p == nil { return nil, acme.NewErrorISE("provisioner expected in request context") } return p, nil } // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // pointer to an ACME provisioner or an error. func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { p, err := provisionerFromContext(ctx) if err != nil { return nil, err } ap, ok := p.(*provisioner.ACME) if !ok { return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") } return ap, nil } // payloadFromContext searches the context for a payload. Returns the payload // or an error. func payloadFromContext(ctx context.Context) (*payloadInfo, error) { val, ok := ctx.Value(payloadContextKey).(*payloadInfo) if !ok || val == nil { return nil, acme.NewErrorISE("payload expected in request context") } return val, nil } ================================================ FILE: acme/api/middleware_test.go ================================================ package api import ( "bytes" "context" "crypto" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/google/uuid" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" tassert "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" ) var testBody = []byte("foo") func testNext(w http.ResponseWriter, _ *http.Request) { w.Write(testBody) } func newBaseContext(ctx context.Context, args ...interface{}) context.Context { for _, a := range args { switch v := a.(type) { case acme.DB: ctx = acme.NewDatabaseContext(ctx, v) case acme.Linker: ctx = acme.NewLinkerContext(ctx, v) case acme.PrerequisitesChecker: ctx = acme.NewPrerequisitesCheckerContext(ctx, v) } } return ctx } func TestHandler_addNonce(t *testing.T) { u := "https://ca.smallstep.com/acme/new-nonce" type test struct { db acme.DB err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/AddNonce-error": func(t *testing.T) test { return test{ db: &acme.MockDB{ MockCreateNonce: func(ctx context.Context) (acme.Nonce, error) { return acme.Nonce(""), acme.NewErrorISE("force") }, }, statusCode: 500, err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { return test{ db: &acme.MockDB{ MockCreateNonce: func(ctx context.Context) (acme.Nonce, error) { return "bar", nil }, }, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := newBaseContext(context.Background(), tc.db) req := httptest.NewRequest("GET", u, http.NoBody).WithContext(ctx) w := httptest.NewRecorder() addNonce(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, res.Header["Replay-Nonce"], []string{"bar"}) assert.Equals(t, res.Header["Cache-Control"], []string{"no-store"}) assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } func TestHandler_addDirLink(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { link string statusCode int ctx context.Context err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme")) return test{ ctx: ctx, link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { req := httptest.NewRequest("GET", "/foo", http.NoBody) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() addDirLink(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s>;rel=\"index\"", tc.link)}) assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } func TestHandler_verifyContentType(t *testing.T) { prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) type test struct { ctx context.Context contentType string err *acme.Error statusCode int url string } var tests = map[string]func(t *testing.T) test{ "fail/provisioner-not-set": func(t *testing.T) test { return test{ url: u, ctx: context.Background(), contentType: "foo", statusCode: 500, err: acme.NewErrorISE("provisioner expected in request context"), } }, "fail/general-bad-content-type": func(t *testing.T) test { return test{ url: u, ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "foo", statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), } }, "fail/certificate-bad-content-type": func(t *testing.T) test { return test{ ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "foo", statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), } }, "ok": func(t *testing.T) test { return test{ ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/jose+json", statusCode: 200, } }, "ok/certificate/pkix-cert": func(t *testing.T) test { return test{ ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/pkix-cert", statusCode: 200, } }, "ok/certificate/jose+json": func(t *testing.T) test { return test{ ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/jose+json", statusCode: 200, } }, "ok/certificate/pkcs7-mime": func(t *testing.T) test { return test{ ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/pkcs7-mime", statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { _u := u if tc.url != "" { _u = tc.url } req := httptest.NewRequest("GET", _u, http.NoBody) req = req.WithContext(tc.ctx) req.Header.Add("Content-Type", tc.contentType) w := httptest.NewRecorder() verifyContentType(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } func TestHandler_isPostAsGet(t *testing.T) { u := "https://ca.smallstep.com/acme/new-account" type test struct { ctx context.Context err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-payload": func(t *testing.T) test { return test{ ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { return test{ ctx: context.WithValue(context.Background(), payloadContextKey, nil), statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), } }, "fail/not-post-as-get": func(t *testing.T) test { return test{ ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"), } }, "ok": func(t *testing.T) test { return test{ ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{isPostAsGet: true}), statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { // h := &Handler{} req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() isPostAsGet(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } type errReader int func (errReader) Read([]byte) (int, error) { return 0, errors.New("force") } func (errReader) Close() error { return nil } func TestHandler_parseJWS(t *testing.T) { u := "https://ca.smallstep.com/acme/new-account" type test struct { next nextHTTP body io.Reader err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/read-body-error": func(t *testing.T) test { return test{ body: errReader(0), statusCode: 500, err: acme.NewErrorISE("failed to read request body: force"), } }, "fail/parse-jws-error": func(t *testing.T) test { return test{ body: strings.NewReader("foo"), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to parse JWS from request body: go-jose/go-jose: compact JWS format must have three parts"), } }, "ok": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, new(jose.SignerOptions)) assert.FatalError(t, err) signed, err := signer.Sign([]byte("baz")) assert.FatalError(t, err) expRaw, err := signed.CompactSerialize() assert.FatalError(t, err) return test{ body: strings.NewReader(expRaw), next: func(w http.ResponseWriter, r *http.Request) { jws, err := jwsFromContext(r.Context()) assert.FatalError(t, err) gotRaw, err := jws.CompactSerialize() assert.FatalError(t, err) assert.Equals(t, gotRaw, expRaw) w.Write(testBody) }, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { // h := &Handler{} req := httptest.NewRequest("GET", u, tc.body) w := httptest.NewRecorder() parseJWS(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } func TestHandler_verifyAndExtractJWSPayload(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := jwk.Public() pub := &_pub so := new(jose.SignerOptions) so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign([]byte("baz")) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) u := "https://ca.smallstep.com/acme/account/1234" type test struct { ctx context.Context next func(http.ResponseWriter, *http.Request) err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { return test{ ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-jwk": func(t *testing.T) test { return test{ ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) return test{ ctx: context.WithValue(ctx, jwsContextKey, nil), statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/verify-jws-failure-wrong-jwk": func(t *testing.T) test { _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwkContextKey, &_pub) return test{ ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "error verifying jws: go-jose/go-jose: error in cryptographic primitive"), } }, "fail/verify-jws-failure-too-many-signatures": func(t *testing.T) test { newParsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) newParsedJWS.Signatures = append(newParsedJWS.Signatures, newParsedJWS.Signatures...) ctx := context.WithValue(context.Background(), jwsContextKey, newParsedJWS) ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "error verifying jws: go-jose/go-jose: too many signatures in payload; expecting only one"), } }, "fail/apple-acmeclient-omitting-leading-null-byte-in-signature-with-wrong-jwk": func(t *testing.T) test { _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() appleNullByteCaseBody := `{"payload":"dGVzdC0xMTA1","protected":"eyJhbGciOiJFUzI1NiJ9","signature":"rQPYKYflfKnlgBKqDeWsJH2TJ6iHAnou7sFzXlmYD4ArXqLfYuqotWERKrna2wfzh0pu7USWO2gzlOqRK9qq"}` appleNullByteCaseJWS, err := jose.ParseJWS(appleNullByteCaseBody) require.NoError(t, err) ctx := context.WithValue(context.Background(), jwsContextKey, appleNullByteCaseJWS) ctx = context.WithValue(ctx, jwkContextKey, &_pub) return test{ ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "error verifying jws: go-jose/go-jose: error in cryptographic primitive"), } }, "fail/algorithm-mismatch": func(t *testing.T) test { _pub := *pub clone := &_pub clone.Algorithm = jose.HS256 ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwkContextKey, clone) return test{ ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"), } }, "ok": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { p, err := payloadFromContext(r.Context()) assert.FatalError(t, err) if assert.NotNil(t, p) { assert.Equals(t, p.value, []byte("baz")) assert.False(t, p.isPostAsGet) assert.False(t, p.isEmptyJSON) } w.Write(testBody) }, } }, "ok/empty-algorithm-in-jwk": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { p, err := payloadFromContext(r.Context()) assert.FatalError(t, err) if assert.NotNil(t, p) { assert.Equals(t, p.value, []byte("baz")) assert.False(t, p.isPostAsGet) assert.False(t, p.isEmptyJSON) } w.Write(testBody) }, } }, "ok/post-as-get": func(t *testing.T) test { _jws, err := signer.Sign([]byte("")) assert.FatalError(t, err) _raw, err := _jws.CompactSerialize() assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { p, err := payloadFromContext(r.Context()) assert.FatalError(t, err) if assert.NotNil(t, p) { assert.Equals(t, p.value, []byte{}) assert.True(t, p.isPostAsGet) assert.False(t, p.isEmptyJSON) } w.Write(testBody) }, } }, "ok/empty-json": func(t *testing.T) test { _jws, err := signer.Sign([]byte("{}")) assert.FatalError(t, err) _raw, err := _jws.CompactSerialize() assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) ctx = context.WithValue(ctx, jwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { p, err := payloadFromContext(r.Context()) assert.FatalError(t, err) if assert.NotNil(t, p) { assert.Equals(t, p.value, []byte("{}")) assert.False(t, p.isPostAsGet) assert.True(t, p.isEmptyJSON) } w.Write(testBody) }, } }, "ok/apple-acmeclient-omitting-leading-null-byte-in-signature": func(t *testing.T) test { appleNullByteCaseKey := []byte(`{ "kid": "uioinbiTlJICL0MYsb6ar1totfRA2tiPqWgntF8xUdo", "crv": "P-256", "alg": "ES256", "kty": "EC", "x": "wlz-Kv9X0h32fzLq-cogls9HxoZQqV-GuWxdb2MCeUY", "y": "xzP6zRrg_jynYljZTxfJuql_QWtdQR6lpJ52q_6Vavg" }`) appleNullByteCaseJWK := &jose.JSONWebKey{} err = json.Unmarshal(appleNullByteCaseKey, appleNullByteCaseJWK) require.NoError(t, err) appleNullByteCaseBody := `{"payload":"dGVzdC0xMTA1","protected":"eyJhbGciOiJFUzI1NiJ9","signature":"rQPYKYflfKnlgBKqDeWsJH2TJ6iHAnou7sFzXlmYD4ArXqLfYuqotWERKrna2wfzh0pu7USWO2gzlOqRK9qq"}` appleNullByteCaseJWS, err := jose.ParseJWS(appleNullByteCaseBody) require.NoError(t, err) ctx := context.WithValue(context.Background(), jwsContextKey, appleNullByteCaseJWS) ctx = context.WithValue(ctx, jwkContextKey, appleNullByteCaseJWK) return test{ ctx: ctx, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { p, err := payloadFromContext(r.Context()) tassert.NoError(t, err) if tassert.NotNil(t, p) { tassert.Equal(t, []byte(`test-1105`), p.value) tassert.False(t, p.isPostAsGet) tassert.False(t, p.isEmptyJSON) } w.Write(testBody) }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { // h := &Handler{} req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() verifyAndExtractJWSPayload(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } func TestHandler_lookupJWK(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} u := fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provName) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) accID := "account-id" prefix := fmt.Sprintf("%s/acme/%s/account/", baseURL, provName) so := new(jose.SignerOptions) so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID)) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign([]byte("baz")) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) type test struct { linker acme.Linker db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ db: &acme.MockDB{}, linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ db: &acme.MockDB{}, linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-kid": func(t *testing.T) test { _signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, new(jose.SignerOptions)) assert.FatalError(t, err) _jws, err := _signer.Sign([]byte("baz")) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ db: &acme.MockDB{}, linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"), } }, "fail/account-not-found": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, accID) return nil, acme.ErrNotFound }, }, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/GetAccount-error": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) return nil, acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) return acc, nil }, }, ctx: ctx, statusCode: 401, err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"), } }, "fail/account-with-location-prefix/bad-kid": func(t *testing.T) test { acc := &acme.Account{LocationPrefix: "foobar", Status: "valid"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) return acc, nil }, }, ctx: ctx, statusCode: http.StatusUnauthorized, err: acme.NewError(acme.ErrorUnauthorizedType, "kid does not match stored account location; expected foobar, but %q", prefix+accID), } }, "fail/account-with-location-prefix/bad-provisioner": func(t *testing.T) test { acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerName: "other"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) return acc, nil }, }, ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { _acc, err := accountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk, jwk) w.Write(testBody) }, statusCode: http.StatusUnauthorized, err: acme.NewError(acme.ErrorUnauthorizedType, "account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s", "other", prov.GetName()), } }, "fail/account-with-location-prefix/bad-provisioner-id": func(t *testing.T) test { p := newProvWithID() acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerID: uuid.NewString()} ctx := acme.NewProvisionerContext(context.Background(), p) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) return acc, nil }, }, ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { _acc, err := accountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk, jwk) w.Write(testBody) }, statusCode: http.StatusUnauthorized, err: acme.NewError(acme.ErrorUnauthorizedType, "account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s", acc.ProvisionerID, p.GetID()), } }, "ok/account-with-location-prefix": func(t *testing.T) test { acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerName: prov.GetName()} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) return acc, nil }, }, ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { _acc, err := accountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk, jwk) w.Write(testBody) }, statusCode: http.StatusOK, } }, "ok/account-without-location-prefix": func(t *testing.T) test { acc := &acme.Account{Status: "valid", Key: jwk} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) return acc, nil }, }, ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { _acc, err := accountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk, jwk) w.Write(testBody) }, statusCode: 200, } }, "ok/account-with-provisioner-id": func(t *testing.T) test { p := newProvWithID() acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerID: p.GetID()} ctx := acme.NewProvisionerContext(context.Background(), p) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) return acc, nil }, }, ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { _acc, err := accountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk, jwk) w.Write(testBody) }, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := newBaseContext(tc.ctx, tc.db, tc.linker) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() lookupJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } func TestHandler_extractJWK(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) kid, err := jwk.Thumbprint(crypto.SHA256) assert.FatalError(t, err) pub := jwk.Public() pub.KeyID = base64.RawURLEncoding.EncodeToString(kid) so := new(jose.SignerOptions) so.WithHeader("jwk", pub) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign([]byte("baz")) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) u := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provName) type test struct { db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { _jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ JSONWebKey: nil, }, }, }, } ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"), } }, "fail/invalid-jwk": func(t *testing.T) test { _jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ JSONWebKey: &jose.JSONWebKey{Key: "foo"}, }, }, }, } ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"), } }, "fail/GetAccountByKey-error": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) return nil, acme.NewErrorISE("force") }, }, statusCode: 500, err: acme.NewErrorISE("force"), } }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) return acc, nil }, }, statusCode: 401, err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) return acc, nil }, }, next: func(w http.ResponseWriter, r *http.Request) { _acc, err := accountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk.KeyID, pub.KeyID) w.Write(testBody) }, statusCode: 200, } }, "ok/no-account": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) return nil, acme.ErrNotFound }, }, next: func(w http.ResponseWriter, r *http.Request) { _acc, err := accountFromContext(r.Context()) assert.NotNil(t, err) assert.Nil(t, _acc) _jwk, err := jwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk.KeyID, pub.KeyID) w.Write(testBody) }, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := newBaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() extractJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } func TestHandler_validateJWS(t *testing.T) { u := "https://ca.smallstep.com/acme/account/1234" type test struct { db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-signature": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"), } }, "fail/more-than-one-signature": func(t *testing.T) test { jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ {}, {}, }, } return test{ db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"), } }, "fail/unprotected-header-not-empty": func(t *testing.T) test { jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ {Unprotected: jose.Header{Nonce: "abc"}}, }, } return test{ db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"), } }, "fail/unsuitable-algorithm-none": func(t *testing.T) test { jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ {Protected: jose.Header{Algorithm: "none"}}, }, } return test{ db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"), } }, "fail/unsuitable-algorithm-mac": func(t *testing.T) test { jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ {Protected: jose.Header{Algorithm: jose.HS256}}, }, } return test{ db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256), } }, "fail/rsa-key-&-alg-mismatch": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) pub := jwk.Public() jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ Algorithm: jose.RS256, JSONWebKey: &pub, ExtraHeaders: map[jose.HeaderKey]interface{}{ "url": u, }, }, }, }, } return test{ db: &acme.MockDB{ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match"), } }, "fail/rsa-key-too-small": func(t *testing.T) test { revert := keyutil.Insecure() defer revert() jwk, err := jose.GenerateJWK("RSA", "", "", "sig", "", 1024) assert.FatalError(t, err) pub := jwk.Public() jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ Algorithm: jose.RS256, JSONWebKey: &pub, ExtraHeaders: map[jose.HeaderKey]interface{}{ "url": u, }, }, }, }, } return test{ db: &acme.MockDB{ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least 2048 bits (256 bytes) in size"), } }, "fail/UseNonce-error": func(t *testing.T) test { jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ {Protected: jose.Header{Algorithm: jose.ES256}}, }, } return test{ db: &acme.MockDB{ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return acme.NewErrorISE("force") }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 500, err: acme.NewErrorISE("force"), } }, "fail/no-url-header": func(t *testing.T) test { jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ {Protected: jose.Header{Algorithm: jose.ES256}}, }, } return test{ db: &acme.MockDB{ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"), } }, "fail/url-mismatch": func(t *testing.T) test { jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ Algorithm: jose.ES256, ExtraHeaders: map[jose.HeaderKey]interface{}{ "url": "foo", }, }, }, }, } return test{ db: &acme.MockDB{ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", u), } }, "fail/both-jwk-kid": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) pub := jwk.Public() jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ Algorithm: jose.ES256, KeyID: "bar", JSONWebKey: &pub, ExtraHeaders: map[jose.HeaderKey]interface{}{ "url": u, }, }, }, }, } return test{ db: &acme.MockDB{ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"), } }, "fail/no-jwk-kid": func(t *testing.T) test { jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ Algorithm: jose.ES256, ExtraHeaders: map[jose.HeaderKey]interface{}{ "url": u, }, }, }, }, } return test{ db: &acme.MockDB{ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"), } }, "ok/kid": func(t *testing.T) test { jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ Algorithm: jose.ES256, KeyID: "bar", ExtraHeaders: map[jose.HeaderKey]interface{}{ "url": u, }, }, }, }, } return test{ db: &acme.MockDB{ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, statusCode: 200, } }, "ok/jwk/ecdsa": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) pub := jwk.Public() jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ Algorithm: jose.ES256, JSONWebKey: &pub, ExtraHeaders: map[jose.HeaderKey]interface{}{ "url": u, }, }, }, }, } return test{ db: &acme.MockDB{ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, statusCode: 200, } }, "ok/jwk/rsa": func(t *testing.T) test { jwk, err := jose.GenerateJWK("RSA", "", "", "sig", "", 2048) assert.FatalError(t, err) pub := jwk.Public() jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ Algorithm: jose.RS256, JSONWebKey: &pub, ExtraHeaders: map[jose.HeaderKey]interface{}{ "url": u, }, }, }, }, } return test{ db: &acme.MockDB{ MockDeleteNonce: func(ctx context.Context, n acme.Nonce) error { return nil }, }, ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := newBaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() validateJWS(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } func Test_canExtractJWKFrom(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) type args struct { jws *jose.JSONWebSignature } tests := []struct { name string args args want bool }{ { name: "no-jws", args: args{ jws: nil, }, want: false, }, { name: "no-signatures", args: args{ jws: &jose.JSONWebSignature{ Signatures: []jose.Signature{}, }, }, want: false, }, { name: "no-jwk", args: args{ jws: &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{}, }, }, }, }, want: false, }, { name: "ok", args: args{ jws: &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ JSONWebKey: jwk, }, }, }, }, }, want: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := canExtractJWKFrom(tt.args.jws); got != tt.want { t.Errorf("canExtractJWKFrom() = %v, want %v", got, tt.want) } }) } } func TestHandler_extractOrLookupJWK(t *testing.T) { u := "https://ca.smallstep.com/acme/account" type test struct { db acme.DB linker acme.Linker statusCode int ctx context.Context err *acme.Error next func(w http.ResponseWriter, r *http.Request) } var tests = map[string]func(t *testing.T) test{ "ok/extract": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) kid, err := jwk.Thumbprint(crypto.SHA256) assert.FatalError(t, err) pub := jwk.Public() pub.KeyID = base64.RawURLEncoding.EncodeToString(kid) so := new(jose.SignerOptions) so.WithHeader("jwk", pub) // JWK for certificate private key flow signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) signed, err := signer.Sign([]byte("foo")) assert.FatalError(t, err) raw, err := signed.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) return test{ linker: acme.NewLinker("dns", "acme"), db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) return nil, acme.ErrNotFound }, }, ctx: context.WithValue(context.Background(), jwsContextKey, parsedJWS), statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, } }, "ok/lookup": func(t *testing.T) test { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) accID := "accID" prefix := fmt.Sprintf("%s/acme/%s/account/", baseURL, provName) so := new(jose.SignerOptions) so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID)) // KID for account private key flow signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk.Key, }, so) assert.FatalError(t, err) jws, err := signer.Sign([]byte("baz")) assert.FatalError(t, err) raw, err := jws.CompactSerialize() assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, acc.ID) return acc, nil }, }, ctx: ctx, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { ctx := newBaseContext(tc.ctx, tc.db, tc.linker) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() extractOrLookupJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } func TestHandler_checkPrerequisites(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} u := fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provName) type test struct { linker acme.Linker ctx context.Context prerequisitesChecker func(context.Context) (bool, error) next func(http.ResponseWriter, *http.Request) err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/error": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, err: acme.WrapErrorISE(errors.New("force"), "error checking acme provisioner prerequisites"), statusCode: 500, } }, "fail/prerequisites-nok": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, err: acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"), statusCode: 501, } }, "ok": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() checkPrerequisites(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) } }) } } func Test_retryVerificationWithPatchedSignatures(t *testing.T) { patchedRKey := []byte(`{ "kid": "uioinbiTlJICL0MYsb6ar1totfRA2tiPqWgntF8xUdo", "crv": "P-256", "alg": "ES256", "kty": "EC", "x": "wlz-Kv9X0h32fzLq-cogls9HxoZQqV-GuWxdb2MCeUY", "y": "xzP6zRrg_jynYljZTxfJuql_QWtdQR6lpJ52q_6Vavg" }`) patchedRJWK := &jose.JSONWebKey{} err := json.Unmarshal(patchedRKey, patchedRJWK) require.NoError(t, err) patchedRBody := `{"payload":"dGVzdC0xMTA1","protected":"eyJhbGciOiJFUzI1NiJ9","signature":"rQPYKYflfKnlgBKqDeWsJH2TJ6iHAnou7sFzXlmYD4ArXqLfYuqotWERKrna2wfzh0pu7USWO2gzlOqRK9qq"}` patchedR, err := jose.ParseJWS(patchedRBody) require.NoError(t, err) patchedSKey := []byte(`{ "kid": "PblXsnK59uTiF5k3mmAN2B6HDPPxqBL_4UGhEG8ZO6g", "crv": "P-256", "alg": "ES256", "kty": "EC", "x": "T5aM_TOSattXNeUkH1VHZXh8URzdjZTI2zLvVgI0cy0", "y": "Lf8h8qZnURXIxm6OnQ69kxGC91YtTZRD2GAroEf1UA8" }`) patchedSJWK := &jose.JSONWebKey{} err = json.Unmarshal(patchedSKey, patchedSJWK) require.NoError(t, err) patchedSBody := `{"payload":"dGVzdC02Ng","protected":"eyJhbGciOiJFUzI1NiJ9","signature":"krtSKSgVB04oqx6i9QLeal_wZSnjV1_PSIM3AubT0WRIxnhl_yYbVpa3i53p3dUW56TtP6_SUZboH6SvLHMz"}` patchedS, err := jose.ParseJWS(patchedSBody) require.NoError(t, err) patchedRSKey := []byte(`{ "kid": "U8BmBVbZsNUawvhOomJQPa6uYj1rdxCPQWF_nOLVsc4", "crv": "P-256", "alg": "ES256", "kty": "EC", "x": "Ym0l3GMS6aHBLo-xe73Kub4kafnOBu_QAfOsx5y-bV0", "y": "wKijX9Cu67HbK94StPcI18WulgRfIMbP2ZU7gQuf3-M" }`) patchedRSJWK := &jose.JSONWebKey{} err = json.Unmarshal(patchedRSKey, patchedRSJWK) require.NoError(t, err) patchedRSBody := `{"payload":"dGVzdC05MDY3","protected":"eyJhbGciOiJFUzI1NiJ9","signature":"2r_My19oRg7mWf9I5JTkNYp8otfEMz-yXRA8ltZTAKZxyJLurpVEgicmNItu7lfcCrGrTgI3Obye_gSaIyc"}` patchedRS, err := jose.ParseJWS(patchedRSBody) require.NoError(t, err) patchedRWithWrongJWK, err := jose.ParseJWS(patchedRBody) require.NoError(t, err) tests := []struct { name string jws *jose.JSONWebSignature jwk *jose.JSONWebKey expectedData []byte expectedSignature string expectedError error }{ {"ok/patched-r", patchedR, patchedRJWK, []byte(`test-1105`), `AK0D2CmH5Xyp5YASqg3lrCR9kyeohwJ6Lu7Bc15ZmA-AK16i32LqqLVhESq52tsH84dKbu1EljtoM5TqkSvaqg`, nil}, {"ok/patched-s", patchedS, patchedSJWK, []byte(`test-66`), `krtSKSgVB04oqx6i9QLeal_wZSnjV1_PSIM3AubT0WQASMZ4Zf8mG1aWt4ud6d3VFuek7T-v0lGW6B-kryxzMw`, nil}, {"ok/patched-rs", patchedRS, patchedRSJWK, []byte(`test-9067`), `ANq_zMtfaEYO5ln_SOSU5DWKfKLXxDM_sl0QPJbWUwAApnHIku6ulUSCJyY0i27uV9wKsatOAjc5vJ7-BJojJw`, nil}, {"fail/patched-r-wrong-jwk", patchedRWithWrongJWK, patchedRSJWK, nil, `rQPYKYflfKnlgBKqDeWsJH2TJ6iHAnou7sFzXlmYD4ArXqLfYuqotWERKrna2wfzh0pu7USWO2gzlOqRK9qq`, errors.New("go-jose/go-jose: error in cryptographic primitive")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { expectedSignature, decodeErr := base64.RawURLEncoding.DecodeString(tt.expectedSignature) require.NoError(t, decodeErr) data, err := retryVerificationWithPatchedSignatures(tt.jws, tt.jwk) if tt.expectedError != nil { tassert.EqualError(t, err, tt.expectedError.Error()) tassert.Equal(t, expectedSignature, tt.jws.Signatures[0].Signature) tassert.Empty(t, data) return } tassert.NoError(t, err) tassert.Len(t, tt.jws.Signatures[0].Signature, 64) tassert.Equal(t, expectedSignature, tt.jws.Signatures[0].Signature) tassert.Equal(t, tt.expectedData, data) }) } } ================================================ FILE: acme/api/order.go ================================================ package api import ( "context" "crypto/x509" "encoding/base64" "encoding/json" "fmt" "net" "net/http" "strings" "time" "github.com/go-chi/chi/v5" "go.step.sm/crypto/randutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme/wire" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" ) // NewOrderRequest represents the body for a NewOrder request. type NewOrderRequest struct { Identifiers []acme.Identifier `json:"identifiers"` NotBefore time.Time `json:"notBefore,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"` } // Validate validates a new-order request body. func (n *NewOrderRequest) Validate() error { if len(n.Identifiers) == 0 { return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty") } for _, id := range n.Identifiers { switch id.Type { case acme.IP: if net.ParseIP(id.Value) == nil { return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value) } case acme.DNS: value, _ := trimIfWildcard(id.Value) if _, err := x509util.SanitizeName(value); err != nil { return acme.NewError(acme.ErrorMalformedType, "invalid DNS name: %s", id.Value) } case acme.PermanentIdentifier: if id.Value == "" { return acme.NewError(acme.ErrorMalformedType, "permanent identifier cannot be empty") } case acme.WireUser, acme.WireDevice: // validation of Wire identifiers is performed in `validateWireIdentifiers`, but // marked here as known and supported types. continue default: return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type) } } if err := n.validateWireIdentifiers(); err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "failed validating Wire identifiers") } // TODO(hs): add some validations for DNS domains? // TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1 return nil } func (n *NewOrderRequest) validateWireIdentifiers() error { if !n.hasWireIdentifiers() { return nil } userIdentifiers := identifiersOfType(acme.WireUser, n.Identifiers) deviceIdentifiers := identifiersOfType(acme.WireDevice, n.Identifiers) if len(userIdentifiers) != 1 { return fmt.Errorf("expected exactly one Wire UserID identifier; got %d", len(userIdentifiers)) } if len(deviceIdentifiers) != 1 { return fmt.Errorf("expected exactly one Wire DeviceID identifier, got %d", len(deviceIdentifiers)) } wireUserID, err := wire.ParseUserID(userIdentifiers[0].Value) if err != nil { return fmt.Errorf("failed parsing Wire UserID: %w", err) } wireDeviceID, err := wire.ParseDeviceID(deviceIdentifiers[0].Value) if err != nil { return fmt.Errorf("failed parsing Wire DeviceID: %w", err) } if _, err := wire.ParseClientID(wireDeviceID.ClientID); err != nil { return fmt.Errorf("invalid Wire client ID %q: %w", wireDeviceID.ClientID, err) } switch { case wireUserID.Domain != wireDeviceID.Domain: return fmt.Errorf("UserID domain %q does not match DeviceID domain %q", wireUserID.Domain, wireDeviceID.Domain) case wireUserID.Name != wireDeviceID.Name: return fmt.Errorf("UserID name %q does not match DeviceID name %q", wireUserID.Name, wireDeviceID.Name) case wireUserID.Handle != wireDeviceID.Handle: return fmt.Errorf("UserID handle %q does not match DeviceID handle %q", wireUserID.Handle, wireDeviceID.Handle) } return nil } // hasWireIdentifiers returns whether the [NewOrderRequest] contains // Wire identifiers. func (n *NewOrderRequest) hasWireIdentifiers() bool { for _, i := range n.Identifiers { if i.Type == acme.WireUser || i.Type == acme.WireDevice { return true } } return false } // identifiersOfType returns the Identifiers that are of type typ. func identifiersOfType(typ acme.IdentifierType, ids []acme.Identifier) (result []acme.Identifier) { for _, id := range ids { if id.Type == typ { result = append(result, id) } } return } // FinalizeRequest captures the body for a Finalize order request. type FinalizeRequest struct { CSR string `json:"csr"` csr *x509.CertificateRequest } // Validate validates a finalize request body. func (f *FinalizeRequest) Validate() error { var err error // RFC 8555 isn't 100% conclusive about using raw base64-url encoding for the // CSR specifically, instead of "normal" base64-url encoding (incl. padding). // By trimming the padding from CSRs submitted by ACME clients that use // base64-url encoding instead of raw base64-url encoding, these are also // supported. This was reported in https://github.com/smallstep/certificates/issues/939 // to be the case for a Synology DSM NAS system. csrBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(f.CSR, "=")) if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr") } f.csr, err = x509.ParseCertificateRequest(csrBytes) if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "unable to parse csr") } if err = f.csr.CheckSignature(); err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "csr failed signature check") } return nil } var defaultOrderExpiry = time.Hour * 24 var defaultOrderBackdate = time.Minute // NewOrder ACME api for creating a new order. func NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ca := mustAuthority(ctx) db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, r, err) return } prov, err := provisionerFromContext(ctx) if err != nil { render.Error(w, r, err) return } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, r, err) return } var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-order request payload")) return } if err := nor.Validate(); err != nil { render.Error(w, r, err) return } // TODO(hs): gather all errors, so that we can build one response with ACME subproblems // include the nor.Validate() error here too, like in the example in the ACME RFC? acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { render.Error(w, r, err) return } var eak *acme.ExternalAccountKey if acmeProv.RequireEAB { if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error retrieving external account binding key")) return } } acmePolicy, err := newACMEPolicyEngine(eak) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error creating ACME policy engine")) return } for _, identifier := range nor.Identifiers { // evaluate the ACME account level policy if err = isIdentifierAllowed(acmePolicy, identifier); err != nil { render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) return } // evaluate the provisioner level policy orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value} if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil { render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) return } // evaluate the authority level policy if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil { render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) return } } now := clock.Now() // New order. o := &acme.Order{ AccountID: acc.ID, ProvisionerID: prov.GetID(), Status: acme.StatusPending, Identifiers: nor.Identifiers, ExpiresAt: now.Add(defaultOrderExpiry), AuthorizationIDs: make([]string, len(nor.Identifiers)), NotBefore: nor.NotBefore, NotAfter: nor.NotAfter, } for i, identifier := range o.Identifiers { az := &acme.Authorization{ AccountID: acc.ID, Identifier: identifier, ExpiresAt: o.ExpiresAt, Status: acme.StatusPending, } if err := newAuthorization(ctx, az); err != nil { render.Error(w, r, err) return } o.AuthorizationIDs[i] = az.ID } if o.NotBefore.IsZero() { o.NotBefore = now } if o.NotAfter.IsZero() { o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration()) } // if request NotBefore was empty, then backdate the order.NotBefore (now) // to avoid timing issues. if nor.NotBefore.IsZero() { backdate := defaultOrderBackdate if bd := ca.GetBackdate(); bd != nil { backdate = *bd } o.NotBefore = o.NotBefore.Add(-backdate) } if err := db.CreateOrder(ctx, o); err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error creating order")) return } linker.LinkOrder(ctx, o) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSONStatus(w, r, o, http.StatusCreated) } func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error { if acmePolicy == nil { return nil } return acmePolicy.AreSANsAllowed([]string{identifier.Value}) } func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error) { if eak == nil { //nolint:nilnil,nolintlint // expected values return nil, nil } return policy.NewX509PolicyEngine(eak.Policy) } func trimIfWildcard(value string) (string, bool) { if strings.HasPrefix(value, "*.") { return strings.TrimPrefix(value, "*."), true } return value, false } func newAuthorization(ctx context.Context, az *acme.Authorization) error { value, isWildcard := trimIfWildcard(az.Identifier.Value) az.Wildcard = isWildcard az.Identifier = acme.Identifier{ Value: value, Type: az.Identifier.Type, } chTypes := challengeTypes(az) var err error az.Token, err = randutil.Alphanumeric(32) if err != nil { return acme.WrapErrorISE(err, "error generating random alphanumeric ID") } db := acme.MustDatabaseFromContext(ctx) prov := acme.MustProvisionerFromContext(ctx) az.Challenges = make([]*acme.Challenge, 0, len(chTypes)) for _, typ := range chTypes { if !prov.IsChallengeEnabled(ctx, provisioner.ACMEChallenge(typ)) { continue } var target string switch az.Identifier.Type { case acme.WireUser: wireOptions, err := prov.GetOptions().GetWireOptions() if err != nil { return acme.WrapErrorISE(err, "failed getting Wire options") } target, err = wireOptions.GetOIDCOptions().EvaluateTarget("") // TODO(hs): determine if required by Wire if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "invalid Go template registered for 'target'") } case acme.WireDevice: wireID, err := wire.ParseDeviceID(az.Identifier.Value) if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing WireDevice") } clientID, err := wire.ParseClientID(wireID.ClientID) if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing ClientID") } wireOptions, err := prov.GetOptions().GetWireOptions() if err != nil { return acme.WrapErrorISE(err, "failed getting Wire options") } target, err = wireOptions.GetDPOPOptions().EvaluateTarget(clientID.DeviceID) if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "invalid Go template registered for 'target'") } } ch := &acme.Challenge{ AccountID: az.AccountID, Value: az.Identifier.Value, Type: typ, Token: az.Token, Status: acme.StatusPending, Target: target, } if err := db.CreateChallenge(ctx, ch); err != nil { return acme.WrapErrorISE(err, "error creating challenge") } az.Challenges = append(az.Challenges, ch) } if err = db.CreateAuthorization(ctx, az); err != nil { return acme.WrapErrorISE(err, "error creating authorization") } return nil } // GetOrder ACME api for retrieving an order. func GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, r, err) return } prov, err := provisionerFromContext(ctx) if err != nil { render.Error(w, r, err) return } o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } if err = o.UpdateStatus(ctx, db); err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error updating order status")) return } linker.LinkOrder(ctx, o) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSON(w, r, o) } // FinalizeOrder attempts to finalize an order and create a certificate. func FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { render.Error(w, r, err) return } prov, err := provisionerFromContext(ctx) if err != nil { render.Error(w, r, err) return } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, r, err) return } var fr FinalizeRequest if err := json.Unmarshal(payload.value, &fr); err != nil { render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal finalize-order request payload")) return } if err := fr.Validate(); err != nil { render.Error(w, r, err) return } o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } ca := mustAuthority(ctx) if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error finalizing order")) return } linker.LinkOrder(ctx, o) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSON(w, r, o) } // challengeTypes determines the types of challenges that should be used // for the ACME authorization request. func challengeTypes(az *acme.Authorization) []acme.ChallengeType { var chTypes []acme.ChallengeType switch az.Identifier.Type { case acme.IP: chTypes = []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01} case acme.DNS: chTypes = []acme.ChallengeType{acme.DNS01} // HTTP and TLS challenges can only be used for identifiers without wildcards. if !az.Wildcard { chTypes = append(chTypes, []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}...) } case acme.PermanentIdentifier: chTypes = []acme.ChallengeType{acme.DEVICEATTEST01} case acme.WireUser: chTypes = []acme.ChallengeType{acme.WIREOIDC01} case acme.WireDevice: chTypes = []acme.ChallengeType{acme.WIREDPOP01} default: chTypes = []acme.ChallengeType{} } return chTypes } ================================================ FILE: acme/api/order_test.go ================================================ package api import ( "bytes" "context" "crypto/x509" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "net/url" "reflect" "testing" "time" "github.com/go-chi/chi/v5" "github.com/pkg/errors" "go.step.sm/crypto/pemutil" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner/wire" sassert "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewOrderRequest_Validate(t *testing.T) { type test struct { nor *NewOrderRequest nbf, naf time.Time err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-identifiers": func(t *testing.T) test { return test{ nor: &NewOrderRequest{}, err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), } }, "fail/bad-identifier": func(t *testing.T) test { return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "example.com"}, {Type: "foo", Value: "bar.com"}, }, }, err: acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: foo"), } }, "fail/bad-identifier/bad-dns": func(t *testing.T) test { return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "xn--bücher.example.com"}, }, }, err: acme.NewError(acme.ErrorMalformedType, "invalid DNS name: xn--bücher.example.com"), } }, "fail/bad-identifier/dns-port": func(t *testing.T) test { return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "example.com:8080"}, }, }, err: acme.NewError(acme.ErrorMalformedType, "invalid DNS name: example.com:8080"), } }, "fail/bad-identifier/dns-wildcard-port": func(t *testing.T) test { return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "*.example.com:8080"}, }, }, err: acme.NewError(acme.ErrorMalformedType, "invalid DNS name: *.example.com:8080"), } }, "fail/bad-identifier/ip": func(t *testing.T) test { nbf := time.Now().UTC().Add(time.Minute) naf := time.Now().UTC().Add(5 * time.Minute) return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "ip", Value: "192.168.42.1000"}, }, NotAfter: naf, NotBefore: nbf, }, nbf: nbf, naf: naf, err: acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", "192.168.42.1000"), } }, "fail/bad-identifier/wireapp-invalid-uri": func(t *testing.T) test { return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "wireapp-user", Value: `{"name": "Alice Smith", "domain": "wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`}, {Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, }, }, err: acme.NewError(acme.ErrorMalformedType, `failed validating Wire identifiers: invalid Wire client ID "example.com": invalid Wire client ID scheme ""; expected "wireapp"`), } }, "fail/bad-identifier/wireapp-wrong-scheme": func(t *testing.T) test { return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "wireapp-user", Value: `{"name": "Alice Smith", "domain": "wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`}, {Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "nowireapp://example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, }, }, err: acme.NewError(acme.ErrorMalformedType, `failed validating Wire identifiers: invalid Wire client ID "nowireapp://example.com": invalid Wire client ID scheme "nowireapp"; expected "wireapp"`), } }, "fail/bad-identifier/wireapp-invalid-user-parts": func(t *testing.T) test { return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "wireapp-user", Value: `{"name": "Alice Smith", "domain": "wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`}, {Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "wireapp://user-device@example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, }, }, err: acme.NewError(acme.ErrorMalformedType, `failed validating Wire identifiers: invalid Wire client ID "wireapp://user-device@example.com": invalid Wire client ID username "user-device"`), } }, "ok": func(t *testing.T) test { nbf := time.Now().UTC().Add(time.Minute) naf := time.Now().UTC().Add(5 * time.Minute) return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "example.com"}, {Type: "dns", Value: "*.bar.com"}, }, NotAfter: naf, NotBefore: nbf, }, nbf: nbf, naf: naf, } }, "ok/ipv4": func(t *testing.T) test { nbf := time.Now().UTC().Add(time.Minute) naf := time.Now().UTC().Add(5 * time.Minute) return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "ip", Value: "192.168.42.42"}, }, NotAfter: naf, NotBefore: nbf, }, nbf: nbf, naf: naf, } }, "ok/ipv6": func(t *testing.T) test { nbf := time.Now().UTC().Add(time.Minute) naf := time.Now().UTC().Add(5 * time.Minute) return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "ip", Value: "2001:db8::1"}, }, NotAfter: naf, NotBefore: nbf, }, nbf: nbf, naf: naf, } }, "ok/mixed-dns-and-ipv4": func(t *testing.T) test { nbf := time.Now().UTC().Add(time.Minute) naf := time.Now().UTC().Add(5 * time.Minute) return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "example.com"}, {Type: "ip", Value: "192.168.42.42"}, }, NotAfter: naf, NotBefore: nbf, }, nbf: nbf, naf: naf, } }, "ok/mixed-ipv4-and-ipv6": func(t *testing.T) test { nbf := time.Now().UTC().Add(time.Minute) naf := time.Now().UTC().Add(5 * time.Minute) return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "ip", Value: "192.168.42.42"}, {Type: "ip", Value: "2001:db8::1"}, }, NotAfter: naf, NotBefore: nbf, }, nbf: nbf, naf: naf, } }, "ok/wireapp": func(t *testing.T) test { nbf := time.Now().UTC().Add(time.Minute) naf := time.Now().UTC().Add(5 * time.Minute) return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "wireapp-user", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, {Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, }, NotAfter: naf, NotBefore: nbf, }, nbf: nbf, naf: naf, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { err := tc.nor.Validate() if tc.err != nil { assert.Error(t, err) var ae *acme.Error if assert.True(t, errors.As(err, &ae)) { assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } return } assert.NoError(t, err) if tc.nbf.IsZero() { assert.True(t, tc.nor.NotBefore.Before(time.Now().Add(time.Minute))) assert.True(t, tc.nor.NotBefore.After(time.Now().Add(-time.Minute))) } else { assert.Equals(t, tc.nor.NotBefore, tc.nbf) } if tc.naf.IsZero() { assert.True(t, tc.nor.NotAfter.Before(time.Now().Add(24*time.Hour))) assert.True(t, tc.nor.NotAfter.After(time.Now().Add(24*time.Hour-time.Minute))) } else { assert.Equals(t, tc.nor.NotAfter, tc.naf) } }) } } func TestFinalizeRequestValidate(t *testing.T) { _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") assert.FatalError(t, err) csr, ok := _csr.(*x509.CertificateRequest) assert.Fatal(t, ok) type test struct { fr *FinalizeRequest err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/parse-csr-error": func(t *testing.T) test { return test{ fr: &FinalizeRequest{}, err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), } }, "fail/invalid-csr-signature": func(t *testing.T) test { b, err := pemutil.Read("../../authority/testdata/certs/badsig.csr") assert.FatalError(t, err) c, ok := b.(*x509.CertificateRequest) assert.Fatal(t, ok) return test{ fr: &FinalizeRequest{ CSR: base64.RawURLEncoding.EncodeToString(c.Raw), }, err: acme.NewError(acme.ErrorMalformedType, "csr failed signature check: x509: ECDSA verification failure"), } }, "ok": func(t *testing.T) test { return test{ fr: &FinalizeRequest{ CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), }, } }, "ok/padding": func(t *testing.T) test { return test{ fr: &FinalizeRequest{ CSR: base64.RawURLEncoding.EncodeToString(csr.Raw) + "==", // add intentional padding }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { if err := tc.fr.Validate(); err != nil { if assert.NotNil(t, err) { var ae *acme.Error if assert.True(t, errors.As(err, &ae)) { assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.fr.csr.Raw, csr.Raw) } } }) } } func TestHandler_GetOrder(t *testing.T) { prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} now := clock.Now() nbf := now naf := now.Add(24 * time.Hour) expiry := now.Add(-time.Hour) o := acme.Order{ ID: "orderID", NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ { Type: "dns", Value: "example.com", }, { Type: "dns", Value: "*.smallstep.com", }, }, ExpiresAt: expiry, Status: acme.StatusInvalid, Error: acme.NewError(acme.ErrorMalformedType, "order has expired"), AuthorizationURLs: []string{ fmt.Sprintf("%s/acme/%s/authz/foo", baseURL.String(), escProvName), fmt.Sprintf("%s/acme/%s/authz/bar", baseURL.String(), escProvName), fmt.Sprintf("%s/acme/%s/authz/baz", baseURL.String(), escProvName), }, FinalizeURL: fmt.Sprintf("%s/acme/%s/order/orderID/finalize", baseURL.String(), escProvName), } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", o.ID) u := fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), escProvName, o.ID) type test struct { db acme.DB ctx context.Context statusCode int err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), nil) ctx = context.WithValue(ctx, accContextKey, acc) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/db.GetOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { return &acme.Order{AccountID: "foo"}, nil }, }, ctx: ctx, statusCode: 401, err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), } }, "fail/provisioner-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { return &acme.Order{AccountID: "accountID", ProvisionerID: "bar"}, nil }, }, ctx: ctx, statusCode: 401, err: acme.NewError(acme.ErrorUnauthorizedType, "provisioner id mismatch"), } }, "fail/order-update-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { return &acme.Order{ AccountID: "accountID", ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), ExpiresAt: clock.Now().Add(-time.Hour), Status: acme.StatusReady, }, nil }, MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { return &acme.Order{ ID: "orderID", AccountID: "accountID", ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), ExpiresAt: expiry, Status: acme.StatusReady, AuthorizationIDs: []string{"foo", "bar", "baz"}, NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ { Type: "dns", Value: "example.com", }, { Type: "dns", Value: "*.smallstep.com", }, }, }, nil }, MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { return nil }, }, ctx: ctx, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() GetOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(o) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{u}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } func TestHandler_newAuthorization(t *testing.T) { defaultProvisioner := newProv() fakeKey := `-----BEGIN PUBLIC KEY----- MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= -----END PUBLIC KEY-----` wireProvisioner := newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wire.Options{ OIDC: &wire.OIDCOptions{ Provider: &wire.Provider{ IssuerURL: "https://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wire.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wire.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, }) wireProvisionerFailOptions := &provisioner.ACME{ Type: "ACME", Name: "test@acme-provisioner.com", Options: &provisioner.Options{}, Challenges: []provisioner.ACMEChallenge{ provisioner.WIREOIDC_01, provisioner.WIREDPOP_01, }, } type test struct { az *acme.Authorization prov acme.Provisioner db acme.DB err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/error-db.CreateChallenge": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "dns", Value: "zap.internal", }, } return test{ prov: defaultProvisioner, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { assert.Equals(t, ch.AccountID, az.AccountID) assert.Equals(t, ch.Type, acme.DNS01) assert.Equals(t, ch.Token, az.Token) assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, az.Identifier.Value) return errors.New("force") }, }, az: az, err: &acme.Error{ Type: "urn:ietf:params:acme:error:serverInternal", Err: errors.New("error creating challenge: force"), Detail: "The server experienced an internal error", Status: 500, }, } }, "fail/error-db.CreateAuthorization": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "dns", Value: "zap.internal", }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } count := 0 var ch1, ch2, ch3 **acme.Challenge return test{ prov: defaultProvisioner, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, az.AccountID) assert.Equals(t, ch.Token, az.Token) assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, az.Identifier.Value) return nil }, MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { assert.Equals(t, _az.AccountID, az.AccountID) assert.Equals(t, _az.Token, az.Token) assert.Equals(t, _az.Status, acme.StatusPending) assert.Equals(t, _az.Identifier, az.Identifier) assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) assert.Equals(t, _az.Wildcard, false) return errors.New("force") }, }, az: az, err: &acme.Error{ Type: "urn:ietf:params:acme:error:serverInternal", Err: errors.New("error creating authorization: force"), Detail: "The server experienced an internal error", Status: 500, }, } }, "fail/wireapp-user-options": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "wireapp-user", Value: "wireapp://%40alice.smith.qa@example.com", }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } return test{ prov: wireProvisionerFailOptions, db: &acme.MockDB{}, az: az, err: &acme.Error{ Type: "urn:ietf:params:acme:error:serverInternal", Err: errors.New("failed getting Wire options: no Wire options available"), Detail: "The server experienced an internal error", Status: 500, }, } }, "fail/wireapp-device-parse-id": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "wireapp-device", Value: `{"name}`, }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } return test{ prov: wireProvisioner, db: &acme.MockDB{}, az: az, err: &acme.Error{ Type: "urn:ietf:params:acme:error:malformed", Err: errors.New("failed parsing WireDevice: unexpected end of JSON input"), Detail: "The request message was malformed", Status: 400, }, } }, "fail/wireapp-device-parse-client-id": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "wireapp-device", Value: `{"name": "device", "domain": "wire.com", "client-id": "CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`, }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } return test{ prov: wireProvisioner, db: &acme.MockDB{}, az: az, err: &acme.Error{ Type: "urn:ietf:params:acme:error:malformed", Err: errors.New(`failed parsing ClientID: invalid Wire client ID scheme ""; expected "wireapp"`), Detail: "The request message was malformed", Status: 400, }, } }, "fail/wireapp-device-options": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "wireapp-device", Value: `{"name": "device", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`, }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } return test{ prov: wireProvisionerFailOptions, db: &acme.MockDB{}, az: az, err: &acme.Error{ Type: "urn:ietf:params:acme:error:serverInternal", Err: errors.New("failed getting Wire options: no Wire options available"), Detail: "The server experienced an internal error", Status: 500, }, } }, "ok/no-wildcard": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "dns", Value: "zap.internal", }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } count := 0 var ch1, ch2, ch3 **acme.Challenge return test{ prov: defaultProvisioner, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, az.AccountID) assert.Equals(t, ch.Token, az.Token) assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, az.Identifier.Value) return nil }, MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { assert.Equals(t, _az.AccountID, az.AccountID) assert.Equals(t, _az.Token, az.Token) assert.Equals(t, _az.Status, acme.StatusPending) assert.Equals(t, _az.Identifier, az.Identifier) assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) assert.Equals(t, _az.Wildcard, false) return nil }, }, az: az, } }, "ok/wildcard": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "dns", Value: "*.zap.internal", }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } var ch1 **acme.Challenge return test{ prov: defaultProvisioner, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) assert.Equals(t, ch.AccountID, az.AccountID) assert.Equals(t, ch.Token, az.Token) assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "zap.internal") ch1 = &ch return nil }, MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { assert.Equals(t, _az.AccountID, az.AccountID) assert.Equals(t, _az.Token, az.Token) assert.Equals(t, _az.Status, acme.StatusPending) assert.Equals(t, _az.Identifier, acme.Identifier{ Type: "dns", Value: "zap.internal", }) assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1}) assert.Equals(t, _az.Wildcard, true) return nil }, }, az: az, } }, "ok/permanent-identifier-disabled": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "permanent-identifier", Value: "7b53aa19-26f7-4fac-824f-7a781de0dab0", }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } return test{ prov: defaultProvisioner, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { t.Errorf("createChallenge should not be called") return nil }, MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { assert.Equals(t, _az.AccountID, az.AccountID) assert.Equals(t, _az.Token, az.Token) assert.Equals(t, _az.Status, acme.StatusPending) assert.Equals(t, _az.Identifier, az.Identifier) assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) assert.Equals(t, _az.Challenges, []*acme.Challenge{}) assert.Equals(t, _az.Wildcard, false) return nil }, }, az: az, } }, "ok/permanent-identifier-enabled": func(t *testing.T) test { var ch1 *acme.Challenge az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "permanent-identifier", Value: "7b53aa19-26f7-4fac-824f-7a781de0dab0", }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } deviceAttestProv := newProv() deviceAttestProv.(*provisioner.ACME).Challenges = []provisioner.ACMEChallenge{provisioner.DEVICE_ATTEST_01} return test{ prov: deviceAttestProv, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { ch.ID = "997bacc2-c175-4214-a3b4-a229ada5f671" assert.Equals(t, ch.Type, acme.DEVICEATTEST01) assert.Equals(t, ch.AccountID, az.AccountID) assert.Equals(t, ch.Token, az.Token) assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "7b53aa19-26f7-4fac-824f-7a781de0dab0") ch1 = ch return nil }, MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { assert.Equals(t, _az.AccountID, az.AccountID) assert.Equals(t, _az.Token, az.Token) assert.Equals(t, _az.Status, acme.StatusPending) assert.Equals(t, _az.Identifier, az.Identifier) assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) assert.Equals(t, _az.Challenges, []*acme.Challenge{ch1}) assert.Equals(t, _az.Wildcard, false) return nil }, }, az: az, } }, "ok/wireapp-user": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "wireapp-user", Value: "wireapp://%40alice.smith.qa@example.com", }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } count := 0 var ch1 **acme.Challenge return test{ prov: wireProvisioner, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "wireapp-user" assert.Equals(t, ch.Type, acme.WIREOIDC01) ch1 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, az.AccountID) assert.Equals(t, ch.Token, az.Token) assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, az.Identifier.Value) return nil }, MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { assert.Equals(t, _az.AccountID, az.AccountID) assert.Equals(t, _az.Token, az.Token) assert.Equals(t, _az.Status, acme.StatusPending) assert.Equals(t, _az.Identifier, az.Identifier) assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) _ = ch1 // assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1}) assert.Equals(t, _az.Wildcard, false) return nil }, }, az: az, } }, "ok/wireapp-device": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ Type: "wireapp-device", Value: `{"name": "device", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`, }, Status: acme.StatusPending, ExpiresAt: clock.Now(), } count := 0 var ch1 **acme.Challenge return test{ prov: wireProvisioner, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "wireapp-device" assert.Equals(t, ch.Type, acme.WIREDPOP01) ch1 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, az.AccountID) assert.Equals(t, ch.Token, az.Token) assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, az.Identifier.Value) return nil }, MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { assert.Equals(t, _az.AccountID, az.AccountID) assert.Equals(t, _az.Token, az.Token) assert.Equals(t, _az.Status, acme.StatusPending) assert.Equals(t, _az.Identifier, az.Identifier) assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) _ = ch1 // assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1}) assert.Equals(t, _az.Wildcard, false) return nil }, }, az: az, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) ctx := newBaseContext(context.Background(), tc.db) ctx = acme.NewProvisionerContext(ctx, tc.prov) err := newAuthorization(ctx, tc.az) if tc.err != nil { sassert.Error(t, err) var k *acme.Error if sassert.True(t, errors.As(err, &k)) { sassert.Equal(t, tc.err.Type, k.Type) sassert.Equal(t, tc.err.Detail, k.Detail) sassert.Equal(t, tc.err.Status, k.Status) sassert.EqualError(t, k.Err, tc.err.Error()) } return } sassert.NoError(t, err) }) } } func TestHandler_NewOrder(t *testing.T) { // Request with chi context prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} u := fmt.Sprintf("%s/acme/%s/order/ordID", baseURL.String(), escProvName) fakeWireSigningKey := `-----BEGIN PUBLIC KEY----- MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= -----END PUBLIC KEY-----` type test struct { ca acme.CertificateAuthority db acme.DB ctx context.Context nor *NewOrderRequest statusCode int vr func(t *testing.T, o *acme.Order) err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), } }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), } }, "fail/acmeProvisionerFromContext-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), &acme.MockProvisioner{}) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 500, ca: &mockCA{}, db: &acme.MockDB{ MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, errors.New("force") }, }, err: acme.NewErrorISE("error retrieving external account binding key: force"), } }, "fail/db.GetExternalAccountKeyByAccountID-error": func(t *testing.T) test { acmeProv := newACMEProv(t) acmeProv.RequireEAB = true acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), acmeProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 500, ca: &mockCA{}, db: &acme.MockDB{ MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, errors.New("force") }, }, err: acme.NewErrorISE("error retrieving external account binding key: force"), } }, "fail/newACMEPolicyEngine-error": func(t *testing.T) test { acmeProv := newACMEProv(t) acmeProv.RequireEAB = true acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), acmeProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 500, ca: &mockCA{}, db: &acme.MockDB{ MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return &acme.ExternalAccountKey{ Policy: &acme.Policy{ X509: acme.X509Policy{ Allowed: acme.PolicyNames{ DNSNames: []string{"**.local"}, }, }, }, }, nil }, }, err: acme.NewErrorISE("error creating ACME policy engine"), } }, "fail/isIdentifierAllowed-error": func(t *testing.T) test { acmeProv := newACMEProv(t) acmeProv.RequireEAB = true acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), acmeProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, ca: &mockCA{}, db: &acme.MockDB{ MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return &acme.ExternalAccountKey{ Policy: &acme.Policy{ X509: acme.X509Policy{ Allowed: acme.PolicyNames{ DNSNames: []string{"*.local"}, }, }, }, }, nil }, }, err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"), } }, "fail/prov.AuthorizeOrderIdentifier-error": func(t *testing.T) test { options := &provisioner.Options{ X509: &provisioner.X509Options{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.local"}, }, }, } provWithPolicy := newACMEProvWithOptions(t, options) provWithPolicy.RequireEAB = true acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, ca: &mockCA{}, db: &acme.MockDB{ MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return &acme.ExternalAccountKey{ Policy: &acme.Policy{ X509: acme.X509Policy{ Allowed: acme.PolicyNames{ DNSNames: []string{"*.internal"}, }, }, }, }, nil }, }, err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"), } }, "fail/ca.AreSANsAllowed-error": func(t *testing.T) test { options := &provisioner.Options{ X509: &provisioner.X509Options{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.internal"}, }, }, } provWithPolicy := newACMEProvWithOptions(t, options) provWithPolicy.RequireEAB = true acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, ca: &mockCA{ MockAreSANsallowed: func(ctx context.Context, sans []string) error { return errors.New("force: not authorized by authority") }, }, db: &acme.MockDB{ MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return &acme.ExternalAccountKey{ Policy: &acme.Policy{ X509: acme.X509Policy{ Allowed: acme.PolicyNames{ DNSNames: []string{"*.internal"}, }, }, }, }, nil }, }, err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"), } }, "fail/error-h.newAuthorization": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 500, ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { assert.Equals(t, ch.AccountID, "accID") assert.Equals(t, ch.Type, acme.DNS01) assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "zap.internal") return errors.New("force") }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, nil }, }, err: acme.NewErrorISE("error creating challenge: force"), } }, "fail/error-db.CreateOrder": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} fr := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string count = 0 ) return test{ ctx: ctx, statusCode: 500, ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, "accID") assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "zap.internal") return nil }, MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { az.ID = "az1ID" az1ID = &az.ID assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) assert.Equals(t, az.Identifier, fr.Identifiers[0]) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) assert.Equals(t, az.Wildcard, false) return nil }, MockCreateOrder: func(ctx context.Context, o *acme.Order) error { assert.Equals(t, o.AccountID, "accID") assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, fr.Identifiers) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return errors.New("force") }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, nil }, }, err: acme.NewErrorISE("error creating order: force"), } }, "ok/multiple-authz": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, {Type: "dns", Value: "*.zar.internal"}, }, } b, err := json.Marshal(nor) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( ch1, ch2, ch3, ch4 **acme.Challenge az1ID, az2ID *string chCount, azCount = 0, 0 ) return test{ ctx: ctx, statusCode: 201, nor: nor, ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch chCount { case 0: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) assert.Equals(t, ch.Value, "zap.internal") ch1 = &ch case 1: ch.ID = "http" assert.Equals(t, ch.Type, acme.HTTP01) assert.Equals(t, ch.Value, "zap.internal") ch2 = &ch case 2: ch.ID = "tls" assert.Equals(t, ch.Type, acme.TLSALPN01) assert.Equals(t, ch.Value, "zap.internal") ch3 = &ch case 3: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) assert.Equals(t, ch.Value, "zar.internal") ch4 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } chCount++ assert.Equals(t, ch.AccountID, "accID") assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) return nil }, MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { switch azCount { case 0: az.ID = "az1ID" az1ID = &az.ID assert.Equals(t, az.Identifier, nor.Identifiers[0]) assert.Equals(t, az.Wildcard, false) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) case 1: az.ID = "az2ID" az2ID = &az.ID assert.Equals(t, az.Identifier, acme.Identifier{ Type: acme.DNS, Value: "zar.internal", }) assert.Equals(t, az.Wildcard, true) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch4}) default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } azCount++ assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) return nil }, MockCreateOrder: func(ctx context.Context, o *acme.Order) error { o.ID = "ordID" assert.Equals(t, o.AccountID, "accID") assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID}) return nil }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, nil }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() testBufferDur := 5 * time.Second orderExpiry := now.Add(defaultOrderExpiry) expNbf := now.Add(-defaultOrderBackdate) expNaf := now.Add(prov.DefaultTLSCertDuration()) assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName), fmt.Sprintf("%s/acme/%s/authz/az2ID", baseURL.String(), escProvName), }) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) }, } }, "ok/default-naf-nbf": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(nor) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string count = 0 ) return test{ ctx: ctx, statusCode: 201, nor: nor, ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, "accID") assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "zap.internal") return nil }, MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { az.ID = "az1ID" az1ID = &az.ID assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) assert.Equals(t, az.Identifier, nor.Identifiers[0]) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) assert.Equals(t, az.Wildcard, false) return nil }, MockCreateOrder: func(ctx context.Context, o *acme.Order) error { o.ID = "ordID" assert.Equals(t, o.AccountID, "accID") assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, nil }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() testBufferDur := 5 * time.Second orderExpiry := now.Add(defaultOrderExpiry) expNbf := now.Add(-defaultOrderBackdate) expNaf := now.Add(prov.DefaultTLSCertDuration()) assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) }, } }, "ok/nbf-no-naf": func(t *testing.T) test { now := clock.Now() expNbf := now.Add(10 * time.Minute) acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, NotBefore: expNbf, } b, err := json.Marshal(nor) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string count = 0 ) return test{ ctx: ctx, statusCode: 201, nor: nor, ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, "accID") assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "zap.internal") return nil }, MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { az.ID = "az1ID" az1ID = &az.ID assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) assert.Equals(t, az.Identifier, nor.Identifiers[0]) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) assert.Equals(t, az.Wildcard, false) return nil }, MockCreateOrder: func(ctx context.Context, o *acme.Order) error { o.ID = "ordID" assert.Equals(t, o.AccountID, "accID") assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, nil }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() testBufferDur := 5 * time.Second orderExpiry := now.Add(defaultOrderExpiry) expNaf := expNbf.Add(prov.DefaultTLSCertDuration()) assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) }, } }, "ok/naf-no-nbf": func(t *testing.T) test { now := clock.Now() expNaf := now.Add(15 * time.Minute) acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, NotAfter: expNaf, } b, err := json.Marshal(nor) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string count = 0 ) return test{ ctx: ctx, statusCode: 201, nor: nor, ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, "accID") assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "zap.internal") return nil }, MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { az.ID = "az1ID" az1ID = &az.ID assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) assert.Equals(t, az.Identifier, nor.Identifiers[0]) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) assert.Equals(t, az.Wildcard, false) return nil }, MockCreateOrder: func(ctx context.Context, o *acme.Order) error { o.ID = "ordID" assert.Equals(t, o.AccountID, "accID") assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, nil }, }, vr: func(t *testing.T, o *acme.Order) { testBufferDur := 5 * time.Second orderExpiry := now.Add(defaultOrderExpiry) expNbf := now.Add(-defaultOrderBackdate) assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) }, } }, "ok/naf-nbf-from-ca": func(t *testing.T) test { now := clock.Now() expNaf := now.Add(15 * time.Minute) acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, NotAfter: expNaf, } b, err := json.Marshal(nor) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string count = 0 ) return test{ ctx: ctx, statusCode: 201, nor: nor, ca: &mockCA{ MockGetBackdate: func() *time.Duration { d, err := time.ParseDuration("1h") require.NoError(t, err) return &d }, }, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, "accID") assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "zap.internal") return nil }, MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { az.ID = "az1ID" az1ID = &az.ID assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) assert.Equals(t, az.Identifier, nor.Identifiers[0]) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) assert.Equals(t, az.Wildcard, false) return nil }, MockCreateOrder: func(ctx context.Context, o *acme.Order) error { o.ID = "ordID" assert.Equals(t, o.AccountID, "accID") assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, nil }, }, vr: func(t *testing.T, o *acme.Order) { testBufferDur := 5 * time.Second orderExpiry := now.Add(defaultOrderExpiry) expNbf := now.Add(-1 * time.Hour) assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) }, } }, "ok/default-naf-nbf-wireapp": func(t *testing.T) test { acmeWireProv := newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wire.Options{ OIDC: &wire.OIDCOptions{ Provider: &wire.Provider{ IssuerURL: "https://issuer.example.com", AuthURL: "", TokenURL: "", JWKSURL: "", UserInfoURL: "", Algorithms: []string{"ES256"}, }, Config: &wire.Config{ ClientID: "integration test", SignatureAlgorithms: []string{"ES256"}, SkipClientIDCheck: true, SkipExpiryCheck: true, SkipIssuerCheck: true, InsecureSkipSignatureCheck: true, Now: time.Now, }, }, DPOP: &wire.DPOPOptions{ SigningKey: []byte(fakeWireSigningKey), }, }, }) acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "wireapp-user", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "handle": "wireapp://%40alice_wire@wire.com"}`}, {Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com", "handle": "wireapp://%40alice_wire@wire.com"}`}, }, } b, err := json.Marshal(nor) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), acmeWireProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( ch1, ch2 **acme.Challenge az1ID, az2ID *string chCount, azCount = 0, 0 ) return test{ ctx: ctx, statusCode: 201, nor: nor, ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch chCount { case 0: assert.Equals(t, ch.Type, acme.WIREOIDC01) assert.Equals(t, ch.Value, `{"name": "Smith, Alice M (QA)", "domain": "example.com", "handle": "wireapp://%40alice_wire@wire.com"}`) ch.ID = "wireapp-oidc" ch1 = &ch case 1: assert.Equals(t, ch.Type, acme.WIREDPOP01) assert.Equals(t, ch.Value, `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com", "handle": "wireapp://%40alice_wire@wire.com"}`) ch.ID = "wireapp-dpop" ch2 = &ch default: require.Fail(t, "test logic error") } chCount++ assert.Equals(t, ch.AccountID, "accID") assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) _, _ = ch1, ch2 return nil }, MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { switch azCount { case 0: az.ID = "az1ID" az1ID = &az.ID assert.Equals(t, az.Identifier, nor.Identifiers[0]) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1}) case 1: az.ID = "az2ID" az2ID = &az.ID assert.Equals(t, az.Identifier, nor.Identifiers[1]) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch2}) default: require.Fail(t, "test logic error") } azCount++ assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) assert.Equals(t, az.Wildcard, false) return nil }, MockCreateOrder: func(ctx context.Context, o *acme.Order) error { o.ID = "ordID" assert.Equals(t, o.AccountID, "accID") assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID}) return nil }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, nil }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() testBufferDur := 5 * time.Second orderExpiry := now.Add(defaultOrderExpiry) expNbf := now.Add(-defaultOrderBackdate) expNaf := now.Add(prov.DefaultTLSCertDuration()) assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName), fmt.Sprintf("%s/acme/%s/authz/az2ID", baseURL.String(), escProvName), }) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) }, } }, "ok/naf-nbf": func(t *testing.T) test { now := clock.Now() expNbf := now.Add(5 * time.Minute) expNaf := now.Add(15 * time.Minute) acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, NotBefore: expNbf, NotAfter: expNaf, } b, err := json.Marshal(nor) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string count = 0 ) return test{ ctx: ctx, statusCode: 201, nor: nor, ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, "accID") assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "zap.internal") return nil }, MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { az.ID = "az1ID" az1ID = &az.ID assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) assert.Equals(t, az.Identifier, nor.Identifiers[0]) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) assert.Equals(t, az.Wildcard, false) return nil }, MockCreateOrder: func(ctx context.Context, o *acme.Order) error { o.ID = "ordID" assert.Equals(t, o.AccountID, "accID") assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, nil }, }, vr: func(t *testing.T, o *acme.Order) { testBufferDur := 5 * time.Second orderExpiry := now.Add(defaultOrderExpiry) assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) }, } }, "ok/default-naf-nbf-with-policy": func(t *testing.T) test { options := &provisioner.Options{ X509: &provisioner.X509Options{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.internal"}, }, }, } provWithPolicy := newACMEProvWithOptions(t, options) provWithPolicy.RequireEAB = true acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "zap.internal"}, }, } b, err := json.Marshal(nor) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string count = 0 ) return test{ ctx: ctx, statusCode: 201, nor: nor, ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: ch.ID = "dns" assert.Equals(t, ch.Type, acme.DNS01) ch1 = &ch case 1: ch.ID = "http" assert.Equals(t, ch.Type, acme.HTTP01) ch2 = &ch case 2: ch.ID = "tls" assert.Equals(t, ch.Type, acme.TLSALPN01) ch3 = &ch default: assert.FatalError(t, errors.New("test logic error")) return errors.New("force") } count++ assert.Equals(t, ch.AccountID, "accID") assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) assert.Equals(t, ch.Value, "zap.internal") return nil }, MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { az.ID = "az1ID" az1ID = &az.ID assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) assert.Equals(t, az.Identifier, nor.Identifiers[0]) assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3}) assert.Equals(t, az.Wildcard, false) return nil }, MockCreateOrder: func(ctx context.Context, o *acme.Order) error { o.ID = "ordID" assert.Equals(t, o.AccountID, "accID") assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) return nil }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, prov.GetID(), provisionerID) assert.Equals(t, "accID", accountID) return nil, nil }, }, vr: func(t *testing.T, o *acme.Order) { now := clock.Now() testBufferDur := 5 * time.Second orderExpiry := now.Add(defaultOrderExpiry) expNbf := now.Add(-defaultOrderBackdate) expNaf := now.Add(prov.DefaultTLSCertDuration()) assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.ca) ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() NewOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { ro := new(acme.Order) assert.FatalError(t, json.Unmarshal(body, ro)) if tc.vr != nil { tc.vr(t, ro) } assert.Equals(t, res.Header["Location"], []string{u}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } func TestHandler_FinalizeOrder(t *testing.T) { mockMustAuthority(t, &mockCA{}) prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} now := clock.Now() nbf := now naf := now.Add(24 * time.Hour) o := acme.Order{ ID: "orderID", NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ { Type: "dns", Value: "example.com", }, { Type: "dns", Value: "*.smallstep.com", }, }, ExpiresAt: naf, Status: acme.StatusValid, AuthorizationURLs: []string{ fmt.Sprintf("%s/acme/%s/authz/foo", baseURL.String(), escProvName), fmt.Sprintf("%s/acme/%s/authz/bar", baseURL.String(), escProvName), fmt.Sprintf("%s/acme/%s/authz/baz", baseURL.String(), escProvName), }, FinalizeURL: fmt.Sprintf("%s/acme/%s/order/orderID/finalize", baseURL.String(), escProvName), CertificateURL: fmt.Sprintf("%s/acme/%s/certificate/certID", baseURL.String(), escProvName), } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", o.ID) u := fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), escProvName, o.ID) _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") assert.FatalError(t, err) csr, ok := _csr.(*x509.CertificateRequest) assert.Fatal(t, ok) nor := &FinalizeRequest{ CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), } payloadBytes, err := json.Marshal(nor) assert.FatalError(t, err) type test struct { db acme.DB ctx context.Context statusCode int err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), } }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} fr := &FinalizeRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), } }, "fail/db.GetOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { return &acme.Order{AccountID: "foo"}, nil }, }, ctx: ctx, statusCode: 401, err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), } }, "fail/provisioner-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { return &acme.Order{AccountID: "accountID", ProvisionerID: "bar"}, nil }, }, ctx: ctx, statusCode: 401, err: acme.NewError(acme.ErrorUnauthorizedType, "provisioner id mismatch"), } }, "fail/order-finalize-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { return &acme.Order{ AccountID: "accountID", ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), ExpiresAt: clock.Now().Add(-time.Hour), Status: acme.StatusReady, }, nil }, MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { return &acme.Order{ ID: "orderID", AccountID: "accountID", ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()), ExpiresAt: naf, Status: acme.StatusValid, AuthorizationIDs: []string{"foo", "bar", "baz"}, NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ { Type: "dns", Value: "example.com", }, { Type: "dns", Value: "*.smallstep.com", }, }, CertificateID: "certID", }, nil }, }, ctx: ctx, statusCode: 200, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() FinalizeOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(o) assert.FatalError(t, err) ro := new(acme.Order) assert.FatalError(t, json.Unmarshal(body, ro)) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{u}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) } } func TestHandler_challengeTypes(t *testing.T) { type args struct { az *acme.Authorization } tests := []struct { name string args args want []acme.ChallengeType }{ { name: "ok/dns", args: args{ az: &acme.Authorization{ Identifier: acme.Identifier{Type: "dns", Value: "example.com"}, Wildcard: false, }, }, want: []acme.ChallengeType{acme.DNS01, acme.HTTP01, acme.TLSALPN01}, }, { name: "ok/wildcard", args: args{ az: &acme.Authorization{ Identifier: acme.Identifier{Type: "dns", Value: "*.example.com"}, Wildcard: true, }, }, want: []acme.ChallengeType{acme.DNS01}, }, { name: "ok/ip", args: args{ az: &acme.Authorization{ Identifier: acme.Identifier{Type: "ip", Value: "192.168.42.42"}, Wildcard: false, }, }, want: []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := challengeTypes(tt.args.az); !reflect.DeepEqual(got, tt.want) { t.Errorf("Handler.challengeTypes() = %v, want %v", got, tt.want) } }) } } func TestTrimIfWildcard(t *testing.T) { tests := []struct { name string arg string wantValue string wantBool bool }{ { name: "no trim", arg: "smallstep.com", wantValue: "smallstep.com", wantBool: false, }, { name: "trim", arg: "*.smallstep.com", wantValue: "smallstep.com", wantBool: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { v, ok := trimIfWildcard(tt.arg) assert.Equals(t, v, tt.wantValue) assert.Equals(t, ok, tt.wantBool) }) } } ================================================ FILE: acme/api/revoke.go ================================================ package api import ( "bytes" "context" "crypto/x509" "encoding/base64" "encoding/json" "fmt" "net/http" "strings" "go.step.sm/crypto/jose" "golang.org/x/crypto/ocsp" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" ) type revokePayload struct { Certificate string `json:"certificate"` ReasonCode *int `json:"reason,omitempty"` } // RevokeCert attempts to revoke a certificate. func RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() db := acme.MustDatabaseFromContext(ctx) linker := acme.MustLinkerFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, r, err) return } prov, err := provisionerFromContext(ctx) if err != nil { render.Error(w, r, err) return } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, r, err) return } var p revokePayload err = json.Unmarshal(payload.value, &p) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error unmarshaling payload")) return } certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) if err != nil { // in this case the most likely cause is a client that didn't properly encode the certificate render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) return } certToBeRevoked, err := x509.ParseCertificate(certBytes) if err != nil { // in this case a client may have encoded something different than a certificate render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) return } serial := certToBeRevoked.SerialNumber.String() dbCert, err := db.GetCertificateBySerial(ctx, serial) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return } if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { // this should never happen render.Error(w, r, acme.NewErrorISE("certificate raw bytes are not equal")) return } if shouldCheckAccountFrom(jws) { account, err := accountFromContext(ctx) if err != nil { render.Error(w, r, err) return } acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) if acmeErr != nil { render.Error(w, r, acmeErr) return } } else { // if account doesn't need to be checked, the JWS should be verified to be signed by the // private key that belongs to the public key in the certificate to be revoked. _, err := jws.Verify(certToBeRevoked.PublicKey) if err != nil { // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? render.Error(w, r, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) return } } ca := mustAuthority(ctx) hasBeenRevokedBefore, err := ca.IsRevoked(serial) if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) return } if hasBeenRevokedBefore { render.Error(w, r, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) return } reasonCode := p.ReasonCode acmeErr := validateReasonCode(reasonCode) if acmeErr != nil { render.Error(w, r, acmeErr) return } // Authorize revocation by ACME provisioner ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) err = prov.AuthorizeRevoke(ctx, "") if err != nil { render.Error(w, r, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) return } options := revokeOptions(serial, certToBeRevoked, reasonCode) err = ca.Revoke(ctx, options) if err != nil { render.Error(w, r, wrapRevokeErr(err)) return } logRevoke(w, options) w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index")) w.Write(nil) } // isAccountAuthorized checks if an ACME account that was retrieved earlier is authorized // to revoke the certificate. An Account must always be valid in order to revoke a certificate. // In case the certificate retrieved from the database belongs to the Account, the Account is // authorized. If the certificate retrieved from the database doesn't belong to the Account, // the identifiers in the certificate are extracted and compared against the (valid) Authorizations // that are stored for the ACME Account. If these sets match, the Account is considered authorized // to revoke the certificate. If this check fails, the client will receive an unauthorized error. func isAccountAuthorized(_ context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { if !account.IsValid() { return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil) } certificateBelongsToAccount := dbCert.AccountID == account.ID if certificateBelongsToAccount { return nil // return early } // TODO(hs): according to RFC8555: 7.6, a server MUST consider the following accounts authorized // to revoke a certificate: // // o the account that issued the certificate. // o an account that holds authorizations for all of the identifiers in the certificate. // // We currently only support the first case. The second might result in step going OOM when // large numbers of Authorizations are involved when the current nosql interface is in use. // We want to protect users from this failure scenario, so that's why it hasn't been added yet. // This issue is tracked in https://github.com/smallstep/certificates/issues/767 // not authorized; fail closed. return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' is not authorized", account.ID), nil) } // wrapRevokeErr is a best effort implementation to transform an error during // revocation into an ACME error, so that clients can understand the error. func wrapRevokeErr(err error) *acme.Error { t := err.Error() if strings.Contains(t, "is already revoked") { return acme.NewError(acme.ErrorAlreadyRevokedType, "%s", t) } return acme.WrapErrorISE(err, "error when revoking certificate") } // unauthorizedError returns an ACME error indicating the request was // not authorized to revoke the certificate. func wrapUnauthorizedError(cert *x509.Certificate, unauthorizedIdentifiers []acme.Identifier, msg string, err error) *acme.Error { var acmeErr *acme.Error if err == nil { acmeErr = acme.NewError(acme.ErrorUnauthorizedType, "%s", msg) } else { acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, "%s", msg) } acmeErr.Status = http.StatusForbidden // RFC8555 7.6 shows example with 403 switch { case len(unauthorizedIdentifiers) > 0: identifier := unauthorizedIdentifiers[0] // picking the first; compound may be an option too? acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", identifier.Value) case cert.Subject.String() != "": acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.CommonName) default: acmeErr.Detail = "No authorization provided" } return acmeErr } // logRevoke logs successful revocation of certificate func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { if rl, ok := w.(logging.ResponseLogger); ok { rl.WithFields(map[string]interface{}{ "serial": ri.Serial, "reasonCode": ri.ReasonCode, "reason": ri.Reason, "passiveOnly": ri.PassiveOnly, "ACME": ri.ACME, }) } } // validateReasonCode validates the revocation reason func validateReasonCode(reasonCode *int) *acme.Error { if reasonCode != nil && ((*reasonCode < ocsp.Unspecified || *reasonCode > ocsp.AACompromise) || *reasonCode == 7) { return acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds") } // NOTE: it's possible to add additional requirements to the reason code: // The server MAY disallow a subset of reasonCodes from being // used by the user. If a request contains a disallowed reasonCode, // then the server MUST reject it with the error type // "urn:ietf:params:acme:error:badRevocationReason" // No additional checks have been implemented so far. return nil } // revokeOptions determines the RevokeOptions for the Authority to use in revocation func revokeOptions(serial string, certToBeRevoked *x509.Certificate, reasonCode *int) *authority.RevokeOptions { opts := &authority.RevokeOptions{ Serial: serial, ACME: true, Crt: certToBeRevoked, } if reasonCode != nil { // NOTE: when implementing CRL and/or OCSP, and reason code is missing, CRL entry extension should be omitted opts.Reason = reason(*reasonCode) opts.ReasonCode = *reasonCode } return opts } // reason transforms an integer reason code to a // textual description of the revocation reason. func reason(reasonCode int) string { switch reasonCode { case ocsp.Unspecified: return "unspecified reason" case ocsp.KeyCompromise: return "key compromised" case ocsp.CACompromise: return "ca compromised" case ocsp.AffiliationChanged: return "affiliation changed" case ocsp.Superseded: return "superseded" case ocsp.CessationOfOperation: return "cessation of operation" case ocsp.CertificateHold: return "certificate hold" case ocsp.RemoveFromCRL: return "remove from crl" case ocsp.PrivilegeWithdrawn: return "privilege withdrawn" case ocsp.AACompromise: return "aa compromised" default: return "unspecified reason" } } // shouldCheckAccountFrom indicates whether an account should be // retrieved from the context, so that it can be used for // additional checks. This should only be done when no JWK // can be extracted from the request, as that would indicate // that the revocation request was signed with a certificate // key pair (and not an account key pair). Looking up such // a JWK would result in no Account being found. func shouldCheckAccountFrom(jws *jose.JSONWebSignature) bool { return !canExtractJWKFrom(jws) } ================================================ FILE: acme/api/revoke_test.go ================================================ package api import ( "bytes" "context" "crypto" "crypto/ecdsa" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/base64" "encoding/json" "fmt" "io" "math/big" "net" "net/http" "net/http/httptest" "net/url" "testing" "time" "github.com/go-chi/chi/v5" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "golang.org/x/crypto/ocsp" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" ) // v is a utility function to return the pointer to an integer func v(v int) *int { return &v } func generateSerial() (*big.Int, error) { return rand.Int(rand.Reader, big.NewInt(1000000000000000000)) } // generateCertKeyPair generates fresh x509 certificate/key pairs for testing func generateCertKeyPair() (*x509.Certificate, crypto.Signer, error) { pub, priv, err := keyutil.GenerateKeyPair("EC", "P-256", 0) if err != nil { return nil, nil, err } serial, err := generateSerial() if err != nil { return nil, nil, err } now := time.Now() template := &x509.Certificate{ Subject: pkix.Name{CommonName: "127.0.0.1"}, Issuer: pkix.Name{CommonName: "Test ACME Revoke Certificate"}, IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, IsCA: false, MaxPathLen: 0, KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, NotBefore: now, NotAfter: now.Add(time.Hour), SerialNumber: serial, } signer, ok := priv.(crypto.Signer) if !ok { return nil, nil, errors.Errorf("result is not a crypto.Signer: type %T", priv) } cert, err := x509util.CreateCertificate(template, template, pub, signer) return cert, signer, err } var errUnsupportedKey = fmt.Errorf("unknown key type; only RSA and ECDSA are supported") // keyID is the account identity provided by a CA during registration. type keyID string // noKeyID indicates that jwsEncodeJSON should compute and use JWK instead of a KID. // See jwsEncodeJSON for details. const noKeyID = keyID("") // jwsEncodeJSON signs claimset using provided key and a nonce. // The result is serialized in JSON format containing either kid or jwk // fields based on the provided keyID value. // // If kid is non-empty, its quoted value is inserted in the protected head // as "kid" field value. Otherwise, JWK is computed using jwkEncode and inserted // as "jwk" field value. The "jwk" and "kid" fields are mutually exclusive. // // See https://tools.ietf.org/html/rfc7515#section-7. // // If nonce is empty, it will not be encoded into the header. // Implementation taken from github.com/mholt/acmez, which seems to be based on // https://github.com/golang/crypto/blob/master/acme/jws.go. func jwsEncodeJSON(claimset interface{}, key crypto.Signer, kid keyID, nonce, u string) ([]byte, error) { alg, sha := jwsHasher(key.Public()) if alg == "" || !sha.Available() { return nil, errUnsupportedKey } phead, err := jwsHead(alg, nonce, u, kid, key) if err != nil { return nil, err } var payload string if claimset != nil { cs, err := json.Marshal(claimset) if err != nil { return nil, err } payload = base64.RawURLEncoding.EncodeToString(cs) } payloadToSign := []byte(phead + "." + payload) hash := sha.New() _, _ = hash.Write(payloadToSign) digest := hash.Sum(nil) sig, err := jwsSign(key, sha, digest) if err != nil { return nil, err } return jwsFinal(sha, sig, phead, payload) } // jwsHasher indicates suitable JWS algorithm name and a hash function // to use for signing a digest with the provided key. // It returns ("", 0) if the key is not supported. // Implementation taken from github.com/mholt/acmez, which seems to be based on // https://github.com/golang/crypto/blob/master/acme/jws.go. func jwsHasher(pub crypto.PublicKey) (string, crypto.Hash) { switch pub := pub.(type) { case *rsa.PublicKey: return "RS256", crypto.SHA256 case *ecdsa.PublicKey: switch pub.Params().Name { case "P-256": return "ES256", crypto.SHA256 case "P-384": return "ES384", crypto.SHA384 case "P-521": return "ES512", crypto.SHA512 } } return "", 0 } // jwsSign signs the digest using the given key. // The hash is unused for ECDSA keys. // // Note: non-stdlib crypto.Signer implementations are expected to return // the signature in the format as specified in RFC7518. // See https://tools.ietf.org/html/rfc7518 for more details. // Implementation taken from github.com/mholt/acmez, which seems to be based on // https://github.com/golang/crypto/blob/master/acme/jws.go. func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) { if key, ok := key.(*ecdsa.PrivateKey); ok { // The key.Sign method of ecdsa returns ASN1-encoded signature. // So, we use the package Sign function instead // to get R and S values directly and format the result accordingly. r, s, err := ecdsa.Sign(rand.Reader, key, digest) if err != nil { return nil, err } rb, sb := r.Bytes(), s.Bytes() size := key.Params().BitSize / 8 if size%8 > 0 { size++ } sig := make([]byte, size*2) copy(sig[size-len(rb):], rb) copy(sig[size*2-len(sb):], sb) return sig, nil } return key.Sign(rand.Reader, digest, hash) } // jwsHead constructs the protected JWS header for the given fields. // Since jwk and kid are mutually-exclusive, the jwk will be encoded // only if kid is empty. If nonce is empty, it will not be encoded. // Implementation taken from github.com/mholt/acmez, which seems to be based on // https://github.com/golang/crypto/blob/master/acme/jws.go. func jwsHead(alg, nonce, u string, kid keyID, key crypto.Signer) (string, error) { phead := fmt.Sprintf(`{"alg":%q`, alg) if kid == noKeyID { jwk, err := jwkEncode(key.Public()) if err != nil { return "", err } phead += fmt.Sprintf(`,"jwk":%s`, jwk) } else { phead += fmt.Sprintf(`,"kid":%q`, kid) } if nonce != "" { phead += fmt.Sprintf(`,"nonce":%q`, nonce) } phead += fmt.Sprintf(`,"url":%q}`, u) phead = base64.RawURLEncoding.EncodeToString([]byte(phead)) return phead, nil } // jwkEncode encodes public part of an RSA or ECDSA key into a JWK. // The result is also suitable for creating a JWK thumbprint. // https://tools.ietf.org/html/rfc7517 // Implementation taken from github.com/mholt/acmez, which seems to be based on // https://github.com/golang/crypto/blob/master/acme/jws.go. func jwkEncode(pub crypto.PublicKey) (string, error) { switch pub := pub.(type) { case *rsa.PublicKey: // https://tools.ietf.org/html/rfc7518#section-6.3.1 n := pub.N e := big.NewInt(int64(pub.E)) // Field order is important. // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. return fmt.Sprintf(`{"e":%q,"kty":"RSA","n":%q}`, base64.RawURLEncoding.EncodeToString(e.Bytes()), base64.RawURLEncoding.EncodeToString(n.Bytes()), ), nil case *ecdsa.PublicKey: // https://tools.ietf.org/html/rfc7518#section-6.2.1 p := pub.Curve.Params() n := p.BitSize / 8 if p.BitSize%8 != 0 { n++ } x := pub.X.Bytes() if n > len(x) { x = append(make([]byte, n-len(x)), x...) } y := pub.Y.Bytes() if n > len(y) { y = append(make([]byte, n-len(y)), y...) } // Field order is important. // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. return fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, p.Name, base64.RawURLEncoding.EncodeToString(x), base64.RawURLEncoding.EncodeToString(y), ), nil } return "", errUnsupportedKey } // jwsFinal constructs the final JWS object. // Implementation taken from github.com/mholt/acmez, which seems to be based on // https://github.com/golang/crypto/blob/master/acme/jws.go. func jwsFinal(_ crypto.Hash, sig []byte, phead, payload string) ([]byte, error) { enc := struct { Protected string `json:"protected"` Payload string `json:"payload"` Sig string `json:"signature"` }{ Protected: phead, Payload: payload, Sig: base64.RawURLEncoding.EncodeToString(sig), } result, err := json.Marshal(&enc) if err != nil { return nil, err } return result, nil } type mockCA struct { MockIsRevoked func(sn string) (bool, error) MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error MockAreSANsallowed func(ctx context.Context, sans []string) error MockGetBackdate func() *time.Duration } func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { return nil, nil } func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error { if m.MockAreSANsallowed != nil { return m.MockAreSANsallowed(ctx, sans) } return nil } func (m *mockCA) IsRevoked(sn string) (bool, error) { if m.MockIsRevoked != nil { return m.MockIsRevoked(sn) } return false, nil } func (m *mockCA) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { if m.MockRevoke != nil { return m.MockRevoke(ctx, opts) } return nil } func (m *mockCA) LoadProvisionerByName(string) (provisioner.Interface, error) { return nil, nil } func (m *mockCA) GetBackdate() *time.Duration { if m.MockGetBackdate != nil { return m.MockGetBackdate() } return nil } func Test_validateReasonCode(t *testing.T) { tests := []struct { name string reasonCode *int want *acme.Error }{ { name: "ok", reasonCode: v(ocsp.Unspecified), want: nil, }, { name: "fail/too-low", reasonCode: v(-1), want: acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds"), }, { name: "fail/too-high", reasonCode: v(11), want: acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds"), }, { name: "fail/missing-7", reasonCode: v(7), want: acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateReasonCode(tt.reasonCode) if (err != nil) != (tt.want != nil) { t.Errorf("validateReasonCode() = %v, want %v", err, tt.want) } if err != nil { assert.Equals(t, err.Type, tt.want.Type) assert.Equals(t, err.Detail, tt.want.Detail) assert.Equals(t, err.Status, tt.want.Status) assert.Equals(t, err.Err.Error(), tt.want.Err.Error()) assert.Equals(t, err.Detail, tt.want.Detail) } }) } } func Test_reason(t *testing.T) { tests := []struct { name string reasonCode int want string }{ { name: "unspecified reason", reasonCode: ocsp.Unspecified, want: "unspecified reason", }, { name: "key compromised", reasonCode: ocsp.KeyCompromise, want: "key compromised", }, { name: "ca compromised", reasonCode: ocsp.CACompromise, want: "ca compromised", }, { name: "affiliation changed", reasonCode: ocsp.AffiliationChanged, want: "affiliation changed", }, { name: "superseded", reasonCode: ocsp.Superseded, want: "superseded", }, { name: "cessation of operation", reasonCode: ocsp.CessationOfOperation, want: "cessation of operation", }, { name: "certificate hold", reasonCode: ocsp.CertificateHold, want: "certificate hold", }, { name: "remove from crl", reasonCode: ocsp.RemoveFromCRL, want: "remove from crl", }, { name: "privilege withdrawn", reasonCode: ocsp.PrivilegeWithdrawn, want: "privilege withdrawn", }, { name: "aa compromised", reasonCode: ocsp.AACompromise, want: "aa compromised", }, { name: "default", reasonCode: -1, want: "unspecified reason", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := reason(tt.reasonCode); got != tt.want { t.Errorf("reason() = %v, want %v", got, tt.want) } }) } } func Test_revokeOptions(t *testing.T) { cert, _, err := generateCertKeyPair() assert.FatalError(t, err) type args struct { serial string certToBeRevoked *x509.Certificate reasonCode *int } tests := []struct { name string args args want *authority.RevokeOptions }{ { name: "ok/no-reasoncode", args: args{ serial: "1234", certToBeRevoked: cert, }, want: &authority.RevokeOptions{ Serial: "1234", Crt: cert, ACME: true, }, }, { name: "ok/including-reasoncode", args: args{ serial: "1234", certToBeRevoked: cert, reasonCode: v(ocsp.KeyCompromise), }, want: &authority.RevokeOptions{ Serial: "1234", Crt: cert, ACME: true, ReasonCode: 1, Reason: "key compromised", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := revokeOptions(tt.args.serial, tt.args.certToBeRevoked, tt.args.reasonCode); !cmp.Equal(got, tt.want) { t.Errorf("revokeOptions() diff =\n%s", cmp.Diff(got, tt.want)) } }) } } func TestHandler_RevokeCert(t *testing.T) { prov := &provisioner.ACME{ Type: "ACME", Name: "testprov", } escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} chiCtx := chi.NewRouteContext() revokeURL := fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), escProvName) cert, key, err := generateCertKeyPair() assert.FatalError(t, err) rp := &revokePayload{ Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), } payloadBytes, err := json.Marshal(rp) assert.FatalError(t, err) jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ { Protected: jose.Header{ Algorithm: jose.ES256, KeyID: "bar", ExtraHeaders: map[jose.HeaderKey]interface{}{ "url": revokeURL, }, }, }, }, } type test struct { db acme.DB ca acme.CertificateAuthority ctx context.Context statusCode int err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { ctx := context.Background() return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-provisioner": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/nil-provisioner": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx = acme.NewProvisionerContext(ctx, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), } }, "fail/no-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), } }, "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), } }, "fail/unmarshal-payload": func(t *testing.T) test { malformedPayload := []byte(`{"payload":malformed?}`) ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload}) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("error unmarshaling payload"), } }, "fail/wrong-certificate-encoding": func(t *testing.T) test { wrongPayload := &revokePayload{ Certificate: base64.StdEncoding.EncodeToString(cert.Raw), } wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: &acme.Error{ Type: "urn:ietf:params:acme:error:malformed", Status: 400, Detail: "The request message was malformed", }, } }, "fail/no-certificate-encoded": func(t *testing.T) test { emptyPayload := &revokePayload{ Certificate: base64.RawURLEncoding.EncodeToString([]byte{}), } emptyPayloadBytes, err := json.Marshal(emptyPayload) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: &acme.Error{ Type: "urn:ietf:params:acme:error:malformed", Status: 400, Detail: "The request message was malformed", }, } }, "fail/db.GetCertificateBySerial": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { return nil, errors.New("force") }, } return test{ db: db, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("error retrieving certificate by serial"), } }, "fail/different-certificate-contents": func(t *testing.T) test { aDifferentCert, _, err := generateCertKeyPair() assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ Leaf: aDifferentCert, }, nil }, } return test{ db: db, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("certificate raw bytes are not equal"), } }, "fail/no-account": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ Leaf: cert, }, nil }, } return test{ db: db, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context"), } }, "fail/nil-account": func(t *testing.T) test { ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, accContextKey, nil) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ Leaf: cert, }, nil }, } return test{ db: db, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context"), } }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", Leaf: cert, }, nil }, } ca := &mockCA{} return test{ db: db, ca: ca, ctx: ctx, statusCode: 403, err: &acme.Error{ Type: "urn:ietf:params:acme:error:unauthorized", Detail: "No authorization provided for name 127.0.0.1", Status: 403, }, } }, "fail/account-not-authorized": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "differentAccountID", Leaf: cert, }, nil }, MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { assert.Equals(t, "accountID", accountID) return []*acme.Authorization{ { AccountID: "accountID", Status: acme.StatusValid, Identifier: acme.Identifier{ Type: acme.IP, Value: "127.0.1.0", }, }, }, nil }, } ca := &mockCA{} return test{ db: db, ca: ca, ctx: ctx, statusCode: 403, err: &acme.Error{ Type: "urn:ietf:params:acme:error:unauthorized", Detail: "No authorization provided for name 127.0.0.1", Status: 403, }, } }, "fail/unauthorized-certificate-key": func(t *testing.T) test { _, unauthorizedKey, err := generateCertKeyPair() assert.FatalError(t, err) jwsPayload := &revokePayload{ Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), ReasonCode: v(2), } jwsBytes, err := jwsEncodeJSON(rp, unauthorizedKey, "", "nonce", revokeURL) assert.FatalError(t, err) parsedJWS, err := jose.ParseJWS(string(jwsBytes)) assert.FatalError(t, err) unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", Leaf: cert, }, nil }, } ca := &mockCA{} acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "verification of jws using certificate public key failed") acmeErr.Detail = "No authorization provided for name 127.0.0.1" return test{ db: db, ca: ca, ctx: ctx, statusCode: 403, err: acmeErr, } }, "fail/certificate-revoked-check-fails": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", Leaf: cert, }, nil }, } ca := &mockCA{ MockIsRevoked: func(sn string) (bool, error) { return false, errors.New("force") }, } return test{ db: db, ca: ca, ctx: ctx, statusCode: 500, err: &acme.Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, }, } }, "fail/certificate-already-revoked": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", Leaf: cert, }, nil }, } ca := &mockCA{ MockIsRevoked: func(sn string) (bool, error) { return true, nil }, } return test{ db: db, ca: ca, ctx: ctx, statusCode: 400, err: &acme.Error{ Type: "urn:ietf:params:acme:error:alreadyRevoked", Detail: "Certificate already revoked", Status: 400, }, } }, "fail/invalid-reasoncode": func(t *testing.T) test { invalidReasonPayload := &revokePayload{ Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), ReasonCode: v(7), } invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) assert.FatalError(t, err) acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", Leaf: cert, }, nil }, } ca := &mockCA{ MockIsRevoked: func(sn string) (bool, error) { return false, nil }, } return test{ db: db, ca: ca, ctx: ctx, statusCode: 400, err: &acme.Error{ Type: "urn:ietf:params:acme:error:badRevocationReason", Detail: "The revocation reason provided is not allowed by the server", Status: 400, }, } }, "fail/prov.AuthorizeRevoke": func(t *testing.T) test { assert.FatalError(t, err) mockACMEProv := &acme.MockProvisioner{ MauthorizeRevoke: func(ctx context.Context, token string) error { return errors.New("force") }, } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", Leaf: cert, }, nil }, } ca := &mockCA{ MockIsRevoked: func(sn string) (bool, error) { return false, nil }, } return test{ db: db, ca: ca, ctx: ctx, statusCode: 500, err: &acme.Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, }, } }, "fail/ca.Revoke": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", Leaf: cert, }, nil }, } ca := &mockCA{ MockRevoke: func(ctx context.Context, opts *authority.RevokeOptions) error { return errors.New("force") }, } return test{ db: db, ca: ca, ctx: ctx, statusCode: 500, err: &acme.Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, }, } }, "fail/ca.Revoke-already-revoked": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", Leaf: cert, }, nil }, } ca := &mockCA{ MockIsRevoked: func(sn string) (bool, error) { return false, nil }, MockRevoke: func(ctx context.Context, opts *authority.RevokeOptions) error { return fmt.Errorf("certificate with serial number '%s' is already revoked", cert.SerialNumber.String()) }, } return test{ db: db, ca: ca, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAlreadyRevokedType, "certificate with serial number '%s' is already revoked", cert.SerialNumber.String()), } }, "ok/using-account-key": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", Leaf: cert, }, nil }, } ca := &mockCA{} return test{ db: db, ca: ca, ctx: ctx, statusCode: 200, } }, "ok/using-certificate-key": func(t *testing.T) test { jwsBytes, err := jwsEncodeJSON(rp, key, "", "nonce", revokeURL) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(jwsBytes)) assert.FatalError(t, err) ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "someDifferentAccountID", Leaf: cert, }, nil }, } ca := &mockCA{} return test{ db: db, ca: ca, ctx: ctx, statusCode: 200, } }, } for name, setup := range tests { tc := setup(t) t.Run(name, func(t *testing.T) { ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) mockMustAuthority(t, tc.ca) req := httptest.NewRequest("POST", revokeURL, http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() RevokeCert(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.True(t, bytes.Equal(bytes.TrimSpace(body), []byte{})) assert.Equals(t, []string{fmt.Sprintf("<%s/acme/%s/directory>;rel=\"index\"", baseURL.String(), escProvName)}, res.Header["Link"]) } }) } } func TestHandler_isAccountAuthorized(t *testing.T) { type test struct { db acme.DB ctx context.Context existingCert *acme.Certificate certToBeRevoked *x509.Certificate account *acme.Account err *acme.Error } accountID := "accountID" var tests = map[string]func(t *testing.T) test{ "fail/account-invalid": func(t *testing.T) test { account := &acme.Account{ ID: accountID, Status: acme.StatusInvalid, } certToBeRevoked := &x509.Certificate{ Subject: pkix.Name{ CommonName: "127.0.0.1", }, } return test{ ctx: context.TODO(), certToBeRevoked: certToBeRevoked, account: account, err: &acme.Error{ Type: "urn:ietf:params:acme:error:unauthorized", Status: http.StatusForbidden, Detail: "No authorization provided for name 127.0.0.1", Err: errors.New("account 'accountID' has status 'invalid'"), }, } }, "fail/different-account": func(t *testing.T) test { account := &acme.Account{ ID: accountID, Status: acme.StatusValid, } certToBeRevoked := &x509.Certificate{ IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, } existingCert := &acme.Certificate{ AccountID: "differentAccountID", } return test{ db: &acme.MockDB{ MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { assert.Equals(t, "accountID", accountID) return []*acme.Authorization{ { AccountID: accountID, Status: acme.StatusValid, Identifier: acme.Identifier{ Type: acme.IP, Value: "127.0.0.1", }, }, }, nil }, }, ctx: context.TODO(), existingCert: existingCert, certToBeRevoked: certToBeRevoked, account: account, err: &acme.Error{ Type: "urn:ietf:params:acme:error:unauthorized", Status: http.StatusForbidden, Detail: "No authorization provided", Err: errors.New("account 'accountID' is not authorized"), }, } }, "ok": func(t *testing.T) test { account := &acme.Account{ ID: accountID, Status: acme.StatusValid, } certToBeRevoked := &x509.Certificate{ IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, } existingCert := &acme.Certificate{ AccountID: "accountID", } return test{ db: &acme.MockDB{ MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { assert.Equals(t, "accountID", accountID) return []*acme.Authorization{ { AccountID: accountID, Status: acme.StatusValid, Identifier: acme.Identifier{ Type: acme.IP, Value: "127.0.0.1", }, }, }, nil }, }, ctx: context.TODO(), existingCert: existingCert, certToBeRevoked: certToBeRevoked, account: account, err: nil, } }, } for name, setup := range tests { tc := setup(t) t.Run(name, func(t *testing.T) { // h := &Handler{db: tc.db} acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) expectError := tc.err != nil gotError := acmeErr != nil if expectError != gotError { t.Errorf("expected: %t, got: %t", expectError, gotError) return } if !gotError { return // nothing to check; return early } assert.Equals(t, acmeErr.Err.Error(), tc.err.Err.Error()) assert.Equals(t, acmeErr.Type, tc.err.Type) assert.Equals(t, acmeErr.Status, tc.err.Status) assert.Equals(t, acmeErr.Detail, tc.err.Detail) assert.Equals(t, acmeErr.Subproblems, tc.err.Subproblems) }) } } func Test_wrapUnauthorizedError(t *testing.T) { type test struct { cert *x509.Certificate unauthorizedIdentifiers []acme.Identifier msg string err error want *acme.Error } var tests = map[string]func(t *testing.T) test{ "unauthorizedIdentifiers": func(t *testing.T) test { acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "account 'accountID' is not authorized") acmeErr.Status = http.StatusForbidden acmeErr.Detail = "No authorization provided for name 127.0.0.1" return test{ err: nil, cert: nil, unauthorizedIdentifiers: []acme.Identifier{ { Type: acme.IP, Value: "127.0.0.1", }, }, msg: "account 'accountID' is not authorized", want: acmeErr, } }, "subject": func(t *testing.T) test { acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "account 'accountID' is not authorized") acmeErr.Status = http.StatusForbidden acmeErr.Detail = "No authorization provided for name test.example.com" cert := &x509.Certificate{ Subject: pkix.Name{ CommonName: "test.example.com", }, } return test{ err: nil, cert: cert, unauthorizedIdentifiers: []acme.Identifier{}, msg: "account 'accountID' is not authorized", want: acmeErr, } }, "wrap-subject": func(t *testing.T) test { acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "verification of jws using certificate public key failed: go-jose/go-jose: error in cryptographic primitive") acmeErr.Status = http.StatusForbidden acmeErr.Detail = "No authorization provided for name test.example.com" cert := &x509.Certificate{ Subject: pkix.Name{ CommonName: "test.example.com", }, } return test{ err: errors.New("go-jose/go-jose: error in cryptographic primitive"), cert: cert, unauthorizedIdentifiers: []acme.Identifier{}, msg: "verification of jws using certificate public key failed", want: acmeErr, } }, "default": func(t *testing.T) test { acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "account 'accountID' is not authorized") acmeErr.Status = http.StatusForbidden acmeErr.Detail = "No authorization provided" cert := &x509.Certificate{ Subject: pkix.Name{ CommonName: "", }, } return test{ err: nil, cert: cert, unauthorizedIdentifiers: []acme.Identifier{}, msg: "account 'accountID' is not authorized", want: acmeErr, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { acmeErr := wrapUnauthorizedError(tc.cert, tc.unauthorizedIdentifiers, tc.msg, tc.err) assert.Equals(t, acmeErr.Err.Error(), tc.want.Err.Error()) assert.Equals(t, acmeErr.Type, tc.want.Type) assert.Equals(t, acmeErr.Status, tc.want.Status) assert.Equals(t, acmeErr.Detail, tc.want.Detail) assert.Equals(t, acmeErr.Subproblems, tc.want.Subproblems) }) } } ================================================ FILE: acme/api/wire_integration_test.go ================================================ package api import ( "bytes" "context" "crypto/ed25519" "crypto/rand" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/base64" "encoding/json" "encoding/pem" "errors" "io" "net/http" "net/http/httptest" "net/url" "os" "strings" "testing" "time" "github.com/go-chi/chi/v5" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner/wire" nosqlDB "github.com/smallstep/nosql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "go.step.sm/crypto/minica" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" ) const ( baseURL = "test.ca.smallstep.com" linkerPrefix = "acme" ) func newWireProvisionerWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME { t.Helper() prov := &provisioner.ACME{ Type: "ACME", Name: "test@acme-provisioner.com", Options: options, Challenges: []provisioner.ACMEChallenge{ provisioner.WIREOIDC_01, provisioner.WIREDPOP_01, }, } err := prov.Init(provisioner.Config{ Claims: config.GlobalProvisionerClaims, }) require.NoError(t, err) return prov } // TODO(hs): replace with test CA server + acmez based test client for // more realistic integration test? func TestWireIntegration(t *testing.T) { accessTokenSignerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) accessTokenSignerPEMBlock, err := pemutil.Serialize(accessTokenSignerJWK.Public().Key) require.NoError(t, err) accessTokenSignerPEMBytes := pem.EncodeToMemory(accessTokenSignerPEMBlock) accessTokenSigner, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(accessTokenSignerJWK.Algorithm), Key: accessTokenSignerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) oidcTokenSignerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) oidcTokenSigner, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(oidcTokenSignerJWK.Algorithm), Key: oidcTokenSignerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) prov := newWireProvisionerWithOptions(t, &provisioner.Options{ X509: &provisioner.X509Options{ Template: `{ "subject": { "organization": "WireTest", "commonName": {{ toJson .Oidc.name }} }, "uris": [{{ toJson .Oidc.preferred_username }}, {{ toJson .Dpop.sub }}], "keyUsage": ["digitalSignature"], "extKeyUsage": ["clientAuth"] }`, }, Wire: &wire.Options{ OIDC: &wire.OIDCOptions{ Provider: &wire.Provider{ IssuerURL: "https://issuer.example.com", AuthURL: "", TokenURL: "", JWKSURL: "", UserInfoURL: "", Algorithms: []string{"ES256"}, }, Config: &wire.Config{ ClientID: "integration test", SignatureAlgorithms: []string{"ES256"}, SkipClientIDCheck: true, SkipExpiryCheck: true, SkipIssuerCheck: true, InsecureSkipSignatureCheck: true, // NOTE: this skips actual token verification Now: time.Now, }, TransformTemplate: "", }, DPOP: &wire.DPOPOptions{ SigningKey: accessTokenSignerPEMBytes, }, }, }) // mock provisioner and linker ctx := context.Background() ctx = acme.NewProvisionerContext(ctx, prov) ctx = acme.NewLinkerContext(ctx, acme.NewLinker(baseURL, linkerPrefix)) // create temporary BoltDB file file, err := os.CreateTemp(os.TempDir(), "integration-db-") require.NoError(t, err) t.Log("database file name:", file.Name()) dbFn := file.Name() err = file.Close() require.NoError(t, err) // open BoltDB rawDB, err := nosqlDB.New(nosqlDB.BBoltDriver, dbFn) require.NoError(t, err) // create tables db, err := nosql.New(rawDB) require.NoError(t, err) // make DB available to handlers ctx = acme.NewDatabaseContext(ctx, db) // simulate signed payloads by making the signing key available in ctx jwk, err := jose.GenerateJWK("OKP", "", "EdDSA", "sig", "", 0) require.NoError(t, err) ed25519PrivKey, ok := jwk.Key.(ed25519.PrivateKey) require.True(t, ok) dpopSigner, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk, }, new(jose.SignerOptions)) require.NoError(t, err) ed25519PubKey, ok := ed25519PrivKey.Public().(ed25519.PublicKey) require.True(t, ok) jwk.Key = ed25519PubKey ctx = context.WithValue(ctx, jwkContextKey, jwk) // get directory dir := func(ctx context.Context) (dir Directory) { req := httptest.NewRequest(http.MethodGet, "/foo/bar", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() GetDirectory(w, req) res := w.Result() require.Equal(t, http.StatusOK, res.StatusCode) body, err := io.ReadAll(res.Body) require.NoError(t, err) err = json.Unmarshal(bytes.TrimSpace(body), &dir) require.NoError(t, err) return }(ctx) t.Log("directory:", dir) // get nonce nonce := func(ctx context.Context) (nonce string) { req := httptest.NewRequest(http.MethodGet, dir.NewNonce, http.NoBody).WithContext(ctx) w := httptest.NewRecorder() addNonce(GetNonce)(w, req) res := w.Result() require.Equal(t, http.StatusNoContent, res.StatusCode) nonce = res.Header["Replay-Nonce"][0] return }(ctx) t.Log("nonce:", nonce) // create new account acc := func(ctx context.Context) (acc *acme.Account) { // create payload nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, } rawNar, err := json.Marshal(nar) require.NoError(t, err) // create account ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: rawNar}) req := httptest.NewRequest(http.MethodGet, dir.NewAccount, http.NoBody).WithContext(ctx) w := httptest.NewRecorder() NewAccount(w, req) res := w.Result() require.Equal(t, http.StatusCreated, res.StatusCode) body, err := io.ReadAll(res.Body) defer res.Body.Close() require.NoError(t, err) err = json.Unmarshal(bytes.TrimSpace(body), &acc) require.NoError(t, err) locationParts := strings.Split(res.Header["Location"][0], "/") acc, err = db.GetAccount(ctx, locationParts[len(locationParts)-1]) require.NoError(t, err) return }(ctx) ctx = context.WithValue(ctx, accContextKey, acc) t.Log("account ID:", acc.ID) // new order order := func(ctx context.Context) (order *acme.Order) { mockMustAuthority(t, &mockCA{}) nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ { Type: "wireapp-user", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`, }, { Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`, }, }, } b, err := json.Marshal(nor) require.NoError(t, err) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) req := httptest.NewRequest("POST", "https://random.local/", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() NewOrder(w, req) res := w.Result() require.Equal(t, http.StatusCreated, res.StatusCode) body, err := io.ReadAll(res.Body) defer res.Body.Close() require.NoError(t, err) err = json.Unmarshal(bytes.TrimSpace(body), &order) require.NoError(t, err) order, err = db.GetOrder(ctx, order.ID) require.NoError(t, err) return }(ctx) t.Log("authzs IDs:", order.AuthorizationIDs) // get authorization getAuthz := func(ctx context.Context, authzID string) (az *acme.Authorization) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("authzID", authzID) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) req := httptest.NewRequest(http.MethodGet, "https://random.local/", http.NoBody).WithContext(ctx) w := httptest.NewRecorder() GetAuthorization(w, req) res := w.Result() require.Equal(t, http.StatusOK, res.StatusCode) body, err := io.ReadAll(res.Body) defer res.Body.Close() require.NoError(t, err) err = json.Unmarshal(bytes.TrimSpace(body), &az) require.NoError(t, err) az, err = db.GetAuthorization(ctx, authzID) require.NoError(t, err) return } var azs []*acme.Authorization for _, azID := range order.AuthorizationIDs { az := getAuthz(ctx, azID) azs = append(azs, az) for _, challenge := range az.Challenges { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("chID", challenge.ID) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) var payload []byte switch challenge.Type { case acme.WIREDPOP01: dpopBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com", }, Challenge: "token", Handle: "wireapp://%40alice.smith.qa@example.com", Nonce: "nonce", HTU: "http://issuer.example.com", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) require.NoError(t, err) proof, err := dpop.CompactSerialize() require.NoError(t, err) tokenBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Cnf struct { Kid string `json:"kid,omitempty"` } `json:"cnf"` Proof string `json:"proof,omitempty"` ClientID string `json:"client_id"` APIVersion int `json:"api_version"` Scope string `json:"scope"` }{ Claims: jose.Claims{ Issuer: "http://issuer.example.com", Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Challenge: "token", Cnf: struct { Kid string `json:"kid,omitempty"` }{ Kid: jwk.KeyID, }, Proof: proof, ClientID: "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com", APIVersion: 5, Scope: "wire_client_id", }) require.NoError(t, err) signed, err := accessTokenSigner.Sign(tokenBytes) require.NoError(t, err) accessToken, err := signed.CompactSerialize() require.NoError(t, err) p, err := json.Marshal(struct { AccessToken string `json:"access_token"` }{ AccessToken: accessToken, }) require.NoError(t, err) payload = p case acme.WIREOIDC01: keyAuth, err := acme.KeyAuthorization("token", jwk) require.NoError(t, err) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` }{ Claims: jose.Claims{ Issuer: "https://issuer.example.com", Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40alice_wire@wire.com", KeyAuth: keyAuth, }) require.NoError(t, err) signed, err := oidcTokenSigner.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) p, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) payload = p default: require.Fail(t, "unexpected challenge payload type") } ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payload}) req := httptest.NewRequest(http.MethodGet, "https://random.local/", http.NoBody).WithContext(ctx) w := httptest.NewRecorder() GetChallenge(w, req) res := w.Result() require.Equal(t, http.StatusOK, res.StatusCode) body, err := io.ReadAll(res.Body) defer res.Body.Close() //nolint:gocritic // close the body require.NoError(t, err) err = json.Unmarshal(bytes.TrimSpace(body), &challenge) require.NoError(t, err) t.Log("challenge:", challenge.ID, challenge.Status) } } // get/validate challenge simulation updateAz := func(ctx context.Context, az *acme.Authorization) (updatedAz *acme.Authorization) { now := clock.Now().Format(time.RFC3339) for _, challenge := range az.Challenges { challenge.Status = acme.StatusValid challenge.ValidatedAt = now err := db.UpdateChallenge(ctx, challenge) if err != nil { t.Error("updating challenge", challenge.ID, ":", err) } } updatedAz, err = db.GetAuthorization(ctx, az.ID) require.NoError(t, err) return } for _, az := range azs { updatedAz := updateAz(ctx, az) for _, challenge := range updatedAz.Challenges { t.Log("updated challenge:", challenge.ID, challenge.Status) switch challenge.Type { case acme.WIREOIDC01: err = db.CreateOidcToken(ctx, order.ID, map[string]any{"name": "Smith, Alice M (QA)", "preferred_username": "wireapp://%40alice.smith.qa@example.com"}) require.NoError(t, err) case acme.WIREDPOP01: err = db.CreateDpopToken(ctx, order.ID, map[string]any{"sub": "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com"}) require.NoError(t, err) default: require.Fail(t, "unexpected challenge type") } } } // get order updatedOrder := func(ctx context.Context) (updatedOrder *acme.Order) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", order.ID) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) req := httptest.NewRequest(http.MethodGet, "https://random.local/", http.NoBody).WithContext(ctx) w := httptest.NewRecorder() GetOrder(w, req) res := w.Result() require.Equal(t, http.StatusOK, res.StatusCode) body, err := io.ReadAll(res.Body) defer res.Body.Close() require.NoError(t, err) err = json.Unmarshal(bytes.TrimSpace(body), &updatedOrder) require.NoError(t, err) require.Equal(t, acme.StatusReady, updatedOrder.Status) return }(ctx) t.Log("updated order status:", updatedOrder.Status) // finalize order finalizedOrder := func(ctx context.Context) (finalizedOrder *acme.Order) { ca, err := minica.New(minica.WithName("WireTestCA")) require.NoError(t, err) mockMustAuthority(t, &mockCASigner{ signer: func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { var ( certOptions []x509util.Option ) for _, op := range extraOpts { if k, ok := op.(provisioner.CertificateOptions); ok { certOptions = append(certOptions, k.Options(signOpts)...) } } x509utilTemplate, err := x509util.NewCertificate(csr, certOptions...) require.NoError(t, err) template := x509utilTemplate.GetCertificate() require.NotNil(t, template) cert, err := ca.Sign(template) require.NoError(t, err) u1, err := url.Parse("wireapp://%40alice.smith.qa@example.com") require.NoError(t, err) u2, err := url.Parse("wireapp://lJGYPz0ZRq2kvc_XpdaDlA%21ed416ce8ecdd9fad@example.com") require.NoError(t, err) assert.Equal(t, []*url.URL{u1, u2}, cert.URIs) assert.Equal(t, "Smith, Alice M (QA)", cert.Subject.CommonName) return []*x509.Certificate{cert, ca.Intermediate}, nil }, }) qUserID, err := url.Parse("wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com") require.NoError(t, err) qUserName, err := url.Parse("wireapp://%40alice.smith.qa@example.com") require.NoError(t, err) _, priv, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) csrTemplate := &x509.CertificateRequest{ Subject: pkix.Name{ Organization: []string{"example.com"}, ExtraNames: []pkix.AttributeTypeAndValue{ { Type: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: "Smith, Alice M (QA)", }, }, }, URIs: []*url.URL{ qUserName, qUserID, }, SignatureAlgorithm: x509.PureEd25519, } csr, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, priv) require.NoError(t, err) fr := FinalizeRequest{CSR: base64.RawURLEncoding.EncodeToString(csr)} frRaw, err := json.Marshal(fr) require.NoError(t, err) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: frRaw}) chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", order.ID) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) req := httptest.NewRequest(http.MethodGet, "https://random.local/", http.NoBody).WithContext(ctx) w := httptest.NewRecorder() FinalizeOrder(w, req) res := w.Result() require.Equal(t, http.StatusOK, res.StatusCode) body, err := io.ReadAll(res.Body) defer res.Body.Close() require.NoError(t, err) err = json.Unmarshal(bytes.TrimSpace(body), &finalizedOrder) require.NoError(t, err) require.Equal(t, acme.StatusValid, finalizedOrder.Status) finalizedOrder, err = db.GetOrder(ctx, order.ID) require.NoError(t, err) return }(ctx) t.Log("finalized order status:", finalizedOrder.Status) } type mockCASigner struct { signer func(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) } func (m *mockCASigner) SignWithContext(_ context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { if m.signer == nil { return nil, errors.New("unimplemented") } return m.signer(cr, opts, signOpts...) } func (m *mockCASigner) AreSANsAllowed(ctx context.Context, sans []string) error { return nil } func (m *mockCASigner) IsRevoked(sn string) (bool, error) { return false, nil } func (m *mockCASigner) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { return nil } func (m *mockCASigner) LoadProvisionerByName(string) (provisioner.Interface, error) { return nil, nil } func (m *mockCASigner) GetBackdate() *time.Duration { return nil } ================================================ FILE: acme/authorization.go ================================================ package acme import ( "context" "encoding/json" "time" ) // Authorization representst an ACME Authorization. type Authorization struct { ID string `json:"-"` AccountID string `json:"-"` Token string `json:"-"` Fingerprint string `json:"-"` Identifier Identifier `json:"identifier"` Status Status `json:"status"` Challenges []*Challenge `json:"challenges"` Wildcard bool `json:"wildcard"` ExpiresAt time.Time `json:"expires"` Error *Error `json:"error,omitempty"` } // ToLog enables response logging. func (az *Authorization) ToLog() (interface{}, error) { b, err := json.Marshal(az) if err != nil { return nil, WrapErrorISE(err, "error marshaling authz for logging") } return string(b), nil } // UpdateStatus updates the ACME Authorization Status if necessary. // Changes to the Authorization are saved using the database interface. func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { now := clock.Now() switch az.Status { case StatusInvalid: return nil case StatusValid: return nil case StatusPending: // check expiry if now.After(az.ExpiresAt) { az.Status = StatusInvalid break } var isValid = false for _, ch := range az.Challenges { if ch.Status == StatusValid { isValid = true break } } if !isValid { return nil } az.Status = StatusValid az.Error = nil default: return NewErrorISE("unrecognized authorization status: %s", az.Status) } if err := db.UpdateAuthorization(ctx, az); err != nil { return WrapErrorISE(err, "error updating authorization") } return nil } ================================================ FILE: acme/authorization_test.go ================================================ package acme import ( "context" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" ) func TestAuthorization_UpdateStatus(t *testing.T) { type test struct { az *Authorization err *Error db DB } tests := map[string]func(t *testing.T) test{ "ok/already-invalid": func(t *testing.T) test { az := &Authorization{ Status: StatusInvalid, } return test{ az: az, } }, "ok/already-valid": func(t *testing.T) test { az := &Authorization{ Status: StatusInvalid, } return test{ az: az, } }, "fail/error-unexpected-status": func(t *testing.T) test { az := &Authorization{ Status: "foo", } return test{ az: az, err: NewErrorISE("unrecognized authorization status: %s", az.Status), } }, "ok/expired": func(t *testing.T) test { now := clock.Now() az := &Authorization{ ID: "azID", AccountID: "accID", Status: StatusPending, ExpiresAt: now.Add(-5 * time.Minute), } return test{ az: az, db: &MockDB{ MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { assert.Equals(t, updaz.ID, az.ID) assert.Equals(t, updaz.AccountID, az.AccountID) assert.Equals(t, updaz.Status, StatusInvalid) assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) return nil }, }, } }, "fail/db.UpdateAuthorization-error": func(t *testing.T) test { now := clock.Now() az := &Authorization{ ID: "azID", AccountID: "accID", Status: StatusPending, ExpiresAt: now.Add(-5 * time.Minute), } return test{ az: az, db: &MockDB{ MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { assert.Equals(t, updaz.ID, az.ID) assert.Equals(t, updaz.AccountID, az.AccountID) assert.Equals(t, updaz.Status, StatusInvalid) assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) return errors.New("force") }, }, err: NewErrorISE("error updating authorization: force"), } }, "ok/no-valid-challenges": func(t *testing.T) test { now := clock.Now() az := &Authorization{ ID: "azID", AccountID: "accID", Status: StatusPending, ExpiresAt: now.Add(5 * time.Minute), Challenges: []*Challenge{ {Status: StatusPending}, {Status: StatusPending}, {Status: StatusPending}, }, } return test{ az: az, } }, "ok/valid": func(t *testing.T) test { now := clock.Now() az := &Authorization{ ID: "azID", AccountID: "accID", Status: StatusPending, ExpiresAt: now.Add(5 * time.Minute), Challenges: []*Challenge{ {Status: StatusPending}, {Status: StatusPending}, {Status: StatusValid}, }, } return test{ az: az, db: &MockDB{ MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { assert.Equals(t, updaz.ID, az.ID) assert.Equals(t, updaz.AccountID, az.AccountID) assert.Equals(t, updaz.Status, StatusValid) assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) assert.Equals(t, updaz.Error, nil) return nil }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if err := tc.az.UpdateStatus(context.Background(), tc.db); err != nil { if assert.NotNil(t, tc.err) { var k *Error if errors.As(err, &k) { assert.Equals(t, k.Type, tc.err.Type) assert.Equals(t, k.Detail, tc.err.Detail) assert.Equals(t, k.Status, tc.err.Status) assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) assert.Equals(t, k.Detail, tc.err.Detail) } else { assert.FatalError(t, errors.New("unexpected error type")) } } } else { assert.Nil(t, tc.err) } }) } } ================================================ FILE: acme/certificate.go ================================================ package acme import ( "crypto/x509" ) // Certificate options with which to create and store a cert object. type Certificate struct { ID string AccountID string OrderID string Leaf *x509.Certificate Intermediates []*x509.Certificate } ================================================ FILE: acme/challenge.go ================================================ package acme import ( "bytes" "context" "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rsa" "crypto/sha256" "crypto/subtle" "crypto/tls" "crypto/x509" "encoding/asn1" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net" "net/url" "reflect" "slices" "strconv" "strings" "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/fxamacker/cbor/v2" "github.com/google/go-tpm/legacy/tpm2" "github.com/smallstep/go-attestation/attest" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/acme/wire" "github.com/smallstep/certificates/authority/provisioner" wireprovisioner "github.com/smallstep/certificates/authority/provisioner/wire" "github.com/smallstep/certificates/internal/cast" ) type ChallengeType string const ( // HTTP01 is the http-01 ACME challenge type HTTP01 ChallengeType = "http-01" // DNS01 is the dns-01 ACME challenge type DNS01 ChallengeType = "dns-01" // TLSALPN01 is the tls-alpn-01 ACME challenge type TLSALPN01 ChallengeType = "tls-alpn-01" // DEVICEATTEST01 is the device-attest-01 ACME challenge type DEVICEATTEST01 ChallengeType = "device-attest-01" // WIREOIDC01 is the Wire OIDC challenge type WIREOIDC01 ChallengeType = "wire-oidc-01" // WIREDPOP01 is the Wire DPoP challenge type WIREDPOP01 ChallengeType = "wire-dpop-01" ) var ( // InsecurePortHTTP01 is the port used to verify http-01 challenges. If not set it // defaults to 80. InsecurePortHTTP01 int // InsecurePortTLSALPN01 is the port used to verify tls-alpn-01 challenges. If not // set it defaults to 443. // // This variable can be used for testing purposes. InsecurePortTLSALPN01 int // StrictFQDN allows to enforce a fully qualified domain name in the DNS // resolution. By default it allows domain resolution using a search list // defined in the resolv.conf or similar configuration. StrictFQDN bool ) // Challenge represents an ACME response Challenge type. type Challenge struct { ID string `json:"-"` AccountID string `json:"-"` AuthorizationID string `json:"-"` Value string `json:"-"` Type ChallengeType `json:"type"` Status Status `json:"status"` Token string `json:"token"` ValidatedAt string `json:"validated,omitempty"` URL string `json:"url"` Target string `json:"target,omitempty"` Error *Error `json:"error,omitempty"` Payload []byte `json:"-"` PayloadFormat string `json:"-"` } // ToLog enables response logging. func (ch *Challenge) ToLog() (interface{}, error) { b, err := json.Marshal(ch) if err != nil { return nil, WrapErrorISE(err, "error marshaling challenge for logging") } return string(b), nil } // Validate attempts to validate the Challenge. Stores changes to the Challenge // type using the DB interface. If the Challenge is validated, the 'status' and // 'validated' attributes are updated. func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, payload []byte) error { // If already valid or invalid then return without performing validation. if ch.Status != StatusPending { return nil } switch ch.Type { case HTTP01: return http01Validate(ctx, ch, db, jwk) case DNS01: return dns01Validate(ctx, ch, db, jwk) case TLSALPN01: return tlsalpn01Validate(ctx, ch, db, jwk) case DEVICEATTEST01: return deviceAttest01Validate(ctx, ch, db, jwk, payload) case WIREOIDC01: wireDB, ok := db.(WireDB) if !ok { return NewErrorISE("db %T is not a WireDB", db) } return wireOIDC01Validate(ctx, ch, wireDB, jwk, payload) case WIREDPOP01: wireDB, ok := db.(WireDB) if !ok { return NewErrorISE("db %T is not a WireDB", db) } return wireDPOP01Validate(ctx, ch, wireDB, jwk, payload) default: return NewErrorISE("unexpected challenge type %q", ch.Type) } } func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} challengeURL := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} // Append insecure port if set. // Only used for testing purposes. if InsecurePortHTTP01 != 0 { insecurePort := strconv.Itoa(InsecurePortHTTP01) u.Host += ":" + insecurePort challengeURL.Host += ":" + insecurePort } vc := MustClientFromContext(ctx) resp, err := vc.Get(challengeURL.String()) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, "error doing http GET for url %s", u)) } defer resp.Body.Close() if resp.StatusCode >= 400 { return storeError(ctx, db, ch, false, NewError(ErrorConnectionType, "error doing http GET for url %s with status code %d", u, resp.StatusCode)) } body, err := io.ReadAll(resp.Body) if err != nil { return WrapErrorISE(err, "error reading "+ "response body for url %s", u) } keyAuth := strings.TrimSpace(string(body)) expected, err := KeyAuthorization(ch.Token, jwk) if err != nil { return err } if keyAuth != expected { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expected, keyAuth)) } // Update and store the challenge. ch.Status = StatusValid ch.Error = nil ch.ValidatedAt = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "error updating challenge") } return nil } // rootedName adds a trailing "." to a given domain name. func rootedName(name string) string { if StrictFQDN { if name == "" || name[len(name)-1] != '.' { return name + "." } } return name } // http01ChallengeHost checks if a Challenge value is an IPv6 address // and adds square brackets if that's the case, so that it can be used // as a hostname. Returns the original Challenge value as the host to // use in other cases. func http01ChallengeHost(value string) string { if ip := net.ParseIP(value); ip != nil { if ip.To4() == nil { value = "[" + value + "]" } return value } return rootedName(value) } // tlsAlpn01ChallengeHost returns the rooted DNS used on TLS-ALPN-01 // validations. func tlsAlpn01ChallengeHost(name string) string { if ip := net.ParseIP(name); ip != nil { return name } return rootedName(name) } // dns01ChallengeHost returns the TXT record used in DNS-01 validations. func dns01ChallengeHost(domain string) string { return "_acme-challenge." + rootedName(domain) } func tlsAlert(err error) uint8 { var opErr *net.OpError if errors.As(err, &opErr) { v := reflect.ValueOf(opErr.Err) if v.Kind() == reflect.Uint8 { return cast.Uint8(v.Uint()) } } return 0 } func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { config := &tls.Config{ NextProtos: []string{"acme-tls/1"}, // https://tools.ietf.org/html/rfc8737#section-4 // ACME servers that implement "acme-tls/1" MUST only negotiate TLS 1.2 // [RFC5246] or higher when connecting to clients for validation. MinVersion: tls.VersionTLS12, ServerName: serverName(ch), InsecureSkipVerify: true, //nolint:gosec // we expect a self-signed challenge certificate } // Allow to change TLS port for testing purposes. hostPort := tlsAlpn01ChallengeHost(ch.Value) if port := InsecurePortTLSALPN01; port == 0 { hostPort = net.JoinHostPort(hostPort, "443") } else { hostPort = net.JoinHostPort(hostPort, strconv.Itoa(port)) } vc := MustClientFromContext(ctx) conn, err := vc.TLSDial("tcp", hostPort, config) if err != nil { // With Go 1.17+ tls.Dial fails if there's no overlap between configured // client and server protocols. When this happens the connection is // closed with the error no_application_protocol(120) as required by // RFC7301. See https://golang.org/doc/go1.17#ALPN if tlsAlert(err) == 120 { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) } return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, "error doing TLS dial for %s", ch.Value)) } defer conn.Close() cs := conn.ConnectionState() certs := cs.PeerCertificates if len(certs) == 0 { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "%s challenge for %s resulted in no certificates", ch.Type, ch.Value)) } if cs.NegotiatedProtocol != "acme-tls/1" { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge")) } leafCert := certs[0] // if no DNS names present, look for IP address and verify that exactly one exists if len(leafCert.DNSNames) == 0 { if len(leafCert.IPAddresses) != 1 || !leafCert.IPAddresses[0].Equal(net.ParseIP(ch.Value)) { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value)) } } else { if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value)) } } idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} foundIDPeAcmeIdentifierV1Obsolete := false keyAuth, err := KeyAuthorization(ch.Token, jwk) if err != nil { return err } hashedKeyAuth := sha256.Sum256([]byte(keyAuth)) for _, ext := range leafCert.Extensions { if idPeAcmeIdentifier.Equal(ext.Id) { if !ext.Critical { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical")) } var extValue []byte rest, err := asn1.Unmarshal(ext.Value, &extValue) if err != nil || len(rest) > 0 || len(hashedKeyAuth) != len(extValue) { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value")) } if subtle.ConstantTimeCompare(hashedKeyAuth[:], extValue) != 1 { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ "expected acmeValidationV1 extension value %s for this challenge but got %s", hex.EncodeToString(hashedKeyAuth[:]), hex.EncodeToString(extValue))) } ch.Status = StatusValid ch.Error = nil ch.ValidatedAt = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge") } return nil } if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) { foundIDPeAcmeIdentifierV1Obsolete = true } } if foundIDPeAcmeIdentifierV1Obsolete { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: obsolete id-pe-acmeIdentifier in acmeValidationV1 extension")) } return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { // Normalize domain for wildcard DNS names // This is done to avoid making TXT lookups for domains like // _acme-challenge.*.example.com // Instead perform txt lookup for _acme-challenge.example.com domain := strings.TrimPrefix(ch.Value, "*.") vc := MustClientFromContext(ctx) txtRecords, err := vc.LookupTxt(dns01ChallengeHost(domain)) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err, "error looking up TXT records for domain %s", domain)) } expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk) if err != nil { return err } h := sha256.Sum256([]byte(expectedKeyAuth)) expected := base64.RawURLEncoding.EncodeToString(h[:]) var found bool for _, r := range txtRecords { if r == expected { found = true break } } if !found { return storeError(ctx, db, ch, false, NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expectedKeyAuth, txtRecords)) } // Update and store the challenge. ch.Status = StatusValid ch.Error = nil ch.ValidatedAt = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "error updating challenge") } return nil } type wireOidcPayload struct { // IDToken contains the OIDC identity token IDToken string `json:"id_token"` } func wireOIDC01Validate(ctx context.Context, ch *Challenge, db WireDB, jwk *jose.JSONWebKey, payload []byte) error { prov, ok := ProvisionerFromContext(ctx) if !ok { return NewErrorISE("missing provisioner") } wireOptions, err := prov.GetOptions().GetWireOptions() if err != nil { return WrapErrorISE(err, "failed getting Wire options") } linker, ok := LinkerFromContext(ctx) if !ok { return NewErrorISE("missing linker") } var oidcPayload wireOidcPayload if err := json.Unmarshal(payload, &oidcPayload); err != nil { return WrapError(ErrorMalformedType, err, "error unmarshalling Wire OIDC challenge payload") } wireID, err := wire.ParseUserID(ch.Value) if err != nil { return WrapErrorISE(err, "error unmarshalling challenge data") } oidcOptions := wireOptions.GetOIDCOptions() verifier, err := oidcOptions.GetVerifier(ctx) if err != nil { return WrapErrorISE(err, "no OIDC verifier available") } idToken, err := verifier.Verify(ctx, oidcPayload.IDToken) if err != nil { return storeError(ctx, db, ch, true, WrapError(ErrorRejectedIdentifierType, err, "error verifying ID token signature")) } var claims struct { Name string `json:"preferred_username,omitempty"` Handle string `json:"name"` Issuer string `json:"iss,omitempty"` GivenName string `json:"given_name,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud,omitempty"` } if err := idToken.Claims(&claims); err != nil { return storeError(ctx, db, ch, true, WrapError(ErrorRejectedIdentifierType, err, "error retrieving claims from ID token")) } // TODO(hs): move this into validation below? expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk) if err != nil { return WrapErrorISE(err, "error determining key authorization") } if expectedKeyAuth != claims.KeyAuth { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %q, but got %q", expectedKeyAuth, claims.KeyAuth)) } // audience is the full URL to the challenge acmeAudience := linker.GetLink(ctx, ChallengeLinkType, ch.AuthorizationID, ch.ID) if claims.ACMEAudience != acmeAudience { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "invalid 'acme_aud' %q", claims.ACMEAudience)) } transformedIDToken, err := validateWireOIDCClaims(oidcOptions, idToken, wireID) if err != nil { return storeError(ctx, db, ch, true, WrapError(ErrorRejectedIdentifierType, err, "claims in OIDC ID token don't match")) } // Update and store the challenge. ch.Status = StatusValid ch.Error = nil ch.ValidatedAt = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "error updating challenge") } orders, err := db.GetAllOrdersByAccountID(ctx, ch.AccountID) if err != nil { return WrapErrorISE(err, "could not retrieve current order by account id") } if len(orders) == 0 { return NewErrorISE("there are not enough orders for this account for this custom OIDC challenge") } order := orders[len(orders)-1] if err := db.CreateOidcToken(ctx, order, transformedIDToken); err != nil { return WrapErrorISE(err, "failed storing OIDC id token") } return nil } func validateWireOIDCClaims(o *wireprovisioner.OIDCOptions, token *oidc.IDToken, wireID wire.UserID) (map[string]any, error) { var m map[string]any if err := token.Claims(&m); err != nil { return nil, fmt.Errorf("failed extracting OIDC ID token claims: %w", err) } transformed, err := o.Transform(m) if err != nil { return nil, fmt.Errorf("failed transforming OIDC ID token: %w", err) } name, ok := transformed["name"] if !ok { return nil, fmt.Errorf("transformed OIDC ID token does not contain 'name'") } if wireID.Name != name { return nil, fmt.Errorf("invalid 'name' %q after transformation", name) } preferredUsername, ok := transformed["preferred_username"] if !ok { return nil, fmt.Errorf("transformed OIDC ID token does not contain 'preferred_username'") } if wireID.Handle != preferredUsername { return nil, fmt.Errorf("invalid 'preferred_username' %q after transformation", preferredUsername) } return transformed, nil } type wireDpopPayload struct { // AccessToken is the token generated by wire-server AccessToken string `json:"access_token"` } func wireDPOP01Validate(ctx context.Context, ch *Challenge, db WireDB, accountJWK *jose.JSONWebKey, payload []byte) error { prov, ok := ProvisionerFromContext(ctx) if !ok { return NewErrorISE("missing provisioner") } wireOptions, err := prov.GetOptions().GetWireOptions() if err != nil { return WrapErrorISE(err, "failed getting Wire options") } linker, ok := LinkerFromContext(ctx) if !ok { return NewErrorISE("missing linker") } var dpopPayload wireDpopPayload if err := json.Unmarshal(payload, &dpopPayload); err != nil { return WrapError(ErrorMalformedType, err, "error unmarshalling Wire DPoP challenge payload") } wireID, err := wire.ParseDeviceID(ch.Value) if err != nil { return WrapErrorISE(err, "error unmarshalling challenge data") } clientID, err := wire.ParseClientID(wireID.ClientID) if err != nil { return WrapErrorISE(err, "error parsing device id") } dpopOptions := wireOptions.GetDPOPOptions() issuer, err := dpopOptions.EvaluateTarget(clientID.DeviceID) if err != nil { return WrapErrorISE(err, "invalid Go template registered for 'target'") } // audience is the full URL to the challenge audience := linker.GetLink(ctx, ChallengeLinkType, ch.AuthorizationID, ch.ID) params := wireVerifyParams{ token: dpopPayload.AccessToken, tokenKey: dpopOptions.GetSigningKey(), dpopKey: accountJWK.Public(), dpopKeyID: accountJWK.KeyID, issuer: issuer, audience: audience, wireID: wireID, chToken: ch.Token, t: clock.Now().UTC(), } _, dpop, err := parseAndVerifyWireAccessToken(params) if err != nil { return storeError(ctx, db, ch, true, WrapError(ErrorRejectedIdentifierType, err, "failed validating Wire access token")) } // Update and store the challenge. ch.Status = StatusValid ch.Error = nil ch.ValidatedAt = clock.Now().Format(time.RFC3339) if err = db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "error updating challenge") } orders, err := db.GetAllOrdersByAccountID(ctx, ch.AccountID) if err != nil { return WrapErrorISE(err, "could not find current order by account id") } if len(orders) == 0 { return NewErrorISE("there are not enough orders for this account for this custom OIDC challenge") } order := orders[len(orders)-1] if err := db.CreateDpopToken(ctx, order, map[string]any(*dpop)); err != nil { return WrapErrorISE(err, "failed storing DPoP token") } return nil } type wireCnf struct { Kid string `json:"kid"` } type wireAccessToken struct { jose.Claims Challenge string `json:"chal"` Nonce string `json:"nonce"` Cnf wireCnf `json:"cnf"` Proof string `json:"proof"` ClientID string `json:"client_id"` APIVersion int `json:"api_version"` Scope string `json:"scope"` } type wireDpopJwt struct { jose.Claims ClientID string `json:"client_id"` Challenge string `json:"chal"` Nonce string `json:"nonce"` HTU string `json:"htu"` } type wireDpopToken map[string]any type wireVerifyParams struct { token string tokenKey crypto.PublicKey dpopKey crypto.PublicKey dpopKeyID string issuer string audience string wireID wire.DeviceID chToken string t time.Time } func parseAndVerifyWireAccessToken(v wireVerifyParams) (*wireAccessToken, *wireDpopToken, error) { jwt, err := jose.ParseSigned(v.token) if err != nil { return nil, nil, fmt.Errorf("failed parsing token: %w", err) } if len(jwt.Headers) != 1 { return nil, nil, fmt.Errorf("token has wrong number of headers %d", len(jwt.Headers)) } keyID, err := KeyToID(&jose.JSONWebKey{Key: v.tokenKey}) if err != nil { return nil, nil, fmt.Errorf("failed calculating token key ID: %w", err) } jwtKeyID := jwt.Headers[0].KeyID if jwtKeyID == "" { if jwtKeyID, err = KeyToID(jwt.Headers[0].JSONWebKey); err != nil { return nil, nil, fmt.Errorf("failed extracting token key ID: %w", err) } } if jwtKeyID != keyID { return nil, nil, fmt.Errorf("invalid token key ID %q", jwtKeyID) } var accessToken wireAccessToken if err = jwt.Claims(v.tokenKey, &accessToken); err != nil { return nil, nil, fmt.Errorf("failed validating Wire DPoP token claims: %w", err) } if err := accessToken.ValidateWithLeeway(jose.Expected{ Time: v.t, Issuer: v.issuer, Audience: jose.Audience{v.audience}, }, 1*time.Minute); err != nil { return nil, nil, fmt.Errorf("failed validation: %w", err) } if accessToken.Challenge == "" { return nil, nil, errors.New("access token challenge 'chal' must not be empty") } if accessToken.Cnf.Kid == "" || accessToken.Cnf.Kid != v.dpopKeyID { return nil, nil, fmt.Errorf("expected 'kid' %q; got %q", v.dpopKeyID, accessToken.Cnf.Kid) } if accessToken.ClientID != v.wireID.ClientID { return nil, nil, fmt.Errorf("invalid Wire 'client_id' %q", accessToken.ClientID) } if accessToken.Expiry.Time().After(v.t.Add(time.Hour)) { return nil, nil, fmt.Errorf("token expiry 'exp' %s is too far into the future", accessToken.Expiry.Time().String()) } if accessToken.Scope != "wire_client_id" { return nil, nil, fmt.Errorf("invalid Wire 'scope' %q", accessToken.Scope) } dpopJWT, err := jose.ParseSigned(accessToken.Proof) if err != nil { return nil, nil, fmt.Errorf("invalid Wire DPoP token: %w", err) } if len(dpopJWT.Headers) != 1 { return nil, nil, fmt.Errorf("DPoP token has wrong number of headers %d", len(jwt.Headers)) } dpopJwtKeyID := dpopJWT.Headers[0].KeyID if dpopJwtKeyID == "" { if dpopJwtKeyID, err = KeyToID(dpopJWT.Headers[0].JSONWebKey); err != nil { return nil, nil, fmt.Errorf("failed extracting DPoP token key ID: %w", err) } } if dpopJwtKeyID != v.dpopKeyID { return nil, nil, fmt.Errorf("invalid DPoP token key ID %q", dpopJWT.Headers[0].KeyID) } var wireDpop wireDpopJwt if err := dpopJWT.Claims(v.dpopKey, &wireDpop); err != nil { return nil, nil, fmt.Errorf("failed validating Wire DPoP token claims: %w", err) } if err := wireDpop.ValidateWithLeeway(jose.Expected{ Time: v.t, Audience: jose.Audience{v.audience}, }, 1*time.Minute); err != nil { return nil, nil, fmt.Errorf("failed DPoP validation: %w", err) } if wireDpop.HTU == "" || wireDpop.HTU != v.issuer { // DPoP doesn't contains "iss" claim, but has it in the "htu" claim return nil, nil, fmt.Errorf("DPoP contains invalid issuer 'htu' %q", wireDpop.HTU) } if wireDpop.Expiry.Time().After(v.t.Add(time.Hour)) { return nil, nil, fmt.Errorf("'exp' %s is too far into the future", wireDpop.Expiry.Time().String()) } if wireDpop.Subject != v.wireID.ClientID { return nil, nil, fmt.Errorf("DPoP contains invalid Wire client ID %q", wireDpop.ClientID) } if wireDpop.Nonce == "" || wireDpop.Nonce != accessToken.Nonce { return nil, nil, fmt.Errorf("DPoP contains invalid 'nonce' %q", wireDpop.Nonce) } if wireDpop.Challenge == "" || wireDpop.Challenge != accessToken.Challenge { return nil, nil, fmt.Errorf("DPoP contains invalid challenge 'chal' %q", wireDpop.Challenge) } // TODO(hs): can we use the wireDpopJwt and map that instead of doing Claims() twice? var dpopToken wireDpopToken if err := dpopJWT.Claims(v.dpopKey, &dpopToken); err != nil { return nil, nil, fmt.Errorf("failed validating Wire DPoP token claims: %w", err) } challenge, ok := dpopToken["chal"].(string) if !ok { return nil, nil, fmt.Errorf("invalid challenge 'chal' in Wire DPoP token") } if challenge == "" || challenge != v.chToken { return nil, nil, fmt.Errorf("invalid Wire DPoP challenge 'chal' %q", challenge) } handle, ok := dpopToken["handle"].(string) if !ok { return nil, nil, fmt.Errorf("invalid 'handle' in Wire DPoP token") } if handle == "" || handle != v.wireID.Handle { return nil, nil, fmt.Errorf("invalid Wire client 'handle' %q", handle) } name, ok := dpopToken["name"].(string) if !ok { return nil, nil, fmt.Errorf("invalid display 'name' in Wire DPoP token") } if name == "" || name != v.wireID.Name { return nil, nil, fmt.Errorf("invalid Wire client display 'name' %q", name) } return &accessToken, &dpopToken, nil } type payloadType struct { AttObj string `json:"attObj"` Error string `json:"error"` } type attestationObject struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` } // TODO(bweeks): move attestation verification to a shared package. func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error { // Update challenge with the payload ch.Payload = payload // Load authorization to store the key fingerprint. az, err := db.GetAuthorization(ctx, ch.AuthorizationID) if err != nil { return WrapErrorISE(err, "error loading authorization") } // Parse payload. var p payloadType if err := json.Unmarshal(payload, &p); err != nil { return WrapErrorISE(err, "error unmarshalling JSON") } if p.Error != "" { return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "payload contained error: %v", p.Error)) } attObj, err := base64.RawURLEncoding.DecodeString(p.AttObj) if err != nil { return storeError(ctx, db, ch, true, NewDetailedError(ErrorBadAttestationStatementType, "failed base64 decoding attObj %q", p.AttObj)) } if len(attObj) == 0 || bytes.Equal(attObj, []byte("{}")) { return storeError(ctx, db, ch, true, NewDetailedError(ErrorBadAttestationStatementType, "attObj must not be empty")) } cborDecoderOptions := cbor.DecOptions{} cborDecoder, err := cborDecoderOptions.DecMode() if err != nil { return WrapErrorISE(err, "failed creating CBOR decoder") } if err := cborDecoder.Wellformed(attObj); err != nil { return storeError(ctx, db, ch, true, NewDetailedError(ErrorBadAttestationStatementType, "attObj is not well formed CBOR: %v", err)) } att := attestationObject{} if err := cborDecoder.Unmarshal(attObj, &att); err != nil { return WrapErrorISE(err, "failed unmarshalling CBOR") } format := att.Format prov := MustProvisionerFromContext(ctx) if !prov.IsAttestationFormatEnabled(ctx, provisioner.ACMEAttestationFormat(format)) { if format != "apple" && format != "step" && format != "tpm" { return storeError(ctx, db, ch, true, NewDetailedError(ErrorBadAttestationStatementType, "unsupported attestation object format %q", format)) } return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "attestation format %q is not enabled", format)) } switch format { case "apple": data, err := doAppleAttestationFormat(ctx, prov, ch, &att) if err != nil { var acmeError *Error if errors.As(err, &acmeError) { if acmeError.Status == 500 { return acmeError } return storeError(ctx, db, ch, true, acmeError) } return WrapErrorISE(err, "error validating attestation") } // Validate nonce with SHA-256 of the token. if len(data.Nonce) != 0 { sum := sha256.Sum256([]byte(ch.Token)) if subtle.ConstantTimeCompare(data.Nonce, sum[:]) != 1 { return storeError(ctx, db, ch, true, NewDetailedError(ErrorBadAttestationStatementType, "challenge token does not match")) } } // Validate Apple's ClientIdentifier (Identifier.Value) with device // identifiers. // // Note: We might want to use an external service for this. if data.UDID != ch.Value && data.SerialNumber != ch.Value { subproblem := NewSubproblemWithIdentifier( ErrorRejectedIdentifierType, Identifier{Type: "permanent-identifier", Value: ch.Value}, "challenge identifier %q doesn't match any of the attested hardware identifiers %q", ch.Value, []string{data.UDID, data.SerialNumber}, ) return storeError(ctx, db, ch, true, NewDetailedError(ErrorBadAttestationStatementType, "permanent identifier does not match").AddSubproblems(subproblem)) } // Update attestation key fingerprint to compare against the CSR az.Fingerprint = data.Fingerprint case "step": data, err := doStepAttestationFormat(ctx, prov, ch, jwk, &att) if err != nil { var acmeError *Error if errors.As(err, &acmeError) { if acmeError.Status == 500 { return acmeError } return storeError(ctx, db, ch, true, acmeError) } return WrapErrorISE(err, "error validating attestation") } // Validate the YubiKey serial number from the attestation // certificate with the challenged Order value. // // Note: We might want to use an external service for this. if data.SerialNumber != ch.Value { subproblem := NewSubproblemWithIdentifier( ErrorRejectedIdentifierType, Identifier{Type: "permanent-identifier", Value: ch.Value}, "challenge identifier %q doesn't match the attested hardware identifier %q", ch.Value, data.SerialNumber, ) return storeError(ctx, db, ch, true, NewDetailedError(ErrorBadAttestationStatementType, "permanent identifier does not match").AddSubproblems(subproblem)) } // Update attestation key fingerprint to compare against the CSR az.Fingerprint = data.Fingerprint case "tpm": data, err := doTPMAttestationFormat(ctx, prov, ch, jwk, &att) if err != nil { var acmeError *Error if errors.As(err, &acmeError) { if acmeError.Status == 500 { return acmeError } return storeError(ctx, db, ch, true, acmeError) } return WrapErrorISE(err, "error validating attestation") } // TODO(hs): currently this will allow a request for which no PermanentIdentifiers have been // extracted from the AK certificate. This is currently the case for AK certs from the CLI, as we // haven't implemented a way for AK certs requested by the CLI to always contain the requested // PermanentIdentifier. Omitting the check below doesn't allow just any request, as the Order can // still fail if the challenge value isn't equal to the CSR subject. if len(data.PermanentIdentifiers) > 0 && !slices.Contains(data.PermanentIdentifiers, ch.Value) { // TODO(hs): add support for HardwareModuleName subproblem := NewSubproblemWithIdentifier( ErrorRejectedIdentifierType, Identifier{Type: "permanent-identifier", Value: ch.Value}, "challenge identifier %q doesn't match any of the attested hardware identifiers %q", ch.Value, data.PermanentIdentifiers, ) return storeError(ctx, db, ch, true, NewDetailedError(ErrorBadAttestationStatementType, "permanent identifier does not match").AddSubproblems(subproblem)) } // Update attestation key fingerprint to compare against the CSR az.Fingerprint = data.Fingerprint default: return storeError(ctx, db, ch, true, NewDetailedError(ErrorBadAttestationStatementType, "unsupported attestation object format %q", format)) } // Update and store the challenge. ch.Status = StatusValid ch.Error = nil ch.ValidatedAt = clock.Now().Format(time.RFC3339) ch.PayloadFormat = format // Store the fingerprint in the authorization. // // TODO: add method to update authorization and challenge atomically. if az.Fingerprint != "" { if err := db.UpdateAuthorization(ctx, az); err != nil { return WrapErrorISE(err, "error updating authorization") } } if err := db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "error updating challenge") } return nil } var ( oidSubjectAlternativeName = asn1.ObjectIdentifier{2, 5, 29, 17} ) type tpmAttestationData struct { Certificate *x509.Certificate VerifiedChains [][]*x509.Certificate PermanentIdentifiers []string Fingerprint string } // coseAlgorithmIdentifier models a COSEAlgorithmIdentifier. // Also see https://www.w3.org/TR/webauthn-2/#sctn-alg-identifier. type coseAlgorithmIdentifier int32 const ( coseAlgES256 = coseAlgorithmIdentifier(-7) coseAlgRS256 = coseAlgorithmIdentifier(-257) coseAlgRS1 = coseAlgorithmIdentifier(-65535) // deprecated, but (still) often used in TPMs ) func doTPMAttestationFormat(_ context.Context, prov Provisioner, ch *Challenge, jwk *jose.JSONWebKey, att *attestationObject) (*tpmAttestationData, error) { ver, ok := att.AttStatement["ver"].(string) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "ver not present") } if ver != "2.0" { return nil, NewDetailedError(ErrorBadAttestationStatementType, "version %q is not supported", ver) } x5c, ok := att.AttStatement["x5c"].([]interface{}) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c not present") } if len(x5c) == 0 { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c is empty") } akCertBytes, ok := x5c[0].([]byte) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c is malformed") } akCert, err := x509.ParseCertificate(akCertBytes) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "x5c is malformed") } intermediates := x509.NewCertPool() for _, v := range x5c[1:] { intCertBytes, vok := v.([]byte) if !vok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c is malformed") } intCert, err := x509.ParseCertificate(intCertBytes) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "x5c is malformed") } intermediates.AddCert(intCert) } // TODO(hs): this can be removed when permanent-identifier/hardware-module-name are handled correctly in // the stdlib in https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/x509/parser.go;drc=b5b2cf519fe332891c165077f3723ee74932a647;l=362, // but I doubt that will happen. if len(akCert.UnhandledCriticalExtensions) > 0 { unhandledCriticalExtensions := akCert.UnhandledCriticalExtensions[:0] for _, extOID := range akCert.UnhandledCriticalExtensions { if !extOID.Equal(oidSubjectAlternativeName) { // critical extensions other than the Subject Alternative Name remain unhandled unhandledCriticalExtensions = append(unhandledCriticalExtensions, extOID) } } akCert.UnhandledCriticalExtensions = unhandledCriticalExtensions } roots, ok := prov.GetAttestationRoots() if !ok { return nil, NewErrorISE("no root CA bundle available to verify the attestation certificate") } // verify that the AK certificate was signed by a trusted root, // chained to by the intermediates provided by the client. As part // of building the verified certificate chain, the signature over the // AK certificate is checked to be a valid signature of one of the // provided intermediates. Signatures over the intermediates are in // turn also verified to be valid signatures from one of the trusted // roots. verifiedChains, err := akCert.Verify(x509.VerifyOptions{ Roots: roots, Intermediates: intermediates, CurrentTime: time.Now().Truncate(time.Second), KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, }) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "x5c is not valid") } // validate additional AK certificate requirements if err := validateAKCertificate(akCert); err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "AK certificate is not valid") } // TODO(hs): implement revocation check; Verify() doesn't perform CRL check nor OCSP lookup. sans, err := x509util.ParseSubjectAlternativeNames(akCert) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "failed parsing AK certificate Subject Alternative Names") } permanentIdentifiers := make([]string, len(sans.PermanentIdentifiers)) for i, pi := range sans.PermanentIdentifiers { permanentIdentifiers[i] = pi.Identifier } // extract and validate pubArea, sig, certInfo and alg properties from the request body pubArea, ok := att.AttStatement["pubArea"].([]byte) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "invalid pubArea in attestation statement") } if len(pubArea) == 0 { return nil, NewDetailedError(ErrorBadAttestationStatementType, "pubArea is empty") } sig, ok := att.AttStatement["sig"].([]byte) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "invalid sig in attestation statement") } if len(sig) == 0 { return nil, NewDetailedError(ErrorBadAttestationStatementType, "sig is empty") } certInfo, ok := att.AttStatement["certInfo"].([]byte) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "invalid certInfo in attestation statement") } if len(certInfo) == 0 { return nil, NewDetailedError(ErrorBadAttestationStatementType, "certInfo is empty") } alg, ok := att.AttStatement["alg"].(int64) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "invalid alg in attestation statement") } algI32, err := cast.SafeInt32(alg) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "invalid alg %d in attestation statement", alg) } var hash crypto.Hash switch coseAlgorithmIdentifier(algI32) { case coseAlgRS256, coseAlgES256: hash = crypto.SHA256 case coseAlgRS1: hash = crypto.SHA1 default: return nil, NewDetailedError(ErrorBadAttestationStatementType, "invalid alg %d in attestation statement", alg) } // recreate the generated key certification parameter values and verify // the attested key using the public key of the AK. certificationParameters := &attest.CertificationParameters{ Public: pubArea, // the public key that was attested CreateAttestation: certInfo, // the attested properties of the key CreateSignature: sig, // signature over the attested properties } verifyOpts := attest.VerifyOpts{ Public: akCert.PublicKey, // public key of the AK that attested the key Hash: hash, } if err = certificationParameters.Verify(verifyOpts); err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "invalid certification parameters") } // decode the "certInfo" data. This won't fail, as it's also done as part of Verify(). tpmCertInfo, err := tpm2.DecodeAttestationData(certInfo) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "failed decoding attestation data") } keyAuth, err := KeyAuthorization(ch.Token, jwk) if err != nil { return nil, WrapErrorISE(err, "failed creating key auth digest") } hashedKeyAuth := sha256.Sum256([]byte(keyAuth)) // verify the WebAuthn object contains the expect key authorization digest, which is carried // within the encoded `certInfo` property of the attestation statement. if subtle.ConstantTimeCompare(hashedKeyAuth[:], []byte(tpmCertInfo.ExtraData)) == 0 { return nil, NewDetailedError(ErrorBadAttestationStatementType, "key authorization invalid") } // decode the (attested) public key and determine its fingerprint. This won't fail, as it's also done as part of Verify(). pub, err := tpm2.DecodePublic(pubArea) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "failed decoding pubArea") } publicKey, err := pub.Key() if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "failed getting public key") } data := &tpmAttestationData{ Certificate: akCert, VerifiedChains: verifiedChains, PermanentIdentifiers: permanentIdentifiers, } if data.Fingerprint, err = keyutil.Fingerprint(publicKey); err != nil { return nil, WrapErrorISE(err, "error calculating key fingerprint") } // TODO(hs): pass more attestation data, so that that can be used/recorded too? return data, nil } var ( oidExtensionExtendedKeyUsage = asn1.ObjectIdentifier{2, 5, 29, 37} oidTCGKpAIKCertificate = asn1.ObjectIdentifier{2, 23, 133, 8, 3} ) // validateAKCertificate validates the X.509 AK certificate to be // in accordance with the required properties. The requirements come from: // https://www.w3.org/TR/webauthn-2/#sctn-tpm-cert-requirements. // // - Version MUST be set to 3. // - Subject field MUST be set to empty. // - The Subject Alternative Name extension MUST be set as defined // in [TPMv2-EK-Profile] section 3.2.9. // - The Extended Key Usage extension MUST contain the OID 2.23.133.8.3 // ("joint-iso-itu-t(2) international-organizations(23) 133 tcg-kp(8) tcg-kp-AIKCertificate(3)"). // - The Basic Constraints extension MUST have the CA component set to false. // - An Authority Information Access (AIA) extension with entry id-ad-ocsp // and a CRL Distribution Point extension [RFC5280] are both OPTIONAL as // the status of many attestation certificates is available through metadata // services. See, for example, the FIDO Metadata Service. func validateAKCertificate(c *x509.Certificate) error { if c.Version != 3 { return fmt.Errorf("AK certificate has invalid version %d; only version 3 is allowed", c.Version) } if c.Subject.String() != "" { return fmt.Errorf("AK certificate subject must be empty; got %q", c.Subject) } if c.IsCA { return errors.New("AK certificate must not be a CA") } if err := validateAKCertificateExtendedKeyUsage(c); err != nil { return err } return validateAKCertificateSubjectAlternativeNames(c) } // validateAKCertificateSubjectAlternativeNames checks if the AK certificate // has TPM hardware details set. func validateAKCertificateSubjectAlternativeNames(c *x509.Certificate) error { sans, err := x509util.ParseSubjectAlternativeNames(c) if err != nil { return fmt.Errorf("failed parsing AK certificate Subject Alternative Names: %w", err) } details := sans.TPMHardwareDetails manufacturer, model, version := details.Manufacturer, details.Model, details.Version switch { case manufacturer == "": return errors.New("missing TPM manufacturer") case model == "": return errors.New("missing TPM model") case version == "": return errors.New("missing TPM version") } return nil } // validateAKCertificateExtendedKeyUsage checks if the AK certificate // has the "tcg-kp-AIKCertificate" Extended Key Usage set. func validateAKCertificateExtendedKeyUsage(c *x509.Certificate) error { var ( valid = false ekus []asn1.ObjectIdentifier ) for _, ext := range c.Extensions { if ext.Id.Equal(oidExtensionExtendedKeyUsage) { if _, err := asn1.Unmarshal(ext.Value, &ekus); err != nil || len(ekus) == 0 || !ekus[0].Equal(oidTCGKpAIKCertificate) { return errors.New("AK certificate is missing Extended Key Usage value tcg-kp-AIKCertificate (2.23.133.8.3)") } valid = true } } if !valid { return errors.New("AK certificate is missing Extended Key Usage extension") } return nil } // Apple Enterprise Attestation Root CA from // https://www.apple.com/certificateauthority/private/ const appleEnterpriseAttestationRootCA = `-----BEGIN CERTIFICATE----- MIICJDCCAamgAwIBAgIUQsDCuyxyfFxeq/bxpm8frF15hzcwCgYIKoZIzj0EAwMw UTEtMCsGA1UEAwwkQXBwbGUgRW50ZXJwcmlzZSBBdHRlc3RhdGlvbiBSb290IENB MRMwEQYDVQQKDApBcHBsZSBJbmMuMQswCQYDVQQGEwJVUzAeFw0yMjAyMTYxOTAx MjRaFw00NzAyMjAwMDAwMDBaMFExLTArBgNVBAMMJEFwcGxlIEVudGVycHJpc2Ug QXR0ZXN0YXRpb24gUm9vdCBDQTETMBEGA1UECgwKQXBwbGUgSW5jLjELMAkGA1UE BhMCVVMwdjAQBgcqhkjOPQIBBgUrgQQAIgNiAAT6Jigq+Ps9Q4CoT8t8q+UnOe2p oT9nRaUfGhBTbgvqSGXPjVkbYlIWYO+1zPk2Sz9hQ5ozzmLrPmTBgEWRcHjA2/y7 7GEicps9wn2tj+G89l3INNDKETdxSPPIZpPj8VmjQjBAMA8GA1UdEwEB/wQFMAMB Af8wHQYDVR0OBBYEFPNqTQGd8muBpV5du+UIbVbi+d66MA4GA1UdDwEB/wQEAwIB BjAKBggqhkjOPQQDAwNpADBmAjEA1xpWmTLSpr1VH4f8Ypk8f3jMUKYz4QPG8mL5 8m9sX/b2+eXpTv2pH4RZgJjucnbcAjEA4ZSB6S45FlPuS/u4pTnzoz632rA+xW/T ZwFEh9bhKjJ+5VQ9/Do1os0u3LEkgN/r -----END CERTIFICATE-----` var ( oidAppleSerialNumber = asn1.ObjectIdentifier{1, 2, 840, 113635, 100, 8, 9, 1} oidAppleUniqueDeviceIdentifier = asn1.ObjectIdentifier{1, 2, 840, 113635, 100, 8, 9, 2} oidAppleSecureEnclaveProcessorOSVersion = asn1.ObjectIdentifier{1, 2, 840, 113635, 100, 8, 10, 2} oidAppleNonce = asn1.ObjectIdentifier{1, 2, 840, 113635, 100, 8, 11, 1} ) type appleAttestationData struct { Nonce []byte SerialNumber string UDID string SEPVersion string Certificate *x509.Certificate Fingerprint string } func doAppleAttestationFormat(_ context.Context, prov Provisioner, _ *Challenge, att *attestationObject) (*appleAttestationData, error) { // Use configured or default attestation roots if none is configured. roots, ok := prov.GetAttestationRoots() if !ok { root, err := pemutil.ParseCertificate([]byte(appleEnterpriseAttestationRootCA)) if err != nil { return nil, WrapErrorISE(err, "error parsing apple enterprise ca") } roots = x509.NewCertPool() roots.AddCert(root) } x5c, ok := att.AttStatement["x5c"].([]interface{}) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c not present") } if len(x5c) == 0 { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c is empty") } der, ok := x5c[0].([]byte) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c is malformed") } leaf, err := x509.ParseCertificate(der) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "x5c is malformed") } intermediates := x509.NewCertPool() for _, v := range x5c[1:] { der, ok = v.([]byte) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c is malformed") } cert, err := x509.ParseCertificate(der) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "x5c is malformed") } intermediates.AddCert(cert) } if _, err := leaf.Verify(x509.VerifyOptions{ Intermediates: intermediates, Roots: roots, CurrentTime: time.Now().Truncate(time.Second), KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, }); err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "x5c is not valid") } data := &appleAttestationData{ Certificate: leaf, } if data.Fingerprint, err = keyutil.Fingerprint(leaf.PublicKey); err != nil { return nil, WrapErrorISE(err, "error calculating key fingerprint") } for _, ext := range leaf.Extensions { switch { case ext.Id.Equal(oidAppleSerialNumber): data.SerialNumber = string(ext.Value) case ext.Id.Equal(oidAppleUniqueDeviceIdentifier): data.UDID = string(ext.Value) case ext.Id.Equal(oidAppleSecureEnclaveProcessorOSVersion): data.SEPVersion = string(ext.Value) case ext.Id.Equal(oidAppleNonce): data.Nonce = ext.Value } } return data, nil } // Yubico PIV Root CA Serial 263751 // https://developers.yubico.com/PIV/Introduction/piv-attestation-ca.pem const yubicoPIVRootCA = `-----BEGIN CERTIFICATE----- MIIDFzCCAf+gAwIBAgIDBAZHMA0GCSqGSIb3DQEBCwUAMCsxKTAnBgNVBAMMIFl1 YmljbyBQSVYgUm9vdCBDQSBTZXJpYWwgMjYzNzUxMCAXDTE2MDMxNDAwMDAwMFoY DzIwNTIwNDE3MDAwMDAwWjArMSkwJwYDVQQDDCBZdWJpY28gUElWIFJvb3QgQ0Eg U2VyaWFsIDI2Mzc1MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMN2 cMTNR6YCdcTFRxuPy31PabRn5m6pJ+nSE0HRWpoaM8fc8wHC+Tmb98jmNvhWNE2E ilU85uYKfEFP9d6Q2GmytqBnxZsAa3KqZiCCx2LwQ4iYEOb1llgotVr/whEpdVOq joU0P5e1j1y7OfwOvky/+AXIN/9Xp0VFlYRk2tQ9GcdYKDmqU+db9iKwpAzid4oH BVLIhmD3pvkWaRA2H3DA9t7H/HNq5v3OiO1jyLZeKqZoMbPObrxqDg+9fOdShzgf wCqgT3XVmTeiwvBSTctyi9mHQfYd2DwkaqxRnLbNVyK9zl+DzjSGp9IhVPiVtGet X02dxhQnGS7K6BO0Qe8CAwEAAaNCMEAwHQYDVR0OBBYEFMpfyvLEojGc6SJf8ez0 1d8Cv4O/MA8GA1UdEwQIMAYBAf8CAQEwDgYDVR0PAQH/BAQDAgEGMA0GCSqGSIb3 DQEBCwUAA4IBAQBc7Ih8Bc1fkC+FyN1fhjWioBCMr3vjneh7MLbA6kSoyWF70N3s XhbXvT4eRh0hvxqvMZNjPU/VlRn6gLVtoEikDLrYFXN6Hh6Wmyy1GTnspnOvMvz2 lLKuym9KYdYLDgnj3BeAvzIhVzzYSeU77/Cupofj093OuAswW0jYvXsGTyix6B3d bW5yWvyS9zNXaqGaUmP3U9/b6DlHdDogMLu3VLpBB9bm5bjaKWWJYgWltCVgUbFq Fqyi4+JE014cSgR57Jcu3dZiehB6UtAPgad9L5cNvua/IWRmm+ANy3O2LH++Pyl8 SREzU8onbBsjMg9QDiSf5oJLKvd/Ren+zGY7 -----END CERTIFICATE-----` // Yubico Attestation Root 1 (YubiKey 5.7.4+) // https://developers.yubico.com/PKI/yubico-ca-1.pem const yubicoAttestationRootCA = `-----BEGIN CERTIFICATE----- MIIDPjCCAiagAwIBAgIUXzeiEDJEOTt14F5n0o6Zf/bBwiUwDQYJKoZIhvcNAQEN BQAwJDEiMCAGA1UEAwwZWXViaWNvIEF0dGVzdGF0aW9uIFJvb3QgMTAgFw0yNDEy MDEwMDAwMDBaGA85OTk5MTIzMTIzNTk1OVowJDEiMCAGA1UEAwwZWXViaWNvIEF0 dGVzdGF0aW9uIFJvb3QgMTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB AMZ6/TxM8rIT+EaoPvG81ontMOo/2mQ2RBwJHS0QZcxVaNXvl12LUhBZ5LmiBScI Zd1Rnx1od585h+/dhK7hEm7JAALkKKts1fO53KGNLZujz5h3wGncr4hyKF0G74b/ U3K9hE5mGND6zqYchCRAHfrYMYRDF4YL0X4D5nGdxvppAy6nkEmtWmMnwO3i0TAu csrbE485HvGM4r0VpgVdJpvgQjiTJCTIq+D35hwtT8QDIv+nGvpcyi5wcIfCkzyC imJukhYy6KoqNMKQEdpNiSOvWyDMTMt1bwCvEzpw91u+msUt4rj0efnO9s0ZOwdw MRDnH4xgUl5ZLwrrPkfC1/0CAwEAAaNmMGQwHQYDVR0OBBYEFNLu71oijTptXCOX PfKF1SbxJXuSMB8GA1UdIwQYMBaAFNLu71oijTptXCOXPfKF1SbxJXuSMBIGA1Ud EwEB/wQIMAYBAf8CAQMwDgYDVR0PAQH/BAQDAgGGMA0GCSqGSIb3DQEBDQUAA4IB AQC3IW/sgB9pZ8apJNjxuGoX+FkILks0wMNrdXL/coUvsrhzsvl6mePMrbGJByJ1 XnquB5sgcRENFxdQFma3mio8Upf1owM1ZreXrJ0mADG2BplqbJnxiyYa+R11reIF TWeIhMNcZKsDZrFAyPuFjCWSQvJmNWe9mFRYFgNhXJKkXIb5H1XgEDlwiedYRM7V olBNlld6pRFKlX8ust6OTMOeADl2xNF0m1LThSdeuXvDyC1g9+ILfz3S6OIYgc3i roRcFD354g7rKfu67qFAw9gC4yi0xBTPrY95rh4/HqaUYCA/L8ldRk6H7Xk35D+W Vpmq2Sh/xT5HiFuhf4wJb0bK -----END CERTIFICATE-----` var ( // serial number of the YubiKey, encoded as an integer. // https://developers.yubico.com/PIV/Introduction/PIV_attestation.html oidYubicoSerialNumber = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 41482, 3, 7} // custom Smallstep managed device extension carrying a device ID or serial number oidStepManagedDevice = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 4} ) type stepAttestationData struct { Certificate *x509.Certificate SerialNumber string Fingerprint string } func doStepAttestationFormat(_ context.Context, prov Provisioner, ch *Challenge, jwk *jose.JSONWebKey, att *attestationObject) (*stepAttestationData, error) { // Use configured or default attestation roots if none is configured. roots, ok := prov.GetAttestationRoots() if !ok { pivRoot, err := pemutil.ParseCertificate([]byte(yubicoPIVRootCA)) if err != nil { return nil, WrapErrorISE(err, "error parsing root ca") } attRoot, err := pemutil.ParseCertificate([]byte(yubicoAttestationRootCA)) if err != nil { return nil, WrapErrorISE(err, "error parsing root ca") } roots = x509.NewCertPool() roots.AddCert(pivRoot) roots.AddCert(attRoot) } // Extract x5c and verify certificate x5c, ok := att.AttStatement["x5c"].([]interface{}) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c not present") } if len(x5c) == 0 { return nil, NewDetailedError(ErrorRejectedIdentifierType, "x5c is empty") } der, ok := x5c[0].([]byte) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c is malformed") } leaf, err := x509.ParseCertificate(der) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "x5c is malformed") } intermediates := x509.NewCertPool() for _, v := range x5c[1:] { der, ok = v.([]byte) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "x5c is malformed") } cert, err := x509.ParseCertificate(der) if err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "x5c is malformed") } intermediates.AddCert(cert) } if _, err := leaf.Verify(x509.VerifyOptions{ Intermediates: intermediates, Roots: roots, CurrentTime: time.Now().Truncate(time.Second), KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, }); err != nil { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "x5c is not valid") } // Verify proof of possession of private key validating the key // authorization. Per recommendation at // https://w3c.github.io/webauthn/#sctn-signature-attestation-types the // signature is CBOR-encoded. var sig []byte csig, ok := att.AttStatement["sig"].([]byte) if !ok { return nil, NewDetailedError(ErrorBadAttestationStatementType, "sig not present") } if err := cbor.Unmarshal(csig, &sig); err != nil { return nil, NewDetailedError(ErrorBadAttestationStatementType, "sig is malformed") } keyAuth, err := KeyAuthorization(ch.Token, jwk) if err != nil { return nil, err } switch pub := leaf.PublicKey.(type) { case *ecdsa.PublicKey: if pub.Curve != elliptic.P256() { return nil, WrapDetailedError(ErrorBadAttestationStatementType, err, "unsupported elliptic curve %s", pub.Curve) } sum := sha256.Sum256([]byte(keyAuth)) if !ecdsa.VerifyASN1(pub, sum[:], sig) { return nil, NewDetailedError(ErrorBadAttestationStatementType, "failed to validate signature") } case *rsa.PublicKey: sum := sha256.Sum256([]byte(keyAuth)) if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sig); err != nil { return nil, NewDetailedError(ErrorBadAttestationStatementType, "failed to validate signature") } case ed25519.PublicKey: if !ed25519.Verify(pub, []byte(keyAuth), sig) { return nil, NewDetailedError(ErrorBadAttestationStatementType, "failed to validate signature") } default: return nil, NewDetailedError(ErrorBadAttestationStatementType, "unsupported public key type %T", pub) } // Parse attestation data: // TODO(mariano): add support for other extensions. data := &stepAttestationData{ Certificate: leaf, } if data.Fingerprint, err = keyutil.Fingerprint(leaf.PublicKey); err != nil { return nil, WrapErrorISE(err, "error calculating key fingerprint") } if data.SerialNumber, err = searchSerialNumber(leaf); err != nil { return nil, WrapErrorISE(err, "error finding serial number") } return data, nil } // searchSerialNumber searches the certificate extensions, looking for a serial // number encoded in one of them. It is not guaranteed that a certificate contains // an extension carrying a serial number, so the result can be empty. func searchSerialNumber(cert *x509.Certificate) (string, error) { for _, ext := range cert.Extensions { if ext.Id.Equal(oidYubicoSerialNumber) { var serialNumber int rest, err := asn1.Unmarshal(ext.Value, &serialNumber) if err != nil || len(rest) > 0 { return "", WrapError(ErrorBadAttestationStatementType, err, "error parsing serial number") } return strconv.Itoa(serialNumber), nil } if ext.Id.Equal(oidStepManagedDevice) { type stepManagedDevice struct { DeviceID string } var md stepManagedDevice rest, err := asn1.Unmarshal(ext.Value, &md) if err != nil || len(rest) > 0 { return "", WrapError(ErrorBadAttestationStatementType, err, "error parsing serial number") } return md.DeviceID, nil } } return "", nil } // serverName determines the SNI HostName to set based on an acme.Challenge // for TLS-ALPN-01 challenges RFC8738 states that, if HostName is an IP, it // should be the ARPA address https://datatracker.ietf.org/doc/html/rfc8738#section-6. // It also references TLS Extensions [RFC6066]. func serverName(ch *Challenge) string { if ip := net.ParseIP(ch.Value); ip != nil { return reverseAddr(ip) } return ch.Value } // reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP // address addr suitable for rDNS (PTR) record lookup or an error if it fails // to parse the IP address. // Implementation taken and adapted from https://golang.org/src/net/dnsclient.go?s=780:834#L20 func reverseAddr(ip net.IP) (arpa string) { if ip.To4() != nil { return uitoa(uint(ip[15])) + "." + uitoa(uint(ip[14])) + "." + uitoa(uint(ip[13])) + "." + uitoa(uint(ip[12])) + ".in-addr.arpa." } // Must be IPv6 buf := make([]byte, 0, len(ip)*4+len("ip6.arpa.")) // Add it, in reverse, to the buffer for i := len(ip) - 1; i >= 0; i-- { v := ip[i] buf = append(buf, hexit[v&0xF], '.', hexit[v>>4], '.') } // Append "ip6.arpa." and return (buf already has the final .) buf = append(buf, "ip6.arpa."...) return string(buf) } // Convert unsigned integer to decimal string. // Implementation taken from https://golang.org/src/net/parse.go func uitoa(val uint) string { if val == 0 { // avoid string allocation return "0" } var buf [20]byte // big enough for 64bit value base 10 i := len(buf) - 1 for val >= 10 { v := val / 10 buf[i] = byte('0' + val - v*10) //nolint:gosec // val - v*10 is always 0-9 i-- val = v } // val < 10 buf[i] = byte('0' + val) return string(buf[i:]) } const hexit = "0123456789abcdef" // KeyAuthorization creates the ACME key authorization value from a token // and a jwk. func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { thumbprint, err := jwk.Thumbprint(crypto.SHA256) if err != nil { return "", WrapErrorISE(err, "error generating JWK thumbprint") } encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) return fmt.Sprintf("%s.%s", token, encPrint), nil } // storeError the given error to an ACME error and saves using the DB interface. func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err *Error) error { ch.Error = err if markInvalid { ch.Status = StatusInvalid } if err := db.UpdateChallenge(ctx, ch); err != nil { return WrapErrorISE(err, "failure saving error to acme challenge") } return nil } ================================================ FILE: acme/challenge_test.go ================================================ package acme import ( "bytes" "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/base64" "encoding/hex" "encoding/json" "encoding/pem" "errors" "fmt" "io" "math/big" "net" "net/http" "net/http/httptest" "reflect" "strconv" "strings" "testing" "time" "github.com/fxamacker/cbor/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/minica" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" wireprovisioner "github.com/smallstep/certificates/authority/provisioner/wire" ) type mockClient struct { get func(url string) (*http.Response, error) lookupTxt func(name string) ([]string, error) tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) } func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) } func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } func (m *mockClient) TLSDial(network, addr string, tlsConfig *tls.Config) (*tls.Conn, error) { return m.tlsDial(network, addr, tlsConfig) } func fatalError(t *testing.T, err error) { t.Helper() if err != nil { t.Fatal(err) } } func mustNonAttestationProvisioner(t *testing.T) Provisioner { t.Helper() prov := &provisioner.ACME{ Type: "ACME", Name: "acme", Challenges: []provisioner.ACMEChallenge{provisioner.HTTP_01}, } if err := prov.Init(provisioner.Config{ Claims: config.GlobalProvisionerClaims, }); err != nil { t.Fatal(err) } prov.AttestationFormats = []provisioner.ACMEAttestationFormat{"bogus-format"} // results in no attestation formats enabled return prov } func mustAttestationProvisioner(t *testing.T, roots []byte) Provisioner { t.Helper() prov := &provisioner.ACME{ Type: "ACME", Name: "acme", Challenges: []provisioner.ACMEChallenge{provisioner.DEVICE_ATTEST_01}, AttestationRoots: roots, } if err := prov.Init(provisioner.Config{ Claims: config.GlobalProvisionerClaims, }); err != nil { t.Fatal(err) } return prov } func mustAccountAndKeyAuthorization(t *testing.T, token string) (*jose.JSONWebKey, string) { t.Helper() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) fatalError(t, err) keyAuth, err := KeyAuthorization(token, jwk) fatalError(t, err) return jwk, keyAuth } func mustAttestApple(t *testing.T, nonce string) ([]byte, *x509.Certificate, *x509.Certificate) { t.Helper() ca, err := minica.New() fatalError(t, err) signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalError(t, err) nonceSum := sha256.Sum256([]byte(nonce)) leaf, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "attestation cert"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ {Id: oidAppleSerialNumber, Value: []byte("serial-number")}, {Id: oidAppleUniqueDeviceIdentifier, Value: []byte("udid")}, {Id: oidAppleSecureEnclaveProcessorOSVersion, Value: []byte("16.0")}, {Id: oidAppleNonce, Value: nonceSum[:]}, }, }) fatalError(t, err) attObj, err := cbor.Marshal(struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` }{ Format: "apple", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, }, }) fatalError(t, err) payload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString(attObj), }) fatalError(t, err) return payload, leaf, ca.Root } func mustAttestYubikey(t *testing.T, _, keyAuthorization string, serial int) ([]byte, *x509.Certificate, *x509.Certificate) { t.Helper() ca, err := minica.New() fatalError(t, err) keyAuthSum := sha256.Sum256([]byte(keyAuthorization)) signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalError(t, err) sig, err := signer.Sign(rand.Reader, keyAuthSum[:], crypto.SHA256) fatalError(t, err) cborSig, err := cbor.Marshal(sig) fatalError(t, err) serialNumber, err := asn1.Marshal(serial) fatalError(t, err) leaf, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "attestation cert"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ {Id: oidYubicoSerialNumber, Value: serialNumber}, }, }) fatalError(t, err) attObj, err := cbor.Marshal(struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` }{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }) fatalError(t, err) payload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString(attObj), }) fatalError(t, err) return payload, leaf, ca.Root } type stepManagedDevice struct { DeviceID string } func mustAttestStepManagedDeviceID(t *testing.T, _, keyAuthorization, serialNumber string) ([]byte, *x509.Certificate, *x509.Certificate) { t.Helper() ca, err := minica.New() require.NoError(t, err) keyAuthSum := sha256.Sum256([]byte(keyAuthorization)) signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) sig, err := signer.Sign(rand.Reader, keyAuthSum[:], crypto.SHA256) require.NoError(t, err) cborSig, err := cbor.Marshal(sig) require.NoError(t, err) v, err := asn1.Marshal(stepManagedDevice{DeviceID: serialNumber}) require.NoError(t, err) leaf, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "attestation cert"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ {Id: oidStepManagedDevice, Value: v}, }, }) require.NoError(t, err) attObj, err := cbor.Marshal(struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` }{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }) require.NoError(t, err) payload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString(attObj), }) require.NoError(t, err) return payload, leaf, ca.Root } func newWireProvisionerWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME { t.Helper() prov := &provisioner.ACME{ Type: "ACME", Name: "wire", Options: options, Challenges: []provisioner.ACMEChallenge{ provisioner.WIREOIDC_01, provisioner.WIREDPOP_01, }, } if err := prov.Init(provisioner.Config{ Claims: config.GlobalProvisionerClaims, }); err != nil { t.Fatal(err) } return prov } func Test_storeError(t *testing.T) { type test struct { ch *Challenge db DB markInvalid bool err *Error } err := NewError(ErrorMalformedType, "foo") tests := map[string]func(t *testing.T) test{ "fail/db.UpdateChallenge-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusValid, } return test{ ch: ch, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusValid, updch.Status) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "fail/db.UpdateChallenge-acme-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusValid, } return test{ ch: ch, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusValid, updch.Status) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return NewError(ErrorMalformedType, "bar") }, }, err: NewError(ErrorMalformedType, "failure saving error to acme challenge: bar"), } }, "ok": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusValid, } return test{ ch: ch, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusValid, updch.Status) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, } }, "ok/mark-invalid": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusValid, } return test{ ch: ch, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusInvalid, updch.Status) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, markInvalid: true, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if err := storeError(context.Background(), tc.db, tc.ch, tc.markInvalid, err); err != nil { if assert.Error(t, tc.err) { var k *Error if errors.As(err, &k) { assert.Equal(t, tc.err.Type, k.Type) assert.Equal(t, tc.err.Detail, k.Detail) assert.Equal(t, tc.err.Status, k.Status) assert.Equal(t, tc.err.Err.Error(), k.Err.Error()) } else { assert.Fail(t, "unexpected error type") } } } else { assert.Nil(t, tc.err) } }) } } func TestKeyAuthorization(t *testing.T) { type test struct { token string jwk *jose.JSONWebKey exp string err *Error } tests := map[string]func(t *testing.T) test{ "fail/jwk-thumbprint-error": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) jwk.Key = "foo" return test{ token: "1234", jwk: jwk, err: NewErrorISE("error generating JWK thumbprint: go-jose/go-jose: unknown key type 'string'"), } }, "ok": func(t *testing.T) test { token := "1234" jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) thumbprint, err := jwk.Thumbprint(crypto.SHA256) require.NoError(t, err) encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) return test{ token: token, jwk: jwk, exp: fmt.Sprintf("%s.%s", token, encPrint), } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if ka, err := KeyAuthorization(tc.token, tc.jwk); err != nil { if assert.Error(t, tc.err) { var k *Error if errors.As(err, &k) { assert.Equal(t, tc.err.Type, k.Type) assert.Equal(t, tc.err.Detail, k.Detail) assert.Equal(t, tc.err.Status, k.Status) assert.Equal(t, tc.err.Err.Error(), k.Err.Error()) } else { assert.Fail(t, "unexpected error type") } } } else { if assert.Nil(t, tc.err) { assert.Equal(t, tc.exp, ka) } } }) } } func TestChallenge_Validate(t *testing.T) { fakeKey := `-----BEGIN PUBLIC KEY----- MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= -----END PUBLIC KEY-----` type test struct { ch *Challenge vc Client jwk *jose.JSONWebKey db DB srv *httptest.Server payload []byte ctx context.Context err *Error } tests := map[string]func(t *testing.T) test{ "ok/already-valid": func(t *testing.T) test { ch := &Challenge{ Status: StatusValid, } return test{ ch: ch, } }, "fail/already-invalid": func(t *testing.T) test { ch := &Challenge{ Status: StatusInvalid, } return test{ ch: ch, } }, "fail/unexpected-type": func(t *testing.T) test { ch := &Challenge{ Status: StatusPending, Type: "foo", } return test{ ch: ch, err: NewErrorISE(`unexpected challenge type "foo"`), } }, "fail/http-01": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Status: StatusPending, Type: "http-01", Token: "token", Value: "zap.internal", } return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, ChallengeType("http-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/http-01": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Status: StatusPending, Type: "http-01", Token: "token", Value: "zap.internal", } return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, ChallengeType("http-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, } }, "ok/http-01-insecure": func(t *testing.T) test { t.Cleanup(func() { InsecurePortHTTP01 = 0 }) ch := &Challenge{ ID: "chID", Status: StatusPending, Type: "http-01", Token: "token", Value: "zap.internal", } InsecurePortHTTP01 = 8080 return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, ChallengeType("http-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal:8080/.well-known/acme-challenge/%s: force", ch.Token) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, } }, "fail/dns-01": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Type: "dns-01", Status: StatusPending, Token: "token", Value: "zap.internal", } return test{ ch: ch, vc: &mockClient{ lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, ChallengeType("dns-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/dns-01": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Type: "dns-01", Status: StatusPending, Token: "token", Value: "zap.internal", } return test{ ch: ch, vc: &mockClient{ lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, ChallengeType("dns-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, } }, "fail/tls-alpn-01": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Type: "tls-alpn-01", Status: StatusPending, Value: "zap.internal", } return test{ ch: ch, vc: &mockClient{ tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorConnectionType, "error doing TLS dial for %v: force", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/tls-alpn-01": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Type: "tls-alpn-01", Status: StatusPending, Value: "zap.internal", } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusValid, updch.Status) assert.Nil(t, updch.Error) return nil }, }, srv: srv, jwk: jwk, } }, "ok/tls-alpn-01-insecure": func(t *testing.T) test { t.Cleanup(func() { InsecurePortTLSALPN01 = 0 }) ch := &Challenge{ ID: "chID", Token: "token", Type: "tls-alpn-01", Status: StatusPending, Value: "zap.internal", } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) require.NoError(t, err) l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { t.Fatalf("failed to listen on a port: %v", err) } } _, port, err := net.SplitHostPort(l.Addr().String()) if err != nil { t.Fatalf("failed to split host port: %v", err) } // Use an insecure port InsecurePortTLSALPN01, err = strconv.Atoi(port) if err != nil { t.Fatalf("failed to convert port to int: %v", err) } srv, tlsDial := newTestTLSALPNServer(cert, func(srv *httptest.Server) { srv.Listener.Close() srv.Listener = l }) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusValid, updch.Status) assert.Nil(t, updch.Error) return nil }, }, srv: srv, jwk: jwk, } }, "fail/device-attest-01": func(t *testing.T) test { payload, err := json.Marshal(struct { Error string `json:"error"` }{ Error: "an error", }) assert.NoError(t, err) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewError(ErrorRejectedIdentifierType, "payload contained error: an error") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) return errors.New("force") }, }, err: NewError(ErrorServerInternalType, "failure saving error to acme challenge: force"), } }, "ok/device-attest-01": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, leaf, root := mustAttestYubikey(t, "nonce", keyAuth, 1234) caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "1234", }, payload: payload, ctx: ctx, jwk: jwk, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error { fingerprint, err := keyutil.Fingerprint(leaf.PublicKey) assert.NoError(t, err) assert.Equal(t, "azID", az.ID) assert.Equal(t, fingerprint, az.Fingerprint) return nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "1234", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Equal(t, "step", updch.PayloadFormat) return nil }, }, } }, "ok/wire-oidc-01": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) srv := mustJWKServer(t, signerJWK.Public()) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud"` }{ Claims: jose.Claims{ Issuer: srv.URL, Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40alice_wire@wire.com", KeyAuth: keyAuth, ACMEAudience: "https://ca.example.com/acme/wire/challenge/azID/chID", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: srv.URL, JWKSURL: srv.URL + "/keys", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: string(valueBytes), }, srv: srv, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return nil }, }, MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) { assert.Equal(t, "accID", accountID) return []string{"orderID"}, nil }, MockCreateOidcToken: func(ctx context.Context, orderID string, idToken map[string]interface{}) error { assert.Equal(t, "orderID", orderID) assert.Equal(t, "Alice Smith", idToken["name"].(string)) assert.Equal(t, "wireapp://%40alice_wire@wire.com", idToken["preferred_username"].(string)) return nil }, }, } }, "fail/wire-oidc-01-no-wire-db": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) srv := mustJWKServer(t, signerJWK.Public()) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud"` }{ Claims: jose.Claims{ Issuer: srv.URL, Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40alice_wire@wire.com", KeyAuth: keyAuth, ACMEAudience: "https://ca.example.com/acme/wire/challenge/azID/chID", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: srv.URL, JWKSURL: srv.URL + "/keys", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: string(valueBytes), }, srv: srv, payload: payload, ctx: ctx, jwk: jwk, db: &MockDB{}, err: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New("db *acme.MockDB is not a WireDB"), }, } }, "ok/wire-dpop-01": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") _ = keyAuth // TODO(hs): keyAuth (not) required for DPoP? Or needs to be added to validation? dpopSigner, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk, }, new(jose.SignerOptions)) require.NoError(t, err) signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) signerPEMBlock, err := pemutil.Serialize(signerJWK.Public().Key) require.NoError(t, err) signerPEMBytes := pem.EncodeToMemory(signerPEMBlock) dpopBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, }, Challenge: "token", Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) require.NoError(t, err) proof, err := dpop.CompactSerialize() require.NoError(t, err) tokenBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Nonce string `json:"nonce,omitempty"` Cnf struct { Kid string `json:"kid,omitempty"` } `json:"cnf"` Proof string `json:"proof,omitempty"` ClientID string `json:"client_id"` APIVersion int `json:"api_version"` Scope string `json:"scope"` }{ Claims: jose.Claims{ Issuer: "http://issuer.example.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Challenge: "token", Nonce: "nonce", Cnf: struct { Kid string `json:"kid,omitempty"` }{ Kid: jwk.KeyID, }, Proof: proof, ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", APIVersion: 5, Scope: "wire_client_id", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) accessToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { AccessToken string `json:"access_token"` }{ AccessToken: accessToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "http://issuerexample.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ Target: "http://issuer.example.com", SigningKey: signerPEMBytes, }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: string(valueBytes), }, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return nil }, }, MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) { assert.Equal(t, "accID", accountID) return []string{"orderID"}, nil }, MockCreateDpopToken: func(ctx context.Context, orderID string, dpop map[string]interface{}) error { assert.Equal(t, "orderID", orderID) assert.Equal(t, "token", dpop["chal"].(string)) assert.Equal(t, "wireapp://%40alice_wire@wire.com", dpop["handle"].(string)) assert.Equal(t, "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", dpop["sub"].(string)) return nil }, }, } }, "fail/wire-dpop-01-no-wire-db": func(t *testing.T) test { jwk, _ := mustAccountAndKeyAuthorization(t, "token") dpopSigner, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk, }, new(jose.SignerOptions)) require.NoError(t, err) signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) signerPEMBlock, err := pemutil.Serialize(signerJWK.Public().Key) require.NoError(t, err) signerPEMBytes := pem.EncodeToMemory(signerPEMBlock) dpopBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, }, Challenge: "token", Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) require.NoError(t, err) proof, err := dpop.CompactSerialize() require.NoError(t, err) tokenBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Nonce string `json:"nonce,omitempty"` Cnf struct { Kid string `json:"kid,omitempty"` } `json:"cnf"` Proof string `json:"proof,omitempty"` ClientID string `json:"client_id"` APIVersion int `json:"api_version"` Scope string `json:"scope"` }{ Claims: jose.Claims{ Issuer: "http://issuer.example.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Challenge: "token", Nonce: "nonce", Cnf: struct { Kid string `json:"kid,omitempty"` }{ Kid: jwk.KeyID, }, Proof: proof, ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", APIVersion: 5, Scope: "wire_client_id", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) accessToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { AccessToken string `json:"access_token"` }{ AccessToken: accessToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "http://issuerexample.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ Target: "http://issuer.example.com", SigningKey: signerPEMBytes, }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: string(valueBytes), }, payload: payload, ctx: ctx, jwk: jwk, db: &MockDB{}, err: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New("db *acme.MockDB is not a WireDB"), }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if tc.srv != nil { defer tc.srv.Close() } ctx := tc.ctx if ctx == nil { ctx = context.Background() } ctx = NewClientContext(ctx, tc.vc) err := tc.ch.Validate(ctx, tc.db, tc.jwk, tc.payload) if tc.err != nil { var k *Error if errors.As(err, &k) { assert.Equal(t, tc.err.Type, k.Type) assert.Equal(t, tc.err.Detail, k.Detail) assert.Equal(t, tc.err.Status, k.Status) assert.Equal(t, tc.err.Err.Error(), k.Err.Error()) } else { assert.Fail(t, "unexpected error type") } return } assert.NoError(t, err) }) } } func mustJWKServer(t *testing.T, pub jose.JSONWebKey) *httptest.Server { t.Helper() mux := http.NewServeMux() server := httptest.NewServer(mux) b, err := json.Marshal(struct { Keys []jose.JSONWebKey `json:"keys,omitempty"` }{ Keys: []jose.JSONWebKey{pub}, }) require.NoError(t, err) jwks := string(b) wellKnown := fmt.Sprintf(`{ "issuer": "%[1]s", "authorization_endpoint": "%[1]s/auth", "token_endpoint": "%[1]s/token", "jwks_uri": "%[1]s/keys", "userinfo_endpoint": "%[1]s/userinfo", "id_token_signing_alg_values_supported": ["ES256"] }`, server.URL) mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, req *http.Request) { _, err := io.WriteString(w, wellKnown) if err != nil { w.WriteHeader(500) } }) mux.HandleFunc("/keys", func(w http.ResponseWriter, req *http.Request) { _, err := io.WriteString(w, jwks) if err != nil { w.WriteHeader(500) } }) t.Cleanup(server.Close) return server } type errReader int func (errReader) Read([]byte) (int, error) { return 0, errors.New("force") } func (errReader) Close() error { return nil } func TestHTTP01Validate(t *testing.T) { type test struct { vc Client ch *Challenge jwk *jose.JSONWebKey db DB err *Error } tests := map[string]func(t *testing.T) test{ "fail/http-get-error-store-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusPending, } return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/http-get-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusPending, } return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, } }, "fail/http-get->=400-store-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusPending, } return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, Body: errReader(0), }, nil }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/http-get->=400": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusPending, } return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, Body: errReader(0), }, nil }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, } }, "fail/read-body": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusPending, } return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return &http.Response{ Body: errReader(0), }, nil }, }, err: NewErrorISE("error reading response body for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token), } }, "fail/key-auth-gen-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) jwk.Key = "foo" return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil }, }, jwk: jwk, err: NewErrorISE("error generating JWK thumbprint: go-jose/go-jose: unknown key type 'string'"), } }, "ok/key-auth-mismatch": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil }, }, jwk: jwk, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusInvalid, updch.Status) err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got foo", expKeyAuth) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, } }, "fail/key-auth-mismatch-store-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil }, }, jwk: jwk, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusInvalid, updch.Status) err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got foo", expKeyAuth) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "fail/update-challenge-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil }, }, jwk: jwk, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusValid, updch.Status) assert.Nil(t, updch.Error) va, err := time.Parse(time.RFC3339, updch.ValidatedAt) require.NoError(t, err) now := clock.Now() assert.True(t, va.Add(-time.Minute).Before(now)) assert.True(t, va.Add(time.Minute).After(now)) return errors.New("force") }, }, err: NewErrorISE("error updating challenge: force"), } }, "ok": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: "zap.internal", Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) return test{ ch: ch, vc: &mockClient{ get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil }, }, jwk: jwk, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, "zap.internal", updch.Value) assert.Equal(t, StatusValid, updch.Status) assert.Nil(t, updch.Error) va, err := time.Parse(time.RFC3339, updch.ValidatedAt) require.NoError(t, err) now := clock.Now() assert.True(t, va.Add(-time.Minute).Before(now)) assert.True(t, va.Add(time.Minute).After(now)) return nil }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) ctx := NewClientContext(context.Background(), tc.vc) if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.Error(t, tc.err) { var k *Error if errors.As(err, &k) { assert.Equal(t, tc.err.Type, k.Type) assert.Equal(t, tc.err.Detail, k.Detail) assert.Equal(t, tc.err.Status, k.Status) assert.Equal(t, tc.err.Err.Error(), k.Err.Error()) } else { assert.Fail(t, "unexpected error type") } } } else { assert.Nil(t, tc.err) } }) } } func TestDNS01Validate(t *testing.T) { fulldomain := "*.zap.internal" domain := strings.TrimPrefix(fulldomain, "*.") type test struct { vc Client ch *Challenge jwk *jose.JSONWebKey db DB err *Error } tests := map[string]func(t *testing.T) test{ "fail/lookupTXT-store-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: fulldomain, Status: StatusPending, } return test{ ch: ch, vc: &mockClient{ lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, fulldomain, updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/lookupTXT-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: fulldomain, Status: StatusPending, } return test{ ch: ch, vc: &mockClient{ lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, fulldomain, updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorDNSType, "error looking up TXT records for domain %s: force", domain) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, } }, "fail/key-auth-gen-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: fulldomain, Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) jwk.Key = "foo" return test{ ch: ch, vc: &mockClient{ lookupTxt: func(url string) ([]string, error) { return []string{"foo"}, nil }, }, jwk: jwk, err: NewErrorISE("error generating JWK thumbprint: go-jose/go-jose: unknown key type 'string'"), } }, "fail/key-auth-mismatch-store-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: fulldomain, Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) return test{ ch: ch, vc: &mockClient{ lookupTxt: func(url string) ([]string, error) { return []string{"foo", "bar"}, nil }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, fulldomain, updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, jwk: jwk, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/key-auth-mismatch-store-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: fulldomain, Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) return test{ ch: ch, vc: &mockClient{ lookupTxt: func(url string) ([]string, error) { return []string{"foo", "bar"}, nil }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, fulldomain, updch.Value) assert.Equal(t, StatusPending, updch.Status) err := NewError(ErrorRejectedIdentifierType, "keyAuthorization does not match; expected %s, but got %s", expKeyAuth, []string{"foo", "bar"}) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, jwk: jwk, } }, "fail/update-challenge-error": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: fulldomain, Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) h := sha256.Sum256([]byte(expKeyAuth)) expected := base64.RawURLEncoding.EncodeToString(h[:]) return test{ ch: ch, vc: &mockClient{ lookupTxt: func(url string) ([]string, error) { return []string{"foo", expected}, nil }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, fulldomain, ch.Value) assert.Equal(t, StatusValid, updch.Status) assert.Nil(t, updch.Error) va, err := time.Parse(time.RFC3339, updch.ValidatedAt) require.NoError(t, err) now := clock.Now() assert.True(t, va.Add(-time.Minute).Before(now)) assert.True(t, va.Add(time.Minute).After(now)) return errors.New("force") }, }, jwk: jwk, err: NewErrorISE("error updating challenge: force"), } }, "ok": func(t *testing.T) test { ch := &Challenge{ ID: "chID", Token: "token", Value: fulldomain, Status: StatusPending, } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) h := sha256.Sum256([]byte(expKeyAuth)) expected := base64.RawURLEncoding.EncodeToString(h[:]) return test{ ch: ch, vc: &mockClient{ lookupTxt: func(url string) ([]string, error) { return []string{"foo", expected}, nil }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, fulldomain, updch.Value) assert.Equal(t, StatusValid, updch.Status) assert.Nil(t, updch.Error) va, err := time.Parse(time.RFC3339, updch.ValidatedAt) require.NoError(t, err) now := clock.Now() assert.True(t, va.Add(-time.Minute).Before(now)) assert.True(t, va.Add(time.Minute).After(now)) return nil }, }, jwk: jwk, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) ctx := NewClientContext(context.Background(), tc.vc) if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.Error(t, tc.err) { var k *Error if errors.As(err, &k) { assert.Equal(t, tc.err.Type, k.Type) assert.Equal(t, tc.err.Detail, k.Detail) assert.Equal(t, tc.err.Status, k.Status) assert.Equal(t, tc.err.Err.Error(), k.Err.Error()) } else { assert.Fail(t, "unexpected error type") } } } else { assert.Nil(t, tc.err) } }) } } type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) func newTestTLSALPNServer(validationCert *tls.Certificate, opts ...func(*httptest.Server)) (*httptest.Server, tlsDialer) { srv := httptest.NewUnstartedServer(http.NewServeMux()) srv.Config.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){ "acme-tls/1": func(_ *http.Server, conn *tls.Conn, _ http.Handler) { // no-op }, "http/1.1": func(_ *http.Server, conn *tls.Conn, _ http.Handler) { panic("unexpected http/1.1 next proto") }, } srv.TLS = &tls.Config{ GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { if len(hello.SupportedProtos) == 1 && hello.SupportedProtos[0] == "acme-tls/1" { return validationCert, nil } return nil, nil }, NextProtos: []string{ "acme-tls/1", "http/1.1", }, } // Apply options for _, fn := range opts { fn(srv) } srv.Listener = tls.NewListener(srv.Listener, srv.TLS) //srv.Config.ErrorLog = log.New(ioutil.Discard, "", 0) // hush return srv, func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) } } // noopConn is a mock net.Conn that does nothing. type noopConn struct{} func (c *noopConn) Read(_ []byte) (n int, err error) { return 0, io.EOF } func (c *noopConn) Write(_ []byte) (n int, err error) { return 0, io.EOF } func (c *noopConn) Close() error { return nil } func (c *noopConn) LocalAddr() net.Addr { return &net.IPAddr{IP: net.IPv4zero, Zone: ""} } func (c *noopConn) RemoteAddr() net.Addr { return &net.IPAddr{IP: net.IPv4zero, Zone: ""} } func (c *noopConn) SetDeadline(time.Time) error { return nil } func (c *noopConn) SetReadDeadline(time.Time) error { return nil } func (c *noopConn) SetWriteDeadline(time.Time) error { return nil } func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, names ...string) (*tls.Certificate, error) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, err } certTemplate := &x509.Certificate{ SerialNumber: big.NewInt(1337), Subject: pkix.Name{ Organization: []string{"Test"}, }, NotBefore: time.Now(), NotAfter: time.Now().AddDate(0, 0, 1), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, DNSNames: names, } if keyAuthHash != nil { oid := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} if obsoleteOID { oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} } keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash) certTemplate.ExtraExtensions = []pkix.Extension{ { Id: oid, Critical: critical, Value: keyAuthHashEnc, }, } } cert, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, privateKey.Public(), privateKey) if err != nil { return nil, err } return &tls.Certificate{ PrivateKey: privateKey, Certificate: [][]byte{cert}, }, nil } func TestTLSALPN01Validate(t *testing.T) { makeTLSCh := func() *Challenge { return &Challenge{ ID: "chID", Token: "token", Type: "tls-alpn-01", Status: StatusPending, Value: "zap.internal", } } type test struct { vc Client ch *Challenge jwk *jose.JSONWebKey db DB srv *httptest.Server err *Error } tests := map[string]func(t *testing.T) test{ "fail/tlsDial-store-error": func(t *testing.T) test { ch := makeTLSCh() return test{ ch: ch, vc: &mockClient{ tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusPending, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorConnectionType, "error doing TLS dial for %v: force", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/tlsDial-error": func(t *testing.T) test { ch := makeTLSCh() return test{ ch: ch, vc: &mockClient{ tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusPending, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorConnectionType, "error doing TLS dial for %v: force", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, } }, "ok/tlsDial-timeout": func(t *testing.T) test { ch := makeTLSCh() srv, tlsDial := newTestTLSALPNServer(nil) // srv.Start() - do not start server to cause timeout return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusPending, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorConnectionType, "error doing TLS dial for %v: context deadline exceeded", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, srv: srv, } }, "ok/no-certificates-error": func(t *testing.T) test { ch := makeTLSCh() return test{ ch: ch, vc: &mockClient{ tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.Client(&noopConn{}, config), nil }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "tls-alpn-01 challenge for %v resulted in no certificates", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, } }, "fail/no-certificates-store-error": func(t *testing.T) test { ch := makeTLSCh() return test{ ch: ch, vc: &mockClient{ tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.Client(&noopConn{}, config), nil }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "tls-alpn-01 challenge for %v resulted in no certificates", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/error-no-protocol": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) srv := httptest.NewTLSServer(nil) return test{ ch: ch, vc: &mockClient{ tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, srv: srv, jwk: jwk, } }, "fail/no-protocol-store-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) srv := httptest.NewTLSServer(nil) return test{ ch: ch, vc: &mockClient{ tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) }, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, srv: srv, jwk: jwk, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/no-names-nor-ips-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, srv: srv, jwk: jwk, } }, "fail/no-names-store-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, srv: srv, jwk: jwk, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/too-many-names-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value, "other.internal") require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, srv: srv, jwk: jwk, } }, "ok/wrong-name": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, "other.internal") require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: leaf certificate must contain a single IP address or DNS name, %v", ch.Value) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, srv: srv, jwk: jwk, } }, "fail/key-auth-gen-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) jwk.Key = "foo" cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, srv: srv, jwk: jwk, err: NewErrorISE("error generating JWK thumbprint: go-jose/go-jose: unknown key type 'string'"), } }, "ok/error-no-extension": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) cert, err := newTLSALPNValidationCert(nil, false, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, srv: srv, jwk: jwk, } }, "fail/no-extension-store-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) cert, err := newTLSALPNValidationCert(nil, false, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, srv: srv, jwk: jwk, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/error-extension-not-critical": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, srv: srv, jwk: jwk, } }, "fail/extension-not-critical-store-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, false, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: acmeValidationV1 extension not critical") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, srv: srv, jwk: jwk, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/error-malformed-extension": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, srv: srv, jwk: jwk, } }, "fail/malformed-extension-store-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) cert, err := newTLSALPNValidationCert([]byte{1, 2, 3}, false, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: malformed acmeValidationV1 extension value") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, srv: srv, jwk: jwk, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/error-keyauth-mismatch": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) incorrectTokenHash := sha256.Sum256([]byte("mismatched")) cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ "expected acmeValidationV1 extension value %s for this challenge but got %s", hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:])) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, srv: srv, jwk: jwk, } }, "fail/keyauth-mismatch-store-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) incorrectTokenHash := sha256.Sum256([]byte("mismatched")) cert, err := newTLSALPNValidationCert(incorrectTokenHash[:], false, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ "expected acmeValidationV1 extension value %s for this challenge but got %s", hex.EncodeToString(expKeyAuthHash[:]), hex.EncodeToString(incorrectTokenHash[:])) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, srv: srv, jwk: jwk, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/error-obsolete-oid": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, srv: srv, jwk: jwk, } }, "fail/obsolete-oid-store-error": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], true, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) err := NewError(ErrorRejectedIdentifierType, "incorrect certificate for tls-alpn-01 challenge: "+ "obsolete id-pe-acmeIdentifier in acmeValidationV1 extension") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, srv: srv, jwk: jwk, err: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok": func(t *testing.T) test { ch := makeTLSCh() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "zap.internal", updch.Value) assert.Nil(t, updch.Error) return nil }, }, srv: srv, jwk: jwk, } }, "ok/ip": func(t *testing.T) test { ch := makeTLSCh() ch.Value = "127.0.0.1" jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) expKeyAuth, err := KeyAuthorization(ch.Token, jwk) require.NoError(t, err) expKeyAuthHash := sha256.Sum256([]byte(expKeyAuth)) cert, err := newTLSALPNValidationCert(expKeyAuthHash[:], false, true, ch.Value) require.NoError(t, err) srv, tlsDial := newTestTLSALPNServer(cert) srv.Start() return test{ ch: ch, vc: &mockClient{ tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("tls-alpn-01"), updch.Type) assert.Equal(t, "127.0.0.1", updch.Value) assert.Nil(t, updch.Error) return nil }, }, srv: srv, jwk: jwk, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if tc.srv != nil { defer tc.srv.Close() } ctx := NewClientContext(context.Background(), tc.vc) if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.Error(t, tc.err) { var k *Error if errors.As(err, &k) { assert.Equal(t, tc.err.Type, k.Type) assert.Equal(t, tc.err.Detail, k.Detail) assert.Equal(t, tc.err.Status, k.Status) assert.Equal(t, tc.err.Err.Error(), k.Err.Error()) assert.Equal(t, tc.err.Subproblems, k.Subproblems) } else { assert.Fail(t, "unexpected error type") } } } else { assert.Nil(t, tc.err) } }) } } func Test_reverseAddr(t *testing.T) { type args struct { ip net.IP } tests := []struct { name string args args wantArpa string }{ { name: "ok/ipv4", args: args{ ip: net.ParseIP("127.0.0.1"), }, wantArpa: "1.0.0.127.in-addr.arpa.", }, { name: "ok/ipv6", args: args{ ip: net.ParseIP("2001:db8::567:89ab"), }, wantArpa: "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if gotArpa := reverseAddr(tt.args.ip); gotArpa != tt.wantArpa { t.Errorf("reverseAddr() = %v, want %v", gotArpa, tt.wantArpa) } }) } } func Test_serverName(t *testing.T) { type args struct { ch *Challenge } tests := []struct { name string args args want string }{ { name: "ok/dns", args: args{ ch: &Challenge{ Value: "example.com", }, }, want: "example.com", }, { name: "ok/ipv4", args: args{ ch: &Challenge{ Value: "127.0.0.1", }, }, want: "1.0.0.127.in-addr.arpa.", }, { name: "ok/ipv6", args: args{ ch: &Challenge{ Value: "2001:db8::567:89ab", }, }, want: "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := serverName(tt.args.ch); got != tt.want { t.Errorf("serverName() = %v, want %v", got, tt.want) } }) } } func Test_http01ChallengeHost(t *testing.T) { tests := []struct { name string strictFQDN bool value string want string }{ { name: "dns", strictFQDN: false, value: "www.example.com", want: "www.example.com", }, { name: "dns strict", strictFQDN: true, value: "www.example.com", want: "www.example.com.", }, { name: "rooted dns", strictFQDN: false, value: "www.example.com.", want: "www.example.com.", }, { name: "rooted dns strict", strictFQDN: true, value: "www.example.com.", want: "www.example.com.", }, { name: "ipv4", value: "127.0.0.1", want: "127.0.0.1", }, { name: "ipv6", value: "::1", want: "[::1]", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tmp := StrictFQDN t.Cleanup(func() { StrictFQDN = tmp }) StrictFQDN = tt.strictFQDN if got := http01ChallengeHost(tt.value); got != tt.want { t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want) } }) } } func Test_doAppleAttestationFormat(t *testing.T) { ctx := context.Background() ca, err := minica.New() if err != nil { t.Fatal(err) } caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.Root.Raw}) signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } leaf, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "attestation cert"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ {Id: oidAppleSerialNumber, Value: []byte("serial-number")}, {Id: oidAppleUniqueDeviceIdentifier, Value: []byte("udid")}, {Id: oidAppleSecureEnclaveProcessorOSVersion, Value: []byte("16.0")}, {Id: oidAppleNonce, Value: []byte("nonce")}, }, }) if err != nil { t.Fatal(err) } fingerprint, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } type args struct { ctx context.Context prov Provisioner ch *Challenge att *attestationObject } tests := []struct { name string args args want *appleAttestationData wantErr bool }{ {"ok", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{}, &attestationObject{ Format: "apple", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, }, }}, &appleAttestationData{ Nonce: []byte("nonce"), SerialNumber: "serial-number", UDID: "udid", SEPVersion: "16.0", Certificate: leaf, Fingerprint: fingerprint, }, false}, {"fail apple issuer", args{ctx, mustAttestationProvisioner(t, nil), &Challenge{}, &attestationObject{ Format: "apple", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, }, }}, nil, true}, {"fail missing x5c", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{}, &attestationObject{ Format: "apple", AttStatement: map[string]interface{}{ "foo": "bar", }, }}, nil, true}, {"fail empty issuer", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{}, &attestationObject{ Format: "apple", AttStatement: map[string]interface{}{ "x5c": []interface{}{}, }, }}, nil, true}, {"fail leaf type", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{}, &attestationObject{ Format: "apple", AttStatement: map[string]interface{}{ "x5c": []interface{}{"leaf", ca.Intermediate.Raw}, }, }}, nil, true}, {"fail leaf parse", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{}, &attestationObject{ Format: "apple", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw[:100], ca.Intermediate.Raw}, }, }}, nil, true}, {"fail intermediate type", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{}, &attestationObject{ Format: "apple", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, "intermediate"}, }, }}, nil, true}, {"fail intermediate parse", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{}, &attestationObject{ Format: "apple", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw[:100]}, }, }}, nil, true}, {"fail verify", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{}, &attestationObject{ Format: "apple", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw}, }, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := doAppleAttestationFormat(tt.args.ctx, tt.args.prov, tt.args.ch, tt.args.att) if (err != nil) != tt.wantErr { t.Errorf("doAppleAttestationFormat() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("doAppleAttestationFormat() = %v, want %v", got, tt.want) } }) } } func Test_doStepAttestationFormat(t *testing.T) { ctx := context.Background() ca, err := minica.New() require.NoError(t, err) caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.Root.Raw}) makeLeaf := func(signer crypto.Signer, serialNumber []byte) *x509.Certificate { leaf, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "attestation cert"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ {Id: oidYubicoSerialNumber, Value: serialNumber}, }, }) require.NoError(t, err) return leaf } makeLeafWithStepManagedDeviceID := func(signer crypto.Signer, serialNumber string) *x509.Certificate { v, err := asn1.Marshal(stepManagedDevice{DeviceID: serialNumber}) require.NoError(t, err) leaf, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "attestation cert"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ {Id: oidStepManagedDevice, Value: v}, }, }) require.NoError(t, err) return leaf } mustSigner := func(kty, crv string, size int) crypto.Signer { s, err := keyutil.GenerateSigner(kty, crv, size) require.NoError(t, err) return s } signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) fingerprint, err := keyutil.Fingerprint(signer.Public()) require.NoError(t, err) serialNumber, err := asn1.Marshal(1234) require.NoError(t, err) leaf := makeLeaf(signer, serialNumber) leafWithStepManagedDeviceID := makeLeafWithStepManagedDeviceID(signer, "1234") jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) keyAuth, err := KeyAuthorization("token", jwk) require.NoError(t, err) keyAuthSum := sha256.Sum256([]byte(keyAuth)) sig, err := signer.Sign(rand.Reader, keyAuthSum[:], crypto.SHA256) require.NoError(t, err) cborSig, err := cbor.Marshal(sig) require.NoError(t, err) otherSigner, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) otherSig, err := otherSigner.Sign(rand.Reader, keyAuthSum[:], crypto.SHA256) require.NoError(t, err) otherCBORSig, err := cbor.Marshal(otherSig) require.NoError(t, err) type args struct { ctx context.Context prov Provisioner ch *Challenge jwk *jose.JSONWebKey att *attestationObject } tests := []struct { name string args args want *stepAttestationData wantErr bool }{ {"ok", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, &stepAttestationData{ SerialNumber: "1234", Certificate: leaf, Fingerprint: fingerprint, }, false}, {"ok/step-managed-device-id", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leafWithStepManagedDeviceID.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, &stepAttestationData{ SerialNumber: "1234", Certificate: leafWithStepManagedDeviceID, Fingerprint: fingerprint, }, false}, {"fail yubico issuer", args{ctx, mustAttestationProvisioner(t, nil), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail x5c type", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": [][]byte{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail x5c empty", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail leaf type", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{"leaf", ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail leaf parse", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw[:100], ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail intermediate type", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, "intermediate"}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail intermediate parse", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw[:100]}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail verify", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail sig type", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": string(cborSig), }, }}, nil, true}, {"fail sig unmarshal", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": []byte("bad-sig"), }, }}, nil, true}, {"fail keyAuthorization", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, &jose.JSONWebKey{Key: []byte("not an asymmetric key")}, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail sig verify P-256", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": otherCBORSig, }, }}, nil, true}, {"fail sig verify P-384", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{makeLeaf(mustSigner("EC", "P-384", 0), serialNumber).Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail sig verify RSA", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{makeLeaf(mustSigner("RSA", "", 2048), serialNumber).Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail sig verify Ed25519", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{makeLeaf(mustSigner("OKP", "Ed25519", 0), serialNumber).Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, {"fail unmarshal serial number", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{makeLeaf(signer, []byte("bad-serial")).Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := doStepAttestationFormat(tt.args.ctx, tt.args.prov, tt.args.ch, tt.args.jwk, tt.args.att) if (err != nil) != tt.wantErr { t.Errorf("doStepAttestationFormat() error = %#v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("doStepAttestationFormat() = %v, want %v", got, tt.want) } }) } } func Test_doStepAttestationFormat_noCAIntermediate(t *testing.T) { ctx := context.Background() // This CA simulates a YubiKey v5.2.4, where the attestation intermediate in // the CA does not have the basic constraint extension. With the current // validation of the certificate the test case below returns an error. If // we change the validation to support this use case, the test case below // should change. // // See https://github.com/Yubico/yubikey-manager/issues/522 ca, err := minica.New(minica.WithIntermediateTemplate(`{"subject": {{ toJson .Subject }}}`)) if err != nil { t.Fatal(err) } caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.Root.Raw}) makeLeaf := func(signer crypto.Signer, serialNumber []byte) *x509.Certificate { leaf, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "attestation cert"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ {Id: oidYubicoSerialNumber, Value: serialNumber}, }, }) if err != nil { t.Fatal(err) } return leaf } signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } serialNumber, err := asn1.Marshal(1234) if err != nil { t.Fatal(err) } leaf := makeLeaf(signer, serialNumber) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) if err != nil { t.Fatal(err) } keyAuth, err := KeyAuthorization("token", jwk) if err != nil { t.Fatal(err) } keyAuthSum := sha256.Sum256([]byte(keyAuth)) sig, err := signer.Sign(rand.Reader, keyAuthSum[:], crypto.SHA256) if err != nil { t.Fatal(err) } cborSig, err := cbor.Marshal(sig) if err != nil { t.Fatal(err) } type args struct { ctx context.Context prov Provisioner ch *Challenge jwk *jose.JSONWebKey att *attestationObject } tests := []struct { name string args args want *stepAttestationData wantErr bool }{ {"fail no intermediate", args{ctx, mustAttestationProvisioner(t, caRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := doStepAttestationFormat(tt.args.ctx, tt.args.prov, tt.args.ch, tt.args.jwk, tt.args.att) if (err != nil) != tt.wantErr { t.Errorf("doStepAttestationFormat() error = %#v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("doStepAttestationFormat() = %v, want %v", got, tt.want) } }) } } func Test_deviceAttest01Validate(t *testing.T) { invalidPayload := "!?" errorPayload, err := json.Marshal(struct { Error string `json:"error"` }{ Error: "an error", }) require.NoError(t, err) errorBase64Payload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: "?!", }) require.NoError(t, err) emptyPayload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString([]byte("")), }) require.NoError(t, err) emptyObjectPayload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString([]byte("{}")), }) require.NoError(t, err) attObj, err := cbor.Marshal(struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` }{ Format: "step", AttStatement: map[string]interface{}{ "alg": -7, "sig": "", }, }) require.NoError(t, err) errorNonWellformedCBORPayload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString(attObj[:len(attObj)-1]), // cut the CBOR encoded data off }) require.NoError(t, err) unsupportedFormatAttObj, err := cbor.Marshal(struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` }{ Format: "unsupported-format", AttStatement: map[string]interface{}{ "alg": -7, "sig": "", }, }) require.NoError(t, err) errorUnsupportedFormat, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString(unsupportedFormatAttObj), }) require.NoError(t, err) type args struct { ctx context.Context ch *Challenge db DB jwk *jose.JSONWebKey payload []byte } type test struct { args args wantErr *Error } tests := map[string]func(t *testing.T) test{ "fail/getAuthorization": func(t *testing.T) test { return test{ args: args{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return nil, errors.New("not found") }, }, payload: []byte(invalidPayload), }, wantErr: NewErrorISE("error loading authorization: not found"), } }, "fail/json.Unmarshal": func(t *testing.T) test { return test{ args: args{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, }, payload: []byte(invalidPayload), }, wantErr: NewErrorISE("error unmarshalling JSON: invalid character '!' looking for beginning of value"), } }, "fail/storeError": func(t *testing.T) test { return test{ args: args{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: errorPayload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) err := NewError(ErrorRejectedIdentifierType, "payload contained error: an error") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return errors.New("force") }, }, }, wantErr: NewErrorISE("failure saving error to acme challenge: force"), } }, "ok/storeError-return-nil": func(t *testing.T) test { return test{ args: args{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: errorPayload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, errorPayload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewError(ErrorRejectedIdentifierType, "payload contained error: an error") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, }, wantErr: nil, } }, "ok/base64-decode": func(t *testing.T) test { return test{ args: args{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, errorBase64Payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "failed base64 decoding attObj %q", "?!") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, payload: errorBase64Payload, }, } }, "ok/empty-attobj": func(t *testing.T) test { return test{ args: args{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, emptyPayload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "attObj must not be empty") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, payload: emptyPayload, }, } }, "ok/empty-json-attobj": func(t *testing.T) test { return test{ args: args{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, emptyObjectPayload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "attObj must not be empty") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, payload: emptyObjectPayload, }, } }, "ok/cborDecoder.Wellformed": func(t *testing.T) test { return test{ args: args{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, errorNonWellformedCBORPayload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "attObj is not well formed CBOR: unexpected EOF") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, payload: errorNonWellformedCBORPayload, }, } }, "ok/unsupported-attestation-format": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), mustNonAttestationProvisioner(t)) return test{ args: args{ ctx: ctx, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, errorUnsupportedFormat, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "unsupported attestation object format %q", "unsupported-format") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, payload: errorUnsupportedFormat, }, } }, "ok/prov.IsAttestationFormatEnabled": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, _, _ := mustAttestYubikey(t, "nonce", keyAuth, 12345678) ctx := NewProvisionerContext(context.Background(), mustNonAttestationProvisioner(t)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewError(ErrorBadAttestationStatementType, "attestation format %q is not enabled", "step") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, }, wantErr: nil, } }, "ok/doAppleAttestationFormat-storeError": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, nil)) attObj, err := cbor.Marshal(struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` }{ Format: "apple", AttStatement: map[string]interface{}{}, }) require.NoError(t, err) payload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString(attObj), }) require.NoError(t, err) return test{ args: args{ ctx: ctx, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "x5c not present") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, }, wantErr: nil, } }, "ok/doAppleAttestationFormat-non-matching-nonce": func(t *testing.T) test { jwk, _ := mustAccountAndKeyAuthorization(t, "token") payload, _, root := mustAttestApple(t, "bad-nonce") caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "serial-number", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "serial-number", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "challenge token does not match") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, }, wantErr: nil, } }, "ok/doAppleAttestationFormat-non-matching-challenge-value": func(t *testing.T) test { jwk, _ := mustAccountAndKeyAuthorization(t, "token") payload, _, root := mustAttestApple(t, "nonce") caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "nonce", Type: "device-attest-01", Status: StatusPending, Value: "non-matching-value", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "nonce", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "non-matching-value", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) subproblem := NewSubproblemWithIdentifier( ErrorRejectedIdentifierType, Identifier{Type: "permanent-identifier", Value: "non-matching-value"}, `challenge identifier "non-matching-value" doesn't match any of the attested hardware identifiers ["udid" "serial-number"]`, ) err := NewDetailedError(ErrorBadAttestationStatementType, "permanent identifier does not match").AddSubproblems(subproblem) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, }, wantErr: nil, } }, "ok/doStepAttestationFormat-storeError": func(t *testing.T) test { ca, err := minica.New() require.NoError(t, err) caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.Root.Raw}) signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) token := "token" keyAuth, err := KeyAuthorization(token, jwk) require.NoError(t, err) keyAuthSum := sha256.Sum256([]byte(keyAuth)) sig, err := signer.Sign(rand.Reader, keyAuthSum[:], crypto.SHA256) require.NoError(t, err) cborSig, err := cbor.Marshal(sig) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) attObj, err := cbor.Marshal(struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` }{ Format: "step", AttStatement: map[string]interface{}{ "alg": -7, "sig": cborSig, }, }) require.NoError(t, err) payload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString(attObj), }) require.NoError(t, err) return test{ args: args{ ctx: ctx, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "x5c not present") assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, }, wantErr: nil, } }, "ok/doStepAttestationFormat-non-matching-identifier": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, leaf, root := mustAttestYubikey(t, "nonce", keyAuth, 87654321) caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error { fingerprint, err := keyutil.Fingerprint(leaf.PublicKey) assert.NoError(t, err) assert.Equal(t, "azID", az.ID) assert.Equal(t, fingerprint, az.Fingerprint) return nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "permanent identifier does not match"). AddSubproblems(NewSubproblemWithIdentifier( ErrorRejectedIdentifierType, Identifier{Type: "permanent-identifier", Value: "12345678"}, "challenge identifier \"12345678\" doesn't match the attested hardware identifier \"87654321\"", )) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, }, wantErr: nil, } }, "ok/unknown-attestation-format": func(t *testing.T) test { ca, err := minica.New() require.NoError(t, err) signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) token := "token" keyAuth, err := KeyAuthorization(token, jwk) require.NoError(t, err) keyAuthSum := sha256.Sum256([]byte(keyAuth)) sig, err := signer.Sign(rand.Reader, keyAuthSum[:], crypto.SHA256) require.NoError(t, err) cborSig, err := cbor.Marshal(sig) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), mustNonAttestationProvisioner(t)) makeLeaf := func(signer crypto.Signer, serialNumber []byte) *x509.Certificate { leaf, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "attestation cert"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ {Id: oidYubicoSerialNumber, Value: serialNumber}, }, }) if err != nil { t.Fatal(err) } return leaf } require.NoError(t, err) serialNumber, err := asn1.Marshal(87654321) require.NoError(t, err) leaf := makeLeaf(signer, serialNumber) attObj, err := cbor.Marshal(struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` }{ Format: "bogus-format", AttStatement: map[string]interface{}{ "x5c": []interface{}{leaf.Raw, ca.Intermediate.Raw}, "alg": -7, "sig": cborSig, }, }) require.NoError(t, err) payload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString(attObj), }) require.NoError(t, err) return test{ args: args{ ctx: ctx, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, `unsupported attestation object format "bogus-format"`) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, jwk: jwk, }, wantErr: nil, } }, "fail/db.UpdateAuthorization": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, leaf, root := mustAttestYubikey(t, "nonce", keyAuth, 12345678) caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error { fingerprint, err := keyutil.Fingerprint(leaf.PublicKey) assert.NoError(t, err) assert.Equal(t, "azID", az.ID) assert.Equal(t, fingerprint, az.Fingerprint) return errors.New("force") }, }, }, wantErr: NewError(ErrorServerInternalType, "error updating authorization: force"), } }, "fail/db.UpdateChallenge": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, leaf, root := mustAttestYubikey(t, "nonce", keyAuth, 12345678) caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error { fingerprint, err := keyutil.Fingerprint(leaf.PublicKey) assert.NoError(t, err) assert.Equal(t, "azID", az.ID) assert.Equal(t, fingerprint, az.Fingerprint) return nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Equal(t, "step", updch.PayloadFormat) return errors.New("force") }, }, }, wantErr: NewError(ErrorServerInternalType, "error updating challenge: force"), } }, "ok": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, leaf, root := mustAttestYubikey(t, "nonce", keyAuth, 12345678) caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error { fingerprint, err := keyutil.Fingerprint(leaf.PublicKey) assert.NoError(t, err) assert.Equal(t, "azID", az.ID) assert.Equal(t, fingerprint, az.Fingerprint) return nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Equal(t, "step", updch.PayloadFormat) return nil }, }, }, wantErr: nil, } }, "ok/step-managed-device-id": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, leaf, root := mustAttestStepManagedDeviceID(t, "nonce", keyAuth, "12345678") caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error { fingerprint, err := keyutil.Fingerprint(leaf.PublicKey) assert.NoError(t, err) assert.Equal(t, "azID", az.ID) assert.Equal(t, fingerprint, az.Fingerprint) return nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) assert.Equal(t, payload, updch.Payload) assert.Equal(t, "step", updch.PayloadFormat) return nil }, }, }, wantErr: nil, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if err := deviceAttest01Validate(tc.args.ctx, tc.args.ch, tc.args.db, tc.args.jwk, tc.args.payload); err != nil { if assert.Error(t, tc.wantErr) { assert.ErrorContains(t, err, tc.wantErr.Error()) } return } assert.Nil(t, tc.wantErr) }) } } var ( oidTPMManufacturer = asn1.ObjectIdentifier{2, 23, 133, 2, 1} oidTPMModel = asn1.ObjectIdentifier{2, 23, 133, 2, 2} oidTPMVersion = asn1.ObjectIdentifier{2, 23, 133, 2, 3} ) func generateValidAKCertificate(t *testing.T) *x509.Certificate { t.Helper() signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) template := &x509.Certificate{ PublicKey: signer.Public(), Version: 3, IsCA: false, UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate}, } asn1Value := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMModel, "SLB 9670 TPM2.0", oidTPMVersion, "7.55")) sans := []x509util.SubjectAlternativeName{ {Type: x509util.DirectoryNameType, ASN1Value: asn1Value}, } ext, err := createSubjectAltNameExtension(nil, nil, nil, nil, sans, true) require.NoError(t, err) ext.Set(template) ca, err := minica.New() require.NoError(t, err) cert, err := ca.Sign(template) require.NoError(t, err) return cert } func Test_validateAKCertificate(t *testing.T) { cert := generateValidAKCertificate(t) tests := []struct { name string c *x509.Certificate expErr error }{ { name: "ok", c: cert, expErr: nil, }, { name: "fail/version", c: &x509.Certificate{ Version: 1, }, expErr: errors.New("AK certificate has invalid version 1; only version 3 is allowed"), }, { name: "fail/subject", c: &x509.Certificate{ Version: 3, Subject: pkix.Name{CommonName: "fail!"}, }, expErr: errors.New(`AK certificate subject must be empty; got "CN=fail!"`), }, { name: "fail/isCA", c: &x509.Certificate{ Version: 3, IsCA: true, }, expErr: errors.New("AK certificate must not be a CA"), }, { name: "fail/extendedKeyUsage", c: &x509.Certificate{ Version: 3, }, expErr: errors.New("AK certificate is missing Extended Key Usage extension"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateAKCertificate(tt.c) if tt.expErr != nil { if assert.Error(t, err) { assert.EqualError(t, err, tt.expErr.Error()) } return } assert.NoError(t, err) }) } } func Test_validateAKCertificateSubjectAlternativeNames(t *testing.T) { ok := generateValidAKCertificate(t) t.Helper() signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) getBase := func() *x509.Certificate { return &x509.Certificate{ PublicKey: signer.Public(), Version: 3, IsCA: false, UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate}, } } ca, err := minica.New() require.NoError(t, err) missingManufacturerASN1 := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMModel, "SLB 9670 TPM2.0", oidTPMVersion, "7.55")) sans := []x509util.SubjectAlternativeName{ {Type: x509util.DirectoryNameType, ASN1Value: missingManufacturerASN1}, } ext, err := createSubjectAltNameExtension(nil, nil, nil, nil, sans, true) require.NoError(t, err) missingManufacturer := getBase() ext.Set(missingManufacturer) missingManufacturer, err = ca.Sign(missingManufacturer) require.NoError(t, err) missingModelASN1 := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMVersion, "7.55")) sans = []x509util.SubjectAlternativeName{ {Type: x509util.DirectoryNameType, ASN1Value: missingModelASN1}, } ext, err = createSubjectAltNameExtension(nil, nil, nil, nil, sans, true) require.NoError(t, err) missingModel := getBase() ext.Set(missingModel) missingModel, err = ca.Sign(missingModel) require.NoError(t, err) missingFirmwareVersionASN1 := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMModel, "SLB 9670 TPM2.0")) sans = []x509util.SubjectAlternativeName{ {Type: x509util.DirectoryNameType, ASN1Value: missingFirmwareVersionASN1}, } ext, err = createSubjectAltNameExtension(nil, nil, nil, nil, sans, true) require.NoError(t, err) missingFirmwareVersion := getBase() ext.Set(missingFirmwareVersion) missingFirmwareVersion, err = ca.Sign(missingFirmwareVersion) require.NoError(t, err) tests := []struct { name string c *x509.Certificate expErr error }{ {"ok", ok, nil}, {"fail/missing-manufacturer", missingManufacturer, errors.New("missing TPM manufacturer")}, {"fail/missing-model", missingModel, errors.New("missing TPM model")}, {"fail/missing-firmware-version", missingFirmwareVersion, errors.New("missing TPM version")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateAKCertificateSubjectAlternativeNames(tt.c) if tt.expErr != nil { if assert.Error(t, err) { assert.EqualError(t, err, tt.expErr.Error()) } return } assert.NoError(t, err) }) } } func Test_validateAKCertificateExtendedKeyUsage(t *testing.T) { ok := generateValidAKCertificate(t) missingEKU := &x509.Certificate{} t.Helper() signer, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) template := &x509.Certificate{ PublicKey: signer.Public(), ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, } ca, err := minica.New() require.NoError(t, err) wrongEKU, err := ca.Sign(template) require.NoError(t, err) emptyEKU, err := ca.Sign(&x509.Certificate{ PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{{ Id: oidExtensionExtendedKeyUsage, Value: []byte{0x30, 0x00}, // DER: empty SEQUENCE }}, }) require.NoError(t, err) tests := []struct { name string c *x509.Certificate expErr error }{ {"ok", ok, nil}, {"fail/wrong-eku", wrongEKU, errors.New("AK certificate is missing Extended Key Usage value tcg-kp-AIKCertificate (2.23.133.8.3)")}, {"fail/empty-eku", emptyEKU, errors.New("AK certificate is missing Extended Key Usage value tcg-kp-AIKCertificate (2.23.133.8.3)")}, {"fail/missing-eku", missingEKU, errors.New("AK certificate is missing Extended Key Usage extension")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateAKCertificateExtendedKeyUsage(tt.c) if tt.expErr != nil { if assert.Error(t, err) { assert.EqualError(t, err, tt.expErr.Error()) } return } assert.NoError(t, err) }) } } // createSubjectAltNameExtension will construct an Extension containing all // SubjectAlternativeNames held in a Certificate. It implements more types than // the golang x509 library, so it is used whenever OtherName or RegisteredID // type SANs are present in the certificate. // // See also https://datatracker.ietf.org/doc/html/rfc5280.html#section-4.2.1.6 // // TODO(hs): this was copied from go.step.sm/crypto/x509util to make it easier // to create the SAN extension for testing purposes. Should it be exposed instead? func createSubjectAltNameExtension(dnsNames, emailAddresses x509util.MultiString, ipAddresses x509util.MultiIP, uris x509util.MultiURL, sans []x509util.SubjectAlternativeName, subjectIsEmpty bool) (x509util.Extension, error) { var zero x509util.Extension var rawValues []asn1.RawValue for _, dnsName := range dnsNames { rawValue, err := x509util.SubjectAlternativeName{ Type: x509util.DNSType, Value: dnsName, }.RawValue() if err != nil { return zero, err } rawValues = append(rawValues, rawValue) } for _, emailAddress := range emailAddresses { rawValue, err := x509util.SubjectAlternativeName{ Type: x509util.EmailType, Value: emailAddress, }.RawValue() if err != nil { return zero, err } rawValues = append(rawValues, rawValue) } for _, ip := range ipAddresses { rawValue, err := x509util.SubjectAlternativeName{ Type: x509util.IPType, Value: ip.String(), }.RawValue() if err != nil { return zero, err } rawValues = append(rawValues, rawValue) } for _, uri := range uris { rawValue, err := x509util.SubjectAlternativeName{ Type: x509util.URIType, Value: uri.String(), }.RawValue() if err != nil { return zero, err } rawValues = append(rawValues, rawValue) } for _, san := range sans { rawValue, err := san.RawValue() if err != nil { return zero, err } rawValues = append(rawValues, rawValue) } // Now marshal the rawValues into the ASN1 sequence, and create an Extension object to hold the extension rawBytes, err := asn1.Marshal(rawValues) if err != nil { return zero, fmt.Errorf("error marshaling SubjectAlternativeName extension to ASN1: %w", err) } return x509util.Extension{ ID: x509util.ObjectIdentifier(oidSubjectAlternativeName), Critical: subjectIsEmpty, Value: rawBytes, }, nil } func Test_tlsAlpn01ChallengeHost(t *testing.T) { type args struct { name string } tests := []struct { name string strictFQDN bool args args want string }{ {"dns", false, args{"smallstep.com"}, "smallstep.com"}, {"dns strict", true, args{"smallstep.com"}, "smallstep.com."}, {"rooted dns", false, args{"smallstep.com."}, "smallstep.com."}, {"rooted dns strict", true, args{"smallstep.com."}, "smallstep.com."}, {"ipv4", true, args{"1.2.3.4"}, "1.2.3.4"}, {"ipv6", true, args{"2607:f8b0:4023:1009::71"}, "2607:f8b0:4023:1009::71"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tmp := StrictFQDN t.Cleanup(func() { StrictFQDN = tmp }) StrictFQDN = tt.strictFQDN assert.Equal(t, tt.want, tlsAlpn01ChallengeHost(tt.args.name)) }) } } func Test_dns01ChallengeHost(t *testing.T) { type args struct { domain string } tests := []struct { name string strictFQDN bool args args want string }{ {"dns", false, args{"smallstep.com"}, "_acme-challenge.smallstep.com"}, {"dns strict", true, args{"smallstep.com"}, "_acme-challenge.smallstep.com."}, {"rooted dns", false, args{"smallstep.com."}, "_acme-challenge.smallstep.com."}, {"rooted dns strict", true, args{"smallstep.com."}, "_acme-challenge.smallstep.com."}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tmp := StrictFQDN t.Cleanup(func() { StrictFQDN = tmp }) StrictFQDN = tt.strictFQDN assert.Equal(t, tt.want, dns01ChallengeHost(tt.args.domain)) }) } } ================================================ FILE: acme/challenge_tpmsimulator_test.go ================================================ //go:build tpmsimulator package acme import ( "context" "crypto" "crypto/sha256" "crypto/x509" "encoding/asn1" "encoding/base64" "encoding/json" "encoding/pem" "errors" "fmt" "net/url" "testing" "github.com/fxamacker/cbor/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/smallstep/go-attestation/attest" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/minica" "go.step.sm/crypto/tpm" "go.step.sm/crypto/tpm/simulator" tpmstorage "go.step.sm/crypto/tpm/storage" "go.step.sm/crypto/x509util" ) func newSimulatedTPM(t *testing.T) *tpm.TPM { t.Helper() tmpDir := t.TempDir() tpm, err := tpm.New(withSimulator(t), tpm.WithStore(tpmstorage.NewDirstore(tmpDir))) // TODO: provide in-memory storage implementation instead require.NoError(t, err) return tpm } func withSimulator(t *testing.T) tpm.NewTPMOption { t.Helper() var sim simulator.Simulator t.Cleanup(func() { if sim == nil { return } err := sim.Close() require.NoError(t, err) }) sim, err := simulator.New() require.NoError(t, err) err = sim.Open() require.NoError(t, err) return tpm.WithSimulator(sim) } func generateKeyID(t *testing.T, pub crypto.PublicKey) []byte { t.Helper() b, err := x509.MarshalPKIXPublicKey(pub) require.NoError(t, err) hash := sha256.Sum256(b) return hash[:] } func mustAttestTPM(t *testing.T, keyAuthorization string, permanentIdentifiers []string) ([]byte, crypto.Signer, *x509.Certificate) { t.Helper() aca, err := minica.New( minica.WithName("TPM Testing"), minica.WithGetSignerFunc( func() (crypto.Signer, error) { return keyutil.GenerateSigner("RSA", "", 2048) }, ), ) require.NoError(t, err) // prepare simulated TPM and create an AK stpm := newSimulatedTPM(t) eks, err := stpm.GetEKs(context.Background()) require.NoError(t, err) ak, err := stpm.CreateAK(context.Background(), "first-ak") require.NoError(t, err) require.NotNil(t, ak) // extract the AK public key // TODO(hs): replace this when there's a simpler method to get the AK public key (e.g. ak.Public()) ap, err := ak.AttestationParameters(context.Background()) require.NoError(t, err) akp, err := attest.ParseAKPublic(attest.TPMVersion20, ap.Public) require.NoError(t, err) // create template and sign certificate for the AK public key keyID := generateKeyID(t, eks[0].Public()) template := &x509.Certificate{ PublicKey: akp.Public, IsCA: false, UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate}, } sans := []x509util.SubjectAlternativeName{} uris := []*url.URL{{Scheme: "urn", Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID)}} for _, pi := range permanentIdentifiers { sans = append(sans, x509util.SubjectAlternativeName{ Type: x509util.PermanentIdentifierType, Value: pi, }) } asn1Value := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMModel, "SLB 9670 TPM2.0", oidTPMVersion, "7.55")) sans = append(sans, x509util.SubjectAlternativeName{ Type: x509util.DirectoryNameType, ASN1Value: asn1Value, }) ext, err := createSubjectAltNameExtension(nil, nil, nil, uris, sans, true) require.NoError(t, err) ext.Set(template) akCert, err := aca.Sign(template) require.NoError(t, err) require.NotNil(t, akCert) // create a new key attested by the AK, while including // the key authorization bytes as qualifying data. keyAuthSum := sha256.Sum256([]byte(keyAuthorization)) config := tpm.AttestKeyConfig{ Algorithm: "RSA", Size: 2048, QualifyingData: keyAuthSum[:], } key, err := stpm.AttestKey(context.Background(), "first-ak", "first-key", config) require.NoError(t, err) require.NotNil(t, key) require.Equal(t, "first-key", key.Name()) require.NotEqual(t, 0, len(key.Data())) require.Equal(t, "first-ak", key.AttestedBy()) require.True(t, key.WasAttested()) require.True(t, key.WasAttestedBy(ak)) signer, err := key.Signer(context.Background()) require.NoError(t, err) // prepare the attestation object with the AK certificate chain, // the attested key, its metadata and the signature signed by the // AK. params, err := key.CertificationParameters(context.Background()) require.NoError(t, err) attObj, err := cbor.Marshal(struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` }{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }) require.NoError(t, err) // marshal the ACME payload payload, err := json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString(attObj), }) require.NoError(t, err) return payload, signer, aca.Root } func Test_deviceAttest01ValidateWithTPMSimulator(t *testing.T) { type args struct { ctx context.Context ch *Challenge db DB jwk *jose.JSONWebKey payload []byte } type test struct { args args wantErr *Error } tests := map[string]func(t *testing.T) test{ "ok/doTPMAttestationFormat-storeError": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, _, root := mustAttestTPM(t, keyAuth, nil) // TODO: value(s) for AK cert? caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) // parse payload, set invalid "ver", remarshal var p payloadType err := json.Unmarshal(payload, &p) require.NoError(t, err) attObj, err := base64.RawURLEncoding.DecodeString(p.AttObj) require.NoError(t, err) att := attestationObject{} err = cbor.Unmarshal(attObj, &att) require.NoError(t, err) att.AttStatement["ver"] = "bogus" attObj, err = cbor.Marshal(struct { Format string `json:"fmt"` AttStatement map[string]interface{} `json:"attStmt,omitempty"` }{ Format: "tpm", AttStatement: att.AttStatement, }) require.NoError(t, err) payload, err = json.Marshal(struct { AttObj string `json:"attObj"` }{ AttObj: base64.RawURLEncoding.EncodeToString(attObj), }) require.NoError(t, err) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "device.id.12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "device.id.12345678", updch.Value) err := NewDetailedError(ErrorBadAttestationStatementType, `version "bogus" is not supported`) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, }, wantErr: nil, } }, "ok with invalid PermanentIdentifier SAN": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, _, root := mustAttestTPM(t, keyAuth, []string{"device.id.12345678"}) // TODO: value(s) for AK cert? caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "device.id.99999999", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "device.id.99999999", updch.Value) err := NewDetailedError(ErrorBadAttestationStatementType, `permanent identifier does not match`). AddSubproblems(NewSubproblemWithIdentifier( ErrorRejectedIdentifierType, Identifier{Type: "permanent-identifier", Value: "device.id.99999999"}, `challenge identifier "device.id.99999999" doesn't match any of the attested hardware identifiers ["device.id.12345678"]`, )) assert.EqualError(t, updch.Error.Err, err.Err.Error()) assert.Equal(t, err.Type, updch.Error.Type) assert.Equal(t, err.Detail, updch.Error.Detail) assert.Equal(t, err.Status, updch.Error.Status) assert.Equal(t, err.Subproblems, updch.Error.Subproblems) return nil }, }, }, wantErr: nil, } }, "ok": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, signer, root := mustAttestTPM(t, keyAuth, nil) // TODO: value(s) for AK cert? caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "device.id.12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error { fingerprint, err := keyutil.Fingerprint(signer.Public()) assert.NoError(t, err) assert.Equal(t, "azID", az.ID) assert.Equal(t, fingerprint, az.Fingerprint) return nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "device.id.12345678", updch.Value) return nil }, }, }, wantErr: nil, } }, "ok with PermanentIdentifier SAN": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") payload, signer, root := mustAttestTPM(t, keyAuth, []string{"device.id.12345678"}) // TODO: value(s) for AK cert? caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw}) ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot)) return test{ args: args{ ctx: ctx, jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", Token: "token", Type: "device-attest-01", Status: StatusPending, Value: "device.id.12345678", }, payload: payload, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { assert.Equal(t, "azID", id) return &Authorization{ID: "azID"}, nil }, MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error { fingerprint, err := keyutil.Fingerprint(signer.Public()) assert.NoError(t, err) assert.Equal(t, "azID", az.ID) assert.Equal(t, fingerprint, az.Fingerprint) return nil }, MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "device.id.12345678", updch.Value) return nil }, }, }, wantErr: nil, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if err := deviceAttest01Validate(tc.args.ctx, tc.args.ch, tc.args.db, tc.args.jwk, tc.args.payload); err != nil { assert.Error(t, tc.wantErr) assert.EqualError(t, err, tc.wantErr.Error()) return } assert.Nil(t, tc.wantErr) }) } } func newBadAttestationStatementError(msg string) *Error { return &Error{ Type: "urn:ietf:params:acme:error:badAttestationStatement", Status: 400, Err: errors.New(msg), } } func newInternalServerError(msg string) *Error { return &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Status: 500, Err: errors.New(msg), } } var ( oidPermanentIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3} oidHardwareModuleNameIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 4} ) func Test_doTPMAttestationFormat(t *testing.T) { ctx := context.Background() aca, err := minica.New( minica.WithName("TPM Testing"), minica.WithGetSignerFunc( func() (crypto.Signer, error) { return keyutil.GenerateSigner("RSA", "", 2048) }, ), ) require.NoError(t, err) acaRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: aca.Root.Raw}) // prepare simulated TPM and create an AK stpm := newSimulatedTPM(t) eks, err := stpm.GetEKs(context.Background()) require.NoError(t, err) ak, err := stpm.CreateAK(context.Background(), "first-ak") require.NoError(t, err) require.NotNil(t, ak) // extract the AK public key // TODO(hs): replace this when there's a simpler method to get the AK public key (e.g. ak.Public()) ap, err := ak.AttestationParameters(context.Background()) require.NoError(t, err) akp, err := attest.ParseAKPublic(attest.TPMVersion20, ap.Public) require.NoError(t, err) // create template and sign certificate for the AK public key keyID := generateKeyID(t, eks[0].Public()) template := &x509.Certificate{ PublicKey: akp.Public, IsCA: false, UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate}, } sans := []x509util.SubjectAlternativeName{} uris := []*url.URL{{Scheme: "urn", Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID)}} asn1Value := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMModel, "SLB 9670 TPM2.0", oidTPMVersion, "7.55")) sans = append(sans, x509util.SubjectAlternativeName{ Type: x509util.DirectoryNameType, ASN1Value: asn1Value, }) ext, err := createSubjectAltNameExtension(nil, nil, nil, uris, sans, true) require.NoError(t, err) ext.Set(template) akCert, err := aca.Sign(template) require.NoError(t, err) require.NotNil(t, akCert) invalidTemplate := &x509.Certificate{ PublicKey: akp.Public, IsCA: false, UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate}, } invalidAKCert, err := aca.Sign(invalidTemplate) require.NoError(t, err) require.NotNil(t, invalidAKCert) // generate a JWK and the key authorization value jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) keyAuthorization, err := KeyAuthorization("token", jwk) require.NoError(t, err) // create a new key attested by the AK, while including // the key authorization bytes as qualifying data. keyAuthSum := sha256.Sum256([]byte(keyAuthorization)) config := tpm.AttestKeyConfig{ Algorithm: "RSA", Size: 2048, QualifyingData: keyAuthSum[:], } key, err := stpm.AttestKey(context.Background(), "first-ak", "first-key", config) require.NoError(t, err) require.NotNil(t, key) params, err := key.CertificationParameters(context.Background()) require.NoError(t, err) signer, err := key.Signer(context.Background()) require.NoError(t, err) fingerprint, err := keyutil.Fingerprint(signer.Public()) require.NoError(t, err) // attest another key and get its certification parameters anotherKey, err := stpm.AttestKey(context.Background(), "first-ak", "another-key", config) require.NoError(t, err) require.NotNil(t, key) anotherKeyParams, err := anotherKey.CertificationParameters(context.Background()) require.NoError(t, err) type args struct { ctx context.Context prov Provisioner ch *Challenge jwk *jose.JSONWebKey att *attestationObject } tests := []struct { name string args args want *tpmAttestationData expErr *Error }{ {"ok", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, nil}, {"fail ver not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("ver not present")}, {"fail ver type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": []interface{}{}, "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("ver not present")}, {"fail bogus ver", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "bogus", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError(`version "bogus" is not supported`)}, {"fail x5c not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("x5c not present")}, {"fail x5c type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": [][]byte{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("x5c not present")}, {"fail x5c empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("x5c is empty")}, {"fail leaf type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{"leaf", aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("x5c is malformed")}, {"fail leaf parse", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw[:100], aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("x5c is malformed: x509: malformed certificate")}, {"fail intermediate type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, "intermediate"}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("x5c is malformed")}, {"fail intermediate parse", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw[:100]}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("x5c is malformed: x509: malformed certificate")}, {"fail roots", args{ctx, mustAttestationProvisioner(t, nil), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newInternalServerError("no root CA bundle available to verify the attestation certificate")}, {"fail verify", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "step", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("x5c is not valid: x509: certificate signed by unknown authority")}, {"fail validateAKCertificate", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{invalidAKCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("AK certificate is not valid: missing TPM manufacturer")}, {"fail pubArea not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, }, }}, nil, newBadAttestationStatementError("invalid pubArea in attestation statement")}, {"fail pubArea type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": []interface{}{}, }, }}, nil, newBadAttestationStatementError("invalid pubArea in attestation statement")}, {"fail pubArea empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": []byte{}, }, }}, nil, newBadAttestationStatementError("pubArea is empty")}, {"fail sig not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("invalid sig in attestation statement")}, {"fail sig type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": []interface{}{}, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("invalid sig in attestation statement")}, {"fail sig empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": []byte{}, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("sig is empty")}, {"fail certInfo not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("invalid certInfo in attestation statement")}, {"fail certInfo type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": []interface{}{}, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("invalid certInfo in attestation statement")}, {"fail certInfo empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": []byte{}, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("certInfo is empty")}, {"fail alg not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("invalid alg in attestation statement")}, {"fail alg type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(0), // invalid alg "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("invalid alg 0 in attestation statement")}, {"fail attestation verification", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": anotherKeyParams.Public, }, }}, nil, newBadAttestationStatementError("invalid certification parameters: certification refers to a different key")}, {"fail keyAuthorization", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, &jose.JSONWebKey{Key: []byte("not an asymmetric key")}, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // RS256 "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newInternalServerError("failed creating key auth digest: error generating JWK thumbprint: go-jose/go-jose: unknown key type '[]uint8'")}, {"fail different keyAuthorization", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "aDifferentToken"}, jwk, &attestationObject{ Format: "tpm", AttStatement: map[string]interface{}{ "ver": "2.0", "x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw}, "alg": int64(-257), // "sig": params.CreateSignature, "certInfo": params.CreateAttestation, "pubArea": params.Public, }, }}, nil, newBadAttestationStatementError("key authorization invalid")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := doTPMAttestationFormat(tt.args.ctx, tt.args.prov, tt.args.ch, tt.args.jwk, tt.args.att) if tt.expErr != nil { var ae *Error if assert.True(t, errors.As(err, &ae)) { assert.EqualError(t, err, tt.expErr.Error()) assert.Equal(t, ae.StatusCode(), tt.expErr.StatusCode()) assert.Equal(t, ae.Type, tt.expErr.Type) } assert.Nil(t, got) return } assert.NoError(t, err) if assert.NotNil(t, got) { assert.Equal(t, akCert, got.Certificate) assert.Equal(t, [][]*x509.Certificate{ { akCert, aca.Intermediate, aca.Root, }, }, got.VerifiedChains) assert.Equal(t, fingerprint, got.Fingerprint) assert.Empty(t, got.PermanentIdentifiers) // currently expected to be always empty } }) } } ================================================ FILE: acme/challenge_wire_test.go ================================================ package acme import ( "context" "crypto" "crypto/ed25519" "encoding/base64" "encoding/json" "encoding/pem" "errors" "net/http/httptest" "strconv" "testing" "time" "github.com/smallstep/certificates/acme/wire" "github.com/smallstep/certificates/authority/provisioner" wireprovisioner "github.com/smallstep/certificates/authority/provisioner/wire" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) func Test_wireDPOP01Validate(t *testing.T) { fakeKey := `-----BEGIN PUBLIC KEY----- MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= -----END PUBLIC KEY-----` type test struct { ch *Challenge jwk *jose.JSONWebKey db WireDB payload []byte ctx context.Context expectedErr *Error } tests := map[string]func(t *testing.T) test{ "fail/no-provisioner": func(t *testing.T) test { return test{ ctx: context.Background(), db: &MockWireDB{}, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New("missing provisioner"), }, } }, "fail/no-linker": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "https://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) return test{ ctx: ctx, db: &MockWireDB{}, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New("missing linker"), }, } }, "fail/unmarshal": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "https://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ctx: ctx, payload: []byte("?!"), ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: "1234", }, db: &MockWireDB{}, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:malformed", Detail: "The request message was malformed", Status: 400, Err: errors.New(`error unmarshalling Wire DPoP challenge payload: invalid character '?' looking for beginning of value`), }, } }, "fail/wire-parse-id": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "https://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ctx: ctx, payload: []byte("{}"), ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: "1234", }, db: &MockWireDB{}, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`error unmarshalling challenge data: json: cannot unmarshal number into Go value of type wire.DeviceID`), }, } }, "fail/wire-parse-client-id": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "https://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) return test{ ctx: ctx, payload: []byte("{}"), ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: string(valueBytes), }, db: &MockWireDB{}, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`error parsing device id: invalid Wire client ID username "594930e9d50bb175"`), }, } }, "fail/parse-and-verify": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "http://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ Target: "{{ .DeviceID }}", SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) jwk, _ := mustAccountAndKeyAuthorization(t, "token") return test{ ctx: ctx, payload: []byte("{}"), jwk: jwk, ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: string(valueBytes), }, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, ch *Challenge) error { assert.Equal(t, "chID", ch.ID) assert.Equal(t, "azID", ch.AuthorizationID) assert.Equal(t, "accID", ch.AccountID) assert.Equal(t, "token", ch.Token) assert.Equal(t, ChallengeType("wire-dpop-01"), ch.Type) assert.Equal(t, StatusInvalid, ch.Status) assert.Equal(t, string(valueBytes), ch.Value) if assert.NotNil(t, ch.Error) { var k *Error // NOTE: the error is not returned up, but stored with the challenge instead if errors.As(ch.Error, &k) { assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type) assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail) assert.Equal(t, 400, k.Status) assert.Equal(t, `failed validating Wire access token: failed parsing token: go-jose/go-jose: compact JWS format must have three parts`, k.Err.Error()) } } return nil }, }, }, } }, "fail/db.UpdateChallenge": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") _ = keyAuth // TODO(hs): keyAuth (not) required for DPoP? Or needs to be added to validation? dpopSigner, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk, }, new(jose.SignerOptions)) require.NoError(t, err) signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) signerPEMBlock, err := pemutil.Serialize(signerJWK.Public().Key) require.NoError(t, err) signerPEMBytes := pem.EncodeToMemory(signerPEMBlock) dpopBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, }, Challenge: "token", Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) require.NoError(t, err) proof, err := dpop.CompactSerialize() require.NoError(t, err) tokenBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Nonce string `json:"nonce,omitempty"` Cnf struct { Kid string `json:"kid,omitempty"` } `json:"cnf"` Proof string `json:"proof,omitempty"` ClientID string `json:"client_id"` APIVersion int `json:"api_version"` Scope string `json:"scope"` }{ Claims: jose.Claims{ Issuer: "http://issuer.example.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Challenge: "token", Nonce: "nonce", Cnf: struct { Kid string `json:"kid,omitempty"` }{ Kid: jwk.KeyID, }, Proof: proof, ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", APIVersion: 5, Scope: "wire_client_id", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) accessToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { AccessToken string `json:"access_token"` }{ AccessToken: accessToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "http://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ Target: "http://issuer.example.com", SigningKey: signerPEMBytes, }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: string(valueBytes), }, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return errors.New("fail") }, }, }, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`error updating challenge: fail`), }, } }, "fail/db.GetAllOrdersByAccountID": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") _ = keyAuth // TODO(hs): keyAuth (not) required for DPoP? Or needs to be added to validation? dpopSigner, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk, }, new(jose.SignerOptions)) require.NoError(t, err) signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) signerPEMBlock, err := pemutil.Serialize(signerJWK.Public().Key) require.NoError(t, err) signerPEMBytes := pem.EncodeToMemory(signerPEMBlock) dpopBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, }, Challenge: "token", Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) require.NoError(t, err) proof, err := dpop.CompactSerialize() require.NoError(t, err) tokenBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Nonce string `json:"nonce,omitempty"` Cnf struct { Kid string `json:"kid,omitempty"` } `json:"cnf"` Proof string `json:"proof,omitempty"` ClientID string `json:"client_id"` APIVersion int `json:"api_version"` Scope string `json:"scope"` }{ Claims: jose.Claims{ Issuer: "http://issuer.example.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Challenge: "token", Nonce: "nonce", Cnf: struct { Kid string `json:"kid,omitempty"` }{ Kid: jwk.KeyID, }, Proof: proof, ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", APIVersion: 5, Scope: "wire_client_id", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) accessToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { AccessToken string `json:"access_token"` }{ AccessToken: accessToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "http://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ Target: "http://issuer.example.com", SigningKey: signerPEMBytes, }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: string(valueBytes), }, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return nil }, }, MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) { assert.Equal(t, "accID", accountID) return nil, errors.New("fail") }, }, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`could not find current order by account id: fail`), }, } }, "fail/db.GetAllOrdersByAccountID-zero": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") _ = keyAuth // TODO(hs): keyAuth (not) required for DPoP? Or needs to be added to validation? dpopSigner, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk, }, new(jose.SignerOptions)) require.NoError(t, err) signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) signerPEMBlock, err := pemutil.Serialize(signerJWK.Public().Key) require.NoError(t, err) signerPEMBytes := pem.EncodeToMemory(signerPEMBlock) dpopBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, }, Challenge: "token", Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) require.NoError(t, err) proof, err := dpop.CompactSerialize() require.NoError(t, err) tokenBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Nonce string `json:"nonce,omitempty"` Cnf struct { Kid string `json:"kid,omitempty"` } `json:"cnf"` Proof string `json:"proof,omitempty"` ClientID string `json:"client_id"` APIVersion int `json:"api_version"` Scope string `json:"scope"` }{ Claims: jose.Claims{ Issuer: "http://issuer.example.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Challenge: "token", Nonce: "nonce", Cnf: struct { Kid string `json:"kid,omitempty"` }{ Kid: jwk.KeyID, }, Proof: proof, ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", APIVersion: 5, Scope: "wire_client_id", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) accessToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { AccessToken string `json:"access_token"` }{ AccessToken: accessToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "http://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ Target: "http://issuer.example.com", SigningKey: signerPEMBytes, }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: string(valueBytes), }, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return nil }, }, MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) { assert.Equal(t, "accID", accountID) return []string{}, nil }, }, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`there are not enough orders for this account for this custom OIDC challenge`), }, } }, "fail/db.CreateDpopToken": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") _ = keyAuth // TODO(hs): keyAuth (not) required for DPoP? Or needs to be added to validation? dpopSigner, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk, }, new(jose.SignerOptions)) require.NoError(t, err) signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) signerPEMBlock, err := pemutil.Serialize(signerJWK.Public().Key) require.NoError(t, err) signerPEMBytes := pem.EncodeToMemory(signerPEMBlock) dpopBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, }, Challenge: "token", Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) require.NoError(t, err) proof, err := dpop.CompactSerialize() require.NoError(t, err) tokenBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Nonce string `json:"nonce,omitempty"` Cnf struct { Kid string `json:"kid,omitempty"` } `json:"cnf"` Proof string `json:"proof,omitempty"` ClientID string `json:"client_id"` APIVersion int `json:"api_version"` Scope string `json:"scope"` }{ Claims: jose.Claims{ Issuer: "http://issuer.example.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Challenge: "token", Nonce: "nonce", Cnf: struct { Kid string `json:"kid,omitempty"` }{ Kid: jwk.KeyID, }, Proof: proof, ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", APIVersion: 5, Scope: "wire_client_id", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) accessToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { AccessToken string `json:"access_token"` }{ AccessToken: accessToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "http://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ Target: "http://issuer.example.com", SigningKey: signerPEMBytes, }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: string(valueBytes), }, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return nil }, }, MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) { assert.Equal(t, "accID", accountID) return []string{"orderID"}, nil }, MockCreateDpopToken: func(ctx context.Context, orderID string, dpop map[string]interface{}) error { assert.Equal(t, "orderID", orderID) assert.Equal(t, "token", dpop["chal"].(string)) assert.Equal(t, "wireapp://%40alice_wire@wire.com", dpop["handle"].(string)) assert.Equal(t, "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", dpop["sub"].(string)) return errors.New("fail") }, }, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`failed storing DPoP token: fail`), }, } }, "ok": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") _ = keyAuth // TODO(hs): keyAuth (not) required for DPoP? Or needs to be added to validation? dpopSigner, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), Key: jwk, }, new(jose.SignerOptions)) require.NoError(t, err) signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) signerPEMBlock, err := pemutil.Serialize(signerJWK.Public().Key) require.NoError(t, err) signerPEMBytes := pem.EncodeToMemory(signerPEMBlock) dpopBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, }, Challenge: "token", Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) require.NoError(t, err) proof, err := dpop.CompactSerialize() require.NoError(t, err) tokenBytes, err := json.Marshal(struct { jose.Claims Challenge string `json:"chal,omitempty"` Nonce string `json:"nonce,omitempty"` Cnf struct { Kid string `json:"kid,omitempty"` } `json:"cnf"` Proof string `json:"proof,omitempty"` ClientID string `json:"client_id"` APIVersion int `json:"api_version"` Scope string `json:"scope"` }{ Claims: jose.Claims{ Issuer: "http://issuer.example.com", Audience: jose.Audience{"https://ca.example.com/acme/wire/challenge/azID/chID"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Challenge: "token", Nonce: "nonce", Cnf: struct { Kid string `json:"kid,omitempty"` }{ Kid: jwk.KeyID, }, Proof: proof, ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", APIVersion: 5, Scope: "wire_client_id", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) accessToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { AccessToken string `json:"access_token"` }{ AccessToken: accessToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "http://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ Target: "http://issuer.example.com", SigningKey: signerPEMBytes, }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-dpop-01", Status: StatusPending, Value: string(valueBytes), }, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-dpop-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return nil }, }, MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) { assert.Equal(t, "accID", accountID) return []string{"orderID"}, nil }, MockCreateDpopToken: func(ctx context.Context, orderID string, dpop map[string]interface{}) error { assert.Equal(t, "orderID", orderID) assert.Equal(t, "token", dpop["chal"].(string)) assert.Equal(t, "wireapp://%40alice_wire@wire.com", dpop["handle"].(string)) assert.Equal(t, "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", dpop["sub"].(string)) return nil }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) err := wireDPOP01Validate(tc.ctx, tc.ch, tc.db, tc.jwk, tc.payload) if tc.expectedErr != nil { var k *Error if errors.As(err, &k) { assert.Equal(t, tc.expectedErr.Type, k.Type) assert.Equal(t, tc.expectedErr.Detail, k.Detail) assert.Equal(t, tc.expectedErr.Status, k.Status) assert.Equal(t, tc.expectedErr.Err.Error(), k.Err.Error()) } else { assert.Fail(t, "unexpected error type") } return } assert.NoError(t, err) }) } } func Test_wireOIDC01Validate(t *testing.T) { fakeKey := `-----BEGIN PUBLIC KEY----- MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= -----END PUBLIC KEY-----` type test struct { ch *Challenge jwk *jose.JSONWebKey db WireDB payload []byte srv *httptest.Server ctx context.Context expectedErr *Error } tests := map[string]func(t *testing.T) test{ "fail/no-provisioner": func(t *testing.T) test { return test{ ctx: context.Background(), db: &MockWireDB{}, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New("missing provisioner"), }, } }, "fail/no-linker": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "https://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) return test{ ctx: ctx, db: &MockWireDB{}, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New("missing linker"), }, } }, "fail/unmarshal": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "https://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ctx: ctx, payload: []byte("?!"), ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: "1234", }, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, ch *Challenge) error { assert.Equal(t, "chID", ch.ID) return nil }, }, }, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:malformed", Detail: "The request message was malformed", Status: 400, Err: errors.New(`error unmarshalling Wire OIDC challenge payload: invalid character '?' looking for beginning of value`), }, } }, "fail/wire-parse-id": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "https://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ctx: ctx, payload: []byte("{}"), ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: "1234", }, db: &MockWireDB{}, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`error unmarshalling challenge data: json: cannot unmarshal number into Go value of type wire.UserID`), }, } }, "fail/verify": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) anotherSignerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) srv := mustJWKServer(t, anotherSignerJWK.Public()) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud"` }{ Claims: jose.Claims{ Issuer: srv.URL, Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40alice_wire@wire.com", KeyAuth: keyAuth, ACMEAudience: "https://ca.example.com/acme/wire/challenge/azID/chID", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: srv.URL, JWKSURL: srv.URL + "/keys", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: string(valueBytes), }, srv: srv, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) if assert.NotNil(t, updch.Error) { var k *Error // NOTE: the error is not returned up, but stored with the challenge instead if errors.As(updch.Error, &k) { assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type) assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail) assert.Equal(t, 400, k.Status) assert.Equal(t, `error verifying ID token signature: failed to verify signature: failed to verify id token signature`, k.Err.Error()) } } return nil }, }, }, } }, "fail/keyauth-mismatch": func(t *testing.T) test { jwk, _ := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) srv := mustJWKServer(t, signerJWK.Public()) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud"` }{ Claims: jose.Claims{ Issuer: srv.URL, Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40alice_wire@wire.com", KeyAuth: "wrong-keyauth", ACMEAudience: "https://ca.example.com/acme/wire/challenge/azID/chID", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: srv.URL, JWKSURL: srv.URL + "/keys", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: string(valueBytes), }, srv: srv, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) if assert.NotNil(t, updch.Error) { var k *Error // NOTE: the error is not returned up, but stored with the challenge instead if errors.As(updch.Error, &k) { assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type) assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail) assert.Equal(t, 400, k.Status) assert.Contains(t, k.Err.Error(), "keyAuthorization does not match") } } return nil }, }, }, } }, "fail/validateWireOIDCClaims": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) srv := mustJWKServer(t, signerJWK.Public()) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud"` }{ Claims: jose.Claims{ Issuer: srv.URL, Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40bob@wire.com", KeyAuth: keyAuth, ACMEAudience: "https://ca.example.com/acme/wire/challenge/azID/chID", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: srv.URL, JWKSURL: srv.URL + "/keys", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: string(valueBytes), }, srv: srv, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) if assert.NotNil(t, updch.Error) { var k *Error // NOTE: the error is not returned up, but stored with the challenge instead if errors.As(updch.Error, &k) { assert.Equal(t, "urn:ietf:params:acme:error:rejectedIdentifier", k.Type) assert.Equal(t, "The server will not issue certificates for the identifier", k.Detail) assert.Equal(t, 400, k.Status) assert.Equal(t, `claims in OIDC ID token don't match: invalid 'preferred_username' "wireapp://%40bob@wire.com" after transformation`, k.Err.Error()) } } return nil }, }, }, } }, "fail/db.UpdateChallenge": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) srv := mustJWKServer(t, signerJWK.Public()) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud"` }{ Claims: jose.Claims{ Issuer: srv.URL, Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40alice_wire@wire.com", KeyAuth: keyAuth, ACMEAudience: "https://ca.example.com/acme/wire/challenge/azID/chID", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: srv.URL, JWKSURL: srv.URL + "/keys", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: string(valueBytes), }, srv: srv, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return errors.New("fail") }, }, }, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`error updating challenge: fail`), }, } }, "fail/db.GetAllOrdersByAccountID": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) srv := mustJWKServer(t, signerJWK.Public()) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud"` }{ Claims: jose.Claims{ Issuer: srv.URL, Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40alice_wire@wire.com", KeyAuth: keyAuth, ACMEAudience: "https://ca.example.com/acme/wire/challenge/azID/chID", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: srv.URL, JWKSURL: srv.URL + "/keys", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: string(valueBytes), }, srv: srv, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return nil }, }, MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) { assert.Equal(t, "accID", accountID) return nil, errors.New("fail") }, }, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`could not retrieve current order by account id: fail`), }, } }, "fail/db.GetAllOrdersByAccountID-zero": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) srv := mustJWKServer(t, signerJWK.Public()) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud"` }{ Claims: jose.Claims{ Issuer: srv.URL, Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40alice_wire@wire.com", KeyAuth: keyAuth, ACMEAudience: "https://ca.example.com/acme/wire/challenge/azID/chID", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: srv.URL, JWKSURL: srv.URL + "/keys", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: string(valueBytes), }, srv: srv, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return nil }, }, MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) { assert.Equal(t, "accID", accountID) return []string{}, nil }, }, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`there are not enough orders for this account for this custom OIDC challenge`), }, } }, "fail/db.CreateOidcToken": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) srv := mustJWKServer(t, signerJWK.Public()) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud"` }{ Claims: jose.Claims{ Issuer: srv.URL, Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40alice_wire@wire.com", KeyAuth: keyAuth, ACMEAudience: "https://ca.example.com/acme/wire/challenge/azID/chID", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: srv.URL, JWKSURL: srv.URL + "/keys", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: string(valueBytes), }, srv: srv, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return nil }, }, MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) { assert.Equal(t, "accID", accountID) return []string{"orderID"}, nil }, MockCreateOidcToken: func(ctx context.Context, orderID string, idToken map[string]interface{}) error { assert.Equal(t, "orderID", orderID) assert.Equal(t, "Alice Smith", idToken["name"].(string)) assert.Equal(t, "wireapp://%40alice_wire@wire.com", idToken["preferred_username"].(string)) return errors.New("fail") }, }, expectedErr: &Error{ Type: "urn:ietf:params:acme:error:serverInternal", Detail: "The server experienced an internal error", Status: 500, Err: errors.New(`failed storing OIDC id token: fail`), }, } }, "ok/wire-oidc-01": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(signerJWK.Algorithm), Key: signerJWK, }, new(jose.SignerOptions)) require.NoError(t, err) srv := mustJWKServer(t, signerJWK.Public()) tokenBytes, err := json.Marshal(struct { jose.Claims Name string `json:"name,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"` KeyAuth string `json:"keyauth"` ACMEAudience string `json:"acme_aud"` }{ Claims: jose.Claims{ Issuer: srv.URL, Audience: []string{"test"}, Expiry: jose.NewNumericDate(time.Now().Add(1 * time.Minute)), }, Name: "Alice Smith", PreferredUsername: "wireapp://%40alice_wire@wire.com", KeyAuth: keyAuth, ACMEAudience: "https://ca.example.com/acme/wire/challenge/azID/chID", }) require.NoError(t, err) signed, err := signer.Sign(tokenBytes) require.NoError(t, err) idToken, err := signed.CompactSerialize() require.NoError(t, err) payload, err := json.Marshal(struct { IDToken string `json:"id_token"` }{ IDToken: idToken, }) require.NoError(t, err) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` }{ Name: "Alice Smith", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com", }) require.NoError(t, err) ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: srv.URL, JWKSURL: srv.URL + "/keys", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: "", }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, }, })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ch: &Challenge{ ID: "chID", AuthorizationID: "azID", AccountID: "accID", Token: "token", Type: "wire-oidc-01", Status: StatusPending, Value: string(valueBytes), }, srv: srv, payload: payload, ctx: ctx, jwk: jwk, db: &MockWireDB{ MockDB: MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { assert.Equal(t, "chID", updch.ID) assert.Equal(t, "token", updch.Token) assert.Equal(t, StatusValid, updch.Status) assert.Equal(t, ChallengeType("wire-oidc-01"), updch.Type) assert.Equal(t, string(valueBytes), updch.Value) return nil }, }, MockGetAllOrdersByAccountID: func(ctx context.Context, accountID string) ([]string, error) { assert.Equal(t, "accID", accountID) return []string{"orderID"}, nil }, MockCreateOidcToken: func(ctx context.Context, orderID string, idToken map[string]interface{}) error { assert.Equal(t, "orderID", orderID) assert.Equal(t, "Alice Smith", idToken["name"].(string)) assert.Equal(t, "wireapp://%40alice_wire@wire.com", idToken["preferred_username"].(string)) return nil }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if tc.srv != nil { defer tc.srv.Close() } err := wireOIDC01Validate(tc.ctx, tc.ch, tc.db, tc.jwk, tc.payload) if tc.expectedErr != nil { var k *Error if errors.As(err, &k) { assert.Equal(t, tc.expectedErr.Type, k.Type) assert.Equal(t, tc.expectedErr.Detail, k.Detail) assert.Equal(t, tc.expectedErr.Status, k.Status) assert.Equal(t, tc.expectedErr.Err.Error(), k.Err.Error()) } else { assert.Fail(t, "unexpected error type") } return } assert.NoError(t, err) }) } } func Test_parseAndVerifyWireAccessToken(t *testing.T) { t.Skip("skip until we can retrieve public key from e2e test, so that we can actually verify the token") key := ` -----BEGIN PUBLIC KEY----- MCowBQYDK2VwAyEAB2IYqBWXAouDt3WcCZgCM3t9gumMEKMlgMsGenSu+fA= -----END PUBLIC KEY-----` publicKey, err := pemutil.Parse([]byte(key)) require.NoError(t, err) pk, ok := publicKey.(ed25519.PublicKey) require.True(t, ok) issuer := "http://wire.com:19983/clients/7a41cf5b79683410/access-token" wireID := wire.DeviceID{ ClientID: "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com", Handle: "wireapp://%40alice_wire@wire.com", } token := `eyJhbGciOiJFZERTQSIsInR5cCI6ImF0K2p3dCIsImp3ayI6eyJrdHkiOiJPS1AiLCJjcnYiOiJFZDI1NTE5IiwieCI6Im8zcWZhQ045a2FzSnZJRlhPdFNMTGhlYW0wTE5jcVF5MHdBMk9PeFRRNW8ifX0.eyJpYXQiOjE3MDU0OTc3MzksImV4cCI6MTcwNTUwMTY5OSwibmJmIjoxNzA1NDk3NzM5LCJpc3MiOiJodHRwOi8vd2lyZS5jb206MTY4MjQvY2xpZW50cy8zN2ZlOThiZDQwZDBkZmUvYWNjZXNzLXRva2VuIiwic3ViIjoid2lyZWFwcDovLzE4NXdIUmtRVHdTOTVGODhaZTQ1SlEhMzdmZTk4YmQ0MGQwZGZlQHdpcmUuY29tIiwiYXVkIjoiaHR0cHM6Ly9zdGVwY2E6NTUwMjMvYWNtZS93aXJlL2NoYWxsZW5nZS9SeEdSWGVoRGxCcHcxNTJQTVUzem0xY2M0cEtGcHVWRi9RWnRFazdQNUVFRXhadHBSYngydjVoYlc3QXB1S2NOSSIsImp0aSI6ImU1MzllODYzLTRkNTgtNGMwMS1iYjk3LTYwODdiNTEzOWIyMCIsIm5vbmNlIjoiUzJKYWVWcExkV28wUkZKaFFrWndXR0ZKY0VoVlFrNUxXVGd4WkhkRFVqQSIsImNoYWwiOiIyaDFPdUdxbTBKUXd6bHVsWGtLSTJEMGZiRDgzRUIxdyIsImNuZiI6eyJraWQiOiJhSEY3MVhYeG0tTWE5Q05zSjNaU1RKTjlYS0ZxOFFmOGh2UTJLN3NLQmQ4In0sInByb29mIjoiZXlKaGJHY2lPaUpGWkVSVFFTSXNJblI1Y0NJNkltUndiM0FyYW5kMElpd2lhbmRySWpwN0ltdDBlU0k2SWs5TFVDSXNJbU55ZGlJNklrVmtNalUxTVRraUxDSjRJam9pWVVsaVMwcFBha0poWXpZeVF6TnRhVmhHVjAxb09ITTJkRXQzUkROaGNHRnVSMHBQZURaVVFYVklRU0o5ZlEuZXlKcFlYUWlPakUzTURVME9UYzNNemtzSW1WNGNDSTZNVGN3TlRVd05Ea3pPU3dpYm1KbUlqb3hOekExTkRrM056TTVMQ0p6ZFdJaU9pSjNhWEpsWVhCd09pOHZNVGcxZDBoU2ExRlVkMU01TlVZNE9GcGxORFZLVVNFek4yWmxPVGhpWkRRd1pEQmtabVZBZDJseVpTNWpiMjBpTENKaGRXUWlPaUpvZEhSd2N6b3ZMM04wWlhCallUbzFOVEF5TXk5aFkyMWxMM2RwY21VdlkyaGhiR3hsYm1kbEwxSjRSMUpZWldoRWJFSndkekUxTWxCTlZUTjZiVEZqWXpSd1MwWndkVlpHTDFGYWRFVnJOMUExUlVWRmVGcDBjRkppZURKMk5XaGlWemRCY0hWTFkwNUpJaXdpYW5ScElqb2lNV1kxTUdRM1lUQXRaamt6WmkwME5XWXdMV0V3TWpBdE1ETm1NREJpTlRreVlUUmtJaXdpYm05dVkyVWlPaUpUTWtwaFpWWndUR1JYYnpCU1JrcG9VV3RhZDFkSFJrcGpSV2hXVVdzMVRGZFVaM2hhU0dSRVZXcEJJaXdpYUhSdElqb2lVRTlUVkNJc0ltaDBkU0k2SW1oMGRIQTZMeTkzYVhKbExtTnZiVG94TmpneU5DOWpiR2xsYm5Sekx6TTNabVU1T0dKa05EQmtNR1JtWlM5aFkyTmxjM010ZEc5clpXNGlMQ0pqYUdGc0lqb2lNbWd4VDNWSGNXMHdTbEYzZW14MWJGaHJTMGt5UkRCbVlrUTRNMFZDTVhjaUxDSm9ZVzVrYkdVaU9pSjNhWEpsWVhCd09pOHZKVFF3WVd4cFkyVmZkMmx5WlVCM2FYSmxMbU52YlNJc0luUmxZVzBpT2lKM2FYSmxJbjAuZlNmQnFuWWlfMTRhZEc5MDAyZ0RJdEgybXNyYW55eXVnR0g5bHpFcmprdmRGbkRPOFRVWWRYUXJKUzdlX3BlU0lzcGxlRUVkaGhzc0gwM3FBWHY2QXciLCJjbGllbnRfaWQiOiJ3aXJlYXBwOi8vMTg1d0hSa1FUd1M5NUY4OFplNDVKUSEzN2ZlOThiZDQwZDBkZmVAd2lyZS5jb20iLCJhcGlfdmVyc2lvbiI6NSwic2NvcGUiOiJ3aXJlX2NsaWVudF9pZCJ9.GKK7ZsJ8EWJjeaHqf8P48H9mluJhxyXUmI0FO3xstda3XDJIK7Z5Ur4hi1OIJB0ZsS5BqRVT2q5whL4KP9hZCA` ch := &Challenge{ Token: "bXUGNpUfcRx3EhB34xP3y62aQZoGZS6j", } issuedAtUnix, err := strconv.ParseInt("1704985205", 10, 64) require.NoError(t, err) issuedAt := time.Unix(issuedAtUnix, 0) jwkBytes := []byte(`{"crv": "Ed25519", "kty": "OKP", "x": "1L1eH2a6AgVvzTp5ZalKRfq6pVPOtEjI7h8TPzBYFgM"}`) var accountJWK jose.JSONWebKey json.Unmarshal(jwkBytes, &accountJWK) rawKid, err := accountJWK.Thumbprint(crypto.SHA256) require.NoError(t, err) accountJWK.KeyID = base64.RawURLEncoding.EncodeToString(rawKid) at, dpop, err := parseAndVerifyWireAccessToken(wireVerifyParams{ token: token, tokenKey: pk, dpopKey: accountJWK.Public(), dpopKeyID: accountJWK.KeyID, issuer: issuer, wireID: wireID, chToken: ch.Token, t: issuedAt.Add(1 * time.Minute), // set validation time to be one minute after issuance }) if assert.NoError(t, err) { // token assertions assert.Equal(t, "42c46d4c-e510-4175-9fb5-d055e125a49d", at.ID) assert.Equal(t, "http://wire.com:19983/clients/7a41cf5b79683410/access-token", at.Issuer) assert.Equal(t, "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com", at.Subject) assert.Contains(t, at.Audience, "http://wire.com:19983/clients/7a41cf5b79683410/access-token") assert.Equal(t, "bXUGNpUfcRx3EhB34xP3y62aQZoGZS6j", at.Challenge) assert.Equal(t, "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com", at.ClientID) assert.Equal(t, 5, at.APIVersion) assert.Equal(t, "wire_client_id", at.Scope) if assert.NotNil(t, at.Cnf) { assert.Equal(t, "oMWfNDJQsI5cPlXN5UoBNncKtc4f2dq2vwCjjXsqw7Q", at.Cnf.Kid) } // dpop proof assertions dt := *dpop assert.Equal(t, "bXUGNpUfcRx3EhB34xP3y62aQZoGZS6j", dt["chal"].(string)) assert.Equal(t, "wireapp://%40alice_wire@wire.com", dt["handle"].(string)) assert.Equal(t, "POST", dt["htm"].(string)) assert.Equal(t, "http://wire.com:19983/clients/7a41cf5b79683410/access-token", dt["htu"].(string)) assert.Equal(t, "5e6684cb-6b48-468d-b091-ff04bed6ec2e", dt["jti"].(string)) assert.Equal(t, "UEJyR2dqOEhzZFJEYWJBaTkyODNEYTE2aEs0dHIxcEc", dt["nonce"].(string)) assert.Equal(t, "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com", dt["sub"].(string)) assert.Equal(t, "wire", dt["team"].(string)) } } func Test_validateWireOIDCClaims(t *testing.T) { fakeKey := ` -----BEGIN PUBLIC KEY----- MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= -----END PUBLIC KEY-----` opts := &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "http://dex:15818/dex", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "wireapp", SignatureAlgorithms: []string{"RS256"}, Now: func() time.Time { return time.Date(2024, 1, 12, 18, 32, 41, 0, time.UTC) // (Token Expiry: 2024-01-12 21:32:42 +0100 CET) }, InsecureSkipSignatureCheck: true, // skipping signature check for this specific test }, TransformTemplate: `{"name": "{{ .preferred_username }}", "preferred_username": "{{ .name }}"}`, }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, } err := opts.Validate() require.NoError(t, err) idTokenString := `eyJhbGciOiJSUzI1NiIsImtpZCI6IjZhNDZlYzQ3YTQzYWI1ZTc4NzU3MzM5NWY1MGY4ZGQ5MWI2OTM5MzcifQ.eyJpc3MiOiJodHRwOi8vZGV4OjE1ODE4L2RleCIsInN1YiI6IkNqcDNhWEpsWVhCd09pOHZTMmh0VjBOTFpFTlRXakoyT1dWTWFHRk9XVlp6WnlFeU5UZzFNVEpoT0RRek5qTXhaV1V6UUhkcGNtVXVZMjl0RWdSc1pHRnciLCJhdWQiOiJ3aXJlYXBwIiwiZXhwIjoxNzA1MDkxNTYyLCJpYXQiOjE3MDUwMDUxNjIsIm5vbmNlIjoib0VjUzBRQUNXLVIyZWkxS09wUmZ2QSIsImF0X2hhc2giOiJoYzk0NmFwS25FeEV5TDVlSzJZMzdRIiwiY19oYXNoIjoidmRubFp2V1d1bVd1Z2NYR1JpOU5FUSIsIm5hbWUiOiJ3aXJlYXBwOi8vJTQwYWxpY2Vfd2lyZUB3aXJlLmNvbSIsInByZWZlcnJlZF91c2VybmFtZSI6IkFsaWNlIFNtaXRoIn0.aEBhWJugBJ9J_0L_4odUCg8SR8HMXVjd__X8uZRo42BSJQQO7-wdpy0jU3S4FOX9fQKr68wD61gS_QsnhfiT7w9U36mLpxaYlNVDCYfpa-gklVFit_0mjUOukXajTLK6H527TGiSss8z22utc40ckS1SbZa2BzKu3yOcqnFHUQwQc5sLYfpRABTB6WBoYFtnWDzdpyWJDaOzz7lfKYv2JBnf9vV8u8SYm-6gNKgtiQ3UUnjhIVUjdfHet2BMvmV2ooZ8V441RULCzKKG_sWZba-D_k_TOnSholGobtUOcKHlmVlmfUe8v7kuyBdhbPcembfgViaNldLQGKZjZfgvLg` ctx := context.Background() o := opts.GetOIDCOptions() verifier, err := o.GetVerifier(ctx) require.NoError(t, err) idToken, err := verifier.Verify(ctx, idTokenString) require.NoError(t, err) wireID := wire.UserID{ Name: "Alice Smith", Handle: "wireapp://%40alice_wire@wire.com", } got, err := validateWireOIDCClaims(o, idToken, wireID) assert.NoError(t, err) assert.Equal(t, "wireapp://%40alice_wire@wire.com", got["preferred_username"].(string)) assert.Equal(t, "Alice Smith", got["name"].(string)) assert.Equal(t, "http://dex:15818/dex", got["iss"].(string)) } func createWireOptions(t *testing.T, transformTemplate string) *wireprovisioner.Options { t.Helper() fakeKey := ` -----BEGIN PUBLIC KEY----- MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= -----END PUBLIC KEY-----` opts := &wireprovisioner.Options{ OIDC: &wireprovisioner.OIDCOptions{ Provider: &wireprovisioner.Provider{ IssuerURL: "https://issuer.example.com", Algorithms: []string{"ES256"}, }, Config: &wireprovisioner.Config{ ClientID: "unit test", SignatureAlgorithms: []string{"ES256"}, Now: time.Now, }, TransformTemplate: transformTemplate, }, DPOP: &wireprovisioner.DPOPOptions{ SigningKey: []byte(fakeKey), }, } err := opts.Validate() require.NoError(t, err) return opts } func Test_idTokenTransformation(t *testing.T) { // {"name": "wireapp://%40alice_wire@wire.com", "preferred_username": "Alice Smith", "iss": "http://dex:15818/dex", ...} idTokenString := `eyJhbGciOiJSUzI1NiIsImtpZCI6IjZhNDZlYzQ3YTQzYWI1ZTc4NzU3MzM5NWY1MGY4ZGQ5MWI2OTM5MzcifQ.eyJpc3MiOiJodHRwOi8vZGV4OjE1ODE4L2RleCIsInN1YiI6IkNqcDNhWEpsWVhCd09pOHZTMmh0VjBOTFpFTlRXakoyT1dWTWFHRk9XVlp6WnlFeU5UZzFNVEpoT0RRek5qTXhaV1V6UUhkcGNtVXVZMjl0RWdSc1pHRnciLCJhdWQiOiJ3aXJlYXBwIiwiZXhwIjoxNzA1MDkxNTYyLCJpYXQiOjE3MDUwMDUxNjIsIm5vbmNlIjoib0VjUzBRQUNXLVIyZWkxS09wUmZ2QSIsImF0X2hhc2giOiJoYzk0NmFwS25FeEV5TDVlSzJZMzdRIiwiY19oYXNoIjoidmRubFp2V1d1bVd1Z2NYR1JpOU5FUSIsIm5hbWUiOiJ3aXJlYXBwOi8vJTQwYWxpY2Vfd2lyZUB3aXJlLmNvbSIsInByZWZlcnJlZF91c2VybmFtZSI6IkFsaWNlIFNtaXRoIn0.aEBhWJugBJ9J_0L_4odUCg8SR8HMXVjd__X8uZRo42BSJQQO7-wdpy0jU3S4FOX9fQKr68wD61gS_QsnhfiT7w9U36mLpxaYlNVDCYfpa-gklVFit_0mjUOukXajTLK6H527TGiSss8z22utc40ckS1SbZa2BzKu3yOcqnFHUQwQc5sLYfpRABTB6WBoYFtnWDzdpyWJDaOzz7lfKYv2JBnf9vV8u8SYm-6gNKgtiQ3UUnjhIVUjdfHet2BMvmV2ooZ8V441RULCzKKG_sWZba-D_k_TOnSholGobtUOcKHlmVlmfUe8v7kuyBdhbPcembfgViaNldLQGKZjZfgvLg` var claims struct { Name string `json:"name,omitempty"` Handle string `json:"preferred_username,omitempty"` Issuer string `json:"iss,omitempty"` } idToken, err := jose.ParseSigned(idTokenString) require.NoError(t, err) err = idToken.UnsafeClaimsWithoutVerification(&claims) require.NoError(t, err) // original token contains "Alice Smith" as handle, and name as "wireapp://%40alice_wire@wire.com" assert.Equal(t, "Alice Smith", claims.Handle) assert.Equal(t, "wireapp://%40alice_wire@wire.com", claims.Name) assert.Equal(t, "http://dex:15818/dex", claims.Issuer) var m map[string]any err = idToken.UnsafeClaimsWithoutVerification(&m) require.NoError(t, err) opts := createWireOptions(t, "") // uses default transformation template result, err := opts.GetOIDCOptions().Transform(m) require.NoError(t, err) // default transformation sets preferred username to handle; name as name assert.Equal(t, "Alice Smith", result["preferred_username"].(string)) assert.Equal(t, "wireapp://%40alice_wire@wire.com", result["name"].(string)) assert.Equal(t, "http://dex:15818/dex", result["iss"].(string)) // swap the preferred_name and the name swap := `{"name": "{{ .preferred_username }}", "preferred_username": "{{ .name }}"}` opts = createWireOptions(t, swap) result, err = opts.GetOIDCOptions().Transform(m) require.NoError(t, err) // with the transformation, handle now contains wireapp://%40alice_wire@wire.com, name contains Alice Smith assert.Equal(t, "wireapp://%40alice_wire@wire.com", result["preferred_username"].(string)) assert.Equal(t, "Alice Smith", result["name"].(string)) assert.Equal(t, "http://dex:15818/dex", result["iss"].(string)) } ================================================ FILE: acme/client.go ================================================ package acme import ( "context" "crypto/tls" "net" "net/http" "time" ) // Client is the interface used to verify ACME challenges. type Client interface { // Get issues an HTTP GET to the specified URL. Get(url string) (*http.Response, error) // LookupTXT returns the DNS TXT records for the given domain name. LookupTxt(name string) ([]string, error) // TLSDial connects to the given network address using net.Dialer and then // initiates a TLS handshake, returning the resulting TLS connection. TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) } type clientKey struct{} // NewClientContext adds the given client to the context. func NewClientContext(ctx context.Context, c Client) context.Context { return context.WithValue(ctx, clientKey{}, c) } // ClientFromContext returns the current client from the given context. func ClientFromContext(ctx context.Context) (c Client, ok bool) { c, ok = ctx.Value(clientKey{}).(Client) return } // MustClientFromContext returns the current client from the given context. It will // return a new instance of the client if it does not exist. func MustClientFromContext(ctx context.Context) Client { c, ok := ClientFromContext(ctx) if !ok { return NewClient() } return c } type client struct { http *http.Client dialer *net.Dialer } // NewClient returns an implementation of Client for verifying ACME challenges. func NewClient() Client { return &client{ http: &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{ //nolint:gosec // used on tls-alpn-01 challenge InsecureSkipVerify: true, // lgtm[go/disabled-certificate-check] }, }, }, dialer: &net.Dialer{ Timeout: 30 * time.Second, }, } } func (c *client) Get(url string) (*http.Response, error) { return c.http.Get(url) } func (c *client) LookupTxt(name string) ([]string, error) { return net.LookupTXT(name) } func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(c.dialer, network, addr, config) } ================================================ FILE: acme/common.go ================================================ package acme import ( "context" "crypto/x509" "time" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" ) // Clock that returns time in UTC rounded to seconds. type Clock struct{} // Now returns the UTC time rounded to seconds. func (c *Clock) Now() time.Time { return time.Now().UTC().Truncate(time.Second) } var clock Clock // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) AreSANsAllowed(ctx context.Context, sans []string) error IsRevoked(sn string) (bool, error) Revoke(context.Context, *authority.RevokeOptions) error LoadProvisionerByName(string) (provisioner.Interface, error) GetBackdate() *time.Duration } // NewContext adds the given acme components to the context. func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context { ctx = NewDatabaseContext(ctx, db) ctx = NewClientContext(ctx, client) ctx = NewLinkerContext(ctx, linker) // Prerequisite checker is optional. if fn != nil { ctx = NewPrerequisitesCheckerContext(ctx, fn) } return ctx } // PrerequisitesChecker is a function that checks if all prerequisites for // serving ACME are met by the CA configuration. type PrerequisitesChecker func(ctx context.Context) (bool, error) // DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns // always true. func DefaultPrerequisitesChecker(context.Context) (bool, error) { return true, nil } type prerequisitesKey struct{} // NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the // context. func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context { return context.WithValue(ctx, prerequisitesKey{}, fn) } // PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the // context. func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) { fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker) return fn, ok && fn != nil } // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. type Provisioner interface { AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) AuthorizeRevoke(ctx context.Context, token string) error IsChallengeEnabled(ctx context.Context, challenge provisioner.ACMEChallenge) bool IsAttestationFormatEnabled(ctx context.Context, format provisioner.ACMEAttestationFormat) bool GetAttestationRoots() (*x509.CertPool, bool) GetID() string GetName() string DefaultTLSCertDuration() time.Duration GetOptions() *provisioner.Options } type provisionerKey struct{} // NewProvisionerContext adds the given provisioner to the context. func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context { return context.WithValue(ctx, provisionerKey{}, v) } // ProvisionerFromContext returns the current provisioner from the given context. func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) { v, ok = ctx.Value(provisionerKey{}).(Provisioner) return } // MustProvisionerFromContext returns the current provisioner from the given context. // It will panic if it's not in the context. func MustProvisionerFromContext(ctx context.Context) Provisioner { var ( v Provisioner ok bool ) if v, ok = ProvisionerFromContext(ctx); !ok { panic("acme provisioner is not the context") } return v } // MockProvisioner for testing type MockProvisioner struct { Mret1 interface{} Merr error MgetID func() string MgetName func() string MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) MauthorizeRevoke func(ctx context.Context, token string) error MisChallengeEnabled func(ctx context.Context, challenge provisioner.ACMEChallenge) bool MisAttFormatEnabled func(ctx context.Context, format provisioner.ACMEAttestationFormat) bool MgetAttestationRoots func() (*x509.CertPool, bool) MdefaultTLSCertDuration func() time.Duration MgetOptions func() *provisioner.Options } // GetName mock func (m *MockProvisioner) GetName() string { if m.MgetName != nil { return m.MgetName() } return m.Mret1.(string) } // AuthorizeOrderIdentifier mock func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error { if m.MauthorizeOrderIdentifier != nil { return m.MauthorizeOrderIdentifier(ctx, identifier) } return m.Merr } // AuthorizeSign mock func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { if m.MauthorizeSign != nil { return m.MauthorizeSign(ctx, ott) } return m.Mret1.([]provisioner.SignOption), m.Merr } // AuthorizeRevoke mock func (m *MockProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { if m.MauthorizeRevoke != nil { return m.MauthorizeRevoke(ctx, token) } return m.Merr } // IsChallengeEnabled mock func (m *MockProvisioner) IsChallengeEnabled(ctx context.Context, challenge provisioner.ACMEChallenge) bool { if m.MisChallengeEnabled != nil { return m.MisChallengeEnabled(ctx, challenge) } return m.Merr == nil } // IsAttestationFormatEnabled mock func (m *MockProvisioner) IsAttestationFormatEnabled(ctx context.Context, format provisioner.ACMEAttestationFormat) bool { if m.MisAttFormatEnabled != nil { return m.MisAttFormatEnabled(ctx, format) } return m.Merr == nil } func (m *MockProvisioner) GetAttestationRoots() (*x509.CertPool, bool) { if m.MgetAttestationRoots != nil { return m.MgetAttestationRoots() } return m.Mret1.(*x509.CertPool), m.Mret1 != nil } // DefaultTLSCertDuration mock func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration { if m.MdefaultTLSCertDuration != nil { return m.MdefaultTLSCertDuration() } return m.Mret1.(time.Duration) } // GetOptions mock func (m *MockProvisioner) GetOptions() *provisioner.Options { if m.MgetOptions != nil { return m.MgetOptions() } return m.Mret1.(*provisioner.Options) } // GetID mock func (m *MockProvisioner) GetID() string { if m.MgetID != nil { return m.MgetID() } return m.Mret1.(string) } ================================================ FILE: acme/db/nosql/account.go ================================================ package nosql import ( "context" "encoding/json" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/acme" nosqlDB "github.com/smallstep/nosql" "go.step.sm/crypto/jose" ) // dbAccount represents an ACME account. type dbAccount struct { ID string `json:"id"` Key *jose.JSONWebKey `json:"key"` Contact []string `json:"contact,omitempty"` Status acme.Status `json:"status"` LocationPrefix string `json:"locationPrefix"` ProvisionerID string `json:"provisionerID,omitempty"` ProvisionerName string `json:"provisionerName"` CreatedAt time.Time `json:"createdAt"` DeactivatedAt time.Time `json:"deactivatedAt"` } func (dba *dbAccount) clone() *dbAccount { nu := *dba return &nu } func (db *DB) getAccountIDByKeyID(_ context.Context, kid string) (string, error) { id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) if err != nil { if nosqlDB.IsErrNotFound(err) { return "", acme.ErrNotFound } return "", errors.Wrapf(err, "error loading key-account index for key %s", kid) } return string(id), nil } // getDBAccount retrieves and unmarshals dbAccount. func (db *DB) getDBAccount(_ context.Context, id string) (*dbAccount, error) { data, err := db.db.Get(accountTable, []byte(id)) if err != nil { if nosqlDB.IsErrNotFound(err) { return nil, acme.ErrNotFound } return nil, errors.Wrapf(err, "error loading account %s", id) } dbacc := new(dbAccount) if err = json.Unmarshal(data, dbacc); err != nil { return nil, errors.Wrapf(err, "error unmarshaling account %s into dbAccount", id) } return dbacc, nil } // GetAccount retrieves an ACME account by ID. func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) { dbacc, err := db.getDBAccount(ctx, id) if err != nil { return nil, err } return &acme.Account{ Status: dbacc.Status, Contact: dbacc.Contact, Key: dbacc.Key, ID: dbacc.ID, LocationPrefix: dbacc.LocationPrefix, ProvisionerID: dbacc.ProvisionerID, ProvisionerName: dbacc.ProvisionerName, }, nil } // GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK). func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) { id, err := db.getAccountIDByKeyID(ctx, kid) if err != nil { return nil, err } return db.GetAccount(ctx, id) } // CreateAccount imlements the AcmeDB.CreateAccount interface. func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error { var err error acc.ID, err = randID() if err != nil { return err } dba := &dbAccount{ ID: acc.ID, Key: acc.Key, Contact: acc.Contact, Status: acc.Status, CreatedAt: clock.Now(), LocationPrefix: acc.LocationPrefix, ProvisionerID: acc.ProvisionerID, ProvisionerName: acc.ProvisionerName, } kid, err := acme.KeyToID(dba.Key) if err != nil { return err } kidB := []byte(kid) // Set the jwkID -> acme account ID index _, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(acc.ID)) switch { case err != nil: return errors.Wrap(err, "error storing keyID to accountID index") case !swapped: return errors.Errorf("key-id to account-id index already exists") default: if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil { db.db.Del(accountByKeyIDTable, kidB) return err } return nil } } // UpdateAccount imlements the AcmeDB.UpdateAccount interface. func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error { old, err := db.getDBAccount(ctx, acc.ID) if err != nil { return err } nu := old.clone() nu.Contact = acc.Contact nu.Status = acc.Status // If the status has changed to 'deactivated', then set deactivatedAt timestamp. if acc.Status == acme.StatusDeactivated && old.Status != acme.StatusDeactivated { nu.DeactivatedAt = clock.Now() } return db.save(ctx, old.ID, nu, old, "account", accountTable) } ================================================ FILE: acme/db/nosql/account_test.go ================================================ package nosql import ( "context" "encoding/json" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" nosqldb "github.com/smallstep/nosql/database" "go.step.sm/crypto/jose" ) func TestDB_getDBAccount(t *testing.T) { accID := "accID" type test struct { db nosql.DB err error acmeErr *acme.Error dbacc *dbAccount } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, string(key), accID) return nil, nosqldb.ErrNotFound }, }, err: acme.ErrNotFound, } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, string(key), accID) return nil, errors.New("force") }, }, err: errors.New("error loading account accID: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, string(key), accID) return []byte("foo"), nil }, }, err: errors.New("error unmarshaling account accID into dbAccount"), } }, "ok": func(t *testing.T) test { now := clock.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) dbacc := &dbAccount{ ID: accID, Status: acme.StatusDeactivated, CreatedAt: now, DeactivatedAt: now, Contact: []string{"foo", "bar"}, Key: jwk, ProvisionerID: "73d2c0f1-9753-448b-9b48-bf00fe434681", ProvisionerName: "acme", } b, err := json.Marshal(dbacc) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, string(key), accID) return b, nil }, }, dbacc: dbacc, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if dbacc, err := d.getDBAccount(context.Background(), accID); err != nil { var acmeErr *acme.Error if errors.As(err, &acmeErr) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, acmeErr.Type, tc.acmeErr.Type) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) assert.Equals(t, acmeErr.Status, tc.acmeErr.Status) assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, dbacc.ID, tc.dbacc.ID) assert.Equals(t, dbacc.Status, tc.dbacc.Status) assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt) assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt) assert.Equals(t, dbacc.Contact, tc.dbacc.Contact) assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID) } }) } } func TestDB_getAccountIDByKeyID(t *testing.T) { accID := "accID" kid := "kid" type test struct { db nosql.DB err error acmeErr *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountByKeyIDTable) assert.Equals(t, string(key), kid) return nil, nosqldb.ErrNotFound }, }, err: acme.ErrNotFound, } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountByKeyIDTable) assert.Equals(t, string(key), kid) return nil, errors.New("force") }, }, err: errors.New("error loading key-account index for key kid: force"), } }, "ok": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountByKeyIDTable) assert.Equals(t, string(key), kid) return []byte(accID), nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if retAccID, err := d.getAccountIDByKeyID(context.Background(), kid); err != nil { var acmeErr *acme.Error if errors.As(err, &acmeErr) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, acmeErr.Type, tc.acmeErr.Type) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) assert.Equals(t, acmeErr.Status, tc.acmeErr.Status) assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, retAccID, accID) } }) } } func TestDB_GetAccount(t *testing.T) { accID := "accID" locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/" provisionerName := "foo" type test struct { db nosql.DB err error acmeErr *acme.Error dbacc *dbAccount } var tests = map[string]func(t *testing.T) test{ "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, string(key), accID) return nil, errors.New("force") }, }, err: errors.New("error loading account accID: force"), } }, "ok": func(t *testing.T) test { now := clock.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) dbacc := &dbAccount{ ID: accID, Status: acme.StatusDeactivated, CreatedAt: now, DeactivatedAt: now, Contact: []string{"foo", "bar"}, Key: jwk, LocationPrefix: locationPrefix, ProvisionerName: provisionerName, } b, err := json.Marshal(dbacc) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, string(key), accID) return b, nil }, }, dbacc: dbacc, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if acc, err := d.GetAccount(context.Background(), accID); err != nil { var acmeErr *acme.Error if errors.As(err, &acmeErr) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, acmeErr.Type, tc.acmeErr.Type) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) assert.Equals(t, acmeErr.Status, tc.acmeErr.Status) assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, acc.ID, tc.dbacc.ID) assert.Equals(t, acc.Status, tc.dbacc.Status) assert.Equals(t, acc.Contact, tc.dbacc.Contact) assert.Equals(t, acc.LocationPrefix, tc.dbacc.LocationPrefix) assert.Equals(t, acc.ProvisionerName, tc.dbacc.ProvisionerName) assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) } }) } } func TestDB_GetAccountByKeyID(t *testing.T) { accID := "accID" kid := "kid" type test struct { db nosql.DB err error acmeErr *acme.Error dbacc *dbAccount } var tests = map[string]func(t *testing.T) test{ "fail/db.getAccountIDByKeyID-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(accountByKeyIDTable)) assert.Equals(t, string(key), kid) return nil, errors.New("force") }, }, err: errors.New("error loading key-account index for key kid: force"), } }, "fail/db.GetAccount-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(accountByKeyIDTable): assert.Equals(t, string(key), kid) return []byte(accID), nil case string(accountTable): assert.Equals(t, string(key), accID) return nil, errors.New("force") default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, }, err: errors.New("error loading account accID: force"), } }, "ok": func(t *testing.T) test { now := clock.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) dbacc := &dbAccount{ ID: accID, Status: acme.StatusDeactivated, CreatedAt: now, DeactivatedAt: now, Contact: []string{"foo", "bar"}, Key: jwk, } b, err := json.Marshal(dbacc) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(accountByKeyIDTable): assert.Equals(t, string(key), kid) return []byte(accID), nil case string(accountTable): assert.Equals(t, string(key), accID) return b, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, }, dbacc: dbacc, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if acc, err := d.GetAccountByKeyID(context.Background(), kid); err != nil { var acmeErr *acme.Error if errors.As(err, &acmeErr) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, acmeErr.Type, tc.acmeErr.Type) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) assert.Equals(t, acmeErr.Status, tc.acmeErr.Status) assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, acc.ID, tc.dbacc.ID) assert.Equals(t, acc.Status, tc.dbacc.Status) assert.Equals(t, acc.Contact, tc.dbacc.Contact) assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) } }) } } func TestDB_CreateAccount(t *testing.T) { locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/" type test struct { db nosql.DB acc *acme.Account err error _id *string } var tests = map[string]func(t *testing.T) test{ "fail/keyID-cmpAndSwap-error": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ Status: acme.StatusValid, Contact: []string{"foo", "bar"}, Key: jwk, LocationPrefix: locationPrefix, } return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, accountByKeyIDTable) assert.Equals(t, string(key), jwk.KeyID) assert.Equals(t, old, nil) assert.Equals(t, nu, []byte(acc.ID)) return nil, false, errors.New("force") }, }, acc: acc, err: errors.New("error storing keyID to accountID index: force"), } }, "fail/keyID-cmpAndSwap-false": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ Status: acme.StatusValid, Contact: []string{"foo", "bar"}, Key: jwk, LocationPrefix: locationPrefix, } return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, accountByKeyIDTable) assert.Equals(t, string(key), jwk.KeyID) assert.Equals(t, old, nil) assert.Equals(t, nu, []byte(acc.ID)) return nil, false, nil }, }, acc: acc, err: errors.New("key-id to account-id index already exists"), } }, "fail/account-save-error": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ Status: acme.StatusValid, Contact: []string{"foo", "bar"}, Key: jwk, LocationPrefix: locationPrefix, } return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { case string(accountByKeyIDTable): assert.Equals(t, string(key), jwk.KeyID) assert.Equals(t, old, nil) return nu, true, nil case string(accountTable): assert.Equals(t, string(key), acc.ID) assert.Equals(t, old, nil) dbacc := new(dbAccount) assert.FatalError(t, json.Unmarshal(nu, dbacc)) assert.Equals(t, dbacc.ID, string(key)) assert.Equals(t, dbacc.Contact, acc.Contact) assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix) assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName) assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) assert.True(t, dbacc.DeactivatedAt.IsZero()) return nil, false, errors.New("force") default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force") } }, }, acc: acc, err: errors.New("error saving acme account: force"), } }, "ok": func(t *testing.T) test { var ( id string idPtr = &id ) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ Status: acme.StatusValid, Contact: []string{"foo", "bar"}, Key: jwk, LocationPrefix: locationPrefix, } return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { id = string(key) switch string(bucket) { case string(accountByKeyIDTable): assert.Equals(t, string(key), jwk.KeyID) assert.Equals(t, old, nil) return nu, true, nil case string(accountTable): assert.Equals(t, string(key), acc.ID) assert.Equals(t, old, nil) dbacc := new(dbAccount) assert.FatalError(t, json.Unmarshal(nu, dbacc)) assert.Equals(t, dbacc.ID, string(key)) assert.Equals(t, dbacc.Contact, acc.Contact) assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix) assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName) assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) assert.True(t, dbacc.DeactivatedAt.IsZero()) return nu, true, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force") } }, }, acc: acc, _id: idPtr, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.CreateAccount(context.Background(), tc.acc); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.acc.ID, *tc._id) } } }) } } func TestDB_UpdateAccount(t *testing.T) { accID := "accID" now := clock.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) dbacc := &dbAccount{ ID: accID, Status: acme.StatusDeactivated, CreatedAt: now, DeactivatedAt: now, Contact: []string{"foo", "bar"}, LocationPrefix: "foo", ProvisionerName: "alpha", Key: jwk, } b, err := json.Marshal(dbacc) assert.FatalError(t, err) type test struct { db nosql.DB acc *acme.Account err error } var tests = map[string]func(t *testing.T) test{ "fail/db.Get-error": func(t *testing.T) test { return test{ acc: &acme.Account{ ID: accID, }, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, string(key), accID) return nil, errors.New("force") }, }, err: errors.New("error loading account accID: force"), } }, "fail/already-deactivated": func(t *testing.T) test { clone := dbacc.clone() clone.Status = acme.StatusDeactivated clone.DeactivatedAt = now dbaccb, err := json.Marshal(clone) assert.FatalError(t, err) acc := &acme.Account{ ID: accID, Status: acme.StatusDeactivated, Contact: []string{"foo", "bar"}, } return test{ acc: acc, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, string(key), accID) return dbaccb, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, old, b) dbNew := new(dbAccount) assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, clone.ID) assert.Equals(t, dbNew.Status, clone.Status) assert.Equals(t, dbNew.Contact, clone.Contact) assert.Equals(t, dbNew.Key.KeyID, clone.Key.KeyID) assert.Equals(t, dbNew.CreatedAt, clone.CreatedAt) assert.Equals(t, dbNew.DeactivatedAt, clone.DeactivatedAt) return nil, false, errors.New("force") }, }, err: errors.New("error saving acme account: force"), } }, "fail/db.CmpAndSwap-error": func(t *testing.T) test { acc := &acme.Account{ ID: accID, Status: acme.StatusDeactivated, Contact: []string{"foo", "bar"}, } return test{ acc: acc, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, string(key), accID) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, old, b) dbNew := new(dbAccount) assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbacc.ID) assert.Equals(t, dbNew.Status, acc.Status) assert.Equals(t, dbNew.Contact, dbacc.Contact) assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID) assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt) assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now)) assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now)) return nil, false, errors.New("force") }, }, err: errors.New("error saving acme account: force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ ID: accID, Status: acme.StatusDeactivated, Contact: []string{"baz", "zap"}, LocationPrefix: "bar", ProvisionerName: "beta", Key: jwk, } return test{ acc: acc, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, string(key), accID) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, accountTable) assert.Equals(t, old, b) dbNew := new(dbAccount) assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbacc.ID) assert.Equals(t, dbNew.Status, acc.Status) assert.Equals(t, dbNew.Contact, acc.Contact) // LocationPrefix should not change. assert.Equals(t, dbNew.LocationPrefix, dbacc.LocationPrefix) assert.Equals(t, dbNew.ProvisionerName, dbacc.ProvisionerName) assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID) assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt) assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now)) assert.True(t, dbNew.DeactivatedAt.Add(time.Minute).After(now)) return nu, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.UpdateAccount(context.Background(), tc.acc); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } ================================================ FILE: acme/db/nosql/authz.go ================================================ package nosql import ( "context" "encoding/json" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/nosql" ) // dbAuthz is the base authz type that others build from. type dbAuthz struct { ID string `json:"id"` AccountID string `json:"accountID"` Identifier acme.Identifier `json:"identifier"` Status acme.Status `json:"status"` Token string `json:"token"` Fingerprint string `json:"fingerprint,omitempty"` ChallengeIDs []string `json:"challengeIDs"` Wildcard bool `json:"wildcard"` CreatedAt time.Time `json:"createdAt"` ExpiresAt time.Time `json:"expiresAt"` Error *acme.Error `json:"error"` } func (ba *dbAuthz) clone() *dbAuthz { u := *ba return &u } // getDBAuthz retrieves and unmarshals a database representation of the // ACME Authorization type. func (db *DB) getDBAuthz(_ context.Context, id string) (*dbAuthz, error) { data, err := db.db.Get(authzTable, []byte(id)) if nosql.IsErrNotFound(err) { return nil, acme.NewError(acme.ErrorMalformedType, "authz %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading authz %s", id) } var dbaz dbAuthz if err = json.Unmarshal(data, &dbaz); err != nil { return nil, errors.Wrapf(err, "error unmarshaling authz %s into dbAuthz", id) } return &dbaz, nil } // GetAuthorization retrieves and unmarshals an ACME authz type from the database. // Implements acme.DB GetAuthorization interface. func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorization, error) { dbaz, err := db.getDBAuthz(ctx, id) if err != nil { return nil, err } var chs = make([]*acme.Challenge, len(dbaz.ChallengeIDs)) for i, chID := range dbaz.ChallengeIDs { chs[i], err = db.GetChallenge(ctx, chID, id) if err != nil { return nil, err } } return &acme.Authorization{ ID: dbaz.ID, AccountID: dbaz.AccountID, Identifier: dbaz.Identifier, Status: dbaz.Status, Challenges: chs, Wildcard: dbaz.Wildcard, ExpiresAt: dbaz.ExpiresAt, Token: dbaz.Token, Fingerprint: dbaz.Fingerprint, Error: dbaz.Error, }, nil } // CreateAuthorization creates an entry in the database for the Authorization. // Implements the acme.DB.CreateAuthorization interface. func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) error { var err error az.ID, err = randID() if err != nil { return err } chIDs := make([]string, len(az.Challenges)) for i, ch := range az.Challenges { chIDs[i] = ch.ID } now := clock.Now() dbaz := &dbAuthz{ ID: az.ID, AccountID: az.AccountID, Status: az.Status, CreatedAt: now, ExpiresAt: az.ExpiresAt, Identifier: az.Identifier, ChallengeIDs: chIDs, Token: az.Token, Fingerprint: az.Fingerprint, Wildcard: az.Wildcard, } return db.save(ctx, az.ID, dbaz, nil, "authz", authzTable) } // UpdateAuthorization saves an updated ACME Authorization to the database. func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) error { old, err := db.getDBAuthz(ctx, az.ID) if err != nil { return err } nu := old.clone() nu.Status = az.Status nu.Fingerprint = az.Fingerprint nu.Error = az.Error return db.save(ctx, old.ID, nu, old, "authz", authzTable) } // GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database. func (db *DB) GetAuthorizationsByAccountID(_ context.Context, accountID string) ([]*acme.Authorization, error) { entries, err := db.db.List(authzTable) if err != nil { return nil, errors.Wrapf(err, "error listing authz") } authzs := []*acme.Authorization{} for _, entry := range entries { dbaz := new(dbAuthz) if err = json.Unmarshal(entry.Value, dbaz); err != nil { return nil, errors.Wrapf(err, "error unmarshaling dbAuthz key '%s' into dbAuthz struct", string(entry.Key)) } // Filter out all dbAuthzs that don't belong to the accountID. This // could be made more efficient with additional data structures mapping the // Account ID to authorizations. Not trivial to do, though. if dbaz.AccountID != accountID { continue } authzs = append(authzs, &acme.Authorization{ ID: dbaz.ID, AccountID: dbaz.AccountID, Identifier: dbaz.Identifier, Status: dbaz.Status, Challenges: nil, // challenges not required for current use case Wildcard: dbaz.Wildcard, ExpiresAt: dbaz.ExpiresAt, Token: dbaz.Token, Fingerprint: dbaz.Fingerprint, Error: dbaz.Error, }) } return authzs, nil } ================================================ FILE: acme/db/nosql/authz_test.go ================================================ package nosql import ( "context" "encoding/json" "fmt" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" nosqldb "github.com/smallstep/nosql/database" ) func TestDB_getDBAuthz(t *testing.T) { azID := "azID" type test struct { db nosql.DB err error acmeErr *acme.Error dbaz *dbAuthz } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), azID) return nil, nosqldb.ErrNotFound }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), azID) return nil, errors.New("force") }, }, err: errors.New("error loading authz azID: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), azID) return []byte("foo"), nil }, }, err: errors.New("error unmarshaling authz azID into dbAuthz"), } }, "ok": func(t *testing.T) test { now := clock.Now() dbaz := &dbAuthz{ ID: azID, AccountID: "accountID", Identifier: acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }, Status: acme.StatusPending, Token: "token", CreatedAt: now, ExpiresAt: now.Add(5 * time.Minute), Error: acme.NewErrorISE("The server experienced an internal error"), ChallengeIDs: []string{"foo", "bar"}, Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), azID) return b, nil }, }, dbaz: dbaz, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if dbaz, err := d.getDBAuthz(context.Background(), azID); err != nil { var acmeErr *acme.Error if errors.As(err, &acmeErr) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, acmeErr.Type, tc.acmeErr.Type) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) assert.Equals(t, acmeErr.Status, tc.acmeErr.Status) assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, dbaz.ID, tc.dbaz.ID) assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID) assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier) assert.Equals(t, dbaz.Status, tc.dbaz.Status) assert.Equals(t, dbaz.Token, tc.dbaz.Token) assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt) assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt) assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error()) assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard) } }) } } func TestDB_GetAuthorization(t *testing.T) { azID := "azID" type test struct { db nosql.DB err error acmeErr *acme.Error dbaz *dbAuthz } var tests = map[string]func(t *testing.T) test{ "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), azID) return nil, errors.New("force") }, }, err: errors.New("error loading authz azID: force"), } }, "fail/forward-acme-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), azID) return nil, nosqldb.ErrNotFound }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "authz azID not found"), } }, "fail/db.GetChallenge-error": func(t *testing.T) test { now := clock.Now() dbaz := &dbAuthz{ ID: azID, AccountID: "accountID", Identifier: acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }, Status: acme.StatusPending, Token: "token", CreatedAt: now, ExpiresAt: now.Add(5 * time.Minute), Error: acme.NewErrorISE("force"), ChallengeIDs: []string{"foo", "bar"}, Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(authzTable): assert.Equals(t, string(key), azID) return b, nil case string(challengeTable): assert.Equals(t, string(key), "foo") return nil, errors.New("force") default: assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket))) return nil, errors.New("force") } }, }, err: errors.New("error loading acme challenge foo: force"), } }, "fail/db.GetChallenge-not-found": func(t *testing.T) test { now := clock.Now() dbaz := &dbAuthz{ ID: azID, AccountID: "accountID", Identifier: acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }, Status: acme.StatusPending, Token: "token", CreatedAt: now, ExpiresAt: now.Add(5 * time.Minute), Error: acme.NewErrorISE("force"), ChallengeIDs: []string{"foo", "bar"}, Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(authzTable): assert.Equals(t, string(key), azID) return b, nil case string(challengeTable): assert.Equals(t, string(key), "foo") return nil, nosqldb.ErrNotFound default: assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket))) return nil, errors.New("force") } }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge foo not found"), } }, "ok": func(t *testing.T) test { now := clock.Now() dbaz := &dbAuthz{ ID: azID, AccountID: "accountID", Identifier: acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }, Status: acme.StatusPending, Token: "token", CreatedAt: now, ExpiresAt: now.Add(5 * time.Minute), Error: acme.NewErrorISE("The server experienced an internal error"), ChallengeIDs: []string{"foo", "bar"}, Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) chCount := 0 fooChb, err := json.Marshal(&dbChallenge{ID: "foo"}) assert.FatalError(t, err) barChb, err := json.Marshal(&dbChallenge{ID: "bar"}) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(authzTable): assert.Equals(t, string(key), azID) return b, nil case string(challengeTable): if chCount == 0 { chCount++ assert.Equals(t, string(key), "foo") return fooChb, nil } assert.Equals(t, string(key), "bar") return barChb, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket '%s'", string(bucket))) return nil, errors.New("force") } }, }, dbaz: dbaz, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if az, err := d.GetAuthorization(context.Background(), azID); err != nil { var acmeErr *acme.Error if errors.As(err, &acmeErr) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, acmeErr.Type, tc.acmeErr.Type) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) assert.Equals(t, acmeErr.Status, tc.acmeErr.Status) assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, az.ID, tc.dbaz.ID) assert.Equals(t, az.AccountID, tc.dbaz.AccountID) assert.Equals(t, az.Identifier, tc.dbaz.Identifier) assert.Equals(t, az.Status, tc.dbaz.Status) assert.Equals(t, az.Token, tc.dbaz.Token) assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard) assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt) assert.Equals(t, az.Challenges, []*acme.Challenge{ {ID: "foo"}, {ID: "bar"}, }) assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error()) } }) } } func TestDB_CreateAuthorization(t *testing.T) { azID := "azID" type test struct { db nosql.DB az *acme.Authorization err error _id *string } var tests = map[string]func(t *testing.T) test{ "fail/cmpAndSwap-error": func(t *testing.T) test { now := clock.Now() az := &acme.Authorization{ ID: azID, AccountID: "accountID", Identifier: acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }, Status: acme.StatusPending, Token: "token", ExpiresAt: now.Add(5 * time.Minute), Challenges: []*acme.Challenge{ {ID: "foo"}, {ID: "bar"}, }, Wildcard: true, Error: acme.NewErrorISE("force"), } return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), az.ID) assert.Equals(t, old, nil) dbaz := new(dbAuthz) assert.FatalError(t, json.Unmarshal(nu, dbaz)) assert.Equals(t, dbaz.ID, string(key)) assert.Equals(t, dbaz.AccountID, az.AccountID) assert.Equals(t, dbaz.Identifier, acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }) assert.Equals(t, dbaz.Status, az.Status) assert.Equals(t, dbaz.Token, az.Token) assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"}) assert.Equals(t, dbaz.Wildcard, az.Wildcard) assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt) assert.Nil(t, dbaz.Error) assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt)) return nil, false, errors.New("force") }, }, az: az, err: errors.New("error saving acme authz: force"), } }, "ok": func(t *testing.T) test { var ( id string idPtr = &id now = clock.Now() az = &acme.Authorization{ ID: azID, AccountID: "accountID", Identifier: acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }, Status: acme.StatusPending, Token: "token", ExpiresAt: now.Add(5 * time.Minute), Challenges: []*acme.Challenge{ {ID: "foo"}, {ID: "bar"}, }, Wildcard: true, Error: acme.NewErrorISE("force"), } ) return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { *idPtr = string(key) assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), az.ID) assert.Equals(t, old, nil) dbaz := new(dbAuthz) assert.FatalError(t, json.Unmarshal(nu, dbaz)) assert.Equals(t, dbaz.ID, string(key)) assert.Equals(t, dbaz.AccountID, az.AccountID) assert.Equals(t, dbaz.Identifier, acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }) assert.Equals(t, dbaz.Status, az.Status) assert.Equals(t, dbaz.Token, az.Token) assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"}) assert.Equals(t, dbaz.Wildcard, az.Wildcard) assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt) assert.Nil(t, dbaz.Error) assert.True(t, clock.Now().Add(-time.Minute).Before(dbaz.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbaz.CreatedAt)) return nu, true, nil }, }, az: az, _id: idPtr, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.CreateAuthorization(context.Background(), tc.az); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.az.ID, *tc._id) } } }) } } func TestDB_UpdateAuthorization(t *testing.T) { azID := "azID" now := clock.Now() dbaz := &dbAuthz{ ID: azID, AccountID: "accountID", Identifier: acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }, Status: acme.StatusPending, Token: "token", CreatedAt: now, ExpiresAt: now.Add(5 * time.Minute), ChallengeIDs: []string{"foo", "bar"}, Wildcard: true, Fingerprint: "fingerprint", } b, err := json.Marshal(dbaz) assert.FatalError(t, err) type test struct { db nosql.DB az *acme.Authorization err error } var tests = map[string]func(t *testing.T) test{ "fail/db.Get-error": func(t *testing.T) test { return test{ az: &acme.Authorization{ ID: azID, }, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), azID) return nil, errors.New("force") }, }, err: errors.New("error loading authz azID: force"), } }, "fail/db.CmpAndSwap-error": func(t *testing.T) test { updAz := &acme.Authorization{ ID: azID, Status: acme.StatusValid, Error: acme.NewError(acme.ErrorMalformedType, "malformed"), } return test{ az: updAz, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), azID) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, old, b) dbOld := new(dbAuthz) assert.FatalError(t, json.Unmarshal(old, dbOld)) assert.Equals(t, dbaz, dbOld) dbNew := new(dbAuthz) assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbaz.ID) assert.Equals(t, dbNew.AccountID, dbaz.AccountID) assert.Equals(t, dbNew.Identifier, dbaz.Identifier) assert.Equals(t, dbNew.Status, acme.StatusValid) assert.Equals(t, dbNew.Token, dbaz.Token) assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs) assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error()) return nil, false, errors.New("force") }, }, err: errors.New("error saving acme authz: force"), } }, "ok": func(t *testing.T) test { updAz := &acme.Authorization{ ID: azID, AccountID: dbaz.AccountID, Status: acme.StatusValid, Identifier: dbaz.Identifier, Challenges: []*acme.Challenge{ {ID: "foo"}, {ID: "bar"}, }, Token: dbaz.Token, Wildcard: dbaz.Wildcard, ExpiresAt: dbaz.ExpiresAt, Fingerprint: "fingerprint", Error: acme.NewError(acme.ErrorMalformedType, "malformed"), } return test{ az: updAz, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, string(key), azID) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, authzTable) assert.Equals(t, old, b) dbOld := new(dbAuthz) assert.FatalError(t, json.Unmarshal(old, dbOld)) assert.Equals(t, dbaz, dbOld) dbNew := new(dbAuthz) assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbaz.ID) assert.Equals(t, dbNew.AccountID, dbaz.AccountID) assert.Equals(t, dbNew.Identifier, dbaz.Identifier) assert.Equals(t, dbNew.Status, acme.StatusValid) assert.Equals(t, dbNew.Token, dbaz.Token) assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs) assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) assert.Equals(t, dbNew.Fingerprint, dbaz.Fingerprint) assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error()) return nu, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.UpdateAuthorization(context.Background(), tc.az); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.az.ID, dbaz.ID) assert.Equals(t, tc.az.AccountID, dbaz.AccountID) assert.Equals(t, tc.az.Identifier, dbaz.Identifier) assert.Equals(t, tc.az.Status, acme.StatusValid) assert.Equals(t, tc.az.Wildcard, dbaz.Wildcard) assert.Equals(t, tc.az.Token, dbaz.Token) assert.Equals(t, tc.az.ExpiresAt, dbaz.ExpiresAt) assert.Equals(t, tc.az.Challenges, []*acme.Challenge{ {ID: "foo"}, {ID: "bar"}, }) assert.Equals(t, tc.az.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) } } }) } } func TestDB_GetAuthorizationsByAccountID(t *testing.T) { azID := "azID" accountID := "accountID" type test struct { db nosql.DB err error acmeErr *acme.Error authzs []*acme.Authorization } var tests = map[string]func(t *testing.T) test{ "fail/db.List-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, authzTable) return nil, errors.New("force") }, }, err: errors.New("error listing authz: force"), } }, "fail/unmarshal": func(t *testing.T) test { b := []byte(`{malformed}`) return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, authzTable) return []*nosqldb.Entry{ { Bucket: bucket, Key: []byte(azID), Value: b, }, }, nil }, }, authzs: nil, err: fmt.Errorf("error unmarshaling dbAuthz key '%s' into dbAuthz struct", azID), } }, "ok": func(t *testing.T) test { now := clock.Now() dbaz := &dbAuthz{ ID: azID, AccountID: accountID, Identifier: acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }, Status: acme.StatusValid, Token: "token", CreatedAt: now, ExpiresAt: now.Add(5 * time.Minute), ChallengeIDs: []string{"foo", "bar"}, Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, authzTable) return []*nosqldb.Entry{ { Bucket: bucket, Key: []byte(azID), Value: b, }, }, nil }, }, authzs: []*acme.Authorization{ { ID: dbaz.ID, AccountID: dbaz.AccountID, Token: dbaz.Token, Identifier: dbaz.Identifier, Status: dbaz.Status, Challenges: nil, Wildcard: dbaz.Wildcard, ExpiresAt: dbaz.ExpiresAt, Error: dbaz.Error, }, }, } }, "ok/skip-different-account": func(t *testing.T) test { now := clock.Now() dbaz := &dbAuthz{ ID: azID, AccountID: "differentAccountID", Identifier: acme.Identifier{ Type: "dns", Value: "test.ca.smallstep.com", }, Status: acme.StatusValid, Token: "token", CreatedAt: now, ExpiresAt: now.Add(5 * time.Minute), ChallengeIDs: []string{"foo", "bar"}, Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, authzTable) return []*nosqldb.Entry{ { Bucket: bucket, Key: []byte(azID), Value: b, }, }, nil }, }, authzs: []*acme.Authorization{}, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if azs, err := d.GetAuthorizationsByAccountID(context.Background(), accountID); err != nil { var acmeErr *acme.Error if errors.As(err, &acmeErr) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, acmeErr.Type, tc.acmeErr.Type) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) assert.Equals(t, acmeErr.Status, tc.acmeErr.Status) assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { if !cmp.Equal(azs, tc.authzs) { t.Errorf("db.GetAuthorizationsByAccountID() diff =\n%s", cmp.Diff(azs, tc.authzs)) } } }) } } ================================================ FILE: acme/db/nosql/certificate.go ================================================ package nosql import ( "context" "crypto/x509" "encoding/json" "encoding/pem" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/nosql" ) type dbCert struct { ID string `json:"id"` CreatedAt time.Time `json:"createdAt"` AccountID string `json:"accountID"` OrderID string `json:"orderID"` Leaf []byte `json:"leaf"` Intermediates []byte `json:"intermediates"` } type dbSerial struct { Serial string `json:"serial"` CertificateID string `json:"certificateID"` } // CreateCertificate creates and stores an ACME certificate type. func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error { var err error cert.ID, err = randID() if err != nil { return err } leaf := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: cert.Leaf.Raw, }) var intermediates []byte for _, cert := range cert.Intermediates { intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: cert.Raw, })...) } dbch := &dbCert{ ID: cert.ID, AccountID: cert.AccountID, OrderID: cert.OrderID, Leaf: leaf, Intermediates: intermediates, CreatedAt: time.Now().UTC(), } err = db.save(ctx, cert.ID, dbch, nil, "certificate", certTable) if err != nil { return err } serial := cert.Leaf.SerialNumber.String() dbSerial := &dbSerial{ Serial: serial, CertificateID: cert.ID, } return db.save(ctx, serial, dbSerial, nil, "serial", certBySerialTable) } // GetCertificate retrieves and unmarshals an ACME certificate type from the // datastore. func (db *DB) GetCertificate(_ context.Context, id string) (*acme.Certificate, error) { b, err := db.db.Get(certTable, []byte(id)) if nosql.IsErrNotFound(err) { return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading certificate %s", id) } dbC := new(dbCert) if err := json.Unmarshal(b, dbC); err != nil { return nil, errors.Wrapf(err, "error unmarshaling certificate %s", id) } certs, err := parseBundle(append(dbC.Leaf, dbC.Intermediates...)) if err != nil { return nil, errors.Wrapf(err, "error parsing certificate chain for ACME certificate with ID %s", id) } return &acme.Certificate{ ID: dbC.ID, AccountID: dbC.AccountID, OrderID: dbC.OrderID, Leaf: certs[0], Intermediates: certs[1:], }, nil } // GetCertificateBySerial retrieves and unmarshals an ACME certificate type from the // datastore based on a certificate serial number. func (db *DB) GetCertificateBySerial(ctx context.Context, serial string) (*acme.Certificate, error) { b, err := db.db.Get(certBySerialTable, []byte(serial)) if nosql.IsErrNotFound(err) { return nil, acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial) } else if err != nil { return nil, errors.Wrapf(err, "error loading certificate ID for serial %s", serial) } dbSerial := new(dbSerial) if err := json.Unmarshal(b, dbSerial); err != nil { return nil, errors.Wrapf(err, "error unmarshaling certificate with serial %s", serial) } return db.GetCertificate(ctx, dbSerial.CertificateID) } func parseBundle(b []byte) ([]*x509.Certificate, error) { var ( err error block *pem.Block bundle []*x509.Certificate ) for len(b) > 0 { block, b = pem.Decode(b) if block == nil { break } if block.Type != "CERTIFICATE" { return nil, errors.New("error decoding PEM: data contains block that is not a certificate") } var crt *x509.Certificate crt, err = x509.ParseCertificate(block.Bytes) if err != nil { return nil, errors.Wrapf(err, "error parsing x509 certificate") } bundle = append(bundle, crt) } if len(b) > 0 { return nil, errors.New("error decoding PEM: unexpected data") } return bundle, nil } ================================================ FILE: acme/db/nosql/certificate_test.go ================================================ package nosql import ( "bytes" "context" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" nosqldb "github.com/smallstep/nosql/database" "go.step.sm/crypto/pemutil" ) func TestDB_CreateCertificate(t *testing.T) { leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") assert.FatalError(t, err) inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") assert.FatalError(t, err) root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") assert.FatalError(t, err) type test struct { db nosql.DB cert *acme.Certificate err error _id *string } var tests = map[string]func(t *testing.T) test{ "fail/cmpAndSwap-error": func(t *testing.T) test { cert := &acme.Certificate{ AccountID: "accountID", OrderID: "orderID", Leaf: leaf, Intermediates: []*x509.Certificate{inter, root}, } return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, certTable) assert.Equals(t, key, []byte(cert.ID)) assert.Equals(t, old, nil) dbc := new(dbCert) assert.FatalError(t, json.Unmarshal(nu, dbc)) assert.Equals(t, dbc.ID, string(key)) assert.Equals(t, dbc.ID, cert.ID) assert.Equals(t, dbc.AccountID, cert.AccountID) assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) return nil, false, errors.New("force") }, }, cert: cert, err: errors.New("error saving acme certificate: force"), } }, "ok": func(t *testing.T) test { cert := &acme.Certificate{ AccountID: "accountID", OrderID: "orderID", Leaf: leaf, Intermediates: []*x509.Certificate{inter, root}, } var ( id string idPtr = &id ) return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { if !bytes.Equal(bucket, certTable) && !bytes.Equal(bucket, certBySerialTable) { t.Fail() } if bytes.Equal(bucket, certTable) { *idPtr = string(key) assert.Equals(t, bucket, certTable) assert.Equals(t, key, []byte(cert.ID)) assert.Equals(t, old, nil) dbc := new(dbCert) assert.FatalError(t, json.Unmarshal(nu, dbc)) assert.Equals(t, dbc.ID, string(key)) assert.Equals(t, dbc.ID, cert.ID) assert.Equals(t, dbc.AccountID, cert.AccountID) assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) } if bytes.Equal(bucket, certBySerialTable) { assert.Equals(t, bucket, certBySerialTable) assert.Equals(t, key, []byte(cert.Leaf.SerialNumber.String())) assert.Equals(t, old, nil) dbs := new(dbSerial) assert.FatalError(t, json.Unmarshal(nu, dbs)) assert.Equals(t, dbs.Serial, string(key)) assert.Equals(t, dbs.CertificateID, cert.ID) *idPtr = cert.ID } return nil, true, nil }, }, _id: idPtr, cert: cert, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.CreateCertificate(context.Background(), tc.cert); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.cert.ID, *tc._id) } } }) } } func TestDB_GetCertificate(t *testing.T) { leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") assert.FatalError(t, err) inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") assert.FatalError(t, err) root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") assert.FatalError(t, err) certID := "certID" type test struct { db nosql.DB err error acmeErr *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, certTable) assert.Equals(t, string(key), certID) return nil, nosqldb.ErrNotFound }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate certID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, certTable) assert.Equals(t, string(key), certID) return nil, errors.Errorf("force") }, }, err: errors.New("error loading certificate certID: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, certTable) assert.Equals(t, string(key), certID) return []byte("foobar"), nil }, }, err: errors.New("error unmarshaling certificate certID"), } }, "fail/parseBundle-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, certTable) assert.Equals(t, string(key), certID) cert := dbCert{ ID: certID, AccountID: "accountID", OrderID: "orderID", Leaf: pem.EncodeToMemory(&pem.Block{ Type: "Public Key", Bytes: leaf.Raw, }), CreatedAt: clock.Now(), } b, err := json.Marshal(cert) assert.FatalError(t, err) return b, nil }, }, err: errors.Errorf("error parsing certificate chain for ACME certificate with ID certID"), } }, "ok": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, certTable) assert.Equals(t, string(key), certID) cert := dbCert{ ID: certID, AccountID: "accountID", OrderID: "orderID", Leaf: pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: leaf.Raw, }), Intermediates: append(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: inter.Raw, }), pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: root.Raw, })...), CreatedAt: clock.Now(), } b, err := json.Marshal(cert) assert.FatalError(t, err) return b, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} cert, err := d.GetCertificate(context.Background(), certID) if err != nil { var acmeErr *acme.Error if errors.As(err, &acmeErr) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, acmeErr.Type, tc.acmeErr.Type) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) assert.Equals(t, acmeErr.Status, tc.acmeErr.Status) assert.Equals(t, acmeErr.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, acmeErr.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, cert.ID, certID) assert.Equals(t, cert.AccountID, "accountID") assert.Equals(t, cert.OrderID, "orderID") assert.Equals(t, cert.Leaf, leaf) assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) } }) } } func Test_parseBundle(t *testing.T) { leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") assert.FatalError(t, err) inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") assert.FatalError(t, err) root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") assert.FatalError(t, err) var certs []byte for _, cert := range []*x509.Certificate{leaf, inter, root} { certs = append(certs, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: cert.Raw, })...) } type test struct { b []byte err error } var tests = map[string]test{ "fail/bad-type-error": { b: pem.EncodeToMemory(&pem.Block{ Type: "Public Key", Bytes: leaf.Raw, }), err: errors.Errorf("error decoding PEM: data contains block that is not a certificate"), }, "fail/bad-pem-error": { b: pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: []byte("foo"), }), err: errors.Errorf("error parsing x509 certificate"), }, "fail/unexpected-data": { b: append(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: leaf.Raw, }), []byte("foo")...), err: errors.Errorf("error decoding PEM: unexpected data"), }, "ok": { b: certs, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { ret, err := parseBundle(tc.b) if err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, ret, []*x509.Certificate{leaf, inter, root}) } } }) } } func TestDB_GetCertificateBySerial(t *testing.T) { leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") assert.FatalError(t, err) inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") assert.FatalError(t, err) root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") assert.FatalError(t, err) certID := "certID" serial := "" type test struct { db nosql.DB err error acmeErr *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { if bytes.Equal(bucket, certBySerialTable) { return nil, nosqldb.ErrNotFound } return nil, errors.New("wrong table") }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial), } }, "fail/db-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { if bytes.Equal(bucket, certBySerialTable) { return nil, errors.New("force") } return nil, errors.New("wrong table") }, }, err: fmt.Errorf("error loading certificate ID for serial %s", serial), } }, "fail/unmarshal-dbSerial": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { if bytes.Equal(bucket, certBySerialTable) { return []byte(`{"serial":malformed!}`), nil } return nil, errors.New("wrong table") }, }, err: fmt.Errorf("error unmarshaling certificate with serial %s", serial), } }, "ok": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { if bytes.Equal(bucket, certBySerialTable) { certSerial := dbSerial{ Serial: serial, CertificateID: certID, } b, err := json.Marshal(certSerial) assert.FatalError(t, err) return b, nil } if bytes.Equal(bucket, certTable) { cert := dbCert{ ID: certID, AccountID: "accountID", OrderID: "orderID", Leaf: pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: leaf.Raw, }), Intermediates: append(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: inter.Raw, }), pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: root.Raw, })...), CreatedAt: clock.Now(), } b, err := json.Marshal(cert) assert.FatalError(t, err) return b, nil } return nil, errors.New("wrong table") }, }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} cert, err := d.GetCertificateBySerial(context.Background(), serial) if err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, cert.ID, certID) assert.Equals(t, cert.AccountID, "accountID") assert.Equals(t, cert.OrderID, "orderID") assert.Equals(t, cert.Leaf, leaf) assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) } }) } } ================================================ FILE: acme/db/nosql/challenge.go ================================================ package nosql import ( "context" "encoding/json" "time" "github.com/pkg/errors" "github.com/smallstep/nosql" "github.com/smallstep/certificates/acme" ) type dbChallenge struct { ID string `json:"id"` AccountID string `json:"accountID"` Type acme.ChallengeType `json:"type"` Status acme.Status `json:"status"` Token string `json:"token"` Value string `json:"value"` Target string `json:"target,omitempty"` ValidatedAt string `json:"validatedAt"` CreatedAt time.Time `json:"createdAt"` Error *acme.Error `json:"error"` // TODO(hs): a bit dangerous; should become db-specific type } func (dbc *dbChallenge) clone() *dbChallenge { u := *dbc return &u } func (db *DB) getDBChallenge(_ context.Context, id string) (*dbChallenge, error) { data, err := db.db.Get(challengeTable, []byte(id)) if nosql.IsErrNotFound(err) { return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading acme challenge %s", id) } dbch := new(dbChallenge) if err := json.Unmarshal(data, dbch); err != nil { return nil, errors.Wrap(err, "error unmarshaling dbChallenge") } return dbch, nil } // CreateChallenge creates a new ACME challenge data structure in the database. // Implements acme.DB.CreateChallenge interface. func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error { var err error ch.ID, err = randID() if err != nil { return errors.Wrap(err, "error generating random id for ACME challenge") } dbch := &dbChallenge{ ID: ch.ID, AccountID: ch.AccountID, Value: ch.Value, Status: acme.StatusPending, Token: ch.Token, CreatedAt: clock.Now(), Type: ch.Type, Target: ch.Target, } return db.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable) } // GetChallenge retrieves and unmarshals an ACME challenge type from the database. // Implements the acme.DB GetChallenge interface. func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) { _ = authzID // unused input dbch, err := db.getDBChallenge(ctx, id) if err != nil { return nil, err } ch := &acme.Challenge{ ID: dbch.ID, AccountID: dbch.AccountID, Type: dbch.Type, Value: dbch.Value, Status: dbch.Status, Token: dbch.Token, Error: dbch.Error, ValidatedAt: dbch.ValidatedAt, Target: dbch.Target, } return ch, nil } // UpdateChallenge updates an ACME challenge type in the database. func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error { old, err := db.getDBChallenge(ctx, ch.ID) if err != nil { return err } nu := old.clone() // These should be the only values changing in an Update request. nu.Status = ch.Status nu.Error = ch.Error nu.ValidatedAt = ch.ValidatedAt return db.save(ctx, old.ID, nu, old, "challenge", challengeTable) } ================================================ FILE: acme/db/nosql/challenge_test.go ================================================ package nosql import ( "context" "encoding/json" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" nosqldb "github.com/smallstep/nosql/database" ) func TestDB_getDBChallenge(t *testing.T) { chID := "chID" type test struct { db nosql.DB err error acmeErr *acme.Error dbc *dbChallenge } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), chID) return nil, nosqldb.ErrNotFound }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), chID) return nil, errors.New("force") }, }, err: errors.New("error loading acme challenge chID: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), chID) return []byte("foo"), nil }, }, err: errors.New("error unmarshaling dbChallenge"), } }, "ok": func(t *testing.T) test { dbc := &dbChallenge{ ID: chID, AccountID: "accountID", Type: "dns-01", Status: acme.StatusPending, Token: "token", Value: "test.ca.smallstep.com", CreatedAt: clock.Now(), ValidatedAt: "foobar", Error: acme.NewErrorISE("The server experienced an internal error"), } b, err := json.Marshal(dbc) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), chID) return b, nil }, }, dbc: dbc, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if ch, err := d.getDBChallenge(context.Background(), chID); err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, ch.ID, tc.dbc.ID) assert.Equals(t, ch.AccountID, tc.dbc.AccountID) assert.Equals(t, ch.Type, tc.dbc.Type) assert.Equals(t, ch.Status, tc.dbc.Status) assert.Equals(t, ch.Token, tc.dbc.Token) assert.Equals(t, ch.Value, tc.dbc.Value) assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error()) } }) } } func TestDB_CreateChallenge(t *testing.T) { type test struct { db nosql.DB ch *acme.Challenge err error _id *string } var tests = map[string]func(t *testing.T) test{ "fail/cmpAndSwap-error": func(t *testing.T) test { ch := &acme.Challenge{ AccountID: "accountID", Type: "dns-01", Status: acme.StatusPending, Token: "token", Value: "test.ca.smallstep.com", } return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), ch.ID) assert.Equals(t, old, nil) dbc := new(dbChallenge) assert.FatalError(t, json.Unmarshal(nu, dbc)) assert.Equals(t, dbc.ID, string(key)) assert.Equals(t, dbc.AccountID, ch.AccountID) assert.Equals(t, dbc.Type, ch.Type) assert.Equals(t, dbc.Status, ch.Status) assert.Equals(t, dbc.Token, ch.Token) assert.Equals(t, dbc.Value, ch.Value) assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) return nil, false, errors.New("force") }, }, ch: ch, err: errors.New("error saving acme challenge: force"), } }, "ok": func(t *testing.T) test { var ( id string idPtr = &id ch = &acme.Challenge{ AccountID: "accountID", Type: "dns-01", Status: acme.StatusPending, Token: "token", Value: "test.ca.smallstep.com", } ) return test{ ch: ch, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { *idPtr = string(key) assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), ch.ID) assert.Equals(t, old, nil) dbc := new(dbChallenge) assert.FatalError(t, json.Unmarshal(nu, dbc)) assert.Equals(t, dbc.ID, string(key)) assert.Equals(t, dbc.AccountID, ch.AccountID) assert.Equals(t, dbc.Type, ch.Type) assert.Equals(t, dbc.Status, ch.Status) assert.Equals(t, dbc.Token, ch.Token) assert.Equals(t, dbc.Value, ch.Value) assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) return nil, true, nil }, }, _id: idPtr, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.CreateChallenge(context.Background(), tc.ch); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.ch.ID, *tc._id) } } }) } } func TestDB_GetChallenge(t *testing.T) { chID := "chID" azID := "azID" type test struct { db nosql.DB err error acmeErr *acme.Error dbc *dbChallenge } var tests = map[string]func(t *testing.T) test{ "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), chID) return nil, errors.New("force") }, }, err: errors.New("error loading acme challenge chID: force"), } }, "fail/forward-acme-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), chID) return nil, nosqldb.ErrNotFound }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"), } }, "ok": func(t *testing.T) test { dbc := &dbChallenge{ ID: chID, AccountID: "accountID", Type: "dns-01", Status: acme.StatusPending, Token: "token", Value: "test.ca.smallstep.com", CreatedAt: clock.Now(), ValidatedAt: "foobar", Error: acme.NewErrorISE("The server experienced an internal error"), } b, err := json.Marshal(dbc) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), chID) return b, nil }, }, dbc: dbc, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if ch, err := d.GetChallenge(context.Background(), chID, azID); err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, ch.ID, tc.dbc.ID) assert.Equals(t, ch.AccountID, tc.dbc.AccountID) assert.Equals(t, ch.Type, tc.dbc.Type) assert.Equals(t, ch.Status, tc.dbc.Status) assert.Equals(t, ch.Token, tc.dbc.Token) assert.Equals(t, ch.Value, tc.dbc.Value) assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error()) } }) } } func TestDB_UpdateChallenge(t *testing.T) { chID := "chID" dbc := &dbChallenge{ ID: chID, AccountID: "accountID", Type: "dns-01", Status: acme.StatusPending, Token: "token", Value: "test.ca.smallstep.com", CreatedAt: clock.Now(), } b, err := json.Marshal(dbc) assert.FatalError(t, err) type test struct { db nosql.DB ch *acme.Challenge err error } var tests = map[string]func(t *testing.T) test{ "fail/db.Get-error": func(t *testing.T) test { return test{ ch: &acme.Challenge{ ID: chID, }, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), chID) return nil, errors.New("force") }, }, err: errors.New("error loading acme challenge chID: force"), } }, "fail/db.CmpAndSwap-error": func(t *testing.T) test { updCh := &acme.Challenge{ ID: chID, Status: acme.StatusValid, ValidatedAt: "foobar", Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"), } return test{ ch: updCh, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), chID) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, old, b) dbOld := new(dbChallenge) assert.FatalError(t, json.Unmarshal(old, dbOld)) assert.Equals(t, dbc, dbOld) dbNew := new(dbChallenge) assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbc.ID) assert.Equals(t, dbNew.AccountID, dbc.AccountID) assert.Equals(t, dbNew.Type, dbc.Type) assert.Equals(t, dbNew.Status, updCh.Status) assert.Equals(t, dbNew.Token, dbc.Token) assert.Equals(t, dbNew.Value, dbc.Value) assert.Equals(t, dbNew.Error.Error(), updCh.Error.Error()) assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt) assert.Equals(t, dbNew.ValidatedAt, updCh.ValidatedAt) return nil, false, errors.New("force") }, }, err: errors.New("error saving acme challenge: force"), } }, "ok": func(t *testing.T) test { updCh := &acme.Challenge{ ID: dbc.ID, AccountID: dbc.AccountID, Type: dbc.Type, Token: dbc.Token, Value: dbc.Value, Status: acme.StatusValid, ValidatedAt: "foobar", Error: acme.NewError(acme.ErrorMalformedType, "malformed"), } return test{ ch: updCh, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), chID) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, old, b) dbOld := new(dbChallenge) assert.FatalError(t, json.Unmarshal(old, dbOld)) assert.Equals(t, dbc, dbOld) dbNew := new(dbChallenge) assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbc.ID) assert.Equals(t, dbNew.AccountID, dbc.AccountID) assert.Equals(t, dbNew.Type, dbc.Type) assert.Equals(t, dbNew.Token, dbc.Token) assert.Equals(t, dbNew.Value, dbc.Value) assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt) assert.Equals(t, dbNew.Status, acme.StatusValid) assert.Equals(t, dbNew.ValidatedAt, "foobar") assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error()) return nu, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.UpdateChallenge(context.Background(), tc.ch); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.ch.ID, dbc.ID) assert.Equals(t, tc.ch.AccountID, dbc.AccountID) assert.Equals(t, tc.ch.Type, dbc.Type) assert.Equals(t, tc.ch.Token, dbc.Token) assert.Equals(t, tc.ch.Value, dbc.Value) assert.Equals(t, tc.ch.ValidatedAt, "foobar") assert.Equals(t, tc.ch.Status, acme.StatusValid) assert.Equals(t, tc.ch.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error()) } } }) } } ================================================ FILE: acme/db/nosql/eab.go ================================================ package nosql import ( "context" "crypto/rand" "encoding/json" "sync" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/acme" nosqlDB "github.com/smallstep/nosql" ) // externalAccountKeyMutex for read/write locking of EAK operations. var externalAccountKeyMutex sync.RWMutex // referencesByProvisionerIndexMutex for locking referencesByProvisioner index operations. var referencesByProvisionerIndexMutex sync.Mutex type dbExternalAccountKey struct { ID string `json:"id"` ProvisionerID string `json:"provisionerID"` Reference string `json:"reference"` AccountID string `json:"accountID,omitempty"` HmacKey []byte `json:"key"` CreatedAt time.Time `json:"createdAt"` BoundAt time.Time `json:"boundAt"` } type dbExternalAccountKeyReference struct { Reference string `json:"reference"` ExternalAccountKeyID string `json:"externalAccountKeyID"` } // getDBExternalAccountKey retrieves and unmarshals dbExternalAccountKey. func (db *DB) getDBExternalAccountKey(_ context.Context, id string) (*dbExternalAccountKey, error) { data, err := db.db.Get(externalAccountKeyTable, []byte(id)) if err != nil { if nosqlDB.IsErrNotFound(err) { return nil, acme.ErrNotFound } return nil, errors.Wrapf(err, "error loading external account key %s", id) } dbeak := new(dbExternalAccountKey) if err = json.Unmarshal(data, dbeak); err != nil { return nil, errors.Wrapf(err, "error unmarshaling external account key %s into dbExternalAccountKey", id) } return dbeak, nil } // CreateExternalAccountKey creates a new External Account Binding key with a name func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { externalAccountKeyMutex.Lock() defer externalAccountKeyMutex.Unlock() keyID, err := randID() if err != nil { return nil, err } random := make([]byte, 32) _, err = rand.Read(random) if err != nil { return nil, err } dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provisionerID, Reference: reference, HmacKey: random, CreatedAt: clock.Now(), } if err := db.save(ctx, keyID, dbeak, nil, "external_account_key", externalAccountKeyTable); err != nil { return nil, err } if err := db.addEAKID(ctx, provisionerID, dbeak.ID); err != nil { return nil, err } if dbeak.Reference != "" { dbExternalAccountKeyReference := &dbExternalAccountKeyReference{ Reference: dbeak.Reference, ExternalAccountKeyID: dbeak.ID, } if err := db.save(ctx, referenceKey(provisionerID, dbeak.Reference), dbExternalAccountKeyReference, nil, "external_account_key_reference", externalAccountKeyIDsByReferenceTable); err != nil { return nil, err } } return &acme.ExternalAccountKey{ ID: dbeak.ID, ProvisionerID: dbeak.ProvisionerID, Reference: dbeak.Reference, AccountID: dbeak.AccountID, HmacKey: dbeak.HmacKey, CreatedAt: dbeak.CreatedAt, BoundAt: dbeak.BoundAt, }, nil } // GetExternalAccountKey retrieves an External Account Binding key by KeyID func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { externalAccountKeyMutex.RLock() defer externalAccountKeyMutex.RUnlock() dbeak, err := db.getDBExternalAccountKey(ctx, keyID) if err != nil { return nil, err } if dbeak.ProvisionerID != provisionerID { return nil, acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created") } return &acme.ExternalAccountKey{ ID: dbeak.ID, ProvisionerID: dbeak.ProvisionerID, Reference: dbeak.Reference, AccountID: dbeak.AccountID, HmacKey: dbeak.HmacKey, CreatedAt: dbeak.CreatedAt, BoundAt: dbeak.BoundAt, }, nil } func (db *DB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error { externalAccountKeyMutex.Lock() defer externalAccountKeyMutex.Unlock() dbeak, err := db.getDBExternalAccountKey(ctx, keyID) if err != nil { return errors.Wrapf(err, "error loading ACME EAB Key with Key ID %s", keyID) } if dbeak.ProvisionerID != provisionerID { return errors.New("provisioner does not match provisioner for which the EAB key was created") } if dbeak.Reference != "" { if err := db.db.Del(externalAccountKeyIDsByReferenceTable, []byte(referenceKey(provisionerID, dbeak.Reference))); err != nil { return errors.Wrapf(err, "error deleting ACME EAB Key reference with Key ID %s and reference %s", keyID, dbeak.Reference) } } if err := db.db.Del(externalAccountKeyTable, []byte(keyID)); err != nil { return errors.Wrapf(err, "error deleting ACME EAB Key with Key ID %s", keyID) } if err := db.deleteEAKID(ctx, provisionerID, keyID); err != nil { return errors.Wrapf(err, "error removing ACME EAB Key ID %s", keyID) } return nil } // GetExternalAccountKeys retrieves all External Account Binding keys for a provisioner func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*acme.ExternalAccountKey, string, error) { _, _ = cursor, limit // unused input externalAccountKeyMutex.RLock() defer externalAccountKeyMutex.RUnlock() // cursor and limit are ignored in open source, at least for now. var eakIDs []string r, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID)) if err != nil { if !nosqlDB.IsErrNotFound(err) { return nil, "", errors.Wrapf(err, "error loading ACME EAB Key IDs for provisioner %s", provisionerID) } // it may happen that no record is found; we'll continue with an empty slice } else { if err := json.Unmarshal(r, &eakIDs); err != nil { return nil, "", errors.Wrapf(err, "error unmarshaling ACME EAB Key IDs for provisioner %s", provisionerID) } } keys := []*acme.ExternalAccountKey{} for _, eakID := range eakIDs { if eakID == "" { continue // shouldn't happen; just in case } eak, err := db.getDBExternalAccountKey(ctx, eakID) if err != nil { if !nosqlDB.IsErrNotFound(err) { return nil, "", errors.Wrapf(err, "error retrieving ACME EAB Key for provisioner %s and keyID %s", provisionerID, eakID) } } keys = append(keys, &acme.ExternalAccountKey{ ID: eak.ID, HmacKey: eak.HmacKey, ProvisionerID: eak.ProvisionerID, Reference: eak.Reference, AccountID: eak.AccountID, CreatedAt: eak.CreatedAt, BoundAt: eak.BoundAt, }) } return keys, "", nil } // GetExternalAccountKeyByReference retrieves an External Account Binding key with unique reference func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { externalAccountKeyMutex.RLock() defer externalAccountKeyMutex.RUnlock() if reference == "" { //nolint:nilnil // legacy return nil, nil } k, err := db.db.Get(externalAccountKeyIDsByReferenceTable, []byte(referenceKey(provisionerID, reference))) if nosqlDB.IsErrNotFound(err) { return nil, acme.ErrNotFound } else if err != nil { return nil, errors.Wrapf(err, "error loading ACME EAB key for reference %s", reference) } dbExternalAccountKeyReference := new(dbExternalAccountKeyReference) if err := json.Unmarshal(k, dbExternalAccountKeyReference); err != nil { return nil, errors.Wrapf(err, "error unmarshaling ACME EAB key for reference %s", reference) } return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID) } func (db *DB) GetExternalAccountKeyByAccountID(context.Context, string, string) (*acme.ExternalAccountKey, error) { //nolint:nilnil // legacy return nil, nil } func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { externalAccountKeyMutex.Lock() defer externalAccountKeyMutex.Unlock() old, err := db.getDBExternalAccountKey(ctx, eak.ID) if err != nil { return err } if old.ProvisionerID != provisionerID { return errors.New("provisioner does not match provisioner for which the EAB key was created") } if old.ProvisionerID != eak.ProvisionerID { return errors.New("cannot change provisioner for an existing ACME EAB Key") } if old.Reference != eak.Reference { return errors.New("cannot change reference for an existing ACME EAB Key") } nu := dbExternalAccountKey{ ID: eak.ID, ProvisionerID: eak.ProvisionerID, Reference: eak.Reference, AccountID: eak.AccountID, HmacKey: eak.HmacKey, CreatedAt: eak.CreatedAt, BoundAt: eak.BoundAt, } return db.save(ctx, nu.ID, nu, old, "external_account_key", externalAccountKeyTable) } func (db *DB) addEAKID(ctx context.Context, provisionerID, eakID string) error { referencesByProvisionerIndexMutex.Lock() defer referencesByProvisionerIndexMutex.Unlock() if eakID == "" { return errors.Errorf("can't add empty eakID for provisioner %s", provisionerID) } var eakIDs []string b, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID)) if err != nil { if !nosqlDB.IsErrNotFound(err) { return errors.Wrapf(err, "error loading eakIDs for provisioner %s", provisionerID) } // it may happen that no record is found; we'll continue with an empty slice } else { if err := json.Unmarshal(b, &eakIDs); err != nil { return errors.Wrapf(err, "error unmarshaling eakIDs for provisioner %s", provisionerID) } } for _, id := range eakIDs { if id == eakID { // return an error when a duplicate ID is found return errors.Errorf("eakID %s already exists for provisioner %s", eakID, provisionerID) } } var newEAKIDs []string newEAKIDs = append(newEAKIDs, eakIDs...) newEAKIDs = append(newEAKIDs, eakID) var ( _old interface{} = eakIDs _new interface{} = newEAKIDs ) // ensure that the DB gets the expected value when the slice is empty; otherwise // it'll return with an error that indicates that the DBs view of the data is // different from the last read (i.e. _old is different from what the DB has). if len(eakIDs) == 0 { _old = nil } if err = db.save(ctx, provisionerID, _new, _old, "externalAccountKeyIDsByProvisionerID", externalAccountKeyIDsByProvisionerIDTable); err != nil { return errors.Wrapf(err, "error saving eakIDs index for provisioner %s", provisionerID) } return nil } func (db *DB) deleteEAKID(ctx context.Context, provisionerID, eakID string) error { referencesByProvisionerIndexMutex.Lock() defer referencesByProvisionerIndexMutex.Unlock() var eakIDs []string b, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID)) if err != nil { if !nosqlDB.IsErrNotFound(err) { return errors.Wrapf(err, "error loading eakIDs for provisioner %s", provisionerID) } // it may happen that no record is found; we'll continue with an empty slice } else { if err := json.Unmarshal(b, &eakIDs); err != nil { return errors.Wrapf(err, "error unmarshaling eakIDs for provisioner %s", provisionerID) } } newEAKIDs := removeElement(eakIDs, eakID) var ( _old interface{} = eakIDs _new interface{} = newEAKIDs ) // ensure that the DB gets the expected value when the slice is empty; otherwise // it'll return with an error that indicates that the DBs view of the data is // different from the last read (i.e. _old is different from what the DB has). if len(eakIDs) == 0 { _old = nil } if err = db.save(ctx, provisionerID, _new, _old, "externalAccountKeyIDsByProvisionerID", externalAccountKeyIDsByProvisionerIDTable); err != nil { return errors.Wrapf(err, "error saving eakIDs index for provisioner %s", provisionerID) } return nil } // referenceKey returns a unique key for a reference per provisioner func referenceKey(provisionerID, reference string) string { return provisionerID + "." + reference } // sliceIndex finds the index of item in slice func sliceIndex(slice []string, item string) int { for i := range slice { if slice[i] == item { return i } } return -1 } // removeElement deletes the item if it exists in the // slice. It returns a new slice, keeping the old one intact. func removeElement(slice []string, item string) []string { newSlice := make([]string, 0) index := sliceIndex(slice, item) if index < 0 { newSlice = append(newSlice, slice...) return newSlice } newSlice = append(newSlice, slice[:index]...) return append(newSlice, slice[index+1:]...) } ================================================ FILE: acme/db/nosql/eab_test.go ================================================ package nosql import ( "context" "encoding/json" "fmt" "testing" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" certdb "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" nosqldb "github.com/smallstep/nosql/database" ) func TestDB_getDBExternalAccountKey(t *testing.T) { keyID := "keyID" provID := "provID" type test struct { db nosql.DB err error acmeErr *acme.Error dbeak *dbExternalAccountKey } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { now := clock.Now() dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: "ref", AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) assert.FatalError(t, err) return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return b, nil }, }, err: nil, dbeak: dbeak, } }, "fail/not-found": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return nil, nosqldb.ErrNotFound }, }, err: acme.ErrNotFound, } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return nil, errors.New("force") }, }, err: errors.New("error loading external account key keyID: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return []byte("foo"), nil }, }, err: errors.New("error unmarshaling external account key keyID into dbExternalAccountKey"), } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if dbeak, err := d.getDBExternalAccountKey(context.Background(), keyID); err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, dbeak.ID, tc.dbeak.ID) assert.Equals(t, dbeak.HmacKey, tc.dbeak.HmacKey) assert.Equals(t, dbeak.ProvisionerID, tc.dbeak.ProvisionerID) assert.Equals(t, dbeak.Reference, tc.dbeak.Reference) assert.Equals(t, dbeak.CreatedAt, tc.dbeak.CreatedAt) assert.Equals(t, dbeak.AccountID, tc.dbeak.AccountID) assert.Equals(t, dbeak.BoundAt, tc.dbeak.BoundAt) } }) } } func TestDB_GetExternalAccountKey(t *testing.T) { keyID := "keyID" provID := "provID" type test struct { db nosql.DB err error acmeErr *acme.Error eak *acme.ExternalAccountKey } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { now := clock.Now() dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: "ref", AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) assert.FatalError(t, err) return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return b, nil }, }, eak: &acme.ExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: "ref", AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return nil, errors.New("force") }, }, err: errors.New("error loading external account key keyID: force"), } }, "fail/non-matching-provisioner": func(t *testing.T) test { now := clock.Now() dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: "aDifferentProvID", Reference: "ref", AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) assert.FatalError(t, err) return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return b, nil }, }, eak: &acme.ExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: "ref", AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, acmeErr: acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created"), } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if eak, err := d.GetExternalAccountKey(context.Background(), provID, keyID); err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, eak.ID, tc.eak.ID) assert.Equals(t, eak.HmacKey, tc.eak.HmacKey) assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, eak.Reference, tc.eak.Reference) assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt) assert.Equals(t, eak.AccountID, tc.eak.AccountID) assert.Equals(t, eak.BoundAt, tc.eak.BoundAt) } }) } } func TestDB_GetExternalAccountKeyByReference(t *testing.T) { keyID := "keyID" provID := "provID" ref := "ref" type test struct { db nosql.DB err error ref string acmeErr *acme.Error eak *acme.ExternalAccountKey } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { now := clock.Now() dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ Reference: ref, ExternalAccountKeyID: keyID, } b, err := json.Marshal(dbeak) assert.FatalError(t, err) dbrefBytes, err := json.Marshal(dbref) assert.FatalError(t, err) return test{ ref: ref, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), provID+"."+ref) return dbrefBytes, nil case string(externalAccountKeyTable): assert.Equals(t, string(key), keyID) return b, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, }, eak: &acme.ExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, err: nil, } }, "ok/no-reference": func(t *testing.T) test { return test{ ref: "", eak: nil, err: nil, } }, "fail/reference-not-found": func(t *testing.T) test { return test{ ref: ref, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(externalAccountKeyIDsByReferenceTable)) assert.Equals(t, string(key), provID+"."+ref) return nil, nosqldb.ErrNotFound }, }, err: errors.New("not found"), } }, "fail/reference-load-error": func(t *testing.T) test { return test{ ref: ref, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(externalAccountKeyIDsByReferenceTable)) assert.Equals(t, string(key), provID+"."+ref) return nil, errors.New("force") }, }, err: errors.New("error loading ACME EAB key for reference ref: force"), } }, "fail/reference-unmarshal-error": func(t *testing.T) test { return test{ ref: ref, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(externalAccountKeyIDsByReferenceTable)) assert.Equals(t, string(key), provID+"."+ref) return []byte{0}, nil }, }, err: errors.New("error unmarshaling ACME EAB key for reference ref"), } }, "fail/db.GetExternalAccountKey-error": func(t *testing.T) test { dbref := &dbExternalAccountKeyReference{ Reference: ref, ExternalAccountKeyID: keyID, } dbrefBytes, err := json.Marshal(dbref) assert.FatalError(t, err) return test{ ref: ref, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), provID+"."+ref) return dbrefBytes, nil case string(externalAccountKeyTable): assert.Equals(t, string(key), keyID) return nil, errors.New("force") default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, }, err: errors.New("error loading external account key keyID: force"), } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if eak, err := d.GetExternalAccountKeyByReference(context.Background(), provID, tc.ref); err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && tc.eak != nil { assert.Equals(t, eak.ID, tc.eak.ID) assert.Equals(t, eak.AccountID, tc.eak.AccountID) assert.Equals(t, eak.BoundAt, tc.eak.BoundAt) assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt) assert.Equals(t, eak.HmacKey, tc.eak.HmacKey) assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, eak.Reference, tc.eak.Reference) } }) } } func TestDB_GetExternalAccountKeys(t *testing.T) { keyID1 := "keyID1" keyID2 := "keyID2" keyID3 := "keyID3" provID := "provID" ref := "ref" type test struct { db nosql.DB err error acmeErr *acme.Error eaks []*acme.ExternalAccountKey } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { now := clock.Now() dbeak1 := &dbExternalAccountKey{ ID: keyID1, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b1, err := json.Marshal(dbeak1) assert.FatalError(t, err) dbeak2 := &dbExternalAccountKey{ ID: keyID2, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b2, err := json.Marshal(dbeak2) assert.FatalError(t, err) dbeak3 := &dbExternalAccountKey{ ID: keyID3, ProvisionerID: "aDifferentProvID", Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b3, err := json.Marshal(dbeak3) assert.FatalError(t, err) return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(externalAccountKeyIDsByProvisionerIDTable): keys := []string{"", keyID1, keyID2} // includes an empty keyID b, err := json.Marshal(keys) assert.FatalError(t, err) return b, nil case string(externalAccountKeyTable): switch string(key) { case keyID1: return b1, nil case keyID2: return b2, nil default: assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) return nil, errors.New("force default") } default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force default") } }, // TODO: remove the MList MList: func(bucket []byte) ([]*nosqldb.Entry, error) { switch string(bucket) { case string(externalAccountKeyTable): return []*nosqldb.Entry{ { Bucket: bucket, Key: []byte(keyID1), Value: b1, }, { Bucket: bucket, Key: []byte(keyID2), Value: b2, }, { Bucket: bucket, Key: []byte(keyID3), Value: b3, }, }, nil case string(externalAccountKeyIDsByProvisionerIDTable): keys := []string{keyID1, keyID2} b, err := json.Marshal(keys) assert.FatalError(t, err) return []*nosqldb.Entry{ { Bucket: bucket, Key: []byte(provID), Value: b, }, }, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force default") } }, }, eaks: []*acme.ExternalAccountKey{ { ID: keyID1, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, { ID: keyID2, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, }, }, } }, "fail/db.Get-externalAccountKeysByProvisionerIDTable": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(externalAccountKeyIDsByProvisionerIDTable)) return nil, errors.New("force") }, }, err: errors.New("error loading ACME EAB Key IDs for provisioner provID: force"), } }, "fail/db.Get-externalAccountKeysByProvisionerIDTable-unmarshal": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(externalAccountKeyIDsByProvisionerIDTable)) b, _ := json.Marshal(1) return b, nil }, }, err: errors.New("error unmarshaling ACME EAB Key IDs for provisioner provID: json: cannot unmarshal number into Go value of type []string"), } }, "fail/db.getDBExternalAccountKey": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(externalAccountKeyIDsByProvisionerIDTable): keys := []string{keyID1, keyID2} b, err := json.Marshal(keys) assert.FatalError(t, err) return b, nil case string(externalAccountKeyTable): return nil, errors.New("force") default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force bucket") } }, }, err: errors.New("error retrieving ACME EAB Key for provisioner provID and keyID keyID1: error loading external account key keyID1: force"), } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} cursor, limit := "", 0 if eaks, nextCursor, err := d.GetExternalAccountKeys(context.Background(), provID, cursor, limit); err != nil { assert.Equals(t, "", nextCursor) var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.Equals(t, tc.err.Error(), err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, len(eaks), len(tc.eaks)) assert.Equals(t, "", nextCursor) for i, eak := range eaks { assert.Equals(t, eak.ID, tc.eaks[i].ID) assert.Equals(t, eak.HmacKey, tc.eaks[i].HmacKey) assert.Equals(t, eak.ProvisionerID, tc.eaks[i].ProvisionerID) assert.Equals(t, eak.Reference, tc.eaks[i].Reference) assert.Equals(t, eak.CreatedAt, tc.eaks[i].CreatedAt) assert.Equals(t, eak.AccountID, tc.eaks[i].AccountID) assert.Equals(t, eak.BoundAt, tc.eaks[i].BoundAt) } } }) } } func TestDB_DeleteExternalAccountKey(t *testing.T) { keyID := "keyID" provID := "provID" ref := "ref" type test struct { db nosql.DB err error acmeErr *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { now := clock.Now() dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ Reference: ref, ExternalAccountKeyID: keyID, } b, err := json.Marshal(dbeak) assert.FatalError(t, err) dbrefBytes, err := json.Marshal(dbref) assert.FatalError(t, err) return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), provID+"."+ref) return dbrefBytes, nil case string(externalAccountKeyTable): assert.Equals(t, string(key), keyID) return b, nil case string(externalAccountKeyIDsByProvisionerIDTable): assert.Equals(t, provID, string(key)) b, err := json.Marshal([]string{keyID}) assert.FatalError(t, err) return b, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force default") } }, MDel: func(bucket, key []byte) error { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), provID+"."+ref) return nil case string(externalAccountKeyTable): assert.Equals(t, string(key), keyID) return nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return errors.New("force default") } }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { fmt.Println(string(bucket)) switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, provID+"."+ref, string(key)) return nil, true, nil case string(externalAccountKeyIDsByProvisionerIDTable): assert.Equals(t, provID, string(key)) return nil, true, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force default") } }, }, } }, "fail/not-found": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(externalAccountKeyTable)) assert.Equals(t, string(key), keyID) return nil, nosqldb.ErrNotFound }, }, err: errors.New("error loading ACME EAB Key with Key ID keyID: not found"), } }, "fail/non-matching-provisioner": func(t *testing.T) test { now := clock.Now() dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: "aDifferentProvID", Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) assert.FatalError(t, err) return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(externalAccountKeyTable)) assert.Equals(t, string(key), keyID) return b, nil }, }, err: errors.New("provisioner does not match provisioner for which the EAB key was created"), } }, "fail/delete-reference": func(t *testing.T) test { now := clock.Now() dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ Reference: ref, ExternalAccountKeyID: keyID, } b, err := json.Marshal(dbeak) assert.FatalError(t, err) dbrefBytes, err := json.Marshal(dbref) assert.FatalError(t, err) return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), ref) return dbrefBytes, nil case string(externalAccountKeyTable): assert.Equals(t, string(key), keyID) return b, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force default") } }, MDel: func(bucket, key []byte) error { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), provID+"."+ref) return errors.New("force") case string(externalAccountKeyTable): assert.Equals(t, string(key), keyID) return nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return errors.New("force default") } }, }, err: errors.New("error deleting ACME EAB Key reference with Key ID keyID and reference ref: force"), } }, "fail/delete-eak": func(t *testing.T) test { now := clock.Now() dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ Reference: ref, ExternalAccountKeyID: keyID, } b, err := json.Marshal(dbeak) assert.FatalError(t, err) dbrefBytes, err := json.Marshal(dbref) assert.FatalError(t, err) return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), ref) return dbrefBytes, nil case string(externalAccountKeyTable): assert.Equals(t, string(key), keyID) return b, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force default") } }, MDel: func(bucket, key []byte) error { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), provID+"."+ref) return nil case string(externalAccountKeyTable): assert.Equals(t, string(key), keyID) return errors.New("force") default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return errors.New("force default") } }, }, err: errors.New("error deleting ACME EAB Key with Key ID keyID: force"), } }, "fail/delete-eakID": func(t *testing.T) test { now := clock.Now() dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } dbref := &dbExternalAccountKeyReference{ Reference: ref, ExternalAccountKeyID: keyID, } b, err := json.Marshal(dbeak) assert.FatalError(t, err) dbrefBytes, err := json.Marshal(dbref) assert.FatalError(t, err) return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), ref) return dbrefBytes, nil case string(externalAccountKeyTable): assert.Equals(t, string(key), keyID) return b, nil case string(externalAccountKeyIDsByProvisionerIDTable): return b, errors.New("force") default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force default") } }, MDel: func(bucket, key []byte) error { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), provID+"."+ref) return nil case string(externalAccountKeyTable): assert.Equals(t, string(key), keyID) return nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return errors.New("force default") } }, }, err: errors.New("error removing ACME EAB Key ID keyID: error loading eakIDs for provisioner provID: force"), } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.DeleteExternalAccountKey(context.Background(), provID, keyID); err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.Equals(t, err.Error(), tc.err.Error()) } } } else { assert.Nil(t, tc.err) } }) } } func TestDB_CreateExternalAccountKey(t *testing.T) { keyID := "keyID" provID := "provID" ref := "ref" type test struct { db nosql.DB err error _id *string eak *acme.ExternalAccountKey } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { var ( id string idPtr = &id ) now := clock.Now() eak := &acme.ExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: "ref", AccountID: "", CreatedAt: now, } return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(externalAccountKeyIDsByProvisionerIDTable)) assert.Equals(t, provID, string(key)) b, _ := json.Marshal([]string{}) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { case string(externalAccountKeyIDsByProvisionerIDTable): assert.Equals(t, provID, string(key)) return nu, true, nil case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, provID+"."+ref, string(key)) assert.Equals(t, nil, old) return nu, true, nil case string(externalAccountKeyTable): assert.Equals(t, nil, old) id = string(key) dbeak := new(dbExternalAccountKey) assert.FatalError(t, json.Unmarshal(nu, dbeak)) assert.Equals(t, string(key), dbeak.ID) assert.Equals(t, eak.ProvisionerID, dbeak.ProvisionerID) assert.Equals(t, eak.Reference, dbeak.Reference) assert.Equals(t, 32, len(dbeak.HmacKey)) assert.False(t, dbeak.CreatedAt.IsZero()) assert.Equals(t, dbeak.AccountID, eak.AccountID) assert.True(t, dbeak.BoundAt.IsZero()) return nu, true, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force default") } }, }, eak: eak, _id: idPtr, } }, "fail/externalAccountKeyID-cmpAndSwap-error": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), ref) assert.Equals(t, old, nil) return nu, true, nil case string(externalAccountKeyTable): assert.Equals(t, old, nil) return nu, true, errors.New("force") default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force default") } }, }, err: errors.New("error saving acme external_account_key: force"), } }, "fail/addEAKID-error": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(externalAccountKeyIDsByProvisionerIDTable)) assert.Equals(t, provID, string(key)) return nil, errors.New("force") }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, string(key), ref) assert.Equals(t, old, nil) return nu, true, nil case string(externalAccountKeyTable): assert.Equals(t, old, nil) return nu, true, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force default") } }, }, err: errors.New("error loading eakIDs for provisioner provID: force"), } }, "fail/externalAccountKeyReference-cmpAndSwap-error": func(t *testing.T) test { return test{ db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(externalAccountKeyIDsByProvisionerIDTable)) assert.Equals(t, provID, string(key)) b, _ := json.Marshal([]string{}) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { case string(externalAccountKeyIDsByProvisionerIDTable): assert.Equals(t, provID, string(key)) return nu, true, nil case string(externalAccountKeyIDsByReferenceTable): assert.Equals(t, provID+"."+ref, string(key)) assert.Equals(t, old, nil) return nu, true, errors.New("force") case string(externalAccountKeyTable): assert.Equals(t, old, nil) return nu, true, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force default") } }, }, err: errors.New("error saving acme external_account_key_reference: force"), } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} eak, err := d.CreateExternalAccountKey(context.Background(), provID, ref) if err != nil { if assert.NotNil(t, tc.err) { assert.Equals(t, err.Error(), tc.err.Error()) } } else if assert.Nil(t, tc.err) { assert.Equals(t, *tc._id, eak.ID) assert.Equals(t, provID, eak.ProvisionerID) assert.Equals(t, ref, eak.Reference) assert.Equals(t, "", eak.AccountID) assert.False(t, eak.CreatedAt.IsZero()) assert.False(t, eak.AlreadyBound()) assert.True(t, eak.BoundAt.IsZero()) } }) } } func TestDB_UpdateExternalAccountKey(t *testing.T) { keyID := "keyID" provID := "provID" ref := "ref" now := clock.Now() dbeak := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(dbeak) assert.FatalError(t, err) type test struct { db nosql.DB eak *acme.ExternalAccountKey err error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { eak := &acme.ExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } return test{ eak: eak, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, old, b) dbNew := new(dbExternalAccountKey) assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbeak.ID) assert.Equals(t, dbNew.ProvisionerID, dbeak.ProvisionerID) assert.Equals(t, dbNew.Reference, dbeak.Reference) assert.Equals(t, dbNew.AccountID, dbeak.AccountID) assert.Equals(t, dbNew.CreatedAt, dbeak.CreatedAt) assert.Equals(t, dbNew.BoundAt, dbeak.BoundAt) assert.Equals(t, dbNew.HmacKey, dbeak.HmacKey) return nu, true, nil }, }, } }, "fail/db.Get-error": func(t *testing.T) test { return test{ eak: &acme.ExternalAccountKey{ ID: keyID, }, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return nil, errors.New("force") }, }, err: errors.New("error loading external account key keyID: force"), } }, "fail/provisioner-mismatch": func(t *testing.T) test { newDBEAK := &dbExternalAccountKey{ ID: keyID, ProvisionerID: "aDifferentProvID", Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(newDBEAK) assert.FatalError(t, err) return test{ eak: &acme.ExternalAccountKey{ ID: keyID, }, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return b, nil }, }, err: errors.New("provisioner does not match provisioner for which the EAB key was created"), } }, "fail/provisioner-change": func(t *testing.T) test { newDBEAK := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(newDBEAK) assert.FatalError(t, err) return test{ eak: &acme.ExternalAccountKey{ ID: keyID, ProvisionerID: "aDifferentProvisionerID", }, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return b, nil }, }, err: errors.New("cannot change provisioner for an existing ACME EAB Key"), } }, "fail/reference-change": func(t *testing.T) test { newDBEAK := &dbExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: ref, AccountID: "", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: now, } b, err := json.Marshal(newDBEAK) assert.FatalError(t, err) return test{ eak: &acme.ExternalAccountKey{ ID: keyID, ProvisionerID: provID, Reference: "aDifferentReference", }, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyTable) assert.Equals(t, string(key), keyID) return b, nil }, }, err: errors.New("cannot change reference for an existing ACME EAB Key"), } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.UpdateExternalAccountKey(context.Background(), provID, tc.eak); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else if assert.Nil(t, tc.err) { assert.Equals(t, dbeak.ID, tc.eak.ID) assert.Equals(t, dbeak.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, dbeak.Reference, tc.eak.Reference) assert.Equals(t, dbeak.AccountID, tc.eak.AccountID) assert.Equals(t, dbeak.CreatedAt, tc.eak.CreatedAt) assert.Equals(t, dbeak.BoundAt, tc.eak.BoundAt) assert.Equals(t, dbeak.HmacKey, tc.eak.HmacKey) } }) } } func TestDB_addEAKID(t *testing.T) { provID := "provID" eakID := "eakID" type test struct { ctx context.Context provisionerID string eakID string db nosql.DB err error } var tests = map[string]func(t *testing.T) test{ "fail/empty-eakID": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: "", err: errors.New("can't add empty eakID for provisioner provID"), } }, "fail/db.Get": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) return nil, errors.New("force") }, }, err: errors.New("error loading eakIDs for provisioner provID: force"), } }, "fail/unmarshal": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) b, _ := json.Marshal(1) return b, nil }, }, err: errors.New("error unmarshaling eakIDs for provisioner provID: json: cannot unmarshal number into Go value of type []string"), } }, "fail/eakID-already-exists": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) b, _ := json.Marshal([]string{eakID}) return b, nil }, }, err: errors.New("eakID eakID already exists for provisioner provID"), } }, "fail/db.save": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) b, _ := json.Marshal([]string{"id1"}) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) oldB, _ := json.Marshal([]string{"id1"}) assert.Equals(t, old, oldB) newB, _ := json.Marshal([]string{"id1", eakID}) assert.Equals(t, nu, newB) return newB, true, errors.New("force") }, }, err: errors.New("error saving eakIDs index for provisioner provID: error saving acme externalAccountKeyIDsByProvisionerID: force"), } }, "ok/db.Get-not-found": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) return nil, nosqldb.ErrNotFound }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) assert.Equals(t, old, nil) b, _ := json.Marshal([]string{eakID}) assert.Equals(t, nu, b) return b, true, nil }, }, err: nil, } }, "ok": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) b, _ := json.Marshal([]string{"id1", "id2"}) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) oldB, _ := json.Marshal([]string{"id1", "id2"}) assert.Equals(t, old, oldB) newB, _ := json.Marshal([]string{"id1", "id2", eakID}) assert.Equals(t, nu, newB) return newB, true, nil }, }, err: nil, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { db := &DB{ db: tc.db, } wantErr := tc.err != nil err := db.addEAKID(tc.ctx, tc.provisionerID, tc.eakID) if (err != nil) != wantErr { t.Errorf("DB.addEAKID() error = %v, wantErr %v", err, wantErr) } if err != nil { assert.Equals(t, tc.err.Error(), err.Error()) } }) } } func TestDB_deleteEAKID(t *testing.T) { provID := "provID" eakID := "eakID" type test struct { ctx context.Context provisionerID string eakID string db nosql.DB err error } var tests = map[string]func(t *testing.T) test{ "fail/db.Get": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) return nil, errors.New("force") }, }, err: errors.New("error loading eakIDs for provisioner provID: force"), } }, "fail/unmarshal": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) b, _ := json.Marshal(1) return b, nil }, }, err: errors.New("error unmarshaling eakIDs for provisioner provID: json: cannot unmarshal number into Go value of type []string"), } }, "fail/db.save": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) b, _ := json.Marshal([]string{"id1", eakID}) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) oldB, _ := json.Marshal([]string{"id1", eakID}) assert.Equals(t, old, oldB) newB, _ := json.Marshal([]string{"id1"}) assert.Equals(t, nu, newB) return newB, true, errors.New("force") }, }, err: errors.New("error saving eakIDs index for provisioner provID: error saving acme externalAccountKeyIDsByProvisionerID: force"), } }, "ok/db.Get-not-found": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) return nil, nosqldb.ErrNotFound }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) assert.Equals(t, old, nil) b, _ := json.Marshal([]string{}) assert.Equals(t, nu, b) return b, true, nil }, }, err: nil, } }, "ok": func(t *testing.T) test { return test{ ctx: context.Background(), provisionerID: provID, eakID: eakID, db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) b, _ := json.Marshal([]string{"id1", eakID, "id2"}) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) oldB, _ := json.Marshal([]string{"id1", eakID, "id2"}) assert.Equals(t, old, oldB) newB, _ := json.Marshal([]string{"id1", "id2"}) assert.Equals(t, nu, newB) return newB, true, nil }, }, err: nil, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { db := &DB{ db: tc.db, } wantErr := tc.err != nil err := db.deleteEAKID(tc.ctx, tc.provisionerID, tc.eakID) if (err != nil) != wantErr { t.Errorf("DB.deleteEAKID() error = %v, wantErr %v", err, wantErr) } if err != nil { assert.Equals(t, tc.err.Error(), err.Error()) } }) } } func TestDB_addAndDeleteEAKID(t *testing.T) { provID := "provID" callCounter := 0 type test struct { ctx context.Context db nosql.DB err error } var tests = map[string]func(t *testing.T) test{ "ok/multi": func(t *testing.T) test { return test{ ctx: context.Background(), db: &certdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) switch callCounter { case 0: return nil, nosqldb.ErrNotFound case 1: b, _ := json.Marshal([]string{"eakID"}) return b, nil case 2: b, _ := json.Marshal([]string{}) return b, nil case 3: b, _ := json.Marshal([]string{"eakID1"}) return b, nil case 4: b, _ := json.Marshal([]string{"eakID1", "eakID2"}) return b, nil case 5: b, _ := json.Marshal([]string{"eakID2"}) return b, nil default: assert.FatalError(t, errors.New("unexpected get iteration")) return nil, errors.New("force get default") } }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) assert.Equals(t, string(key), provID) switch callCounter { case 0: assert.Equals(t, old, nil) newB, _ := json.Marshal([]string{"eakID"}) assert.Equals(t, nu, newB) return newB, true, nil case 1: oldB, _ := json.Marshal([]string{"eakID"}) assert.Equals(t, old, oldB) newB, _ := json.Marshal([]string{}) return newB, true, nil case 2: assert.Equals(t, old, nil) newB, _ := json.Marshal([]string{"eakID1"}) assert.Equals(t, nu, newB) return newB, true, nil case 3: oldB, _ := json.Marshal([]string{"eakID1"}) assert.Equals(t, old, oldB) newB, _ := json.Marshal([]string{"eakID1", "eakID2"}) assert.Equals(t, nu, newB) return newB, true, nil case 4: oldB, _ := json.Marshal([]string{"eakID1", "eakID2"}) assert.Equals(t, old, oldB) newB, _ := json.Marshal([]string{"eakID2"}) assert.Equals(t, nu, newB) return newB, true, nil case 5: oldB, _ := json.Marshal([]string{"eakID2"}) assert.Equals(t, old, oldB) newB, _ := json.Marshal([]string{}) assert.Equals(t, nu, newB) return newB, true, nil default: assert.FatalError(t, errors.New("unexpected get iteration")) return nil, true, errors.New("force save default") } }, }, err: nil, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { // goal of this test is to simulate multiple calls; no errors expected. db := &DB{ db: tc.db, } err := db.addEAKID(tc.ctx, provID, "eakID") if err != nil { t.Errorf("DB.addEAKID() error = %v", err) } callCounter++ err = db.deleteEAKID(tc.ctx, provID, "eakID") if err != nil { t.Errorf("DB.deleteEAKID() error = %v", err) } callCounter++ err = db.addEAKID(tc.ctx, provID, "eakID1") if err != nil { t.Errorf("DB.addEAKID() error = %v", err) } callCounter++ err = db.addEAKID(tc.ctx, provID, "eakID2") if err != nil { t.Errorf("DB.addEAKID() error = %v", err) } callCounter++ err = db.deleteEAKID(tc.ctx, provID, "eakID1") if err != nil { t.Errorf("DB.deleteEAKID() error = %v", err) } callCounter++ err = db.deleteEAKID(tc.ctx, provID, "eakID2") if err != nil { t.Errorf("DB.deleteAKID() error = %v", err) } }) } } func Test_removeElement(t *testing.T) { tests := []struct { name string slice []string item string want []string }{ { name: "remove-first", slice: []string{"id1", "id2", "id3"}, item: "id1", want: []string{"id2", "id3"}, }, { name: "remove-last", slice: []string{"id1", "id2", "id3"}, item: "id3", want: []string{"id1", "id2"}, }, { name: "remove-middle", slice: []string{"id1", "id2", "id3"}, item: "id2", want: []string{"id1", "id3"}, }, { name: "remove-non-existing", slice: []string{"id1", "id2", "id3"}, item: "none", want: []string{"id1", "id2", "id3"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := removeElement(tt.slice, tt.item) if !cmp.Equal(tt.want, got) { t.Errorf("removeElement() diff =\n %s", cmp.Diff(tt.want, got)) } }) } } ================================================ FILE: acme/db/nosql/nonce.go ================================================ package nosql import ( "context" "encoding/base64" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" ) // dbNonce contains nonce metadata used in the ACME protocol. type dbNonce struct { ID string CreatedAt time.Time DeletedAt time.Time } // CreateNonce creates, stores, and returns an ACME replay-nonce. // Implements the acme.DB interface. func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) { _id, err := randID() if err != nil { return "", err } id := base64.RawURLEncoding.EncodeToString([]byte(_id)) n := &dbNonce{ ID: id, CreatedAt: clock.Now(), } if err := db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil { return "", err } return acme.Nonce(id), nil } // DeleteNonce verifies that the nonce is valid (by checking if it exists), // and if so, consumes the nonce resource by deleting it from the database. func (db *DB) DeleteNonce(_ context.Context, nonce acme.Nonce) error { err := db.db.Update(&database.Tx{ Operations: []*database.TxEntry{ { Bucket: nonceTable, Key: []byte(nonce), Cmd: database.Get, }, { Bucket: nonceTable, Key: []byte(nonce), Cmd: database.Delete, }, }, }) switch { case nosql.IsErrNotFound(err): return acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", string(nonce)) case err != nil: return errors.Wrapf(err, "error deleting nonce %s", string(nonce)) default: return nil } } ================================================ FILE: acme/db/nosql/nonce_test.go ================================================ package nosql import ( "context" "encoding/json" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" ) func TestDB_CreateNonce(t *testing.T) { type test struct { db nosql.DB err error _id *string } var tests = map[string]func(t *testing.T) test{ "fail/cmpAndSwap-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, nonceTable) assert.Equals(t, old, nil) dbn := new(dbNonce) assert.FatalError(t, json.Unmarshal(nu, dbn)) assert.Equals(t, dbn.ID, string(key)) assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt)) return nil, false, errors.New("force") }, }, err: errors.New("error saving acme nonce: force"), } }, "ok": func(t *testing.T) test { var ( id string idPtr = &id ) return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { *idPtr = string(key) assert.Equals(t, bucket, nonceTable) assert.Equals(t, old, nil) dbn := new(dbNonce) assert.FatalError(t, json.Unmarshal(nu, dbn)) assert.Equals(t, dbn.ID, string(key)) assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt)) return nil, true, nil }, }, _id: idPtr, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if n, err := d.CreateNonce(context.Background()); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, string(n), *tc._id) } } }) } } func TestDB_DeleteNonce(t *testing.T) { nonceID := "nonceID" type test struct { db nosql.DB err error acmeErr *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MUpdate: func(tx *database.Tx) error { assert.Equals(t, tx.Operations[0].Bucket, nonceTable) assert.Equals(t, tx.Operations[0].Key, []byte(nonceID)) assert.Equals(t, tx.Operations[0].Cmd, database.Get) assert.Equals(t, tx.Operations[1].Bucket, nonceTable) assert.Equals(t, tx.Operations[1].Key, []byte(nonceID)) assert.Equals(t, tx.Operations[1].Cmd, database.Delete) return database.ErrNotFound }, }, acmeErr: acme.NewError(acme.ErrorBadNonceType, "nonce %s not found", nonceID), } }, "fail/db.Update-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MUpdate: func(tx *database.Tx) error { assert.Equals(t, tx.Operations[0].Bucket, nonceTable) assert.Equals(t, tx.Operations[0].Key, []byte(nonceID)) assert.Equals(t, tx.Operations[0].Cmd, database.Get) assert.Equals(t, tx.Operations[1].Bucket, nonceTable) assert.Equals(t, tx.Operations[1].Key, []byte(nonceID)) assert.Equals(t, tx.Operations[1].Cmd, database.Delete) return errors.New("force") }, }, err: errors.New("error deleting nonce nonceID: force"), } }, "ok": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MUpdate: func(tx *database.Tx) error { assert.Equals(t, tx.Operations[0].Bucket, nonceTable) assert.Equals(t, tx.Operations[0].Key, []byte(nonceID)) assert.Equals(t, tx.Operations[0].Cmd, database.Get) assert.Equals(t, tx.Operations[1].Bucket, nonceTable) assert.Equals(t, tx.Operations[1].Key, []byte(nonceID)) assert.Equals(t, tx.Operations[1].Cmd, database.Delete) return nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else { assert.Nil(t, tc.err) } }) } } ================================================ FILE: acme/db/nosql/nosql.go ================================================ package nosql import ( "context" "encoding/json" "time" "github.com/pkg/errors" nosqlDB "github.com/smallstep/nosql" "go.step.sm/crypto/randutil" ) var ( accountTable = []byte("acme_accounts") accountByKeyIDTable = []byte("acme_keyID_accountID_index") authzTable = []byte("acme_authzs") challengeTable = []byte("acme_challenges") nonceTable = []byte("nonces") orderTable = []byte("acme_orders") ordersByAccountIDTable = []byte("acme_account_orders_index") certTable = []byte("acme_certs") certBySerialTable = []byte("acme_serial_certs_index") externalAccountKeyTable = []byte("acme_external_account_keys") externalAccountKeyIDsByReferenceTable = []byte("acme_external_account_keyID_reference_index") externalAccountKeyIDsByProvisionerIDTable = []byte("acme_external_account_keyID_provisionerID_index") wireDpopTokenTable = []byte("wire_acme_dpop_token") wireOidcTokenTable = []byte("wire_acme_oidc_token") ) // DB is a struct that implements the AcmeDB interface. type DB struct { db nosqlDB.DB } // New configures and returns a new ACME DB backend implemented using a nosql DB. func New(db nosqlDB.DB) (*DB, error) { tables := [][]byte{accountTable, accountByKeyIDTable, authzTable, challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable, certBySerialTable, externalAccountKeyTable, externalAccountKeyIDsByReferenceTable, externalAccountKeyIDsByProvisionerIDTable, wireDpopTokenTable, wireOidcTokenTable, } for _, b := range tables { if err := db.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", string(b)) } } return &DB{db}, nil } // save writes the new data to the database, overwriting the old data if it // existed. func (db *DB) save(_ context.Context, id string, nu, old interface{}, typ string, table []byte) error { var ( err error newB []byte ) if nu == nil { newB = nil } else { newB, err = json.Marshal(nu) if err != nil { return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu) } } var oldB []byte if old == nil { oldB = nil } else { oldB, err = json.Marshal(old) if err != nil { return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old) } } _, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB) switch { case err != nil: return errors.Wrapf(err, "error saving acme %s", typ) case !swapped: return errors.Errorf("error saving acme %s; changed since last read", typ) default: return nil } } var idLen = 32 func randID() (val string, err error) { val, err = randutil.Alphanumeric(idLen) if err != nil { return "", errors.Wrap(err, "error generating random alphanumeric ID") } return val, nil } // Clock that returns time in UTC rounded to seconds. type Clock struct{} // Now returns the UTC time rounded to seconds. func (c *Clock) Now() time.Time { return time.Now().UTC().Truncate(time.Second) } var clock = new(Clock) ================================================ FILE: acme/db/nosql/nosql_test.go ================================================ package nosql import ( "context" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" ) func TestNew(t *testing.T) { type test struct { db nosql.DB err error } var tests = map[string]test{ "fail/db.CreateTable-error": { db: &db.MockNoSQLDB{ MCreateTable: func(bucket []byte) error { assert.Equals(t, string(bucket), string(accountTable)) return errors.New("force") }, }, err: errors.Errorf("error creating table %s: force", string(accountTable)), }, "ok": { db: &db.MockNoSQLDB{ MCreateTable: func(bucket []byte) error { return nil }, }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { if _, err := New(tc.db); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } type errorThrower string func (et errorThrower) MarshalJSON() ([]byte, error) { return nil, errors.New("force") } func TestDB_save(t *testing.T) { type test struct { db nosql.DB nu interface{} old interface{} err error } var tests = map[string]test{ "fail/error-marshaling-new": { nu: errorThrower("foo"), err: errors.New("error marshaling acme type: challenge"), }, "fail/error-marshaling-old": { nu: "new", old: errorThrower("foo"), err: errors.New("error marshaling acme type: challenge"), }, "fail/db.CmpAndSwap-error": { nu: "new", old: "old", db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), "id") assert.Equals(t, string(old), "\"old\"") assert.Equals(t, string(nu), "\"new\"") return nil, false, errors.New("force") }, }, err: errors.New("error saving acme challenge: force"), }, "fail/db.CmpAndSwap-false-marshaling-old": { nu: "new", old: "old", db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), "id") assert.Equals(t, string(old), "\"old\"") assert.Equals(t, string(nu), "\"new\"") return nil, false, nil }, }, err: errors.New("error saving acme challenge; changed since last read"), }, "ok": { nu: "new", old: "old", db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), "id") assert.Equals(t, string(old), "\"old\"") assert.Equals(t, string(nu), "\"new\"") return nu, true, nil }, }, }, "ok/nils": { nu: nil, old: nil, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, challengeTable) assert.Equals(t, string(key), "id") assert.Equals(t, old, nil) assert.Equals(t, nu, nil) return nu, true, nil }, }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { d := &DB{db: tc.db} if err := d.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } ================================================ FILE: acme/db/nosql/order.go ================================================ package nosql import ( "context" "encoding/json" "sync" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/nosql" ) // Mutex for locking ordersByAccount index operations. var ordersByAccountMux sync.Mutex type dbOrder struct { ID string `json:"id"` AccountID string `json:"accountID"` ProvisionerID string `json:"provisionerID"` Identifiers []acme.Identifier `json:"identifiers"` AuthorizationIDs []string `json:"authorizationIDs"` Status acme.Status `json:"status"` NotBefore time.Time `json:"notBefore,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"` CreatedAt time.Time `json:"createdAt"` ExpiresAt time.Time `json:"expiresAt,omitempty"` CertificateID string `json:"certificate,omitempty"` Error *acme.Error `json:"error,omitempty"` } func (a *dbOrder) clone() *dbOrder { b := *a return &b } // getDBOrder retrieves and unmarshals an ACME Order type from the database. func (db *DB) getDBOrder(_ context.Context, id string) (*dbOrder, error) { b, err := db.db.Get(orderTable, []byte(id)) if nosql.IsErrNotFound(err) { return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading order %s", id) } o := new(dbOrder) if err := json.Unmarshal(b, &o); err != nil { return nil, errors.Wrapf(err, "error unmarshaling order %s into dbOrder", id) } return o, nil } // GetOrder retrieves an ACME Order from the database. func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) { dbo, err := db.getDBOrder(ctx, id) if err != nil { return nil, err } o := &acme.Order{ ID: dbo.ID, AccountID: dbo.AccountID, ProvisionerID: dbo.ProvisionerID, CertificateID: dbo.CertificateID, Status: dbo.Status, ExpiresAt: dbo.ExpiresAt, Identifiers: dbo.Identifiers, NotBefore: dbo.NotBefore, NotAfter: dbo.NotAfter, AuthorizationIDs: dbo.AuthorizationIDs, Error: dbo.Error, } return o, nil } // CreateOrder creates ACME Order resources and saves them to the DB. func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { var err error o.ID, err = randID() if err != nil { return err } now := clock.Now() dbo := &dbOrder{ ID: o.ID, AccountID: o.AccountID, ProvisionerID: o.ProvisionerID, Status: o.Status, CreatedAt: now, ExpiresAt: o.ExpiresAt, Identifiers: o.Identifiers, NotBefore: o.NotBefore, NotAfter: o.NotAfter, AuthorizationIDs: o.AuthorizationIDs, } if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil { return err } _, err = db.updateAddOrderIDs(ctx, o.AccountID, false, o.ID) if err != nil { return err } return nil } // UpdateOrder saves an updated ACME Order to the database. func (db *DB) UpdateOrder(ctx context.Context, o *acme.Order) error { old, err := db.getDBOrder(ctx, o.ID) if err != nil { return err } nu := old.clone() nu.Status = o.Status nu.Error = o.Error nu.CertificateID = o.CertificateID return db.save(ctx, old.ID, nu, old, "order", orderTable) } func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, includeReadyOrders bool, addOids ...string) ([]string, error) { ordersByAccountMux.Lock() defer ordersByAccountMux.Unlock() var oldOids []string b, err := db.db.Get(ordersByAccountIDTable, []byte(accID)) if err != nil { if !nosql.IsErrNotFound(err) { return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID) } } else { if err := json.Unmarshal(b, &oldOids); err != nil { return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID) } } // Remove any order that is not in PENDING state and update the stored list // before returning. // // According to RFC 8555: // The server SHOULD include pending orders and SHOULD NOT include orders // that are invalid in the array of URLs. pendOids := []string{} for _, oid := range oldOids { o, err := db.GetOrder(ctx, oid) if err != nil { return nil, acme.WrapErrorISE(err, "error loading order %s for account %s", oid, accID) } if err = o.UpdateStatus(ctx, db); err != nil { return nil, acme.WrapErrorISE(err, "error updating order %s for account %s", oid, accID) } if o.Status == acme.StatusPending || (o.Status == acme.StatusReady && includeReadyOrders) { pendOids = append(pendOids, oid) } } pendOids = append(pendOids, addOids...) var ( _old interface{} = oldOids _new interface{} = pendOids ) switch { case len(oldOids) == 0 && len(pendOids) == 0: // If list has not changed from empty, then no need to write the DB. return []string{}, nil case len(oldOids) == 0: _old = nil case len(pendOids) == 0: _new = nil } if err = db.save(ctx, accID, _new, _old, "orderIDsByAccountID", ordersByAccountIDTable); err != nil { // Delete all orders that may have been previously stored if orderIDsByAccountID update fails. for _, oid := range addOids { // Ignore error from delete -- we tried our best. // TODO when we have logging w/ request ID tracking, logging this error. db.db.Del(orderTable, []byte(oid)) } return nil, errors.Wrapf(err, "error saving orderIDs index for account %s", accID) } return pendOids, nil } // GetOrdersByAccountID returns a list of order IDs owned by the account. func (db *DB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { return db.updateAddOrderIDs(ctx, accID, false) } // GetAllOrdersByAccountID returns a list of any order IDs owned by the account. func (db *DB) GetAllOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { return db.updateAddOrderIDs(ctx, accID, true) } ================================================ FILE: acme/db/nosql/order_test.go ================================================ package nosql import ( "context" "encoding/json" "reflect" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" ) func TestDB_getDBOrder(t *testing.T) { orderID := "orderID" type test struct { db nosql.DB err error acmeErr *acme.Error dbo *dbOrder } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) return nil, database.ErrNotFound }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) return nil, errors.New("force") }, }, err: errors.New("error loading order orderID: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) return []byte("foo"), nil }, }, err: errors.New("error unmarshaling order orderID into dbOrder"), } }, "ok": func(t *testing.T) test { now := clock.Now() dbo := &dbOrder{ ID: orderID, AccountID: "accID", ProvisionerID: "provID", CertificateID: "certID", Status: acme.StatusValid, ExpiresAt: now, CreatedAt: now, NotBefore: now, NotAfter: now, Identifiers: []acme.Identifier{ {Type: "dns", Value: "test.ca.smallstep.com"}, {Type: "dns", Value: "example.foo.com"}, }, AuthorizationIDs: []string{"foo", "bar"}, Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"), } b, err := json.Marshal(dbo) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) return b, nil }, }, dbo: dbo, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if dbo, err := d.getDBOrder(context.Background(), orderID); err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, dbo.ID, tc.dbo.ID) assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID) assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID) assert.Equals(t, dbo.Status, tc.dbo.Status) assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt) assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt) assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore) assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter) assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers) assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs) assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error()) } }) } } func TestDB_GetOrder(t *testing.T) { orderID := "orderID" type test struct { db nosql.DB err error acmeErr *acme.Error dbo *dbOrder } var tests = map[string]func(t *testing.T) test{ "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) return nil, errors.New("force") }, }, err: errors.New("error loading order orderID: force"), } }, "fail/forward-acme-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) return nil, database.ErrNotFound }, }, acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), } }, "ok": func(t *testing.T) test { now := clock.Now() dbo := &dbOrder{ ID: orderID, AccountID: "accID", ProvisionerID: "provID", CertificateID: "certID", Status: acme.StatusValid, ExpiresAt: now, CreatedAt: now, NotBefore: now, NotAfter: now, Identifiers: []acme.Identifier{ {Type: "dns", Value: "test.ca.smallstep.com"}, {Type: "dns", Value: "example.foo.com"}, }, AuthorizationIDs: []string{"foo", "bar"}, Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"), } b, err := json.Marshal(dbo) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) return b, nil }, }, dbo: dbo, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if o, err := d.GetOrder(context.Background(), orderID); err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, o.ID, tc.dbo.ID) assert.Equals(t, o.AccountID, tc.dbo.AccountID) assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID) assert.Equals(t, o.CertificateID, tc.dbo.CertificateID) assert.Equals(t, o.Status, tc.dbo.Status) assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt) assert.Equals(t, o.NotBefore, tc.dbo.NotBefore) assert.Equals(t, o.NotAfter, tc.dbo.NotAfter) assert.Equals(t, o.Identifiers, tc.dbo.Identifiers) assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs) assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error()) } }) } } func TestDB_UpdateOrder(t *testing.T) { orderID := "orderID" now := clock.Now() dbo := &dbOrder{ ID: orderID, AccountID: "accID", ProvisionerID: "provID", Status: acme.StatusPending, ExpiresAt: now, CreatedAt: now, NotBefore: now, NotAfter: now, Identifiers: []acme.Identifier{ {Type: "dns", Value: "test.ca.smallstep.com"}, {Type: "dns", Value: "example.foo.com"}, }, AuthorizationIDs: []string{"foo", "bar"}, } b, err := json.Marshal(dbo) assert.FatalError(t, err) type test struct { db nosql.DB o *acme.Order err error } var tests = map[string]func(t *testing.T) test{ "fail/db.Get-error": func(t *testing.T) test { return test{ o: &acme.Order{ ID: orderID, }, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) return nil, errors.New("force") }, }, err: errors.New("error loading order orderID: force"), } }, "fail/save-error": func(t *testing.T) test { o := &acme.Order{ ID: orderID, Status: acme.StatusValid, CertificateID: "certID", Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"), } return test{ o: o, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, old, b) dbNew := new(dbOrder) assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbo.ID) assert.Equals(t, dbNew.AccountID, dbo.AccountID) assert.Equals(t, dbNew.ProvisionerID, dbo.ProvisionerID) assert.Equals(t, dbNew.CertificateID, o.CertificateID) assert.Equals(t, dbNew.Status, o.Status) assert.Equals(t, dbNew.CreatedAt, dbo.CreatedAt) assert.Equals(t, dbNew.ExpiresAt, dbo.ExpiresAt) assert.Equals(t, dbNew.NotBefore, dbo.NotBefore) assert.Equals(t, dbNew.NotAfter, dbo.NotAfter) assert.Equals(t, dbNew.AuthorizationIDs, dbo.AuthorizationIDs) assert.Equals(t, dbNew.Identifiers, dbo.Identifiers) assert.Equals(t, dbNew.Error.Error(), o.Error.Error()) return nil, false, errors.New("force") }, }, err: errors.New("error saving acme order: force"), } }, "ok": func(t *testing.T) test { o := &acme.Order{ ID: orderID, Status: acme.StatusValid, CertificateID: "certID", Error: acme.NewError(acme.ErrorMalformedType, "The request message was malformed"), } return test{ o: o, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, string(key), orderID) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, old, b) dbNew := new(dbOrder) assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.Equals(t, dbNew.ID, dbo.ID) assert.Equals(t, dbNew.AccountID, dbo.AccountID) assert.Equals(t, dbNew.ProvisionerID, dbo.ProvisionerID) assert.Equals(t, dbNew.CertificateID, o.CertificateID) assert.Equals(t, dbNew.Status, o.Status) assert.Equals(t, dbNew.CreatedAt, dbo.CreatedAt) assert.Equals(t, dbNew.ExpiresAt, dbo.ExpiresAt) assert.Equals(t, dbNew.NotBefore, dbo.NotBefore) assert.Equals(t, dbNew.NotAfter, dbo.NotAfter) assert.Equals(t, dbNew.AuthorizationIDs, dbo.AuthorizationIDs) assert.Equals(t, dbNew.Identifiers, dbo.Identifiers) assert.Equals(t, dbNew.Error.Error(), o.Error.Error()) return nu, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.UpdateOrder(context.Background(), tc.o); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.o.ID, dbo.ID) assert.Equals(t, tc.o.CertificateID, "certID") assert.Equals(t, tc.o.Status, acme.StatusValid) assert.Equals(t, tc.o.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error()) } } }) } } func TestDB_CreateOrder(t *testing.T) { now := clock.Now() nbf := now.Add(5 * time.Minute) naf := now.Add(15 * time.Minute) type test struct { db nosql.DB o *acme.Order err error _id *string } var tests = map[string]func(t *testing.T) test{ "fail/order-save-error": func(t *testing.T) test { o := &acme.Order{ AccountID: "accID", ProvisionerID: "provID", CertificateID: "certID", Status: acme.StatusValid, ExpiresAt: now, NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ {Type: "dns", Value: "test.ca.smallstep.com"}, {Type: "dns", Value: "example.foo.com"}, }, AuthorizationIDs: []string{"foo", "bar"}, } return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, string(bucket), string(orderTable)) assert.Equals(t, string(key), o.ID) assert.Equals(t, old, nil) dbo := new(dbOrder) assert.FatalError(t, json.Unmarshal(nu, dbo)) assert.Equals(t, dbo.ID, o.ID) assert.Equals(t, dbo.AccountID, o.AccountID) assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) assert.Equals(t, dbo.CertificateID, "") assert.Equals(t, dbo.Status, o.Status) assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) assert.Equals(t, dbo.NotBefore, o.NotBefore) assert.Equals(t, dbo.NotAfter, o.NotAfter) assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) assert.Equals(t, dbo.Identifiers, o.Identifiers) assert.Equals(t, dbo.Error, nil) return nil, false, errors.New("force") }, }, o: o, err: errors.New("error saving acme order: force"), } }, "fail/orderIDsByOrderUpdate-error": func(t *testing.T) test { o := &acme.Order{ AccountID: "accID", ProvisionerID: "provID", CertificateID: "certID", Status: acme.StatusValid, ExpiresAt: now, NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ {Type: "dns", Value: "test.ca.smallstep.com"}, {Type: "dns", Value: "example.foo.com"}, }, AuthorizationIDs: []string{"foo", "bar"}, } return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(ordersByAccountIDTable)) assert.Equals(t, string(key), o.AccountID) return nil, errors.New("force") }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, string(bucket), string(orderTable)) assert.Equals(t, string(key), o.ID) assert.Equals(t, old, nil) dbo := new(dbOrder) assert.FatalError(t, json.Unmarshal(nu, dbo)) assert.Equals(t, dbo.ID, o.ID) assert.Equals(t, dbo.AccountID, o.AccountID) assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) assert.Equals(t, dbo.CertificateID, "") assert.Equals(t, dbo.Status, o.Status) assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) assert.Equals(t, dbo.NotBefore, o.NotBefore) assert.Equals(t, dbo.NotAfter, o.NotAfter) assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) assert.Equals(t, dbo.Identifiers, o.Identifiers) assert.Equals(t, dbo.Error, nil) return nu, true, nil }, }, o: o, err: errors.New("error loading orderIDs for account accID: force"), } }, "ok": func(t *testing.T) test { var ( id string idptr = &id ) o := &acme.Order{ AccountID: "accID", ProvisionerID: "provID", Status: acme.StatusValid, ExpiresAt: now, NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ {Type: "dns", Value: "test.ca.smallstep.com"}, {Type: "dns", Value: "example.foo.com"}, }, AuthorizationIDs: []string{"foo", "bar"}, } return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, string(bucket), string(ordersByAccountIDTable)) assert.Equals(t, string(key), o.AccountID) return nil, database.ErrNotFound }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { case string(ordersByAccountIDTable): b, err := json.Marshal([]string{o.ID}) assert.FatalError(t, err) assert.Equals(t, string(key), "accID") assert.Equals(t, old, nil) assert.Equals(t, nu, b) return nu, true, nil case string(orderTable): *idptr = string(key) assert.Equals(t, string(key), o.ID) assert.Equals(t, old, nil) dbo := new(dbOrder) assert.FatalError(t, json.Unmarshal(nu, dbo)) assert.Equals(t, dbo.ID, o.ID) assert.Equals(t, dbo.AccountID, o.AccountID) assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) assert.Equals(t, dbo.CertificateID, "") assert.Equals(t, dbo.Status, o.Status) assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) assert.Equals(t, dbo.NotBefore, o.NotBefore) assert.Equals(t, dbo.NotAfter, o.NotAfter) assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) assert.Equals(t, dbo.Identifiers, o.Identifiers) assert.Equals(t, dbo.Error, nil) return nu, true, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force") } }, }, o: o, _id: idptr, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if err := d.CreateOrder(context.Background(), tc.o); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.o.ID, *tc._id) } } }) } } func TestDB_updateAddOrderIDs(t *testing.T) { accID := "accID" type test struct { db nosql.DB err error acmeErr *acme.Error addOids []string res []string } var tests = map[string]func(t *testing.T) test{ "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(accID)) return nil, errors.New("force") }, }, err: errors.Errorf("error loading orderIDs for account %s", accID), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(accID)) return []byte("foo"), nil }, }, err: errors.Errorf("error unmarshaling orderIDs for account %s", accID), } }, "fail/db.Get-order-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(ordersByAccountIDTable): assert.Equals(t, key, []byte(accID)) b, err := json.Marshal([]string{"foo", "bar"}) assert.FatalError(t, err) return b, nil case string(orderTable): assert.Equals(t, key, []byte("foo")) return nil, errors.New("force") default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, }, acmeErr: acme.NewErrorISE("error loading order foo for account accID: error loading order foo: force"), } }, "fail/update-order-status-error": func(t *testing.T) test { expiry := clock.Now().Add(-5 * time.Minute) ofoo := &dbOrder{ ID: "foo", Status: acme.StatusPending, ExpiresAt: expiry, } bfoo, err := json.Marshal(ofoo) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(ordersByAccountIDTable): assert.Equals(t, key, []byte(accID)) b, err := json.Marshal([]string{"foo", "bar"}) assert.FatalError(t, err) return b, nil case string(orderTable): assert.Equals(t, key, []byte("foo")) return bfoo, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, key, []byte("foo")) assert.Equals(t, old, bfoo) newdbo := new(dbOrder) assert.FatalError(t, json.Unmarshal(nu, newdbo)) assert.Equals(t, newdbo.ID, "foo") assert.Equals(t, newdbo.Status, acme.StatusInvalid) assert.Equals(t, newdbo.ExpiresAt, expiry) assert.Equals(t, newdbo.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error()) return nil, false, errors.New("force") }, }, acmeErr: acme.NewErrorISE("error updating order foo for account accID: error updating order: error saving acme order: force"), } }, "fail/db.save-order-error": func(t *testing.T) test { addOids := []string{"foo", "bar"} b, err := json.Marshal(addOids) assert.FatalError(t, err) delCount := 0 return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(accID)) return nil, database.ErrNotFound }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(accID)) assert.Equals(t, old, nil) assert.Equals(t, nu, b) return nil, false, errors.New("force") }, MDel: func(bucket, key []byte) error { delCount++ switch delCount { case 1: assert.Equals(t, bucket, orderTable) assert.Equals(t, key, []byte("foo")) return nil case 2: assert.Equals(t, bucket, orderTable) assert.Equals(t, key, []byte("bar")) return nil default: assert.FatalError(t, errors.New("delete should only be called twice")) return errors.New("force") } }, }, addOids: addOids, err: errors.Errorf("error saving orderIDs index for account %s", accID), } }, "ok/no-old": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(ordersByAccountIDTable): return nil, database.ErrNotFound default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { case string(ordersByAccountIDTable): assert.Equals(t, key, []byte(accID)) assert.Equals(t, old, nil) assert.Equals(t, nu, nil) return nil, true, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force") } }, }, res: []string{}, } }, "ok/all-old-not-pending": func(t *testing.T) test { oldOids := []string{"foo", "bar"} bOldOids, err := json.Marshal(oldOids) assert.FatalError(t, err) expiry := clock.Now().Add(-5 * time.Minute) ofoo := &dbOrder{ ID: "foo", Status: acme.StatusPending, ExpiresAt: expiry, } bfoo, err := json.Marshal(ofoo) assert.FatalError(t, err) obar := &dbOrder{ ID: "bar", Status: acme.StatusPending, ExpiresAt: expiry, } bbar, err := json.Marshal(obar) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(ordersByAccountIDTable): return bOldOids, nil case string(orderTable): switch string(key) { case "foo": assert.Equals(t, key, []byte("foo")) return bfoo, nil case "bar": assert.Equals(t, key, []byte("bar")) return bbar, nil default: assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) return nil, errors.New("force") } default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { case string(orderTable): return nil, true, nil case string(ordersByAccountIDTable): assert.Equals(t, key, []byte(accID)) assert.Equals(t, old, bOldOids) assert.Equals(t, nu, nil) return nil, true, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force") } }, }, res: []string{}, } }, "ok/old-and-new": func(t *testing.T) test { oldOids := []string{"foo", "bar"} bOldOids, err := json.Marshal(oldOids) assert.FatalError(t, err) addOids := []string{"zap", "zar"} bAddOids, err := json.Marshal(addOids) assert.FatalError(t, err) expiry := clock.Now().Add(-5 * time.Minute) ofoo := &dbOrder{ ID: "foo", Status: acme.StatusPending, ExpiresAt: expiry, } bfoo, err := json.Marshal(ofoo) assert.FatalError(t, err) obar := &dbOrder{ ID: "bar", Status: acme.StatusPending, ExpiresAt: expiry, } bbar, err := json.Marshal(obar) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(ordersByAccountIDTable): return bOldOids, nil case string(orderTable): switch string(key) { case "foo": assert.Equals(t, key, []byte("foo")) return bfoo, nil case "bar": assert.Equals(t, key, []byte("bar")) return bbar, nil default: assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) return nil, errors.New("force") } default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { case string(orderTable): return nil, true, nil case string(ordersByAccountIDTable): assert.Equals(t, key, []byte(accID)) assert.Equals(t, old, bOldOids) assert.Equals(t, nu, bAddOids) return nil, true, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force") } }, }, addOids: addOids, res: addOids, } }, "ok/old-and-new-2": func(t *testing.T) test { oldOids := []string{"foo", "bar", "baz"} bOldOids, err := json.Marshal(oldOids) assert.FatalError(t, err) addOids := []string{"zap", "zar"} now := clock.Now() min5 := now.Add(5 * time.Minute) expiry := now.Add(-5 * time.Minute) o1 := &dbOrder{ ID: "foo", Status: acme.StatusPending, ExpiresAt: min5, AuthorizationIDs: []string{"a"}, } bo1, err := json.Marshal(o1) assert.FatalError(t, err) o2 := &dbOrder{ ID: "bar", Status: acme.StatusPending, ExpiresAt: expiry, } bo2, err := json.Marshal(o2) assert.FatalError(t, err) o3 := &dbOrder{ ID: "baz", Status: acme.StatusPending, ExpiresAt: min5, AuthorizationIDs: []string{"b"}, } bo3, err := json.Marshal(o3) assert.FatalError(t, err) az1 := &dbAuthz{ ID: "a", Status: acme.StatusPending, ExpiresAt: min5, ChallengeIDs: []string{"aa"}, } baz1, err := json.Marshal(az1) assert.FatalError(t, err) az2 := &dbAuthz{ ID: "b", Status: acme.StatusPending, ExpiresAt: min5, ChallengeIDs: []string{"bb"}, } baz2, err := json.Marshal(az2) assert.FatalError(t, err) ch1 := &dbChallenge{ ID: "aa", Status: acme.StatusPending, } bch1, err := json.Marshal(ch1) assert.FatalError(t, err) ch2 := &dbChallenge{ ID: "bb", Status: acme.StatusPending, } bch2, err := json.Marshal(ch2) assert.FatalError(t, err) newOids := append([]string{"foo", "baz"}, addOids...) bNewOids, err := json.Marshal(newOids) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { switch string(bucket) { case string(authzTable): switch string(key) { case "a": return baz1, nil case "b": return baz2, nil default: assert.FatalError(t, errors.Errorf("unexpected authz key %s", string(key))) return nil, errors.New("force") } case string(challengeTable): switch string(key) { case "aa": return bch1, nil case "bb": return bch2, nil default: assert.FatalError(t, errors.Errorf("unexpected challenge key %s", string(key))) return nil, errors.New("force") } case string(ordersByAccountIDTable): return bOldOids, nil case string(orderTable): switch string(key) { case "foo": return bo1, nil case "bar": return bo2, nil case "baz": return bo3, nil default: assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) return nil, errors.New("force") } default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { switch string(bucket) { case string(orderTable): return nil, true, nil case string(ordersByAccountIDTable): assert.Equals(t, key, []byte(accID)) assert.Equals(t, old, bOldOids) assert.Equals(t, nu, bNewOids) return nil, true, nil default: assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force") } }, }, addOids: addOids, res: newOids, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} var ( res []string err error ) if tc.addOids == nil { res, err = d.updateAddOrderIDs(context.Background(), accID, false) } else { res, err = d.updateAddOrderIDs(context.Background(), accID, false, tc.addOids...) } if err != nil { var ae *acme.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.acmeErr) { assert.Equals(t, ae.Type, tc.acmeErr.Type) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) assert.Equals(t, ae.Status, tc.acmeErr.Status) assert.Equals(t, ae.Err.Error(), tc.acmeErr.Err.Error()) assert.Equals(t, ae.Detail, tc.acmeErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.True(t, reflect.DeepEqual(res, tc.res)) } }) } } ================================================ FILE: acme/db/nosql/wire.go ================================================ package nosql import ( "context" "encoding/json" "fmt" "time" "github.com/smallstep/certificates/acme" "github.com/smallstep/nosql" ) type dbDpopToken struct { ID string `json:"id"` Content []byte `json:"content"` CreatedAt time.Time `json:"createdAt"` } // getDBDpopToken retrieves and unmarshals an DPoP type from the database. func (db *DB) getDBDpopToken(_ context.Context, orderID string) (*dbDpopToken, error) { b, err := db.db.Get(wireDpopTokenTable, []byte(orderID)) if err != nil { if nosql.IsErrNotFound(err) { return nil, acme.NewError(acme.ErrorMalformedType, "dpop token %q not found", orderID) } return nil, fmt.Errorf("failed loading dpop token %q: %w", orderID, err) } d := new(dbDpopToken) if err := json.Unmarshal(b, d); err != nil { return nil, fmt.Errorf("failed unmarshaling dpop token %q into dbDpopToken: %w", orderID, err) } return d, nil } // GetDpopToken retrieves an DPoP from the database. func (db *DB) GetDpopToken(ctx context.Context, orderID string) (map[string]any, error) { dbDpop, err := db.getDBDpopToken(ctx, orderID) if err != nil { return nil, err } dpop := make(map[string]any) err = json.Unmarshal(dbDpop.Content, &dpop) return dpop, err } // CreateDpopToken creates DPoP resources and saves them to the DB. func (db *DB) CreateDpopToken(ctx context.Context, orderID string, dpop map[string]any) error { content, err := json.Marshal(dpop) if err != nil { return fmt.Errorf("failed marshaling dpop token: %w", err) } now := clock.Now() dbDpop := &dbDpopToken{ ID: orderID, Content: content, CreatedAt: now, } if err := db.save(ctx, orderID, dbDpop, nil, "dpop", wireDpopTokenTable); err != nil { return fmt.Errorf("failed saving dpop token: %w", err) } return nil } type dbOidcToken struct { ID string `json:"id"` Content []byte `json:"content"` CreatedAt time.Time `json:"createdAt"` } // getDBOidcToken retrieves and unmarshals an OIDC id token type from the database. func (db *DB) getDBOidcToken(_ context.Context, orderID string) (*dbOidcToken, error) { b, err := db.db.Get(wireOidcTokenTable, []byte(orderID)) if err != nil { if nosql.IsErrNotFound(err) { return nil, acme.NewError(acme.ErrorMalformedType, "oidc token %q not found", orderID) } return nil, fmt.Errorf("failed loading oidc token %q: %w", orderID, err) } o := new(dbOidcToken) if err := json.Unmarshal(b, o); err != nil { return nil, fmt.Errorf("failed unmarshaling oidc token %q into dbOidcToken: %w", orderID, err) } return o, nil } // GetOidcToken retrieves an oidc token from the database. func (db *DB) GetOidcToken(ctx context.Context, orderID string) (map[string]any, error) { dbOidc, err := db.getDBOidcToken(ctx, orderID) if err != nil { return nil, err } idToken := make(map[string]any) err = json.Unmarshal(dbOidc.Content, &idToken) return idToken, err } // CreateOidcToken creates oidc token resources and saves them to the DB. func (db *DB) CreateOidcToken(ctx context.Context, orderID string, idToken map[string]any) error { content, err := json.Marshal(idToken) if err != nil { return fmt.Errorf("failed marshaling oidc token: %w", err) } now := clock.Now() dbOidc := &dbOidcToken{ ID: orderID, Content: content, CreatedAt: now, } if err := db.save(ctx, orderID, dbOidc, nil, "oidc", wireOidcTokenTable); err != nil { return fmt.Errorf("failed saving oidc token: %w", err) } return nil } ================================================ FILE: acme/db/nosql/wire_test.go ================================================ package nosql import ( "context" "encoding/json" "errors" "testing" "time" "github.com/smallstep/certificates/acme" certificatesdb "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDB_GetDpopToken(t *testing.T) { type test struct { db *DB orderID string expected map[string]any expectedErr error } var tests = map[string]func(t *testing.T) test{ "fail/acme-not-found": func(t *testing.T) test { dir := t.TempDir() db, err := nosql.New("badgerv2", dir) require.NoError(t, err) return test{ db: &DB{ db: db, }, orderID: "orderID", expectedErr: &acme.Error{ Type: "urn:ietf:params:acme:error:malformed", Status: 400, Detail: "The request message was malformed", Err: errors.New(`dpop token "orderID" not found`), }, } }, "fail/unmarshal-error": func(t *testing.T) test { dir := t.TempDir() db, err := nosql.New("badgerv2", dir) require.NoError(t, err) token := dbDpopToken{ ID: "orderID", Content: []byte("{}"), CreatedAt: time.Now(), } b, err := json.Marshal(token) require.NoError(t, err) err = db.Set(wireDpopTokenTable, []byte("orderID"), b[1:]) // start at index 1; corrupt JSON data require.NoError(t, err) return test{ db: &DB{ db: db, }, orderID: "orderID", expectedErr: errors.New(`failed unmarshaling dpop token "orderID" into dbDpopToken: invalid character ':' after top-level value`), } }, "fail/db.Get": func(t *testing.T) test { db := &certificatesdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equal(t, wireDpopTokenTable, bucket) assert.Equal(t, []byte("orderID"), key) return nil, errors.New("fail") }, } return test{ db: &DB{ db: db, }, orderID: "orderID", expectedErr: errors.New(`failed loading dpop token "orderID": fail`), } }, "ok": func(t *testing.T) test { dir := t.TempDir() db, err := nosql.New("badgerv2", dir) require.NoError(t, err) token := dbDpopToken{ ID: "orderID", Content: []byte(`{"sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com"}`), CreatedAt: time.Now(), } b, err := json.Marshal(token) require.NoError(t, err) err = db.Set(wireDpopTokenTable, []byte("orderID"), b) require.NoError(t, err) return test{ db: &DB{ db: db, }, orderID: "orderID", expected: map[string]any{ "sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com", }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { got, err := tc.db.GetDpopToken(context.Background(), tc.orderID) if tc.expectedErr != nil { assert.EqualError(t, err, tc.expectedErr.Error()) ae := &acme.Error{} if errors.As(err, &ae) { ee := &acme.Error{} require.True(t, errors.As(tc.expectedErr, &ee)) assert.Equal(t, ee.Detail, ae.Detail) assert.Equal(t, ee.Type, ae.Type) assert.Equal(t, ee.Status, ae.Status) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tc.expected, got) }) } } func TestDB_CreateDpopToken(t *testing.T) { type test struct { db *DB orderID string dpop map[string]any expectedErr error } var tests = map[string]func(t *testing.T) test{ "fail/db.Save": func(t *testing.T) test { db := &certificatesdb.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equal(t, wireDpopTokenTable, bucket) assert.Equal(t, []byte("orderID"), key) return nil, false, errors.New("fail") }, } return test{ db: &DB{ db: db, }, orderID: "orderID", dpop: map[string]any{ "sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com", }, expectedErr: errors.New("failed saving dpop token: error saving acme dpop: fail"), } }, "ok": func(t *testing.T) test { dir := t.TempDir() db, err := nosql.New("badgerv2", dir) require.NoError(t, err) return test{ db: &DB{ db: db, }, orderID: "orderID", dpop: map[string]any{ "sub": "wireapp://guVX5xeFS3eTatmXBIyA4A!7a41cf5b79683410@wire.com", }, } }, "ok/nil": func(t *testing.T) test { dir := t.TempDir() db, err := nosql.New("badgerv2", dir) require.NoError(t, err) return test{ db: &DB{ db: db, }, orderID: "orderID", dpop: nil, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { err := tc.db.CreateDpopToken(context.Background(), tc.orderID, tc.dpop) if tc.expectedErr != nil { assert.EqualError(t, err, tc.expectedErr.Error()) return } assert.NoError(t, err) dpop, err := tc.db.getDBDpopToken(context.Background(), tc.orderID) require.NoError(t, err) assert.Equal(t, tc.orderID, dpop.ID) var m map[string]any err = json.Unmarshal(dpop.Content, &m) require.NoError(t, err) assert.Equal(t, tc.dpop, m) }) } } func TestDB_GetOidcToken(t *testing.T) { type test struct { db *DB orderID string expected map[string]any expectedErr error } var tests = map[string]func(t *testing.T) test{ "fail/acme-not-found": func(t *testing.T) test { dir := t.TempDir() db, err := nosql.New("badgerv2", dir) require.NoError(t, err) return test{ db: &DB{ db: db, }, orderID: "orderID", expectedErr: &acme.Error{ Type: "urn:ietf:params:acme:error:malformed", Status: 400, Detail: "The request message was malformed", Err: errors.New(`oidc token "orderID" not found`), }, } }, "fail/unmarshal-error": func(t *testing.T) test { dir := t.TempDir() db, err := nosql.New("badgerv2", dir) require.NoError(t, err) token := dbOidcToken{ ID: "orderID", Content: []byte("{}"), CreatedAt: time.Now(), } b, err := json.Marshal(token) require.NoError(t, err) err = db.Set(wireOidcTokenTable, []byte("orderID"), b[1:]) // start at index 1; corrupt JSON data require.NoError(t, err) return test{ db: &DB{ db: db, }, orderID: "orderID", expectedErr: errors.New(`failed unmarshaling oidc token "orderID" into dbOidcToken: invalid character ':' after top-level value`), } }, "fail/db.Get": func(t *testing.T) test { db := &certificatesdb.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equal(t, wireOidcTokenTable, bucket) assert.Equal(t, []byte("orderID"), key) return nil, errors.New("fail") }, } return test{ db: &DB{ db: db, }, orderID: "orderID", expectedErr: errors.New(`failed loading oidc token "orderID": fail`), } }, "ok": func(t *testing.T) test { dir := t.TempDir() db, err := nosql.New("badgerv2", dir) require.NoError(t, err) token := dbOidcToken{ ID: "orderID", Content: []byte(`{"name": "Alice Smith", "preferred_username": "@alice.smith"}`), CreatedAt: time.Now(), } b, err := json.Marshal(token) require.NoError(t, err) err = db.Set(wireOidcTokenTable, []byte("orderID"), b) require.NoError(t, err) return test{ db: &DB{ db: db, }, orderID: "orderID", expected: map[string]any{ "name": "Alice Smith", "preferred_username": "@alice.smith", }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { got, err := tc.db.GetOidcToken(context.Background(), tc.orderID) if tc.expectedErr != nil { assert.EqualError(t, err, tc.expectedErr.Error()) ae := &acme.Error{} if errors.As(err, &ae) { ee := &acme.Error{} require.True(t, errors.As(tc.expectedErr, &ee)) assert.Equal(t, ee.Detail, ae.Detail) assert.Equal(t, ee.Type, ae.Type) assert.Equal(t, ee.Status, ae.Status) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tc.expected, got) }) } } func TestDB_CreateOidcToken(t *testing.T) { type test struct { db *DB orderID string oidc map[string]any expectedErr error } var tests = map[string]func(t *testing.T) test{ "fail/db.Save": func(t *testing.T) test { db := &certificatesdb.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equal(t, wireOidcTokenTable, bucket) assert.Equal(t, []byte("orderID"), key) return nil, false, errors.New("fail") }, } return test{ db: &DB{ db: db, }, orderID: "orderID", oidc: map[string]any{ "name": "Alice Smith", "preferred_username": "@alice.smith", }, expectedErr: errors.New("failed saving oidc token: error saving acme oidc: fail"), } }, "ok": func(t *testing.T) test { dir := t.TempDir() db, err := nosql.New("badgerv2", dir) require.NoError(t, err) return test{ db: &DB{ db: db, }, orderID: "orderID", oidc: map[string]any{ "name": "Alice Smith", "preferred_username": "@alice.smith", }, } }, "ok/nil": func(t *testing.T) test { dir := t.TempDir() db, err := nosql.New("badgerv2", dir) require.NoError(t, err) return test{ db: &DB{ db: db, }, orderID: "orderID", oidc: nil, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { err := tc.db.CreateOidcToken(context.Background(), tc.orderID, tc.oidc) if tc.expectedErr != nil { assert.EqualError(t, err, tc.expectedErr.Error()) return } assert.NoError(t, err) oidc, err := tc.db.getDBOidcToken(context.Background(), tc.orderID) require.NoError(t, err) assert.Equal(t, tc.orderID, oidc.ID) var m map[string]any err = json.Unmarshal(oidc.Content, &m) require.NoError(t, err) assert.Equal(t, tc.oidc, m) }) } } ================================================ FILE: acme/db.go ================================================ package acme import ( "context" "database/sql" "github.com/pkg/errors" ) // ErrNotFound is an error that should be used by the acme.DB interface to // indicate that an entity does not exist. For example, in the new-account // endpoint, if GetAccountByKeyID returns ErrNotFound we will create the new // account. var ErrNotFound = errors.New("not found") // IsErrNotFound returns true if the error is a "not found" error. Returns false // otherwise. func IsErrNotFound(err error) bool { return errors.Is(err, ErrNotFound) || errors.Is(err, sql.ErrNoRows) } // DB is the DB interface expected by the step-ca ACME API. type DB interface { CreateAccount(ctx context.Context, acc *Account) error GetAccount(ctx context.Context, id string) (*Account, error) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) UpdateAccount(ctx context.Context, acc *Account) error CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error CreateNonce(ctx context.Context) (Nonce, error) DeleteNonce(ctx context.Context, nonce Nonce) error CreateAuthorization(ctx context.Context, az *Authorization) error GetAuthorization(ctx context.Context, id string) (*Authorization, error) UpdateAuthorization(ctx context.Context, az *Authorization) error GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) CreateCertificate(ctx context.Context, cert *Certificate) error GetCertificate(ctx context.Context, id string) (*Certificate, error) GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error) CreateChallenge(ctx context.Context, ch *Challenge) error GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error) UpdateChallenge(ctx context.Context, ch *Challenge) error CreateOrder(ctx context.Context, o *Order) error GetOrder(ctx context.Context, id string) (*Order, error) GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) UpdateOrder(ctx context.Context, o *Order) error } // WireDB is the interface used for operations on ACME Orders for Wire identifiers. This // is not a general purpose interface, and it should only be used when Wire identifiers // are enabled in the CA configuration. Currently it provides a runtime assertion only; // not at compile time. type WireDB interface { DB GetAllOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) CreateDpopToken(ctx context.Context, orderID string, dpop map[string]interface{}) error GetDpopToken(ctx context.Context, orderID string) (map[string]interface{}, error) CreateOidcToken(ctx context.Context, orderID string, idToken map[string]interface{}) error GetOidcToken(ctx context.Context, orderID string) (map[string]interface{}, error) } type dbKey struct{} // NewDatabaseContext adds the given acme database to the context. func NewDatabaseContext(ctx context.Context, db DB) context.Context { return context.WithValue(ctx, dbKey{}, db) } // DatabaseFromContext returns the current acme database from the given context. func DatabaseFromContext(ctx context.Context) (db DB, ok bool) { db, ok = ctx.Value(dbKey{}).(DB) return } // MustDatabaseFromContext returns the current database from the given context. // It will panic if it's not in the context. func MustDatabaseFromContext(ctx context.Context) DB { var ( db DB ok bool ) if db, ok = DatabaseFromContext(ctx); !ok { panic("acme database is not in the context") } return db } // MockDB is an implementation of the DB interface that should only be used as // a mock in tests. type MockDB struct { MockCreateAccount func(ctx context.Context, acc *Account) error MockGetAccount func(ctx context.Context, id string) (*Account, error) MockGetAccountByKeyID func(ctx context.Context, kid string) (*Account, error) MockUpdateAccount func(ctx context.Context, acc *Account) error MockCreateExternalAccountKey func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) MockGetExternalAccountKeyByAccountID func(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error MockCreateNonce func(ctx context.Context) (Nonce, error) MockDeleteNonce func(ctx context.Context, nonce Nonce) error MockCreateAuthorization func(ctx context.Context, az *Authorization) error MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error) MockUpdateAuthorization func(ctx context.Context, az *Authorization) error MockGetAuthorizationsByAccountID func(ctx context.Context, accountID string) ([]*Authorization, error) MockCreateCertificate func(ctx context.Context, cert *Certificate) error MockGetCertificate func(ctx context.Context, id string) (*Certificate, error) MockGetCertificateBySerial func(ctx context.Context, serial string) (*Certificate, error) MockCreateChallenge func(ctx context.Context, ch *Challenge) error MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error) MockUpdateChallenge func(ctx context.Context, ch *Challenge) error MockCreateOrder func(ctx context.Context, o *Order) error MockGetOrder func(ctx context.Context, id string) (*Order, error) MockGetOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error) MockUpdateOrder func(ctx context.Context, o *Order) error MockRet1 interface{} MockError error } // MockWireDB is an implementation of the WireDB interface that should only be used as // a mock in tests. It embeds the MockDB, as it is an extension of the existing database // methods. type MockWireDB struct { MockDB MockGetAllOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error) MockGetDpopToken func(ctx context.Context, orderID string) (map[string]interface{}, error) MockCreateDpopToken func(ctx context.Context, orderID string, dpop map[string]interface{}) error MockGetOidcToken func(ctx context.Context, orderID string) (map[string]interface{}, error) MockCreateOidcToken func(ctx context.Context, orderID string, idToken map[string]interface{}) error } // CreateAccount mock. func (m *MockDB) CreateAccount(ctx context.Context, acc *Account) error { if m.MockCreateAccount != nil { return m.MockCreateAccount(ctx, acc) } else if m.MockError != nil { return m.MockError } return m.MockError } // GetAccount mock. func (m *MockDB) GetAccount(ctx context.Context, id string) (*Account, error) { if m.MockGetAccount != nil { return m.MockGetAccount(ctx, id) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*Account), m.MockError } // GetAccountByKeyID mock func (m *MockDB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) { if m.MockGetAccountByKeyID != nil { return m.MockGetAccountByKeyID(ctx, kid) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*Account), m.MockError } // UpdateAccount mock func (m *MockDB) UpdateAccount(ctx context.Context, acc *Account) error { if m.MockUpdateAccount != nil { return m.MockUpdateAccount(ctx, acc) } else if m.MockError != nil { return m.MockError } return m.MockError } // CreateExternalAccountKey mock func (m *MockDB) CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) { if m.MockCreateExternalAccountKey != nil { return m.MockCreateExternalAccountKey(ctx, provisionerID, reference) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*ExternalAccountKey), m.MockError } // GetExternalAccountKey mock func (m *MockDB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) { if m.MockGetExternalAccountKey != nil { return m.MockGetExternalAccountKey(ctx, provisionerID, keyID) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*ExternalAccountKey), m.MockError } // GetExternalAccountKeys mock func (m *MockDB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) { if m.MockGetExternalAccountKeys != nil { return m.MockGetExternalAccountKeys(ctx, provisionerID, cursor, limit) } else if m.MockError != nil { return nil, "", m.MockError } return m.MockRet1.([]*ExternalAccountKey), "", m.MockError } // GetExternalAccountKeyByReference mock func (m *MockDB) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) { if m.MockGetExternalAccountKeyByReference != nil { return m.MockGetExternalAccountKeyByReference(ctx, provisionerID, reference) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*ExternalAccountKey), m.MockError } // GetExternalAccountKeyByAccountID mock func (m *MockDB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) { if m.MockGetExternalAccountKeyByAccountID != nil { return m.MockGetExternalAccountKeyByAccountID(ctx, provisionerID, accountID) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*ExternalAccountKey), m.MockError } // DeleteExternalAccountKey mock func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error { if m.MockDeleteExternalAccountKey != nil { return m.MockDeleteExternalAccountKey(ctx, provisionerID, keyID) } else if m.MockError != nil { return m.MockError } return m.MockError } // UpdateExternalAccountKey mock func (m *MockDB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error { if m.MockUpdateExternalAccountKey != nil { return m.MockUpdateExternalAccountKey(ctx, provisionerID, eak) } else if m.MockError != nil { return m.MockError } return m.MockError } // CreateNonce mock func (m *MockDB) CreateNonce(ctx context.Context) (Nonce, error) { if m.MockCreateNonce != nil { return m.MockCreateNonce(ctx) } else if m.MockError != nil { return Nonce(""), m.MockError } return m.MockRet1.(Nonce), m.MockError } // DeleteNonce mock func (m *MockDB) DeleteNonce(ctx context.Context, nonce Nonce) error { if m.MockDeleteNonce != nil { return m.MockDeleteNonce(ctx, nonce) } else if m.MockError != nil { return m.MockError } return m.MockError } // CreateAuthorization mock func (m *MockDB) CreateAuthorization(ctx context.Context, az *Authorization) error { if m.MockCreateAuthorization != nil { return m.MockCreateAuthorization(ctx, az) } else if m.MockError != nil { return m.MockError } return m.MockError } // GetAuthorization mock func (m *MockDB) GetAuthorization(ctx context.Context, id string) (*Authorization, error) { if m.MockGetAuthorization != nil { return m.MockGetAuthorization(ctx, id) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*Authorization), m.MockError } // UpdateAuthorization mock func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) error { if m.MockUpdateAuthorization != nil { return m.MockUpdateAuthorization(ctx, az) } else if m.MockError != nil { return m.MockError } return m.MockError } // GetAuthorizationsByAccountID mock func (m *MockDB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) { if m.MockGetAuthorizationsByAccountID != nil { return m.MockGetAuthorizationsByAccountID(ctx, accountID) } else if m.MockError != nil { return nil, m.MockError } return nil, m.MockError } // CreateCertificate mock func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error { if m.MockCreateCertificate != nil { return m.MockCreateCertificate(ctx, cert) } else if m.MockError != nil { return m.MockError } return m.MockError } // GetCertificate mock func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, error) { if m.MockGetCertificate != nil { return m.MockGetCertificate(ctx, id) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*Certificate), m.MockError } // GetCertificateBySerial mock func (m *MockDB) GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error) { if m.MockGetCertificateBySerial != nil { return m.MockGetCertificateBySerial(ctx, serial) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*Certificate), m.MockError } // CreateChallenge mock func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error { if m.MockCreateChallenge != nil { return m.MockCreateChallenge(ctx, ch) } else if m.MockError != nil { return m.MockError } return m.MockError } // GetChallenge mock func (m *MockDB) GetChallenge(ctx context.Context, chID, azID string) (*Challenge, error) { if m.MockGetChallenge != nil { return m.MockGetChallenge(ctx, chID, azID) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*Challenge), m.MockError } // UpdateChallenge mock func (m *MockDB) UpdateChallenge(ctx context.Context, ch *Challenge) error { if m.MockUpdateChallenge != nil { return m.MockUpdateChallenge(ctx, ch) } else if m.MockError != nil { return m.MockError } return m.MockError } // CreateOrder mock func (m *MockDB) CreateOrder(ctx context.Context, o *Order) error { if m.MockCreateOrder != nil { return m.MockCreateOrder(ctx, o) } else if m.MockError != nil { return m.MockError } return m.MockError } // GetOrder mock func (m *MockDB) GetOrder(ctx context.Context, id string) (*Order, error) { if m.MockGetOrder != nil { return m.MockGetOrder(ctx, id) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*Order), m.MockError } // UpdateOrder mock func (m *MockDB) UpdateOrder(ctx context.Context, o *Order) error { if m.MockUpdateOrder != nil { return m.MockUpdateOrder(ctx, o) } else if m.MockError != nil { return m.MockError } return m.MockError } // GetOrdersByAccountID mock func (m *MockDB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { if m.MockGetOrdersByAccountID != nil { return m.MockGetOrdersByAccountID(ctx, accID) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.([]string), m.MockError } // GetAllOrdersByAccountID returns a list of any order IDs owned by the account. func (m *MockWireDB) GetAllOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) { if m.MockGetAllOrdersByAccountID != nil { return m.MockGetAllOrdersByAccountID(ctx, accountID) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.([]string), m.MockError } // GetDpop retrieves a DPoP from the database. func (m *MockWireDB) GetDpopToken(ctx context.Context, orderID string) (map[string]any, error) { if m.MockGetDpopToken != nil { return m.MockGetDpopToken(ctx, orderID) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(map[string]any), m.MockError } // CreateDpop creates DPoP resources and saves them to the DB. func (m *MockWireDB) CreateDpopToken(ctx context.Context, orderID string, dpop map[string]any) error { if m.MockCreateDpopToken != nil { return m.MockCreateDpopToken(ctx, orderID, dpop) } return m.MockError } // GetOidcToken retrieves an oidc token from the database. func (m *MockWireDB) GetOidcToken(ctx context.Context, orderID string) (map[string]any, error) { if m.MockGetOidcToken != nil { return m.MockGetOidcToken(ctx, orderID) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(map[string]any), m.MockError } // CreateOidcToken creates oidc token resources and saves them to the DB. func (m *MockWireDB) CreateOidcToken(ctx context.Context, orderID string, idToken map[string]any) error { if m.MockCreateOidcToken != nil { return m.MockCreateOidcToken(ctx, orderID, idToken) } return m.MockError } ================================================ FILE: acme/db_test.go ================================================ package acme import ( "database/sql" "errors" "fmt" "testing" ) func TestIsErrNotFound(t *testing.T) { type args struct { err error } tests := []struct { name string args args want bool }{ {"true ErrNotFound", args{ErrNotFound}, true}, {"true sql.ErrNoRows", args{sql.ErrNoRows}, true}, {"true wrapped ErrNotFound", args{fmt.Errorf("something failed: %w", ErrNotFound)}, true}, {"true wrapped sql.ErrNoRows", args{fmt.Errorf("something failed: %w", sql.ErrNoRows)}, true}, {"false other", args{errors.New("not found")}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := IsErrNotFound(tt.args.err); got != tt.want { t.Errorf("IsErrNotFound() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: acme/errors.go ================================================ package acme import ( "encoding/json" "fmt" "net/http" "github.com/pkg/errors" "github.com/smallstep/certificates/api/render" ) // ProblemType is the type of the ACME problem. type ProblemType int const ( // ErrorAccountDoesNotExistType request specified an account that does not exist ErrorAccountDoesNotExistType ProblemType = iota // ErrorAlreadyRevokedType request specified a certificate to be revoked that has already been revoked ErrorAlreadyRevokedType // ErrorBadAttestationStatementType WebAuthn attestation statement could not be verified ErrorBadAttestationStatementType // ErrorBadCSRType CSR is unacceptable (e.g., due to a short key) ErrorBadCSRType // ErrorBadNonceType client sent an unacceptable anti-replay nonce ErrorBadNonceType // ErrorBadPublicKeyType JWS was signed by a public key the server does not support ErrorBadPublicKeyType // ErrorBadRevocationReasonType revocation reason provided is not allowed by the server ErrorBadRevocationReasonType // ErrorBadSignatureAlgorithmType JWS was signed with an algorithm the server does not support ErrorBadSignatureAlgorithmType // ErrorCaaType Authority Authorization (CAA) records forbid the CA from issuing a certificate ErrorCaaType // ErrorCompoundType error conditions are indicated in the “subproblems” array. ErrorCompoundType // ErrorConnectionType server could not connect to validation target ErrorConnectionType // ErrorDNSType was a problem with a DNS query during identifier validation ErrorDNSType // ErrorExternalAccountRequiredType request must include a value for the “externalAccountBinding” field ErrorExternalAccountRequiredType // ErrorIncorrectResponseType received didn’t match the challenge’s requirements ErrorIncorrectResponseType // ErrorInvalidContactType URL for an account was invalid ErrorInvalidContactType // ErrorMalformedType request message was malformed ErrorMalformedType // ErrorOrderNotReadyType request attempted to finalize an order that is not ready to be finalized ErrorOrderNotReadyType // ErrorRateLimitedType request exceeds a rate limit ErrorRateLimitedType // ErrorRejectedIdentifierType server will not issue certificates for the identifier ErrorRejectedIdentifierType // ErrorServerInternalType server experienced an internal error ErrorServerInternalType // ErrorTLSType server received a TLS error during validation ErrorTLSType // ErrorUnauthorizedType client lacks sufficient authorization ErrorUnauthorizedType // ErrorUnsupportedContactType URL for an account used an unsupported protocol scheme ErrorUnsupportedContactType // ErrorUnsupportedIdentifierType identifier is of an unsupported type ErrorUnsupportedIdentifierType // ErrorUserActionRequiredType the “instance” URL and take actions specified there ErrorUserActionRequiredType // ErrorNotImplementedType operation is not implemented ErrorNotImplementedType ) // String returns the string representation of the acme problem type, // fulfilling the Stringer interface. func (ap ProblemType) String() string { switch ap { case ErrorAccountDoesNotExistType: return "accountDoesNotExist" case ErrorAlreadyRevokedType: return "alreadyRevoked" case ErrorBadAttestationStatementType: return "badAttestationStatement" case ErrorBadCSRType: return "badCSR" case ErrorBadNonceType: return "badNonce" case ErrorBadPublicKeyType: return "badPublicKey" case ErrorBadRevocationReasonType: return "badRevocationReason" case ErrorBadSignatureAlgorithmType: return "badSignatureAlgorithm" case ErrorCaaType: return "caa" case ErrorCompoundType: return "compound" case ErrorConnectionType: return "connection" case ErrorDNSType: return "dns" case ErrorExternalAccountRequiredType: return "externalAccountRequired" case ErrorInvalidContactType: return "incorrectResponse" case ErrorMalformedType: return "malformed" case ErrorOrderNotReadyType: return "orderNotReady" case ErrorRateLimitedType: return "rateLimited" case ErrorRejectedIdentifierType: return "rejectedIdentifier" case ErrorServerInternalType: return "serverInternal" case ErrorTLSType: return "tls" case ErrorUnauthorizedType: return "unauthorized" case ErrorUnsupportedContactType: return "unsupportedContact" case ErrorUnsupportedIdentifierType: return "unsupportedIdentifier" case ErrorUserActionRequiredType: return "userActionRequired" case ErrorNotImplementedType: return "notImplemented" default: return fmt.Sprintf("unsupported type ACME error type '%d'", int(ap)) } } type errorMetadata struct { details string status int typ string String string } var ( officialACMEPrefix = "urn:ietf:params:acme:error:" errorServerInternalMetadata = errorMetadata{ typ: officialACMEPrefix + ErrorServerInternalType.String(), details: "The server experienced an internal error", status: 500, } errorMap = map[ProblemType]errorMetadata{ ErrorAccountDoesNotExistType: { typ: officialACMEPrefix + ErrorAccountDoesNotExistType.String(), details: "Account does not exist", status: 400, }, ErrorAlreadyRevokedType: { typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(), details: "Certificate already revoked", status: 400, }, ErrorBadCSRType: { typ: officialACMEPrefix + ErrorBadCSRType.String(), details: "The CSR is unacceptable", status: 400, }, ErrorBadNonceType: { typ: officialACMEPrefix + ErrorBadNonceType.String(), details: "Unacceptable anti-replay nonce", status: 400, }, ErrorBadPublicKeyType: { typ: officialACMEPrefix + ErrorBadPublicKeyType.String(), details: "The jws was signed by a public key the server does not support", status: 400, }, ErrorBadRevocationReasonType: { typ: officialACMEPrefix + ErrorBadRevocationReasonType.String(), details: "The revocation reason provided is not allowed by the server", status: 400, }, ErrorBadSignatureAlgorithmType: { typ: officialACMEPrefix + ErrorBadSignatureAlgorithmType.String(), details: "The JWS was signed with an algorithm the server does not support", status: 400, }, ErrorBadAttestationStatementType: { typ: officialACMEPrefix + ErrorBadAttestationStatementType.String(), details: "Attestation statement cannot be verified", status: 400, }, ErrorCaaType: { typ: officialACMEPrefix + ErrorCaaType.String(), details: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate", status: 400, }, ErrorCompoundType: { typ: officialACMEPrefix + ErrorCompoundType.String(), details: "Specific error conditions are indicated in the “subproblems” array", status: 400, }, ErrorConnectionType: { typ: officialACMEPrefix + ErrorConnectionType.String(), details: "The server could not connect to validation target", status: 400, }, ErrorDNSType: { typ: officialACMEPrefix + ErrorDNSType.String(), details: "There was a problem with a DNS query during identifier validation", status: 400, }, ErrorExternalAccountRequiredType: { typ: officialACMEPrefix + ErrorExternalAccountRequiredType.String(), details: "The request must include a value for the \"externalAccountBinding\" field", status: 400, }, ErrorIncorrectResponseType: { typ: officialACMEPrefix + ErrorIncorrectResponseType.String(), details: "Response received didn't match the challenge's requirements", status: 400, }, ErrorInvalidContactType: { typ: officialACMEPrefix + ErrorInvalidContactType.String(), details: "A contact URL for an account was invalid", status: 400, }, ErrorMalformedType: { typ: officialACMEPrefix + ErrorMalformedType.String(), details: "The request message was malformed", status: 400, }, ErrorOrderNotReadyType: { typ: officialACMEPrefix + ErrorOrderNotReadyType.String(), details: "The request attempted to finalize an order that is not ready to be finalized", status: 400, }, ErrorRateLimitedType: { typ: officialACMEPrefix + ErrorRateLimitedType.String(), details: "The request exceeds a rate limit", status: 400, }, ErrorRejectedIdentifierType: { typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(), details: "The server will not issue certificates for the identifier", status: 400, }, ErrorNotImplementedType: { typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(), details: "The requested operation is not implemented", status: 501, }, ErrorTLSType: { typ: officialACMEPrefix + ErrorTLSType.String(), details: "The server received a TLS error during validation", status: 400, }, ErrorUnauthorizedType: { typ: officialACMEPrefix + ErrorUnauthorizedType.String(), details: "The client lacks sufficient authorization", status: 401, }, ErrorUnsupportedContactType: { typ: officialACMEPrefix + ErrorUnsupportedContactType.String(), details: "A contact URL for an account used an unsupported protocol scheme", status: 400, }, ErrorUnsupportedIdentifierType: { typ: officialACMEPrefix + ErrorUnsupportedIdentifierType.String(), details: "An identifier is of an unsupported type", status: 400, }, ErrorUserActionRequiredType: { typ: officialACMEPrefix + ErrorUserActionRequiredType.String(), details: "Visit the “instance” URL and take actions specified there", status: 400, }, ErrorServerInternalType: errorServerInternalMetadata, } ) // Error represents an ACME Error type Error struct { Type string `json:"type"` Detail string `json:"detail"` Subproblems []Subproblem `json:"subproblems,omitempty"` Err error `json:"-"` Status int `json:"-"` } // Subproblem represents an ACME subproblem. It's fairly // similar to an ACME error, but differs in that it can't // include subproblems itself, the error is reflected // in the Detail property and doesn't have a Status. type Subproblem struct { Type string `json:"type"` Detail string `json:"detail"` // The "identifier" field MUST NOT be present at the top level in ACME // problem documents. It can only be present in subproblems. // Subproblems need not all have the same type, and they do not need to // match the top level type. Identifier *Identifier `json:"identifier,omitempty"` } // NewError creates a new Error. func NewError(pt ProblemType, msg string, args ...any) *Error { return newError(pt, errors.Errorf(msg, args...)) } // NewDetailedError creates a new Error that includes the error // message in the details, providing more information to the // ACME client. func NewDetailedError(pt ProblemType, msg string, args ...any) *Error { return NewError(pt, msg, args...).withDetail() } func (e *Error) withDetail() *Error { if e == nil || e.Status >= 500 || e.Err == nil { return e } e.Detail = fmt.Sprintf("%s: %s", e.Detail, e.Err) return e } // AddSubproblems adds the Subproblems to Error. It // returns the Error, allowing for fluent addition. func (e *Error) AddSubproblems(subproblems ...Subproblem) *Error { e.Subproblems = append(e.Subproblems, subproblems...) return e } // NewSubproblem creates a new Subproblem. The msg and args // are used to create a new error, which is set as the Detail, allowing // for more detailed error messages to be returned to the ACME client. func NewSubproblem(pt ProblemType, msg string, args ...any) Subproblem { e := newError(pt, fmt.Errorf(msg, args...)) s := Subproblem{ Type: e.Type, Detail: e.Err.Error(), } return s } // NewSubproblemWithIdentifier creates a new Subproblem with a specific ACME // Identifier. It calls NewSubproblem and sets the Identifier. func NewSubproblemWithIdentifier(pt ProblemType, identifier Identifier, msg string, args ...any) Subproblem { s := NewSubproblem(pt, msg, args...) s.Identifier = &identifier return s } func newError(pt ProblemType, err error) *Error { meta, ok := errorMap[pt] if !ok { meta = errorServerInternalMetadata return &Error{ Type: meta.typ, Detail: meta.details, Status: meta.status, Err: err, } } return &Error{ Type: meta.typ, Detail: meta.details, Status: meta.status, Err: err, } } // NewErrorISE creates a new ErrorServerInternalType Error. func NewErrorISE(msg string, args ...any) *Error { return NewError(ErrorServerInternalType, msg, args...) } // WrapError attempts to wrap the internal error. func WrapError(typ ProblemType, err error, msg string, args ...any) *Error { var e *Error switch { case err == nil: return nil case errors.As(err, &e): if e.Err == nil { e.Err = errors.Errorf(msg+"; "+e.Detail, args...) } else { e.Err = errors.Wrapf(e.Err, msg, args...) } return e default: return newError(typ, errors.Wrapf(err, msg, args...)) } } func WrapDetailedError(typ ProblemType, err error, msg string, args ...any) *Error { return WrapError(typ, err, msg, args...).withDetail() } // WrapErrorISE shortcut to wrap an internal server error type. func WrapErrorISE(err error, msg string, args ...any) *Error { return WrapError(ErrorServerInternalType, err, msg, args...) } // StatusCode returns the status code and implements the StatusCoder interface. func (e *Error) StatusCode() int { return e.Status } // Error implements the error interface. func (e *Error) Error() string { if e.Err == nil { return e.Detail } return e.Err.Error() } // Cause returns the internal error and implements the Causer interface. func (e *Error) Cause() error { if e.Err == nil { return errors.New(e.Detail) } return e.Err } // ToLog implements the EnableLogger interface. func (e *Error) ToLog() (any, error) { b, err := json.Marshal(e) if err != nil { return nil, WrapErrorISE(err, "error marshaling acme.Error for logging") } return string(b), nil } // Render implements render.RenderableError for Error. func (e *Error) Render(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/problem+json") render.JSONStatus(w, r, e, e.StatusCode()) } ================================================ FILE: acme/errors_test.go ================================================ package acme import ( "encoding/json" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func mustJSON(t *testing.T, m map[string]interface{}) string { t.Helper() b, err := json.Marshal(m) require.NoError(t, err) return string(b) } func TestError_WithAdditionalErrorDetail(t *testing.T) { internalJSON := mustJSON(t, map[string]interface{}{ "detail": "The server experienced an internal error", "type": "urn:ietf:params:acme:error:serverInternal", }) malformedErr := NewError(ErrorMalformedType, "malformed error") // will result in Err == nil behavior malformedJSON := mustJSON(t, map[string]interface{}{ "detail": "The request message was malformed", "type": "urn:ietf:params:acme:error:malformed", }) withDetailJSON := mustJSON(t, map[string]interface{}{ "detail": "Attestation statement cannot be verified: invalid property", "type": "urn:ietf:params:acme:error:badAttestationStatement", }) tests := []struct { name string err *Error want string }{ {"internal", NewDetailedError(ErrorServerInternalType, ""), internalJSON}, {"nil err", malformedErr, malformedJSON}, {"detailed", NewDetailedError(ErrorBadAttestationStatementType, "invalid property"), withDetailJSON}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b, err := json.Marshal(tt.err) require.NoError(t, err) // tests if the additional error detail is included in the JSON representation // of the ACME error. This is what is returned to ACME clients and being logged // by the CA. assert.JSONEq(t, tt.want, string(b)) }) } } ================================================ FILE: acme/linker.go ================================================ package acme import ( "context" "fmt" "net" "net/http" "net/url" "strings" "github.com/go-chi/chi/v5" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" ) // LinkType captures the link type. type LinkType int const ( // NewNonceLinkType new-nonce NewNonceLinkType LinkType = iota // NewAccountLinkType new-account NewAccountLinkType // AccountLinkType account AccountLinkType // OrderLinkType order OrderLinkType // NewOrderLinkType new-order NewOrderLinkType // OrdersByAccountLinkType list of orders owned by account OrdersByAccountLinkType // FinalizeLinkType finalize order FinalizeLinkType // NewAuthzLinkType authz NewAuthzLinkType // AuthzLinkType new-authz AuthzLinkType // ChallengeLinkType challenge ChallengeLinkType // CertificateLinkType certificate CertificateLinkType // DirectoryLinkType directory DirectoryLinkType // RevokeCertLinkType revoke certificate RevokeCertLinkType // KeyChangeLinkType key rollover KeyChangeLinkType ) func (l LinkType) String() string { switch l { case NewNonceLinkType: return "new-nonce" case NewAccountLinkType: return "new-account" case AccountLinkType: return "account" case NewOrderLinkType: return "new-order" case OrderLinkType: return "order" case NewAuthzLinkType: return "new-authz" case AuthzLinkType: return "authz" case ChallengeLinkType: return "challenge" case CertificateLinkType: return "certificate" case DirectoryLinkType: return "directory" case RevokeCertLinkType: return "revoke-cert" case KeyChangeLinkType: return "key-change" default: return fmt.Sprintf("unexpected LinkType '%d'", int(l)) } } func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string { switch typ { case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: return fmt.Sprintf("/%s/%s", provisionerName, typ) case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) case ChallengeLinkType: return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) //nolint:gosec // operating on internally defined inputs case OrdersByAccountLinkType: return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) case FinalizeLinkType: return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) default: return "" } } // NewLinker returns a new Directory type. func NewLinker(dns, prefix string) Linker { _, _, err := net.SplitHostPort(dns) if err != nil && strings.Contains(err.Error(), "too many colons in address") { // this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334 // in case a port was appended to this wrong format, we try to extract the port, then check if it's // still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of // these cases, then the input dns is not changed. lastIndex := strings.LastIndex(dns, ":") hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:] if ip := net.ParseIP(hostPart); ip != nil { dns = "[" + hostPart + "]:" + portPart } else if ip := net.ParseIP(dns); ip != nil { dns = "[" + dns + "]" } } return &linker{prefix: prefix, dns: dns} } // Linker interface for generating links for ACME resources. type Linker interface { GetLink(ctx context.Context, typ LinkType, inputs ...string) string Middleware(http.Handler) http.Handler LinkOrder(ctx context.Context, o *Order) LinkAccount(ctx context.Context, o *Account) LinkChallenge(ctx context.Context, o *Challenge, azID string) LinkAuthorization(ctx context.Context, o *Authorization) LinkOrdersByAccountID(ctx context.Context, orders []string) } type linkerKey struct{} // NewLinkerContext adds the given linker to the context. func NewLinkerContext(ctx context.Context, v Linker) context.Context { return context.WithValue(ctx, linkerKey{}, v) } // LinkerFromContext returns the current linker from the given context. func LinkerFromContext(ctx context.Context) (v Linker, ok bool) { v, ok = ctx.Value(linkerKey{}).(Linker) return } // MustLinkerFromContext returns the current linker from the given context. It // will panic if it's not in the context. func MustLinkerFromContext(ctx context.Context) Linker { var ( v Linker ok bool ) if v, ok = LinkerFromContext(ctx); !ok { panic("acme linker is not the context") } return v } type baseURLKey struct{} func newBaseURLContext(ctx context.Context, r *http.Request) context.Context { var u *url.URL if r.Host != "" { u = &url.URL{Scheme: "https", Host: r.Host} } return context.WithValue(ctx, baseURLKey{}, u) } func baseURLFromContext(ctx context.Context) *url.URL { if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok { return u } return nil } // linker generates ACME links. type linker struct { prefix string dns string } // Middleware gets the provisioner and current url from the request and sets // them in the context so we can use the linker to create ACME links. func (l *linker) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Add base url to the context. ctx := newBaseURLContext(r.Context(), r) // Add provisioner to the context. nameEscaped := chi.URLParam(r, "provisionerID") name, err := url.PathUnescape(nameEscaped) if err != nil { render.Error(w, r, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) return } p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name) if err != nil { render.Error(w, r, err) return } var acmeProv *provisioner.ACME switch prov := p.(type) { case *provisioner.ACME: acmeProv = prov case provisioner.Uninitialized: render.Error(w, r, NewDetailedError(ErrorUnauthorizedType, "provisioner is disabled due to an initialization error")) return default: render.Error(w, r, NewDetailedError(ErrorUnauthorizedType, "provisioner must be of type ACME")) return } ctx = NewProvisionerContext(ctx, Provisioner(acmeProv)) next.ServeHTTP(w, r.WithContext(ctx)) }) } // GetLink is a helper for GetLinkExplicit. func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { var name string if p, ok := ProvisionerFromContext(ctx); ok { name = p.GetName() } var u url.URL if baseURL := baseURLFromContext(ctx); baseURL != nil { u = *baseURL } if u.Scheme == "" { u.Scheme = "https" } if u.Host == "" { u.Host = l.dns } u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...) return u.String() } // LinkOrder sets the ACME links required by an ACME order. func (l *linker) LinkOrder(ctx context.Context, o *Order) { o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) for i, azID := range o.AuthorizationIDs { o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID) } o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, o.ID) if o.CertificateID != "" { o.CertificateURL = l.GetLink(ctx, CertificateLinkType, o.CertificateID) } } // LinkAccount sets the ACME links required by an ACME account. func (l *linker) LinkAccount(ctx context.Context, acc *Account) { acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID) } // LinkChallenge sets the ACME links required by an ACME challenge. func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) { ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID) } // LinkAuthorization sets the ACME links required by an ACME authorization. func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) { for _, ch := range az.Challenges { l.LinkChallenge(ctx, ch, az.ID) } } // LinkOrdersByAccountID converts each order ID to an ACME link. func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) { for i, id := range orders { orders[i] = l.GetLink(ctx, OrderLinkType, id) } } ================================================ FILE: acme/linker_test.go ================================================ package acme import ( "context" "fmt" "net/url" "testing" "time" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" ) func mockProvisioner(t *testing.T) Provisioner { t.Helper() var defaultDisableRenewal = false // Initialize provisioners p := &provisioner.ACME{ Type: "ACME", Name: "test@acme-provisioner.com", } if err := p.Init(provisioner.Config{Claims: provisioner.Claims{ MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, DisableRenewal: &defaultDisableRenewal, }}); err != nil { fmt.Printf("%v", err) } return p } func TestGetUnescapedPathSuffix(t *testing.T) { getPath := GetUnescapedPathSuffix assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce") assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory") assert.Equals(t, getPath(NewAccountLinkType, "{provisionerID}"), "/{provisionerID}/new-account") assert.Equals(t, getPath(AccountLinkType, "{provisionerID}", "{accID}"), "/{provisionerID}/account/{accID}") assert.Equals(t, getPath(KeyChangeLinkType, "{provisionerID}"), "/{provisionerID}/key-change") assert.Equals(t, getPath(NewOrderLinkType, "{provisionerID}"), "/{provisionerID}/new-order") assert.Equals(t, getPath(OrderLinkType, "{provisionerID}", "{ordID}"), "/{provisionerID}/order/{ordID}") assert.Equals(t, getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), "/{provisionerID}/account/{accID}/orders") assert.Equals(t, getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), "/{provisionerID}/order/{ordID}/finalize") assert.Equals(t, getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), "/{provisionerID}/authz/{authzID}") assert.Equals(t, getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), "/{provisionerID}/challenge/{authzID}/{chID}") assert.Equals(t, getPath(CertificateLinkType, "{provisionerID}", "{certID}"), "/{provisionerID}/certificate/{certID}") } func TestLinker_DNS(t *testing.T) { prov := mockProvisioner(t) escProvName := url.PathEscape(prov.GetName()) ctx := NewProvisionerContext(context.Background(), prov) type test struct { name string dns string prefix string expectedDirectoryLink string } tests := []test{ { name: "domain", dns: "ca.smallstep.com", prefix: "acme", expectedDirectoryLink: fmt.Sprintf("https://ca.smallstep.com/acme/%s/directory", escProvName), }, { name: "domain-port", dns: "ca.smallstep.com:8443", prefix: "acme", expectedDirectoryLink: fmt.Sprintf("https://ca.smallstep.com:8443/acme/%s/directory", escProvName), }, { name: "ipv4", dns: "127.0.0.1", prefix: "acme", expectedDirectoryLink: fmt.Sprintf("https://127.0.0.1/acme/%s/directory", escProvName), }, { name: "ipv4-port", dns: "127.0.0.1:8443", prefix: "acme", expectedDirectoryLink: fmt.Sprintf("https://127.0.0.1:8443/acme/%s/directory", escProvName), }, { name: "ipv6", dns: "[::1]", prefix: "acme", expectedDirectoryLink: fmt.Sprintf("https://[::1]/acme/%s/directory", escProvName), }, { name: "ipv6-port", dns: "[::1]:8443", prefix: "acme", expectedDirectoryLink: fmt.Sprintf("https://[::1]:8443/acme/%s/directory", escProvName), }, { name: "ipv6-no-brackets", dns: "::1", prefix: "acme", expectedDirectoryLink: fmt.Sprintf("https://[::1]/acme/%s/directory", escProvName), }, { name: "ipv6-port-no-brackets", dns: "::1:8443", prefix: "acme", expectedDirectoryLink: fmt.Sprintf("https://[::1]:8443/acme/%s/directory", escProvName), }, { name: "ipv6-long-no-brackets", dns: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", prefix: "acme", expectedDirectoryLink: fmt.Sprintf("https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]/acme/%s/directory", escProvName), }, { name: "ipv6-long-port-no-brackets", dns: "2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443", prefix: "acme", expectedDirectoryLink: fmt.Sprintf("https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8443/acme/%s/directory", escProvName), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { linker := NewLinker(tt.dns, tt.prefix) assert.Equals(t, tt.expectedDirectoryLink, linker.GetLink(ctx, DirectoryLinkType)) }) } } func TestLinker_GetLink(t *testing.T) { dns := "ca.smallstep.com" prefix := "acme" linker := NewLinker(dns, prefix) id := "1234" prov := mockProvisioner(t) escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} ctx := NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL) // No provisioner and no BaseURL from request assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", "")) // Provisioner: yes, BaseURL: no assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) // Provisioner: no, BaseURL: yes assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewAccountLinkType), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, AccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewOrderLinkType), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, OrderLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, OrdersByAccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, FinalizeLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewAuthzLinkType), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, AuthzLinkType, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, DirectoryLinkType), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, RevokeCertLinkType, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, KeyChangeLinkType), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, ChallengeLinkType, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id)) assert.Equals(t, linker.GetLink(ctx, CertificateLinkType, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName)) } func TestLinker_LinkOrder(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) ctx := NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL) oid := "orderID" certID := "certID" linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { o *Order validate func(o *Order) } var tests = map[string]test{ "no-authz-and-no-cert": { o: &Order{ ID: oid, }, validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{}) assert.Equals(t, o.CertificateURL, "") }, }, "one-authz-and-cert": { o: &Order{ ID: oid, CertificateID: certID, AuthorizationIDs: []string{"foo"}, }, validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), }) assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID)) }, }, "many-authz": { o: &Order{ ID: oid, CertificateID: certID, AuthorizationIDs: []string{"foo", "bar", "zap"}, }, validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "bar"), fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "zap"), }) assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID)) }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { l.LinkOrder(ctx, tc.o) tc.validate(tc.o) }) } } func TestLinker_LinkAccount(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) ctx := NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL) accID := "accountID" linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { a *Account validate func(o *Account) } var tests = map[string]test{ "ok": { a: &Account{ ID: accID, }, validate: func(a *Account) { assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID)) }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { l.LinkAccount(ctx, tc.a) tc.validate(tc.a) }) } } func TestLinker_LinkChallenge(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) ctx := NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL) chID := "chID" azID := "azID" linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { ch *Challenge validate func(o *Challenge) } var tests = map[string]test{ "ok": { ch: &Challenge{ ID: chID, }, validate: func(ch *Challenge) { assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID)) }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { l.LinkChallenge(ctx, tc.ch, azID) tc.validate(tc.ch) }) } } func TestLinker_LinkAuthorization(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) ctx := NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL) chID0 := "chID-0" chID1 := "chID-1" chID2 := "chID-2" azID := "azID" linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { az *Authorization validate func(o *Authorization) } var tests = map[string]test{ "ok": { az: &Authorization{ ID: azID, Challenges: []*Challenge{ {ID: chID0}, {ID: chID1}, {ID: chID2}, }, }, validate: func(az *Authorization) { assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0)) assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1)) assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2)) }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { l.LinkAuthorization(ctx, tc.az) tc.validate(tc.az) }) } } func TestLinker_LinkOrdersByAccountID(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) ctx := NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL) linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { oids []string } var tests = map[string]test{ "ok": { oids: []string{"foo", "bar", "baz"}, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { l.LinkOrdersByAccountID(ctx, tc.oids) assert.Equals(t, tc.oids, []string{ fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "foo"), fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "bar"), fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "baz"), }) }) } } ================================================ FILE: acme/nonce.go ================================================ package acme // Nonce represents an ACME nonce type. type Nonce string // String implements the ToString interface. func (n Nonce) String() string { return string(n) } ================================================ FILE: acme/order.go ================================================ package acme import ( "bytes" "context" "crypto/subtle" "crypto/x509" "encoding/asn1" "encoding/json" "errors" "fmt" "net" "net/url" "sort" "strings" "time" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/acme/wire" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/webhook" ) type IdentifierType string const ( // IP is the ACME ip identifier type IP IdentifierType = "ip" // DNS is the ACME dns identifier type DNS IdentifierType = "dns" // PermanentIdentifier is the ACME permanent-identifier identifier type // defined in https://datatracker.ietf.org/doc/html/draft-bweeks-acme-device-attest-00 PermanentIdentifier IdentifierType = "permanent-identifier" // WireUser is the Wire user identifier type WireUser IdentifierType = "wireapp-user" // WireDevice is the Wire device identifier type WireDevice IdentifierType = "wireapp-device" ) // Identifier encodes the type that an order pertains to. type Identifier struct { Type IdentifierType `json:"type"` Value string `json:"value"` } // Order contains order metadata for the ACME protocol order type. type Order struct { ID string `json:"id"` AccountID string `json:"-"` ProvisionerID string `json:"-"` Status Status `json:"status"` ExpiresAt time.Time `json:"expires"` Identifiers []Identifier `json:"identifiers"` NotBefore time.Time `json:"notBefore"` NotAfter time.Time `json:"notAfter"` Error *Error `json:"error,omitempty"` AuthorizationIDs []string `json:"-"` AuthorizationURLs []string `json:"authorizations"` FinalizeURL string `json:"finalize"` CertificateID string `json:"-"` CertificateURL string `json:"certificate,omitempty"` } // ToLog enables response logging. func (o *Order) ToLog() (interface{}, error) { b, err := json.Marshal(o) if err != nil { return nil, WrapErrorISE(err, "error marshaling order for logging") } return string(b), nil } // UpdateStatus updates the ACME Order Status if necessary. // Changes to the order are saved using the database interface. func (o *Order) UpdateStatus(ctx context.Context, db DB) error { now := clock.Now() switch o.Status { case StatusInvalid: return nil case StatusValid: return nil case StatusReady: // Check expiry if now.After(o.ExpiresAt) { o.Status = StatusInvalid o.Error = NewError(ErrorMalformedType, "order has expired") break } return nil case StatusPending: // Check expiry if now.After(o.ExpiresAt) { o.Status = StatusInvalid o.Error = NewError(ErrorMalformedType, "order has expired") break } var count = map[Status]int{ StatusValid: 0, StatusInvalid: 0, StatusPending: 0, } for _, azID := range o.AuthorizationIDs { az, err := db.GetAuthorization(ctx, azID) if err != nil { return WrapErrorISE(err, "error getting authorization ID %s", azID) } if err = az.UpdateStatus(ctx, db); err != nil { return WrapErrorISE(err, "error updating authorization ID %s", azID) } st := az.Status count[st]++ } switch { case count[StatusInvalid] > 0: o.Status = StatusInvalid // No change in the order status, so just return the order as is - // without writing any changes. case count[StatusPending] > 0: return nil case count[StatusValid] == len(o.AuthorizationIDs): o.Status = StatusReady default: return NewErrorISE("unexpected authz status") } default: return NewErrorISE("unrecognized order status: %s", o.Status) } if err := db.UpdateOrder(ctx, o); err != nil { return WrapErrorISE(err, "error updating order") } return nil } // getAuthorizationFingerprint returns a fingerprint from the list of authorizations. This // fingerprint is used on the device-attest-01 flow to verify the attestation // certificate public key with the CSR public key. // // There's no point on reading all the authorizations as there will be only one // for a permanent identifier. func (o *Order) getAuthorizationFingerprint(ctx context.Context, db DB) (string, error) { for _, azID := range o.AuthorizationIDs { az, err := db.GetAuthorization(ctx, azID) if err != nil { return "", WrapErrorISE(err, "error getting authorization %q", azID) } // There's no point on reading all the authorizations as there will // be only one for a permanent identifier. if az.Fingerprint != "" { return az.Fingerprint, nil } } return "", nil } // Finalize signs a certificate if the necessary conditions for Order completion // have been met. // // TODO(mariano): Here or in the challenge validation we should perform some // external validation using the identifier value and the attestation data. From // a validation service we can get the list of SANs to set in the final // certificate. func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth CertificateAuthority, p Provisioner) error { if err := o.UpdateStatus(ctx, db); err != nil { return err } switch o.Status { case StatusInvalid: return NewError(ErrorOrderNotReadyType, "order %s has been abandoned", o.ID) case StatusValid: return nil case StatusPending: return NewError(ErrorOrderNotReadyType, "order %s is not ready", o.ID) case StatusReady: break default: return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID) } // Get key fingerprint if any. And then compare it with the CSR fingerprint. // // In device-attest-01 challenges we should check that the keys in the CSR // and the attestation certificate are the same. fingerprint, err := o.getAuthorizationFingerprint(ctx, db) if err != nil { return err } if fingerprint != "" { fp, err := keyutil.Fingerprint(csr.PublicKey) if err != nil { return WrapErrorISE(err, "error calculating key fingerprint") } if subtle.ConstantTimeCompare([]byte(fingerprint), []byte(fp)) == 0 { return NewError(ErrorUnauthorizedType, "order %s csr does not match the attested key", o.ID) } } // canonicalize the CSR to allow for comparison csr = canonicalize(csr) // Template data data := x509util.NewTemplateData() if o.containsWireIdentifiers() { wireDB, ok := db.(WireDB) if !ok { return fmt.Errorf("db %T is not a WireDB", db) } subject, err := createWireSubject(o, csr) if err != nil { return fmt.Errorf("failed creating Wire subject: %w", err) } data.SetSubject(subject) // Inject Wire's custom challenges into the template once they have been validated dpop, err := wireDB.GetDpopToken(ctx, o.ID) if err != nil { return fmt.Errorf("failed getting Wire DPoP token: %w", err) } data.Set("Dpop", dpop) oidc, err := wireDB.GetOidcToken(ctx, o.ID) if err != nil { return fmt.Errorf("failed getting Wire OIDC token: %w", err) } data.Set("Oidc", oidc) } else { data.SetCommonName(csr.Subject.CommonName) } // Custom sign options passed to authority.Sign var extraOptions []provisioner.SignOption // TODO: support for multiple identifiers? var permanentIdentifier string for i := range o.Identifiers { if o.Identifiers[i].Type == PermanentIdentifier { permanentIdentifier = o.Identifiers[i].Value // the first (and only) Permanent Identifier that gets added to the certificate // should be equal to the Subject Common Name if it's set. If not equal, the CSR // is rejected, because the Common Name hasn't been challenged in that case. This // could result in unauthorized access if a relying system relies on the Common // Name in its authorization logic. if csr.Subject.CommonName != "" && csr.Subject.CommonName != permanentIdentifier { return NewError(ErrorBadCSRType, "CSR Subject Common Name does not match identifiers exactly: "+ "CSR Subject Common Name = %s, Order Permanent Identifier = %s", csr.Subject.CommonName, permanentIdentifier) } break } } var defaultTemplate string if permanentIdentifier != "" { defaultTemplate = x509util.DefaultAttestedLeafTemplate data.SetSubjectAlternativeNames(x509util.SubjectAlternativeName{ Type: x509util.PermanentIdentifierType, Value: permanentIdentifier, }) extraOptions = append(extraOptions, provisioner.AttestationData{ PermanentIdentifier: permanentIdentifier, }) } else { defaultTemplate = x509util.DefaultLeafTemplate sans, err := o.sans(csr) if err != nil { return err } data.SetSubjectAlternativeNames(sans...) } // Get authorizations from the ACME provisioner. ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOps, err := p.AuthorizeSign(ctx, "") if err != nil { return WrapErrorISE(err, "error retrieving authorization options from ACME provisioner") } // Unlike most of the provisioners, ACME's AuthorizeSign method doesn't // define the templates, and the template data used in WebHooks is not // available. for _, signOp := range signOps { if wc, ok := signOp.(*provisioner.WebhookController); ok { wc.TemplateData = data } } templateOptions, err := provisioner.CustomTemplateOptions(p.GetOptions(), data, defaultTemplate) if err != nil { return WrapErrorISE(err, "error creating template options from ACME provisioner") } // Build extra signing options. signOps = append(signOps, templateOptions) signOps = append(signOps, extraOptions...) // Sign a new certificate. certChain, err := auth.SignWithContext(ctx, csr, provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(o.NotBefore), NotAfter: provisioner.NewTimeDuration(o.NotAfter), }, signOps...) if err != nil { // Add subproblem for webhook errors, others can be added later. var webhookErr *webhook.Error if errors.As(err, &webhookErr) { acmeError := NewDetailedError(ErrorUnauthorizedType, "%s", webhookErr.Error()) acmeError.AddSubproblems(Subproblem{ Type: fmt.Sprintf("urn:smallstep:acme:error:%s", webhookErr.Code), Detail: webhookErr.Message, }) return acmeError } return WrapErrorISE(err, "error signing certificate for order %s", o.ID) } cert := &Certificate{ AccountID: o.AccountID, OrderID: o.ID, Leaf: certChain[0], Intermediates: certChain[1:], } if err := db.CreateCertificate(ctx, cert); err != nil { return WrapErrorISE(err, "error creating certificate for order %s", o.ID) } o.CertificateID = cert.ID o.Status = StatusValid if err = db.UpdateOrder(ctx, o); err != nil { return WrapErrorISE(err, "error updating order %s", o.ID) } return nil } // containsWireIdentifiers checks if [Order] contains ACME // identifiers for the WireUser or WireDevice types. func (o *Order) containsWireIdentifiers() bool { for _, i := range o.Identifiers { if i.Type == WireUser || i.Type == WireDevice { return true } } return false } // createWireSubject creates the subject for an [Order] with WireUser identifiers. func createWireSubject(o *Order, csr *x509.CertificateRequest) (subject x509util.Subject, err error) { wireUserIDs, wireDeviceIDs, otherIDs := 0, 0, 0 for _, identifier := range o.Identifiers { switch identifier.Type { case WireUser: wireID, err := wire.ParseUserID(identifier.Value) if err != nil { return subject, NewErrorISE("unmarshal wireID: %s", err) } // TODO: temporarily using a custom OIDC for carrying the display name without having it listed as a DNS SAN. // reusing LDAP's OID for diplay name see http://oid-info.com/get/2.16.840.1.113730.3.1.241 displayNameOid := asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241} var foundDisplayName = false for _, entry := range csr.Subject.Names { if entry.Type.Equal(displayNameOid) { foundDisplayName = true displayName := entry.Value.(string) if displayName != wireID.Name { return subject, NewErrorISE("expected displayName %v, found %v", wireID.Name, displayName) } } } if !foundDisplayName { return subject, NewErrorISE("CSR must contain the display name in '2.16.840.1.113730.3.1.241' OID") } if len(csr.Subject.Organization) == 0 || !strings.EqualFold(csr.Subject.Organization[0], wireID.Domain) { return subject, NewErrorISE("expected Organization [%s], found %v", wireID.Domain, csr.Subject.Organization) } subject.CommonName = wireID.Name subject.Organization = []string{wireID.Domain} wireUserIDs++ case WireDevice: wireDeviceIDs++ default: otherIDs++ } } if otherIDs > 0 || wireUserIDs != 1 && wireDeviceIDs != 1 { return subject, NewErrorISE("order must have exactly one WireUser and WireDevice identifier") } return } func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativeName, error) { var sans []x509util.SubjectAlternativeName if len(csr.EmailAddresses) > 0 { return sans, NewError(ErrorBadCSRType, "Only DNS names and IP addresses are allowed") } // order the DNS names and IP addresses, so that they can be compared against the canonicalized CSR orderNames := make([]string, numberOfIdentifierType(DNS, o.Identifiers)) orderIPs := make([]net.IP, numberOfIdentifierType(IP, o.Identifiers)) orderPIDs := make([]string, numberOfIdentifierType(PermanentIdentifier, o.Identifiers)) tmpOrderURIs := make([]*url.URL, numberOfIdentifierType(WireUser, o.Identifiers)+numberOfIdentifierType(WireDevice, o.Identifiers)) indexDNS, indexIP, indexPID, indexURI := 0, 0, 0, 0 for _, n := range o.Identifiers { switch n.Type { case DNS: orderNames[indexDNS] = n.Value indexDNS++ case IP: orderIPs[indexIP] = net.ParseIP(n.Value) // NOTE: this assumes are all valid IPs at this time; or will result in nil entries indexIP++ case PermanentIdentifier: orderPIDs[indexPID] = n.Value indexPID++ case WireUser: wireID, err := wire.ParseUserID(n.Value) if err != nil { return sans, NewErrorISE("unsupported identifier value in order: %s", n.Value) } handle, err := url.Parse(wireID.Handle) if err != nil { return sans, NewErrorISE("handle must be a URI: %s", wireID.Handle) } tmpOrderURIs[indexURI] = handle indexURI++ case WireDevice: wireID, err := wire.ParseDeviceID(n.Value) if err != nil { return sans, NewErrorISE("unsupported identifier value in order: %s", n.Value) } clientID, err := url.Parse(wireID.ClientID) if err != nil { return sans, NewErrorISE("clientId must be a URI: %s", wireID.ClientID) } tmpOrderURIs[indexURI] = clientID indexURI++ default: return sans, NewErrorISE("unsupported identifier type in order: %s", n.Type) } } orderNames = uniqueSortedLowerNames(orderNames) orderIPs = uniqueSortedIPs(orderIPs) orderURIs := uniqueSortedURIStrings(tmpOrderURIs) totalNumberOfSANs := len(csr.DNSNames) + len(csr.IPAddresses) + len(csr.URIs) sans = make([]x509util.SubjectAlternativeName, totalNumberOfSANs) index := 0 // Validate identifier names against CSR alternative names. // // Note that with certificate templates we are not going to check for the // absence of other SANs as they will only be set if the template allows // them. if len(csr.DNSNames) != len(orderNames) { return sans, NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ "CSR names = %v, Order names = %v", csr.DNSNames, orderNames) } for i := range csr.DNSNames { if csr.DNSNames[i] != orderNames[i] { return sans, NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ "CSR names = %v, Order names = %v", csr.DNSNames, orderNames) } sans[index] = x509util.SubjectAlternativeName{ Type: x509util.DNSType, Value: csr.DNSNames[i], } index++ } if len(csr.IPAddresses) != len(orderIPs) { return sans, NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+ "CSR IPs = %v, Order IPs = %v", csr.IPAddresses, orderIPs) } for i := range csr.IPAddresses { if !ipsAreEqual(csr.IPAddresses[i], orderIPs[i]) { return sans, NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+ "CSR IPs = %v, Order IPs = %v", csr.IPAddresses, orderIPs) } sans[index] = x509util.SubjectAlternativeName{ Type: x509util.IPType, Value: csr.IPAddresses[i].String(), } index++ } if len(csr.URIs) != len(tmpOrderURIs) { return sans, NewError(ErrorBadCSRType, "CSR URIs do not match identifiers exactly: "+ "CSR URIs = %v, Order URIs = %v", csr.URIs, tmpOrderURIs) } // sort URI list csrURIs := uniqueSortedURIStrings(csr.URIs) for i := range csrURIs { if csrURIs[i] != orderURIs[i] { return sans, NewError(ErrorBadCSRType, "CSR URIs do not match identifiers exactly: "+ "CSR URIs = %v, Order URIs = %v", csr.URIs, tmpOrderURIs) } sans[index] = x509util.SubjectAlternativeName{ Type: x509util.URIType, Value: orderURIs[i], } index++ } return sans, nil } // numberOfIdentifierType returns the number of Identifiers that // are of type typ. func numberOfIdentifierType(typ IdentifierType, ids []Identifier) int { c := 0 for _, id := range ids { if id.Type == typ { c++ } } return c } // canonicalize canonicalizes a CSR so that it can be compared against an Order // NOTE: this effectively changes the order of SANs in the CSR, which may be OK, // but may not be expected. It also adds a Subject Common Name to either the IP // addresses or DNS names slice, depending on whether it can be parsed as an IP // or not. This might result in an additional SAN in the final certificate. func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.CertificateRequest) { // for clarity only; we're operating on the same object by pointer canonicalized = csr // RFC8555: The CSR MUST indicate the exact same set of requested // identifiers as the initial newOrder request. Identifiers of type "dns" // MUST appear either in the commonName portion of the requested subject // name or in an extensionRequest attribute [RFC2985] requesting a // subjectAltName extension, or both. Subject Common Names that can be // parsed as an IP are included as an IP address for the equality check. // If these were excluded, a certificate could contain an IP as the // common name without having been challenged. if csr.Subject.CommonName != "" { if ip := net.ParseIP(csr.Subject.CommonName); ip != nil { canonicalized.IPAddresses = append(canonicalized.IPAddresses, ip) } else { canonicalized.DNSNames = append(canonicalized.DNSNames, csr.Subject.CommonName) } } canonicalized.DNSNames = uniqueSortedLowerNames(canonicalized.DNSNames) canonicalized.IPAddresses = uniqueSortedIPs(canonicalized.IPAddresses) return canonicalized } // ipsAreEqual compares IPs to be equal. Nil values (i.e. invalid IPs) are // not considered equal. IPv6 representations of IPv4 addresses are // considered equal to the IPv4 address in this implementation, which is // standard Go behavior. An example is "::ffff:192.168.42.42", which // is equal to "192.168.42.42". This is considered a known issue within // step and is tracked here too: https://github.com/golang/go/issues/37921. func ipsAreEqual(x, y net.IP) bool { if x == nil || y == nil { return false } return x.Equal(y) } // uniqueSortedLowerNames returns the set of all unique names in the input after all // of them are lowercased. The returned names will be in their lowercased form // and sorted alphabetically. func uniqueSortedLowerNames(names []string) (unique []string) { nameMap := make(map[string]int, len(names)) for _, name := range names { nameMap[strings.ToLower(name)] = 1 } unique = make([]string, 0, len(nameMap)) for name := range nameMap { if name != "" { unique = append(unique, name) } } sort.Strings(unique) return } func uniqueSortedURIStrings(uris []*url.URL) (unique []string) { uriMap := make(map[string]struct{}, len(uris)) for _, name := range uris { uriMap[name.String()] = struct{}{} } unique = make([]string, 0, len(uriMap)) for name := range uriMap { unique = append(unique, name) } sort.Strings(unique) return } // uniqueSortedIPs returns the set of all unique net.IPs in the input. They // are sorted by their bytes (octet) representation. func uniqueSortedIPs(ips []net.IP) (unique []net.IP) { type entry struct { ip net.IP } ipEntryMap := make(map[string]entry, len(ips)) for _, ip := range ips { // reparsing the IP results in the IP being represented using 16 bytes // for both IPv4 as well as IPv6, even when the ips slice contains IPs that // are represented by 4 bytes. This ensures a fair comparison and thus ordering. ipEntryMap[ip.String()] = entry{ip: net.ParseIP(ip.String())} } unique = make([]net.IP, 0, len(ipEntryMap)) for _, entry := range ipEntryMap { unique = append(unique, entry.ip) } sort.Slice(unique, func(i, j int) bool { return bytes.Compare(unique[i], unique[j]) < 0 }) return } ================================================ FILE: acme/order_test.go ================================================ package acme import ( "context" "crypto" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/json" "fmt" "net" "net/url" "reflect" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/webhook" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" ) func TestOrder_UpdateStatus(t *testing.T) { type test struct { o *Order err *Error db DB } tests := map[string]func(t *testing.T) test{ "ok/already-invalid": func(t *testing.T) test { o := &Order{ Status: StatusInvalid, } return test{ o: o, } }, "ok/already-valid": func(t *testing.T) test { o := &Order{ Status: StatusInvalid, } return test{ o: o, } }, "fail/error-unexpected-status": func(t *testing.T) test { o := &Order{ Status: "foo", } return test{ o: o, err: NewErrorISE("unrecognized order status: %s", o.Status), } }, "ok/ready-expired": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(-5 * time.Minute), } return test{ o: o, db: &MockDB{ MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.Status, StatusInvalid) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) return nil }, }, } }, "fail/ready-expired-db.UpdateOrder-error": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(-5 * time.Minute), } return test{ o: o, db: &MockDB{ MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.Status, StatusInvalid) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) return errors.New("force") }, }, err: NewErrorISE("error updating order: force"), } }, "ok/pending-expired": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusPending, ExpiresAt: now.Add(-5 * time.Minute), } return test{ o: o, db: &MockDB{ MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.Status, StatusInvalid) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) err := NewError(ErrorMalformedType, "order has expired") assert.HasPrefix(t, updo.Error.Err.Error(), err.Err.Error()) assert.Equals(t, updo.Error.Type, err.Type) assert.Equals(t, updo.Error.Detail, err.Detail) assert.Equals(t, updo.Error.Status, err.Status) assert.Equals(t, updo.Error.Detail, err.Detail) return nil }, }, } }, "ok/invalid": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusPending, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, } az1 := &Authorization{ ID: "a", Status: StatusValid, } az2 := &Authorization{ ID: "b", Status: StatusInvalid, } return test{ o: o, db: &MockDB{ MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.Status, StatusInvalid) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) return nil }, MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { switch id { case az1.ID: return az1, nil case az2.ID: return az2, nil default: assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) return nil, errors.New("force") } }, }, } }, "ok/still-pending": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusPending, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, } az1 := &Authorization{ ID: "a", Status: StatusValid, } az2 := &Authorization{ ID: "b", Status: StatusPending, } return test{ o: o, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { switch id { case az1.ID: return az1, nil case az2.ID: return az2, nil default: assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) return nil, errors.New("force") } }, }, } }, "ok/valid": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusPending, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, } az1 := &Authorization{ ID: "a", Status: StatusValid, } az2 := &Authorization{ ID: "b", Status: StatusValid, } return test{ o: o, db: &MockDB{ MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.Status, StatusReady) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) return nil }, MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { switch id { case az1.ID: return az1, nil case az2.ID: return az2, nil default: assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) return nil, errors.New("force") } }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if err := tc.o.UpdateStatus(context.Background(), tc.db); err != nil { if assert.NotNil(t, tc.err) { var k *Error if errors.As(err, &k) { assert.Equals(t, k.Type, tc.err.Type) assert.Equals(t, k.Detail, tc.err.Detail) assert.Equals(t, k.Status, tc.err.Status) assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) assert.Equals(t, k.Detail, tc.err.Detail) } else { assert.FatalError(t, errors.New("unexpected error type")) } } } else { assert.Nil(t, tc.err) } }) } } type mockSignAuth struct { signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) areSANsAllowed func(ctx context.Context, sans []string) error loadProvisionerByName func(string) (provisioner.Interface, error) ret1, ret2 interface{} err error } func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { if m.signWithContext != nil { return m.signWithContext(ctx, csr, signOpts, extraOpts...) } else if m.err != nil { return nil, m.err } return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error { if m.areSANsAllowed != nil { return m.areSANsAllowed(ctx, sans) } return m.err } func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) { if m.loadProvisionerByName != nil { return m.loadProvisionerByName(name) } return m.ret1.(provisioner.Interface), m.err } func (m *mockSignAuth) IsRevoked(string) (bool, error) { return false, nil } func (m *mockSignAuth) Revoke(context.Context, *authority.RevokeOptions) error { return nil } func (m *mockSignAuth) GetBackdate() *time.Duration { return nil } func TestOrder_Finalize(t *testing.T) { mustSigner := func(kty, crv string, size int) crypto.Signer { s, err := keyutil.GenerateSigner(kty, crv, size) if err != nil { t.Fatal(err) } return s } type test struct { o *Order err *Error db DB ca CertificateAuthority csr *x509.CertificateRequest prov Provisioner } tests := map[string]func(t *testing.T) test{ "fail/invalid": func(t *testing.T) test { o := &Order{ ID: "oid", Status: StatusInvalid, } return test{ o: o, err: NewError(ErrorOrderNotReadyType, "order %s has been abandoned", o.ID), } }, "fail/pending": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusPending, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, } az1 := &Authorization{ ID: "a", Status: StatusValid, } az2 := &Authorization{ ID: "b", Status: StatusPending, ExpiresAt: now.Add(5 * time.Minute), } return test{ o: o, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { switch id { case az1.ID: return az1, nil case az2.ID: return az2, nil default: assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) return nil, errors.New("force") } }, }, err: NewError(ErrorOrderNotReadyType, "order %s is not ready", o.ID), } }, "ok/already-valid": func(t *testing.T) test { o := &Order{ ID: "oid", Status: StatusValid, } return test{ o: o, } }, "fail/error-unexpected-status": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: "foo", ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, } return test{ o: o, err: NewErrorISE("unrecognized order status: %s", o.Status), } }, "fail/non-matching-permanent-identifier-common-name": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "permanent-identifier", Value: "a-permanent-identifier"}, }, } signer := mustSigner("EC", "P-256", 0) fingerprint, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "a-different-identifier", }, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3}, Value: []byte("a-permanent-identifier"), }, }, } return test{ o: o, csr: csr, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { switch id { case "a": return &Authorization{ ID: id, Status: StatusValid, }, nil case "b": return &Authorization{ ID: id, Fingerprint: fingerprint, Status: StatusValid, }, nil default: assert.FatalError(t, errors.Errorf("unexpected authorization %s", id)) return nil, errors.New("force") } }, MockUpdateOrder: func(ctx context.Context, o *Order) error { return nil }, }, err: &Error{ Type: "urn:ietf:params:acme:error:badCSR", Detail: "The CSR is unacceptable", Status: 400, Err: fmt.Errorf("CSR Subject Common Name does not match identifiers exactly: "+ "CSR Subject Common Name = %s, Order Permanent Identifier = %s", csr.Subject.CommonName, "a-permanent-identifier"), }, } }, "fail/error-provisioner-auth": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "dns", Value: "bar.internal"}, }, } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, DNSNames: []string{"bar.internal"}, } return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, errors.New("force") }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, err: NewErrorISE("error retrieving authorization options from ACME provisioner: force"), } }, "fail/error-template-options": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "dns", Value: "bar.internal"}, }, } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, DNSNames: []string{"bar.internal"}, } return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return &provisioner.Options{ X509: &provisioner.X509Options{ TemplateData: json.RawMessage([]byte("fo{o")), }, } }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, err: NewErrorISE("error creating template options from ACME provisioner: error unmarshaling template data: invalid character 'o' in literal false (expecting 'a')"), } }, "fail/error-ca-sign": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "dns", Value: "bar.internal"}, }, } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, DNSNames: []string{"bar.internal"}, } return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return nil, errors.New("force") }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, err: NewErrorISE("error signing certificate for order oID: force"), } }, "fail/webhook-error": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "dns", Value: "bar.internal"}, }, } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, DNSNames: []string{"bar.internal"}, } return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return nil, errs.ForbiddenErr(&webhook.Error{Code: "theCode", Message: "The message"}, "forbidden error") }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, err: NewDetailedError(ErrorUnauthorizedType, "The message (theCode)").AddSubproblems(Subproblem{ Type: "urn:smallstep:acme:error:theCode", Detail: "The message", }), } }, "fail/error-db.CreateCertificate": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "dns", Value: "bar.internal"}, }, } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, DNSNames: []string{"bar.internal"}, } foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.OrderID, o.ID) assert.Equals(t, cert.Leaf, foo) assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) return errors.New("force") }, }, err: NewErrorISE("error creating certificate for order oID: force"), } }, "fail/error-db.UpdateOrder": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "dns", Value: "bar.internal"}, }, } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, DNSNames: []string{"bar.internal"}, } foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { cert.ID = "certID" assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.OrderID, o.ID) assert.Equals(t, cert.Leaf, foo) assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) return nil }, MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.CertificateID, "certID") assert.Equals(t, updo.Status, StatusValid) assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) assert.Equals(t, updo.Identifiers, o.Identifiers) return errors.New("force") }, }, err: NewErrorISE("error updating order oID: force"), } }, "fail/csr-fingerprint": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "permanent-identifier", Value: "a-permanent-identifier"}, }, } signer := mustSigner("EC", "P-256", 0) csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "a-permanent-identifier", }, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3}, Value: []byte("a-permanent-identifier"), }, }, } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-permanent-identifier"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3}, Value: []byte("a-permanent-identifier"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ ID: id, Fingerprint: "other-fingerprint", Status: StatusValid, }, nil }, MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { cert.ID = "certID" assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.OrderID, o.ID) assert.Equals(t, cert.Leaf, leaf) assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) return nil }, MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.CertificateID, "certID") assert.Equals(t, updo.Status, StatusValid) assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) assert.Equals(t, updo.Identifiers, o.Identifiers) return nil }, }, err: NewError(ErrorUnauthorizedType, "order oID csr does not match the attested key"), } }, "ok/permanent-identifier": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "permanent-identifier", Value: "a-permanent-identifier"}, }, } signer := mustSigner("EC", "P-256", 0) fingerprint, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "a-permanent-identifier", }, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3}, Value: []byte("a-permanent-identifier"), }, }, } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-permanent-identifier"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3}, Value: []byte("a-permanent-identifier"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { switch id { case "a": return &Authorization{ ID: id, Status: StatusValid, }, nil case "b": return &Authorization{ ID: id, Fingerprint: fingerprint, Status: StatusValid, }, nil default: assert.FatalError(t, errors.Errorf("unexpected authorization %s", id)) return nil, errors.New("force") } }, MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { cert.ID = "certID" assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.OrderID, o.ID) assert.Equals(t, cert.Leaf, leaf) assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) return nil }, MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.CertificateID, "certID") assert.Equals(t, updo.Status, StatusValid) assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) assert.Equals(t, updo.Identifiers, o.Identifiers) return nil }, }, } }, "ok/permanent-identifier-only": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "permanent-identifier", Value: "a-permanent-identifier"}, }, } signer := mustSigner("EC", "P-256", 0) fingerprint, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "a-permanent-identifier", }, DNSNames: []string{"foo.internal"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3}, Value: []byte("a-permanent-identifier"), }, }, } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-permanent-identifier"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3}, Value: []byte("a-permanent-identifier"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, // TODO(hs): we should work on making the mocks more realistic. Ideally, we should get rid of // the mock entirely, relying on an instances of provisioner, authority and DB (possibly hardest), so // that behavior of the tests is what an actual CA would do. We could gradually phase them out by // using the mocking functions as a wrapper for actual test helpers generated per test case or per // function that's tested. ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ ID: id, Fingerprint: fingerprint, Status: StatusValid, }, nil }, MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { cert.ID = "certID" assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.OrderID, o.ID) assert.Equals(t, cert.Leaf, leaf) assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) return nil }, MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.CertificateID, "certID") assert.Equals(t, updo.Status, StatusValid) assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) assert.Equals(t, updo.Identifiers, o.Identifiers) return nil }, }, } }, "fail/csr-wire-id-csr-uri-missing": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "wireapp-device", Value: "{\"name\": \"device\", \"domain\": \"wire.com\", \"client-id\": \"wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com\", \"handle\": \"wireapp://%40alice_wire@wire.com\"}"}, }, } signer := mustSigner("EC", "P-256", 0) _, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } csr := &x509.CertificateRequest{ Subject: pkix.Name{ Names: []pkix.AttributeTypeAndValue{ {Type: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: "device"}, }, Organization: []string{"wire.com"}, }, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-wireapp-user"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockWireDB{ MockDB: MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, MockGetDpopToken: func(ctx context.Context, orderID string) (map[string]interface{}, error) { assert.Equals(t, orderID, o.ID) dpopMap := map[string]interface{}{ "dpop": "a-dpop-token", } return dpopMap, nil }, MockGetOidcToken: func(ctx context.Context, orderID string) (map[string]interface{}, error) { assert.Equals(t, orderID, o.ID) oidcMap := map[string]interface{}{ "oidc": "a-oidc-token", } return oidcMap, nil }, }, err: NewError(ErrorBadCSRType, "CSR URIs do not match identifiers exactly: CSR URIs = [], Order URIs = [wireapp://CzbfFjDOQrenCbDxVmgnFw%%21594930e9d50bb175@wire.com]"), } }, "fail/csr-wire-id-csr-uri-mismatch": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "wireapp-device", Value: "{\"name\": \"device\", \"domain\": \"wire.com\", \"client-id\": \"wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com\", \"handle\": \"wireapp://%40alice_wire@wire.com\"}"}, }, } signer := mustSigner("EC", "P-256", 0) _, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } wireURL, _ := url.Parse("someurl.com") csr := &x509.CertificateRequest{ Subject: pkix.Name{ Names: []pkix.AttributeTypeAndValue{ {Type: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: "device"}, }, Organization: []string{"wire.com"}, }, URIs: []*url.URL{ wireURL, }, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-wireapp-user"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockWireDB{ MockDB: MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, MockGetDpopToken: func(ctx context.Context, orderID string) (map[string]interface{}, error) { assert.Equals(t, orderID, o.ID) dpopMap := map[string]interface{}{ "dpop": "a-dpop-token", } return dpopMap, nil }, MockGetOidcToken: func(ctx context.Context, orderID string) (map[string]interface{}, error) { assert.Equals(t, orderID, o.ID) oidcMap := map[string]interface{}{ "oidc": "a-oidc-token", } return oidcMap, nil }, }, err: NewError(ErrorBadCSRType, "CSR URIs do not match identifiers exactly: CSR URIs = [someurl.com], Order URIs = [wireapp://CzbfFjDOQrenCbDxVmgnFw%%21594930e9d50bb175@wire.com]"), } }, "fail/other-than-wire-ids-present": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "wireapp-device", Value: "{\"name\": \"device\", \"domain\": \"wire.com\", \"client-id\": \"wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com\", \"handle\": \"wireapp://%40alice_wire@wire.com\"}"}, {Type: "permanent-identifier", Value: "a-permanent-identifier"}, }, } signer := mustSigner("EC", "P-256", 0) _, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } wireURL, _ := url.Parse("wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com") csr := &x509.CertificateRequest{ Subject: pkix.Name{ Names: []pkix.AttributeTypeAndValue{ {Type: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: "device"}, }, Organization: []string{"wire.com"}, }, URIs: []*url.URL{ wireURL, }, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-wireapp-user"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockWireDB{ MockDB: MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, }, err: NewError(ErrorServerInternalType, "order must have exactly one WireUser and WireDevice identifier"), } }, "fail/wire-id-org-missing": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "wireapp-user", Value: "{\"name\": \"Alice Smith\", \"domain\": \"wire.com\", \"handle\": \"wireapp://%40alice_wire@wire.com\"}"}, }, } signer := mustSigner("EC", "P-256", 0) _, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } wireURL, _ := url.Parse("wireapp://%40alice_wire@wire.com") csr := &x509.CertificateRequest{ Subject: pkix.Name{ Names: []pkix.AttributeTypeAndValue{ {Type: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: "Alice Smith"}, }, }, URIs: []*url.URL{ wireURL, }, PublicKey: signer.Public(), } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-wireapp-user"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockWireDB{ MockDB: MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, }, err: NewError(ErrorServerInternalType, "expected Organization [wire.com], found []"), } }, "fail/wire-id-display-name-missing": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "wireapp-user", Value: "{\"name\": \"Alice Smith\", \"domain\": \"wire.com\", \"handle\": \"wireapp://%40alice_wire@wire.com\"}"}, }, } signer := mustSigner("EC", "P-256", 0) _, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } wireURL, _ := url.Parse("wireapp://%40alice_wire@wire.com") csr := &x509.CertificateRequest{ Subject: pkix.Name{ Organization: []string{"wire.com"}, }, URIs: []*url.URL{ wireURL, }, PublicKey: signer.Public(), } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-wireapp-user"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockWireDB{ MockDB: MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, }, err: NewError(ErrorServerInternalType, "CSR must contain the display name in '2.16.840.1.113730.3.1.241' OID"), } }, "fail/wire-id-display-name-mismatch": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "wireapp-user", Value: "{\"name\": \"Alice Smith\", \"domain\": \"wire.com\", \"handle\": \"wireapp://%40alice_wire@wire.com\"}"}, }, } signer := mustSigner("EC", "P-256", 0) _, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } wireURL, _ := url.Parse("wireapp://%40alice_wire@wire.com") csr := &x509.CertificateRequest{ Subject: pkix.Name{ Names: []pkix.AttributeTypeAndValue{ {Type: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: "Someone else"}, }, Organization: []string{"wire.com"}, }, URIs: []*url.URL{ wireURL, }, PublicKey: signer.Public(), } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-wireapp-user"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockWireDB{ MockDB: MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, }, err: NewError(ErrorServerInternalType, "expected displayName Alice Smith, found Someone else"), } }, "ok/wire-id-user": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "wireapp-user", Value: "{\"name\": \"Alice Smith\", \"domain\": \"wire.com\", \"handle\": \"wireapp://%40alice_wire@wire.com\"}"}, }, } signer := mustSigner("EC", "P-256", 0) _, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } wireURL, _ := url.Parse("wireapp://%40alice_wire@wire.com") csr := &x509.CertificateRequest{ Subject: pkix.Name{ Names: []pkix.AttributeTypeAndValue{ {Type: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: "Alice Smith"}, }, Organization: []string{"wire.com"}, }, URIs: []*url.URL{ wireURL, }, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-wireapp-user"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockWireDB{ MockDB: MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, MockGetDpopToken: func(ctx context.Context, orderID string) (map[string]interface{}, error) { assert.Equals(t, orderID, o.ID) dpopMap := map[string]interface{}{ "dpop": "a-dpop-token", } return dpopMap, nil }, MockGetOidcToken: func(ctx context.Context, orderID string) (map[string]interface{}, error) { assert.Equals(t, orderID, o.ID) oidcMap := map[string]interface{}{ "oidc": "a-oidc-token", } return oidcMap, nil }, }, } }, "ok/wire-id-device": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "wireapp-device", Value: "{\"name\": \"device\", \"domain\": \"wire.com\", \"client-id\": \"wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com\", \"handle\": \"wireapp://%40alice_wire@wire.com\"}"}, }, } signer := mustSigner("EC", "P-256", 0) _, err := keyutil.Fingerprint(signer.Public()) if err != nil { t.Fatal(err) } wireURL, _ := url.Parse("wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com") csr := &x509.CertificateRequest{ Subject: pkix.Name{ Names: []pkix.AttributeTypeAndValue{ {Type: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: "device"}, }, Organization: []string{"wire.com"}, }, URIs: []*url.URL{ wireURL, }, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } leaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "a-wireapp-user"}, PublicKey: signer.Public(), ExtraExtensions: []pkix.Extension{ { Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 3, 1, 241}, Value: []byte("a-wireapp-user"), }, }, } inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}} root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, }, db: &MockWireDB{ MockDB: MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }, MockGetDpopToken: func(ctx context.Context, orderID string) (map[string]interface{}, error) { assert.Equals(t, orderID, o.ID) dpopMap := map[string]interface{}{ "dpop": "a-dpop-token", } return dpopMap, nil }, MockGetOidcToken: func(ctx context.Context, orderID string) (map[string]interface{}, error) { assert.Equals(t, orderID, o.ID) oidcMap := map[string]interface{}{ "oidc": "a-oidc-token", } return oidcMap, nil }, }, } }, "ok/new-cert-dns": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "dns", Value: "bar.internal"}, }, } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, DNSNames: []string{"bar.internal"}, } foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { cert.ID = "certID" assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.OrderID, o.ID) assert.Equals(t, cert.Leaf, foo) assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) return nil }, MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.CertificateID, "certID") assert.Equals(t, updo.Status, StatusValid) assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) assert.Equals(t, updo.Identifiers, o.Identifiers) return nil }, }, } }, "ok/new-cert-ip": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "ip", Value: "192.168.42.42"}, {Type: "ip", Value: "192.168.43.42"}, }, } csr := &x509.CertificateRequest{ IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, // in case of IPs, no Common Name } foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { cert.ID = "certID" assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.OrderID, o.ID) assert.Equals(t, cert.Leaf, foo) assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) return nil }, MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.CertificateID, "certID") assert.Equals(t, updo.Status, StatusValid) assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) assert.Equals(t, updo.Identifiers, o.Identifiers) return nil }, }, } }, "ok/new-cert-dns-and-ip": func(t *testing.T) test { now := clock.Now() o := &Order{ ID: "oID", AccountID: "accID", Status: StatusReady, ExpiresAt: now.Add(5 * time.Minute), AuthorizationIDs: []string{"a", "b"}, Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "ip", Value: "192.168.42.42"}, }, } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, IPAddresses: []net.IP{net.ParseIP("192.168.42.42")}, } foo := &x509.Certificate{Subject: pkix.Name{CommonName: "foo"}} bar := &x509.Certificate{Subject: pkix.Name{CommonName: "bar"}} baz := &x509.Certificate{Subject: pkix.Name{CommonName: "baz"}} return test{ o: o, csr: csr, prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { assert.Equals(t, token, "") return nil, nil }, MgetOptions: func() *provisioner.Options { return nil }, }, ca: &mockSignAuth{ signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, }, db: &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { cert.ID = "certID" assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.OrderID, o.ID) assert.Equals(t, cert.Leaf, foo) assert.Equals(t, cert.Intermediates, []*x509.Certificate{bar, baz}) return nil }, MockUpdateOrder: func(ctx context.Context, updo *Order) error { assert.Equals(t, updo.CertificateID, "certID") assert.Equals(t, updo.Status, StatusValid) assert.Equals(t, updo.ID, o.ID) assert.Equals(t, updo.AccountID, o.AccountID) assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs) assert.Equals(t, updo.Identifiers, o.Identifiers) return nil }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if err := tc.o.Finalize(context.Background(), tc.db, tc.csr, tc.ca, tc.prov); err != nil { if assert.NotNil(t, tc.err) { var k *Error if errors.As(err, &k) { assert.Equals(t, k.Type, tc.err.Type) assert.Equals(t, k.Detail, tc.err.Detail) assert.Equals(t, k.Status, tc.err.Status) assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) assert.Equals(t, k.Detail, tc.err.Detail) assert.Equals(t, k.Subproblems, tc.err.Subproblems) } else { assert.FatalError(t, errors.New("unexpected error type")) } } } else { assert.Nil(t, tc.err) } }) } } func Test_uniqueSortedIPs(t *testing.T) { type args struct { ips []net.IP } tests := []struct { name string args args want []net.IP }{ { name: "ok/empty", args: args{ ips: []net.IP{}, }, want: []net.IP{}, }, { name: "ok/single-ipv4", args: args{ ips: []net.IP{net.ParseIP("192.168.42.42")}, }, want: []net.IP{net.ParseIP("192.168.42.42")}, }, { name: "ok/multiple-ipv4", args: args{ ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.1"), net.ParseIP("127.0.0.1")}, }, want: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("192.168.42.1"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.42")}, }, { name: "ok/multiple-ipv4-with-varying-byte-representations", args: args{ ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.1"), []byte{0x7f, 0x0, 0x0, 0x1}}, }, want: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("192.168.42.1"), net.ParseIP("192.168.42.10"), net.ParseIP("192.168.42.42")}, }, { name: "ok/unique-ipv4", args: args{ ips: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.42")}, }, want: []net.IP{net.ParseIP("192.168.42.42")}, }, { name: "ok/single-ipv6", args: args{ ips: []net.IP{net.ParseIP("2001:db8::30")}, }, want: []net.IP{net.ParseIP("2001:db8::30")}, }, { name: "ok/multiple-ipv6", args: args{ ips: []net.IP{net.ParseIP("2001:db8::30"), net.ParseIP("2001:db8::20"), net.ParseIP("2001:db8::10")}, }, want: []net.IP{net.ParseIP("2001:db8::10"), net.ParseIP("2001:db8::20"), net.ParseIP("2001:db8::30")}, }, { name: "ok/unique-ipv6", args: args{ ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1")}, }, want: []net.IP{net.ParseIP("2001:db8::1")}, }, { name: "ok/mixed-ipv4-and-ipv6", args: args{ ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1"), net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.42")}, }, want: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::1")}, }, { name: "ok/mixed-ipv4-and-ipv6-and-varying-byte-representations", args: args{ ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1"), net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.42"), []byte{0x7f, 0x0, 0x0, 0x1}}, }, want: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::1")}, }, { name: "ok/mixed-ipv4-and-ipv6-and-more-varying-byte-representations", args: args{ ips: []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::1"), net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::2"), net.ParseIP("192.168.42.42"), []byte{0x7f, 0x0, 0x0, 0x1}, []byte{0x7f, 0x0, 0x0, 0x1}, []byte{0x7f, 0x0, 0x0, 0x2}}, }, want: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.2"), net.ParseIP("192.168.42.42"), net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::2")}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := uniqueSortedIPs(tt.args.ips) if !cmp.Equal(tt.want, got) { t.Errorf("uniqueSortedIPs() diff =\n%s", cmp.Diff(tt.want, got)) } }) } } func Test_numberOfIdentifierType(t *testing.T) { type args struct { typ IdentifierType ids []Identifier } tests := []struct { name string args args want int }{ { name: "ok/no-identifiers", args: args{ typ: DNS, ids: []Identifier{}, }, want: 0, }, { name: "ok/no-dns", args: args{ typ: DNS, ids: []Identifier{ { Type: IP, Value: "192.168.42.42", }, }, }, want: 0, }, { name: "ok/no-ips", args: args{ typ: IP, ids: []Identifier{ { Type: DNS, Value: "example.com", }, }, }, want: 0, }, { name: "ok/one-dns", args: args{ typ: DNS, ids: []Identifier{ { Type: DNS, Value: "example.com", }, { Type: IP, Value: "192.168.42.42", }, }, }, want: 1, }, { name: "ok/one-ip", args: args{ typ: IP, ids: []Identifier{ { Type: DNS, Value: "example.com", }, { Type: IP, Value: "192.168.42.42", }, }, }, want: 1, }, { name: "ok/more-dns", args: args{ typ: DNS, ids: []Identifier{ { Type: DNS, Value: "example.com", }, { Type: DNS, Value: "*.example.com", }, { Type: IP, Value: "192.168.42.42", }, }, }, want: 2, }, { name: "ok/more-ips", args: args{ typ: IP, ids: []Identifier{ { Type: DNS, Value: "example.com", }, { Type: IP, Value: "192.168.42.42", }, { Type: IP, Value: "192.168.42.43", }, }, }, want: 2, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := numberOfIdentifierType(tt.args.typ, tt.args.ids); got != tt.want { t.Errorf("numberOfIdentifierType() = %v, want %v", got, tt.want) } }) } } func Test_ipsAreEqual(t *testing.T) { type args struct { x net.IP y net.IP } tests := []struct { name string args args want bool }{ { name: "ok/ipv4", args: args{ x: net.ParseIP("192.168.42.42"), y: net.ParseIP("192.168.42.42"), }, want: true, }, { name: "fail/ipv4", args: args{ x: net.ParseIP("192.168.42.42"), y: net.ParseIP("192.168.42.43"), }, want: false, }, { name: "ok/ipv6", args: args{ x: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), }, want: true, }, { name: "fail/ipv6", args: args{ x: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7335"), }, want: false, }, { name: "fail/ipv4-and-ipv6", args: args{ x: net.ParseIP("192.168.42.42"), y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), }, want: false, }, { name: "ok/ipv4-mapped-to-ipv6", args: args{ x: net.ParseIP("192.168.42.42"), y: net.ParseIP("::ffff:192.168.42.42"), // parsed to the same IPv4 by Go }, want: true, // we expect this to happen; a known issue in which ipv4 mapped ipv6 addresses are considered the same as their ipv4 counterpart }, { name: "fail/invalid-ipv4-and-valid-ipv6", args: args{ x: net.ParseIP("192.168.42.1000"), y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), }, want: false, }, { name: "fail/valid-ipv4-and-invalid-ipv6", args: args{ x: net.ParseIP("192.168.42.42"), y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:733400"), }, want: false, }, { name: "fail/invalid-ipv4-and-invalid-ipv6", args: args{ x: net.ParseIP("192.168.42.1000"), y: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:1000000"), }, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := ipsAreEqual(tt.args.x, tt.args.y); got != tt.want { t.Errorf("ipsAreEqual() = %v, want %v", got, tt.want) } }) } } func Test_canonicalize(t *testing.T) { type args struct { csr *x509.CertificateRequest } tests := []struct { name string args args want *x509.CertificateRequest }{ { name: "ok/dns", args: args{ csr: &x509.CertificateRequest{ DNSNames: []string{"www.example.com", "example.com"}, }, }, want: &x509.CertificateRequest{ DNSNames: []string{"example.com", "www.example.com"}, IPAddresses: []net.IP{}, }, }, { name: "ok/common-name", args: args{ csr: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "example.com", }, DNSNames: []string{"www.example.com"}, }, }, want: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "example.com", }, DNSNames: []string{"example.com", "www.example.com"}, IPAddresses: []net.IP{}, }, }, { name: "ok/ipv4", args: args{ csr: &x509.CertificateRequest{ IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, }, }, want: &x509.CertificateRequest{ DNSNames: []string{}, IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, }, }, { name: "ok/mixed", args: args{ csr: &x509.CertificateRequest{ DNSNames: []string{"www.example.com", "example.com"}, IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, }, }, want: &x509.CertificateRequest{ DNSNames: []string{"example.com", "www.example.com"}, IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, }, }, { name: "ok/mixed-common-name", args: args{ csr: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "example.com", }, DNSNames: []string{"www.example.com"}, IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, }, }, want: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "example.com", }, DNSNames: []string{"example.com", "www.example.com"}, IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, }, }, { name: "ok/ip-common-name", args: args{ csr: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "127.0.0.1", }, DNSNames: []string{"example.com"}, IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, }, }, want: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "127.0.0.1", }, DNSNames: []string{"example.com"}, IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := canonicalize(tt.args.csr) if !cmp.Equal(tt.want, got) { t.Errorf("canonicalize() diff =\n%s", cmp.Diff(tt.want, got)) } }) } } func TestOrder_sans(t *testing.T) { type fields struct { Identifiers []Identifier } tests := []struct { name string fields fields csr *x509.CertificateRequest want []x509util.SubjectAlternativeName err *Error }{ { name: "ok/dns", fields: fields{ Identifiers: []Identifier{ {Type: "dns", Value: "example.com"}, }, }, csr: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "example.com", }, }, want: []x509util.SubjectAlternativeName{ {Type: "dns", Value: "example.com"}, }, err: nil, }, { name: "fail/invalid-alternative-name-email", fields: fields{ Identifiers: []Identifier{}, }, csr: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, EmailAddresses: []string{"test@example.com"}, }, want: []x509util.SubjectAlternativeName{}, err: NewError(ErrorBadCSRType, "Only DNS names and IP addresses are allowed"), }, { name: "fail/error-names-length-mismatch", fields: fields{ Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "dns", Value: "bar.internal"}, }, }, csr: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, }, want: []x509util.SubjectAlternativeName{}, err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ "CSR names = %v, Order names = %v", []string{"foo.internal"}, []string{"bar.internal", "foo.internal"}), }, { name: "fail/error-names-mismatch", fields: fields{ Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "dns", Value: "bar.internal"}, }, }, csr: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "foo.internal", }, DNSNames: []string{"zap.internal"}, }, want: []x509util.SubjectAlternativeName{}, err: NewError(ErrorBadCSRType, "CSR names do not match identifiers exactly: "+ "CSR names = %v, Order names = %v", []string{"foo.internal", "zap.internal"}, []string{"bar.internal", "foo.internal"}), }, { name: "ok/ipv4", fields: fields{ Identifiers: []Identifier{ {Type: "ip", Value: "192.168.43.42"}, {Type: "ip", Value: "192.168.42.42"}, }, }, csr: &x509.CertificateRequest{ IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42")}, }, want: []x509util.SubjectAlternativeName{ {Type: "ip", Value: "192.168.42.42"}, {Type: "ip", Value: "192.168.43.42"}, }, err: nil, }, { name: "ok/ipv6", fields: fields{ Identifiers: []Identifier{ {Type: "ip", Value: "2001:0db8:85a3::8a2e:0370:7335"}, {Type: "ip", Value: "2001:0db8:85a3::8a2e:0370:7334"}, }, }, csr: &x509.CertificateRequest{ IPAddresses: []net.IP{net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7335"), net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, }, want: []x509util.SubjectAlternativeName{ {Type: "ip", Value: "2001:db8:85a3::8a2e:370:7334"}, {Type: "ip", Value: "2001:db8:85a3::8a2e:370:7335"}, }, err: nil, }, { name: "fail/error-ips-length-mismatch", fields: fields{ Identifiers: []Identifier{ {Type: "ip", Value: "192.168.42.42"}, {Type: "ip", Value: "192.168.43.42"}, }, }, csr: &x509.CertificateRequest{ IPAddresses: []net.IP{net.ParseIP("192.168.42.42")}, }, want: []x509util.SubjectAlternativeName{}, err: NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+ "CSR IPs = %v, Order IPs = %v", []net.IP{net.ParseIP("192.168.42.42")}, []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}), }, { name: "fail/error-ips-mismatch", fields: fields{ Identifiers: []Identifier{ {Type: "ip", Value: "192.168.42.42"}, {Type: "ip", Value: "192.168.43.42"}, }, }, csr: &x509.CertificateRequest{ IPAddresses: []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.42.32")}, }, want: []x509util.SubjectAlternativeName{}, err: NewError(ErrorBadCSRType, "CSR IPs do not match identifiers exactly: "+ "CSR IPs = %v, Order IPs = %v", []net.IP{net.ParseIP("192.168.42.32"), net.ParseIP("192.168.42.42")}, []net.IP{net.ParseIP("192.168.42.42"), net.ParseIP("192.168.43.42")}), }, { name: "ok/mixed", fields: fields{ Identifiers: []Identifier{ {Type: "dns", Value: "foo.internal"}, {Type: "dns", Value: "bar.internal"}, {Type: "ip", Value: "192.168.43.42"}, {Type: "ip", Value: "192.168.42.42"}, {Type: "ip", Value: "2001:0db8:85a3:0000:0000:8a2e:0370:7334"}, }, }, csr: &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "bar.internal", }, DNSNames: []string{"foo.internal"}, IPAddresses: []net.IP{net.ParseIP("192.168.43.42"), net.ParseIP("192.168.42.42"), net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}, }, want: []x509util.SubjectAlternativeName{ {Type: "dns", Value: "bar.internal"}, {Type: "dns", Value: "foo.internal"}, {Type: "ip", Value: "192.168.42.42"}, {Type: "ip", Value: "192.168.43.42"}, {Type: "ip", Value: "2001:db8:85a3::8a2e:370:7334"}, }, err: nil, }, { name: "fail/unsupported-identifier-type", fields: fields{ Identifiers: []Identifier{ {Type: "ipv4", Value: "192.168.42.42"}, }, }, csr: &x509.CertificateRequest{ IPAddresses: []net.IP{net.ParseIP("192.168.42.42")}, }, want: []x509util.SubjectAlternativeName{}, err: NewError(ErrorServerInternalType, "unsupported identifier type in order: ipv4"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &Order{ Identifiers: tt.fields.Identifiers, } canonicalizedCSR := canonicalize(tt.csr) got, err := o.sans(canonicalizedCSR) if tt.err != nil { if err == nil { t.Errorf("Order.sans() = %v, want error; got none", got) return } var k *Error if errors.As(err, &k) { assert.Equals(t, k.Type, tt.err.Type) assert.Equals(t, k.Detail, tt.err.Detail) assert.Equals(t, k.Status, tt.err.Status) assert.Equals(t, k.Err.Error(), tt.err.Err.Error()) assert.Equals(t, k.Detail, tt.err.Detail) } else { assert.FatalError(t, errors.New("unexpected error type")) } return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Order.sans() = %v, want %v", got, tt.want) } }) } } func TestOrder_getAuthorizationFingerprint(t *testing.T) { ctx := context.Background() type fields struct { AuthorizationIDs []string } type args struct { ctx context.Context db DB } tests := []struct { name string fields fields args args want string wantErr bool }{ {"ok", fields{[]string{"az1", "az2"}}, args{ctx, &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return &Authorization{ID: id, Status: StatusValid}, nil }, }}, "", false}, {"ok fingerprint", fields{[]string{"az1", "az2"}}, args{ctx, &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { if id == "az1" { return &Authorization{ID: id, Status: StatusValid}, nil } return &Authorization{ID: id, Fingerprint: "fingerprint", Status: StatusValid}, nil }, }}, "fingerprint", false}, {"fail", fields{[]string{"az1", "az2"}}, args{ctx, &MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { return nil, errors.New("force") }, }}, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &Order{ AuthorizationIDs: tt.fields.AuthorizationIDs, } got, err := o.getAuthorizationFingerprint(tt.args.ctx, tt.args.db) if (err != nil) != tt.wantErr { t.Errorf("Order.getAuthorizationFingerprint() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("Order.getAuthorizationFingerprint() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: acme/status.go ================================================ package acme // Status represents an ACME status. type Status string var ( // StatusValid -- valid StatusValid = Status("valid") // StatusInvalid -- invalid StatusInvalid = Status("invalid") // StatusPending -- pending; e.g. an Order that is not ready to be finalized. StatusPending = Status("pending") // StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid. StatusDeactivated = Status("deactivated") // StatusReady -- ready; e.g. for an Order that is ready to be finalized. StatusReady = Status("ready") //statusExpired = "expired" //statusActive = "active" //statusProcessing = "processing" ) ================================================ FILE: acme/wire/id.go ================================================ package wire import ( "encoding/json" "errors" "fmt" "net/url" "strings" ) type UserID struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` Handle string `json:"handle,omitempty"` } type DeviceID struct { Name string `json:"name,omitempty"` Domain string `json:"domain,omitempty"` ClientID string `json:"client-id,omitempty"` Handle string `json:"handle,omitempty"` } func ParseUserID(value string) (id UserID, err error) { if err = json.Unmarshal([]byte(value), &id); err != nil { return } switch { case id.Handle == "": err = errors.New("handle must not be empty") case id.Name == "": err = errors.New("name must not be empty") case id.Domain == "": err = errors.New("domain must not be empty") } return } func ParseDeviceID(value string) (id DeviceID, err error) { if err = json.Unmarshal([]byte(value), &id); err != nil { return } switch { case id.Handle == "": err = errors.New("handle must not be empty") case id.Name == "": err = errors.New("name must not be empty") case id.Domain == "": err = errors.New("domain must not be empty") case id.ClientID == "": err = errors.New("client-id must not be empty") } return } type ClientID struct { Scheme string Username string DeviceID string Domain string } // ParseClientID parses a Wire clientID. The ClientID format is as follows: // // "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", // // where '!' is used as a separator between the user id & device id. func ParseClientID(clientID string) (ClientID, error) { clientIDURI, err := url.Parse(clientID) if err != nil { return ClientID{}, fmt.Errorf("invalid Wire client ID URI %q: %w", clientID, err) } if clientIDURI.Scheme != "wireapp" { return ClientID{}, fmt.Errorf("invalid Wire client ID scheme %q; expected \"wireapp\"", clientIDURI.Scheme) } fullUsername := clientIDURI.User.Username() parts := strings.SplitN(fullUsername, "!", 2) if len(parts) != 2 { return ClientID{}, fmt.Errorf("invalid Wire client ID username %q", fullUsername) } return ClientID{ Scheme: clientIDURI.Scheme, Username: parts[0], DeviceID: parts[1], Domain: clientIDURI.Host, }, nil } ================================================ FILE: acme/wire/id_test.go ================================================ package wire import ( "errors" "testing" "github.com/stretchr/testify/assert" ) func TestParseUserID(t *testing.T) { ok := `{"name": "Alice Smith", "domain": "wire.com", "handle": "wireapp://%40alice_wire@wire.com"}` failJSON := `{"name": }` emptyHandle := `{"name": "Alice Smith", "domain": "wire.com", "handle": ""}` emptyName := `{"name": "", "domain": "wire.com", "handle": "wireapp://%40alice_wire@wire.com"}` emptyDomain := `{"name": "Alice Smith", "domain": "", "handle": "wireapp://%40alice_wire@wire.com"}` tests := []struct { name string value string wantWireID UserID wantErr bool }{ {name: "ok", value: ok, wantWireID: UserID{Name: "Alice Smith", Domain: "wire.com", Handle: "wireapp://%40alice_wire@wire.com"}}, {name: "fail/json", value: failJSON, wantErr: true}, {name: "fail/empty-handle", value: emptyHandle, wantErr: true}, {name: "fail/empty-name", value: emptyName, wantErr: true}, {name: "fail/empty-domain", value: emptyDomain, wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotWireID, err := ParseUserID(tt.value) if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) assert.Equal(t, tt.wantWireID, gotWireID) }) } } func TestParseDeviceID(t *testing.T) { ok := `{"name": "device", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}` failJSON := `{"name": }` emptyHandle := `{"name": "device", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": ""}` emptyName := `{"name": "", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}` emptyDomain := `{"name": "device", "domain": "", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}` emptyClientID := `{"name": "device", "domain": "wire.com", "client-id": "", "handle": "wireapp://%40alice_wire@wire.com"}` tests := []struct { name string value string wantWireID DeviceID wantErr bool }{ {name: "ok", value: ok, wantWireID: DeviceID{Name: "device", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com"}}, {name: "fail/json", value: failJSON, wantErr: true}, {name: "fail/empty-handle", value: emptyHandle, wantErr: true}, {name: "fail/empty-name", value: emptyName, wantErr: true}, {name: "fail/empty-domain", value: emptyDomain, wantErr: true}, {name: "fail/empty-client-id", value: emptyClientID, wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotWireID, err := ParseDeviceID(tt.value) if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) assert.Equal(t, tt.wantWireID, gotWireID) }) } } func TestParseClientID(t *testing.T) { tests := []struct { name string clientID string want ClientID expectedErr error }{ {name: "ok", clientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", want: ClientID{Scheme: "wireapp", Username: "CzbfFjDOQrenCbDxVmgnFw", DeviceID: "594930e9d50bb175", Domain: "wire.com"}}, {name: "fail/uri", clientID: "bla", expectedErr: errors.New(`invalid Wire client ID scheme ""; expected "wireapp"`)}, {name: "fail/scheme", clientID: "not-wireapp://bla.com", expectedErr: errors.New(`invalid Wire client ID scheme "not-wireapp"; expected "wireapp"`)}, {name: "fail/username", clientID: "wireapp://user@wire.com", expectedErr: errors.New(`invalid Wire client ID username "user"`)}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := ParseClientID(tt.clientID) if tt.expectedErr != nil { assert.EqualError(t, err, tt.expectedErr.Error()) return } assert.NoError(t, err) assert.Equal(t, tt.want, got) }) } } ================================================ FILE: api/api.go ================================================ package api import ( "bytes" "context" "crypto" "crypto/dsa" //nolint:staticcheck // support legacy algorithms "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/x509" "encoding/asn1" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "net/http" "strconv" "strings" "time" "github.com/go-chi/chi/v5" "github.com/pkg/errors" "go.step.sm/crypto/sshutil" "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/api/log" "github.com/smallstep/certificates/api/models" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" "github.com/smallstep/certificates/logging" ) // Authority is the interface implemented by a CA authority. type Authority interface { SSHAuthority // context specifies the Authorize[Sign|Revoke|etc.] method. Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error) RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error) LoadProvisionerByName(string) (provisioner.Interface, error) GetProvisioners(cursor string, limit int) (provisioner.List, string, error) Revoke(context.Context, *authority.RevokeOptions) error GetEncryptedKey(kid string) (string, error) GetRoots() ([]*x509.Certificate, error) GetIntermediateCertificates() []*x509.Certificate GetFederation() ([]*x509.Certificate, error) Version() authority.Version GetCertificateRevocationList() (*authority.CertificateRevocationListInfo, error) } // mustAuthority will be replaced on unit tests. var mustAuthority = func(ctx context.Context) Authority { return authority.MustFromContext(ctx) } // TimeDuration is an alias of provisioner.TimeDuration type TimeDuration = provisioner.TimeDuration // NewTimeDuration returns a TimeDuration with the defined time. func NewTimeDuration(t time.Time) TimeDuration { return provisioner.NewTimeDuration(t) } // ParseTimeDuration returns a new TimeDuration parsing the RFC 3339 time or // time.Duration string. func ParseTimeDuration(s string) (TimeDuration, error) { return provisioner.ParseTimeDuration(s) } // Certificate wraps a *x509.Certificate and adds the json.Marshaler interface. type Certificate struct { *x509.Certificate } // NewCertificate is a helper method that returns a Certificate from a // *x509.Certificate. func NewCertificate(cr *x509.Certificate) Certificate { return Certificate{ Certificate: cr, } } // reset sets the inner x509.CertificateRequest to nil func (c *Certificate) reset() { if c != nil { c.Certificate = nil } } // MarshalJSON implements the json.Marshaler interface. The certificate is // quoted string using the PEM encoding. func (c Certificate) MarshalJSON() ([]byte, error) { if c.Certificate == nil { return []byte("null"), nil } block := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: c.Raw, }) return json.Marshal(string(block)) } // UnmarshalJSON implements the json.Unmarshaler interface. The certificate is // expected to be a quoted string using the PEM encoding. func (c *Certificate) UnmarshalJSON(data []byte) error { var s string if err := json.Unmarshal(data, &s); err != nil { return errors.Wrap(err, "error decoding certificate") } // Make sure the inner x509.Certificate is nil if s == "null" || s == "" { c.reset() return nil } block, _ := pem.Decode([]byte(s)) if block == nil { return errors.New("error decoding certificate") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return errors.Wrap(err, "error decoding certificate") } c.Certificate = cert return nil } // CertificateRequest wraps a *x509.CertificateRequest and adds the // json.Unmarshaler interface. type CertificateRequest struct { *x509.CertificateRequest } // NewCertificateRequest is a helper method that returns a CertificateRequest // from a *x509.CertificateRequest. func NewCertificateRequest(cr *x509.CertificateRequest) CertificateRequest { return CertificateRequest{ CertificateRequest: cr, } } // reset sets the inner x509.CertificateRequest to nil func (c *CertificateRequest) reset() { if c != nil { c.CertificateRequest = nil } } // MarshalJSON implements the json.Marshaler interface. The certificate request // is a quoted string using the PEM encoding. func (c CertificateRequest) MarshalJSON() ([]byte, error) { if c.CertificateRequest == nil { return []byte("null"), nil } block := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE REQUEST", Bytes: c.Raw, }) return json.Marshal(string(block)) } // UnmarshalJSON implements the json.Unmarshaler interface. The certificate // request is expected to be a quoted string using the PEM encoding. func (c *CertificateRequest) UnmarshalJSON(data []byte) error { var s string if err := json.Unmarshal(data, &s); err != nil { return errors.Wrap(err, "error decoding csr") } // Make sure the inner x509.CertificateRequest is nil if s == "null" || s == "" { c.reset() return nil } block, _ := pem.Decode([]byte(s)) if block == nil { return errors.New("error decoding csr") } cr, err := x509.ParseCertificateRequest(block.Bytes) if err != nil { return errors.Wrap(err, "error decoding csr") } c.CertificateRequest = cr return nil } // Router defines a common router interface. type Router interface { // MethodFunc adds routes for `pattern` that matches // the `method` HTTP method. MethodFunc(method, pattern string, h http.HandlerFunc) } // RouterHandler is the interface that a HTTP handler that manages multiple // endpoints will implement. type RouterHandler interface { Route(r Router) } // VersionResponse is the response object that returns the version of the // server. type VersionResponse struct { Version string `json:"version"` RequireClientAuthentication bool `json:"requireClientAuthentication,omitempty"` } // HealthResponse is the response object that returns the health of the server. type HealthResponse struct { Status string `json:"status"` } // RootResponse is the response object that returns the PEM of a root certificate. type RootResponse struct { RootPEM Certificate `json:"ca"` } // ProvisionersResponse is the response object that returns the list of // provisioners. type ProvisionersResponse struct { Provisioners provisioner.List NextCursor string } const redacted = "*** REDACTED ***" func scepFromProvisioner(p *provisioner.SCEP) *models.SCEP { return &models.SCEP{ ID: p.ID, Type: p.Type, Name: p.Name, ForceCN: p.ForceCN, ChallengePassword: redacted, Capabilities: p.Capabilities, IncludeRoot: p.IncludeRoot, ExcludeIntermediate: p.ExcludeIntermediate, MinimumPublicKeyLength: p.MinimumPublicKeyLength, DecrypterCertificate: []byte(redacted), DecrypterKeyPEM: []byte(redacted), DecrypterKeyURI: redacted, DecrypterKeyPassword: redacted, EncryptionAlgorithmIdentifier: p.EncryptionAlgorithmIdentifier, Options: p.Options, Claims: p.Claims, } } // MarshalJSON implements json.Marshaler. It marshals the ProvisionersResponse // into a byte slice. // // Special treatment is given to the SCEP provisioner, as it contains a // challenge secret that MUST NOT be leaked in (public) HTTP responses. The // challenge value is thus redacted in HTTP responses. func (p ProvisionersResponse) MarshalJSON() ([]byte, error) { var responseProvisioners provisioner.List for _, item := range p.Provisioners { scepProv, ok := item.(*provisioner.SCEP) if !ok { responseProvisioners = append(responseProvisioners, item) continue } responseProvisioners = append(responseProvisioners, scepFromProvisioner(scepProv)) } var list = struct { Provisioners []provisioner.Interface `json:"provisioners"` NextCursor string `json:"nextCursor"` }{ Provisioners: []provisioner.Interface(responseProvisioners), NextCursor: p.NextCursor, } return json.Marshal(list) } // ProvisionerKeyResponse is the response object that returns the encrypted key // of a provisioner. type ProvisionerKeyResponse struct { Key string `json:"key"` } // RootsResponse is the response object of the roots request. type RootsResponse struct { Certificates []Certificate `json:"crts"` } // IntermediatesResponse is the response object of the intermediates request. type IntermediatesResponse struct { Certificates []Certificate `json:"crts"` } // FederationResponse is the response object of the federation request. type FederationResponse struct { Certificates []Certificate `json:"crts"` } // caHandler is the type used to implement the different CA HTTP endpoints. type caHandler struct { Authority Authority } // Route configures the http request router. func (h *caHandler) Route(r Router) { Route(r) } // New creates a new RouterHandler with the CA endpoints. // // Deprecated: Use api.Route(r Router) func New(Authority) RouterHandler { return &caHandler{} } func Route(r Router) { r.MethodFunc("GET", "/version", Version) r.MethodFunc("GET", "/health", Health) r.MethodFunc("GET", "/root/{sha}", Root) r.MethodFunc("POST", "/sign", Sign) r.MethodFunc("POST", "/renew", Renew) r.MethodFunc("POST", "/rekey", Rekey) r.MethodFunc("POST", "/revoke", Revoke) r.MethodFunc("GET", "/crl", CRL) r.MethodFunc("GET", "/provisioners", Provisioners) r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey) r.MethodFunc("GET", "/roots", Roots) r.MethodFunc("GET", "/roots.pem", RootsPEM) r.MethodFunc("GET", "/intermediates", Intermediates) r.MethodFunc("GET", "/intermediates.pem", IntermediatesPEM) r.MethodFunc("GET", "/federation", Federation) // SSH CA r.MethodFunc("POST", "/ssh/sign", SSHSign) r.MethodFunc("POST", "/ssh/renew", SSHRenew) r.MethodFunc("POST", "/ssh/revoke", SSHRevoke) r.MethodFunc("POST", "/ssh/rekey", SSHRekey) r.MethodFunc("GET", "/ssh/roots", SSHRoots) r.MethodFunc("GET", "/ssh/federation", SSHFederation) r.MethodFunc("POST", "/ssh/config", SSHConfig) r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig) r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost) r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts) r.MethodFunc("POST", "/ssh/bastion", SSHBastion) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", Renew) r.MethodFunc("POST", "/sign-ssh", SSHSign) r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts) } // Version is an HTTP handler that returns the version of the server. func Version(w http.ResponseWriter, r *http.Request) { v := mustAuthority(r.Context()).Version() render.JSON(w, r, VersionResponse{ Version: v.Version, RequireClientAuthentication: v.RequireClientAuthentication, }) } // Health is an HTTP handler that returns the status of the server. func Health(w http.ResponseWriter, r *http.Request) { render.JSON(w, r, HealthResponse{Status: "ok"}) } // Root is an HTTP handler that using the SHA256 from the URL, returns the root // certificate for the given SHA256. func Root(w http.ResponseWriter, r *http.Request) { sha := chi.URLParam(r, "sha") sum := strings.ToLower(strings.ReplaceAll(sha, "-", "")) // Load root certificate with the cert, err := mustAuthority(r.Context()).Root(sum) if err != nil { render.Error(w, r, errs.NotFoundErr(err, errs.WithMessage("root certificate with fingerprint %q was not found", sum))) return } render.JSON(w, r, &RootResponse{RootPEM: Certificate{cert}}) } func certChainToPEM(certChain []*x509.Certificate) []Certificate { certChainPEM := make([]Certificate, 0, len(certChain)) for _, c := range certChain { certChainPEM = append(certChainPEM, Certificate{c}) } return certChainPEM } // Provisioners returns the list of provisioners configured in the authority. func Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { render.Error(w, r, err) return } p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } render.JSON(w, r, &ProvisionersResponse{ Provisioners: p, NextCursor: next, }) } // ProvisionerKey returns the encrypted key of a provisioner by it's key id. func ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") key, err := mustAuthority(r.Context()).GetEncryptedKey(kid) if err != nil { render.Error(w, r, errs.NotFoundErr(err)) return } render.JSON(w, r, &ProvisionerKeyResponse{key}) } // Roots returns all the root certificates for the CA. func Roots(w http.ResponseWriter, r *http.Request) { roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error getting roots")) return } certs := make([]Certificate, len(roots)) for i := range roots { certs[i] = Certificate{roots[i]} } render.JSONStatus(w, r, &RootsResponse{ Certificates: certs, }, http.StatusCreated) } // RootsPEM returns all the root certificates for the CA in PEM format. func RootsPEM(w http.ResponseWriter, r *http.Request) { roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } w.Header().Set("Content-Type", "application/x-pem-file") for _, root := range roots { block := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: root.Raw, }) if _, err := w.Write(block); err != nil { log.Error(w, r, err) return } } } // Intermediates returns all the intermediate certificates of the CA. func Intermediates(w http.ResponseWriter, r *http.Request) { intermediates := mustAuthority(r.Context()).GetIntermediateCertificates() if len(intermediates) == 0 { render.Error(w, r, errs.NotImplemented("error getting intermediates: method not implemented")) return } certs := make([]Certificate, len(intermediates)) for i := range intermediates { certs[i] = Certificate{intermediates[i]} } render.JSONStatus(w, r, &IntermediatesResponse{ Certificates: certs, }, http.StatusCreated) } // IntermediatesPEM returns all the intermediate certificates for the CA in PEM format. func IntermediatesPEM(w http.ResponseWriter, r *http.Request) { intermediates := mustAuthority(r.Context()).GetIntermediateCertificates() if len(intermediates) == 0 { render.Error(w, r, errs.NotImplemented("error getting intermediates: method not implemented")) return } w.Header().Set("Content-Type", "application/x-pem-file") for _, crt := range intermediates { block := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, }) if _, err := w.Write(block); err != nil { log.Error(w, r, err) return } } } // Federation returns all the public certificates in the federation. func Federation(w http.ResponseWriter, r *http.Request) { federated, err := mustAuthority(r.Context()).GetFederation() if err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error getting federated roots")) return } certs := make([]Certificate, len(federated)) for i := range federated { certs[i] = Certificate{federated[i]} } render.JSONStatus(w, r, &FederationResponse{ Certificates: certs, }, http.StatusCreated) } var oidStepProvisioner = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1} type stepProvisioner struct { Type int Name []byte CredentialID []byte } func logOtt(w http.ResponseWriter, token string) { if rl, ok := w.(logging.ResponseLogger); ok { rl.WithFields(map[string]interface{}{ "ott": token, }) } } // LogCertificate adds certificate fields to the log message. func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) { if rl, ok := w.(logging.ResponseLogger); ok { m := map[string]interface{}{ "serial": cert.SerialNumber.String(), "subject": cert.Subject.CommonName, "issuer": cert.Issuer.CommonName, "sans": fmtSans(cert), "valid-from": cert.NotBefore.Format(time.RFC3339), "valid-to": cert.NotAfter.Format(time.RFC3339), "public-key": fmtPublicKey(cert), "certificate": base64.StdEncoding.EncodeToString(cert.Raw), } for _, ext := range cert.Extensions { if !ext.Id.Equal(oidStepProvisioner) { continue } val := &stepProvisioner{} rest, err := asn1.Unmarshal(ext.Value, val) if err != nil || len(rest) > 0 { break } if len(val.CredentialID) > 0 { m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID) } else { m["provisioner"] = string(val.Name) } break } rl.WithFields(m) } } // LogSSHCertificate adds SSH certificate fields to the log message. func LogSSHCertificate(w http.ResponseWriter, cert *ssh.Certificate) { if rl, ok := w.(logging.ResponseLogger); ok { mak := bytes.TrimSpace(ssh.MarshalAuthorizedKey(cert)) var certificate string parts := strings.Split(string(mak), " ") if len(parts) > 1 { certificate = parts[1] } var userOrHost string if cert.CertType == ssh.HostCert { userOrHost = "host" } else { userOrHost = "user" } certificateType := fmt.Sprintf("%s %s certificate", parts[0], userOrHost) // e.g. ecdsa-sha2-nistp256-cert-v01@openssh.com user certificate m := map[string]interface{}{ "serial": cert.Serial, "principals": cert.ValidPrincipals, "valid-from": time.Unix(cast.Int64(cert.ValidAfter), 0).Format(time.RFC3339), "valid-to": time.Unix(cast.Int64(cert.ValidBefore), 0).Format(time.RFC3339), "certificate": certificate, "certificate-type": certificateType, } fingerprint, err := sshutil.FormatFingerprint(mak, sshutil.DefaultFingerprint) if err == nil { fpParts := strings.Split(fingerprint, " ") if len(fpParts) > 3 { m["public-key"] = fmt.Sprintf("%s %s", fpParts[1], fpParts[len(fpParts)-1]) } } rl.WithFields(m) } } // ParseCursor parses the cursor and limit from the request query params. func ParseCursor(r *http.Request) (cursor string, limit int, err error) { q := r.URL.Query() cursor = q.Get("cursor") if v := q.Get("limit"); v != "" { limit, err = strconv.Atoi(v) if err != nil { return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v) } } return } func fmtSans(cert *x509.Certificate) map[string][]string { sans := make(map[string][]string) if len(cert.DNSNames) > 0 { sans["dns"] = cert.DNSNames } if len(cert.EmailAddresses) > 0 { sans["email"] = cert.EmailAddresses } if size := len(cert.IPAddresses); size > 0 { ips := make([]string, size) for i, ip := range cert.IPAddresses { ips[i] = ip.String() } sans["ip"] = ips } if size := len(cert.URIs); size > 0 { uris := make([]string, size) for i, u := range cert.URIs { uris[i] = u.String() } sans["uri"] = uris } return sans } func fmtPublicKey(cert *x509.Certificate) string { var params string switch pk := cert.PublicKey.(type) { case *ecdsa.PublicKey: params = pk.Curve.Params().Name case *rsa.PublicKey: params = strconv.Itoa(pk.Size() * 8) case ed25519.PublicKey: return cert.PublicKeyAlgorithm.String() case *dsa.PublicKey: params = strconv.Itoa(pk.Q.BitLen() * 8) default: params = "unknown" } return fmt.Sprintf("%s %s", cert.PublicKeyAlgorithm, params) } ================================================ FILE: api/api_test.go ================================================ package api import ( "bytes" "context" "crypto" "crypto/dsa" //nolint:staticcheck // support legacy algorithms "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "io" "math/big" "net/http" "net/http/httptest" "reflect" "strings" "testing" "time" "github.com/go-chi/chi/v5" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "go.step.sm/crypto/minica" "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" ) const ( rootPEM = `-----BEGIN CERTIFICATE----- MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7 qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY /iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/ zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6 yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx -----END CERTIFICATE-----` certPEM = `-----BEGIN CERTIFICATE----- MIIDujCCAqKgAwIBAgIIE31FZVaPXTUwDQYJKoZIhvcNAQEFBQAwSTELMAkGA1UE BhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJbmMxJTAjBgNVBAMTHEdvb2dsZSBJbnRl cm5ldCBBdXRob3JpdHkgRzIwHhcNMTQwMTI5MTMyNzQzWhcNMTQwNTI5MDAwMDAw WjBpMQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwN TW91bnRhaW4gVmlldzETMBEGA1UECgwKR29vZ2xlIEluYzEYMBYGA1UEAwwPbWFp bC5nb29nbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfRrObuSW5T7q 5CnSEqefEmtH4CCv6+5EckuriNr1CjfVvqzwfAhopXkLrq45EQm8vkmf7W96XJhC 7ZM0dYi1/qOCAU8wggFLMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAa BgNVHREEEzARgg9tYWlsLmdvb2dsZS5jb20wCwYDVR0PBAQDAgeAMGgGCCsGAQUF BwEBBFwwWjArBggrBgEFBQcwAoYfaHR0cDovL3BraS5nb29nbGUuY29tL0dJQUcy LmNydDArBggrBgEFBQcwAYYfaHR0cDovL2NsaWVudHMxLmdvb2dsZS5jb20vb2Nz cDAdBgNVHQ4EFgQUiJxtimAuTfwb+aUtBn5UYKreKvMwDAYDVR0TAQH/BAIwADAf BgNVHSMEGDAWgBRK3QYWG7z2aLV29YG2u2IaulqBLzAXBgNVHSAEEDAOMAwGCisG AQQB1nkCBQEwMAYDVR0fBCkwJzAloCOgIYYfaHR0cDovL3BraS5nb29nbGUuY29t L0dJQUcyLmNybDANBgkqhkiG9w0BAQUFAAOCAQEAH6RYHxHdcGpMpFE3oxDoFnP+ gtuBCHan2yE2GRbJ2Cw8Lw0MmuKqHlf9RSeYfd3BXeKkj1qO6TVKwCh+0HdZk283 TZZyzmEOyclm3UGFYe82P/iDFt+CeQ3NpmBg+GoaVCuWAARJN/KfglbLyyYygcQq 0SgeDh8dRKUiaW3HQSoYvTvdTuqzwK4CXsr3b5/dAOY8uMuG/IAR3FgwTbZ1dtoW RvOTa8hYiU6A475WuZKyEHcwnGYe57u2I2KbMgcKjPniocj4QzgYsVAVKW3IwaOh yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA== -----END CERTIFICATE-----` csrPEM = `-----BEGIN CERTIFICATE REQUEST----- MIIEYjCCAkoCAQAwHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0ZXAuY29tMIICIjAN BgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuCpifZfoZhYNywfpnPa21NezXgtn wrWBFE6xhVzE7YDSIqtIsj8aR7R8zwEymxfv5j5298LUy/XSmItVH31CsKyfcGqN QM0PZr9XY3z5V6qchGMqjzt/jqlYMBHujcxIFBfz4HATxSgKyvHqvw14ESsS2huu 7jowx+XTKbFYgKcXrjBkvOej5FXD3ehkg0jDA2UAJNdfKmrc1BBEaaqOtfh7eyU2 HU7+5gxH8C27IiCAmNj719E0B99Nu2MUw6aLFIM4xAcRga33Avevx6UuXZZIEepe V1sihrkcnDK9Vsxkme5erXzvAoOiRusiC2iIomJHJrdRM5ReEU+N+Tl1Kxq+rk7H /qAq78wVm07M1/GGi9SUMObZS4WuJpM6whlikIAEbv9iV+CK0sv/Jr/AADdGMmQU lwk+Q0ZNE8p4ZuWILv/dtLDtDVBpnrrJ9e8duBtB0lGcG8MdaUCQ346EI4T0Sgx0 hJ+wMq8zYYFfPIZEHC8o9p1ywWN9ySpJ8Zj/5ubmx9v2bY67GbuVFEa8iAp+S00x /Z8nD6/JsoKtexuHyGr3ixWFzlBqXDuugukIDFUOVDCbuGw4Io4/hEMu4Zz0TIFk Uu/wf2z75Tt8EkosKLu2wieKcY7n7Vhog/0tqexqWlWtJH0tvq4djsGoSvA62WPs 0iXXj+aZIARPNhECAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4ICAQA0vyHIndAkIs/I Nnz5yZWCokRjokoKv3Aj4VilyjncL+W0UIPULLU/47ZyoHVSUj2t8gknr9xu/Kd+ g/2z0RiF3CIp8IUH49w/HYWaR95glzVNAAzr8qD9UbUqloLVQW3lObSRGtezhdZO sspw5dC+inhAb1LZhx8PVxB3SAeJ8h11IEBr0s2Hxt9viKKd7YPtIFZkZdOkVx4R if1DMawj1P6fEomf8z7m+dmbUYTqqosbCbRL01mzEga/kF6JyH/OzpNlcsAiyM8e BxPWH6TtPqwmyy4y7j1outmM0RnyUw5A0HmIbWh+rHpXiHVsnNqse0XfzmaxM8+z dxYeDax8aMWZKfvY1Zew+xIxl7DtEy1BpxrZcawumJYt5+LL+bwF/OtL0inQLnw8 zyqydsXNdrpIQJnfmWPld7ThWbQw2FBE70+nFSxHeG2ULnpF3M9xf6ZNAF4gqaNE Q7vMNPBWrJWu+A++vHY61WGET+h4lY3GFr2I8OE4IiHPQi1D7Y0+fwOmStwuRPM4 2rARcJChNdiYBkkuvs4kixKTTjdXhB8RQtuBSrJ0M1tzq2qMbm7F8G01rOg4KlXU 58jHzJwr1K7cx0lpWfGTtc5bseCGtTKmDBXTziw04yl8eE1+ZFOganixGwCtl4Tt DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w== -----END CERTIFICATE REQUEST-----` stepCertPEM = `-----BEGIN CERTIFICATE----- MIIChTCCAiugAwIBAgIRAJ3O5T28Rdj2lr/UPjf+GAUwCgYIKoZIzj0EAwIwJDEi MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0xOTAyMjAyMDE1 NDNaFw0xOTAyMjEyMDE1NDNaMHExCzAJBgNVBAYTAlVTMQswCQYDVQQIEwJDQTEW MBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEcMBoGA1UEChMTU21hbGxzdGVwIExhYnMg SW5jLjEfMB0GA1UEAxMWaW50ZXJuYWwuc21hbGxzdGVwLmNvbTBZMBMGByqGSM49 AgEGCCqGSM49AwEHA0IABC0aKrTNl+gXFuNkXisqX4/foLO3VMt+Kphngziim+fz aJhiS9JU+oFYLTNW6HWGUD8CNzfwrmWlVsAmiJwHKlKjgfAwge0wDgYDVR0PAQH/ BAQDAgWgMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAdBgNVHQ4EFgQU JheKvlZqNv1IcgaC8WOS1Zg0i1QwHwYDVR0jBBgwFoAUu97PaFQPfuyKOeew7Hg4 5WFIAVMwIQYDVR0RBBowGIIWaW50ZXJuYWwuc21hbGxzdGVwLmNvbTBZBgwrBgEE AYKkZMYoQAEESTBHAgEBBBVtYXJpYW5vQHNtYWxsc3RlcC5jb20EK2pPMzdkdERi a3UtUW5hYnM1VlIwWXc2WUZGdjl3ZUExOGRwM2h0dmRFanMwCgYIKoZIzj0EAwID SAAwRQIhAIrn17fP5CBrGtKuhyPiq6eSwryBCf8ki+k17u5a+E/LAiB24Y2E0Put nIHOI54lAqDeF7A0y73fPRVCiJEWmuxz0g== -----END CERTIFICATE-----` pubKey = `{ "use": "sig", "kty": "EC", "kid": "oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00", "crv": "P-256", "alg": "ES256", "x": "p9QX4tzjxUrB0fgqRWLKUuPolDtBW681f2Qyh-uVNhk", "y": "CNSEloc4oLDFTX0Vywj0WiqOlh516sFQwCj6WtM8LT8" }` privKey = "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiNEhBYjE0WDQ5OFM4LWxSb29JTnpqZyJ9.RbkJXGzI3kOsaP20KmZs0ELFLgpRddAE49AJHlEblw-uH_gg6SV3QA.M3MArEpHgI171lhm.gBlFySpzK9F7riBJbtLSNkb4nAw_gWokqs1jS-ZK1qxuqTK-9mtX5yILjRnftx9P9uFp5xt7rvv4Mgom1Ed4V9WtIyfNP_Cz3Pme1Eanp5nY68WCe_yG6iSB1RJdMDBUb2qBDZiBdhJim1DRXsOfgedOrNi7GGbppMlD77DEpId118owR5izA-c6Q_hg08hIE3tnMAnebDNQoF9jfEY99_AReVRH8G4hgwZEPCfXMTb3J-lowKGG4vXIbK5knFLh47SgOqG4M2M51SMS-XJ7oBz1Vjoamc90QIqKV51rvZ5m0N_sPFtxzcfV4E9yYH3XVd4O-CG4ydVKfKVyMtQ.mcKFZqBHp_n7Ytj2jz9rvw" ) func mustJSON(t *testing.T, v any) []byte { t.Helper() var buf bytes.Buffer require.NoError(t, json.NewEncoder(&buf).Encode(v)) return buf.Bytes() } func parseCertificate(data string) *x509.Certificate { block, _ := pem.Decode([]byte(data)) if block == nil { panic("failed to parse certificate PEM") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { panic("failed to parse certificate: " + err.Error()) } return cert } func parseCertificateRequest(data string) *x509.CertificateRequest { block, _ := pem.Decode([]byte(data)) if block == nil { panic("failed to parse certificate request PEM") } csr, err := x509.ParseCertificateRequest(block.Bytes) if err != nil { panic("failed to parse certificate request: " + err.Error()) } return csr } func mockMustAuthority(t *testing.T, a Authority) { t.Helper() fn := mustAuthority t.Cleanup(func() { mustAuthority = fn }) mustAuthority = func(ctx context.Context) Authority { return a } } type mockAuthority struct { ret1, ret2 interface{} err error authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error) authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByName func(name string) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) revoke func(context.Context, *authority.RevokeOptions) error getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) getIntermediateCertificates func() []*x509.Certificate getFederation func() ([]*x509.Certificate, error) getCRL func() (*authority.CertificateRevocationListInfo, error) signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) checkSSHHost func(ctx context.Context, principal, token string) (bool, error) getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) version func() authority.Version } func (m *mockAuthority) GetCertificateRevocationList() (*authority.CertificateRevocationListInfo, error) { if m.getCRL != nil { return m.getCRL() } return m.ret1.(*authority.CertificateRevocationListInfo), m.err } // TODO: remove once Authorize is deprecated. func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { if m.authorize != nil { return m.authorize(ctx, ott) } return m.ret1.([]provisioner.SignOption), m.err } func (m *mockAuthority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) { if m.authorizeRenewToken != nil { return m.authorizeRenewToken(ctx, ott) } return m.ret1.(*x509.Certificate), m.err } func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { if m.getTLSOptions != nil { return m.getTLSOptions() } return m.ret1.(*authority.TLSOptions) } func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { if m.root != nil { return m.root(shasum) } return m.ret1.(*x509.Certificate), m.err } func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { if m.signWithContext != nil { return m.signWithContext(ctx, cr, opts, signOpts...) } return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { if m.renew != nil { return m.renew(cert) } return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } func (m *mockAuthority) RenewContext(ctx context.Context, oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { if m.renewContext != nil { return m.renewContext(ctx, oldcert, pk) } return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { if m.rekey != nil { return m.rekey(oldcert, pk) } return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { if m.getProvisioners != nil { return m.getProvisioners(nextCursor, limit) } return m.ret1.(provisioner.List), m.ret2.(string), m.err } func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { if m.loadProvisionerByCertificate != nil { return m.loadProvisionerByCertificate(cert) } return m.ret1.(provisioner.Interface), m.err } func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { if m.loadProvisionerByName != nil { return m.loadProvisionerByName(name) } return m.ret1.(provisioner.Interface), m.err } func (m *mockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { if m.revoke != nil { return m.revoke(ctx, opts) } return m.err } func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { if m.getEncryptedKey != nil { return m.getEncryptedKey(kid) } return m.ret1.(string), m.err } func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { if m.getRoots != nil { return m.getRoots() } return m.ret1.([]*x509.Certificate), m.err } func (m *mockAuthority) GetIntermediateCertificates() []*x509.Certificate { if m.getIntermediateCertificates != nil { return m.getIntermediateCertificates() } return m.ret1.([]*x509.Certificate) } func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { if m.getFederation != nil { return m.getFederation() } return m.ret1.([]*x509.Certificate), m.err } func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { if m.signSSH != nil { return m.signSSH(ctx, key, opts, signOpts...) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { if m.signSSHAddUser != nil { return m.signSSHAddUser(ctx, key, cert) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { if m.renewSSH != nil { return m.renewSSH(ctx, cert) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { if m.rekeySSH != nil { return m.rekeySSH(ctx, cert, key, signOpts...) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) { if m.getSSHHosts != nil { return m.getSSHHosts(ctx, cert) } return m.ret1.([]authority.Host), m.err } func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { if m.getSSHRoots != nil { return m.getSSHRoots(ctx) } return m.ret1.(*authority.SSHKeys), m.err } func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { if m.getSSHFederation != nil { return m.getSSHFederation(ctx) } return m.ret1.(*authority.SSHKeys), m.err } func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { if m.getSSHConfig != nil { return m.getSSHConfig(ctx, typ, data) } return m.ret1.([]templates.Output), m.err } func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { if m.checkSSHHost != nil { return m.checkSSHHost(ctx, principal, token) } return m.ret1.(bool), m.err } func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { if m.getSSHBastion != nil { return m.getSSHBastion(ctx, user, hostname) } return m.ret1.(*authority.Bastion), m.err } func (m *mockAuthority) Version() authority.Version { if m.version != nil { return m.version() } return m.ret1.(authority.Version) } func TestNewCertificate(t *testing.T) { cert := parseCertificate(rootPEM) if !reflect.DeepEqual(Certificate{Certificate: cert}, NewCertificate(cert)) { t.Errorf("NewCertificate failed, got %v, wants %v", NewCertificate(cert), Certificate{Certificate: cert}) } } func TestCertificate_MarshalJSON(t *testing.T) { type fields struct { Certificate *x509.Certificate } tests := []struct { name string fields fields want []byte wantErr bool }{ {"nil", fields{Certificate: nil}, []byte("null"), false}, {"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false}, {"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"`), false}, {"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := Certificate{ Certificate: tt.fields.Certificate, } got, err := c.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("Certificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Certificate.MarshalJSON() = %s, want %s", got, tt.want) } }) } } func TestCertificate_UnmarshalJSON(t *testing.T) { tests := []struct { name string data []byte wantCert bool wantErr bool }{ {"no data", nil, false, true}, {"incomplete string 1", []byte(`"foobar`), false, true}, {"incomplete string 2", []byte(`foobar"`), false, true}, {"invalid string", []byte(`"foobar"`), false, true}, {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true}, {"invalid type", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), false, true}, {"empty string", []byte(`""`), false, false}, {"json null", []byte(`null`), false, false}, {"valid root", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), true, false}, {"valid cert", []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"`), true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var c Certificate if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr { t.Errorf("Certificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } if tt.wantCert && c.Certificate == nil { t.Error("Certificate.UnmarshalJSON() failed, Certificate is nil") } }) } } func TestCertificate_UnmarshalJSON_json(t *testing.T) { tests := []struct { name string data string wantCert bool wantErr bool }{ {"invalid type (bool)", `{"crt":true}`, false, true}, {"invalid type (number)", `{"crt":123}`, false, true}, {"invalid type (object)", `{"crt":{}}`, false, true}, {"empty crt (null)", `{"crt":null}`, false, false}, {"empty crt (string)", `{"crt":""}`, false, false}, {"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true}, {"valid crt", `{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"}`, true, false}, } type request struct { Cert Certificate `json:"crt"` } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var body request if err := json.Unmarshal([]byte(tt.data), &body); (err != nil) != tt.wantErr { t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } switch tt.wantCert { case true: if body.Cert.Certificate == nil { t.Error("json.Unmarshal() failed, Certificate is nil") } case false: if body.Cert.Certificate != nil { t.Error("json.Unmarshal() failed, Certificate is not nil") } } }) } } func TestNewCertificateRequest(t *testing.T) { csr := parseCertificateRequest(csrPEM) if !reflect.DeepEqual(CertificateRequest{CertificateRequest: csr}, NewCertificateRequest(csr)) { t.Errorf("NewCertificateRequest failed, got %v, wants %v", NewCertificateRequest(csr), CertificateRequest{CertificateRequest: csr}) } } func TestCertificateRequest_MarshalJSON(t *testing.T) { type fields struct { CertificateRequest *x509.CertificateRequest } tests := []struct { name string fields fields want []byte wantErr bool }{ {"nil", fields{CertificateRequest: nil}, []byte("null"), false}, {"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false}, {"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `\n"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := CertificateRequest{ CertificateRequest: tt.fields.CertificateRequest, } got, err := c.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("CertificateRequest.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CertificateRequest.MarshalJSON() = %s, want %s", got, tt.want) } }) } } func TestCertificateRequest_UnmarshalJSON(t *testing.T) { tests := []struct { name string data []byte wantCert bool wantErr bool }{ {"no data", nil, false, true}, {"incomplete string 1", []byte(`"foobar`), false, true}, {"incomplete string 2", []byte(`foobar"`), false, true}, {"invalid string", []byte(`"foobar"`), false, true}, {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true}, {"invalid type", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), false, true}, {"empty string", []byte(`""`), false, false}, {"json null", []byte(`null`), false, false}, {"valid csr", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var c CertificateRequest if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr { t.Errorf("CertificateRequest.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } if tt.wantCert && c.CertificateRequest == nil { t.Error("CertificateRequest.UnmarshalJSON() failed, CertificateRequet is nil") } }) } } func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) { tests := []struct { name string data string wantCert bool wantErr bool }{ {"invalid type (bool)", `{"csr":true}`, false, true}, {"invalid type (number)", `{"csr":123}`, false, true}, {"invalid type (object)", `{"csr":{}}`, false, true}, {"empty csr (null)", `{"csr":null}`, false, false}, {"empty csr (string)", `{"csr":""}`, false, false}, {"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true}, {"valid csr", `{"csr":"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"}`, true, false}, } type request struct { CSR CertificateRequest `json:"csr"` } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var body request if err := json.Unmarshal([]byte(tt.data), &body); (err != nil) != tt.wantErr { t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } switch tt.wantCert { case true: if body.CSR.CertificateRequest == nil { t.Error("json.Unmarshal() failed, CertificateRequest is nil") } case false: if body.CSR.CertificateRequest != nil { t.Error("json.Unmarshal() failed, CertificateRequest is not nil") } } }) } } func TestSignRequest_Validate(t *testing.T) { csr := parseCertificateRequest(csrPEM) bad := parseCertificateRequest(csrPEM) bad.Signature[0]++ type fields struct { CsrPEM CertificateRequest OTT string NotBefore time.Time NotAfter time.Time } tests := []struct { name string fields fields err error }{ {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")}, {"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")}, {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &SignRequest{ CsrPEM: tt.fields.CsrPEM, OTT: tt.fields.OTT, NotAfter: NewTimeDuration(tt.fields.NotAfter), NotBefore: NewTimeDuration(tt.fields.NotBefore), } if err := s.Validate(); err != nil { if assert.NotNil(t, tt.err) { assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } } else { assert.Nil(t, tt.err) } }) } } type mockProvisioner struct { ret1, ret2, ret3 interface{} err error getID func() string getIDForToken func() string getTokenID func(string) (string, error) getName func() string getType func() provisioner.Type getEncryptedKey func() (string, string, bool) init func(provisioner.Config) error authorizeRenew func(ctx context.Context, cert *x509.Certificate) error authorizeRevoke func(ctx context.Context, token string) error authorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) authorizeRenewal func(*x509.Certificate) error authorizeSSHSign func(ctx context.Context, token string) ([]provisioner.SignOption, error) authorizeSSHRevoke func(ctx context.Context, token string) error authorizeSSHRenew func(ctx context.Context, token string) (*ssh.Certificate, error) authorizeSSHRekey func(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) } func (m *mockProvisioner) GetID() string { if m.getID != nil { return m.getID() } return m.ret1.(string) } func (m *mockProvisioner) GetIDForToken() string { if m.getIDForToken != nil { return m.getIDForToken() } return m.ret1.(string) } func (m *mockProvisioner) GetTokenID(token string) (string, error) { if m.getTokenID != nil { return m.getTokenID(token) } if m.ret1 == nil { return "", m.err } return m.ret1.(string), m.err } func (m *mockProvisioner) GetName() string { if m.getName != nil { return m.getName() } return m.ret1.(string) } func (m *mockProvisioner) GetType() provisioner.Type { if m.getType != nil { return m.getType() } return m.ret1.(provisioner.Type) } func (m *mockProvisioner) GetEncryptedKey() (string, string, bool) { if m.getEncryptedKey != nil { return m.getEncryptedKey() } return m.ret1.(string), m.ret2.(string), m.ret3.(bool) } func (m *mockProvisioner) Init(c provisioner.Config) error { if m.init != nil { return m.init(c) } return m.err } func (m *mockProvisioner) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if m.authorizeRenew != nil { return m.authorizeRenew(ctx, cert) } return m.err } func (m *mockProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { if m.authorizeRevoke != nil { return m.authorizeRevoke(ctx, token) } return m.err } func (m *mockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { if m.authorizeSign != nil { return m.authorizeSign(ctx, ott) } return m.ret1.([]provisioner.SignOption), m.err } func (m *mockProvisioner) AuthorizeRenewal(c *x509.Certificate) error { if m.authorizeRenewal != nil { return m.authorizeRenewal(c) } return m.err } func (m *mockProvisioner) AuthorizeSSHSign(ctx context.Context, token string) ([]provisioner.SignOption, error) { if m.authorizeSSHSign != nil { return m.authorizeSSHSign(ctx, token) } return m.ret1.([]provisioner.SignOption), m.err } func (m *mockProvisioner) AuthorizeSSHRevoke(ctx context.Context, token string) error { if m.authorizeSSHRevoke != nil { return m.authorizeSSHRevoke(ctx, token) } return m.err } func (m *mockProvisioner) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { if m.authorizeSSHRenew != nil { return m.authorizeSSHRenew(ctx, token) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) { if m.authorizeSSHRekey != nil { return m.authorizeSSHRekey(ctx, token) } return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err } func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority } type args struct { r Router } tests := []struct { name string fields fields args args }{ {"ok", fields{&mockAuthority{}}, args{chi.NewRouter()}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &caHandler{ Authority: tt.fields.Authority, } h.Route(tt.args.r) }) } } func Test_Health(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/health", http.NoBody) w := httptest.NewRecorder() Health(w, req) res := w.Result() if res.StatusCode != 200 { t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Health unexpected error = %v", err) } expected := []byte("{\"status\":\"ok\"}\n") if !bytes.Equal(body, expected) { t.Errorf("caHandler.Health Body = %s, wants %s", body, expected) } } func Test_Root(t *testing.T) { const sha = "efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36" tests := []struct { name string root *x509.Certificate err error expectedMsg string statusCode int }{ {"ok", parseCertificate(rootPEM), nil, "", 200}, {"fail", nil, fmt.Errorf("not found"), fmt.Sprintf(`root certificate with fingerprint \"%s\" was not found`, sha), 404}, } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("sha", sha) req := httptest.NewRequest("GET", "http://example.com/root/"+sha, http.NoBody) req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) expected := []byte(`{"ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err}) w := httptest.NewRecorder() Root(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Root unexpected error = %v", err) } if tt.statusCode == 200 { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) } } else { require.Contains(t, string(body), tt.expectedMsg) } }) } } func Test_Sign(t *testing.T) { csr := parseCertificateRequest(csrPEM) valid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, OTT: "foobarzar", }) require.NoError(t, err) invalid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, OTT: "", }) require.NoError(t, err) expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) tests := []struct { name string input string certAttrOpts []provisioner.SignOption autherr error cert *x509.Certificate root *x509.Certificate signErr error statusCode int expected []byte }{ {"ok", string(valid), nil, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated, expected1}, {"ok with Provisioner", string(valid), nil, nil, parseCertificate(stepCertPEM), parseCertificate(rootPEM), nil, http.StatusCreated, expected2}, {"json read error", "{", nil, nil, nil, nil, nil, http.StatusBadRequest, nil}, {"validate error", string(invalid), nil, nil, nil, nil, nil, http.StatusBadRequest, nil}, {"authorize error", string(valid), nil, fmt.Errorf("an error"), nil, nil, nil, http.StatusUnauthorized, nil}, {"sign error", string(valid), nil, nil, nil, nil, fmt.Errorf("an error"), http.StatusForbidden, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr }, getTLSOptions: func() *authority.TLSOptions { return nil }, }) req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input)) w := httptest.NewRecorder() Sign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Root unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.expected) { t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.expected) } } }) } } func Test_Renew(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } // Prepare root and leaf for renew after expiry test. now := time.Now() rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } leafPub, leafPriv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } root := &x509.Certificate{ Subject: pkix.Name{CommonName: "Test Root CA"}, PublicKey: rootPub, KeyUsage: x509.KeyUsageCertSign, BasicConstraintsValid: true, IsCA: true, NotBefore: now.Add(-2 * time.Hour), NotAfter: now.Add(time.Hour), } root, err = x509util.CreateCertificate(root, root, rootPub, rootPriv) if err != nil { t.Fatal(err) } expiredLeaf := &x509.Certificate{ Subject: pkix.Name{CommonName: "Leaf certificate"}, PublicKey: leafPub, KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, NotBefore: now.Add(-time.Hour), NotAfter: now.Add(-time.Minute), EmailAddresses: []string{"test@example.org"}, } expiredLeaf, err = x509util.CreateCertificate(expiredLeaf, root, leafPub, rootPriv) if err != nil { t.Fatal(err) } // Generate renew after expiry token so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("x5cInsecure", []string{base64.StdEncoding.EncodeToString(expiredLeaf.Raw)}) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: leafPriv}, so) if err != nil { t.Fatal(err) } generateX5cToken := func(claims jose.Claims) string { s, err := jose.Signed(sig).Claims(claims).CompactSerialize() if err != nil { t.Fatal(err) } return s } tests := []struct { name string tls *tls.ConnectionState header http.Header cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", cs, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"ok renew after expiry", &tls.ConnectionState{}, http.Header{ "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), })}, }, expiredLeaf, root, nil, http.StatusCreated}, {"no tls", nil, nil, nil, nil, nil, http.StatusBadRequest}, {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, nil, http.StatusBadRequest}, {"renew error", cs, nil, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, {"fail expired token", &tls.ConnectionState{}, http.Header{ "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)), })}, }, expiredLeaf, root, errs.Forbidden("an error"), http.StatusUnauthorized}, {"fail invalid root", &tls.ConnectionState{}, http.Header{ "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)), })}, }, expiredLeaf, parseCertificate(rootPEM), errs.Forbidden("an error"), http.StatusUnauthorized}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) { jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root}) if err != nil { return nil, errs.Unauthorized(err.Error()) } var claims jose.Claims if err := jwt.Claims(chain[0][0].PublicKey, &claims); err != nil { return nil, errs.Unauthorized(err.Error()) } if err := claims.ValidateWithLeeway(jose.Expected{ Time: now, }, time.Minute); err != nil { return nil, errs.Unauthorized(err.Error()) } return chain[0][0], nil }, getTLSOptions: func() *authority.TLSOptions { return nil }, }) req := httptest.NewRequest("POST", "http://example.com/renew", http.NoBody) req.TLS = tt.tls req.Header = tt.header w := httptest.NewRecorder() Renew(logging.NewResponseLogger(w), req) res := w.Result() defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != nil { t.Errorf("caHandler.Renew unexpected error = %v", err) } if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("%s", body) } if tt.statusCode < http.StatusBadRequest { expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` + `"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` + `"certChain":["` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `","` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `"]}`) if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Root Body = \n%s, wants \n%s", body, expected) } } }) } } func Test_Rekey(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } csr := parseCertificateRequest(csrPEM) valid, err := json.Marshal(RekeyRequest{ CsrPEM: CertificateRequest{csr}, }) if err != nil { t.Fatal(err) } tests := []struct { name string input string tls *tls.ConnectionState cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", string(valid), cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no tls", string(valid), nil, nil, nil, nil, http.StatusBadRequest}, {"no peer certificates", string(valid), &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, {"rekey error", string(valid), cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, {"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest}, } expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil }, }) req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input)) req.TLS = tt.tls w := httptest.NewRecorder() Rekey(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Rekey unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Rekey Body = %s, wants %s", body, expected) } } }) } } func Test_Provisioners(t *testing.T) { type fields struct { Authority Authority } type args struct { w http.ResponseWriter r *http.Request } req, err := http.NewRequest("GET", "http://example.com/provisioners?cursor=foo&limit=20", http.NoBody) if err != nil { t.Fatal(err) } reqLimitFail, err := http.NewRequest("GET", "http://example.com/provisioners?limit=abc", http.NoBody) if err != nil { t.Fatal(err) } var key jose.JSONWebKey if err := json.Unmarshal([]byte(pubKey), &key); err != nil { t.Fatal(err) } p := provisioner.List{ &provisioner.JWK{ Type: "JWK", Name: "max", EncryptedKey: "abc", Key: &key, }, &provisioner.JWK{ Type: "JWK", Name: "mariano", EncryptedKey: "def", Key: &key, }, } pr := ProvisionersResponse{ Provisioners: p, } tests := []struct { name string fields fields args args statusCode int }{ {"ok", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200}, {"fail", fields{&mockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500}, {"limit fail", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400}, } expected, err := json.Marshal(pr) if err != nil { t.Fatal(err) } expectedError400 := errs.BadRequest("limit 'abc' is not an integer") expectedError400Bytes, err := json.Marshal(expectedError400) require.NoError(t, err) expectedError500 := errs.InternalServer("force") expectedError500Bytes, err := json.Marshal(expectedError500) require.NoError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, tt.fields.Authority) Provisioners(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Provisioners unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected) } } else { switch tt.statusCode { case 400: if !bytes.Equal(bytes.TrimSpace(body), expectedError400Bytes) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400Bytes) } case 500: if !bytes.Equal(bytes.TrimSpace(body), expectedError500Bytes) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500Bytes) } default: t.Errorf("caHandler.Provisioner unexpected status code = %d", tt.statusCode) } } }) } } func Test_ProvisionerKey(t *testing.T) { type fields struct { Authority Authority } type args struct { w http.ResponseWriter r *http.Request } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("kid", "oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00") req := httptest.NewRequest("GET", "http://example.com/provisioners/oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00/encrypted-key", http.NoBody) req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) tests := []struct { name string fields fields args args statusCode int }{ {"ok", fields{&mockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200}, {"fail", fields{&mockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404}, } expected := []byte(`{"key":"` + privKey + `"}`) expectedError404 := errs.NotFound("force") expectedError404Bytes, err := json.Marshal(expectedError404) require.NoError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, tt.fields.Authority) ProvisionerKey(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Provisioners unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected) } } else { if !bytes.Equal(bytes.TrimSpace(body), expectedError404Bytes) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError404Bytes) } } }) } } func Test_Roots(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } tests := []struct { name string tls *tls.ConnectionState cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}) req := httptest.NewRequest("GET", "http://example.com/roots", http.NoBody) req.TLS = tt.tls w := httptest.NewRecorder() Roots(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Roots unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected) } } }) } } func Test_caHandler_RootsPEM(t *testing.T) { parsedRoot := parseCertificate(rootPEM) tests := []struct { name string roots []*x509.Certificate err error statusCode int expect string }{ {"one root", []*x509.Certificate{parsedRoot}, nil, http.StatusOK, rootPEM}, {"two roots", []*x509.Certificate{parsedRoot, parsedRoot}, nil, http.StatusOK, rootPEM + "\n" + rootPEM}, {"fail", nil, errors.New("an error"), http.StatusInternalServerError, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err}) req := httptest.NewRequest("GET", "https://example.com/roots", http.NoBody) w := httptest.NewRecorder() RootsPEM(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.RootsPEM StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.RootsPEM unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), []byte(tt.expect)) { t.Errorf("caHandler.RootsPEM Body = %s, wants %s", body, tt.expect) } } }) } } func Test_Federation(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } tests := []struct { name string tls *tls.ConnectionState cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}) req := httptest.NewRequest("GET", "http://example.com/federation", http.NoBody) req.TLS = tt.tls w := httptest.NewRecorder() Federation(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Federation unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected) } } }) } } func Test_fmtPublicKey(t *testing.T) { p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } rsa2048, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatal(err) } edPub, edPriv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } var dsa2048 dsa.PrivateKey if err := dsa.GenerateParameters(&dsa2048.Parameters, rand.Reader, dsa.L2048N256); err != nil { t.Fatal(err) } if err := dsa.GenerateKey(&dsa2048, rand.Reader); err != nil { t.Fatal(err) } type args struct { pub, priv interface{} cert *x509.Certificate } tests := []struct { name string args args want string }{ {"p256", args{p256.Public(), p256, nil}, "ECDSA P-256"}, {"rsa2048", args{rsa2048.Public(), rsa2048, nil}, "RSA 2048"}, {"ed25519", args{edPub, edPriv, nil}, "Ed25519"}, {"dsa2048", args{cert: &x509.Certificate{PublicKeyAlgorithm: x509.DSA, PublicKey: &dsa2048.PublicKey}}, "DSA 2048"}, {"unknown", args{cert: &x509.Certificate{PublicKeyAlgorithm: x509.ECDSA, PublicKey: []byte("12345678")}}, "ECDSA unknown"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var cert *x509.Certificate if tt.args.cert != nil { cert = tt.args.cert } else { cert = mustCertificate(t, tt.args.pub, tt.args.priv) } if got := fmtPublicKey(cert); got != tt.want { t.Errorf("fmtPublicKey() = %v, want %v", got, tt.want) } }) } } func mustCertificate(t *testing.T, pub, priv interface{}) *x509.Certificate { template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ Organization: []string{"Acme Co"}, }, NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, } der, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv) if err != nil { t.Fatal(err) } cert, err := x509.ParseCertificate(der) if err != nil { t.Fatal(err) } return cert } func TestProvisionersResponse_MarshalJSON(t *testing.T) { k := map[string]any{ "use": "sig", "kty": "EC", "kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc", "crv": "P-256", "alg": "ES256", "x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8", "y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y", } key := jose.JSONWebKey{} b, err := json.Marshal(k) require.NoError(t, err) err = json.Unmarshal(b, &key) require.NoError(t, err) r := ProvisionersResponse{ Provisioners: provisioner.List{ &provisioner.SCEP{ Name: "scep", Type: "scep", ChallengePassword: "not-so-secret", MinimumPublicKeyLength: 2048, EncryptionAlgorithmIdentifier: 2, IncludeRoot: true, ExcludeIntermediate: true, DecrypterCertificate: []byte{1, 2, 3, 4}, DecrypterKeyPEM: []byte{5, 6, 7, 8}, DecrypterKeyURI: "softkms:path=/path/to/private.key", DecrypterKeyPassword: "super-secret-password", }, &provisioner.JWK{ EncryptedKey: "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg", Key: &key, Name: "step-cli", Type: "JWK", }, }, NextCursor: "next", } expected := map[string]any{ "provisioners": []map[string]any{ { "type": "scep", "name": "scep", "forceCN": false, "includeRoot": true, "excludeIntermediate": true, "challenge": "*** REDACTED ***", "decrypterCertificate": []byte("*** REDACTED ***"), "decrypterKey": "*** REDACTED ***", "decrypterKeyPEM": []byte("*** REDACTED ***"), "decrypterKeyPassword": "*** REDACTED ***", "minimumPublicKeyLength": 2048, "encryptionAlgorithmIdentifier": 2, }, { "type": "JWK", "name": "step-cli", "key": map[string]any{ "use": "sig", "kty": "EC", "kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc", "crv": "P-256", "alg": "ES256", "x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8", "y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y", }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg", }, }, "nextCursor": "next", } expBytes, err := json.Marshal(expected) assert.NoError(t, err) br, err := r.MarshalJSON() assert.NoError(t, err) assert.JSONEq(t, string(expBytes), string(br)) keyCopy := key expList := provisioner.List{ &provisioner.SCEP{ Name: "scep", Type: "scep", ChallengePassword: "not-so-secret", MinimumPublicKeyLength: 2048, EncryptionAlgorithmIdentifier: 2, IncludeRoot: true, ExcludeIntermediate: true, DecrypterCertificate: []byte{1, 2, 3, 4}, DecrypterKeyPEM: []byte{5, 6, 7, 8}, DecrypterKeyURI: "softkms:path=/path/to/private.key", DecrypterKeyPassword: "super-secret-password", }, &provisioner.JWK{ EncryptedKey: "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg", Key: &keyCopy, Name: "step-cli", Type: "JWK", }, } // MarshalJSON must not affect the struct properties itself assert.Equal(t, expList, r.Provisioners) } const ( fixtureECDSACertificate = `ecdsa-sha2-nistp256-cert-v01@openssh.com AAAAKGVjZHNhLXNoYTItbmlzdHAyNTYtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLnkvSk4odlo3b1R+RDw+LmorL3RkN354IilCIVFVen4AAAAIbmlzdHAyNTYAAABBBHjKHss8WM2ffMYlavisoLXR0I6UEIU+cidV1ogEH1U6+/SYaFPrlzQo0tGLM5CNkMbhInbyasQsrHzn8F1Rt7nHg5/tcSf9qwAAAAEAAAAGaGVybWFuAAAACgAAAAZoZXJtYW4AAAAAY8kvJwAAAABjyhBjAAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAGgAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAAhuaXN0cDI1NgAAAEEE/ayqpPrZZF5uA1UlDt4FreTf15agztQIzpxnWq/XoxAHzagRSkFGkdgFpjgsfiRpP8URHH3BZScqc0ZDCTxhoQAAAGQAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAEkAAAAhAJuP1wCVwoyrKrEtHGfFXrVbRHySDjvXtS1tVTdHyqymAAAAIBa/CSSzfZb4D2NLP+eEmOOMJwSjYOiNM8fiOoAaqglI herman` ) func TestLogSSHCertificate(t *testing.T) { out, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fixtureECDSACertificate)) require.NoError(t, err) cert, ok := out.(*ssh.Certificate) require.True(t, ok) w := httptest.NewRecorder() rl := logging.NewResponseLogger(w) LogSSHCertificate(rl, cert) assert.Equal(t, 200, w.Result().StatusCode) fields := rl.Fields() assert.Equal(t, uint64(14376510277651266987), fields["serial"]) assert.Equal(t, []string{"herman"}, fields["principals"]) assert.Equal(t, "ecdsa-sha2-nistp256-cert-v01@openssh.com user certificate", fields["certificate-type"]) assert.Equal(t, time.Unix(1674129191, 0).Format(time.RFC3339), fields["valid-from"]) assert.Equal(t, time.Unix(1674186851, 0).Format(time.RFC3339), fields["valid-to"]) assert.Equal(t, "AAAAKGVjZHNhLXNoYTItbmlzdHAyNTYtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLnkvSk4odlo3b1R+RDw+LmorL3RkN354IilCIVFVen4AAAAIbmlzdHAyNTYAAABBBHjKHss8WM2ffMYlavisoLXR0I6UEIU+cidV1ogEH1U6+/SYaFPrlzQo0tGLM5CNkMbhInbyasQsrHzn8F1Rt7nHg5/tcSf9qwAAAAEAAAAGaGVybWFuAAAACgAAAAZoZXJtYW4AAAAAY8kvJwAAAABjyhBjAAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAGgAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAAhuaXN0cDI1NgAAAEEE/ayqpPrZZF5uA1UlDt4FreTf15agztQIzpxnWq/XoxAHzagRSkFGkdgFpjgsfiRpP8URHH3BZScqc0ZDCTxhoQAAAGQAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAEkAAAAhAJuP1wCVwoyrKrEtHGfFXrVbRHySDjvXtS1tVTdHyqymAAAAIBa/CSSzfZb4D2NLP+eEmOOMJwSjYOiNM8fiOoAaqglI", fields["certificate"]) assert.Equal(t, "SHA256:RvkDPGwl/G9d7LUFm1kmWhvOD9I/moPq4yxcb0STwr0 (ECDSA-CERT)", fields["public-key"]) } func TestIntermediates(t *testing.T) { ca, err := minica.New() require.NoError(t, err) getRequest := func(t *testing.T, crt []*x509.Certificate) *http.Request { mockMustAuthority(t, &mockAuthority{ ret1: crt, }) return httptest.NewRequest("GET", "/intermediates", http.NoBody) } type args struct { crts []*x509.Certificate } tests := []struct { name string args args wantStatusCode int wantBody []byte }{ {"ok", args{[]*x509.Certificate{ca.Intermediate}}, http.StatusCreated, mustJSON(t, IntermediatesResponse{ Certificates: []Certificate{{ca.Intermediate}}, })}, {"ok multiple", args{[]*x509.Certificate{ca.Root, ca.Intermediate}}, http.StatusCreated, mustJSON(t, IntermediatesResponse{ Certificates: []Certificate{{ca.Root}, {ca.Intermediate}}, })}, {"fail", args{}, http.StatusNotImplemented, mustJSON(t, errs.NotImplemented("not implemented"))}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() r := getRequest(t, tt.args.crts) Intermediates(w, r) assert.Equal(t, tt.wantStatusCode, w.Result().StatusCode) assert.Equal(t, tt.wantBody, w.Body.Bytes()) }) } } func TestIntermediatesPEM(t *testing.T) { ca, err := minica.New() require.NoError(t, err) getRequest := func(t *testing.T, crt []*x509.Certificate) *http.Request { mockMustAuthority(t, &mockAuthority{ ret1: crt, }) return httptest.NewRequest("GET", "/intermediates.pem", http.NoBody) } type args struct { crts []*x509.Certificate } tests := []struct { name string args args wantStatusCode int wantBody []byte }{ {"ok", args{[]*x509.Certificate{ca.Intermediate}}, http.StatusOK, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: ca.Intermediate.Raw, })}, {"ok multiple", args{[]*x509.Certificate{ca.Root, ca.Intermediate}}, http.StatusOK, append(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: ca.Root.Raw, }), pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: ca.Intermediate.Raw, })...)}, {"fail", args{}, http.StatusNotImplemented, mustJSON(t, errs.NotImplemented("not implemented"))}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() r := getRequest(t, tt.args.crts) IntermediatesPEM(w, r) assert.Equal(t, tt.wantStatusCode, w.Result().StatusCode) assert.Equal(t, tt.wantBody, w.Body.Bytes()) }) } } ================================================ FILE: api/crl.go ================================================ package api import ( "encoding/pem" "net/http" "time" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) // CRL is an HTTP handler that returns the current CRL in DER or PEM format func CRL(w http.ResponseWriter, r *http.Request) { crlInfo, err := mustAuthority(r.Context()).GetCertificateRevocationList() if err != nil { render.Error(w, r, err) return } if crlInfo == nil { render.Error(w, r, errs.New(http.StatusNotFound, "no CRL available")) return } expires := crlInfo.ExpiresAt if expires.IsZero() { expires = time.Now() } w.Header().Add("Expires", expires.Format(time.RFC1123)) _, formatAsPEM := r.URL.Query()["pem"] if formatAsPEM { w.Header().Add("Content-Type", "application/x-pem-file") w.Header().Add("Content-Disposition", "attachment; filename=\"crl.pem\"") _ = pem.Encode(w, &pem.Block{ Type: "X509 CRL", Bytes: crlInfo.Data, }) } else { w.Header().Add("Content-Type", "application/pkix-crl") w.Header().Add("Content-Disposition", "attachment; filename=\"crl.crl\"") w.Write(crlInfo.Data) } } ================================================ FILE: api/crl_test.go ================================================ package api import ( "bytes" "context" "encoding/pem" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/go-chi/chi/v5" "github.com/pkg/errors" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/errs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_CRL(t *testing.T) { data := []byte{1, 2, 3, 4} pemData := pem.EncodeToMemory(&pem.Block{ Type: "X509 CRL", Bytes: data, }) pemData = bytes.TrimSpace(pemData) emptyPEMData := pem.EncodeToMemory(&pem.Block{ Type: "X509 CRL", Bytes: nil, }) emptyPEMData = bytes.TrimSpace(emptyPEMData) tests := []struct { name string url string err error statusCode int crlInfo *authority.CertificateRevocationListInfo expectedBody []byte expectedHeaders http.Header expectedErrorJSON string }{ {"ok", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: data}, data, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.crl"`}}, ""}, {"ok/pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: data}, pemData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""}, {"ok/empty", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, nil, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.crl"`}}, ""}, {"ok/empty-pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, emptyPEMData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""}, {"fail/internal", "http://example.com/crl", errs.Wrap(http.StatusInternalServerError, errors.New("failure"), "authority.GetCertificateRevocationList"), http.StatusInternalServerError, nil, nil, http.Header{}, `{"status":500,"message":"The certificate authority encountered an Internal Server Error. Please see the certificate authority logs for more info."}`}, {"fail/nil", "http://example.com/crl", nil, http.StatusNotFound, nil, nil, http.Header{}, `{"status":404,"message":"no CRL available"}`}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ret1: tt.crlInfo, err: tt.err}) chiCtx := chi.NewRouteContext() req := httptest.NewRequest("GET", tt.url, http.NoBody) req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) w := httptest.NewRecorder() CRL(w, req) res := w.Result() assert.Equal(t, tt.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() require.NoError(t, err) if tt.statusCode >= 300 { assert.JSONEq(t, tt.expectedErrorJSON, string(bytes.TrimSpace(body))) return } // check expected header values for _, h := range []string{"content-type", "content-disposition"} { v := tt.expectedHeaders.Get(h) require.NotEmpty(t, v) actual := res.Header.Get(h) assert.Equal(t, v, actual) } // check expires header value assert.NotEmpty(t, res.Header.Get("expires")) t1, err := time.Parse(time.RFC1123, res.Header.Get("expires")) if assert.NoError(t, err) { assert.False(t, t1.IsZero()) } // check body contents assert.Equal(t, tt.expectedBody, bytes.TrimSpace(body)) }) } } ================================================ FILE: api/log/log.go ================================================ // Package log implements API-related logging helpers. package log import ( "context" "fmt" "net/http" "os" "github.com/pkg/errors" ) type errorLoggerKey struct{} // ErrorLogger is the function type used to log errors. type ErrorLogger func(http.ResponseWriter, *http.Request, error) func (fn ErrorLogger) call(w http.ResponseWriter, r *http.Request, err error) { if fn == nil { return } fn(w, r, err) } // WithErrorLogger returns a new context with the given error logger. func WithErrorLogger(ctx context.Context, fn ErrorLogger) context.Context { return context.WithValue(ctx, errorLoggerKey{}, fn) } // ErrorLoggerFromContext returns an error logger from the context. func ErrorLoggerFromContext(ctx context.Context) (fn ErrorLogger) { fn, _ = ctx.Value(errorLoggerKey{}).(ErrorLogger) return } // StackTracedError is the set of errors implementing the StackTrace function. // // Errors implementing this interface have their stack traces logged when passed // to the Error function of this package. type StackTracedError interface { error StackTrace() errors.StackTrace } type fieldCarrier interface { WithFields(map[string]any) Fields() map[string]any } // Error adds to the response writer the given error if it implements // logging.ResponseLogger. If it does not implement it, then writes the error // using the log package. func Error(w http.ResponseWriter, r *http.Request, err error) { ErrorLoggerFromContext(r.Context()).call(w, r, err) fc, ok := w.(fieldCarrier) if !ok { return } fc.WithFields(map[string]any{ "error": err, }) if os.Getenv("STEPDEBUG") != "1" { return } var st StackTracedError if errors.As(err, &st) { fc.WithFields(map[string]any{ "stack-trace": fmt.Sprintf("%+v", st.StackTrace()), }) } } // EnabledResponse log the response object if it implements the EnableLogger // interface. func EnabledResponse(rw http.ResponseWriter, r *http.Request, v any) { type enableLogger interface { ToLog() (any, error) } if el, ok := v.(enableLogger); ok { out, err := el.ToLog() if err != nil { Error(rw, r, err) return } if rl, ok := rw.(fieldCarrier); ok { rl.WithFields(map[string]any{ "response": out, }) } } } ================================================ FILE: api/log/log_test.go ================================================ package log import ( "bytes" "encoding/json" "log/slog" "net/http" "net/http/httptest" "testing" "unsafe" pkgerrors "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/smallstep/certificates/logging" ) type stackTracedError struct{} func (stackTracedError) Error() string { return "a stacktraced error" } func (stackTracedError) StackTrace() pkgerrors.StackTrace { f := struct{}{} return pkgerrors.StackTrace{ // fake stacktrace pkgerrors.Frame(unsafe.Pointer(&f)), pkgerrors.Frame(unsafe.Pointer(&f)), } } func TestError(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{})) req := httptest.NewRequest("GET", "/test", http.NoBody) reqWithLogger := req.WithContext(WithErrorLogger(req.Context(), func(w http.ResponseWriter, r *http.Request, err error) { if err != nil { logger.ErrorContext(r.Context(), "request failed", slog.Any("error", err)) } })) tests := []struct { name string error rw http.ResponseWriter r *http.Request isFieldCarrier bool isSlogLogger bool stepDebug bool expectStackTrace bool }{ {"noLogger", nil, nil, req, false, false, false, false}, {"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, false, false}, {"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, false}, {"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, false, false}, {"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, false}, {"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, true}, {"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, true}, {"slogWithNoError", nil, logging.NewResponseLogger(httptest.NewRecorder()), reqWithLogger, true, true, false, false}, {"slogWithError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), reqWithLogger, true, true, false, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.stepDebug { t.Setenv("STEPDEBUG", "1") } else { t.Setenv("STEPDEBUG", "0") } Error(tt.rw, tt.r, tt.error) // return early if test case doesn't use logger if !tt.isFieldCarrier && !tt.isSlogLogger { return } if tt.isFieldCarrier { fields := tt.rw.(logging.ResponseLogger).Fields() // expect the error field to be (not) set and to be the same error that was fed to Error if tt.error == nil { assert.Nil(t, fields["error"]) } else { assert.Same(t, tt.error, fields["error"]) } // check if stack-trace is set when expected if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace { t.Error(`ResponseLogger["stack-trace"] not set`) } else if !tt.expectStackTrace && hasStackTrace { t.Error(`ResponseLogger["stack-trace"] was set`) } } if tt.isSlogLogger { b := buf.Bytes() if tt.error == nil { assert.Empty(t, b) } else if assert.NotEmpty(t, b) { var m map[string]any assert.NoError(t, json.Unmarshal(b, &m)) assert.Equal(t, tt.error.Error(), m["error"]) } buf.Reset() } }) } } ================================================ FILE: api/models/scep.go ================================================ package models import ( "context" "crypto/x509" "errors" "github.com/smallstep/certificates/authority/provisioner" "golang.org/x/crypto/ssh" ) var errDummyImplementation = errors.New("dummy implementation") // SCEP is the SCEP provisioner model used solely in CA API // responses. All methods for the [provisioner.Interface] interface // are implemented, but return a dummy error. // TODO(hs): remove reliance on the interface for the API responses type SCEP struct { ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` ForceCN bool `json:"forceCN"` ChallengePassword string `json:"challenge"` Capabilities []string `json:"capabilities,omitempty"` IncludeRoot bool `json:"includeRoot"` ExcludeIntermediate bool `json:"excludeIntermediate"` MinimumPublicKeyLength int `json:"minimumPublicKeyLength"` DecrypterCertificate []byte `json:"decrypterCertificate"` DecrypterKeyPEM []byte `json:"decrypterKeyPEM"` DecrypterKeyURI string `json:"decrypterKey"` DecrypterKeyPassword string `json:"decrypterKeyPassword"` EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier"` Options *provisioner.Options `json:"options,omitempty"` Claims *provisioner.Claims `json:"claims,omitempty"` } // GetID returns the provisioner unique identifier. func (s *SCEP) GetID() string { if s.ID != "" { return s.ID } return s.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (s *SCEP) GetIDForToken() string { return "scep/" + s.Name } // GetName returns the name of the provisioner. func (s *SCEP) GetName() string { return s.Name } // GetType returns the type of provisioner. func (s *SCEP) GetType() provisioner.Type { return provisioner.TypeSCEP } // GetEncryptedKey returns the base provisioner encrypted key if it's defined. func (s *SCEP) GetEncryptedKey() (string, string, bool) { return "", "", false } // GetTokenID returns the identifier of the token. func (s *SCEP) GetTokenID(string) (string, error) { return "", errDummyImplementation } // Init initializes and validates the fields of a SCEP type. func (s *SCEP) Init(_ provisioner.Config) (err error) { return errDummyImplementation } // AuthorizeSign returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for signing x509 Certificates. func (s *SCEP) AuthorizeSign(context.Context, string) ([]provisioner.SignOption, error) { return nil, errDummyImplementation } // AuthorizeRevoke returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for revoking x509 Certificates. func (s *SCEP) AuthorizeRevoke(context.Context, string) error { return errDummyImplementation } // AuthorizeRenew returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for renewing x509 Certificates. func (s *SCEP) AuthorizeRenew(context.Context, *x509.Certificate) error { return errDummyImplementation } // AuthorizeSSHSign returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for signing SSH Certificates. func (s *SCEP) AuthorizeSSHSign(context.Context, string) ([]provisioner.SignOption, error) { return nil, errDummyImplementation } // AuthorizeSSHRevoke returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for revoking SSH Certificates. func (s *SCEP) AuthorizeSSHRevoke(context.Context, string) error { return errDummyImplementation } // AuthorizeSSHRenew returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for renewing SSH Certificates. func (s *SCEP) AuthorizeSSHRenew(context.Context, string) (*ssh.Certificate, error) { return nil, errDummyImplementation } // AuthorizeSSHRekey returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for rekeying SSH Certificates. func (s *SCEP) AuthorizeSSHRekey(context.Context, string) (*ssh.Certificate, []provisioner.SignOption, error) { return nil, nil, errDummyImplementation } var _ provisioner.Interface = (*SCEP)(nil) ================================================ FILE: api/read/read.go ================================================ // Package read implements request object readers. package read import ( "encoding/json" "errors" "io" "net/http" "strings" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) // JSON reads JSON from the request body and stores it in the value // pointed to by v. func JSON(r io.Reader, v interface{}) error { if err := json.NewDecoder(r).Decode(v); err != nil { return errs.BadRequestErr(err, "error decoding json") } return nil } // ProtoJSON reads JSON from the request body and stores it in the value // pointed to by m. func ProtoJSON(r io.Reader, m proto.Message) error { data, err := io.ReadAll(r) if err != nil { return errs.BadRequestErr(err, "error reading request body") } switch err := protojson.Unmarshal(data, m); { case errors.Is(err, proto.Error): return badProtoJSONError(err.Error()) default: return err } } // badProtoJSONError is an error type that is returned by ProtoJSON // when a proto message cannot be unmarshaled. Usually this is caused // by an error in the request body. type badProtoJSONError string // Error implements error for badProtoJSONError func (e badProtoJSONError) Error() string { return string(e) } // Render implements render.RenderableError for badProtoJSONError func (e badProtoJSONError) Render(w http.ResponseWriter, r *http.Request) { v := struct { Type string `json:"type"` Detail string `json:"detail"` Message string `json:"message"` }{ Type: "badRequest", Detail: "bad request", // trim the proto prefix for the message Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")), } render.JSONStatus(w, r, v, http.StatusBadRequest) } ================================================ FILE: api/read/read_test.go ================================================ package read import ( "encoding/json" "errors" "io" "net/http" "net/http/httptest" "reflect" "strings" "testing" "testing/iotest" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/errs" ) func TestJSON(t *testing.T) { type args struct { r io.Reader v interface{} } tests := []struct { name string args args wantErr bool }{ {"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false}, {"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := JSON(tt.args.r, &tt.args.v) if (err != nil) != tt.wantErr { t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr) } if tt.wantErr { var e *errs.Error if errors.As(err, &e) { if code := e.StatusCode(); code != 400 { t.Errorf("error.StatusCode() = %v, wants 400", code) } } else { t.Errorf("error type = %T, wants *Error", err) } } else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) { t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"}) } }) } } func TestProtoJSON(t *testing.T) { p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import? type args struct { r io.Reader m proto.Message } tests := []struct { name string args args wantErr bool }{ { name: "fail/io.ReadAll", args: args{ r: iotest.ErrReader(errors.New("read error")), m: p, }, wantErr: true, }, { name: "fail/proto", args: args{ r: strings.NewReader(`{?}`), m: p, }, wantErr: true, }, { name: "ok", args: args{ r: strings.NewReader(`{"x509":{}}`), m: p, }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := ProtoJSON(tt.args.r, tt.args.m) if (err != nil) != tt.wantErr { t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr) } if tt.wantErr { var ( ee *errs.Error bpe badProtoJSONError ) switch { case errors.As(err, &bpe): assert.Contains(t, err.Error(), "syntax error") case errors.As(err, &ee): assert.Equal(t, http.StatusBadRequest, ee.Status) } return } assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m)) assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m)) }) } } func Test_badProtoJSONError_Render(t *testing.T) { tests := []struct { name string e badProtoJSONError expected string }{ { name: "bad proto normal space", e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"), expected: "syntax error (line 1:2): invalid value ?", }, { name: "bad proto non breaking space", e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"), expected: "syntax error (line 1:2): invalid value ?", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/test", http.NoBody) tt.e.Render(w, r) res := w.Result() defer res.Body.Close() data, err := io.ReadAll(res.Body) assert.NoError(t, err) v := struct { Type string `json:"type"` Detail string `json:"detail"` Message string `json:"message"` }{} assert.NoError(t, json.Unmarshal(data, &v)) assert.Equal(t, "badRequest", v.Type) assert.Equal(t, "bad request", v.Detail) assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message) }) } } ================================================ FILE: api/rekey.go ================================================ package api import ( "net/http" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) // RekeyRequest is the request body for a certificate rekey request. type RekeyRequest struct { CsrPEM CertificateRequest `json:"csr"` } // Validate checks the fields of the RekeyRequest and returns nil if they are ok // or an error if something is wrong. func (s *RekeyRequest) Validate() error { if s.CsrPEM.CertificateRequest == nil { return errs.BadRequest("missing csr") } if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { return errs.BadRequestErr(err, "invalid csr") } return nil } // Rekey is similar to renew except that the certificate will be renewed with new key from csr. func Rekey(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { render.Error(w, r, errs.BadRequest("missing client certificate")) return } var body RekeyRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { render.Error(w, r, err) return } a := mustAuthority(r.Context()) certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) if err != nil { render.Error(w, r, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) return } certChainPEM := certChainToPEM(certChain) var caPEM Certificate if len(certChainPEM) > 1 { caPEM = certChainPEM[1] } LogCertificate(w, certChain[0]) render.JSONStatus(w, r, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } ================================================ FILE: api/render/render.go ================================================ // Package render implements functionality related to response rendering. package render import ( "encoding/json" "errors" "net/http" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "github.com/smallstep/certificates/api/log" ) // JSON is shorthand for JSONStatus(w, v, http.StatusOK). func JSON(w http.ResponseWriter, r *http.Request, v interface{}) { JSONStatus(w, r, v, http.StatusOK) } // JSONStatus marshals v into w. It additionally sets the status code of // w to the given one. // // JSONStatus sets the Content-Type of w to application/json unless one is // specified. func JSONStatus(w http.ResponseWriter, r *http.Request, v interface{}, status int) { setContentTypeUnlessPresent(w, "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(v); err != nil { var errUnsupportedType *json.UnsupportedTypeError if errors.As(err, &errUnsupportedType) { panic(err) } var errUnsupportedValue *json.UnsupportedValueError if errors.As(err, &errUnsupportedValue) { panic(err) } var errMarshalError *json.MarshalerError if errors.As(err, &errMarshalError) { panic(err) } } log.EnabledResponse(w, r, v) } // ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK). func ProtoJSON(w http.ResponseWriter, m proto.Message) { ProtoJSONStatus(w, m, http.StatusOK) } // ProtoJSONStatus writes the given value into the http.ResponseWriter and the // given status is written as the status code of the response. func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { b, err := protojson.Marshal(m) if err != nil { panic(err) } setContentTypeUnlessPresent(w, "application/json") w.WriteHeader(status) _, _ = w.Write(b) } func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) { const header = "Content-Type" h := w.Header() if _, ok := h[header]; !ok { h.Set(header, contentType) } } // RenderableError is the set of errors that implement the basic Render method. // // Errors that implement this interface will use their own Render method when // being rendered into responses. type RenderableError interface { error Render(http.ResponseWriter, *http.Request) } // Error marshals the JSON representation of err to w. In case err implements // RenderableError its own Render method will be called instead. func Error(rw http.ResponseWriter, r *http.Request, err error) { log.Error(rw, r, err) var re RenderableError if errors.As(err, &re) { re.Render(rw, r) return } JSONStatus(rw, r, err, statusCodeFromError(err)) } // StatusCodedError is the set of errors that implement the basic StatusCode // function. // // Errors that implement this interface will use the code reported by StatusCode // as the HTTP response code when being rendered by this package. type StatusCodedError interface { error StatusCode() int } func statusCodeFromError(err error) (code int) { code = http.StatusInternalServerError type causer interface { Cause() error } for err != nil { var sc StatusCodedError if errors.As(err, &sc) { code = sc.StatusCode() break } var c causer if !errors.As(err, &c) { break } err = c.Cause() } return } ================================================ FILE: api/render/render_test.go ================================================ package render import ( "encoding/json" "fmt" "io" "math" "net/http" "net/http/httptest" "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/smallstep/certificates/logging" ) func TestJSON(t *testing.T) { rec := httptest.NewRecorder() rw := logging.NewResponseLogger(rec) r := httptest.NewRequest("POST", "/test", http.NoBody) JSON(rw, r, map[string]interface{}{"foo": "bar"}) assert.Equal(t, http.StatusOK, rec.Result().StatusCode) assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) assert.Equal(t, "{\"foo\":\"bar\"}\n", rec.Body.String()) assert.Empty(t, rw.Fields()) } func TestJSONPanicsOnUnsupportedType(t *testing.T) { jsonPanicTest[json.UnsupportedTypeError](t, make(chan struct{})) } func TestJSONPanicsOnUnsupportedValue(t *testing.T) { jsonPanicTest[json.UnsupportedValueError](t, math.NaN()) } func TestJSONPanicsOnMarshalerError(t *testing.T) { var v erroneousJSONMarshaler jsonPanicTest[json.MarshalerError](t, v) } type erroneousJSONMarshaler struct{} func (erroneousJSONMarshaler) MarshalJSON() ([]byte, error) { return nil, assert.AnError } func jsonPanicTest[T json.UnsupportedTypeError | json.UnsupportedValueError | json.MarshalerError](t *testing.T, v any) { t.Helper() defer func() { var err error if r := recover(); r == nil { t.Fatal("expected panic") } else if e, ok := r.(error); !ok { t.Fatalf("did not panic with an error (%T)", r) } else { err = e } var e *T assert.ErrorAs(t, err, &e) }() r := httptest.NewRequest("POST", "/test", http.NoBody) JSON(httptest.NewRecorder(), r, v) } type renderableError struct { Code int `json:"-"` Message string `json:"message"` } func (err renderableError) Error() string { return err.Message } func (err renderableError) Render(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "something/custom") JSONStatus(w, r, err, err.Code) } type statusedError struct { Contents string } func (err statusedError) Error() string { return err.Contents } func (statusedError) StatusCode() int { return 432 } func TestError(t *testing.T) { cases := []struct { err error code int body string header string }{ 0: { err: renderableError{532, "some string"}, code: 532, body: "{\"message\":\"some string\"}\n", header: "something/custom", }, 1: { err: statusedError{"123"}, code: 432, body: "{\"Contents\":\"123\"}\n", header: "application/json", }, } for caseIndex := range cases { kase := cases[caseIndex] t.Run(strconv.Itoa(caseIndex), func(t *testing.T) { rec := httptest.NewRecorder() r := httptest.NewRequest("POST", "/test", http.NoBody) Error(rec, r, kase.err) assert.Equal(t, kase.code, rec.Result().StatusCode) assert.Equal(t, kase.body, rec.Body.String()) assert.Equal(t, kase.header, rec.Header().Get("Content-Type")) }) } } type causedError struct { cause error } func (err causedError) Error() string { return fmt.Sprintf("cause: %s", err.cause) } func (err causedError) Cause() error { return err.cause } func TestStatusCodeFromError(t *testing.T) { cases := []struct { err error exp int }{ 0: {nil, http.StatusInternalServerError}, 1: {io.EOF, http.StatusInternalServerError}, 2: {statusedError{"123"}, 432}, 3: {causedError{statusedError{"432"}}, 432}, } for caseIndex, kase := range cases { assert.Equal(t, kase.exp, statusCodeFromError(kase.err), "case: %d", caseIndex) } } ================================================ FILE: api/renew.go ================================================ package api import ( "crypto/x509" "net/http" "strings" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) const ( authorizationHeader = "Authorization" bearerScheme = "Bearer" ) // Renew uses the information of certificate in the TLS connection to create a // new one. func Renew(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Get the leaf certificate from the peer or the token. cert, token, err := getPeerCertificate(r) if err != nil { render.Error(w, r, err) return } // The token can be used by RAs to renew a certificate. if token != "" { ctx = authority.NewTokenContext(ctx, token) logOtt(w, token) } a := mustAuthority(ctx) certChain, err := a.RenewContext(ctx, cert, nil) if err != nil { render.Error(w, r, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return } certChainPEM := certChainToPEM(certChain) var caPEM Certificate if len(certChainPEM) > 1 { caPEM = certChainPEM[1] } LogCertificate(w, certChain[0]) render.JSONStatus(w, r, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { return r.TLS.PeerCertificates[0], "", nil } if s := r.Header.Get(authorizationHeader); s != "" { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RenewMethod) peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) return peer, parts[1], err } } return nil, "", errs.BadRequest("missing client certificate") } ================================================ FILE: api/revoke.go ================================================ package api import ( "math/big" "net/http" "golang.org/x/crypto/ocsp" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" ) // RevokeResponse is the response object that returns the health of the server. type RevokeResponse struct { Status string `json:"status"` } // RevokeRequest is the request body for a revocation request. type RevokeRequest struct { Serial string `json:"serial"` OTT string `json:"ott"` ReasonCode int `json:"reasonCode"` Reason string `json:"reason"` Passive bool `json:"passive"` } // Validate checks the fields of the RevokeRequest and returns nil if they are ok // or an error if something is wrong. func (r *RevokeRequest) Validate() (err error) { if r.Serial == "" { return errs.BadRequest("missing serial") } sn, ok := new(big.Int).SetString(r.Serial, 0) if !ok { return errs.BadRequest("'%s' is not a valid serial number - use a base 10 representation or a base 16 representation with '0x' prefix", r.Serial) } r.Serial = sn.String() if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise { return errs.BadRequest("reasonCode out of bounds") } if !r.Passive { return errs.NotImplemented("non-passive revocation not implemented") } return } // Revoke supports handful of different methods that revoke a Certificate. // // NOTE: currently only Passive revocation is supported. // // TODO: Add CRL and OCSP support. func Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { render.Error(w, r, err) return } opts := &authority.RevokeOptions{ Serial: body.Serial, Reason: body.Reason, ReasonCode: body.ReasonCode, PassiveOnly: body.Passive, } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod) a := mustAuthority(ctx) // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. if body.OTT != "" { logOtt(w, body.OTT) if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, r, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT } else { // If no token is present, then the request must be made over mTLS and // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { render.Error(w, r, errs.BadRequest("missing ott or client certificate")) return } opts.Crt = r.TLS.PeerCertificates[0] if serialNumber := opts.Crt.SerialNumber.String(); opts.Serial != serialNumber { render.Error(w, r, errs.Forbidden( "request serial number %q and certificate serial number %q do not match", opts.Serial, serialNumber)) return } // TODO: should probably be checking if the certificate was revoked here. // Will need to thread that request down to the authority, so will need // to add API for that. LogCertificate(w, opts.Crt) opts.MTLS = true } if err := a.Revoke(ctx, opts); err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error revoking certificate")) return } logRevoke(w, opts) render.JSON(w, r, &RevokeResponse{Status: "ok"}) } func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { if rl, ok := w.(logging.ResponseLogger); ok { rl.WithFields(map[string]interface{}{ "serial": ri.Serial, "reasonCode": ri.ReasonCode, "reason": ri.Reason, "passiveOnly": ri.PassiveOnly, "mTLS": ri.MTLS, }) } } ================================================ FILE: api/revoke_test.go ================================================ package api import ( "bytes" "context" "crypto/tls" "crypto/x509" "encoding/json" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" ) func TestRevokeRequestValidate(t *testing.T) { type test struct { rr *RevokeRequest err *errs.Error } tests := map[string]test{ "error/missing serial": { rr: &RevokeRequest{}, err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest}, }, "error/bad sn": { rr: &RevokeRequest{Serial: "sn"}, err: &errs.Error{Err: errors.New("'sn' is not a valid serial number - use a base 10 representation or a base 16 representation with '0x' prefix"), Status: http.StatusBadRequest}, }, "error/bad reasonCode": { rr: &RevokeRequest{ Serial: "10", ReasonCode: 15, Passive: true, }, err: &errs.Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest}, }, "error/non-passive not implemented": { rr: &RevokeRequest{ Serial: "10", ReasonCode: 8, Passive: false, }, err: &errs.Error{Err: errors.New("non-passive revocation not implemented"), Status: http.StatusNotImplemented}, }, "ok": { rr: &RevokeRequest{ Serial: "10", ReasonCode: 9, Passive: true, }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { if err := tc.rr.Validate(); err != nil { var ee *errs.Error if errors.As(err, &ee) { assert.HasPrefix(t, ee.Error(), tc.err.Error()) assert.Equals(t, ee.StatusCode(), tc.err.Status) } else { t.Errorf("unexpected error type: %T", err) } } else { assert.Nil(t, tc.err) } }) } } func Test_caHandler_Revoke(t *testing.T) { type test struct { input string auth Authority tls *tls.ConnectionState statusCode int expected []byte } tests := map[string]func(*testing.T) test{ "400/json read error": func(t *testing.T) test { return test{ input: "{", statusCode: http.StatusBadRequest, } }, "400/invalid request body": func(t *testing.T) test { input, err := json.Marshal(RevokeRequest{}) assert.FatalError(t, err) return test{ input: string(input), statusCode: http.StatusBadRequest, } }, "200/ott": func(t *testing.T) test { input, err := json.Marshal(RevokeRequest{ Serial: "10", ReasonCode: 4, Reason: "foo", OTT: "valid", Passive: true, }) assert.FatalError(t, err) return test{ input: string(input), statusCode: http.StatusOK, auth: &mockAuthority{ authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { assert.True(t, opts.PassiveOnly) assert.False(t, opts.MTLS) assert.Equals(t, opts.Serial, "10") assert.Equals(t, opts.ReasonCode, 4) assert.Equals(t, opts.Reason, "foo") return nil }, }, expected: []byte(`{"status":"ok"}`), } }, "400/no OTT and no peer certificate": func(t *testing.T) test { input, err := json.Marshal(RevokeRequest{ Serial: "10", ReasonCode: 4, Passive: true, }) assert.FatalError(t, err) return test{ input: string(input), statusCode: http.StatusBadRequest, } }, "200/no ott": func(t *testing.T) test { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } input, err := json.Marshal(RevokeRequest{ Serial: "1404354960355712309", ReasonCode: 4, Reason: "foo", Passive: true, }) assert.FatalError(t, err) return test{ input: string(input), statusCode: http.StatusOK, tls: cs, auth: &mockAuthority{ authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, ri *authority.RevokeOptions) error { assert.True(t, ri.PassiveOnly) assert.True(t, ri.MTLS) assert.Equals(t, ri.Serial, "1404354960355712309") assert.Equals(t, ri.ReasonCode, 4) assert.Equals(t, ri.Reason, "foo") return nil }, loadProvisionerByCertificate: func(crt *x509.Certificate) (provisioner.Interface, error) { return &mockProvisioner{ getID: func() string { return "mock-provisioner-id" }, }, err }, }, expected: []byte(`{"status":"ok"}`), } }, "500/ott authority.Revoke": func(t *testing.T) test { input, err := json.Marshal(RevokeRequest{ Serial: "10", ReasonCode: 4, Reason: "foo", OTT: "valid", Passive: true, }) assert.FatalError(t, err) return test{ input: string(input), statusCode: http.StatusInternalServerError, auth: &mockAuthority{ authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { return errs.InternalServer("force") }, }, } }, "403/ott authority.Revoke": func(t *testing.T) test { input, err := json.Marshal(RevokeRequest{ Serial: "10", ReasonCode: 4, Reason: "foo", OTT: "valid", Passive: true, }) assert.FatalError(t, err) return test{ input: string(input), statusCode: http.StatusForbidden, auth: &mockAuthority{ authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { return errors.New("force") }, }, } }, } for name, _tc := range tests { tc := _tc(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input)) if tc.tls != nil { req.TLS = tc.tls } w := httptest.NewRecorder() Revoke(logging.NewResponseLogger(w), req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if tc.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tc.expected) { t.Errorf("caHandler.Root Body = %s, wants %s", body, tc.expected) } } }) } } ================================================ FILE: api/sign.go ================================================ package api import ( "crypto/tls" "encoding/json" "net/http" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) // SignRequest is the request body for a certificate signature request. type SignRequest struct { CsrPEM CertificateRequest `json:"csr"` OTT string `json:"ott"` NotAfter TimeDuration `json:"notAfter,omitempty"` NotBefore TimeDuration `json:"notBefore,omitempty"` TemplateData json.RawMessage `json:"templateData,omitempty"` } // Validate checks the fields of the SignRequest and returns nil if they are ok // or an error if something is wrong. func (s *SignRequest) Validate() error { if s.CsrPEM.CertificateRequest == nil { return errs.BadRequest("missing csr") } if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { return errs.BadRequestErr(err, "invalid csr") } if s.OTT == "" { return errs.BadRequest("missing ott") } return nil } // SignResponse is the response object of the certificate signature request. type SignResponse struct { ServerPEM Certificate `json:"crt"` CaPEM Certificate `json:"ca"` CertChainPEM []Certificate `json:"certChain"` TLSOptions *config.TLSOptions `json:"tlsOptions,omitempty"` TLS *tls.ConnectionState `json:"-"` } // Sign is an HTTP handler that reads a certificate request and an // one-time-token (ott) from the body and creates a new certificate with the // information in the certificate request. func Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { render.Error(w, r, err) return } opts := provisioner.SignOptions{ NotBefore: body.NotBefore, NotAfter: body.NotAfter, TemplateData: body.TemplateData, } ctx := r.Context() a := mustAuthority(ctx) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, r, errs.UnauthorizedErr(err)) return } certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error signing certificate")) return } certChainPEM := certChainToPEM(certChain) var caPEM Certificate if len(certChainPEM) > 1 { caPEM = certChainPEM[1] } LogCertificate(w, certChain[0]) render.JSONStatus(w, r, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } ================================================ FILE: api/ssh.go ================================================ package api import ( "context" "crypto/x509" "encoding/base64" "encoding/json" "net/http" "net/url" "strings" "time" "github.com/google/uuid" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" "github.com/smallstep/certificates/templates" ) // SSHAuthority is the interface implemented by a SSH CA authority. type SSHAuthority interface { SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) GetSSHRoots(ctx context.Context) (*config.SSHKeys, error) GetSSHFederation(ctx context.Context) (*config.SSHKeys, error) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) GetSSHBastion(ctx context.Context, user string, hostname string) (*config.Bastion, error) } // SSHSignRequest is the request body of an SSH certificate request. type SSHSignRequest struct { PublicKey []byte `json:"publicKey"` // base64 encoded OTT string `json:"ott"` CertType string `json:"certType,omitempty"` KeyID string `json:"keyID,omitempty"` Principals []string `json:"principals,omitempty"` ValidAfter TimeDuration `json:"validAfter,omitempty"` ValidBefore TimeDuration `json:"validBefore,omitempty"` AddUserPublicKey []byte `json:"addUserPublicKey,omitempty"` IdentityCSR CertificateRequest `json:"identityCSR,omitempty"` TemplateData json.RawMessage `json:"templateData,omitempty"` } // Validate validates the SSHSignRequest. func (s *SSHSignRequest) Validate() error { switch { case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: return errs.BadRequest("invalid certType '%s'", s.CertType) case len(s.PublicKey) == 0: return errs.BadRequest("missing or empty publicKey") case s.OTT == "": return errs.BadRequest("missing or empty ott") default: // Validate identity signature if provided if s.IdentityCSR.CertificateRequest != nil { if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil { return errs.BadRequestErr(err, "invalid identityCSR") } } return nil } } // SSHSignResponse is the response object that returns the SSH certificate. type SSHSignResponse struct { Certificate SSHCertificate `json:"crt"` AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"` IdentityCertificate []Certificate `json:"identityCrt,omitempty"` } // SSHRootsResponse represents the response object that returns the SSH user and // host keys. type SSHRootsResponse struct { UserKeys []SSHPublicKey `json:"userKey,omitempty"` HostKeys []SSHPublicKey `json:"hostKey,omitempty"` } // SSHCertificate represents the response SSH certificate. type SSHCertificate struct { *ssh.Certificate `json:"omitempty"` } // SSHGetHostsResponse is the response object that returns the list of valid // hosts for SSH. type SSHGetHostsResponse struct { Hosts []config.Host `json:"hosts"` } // MarshalJSON implements the json.Marshaler interface. Returns a quoted, // base64 encoded, openssh wire format version of the certificate. func (c SSHCertificate) MarshalJSON() ([]byte, error) { if c.Certificate == nil { return []byte("null"), nil } s := base64.StdEncoding.EncodeToString(c.Certificate.Marshal()) return []byte(`"` + s + `"`), nil } // UnmarshalJSON implements the json.Unmarshaler interface. The certificate is // expected to be a quoted, base64 encoded, openssh wire formatted block of bytes. func (c *SSHCertificate) UnmarshalJSON(data []byte) error { var s string if err := json.Unmarshal(data, &s); err != nil { return errors.Wrap(err, "error decoding certificate") } if s == "" { c.Certificate = nil return nil } certData, err := base64.StdEncoding.DecodeString(s) if err != nil { return errors.Wrap(err, "error decoding ssh certificate") } pub, err := ssh.ParsePublicKey(certData) if err != nil { return errors.Wrap(err, "error parsing ssh certificate") } cert, ok := pub.(*ssh.Certificate) if !ok { return errors.Errorf("error decoding ssh certificate: %T is not an *ssh.Certificate", pub) } c.Certificate = cert return nil } // SSHPublicKey represents a public key in a response object. type SSHPublicKey struct { ssh.PublicKey } // MarshalJSON implements the json.Marshaler interface. Returns a quoted, // base64 encoded, openssh wire format version of the public key. func (p *SSHPublicKey) MarshalJSON() ([]byte, error) { if p == nil || p.PublicKey == nil { return []byte("null"), nil } s := base64.StdEncoding.EncodeToString(p.PublicKey.Marshal()) return []byte(`"` + s + `"`), nil } // UnmarshalJSON implements the json.Unmarshaler interface. The public key is // expected to be a quoted, base64 encoded, openssh wire formatted block of // bytes. func (p *SSHPublicKey) UnmarshalJSON(data []byte) error { var s string if err := json.Unmarshal(data, &s); err != nil { return errors.Wrap(err, "error decoding ssh public key") } if s == "" { p.PublicKey = nil return nil } data, err := base64.StdEncoding.DecodeString(s) if err != nil { return errors.Wrap(err, "error decoding ssh public key") } pub, err := ssh.ParsePublicKey(data) if err != nil { return errors.Wrap(err, "error parsing ssh public key") } p.PublicKey = pub return nil } // Template represents the output of a template. type Template = templates.Output // SSHConfigRequest is the request body used to get the SSH configuration // templates. type SSHConfigRequest struct { Type string `json:"type"` Data map[string]string `json:"data"` } // Validate checks the values of the SSHConfigurationRequest. func (r *SSHConfigRequest) Validate() error { switch r.Type { case "": r.Type = provisioner.SSHUserCert return nil case provisioner.SSHUserCert, provisioner.SSHHostCert: return nil default: return errs.BadRequest("invalid type '%s'", r.Type) } } // SSHConfigResponse is the response that returns the rendered templates. type SSHConfigResponse struct { UserTemplates []Template `json:"userTemplates,omitempty"` HostTemplates []Template `json:"hostTemplates,omitempty"` } // SSHCheckPrincipalRequest is the request body used to check if a principal // certificate has been created. Right now it only supported for hosts // certificates. type SSHCheckPrincipalRequest struct { Type string `json:"type"` Principal string `json:"principal"` Token string `json:"token,omitempty"` } // Validate checks the check principal request. func (r *SSHCheckPrincipalRequest) Validate() error { switch { case r.Type != provisioner.SSHHostCert: return errs.BadRequest("unsupported type '%s'", r.Type) case r.Principal == "": return errs.BadRequest("missing or empty principal") default: return nil } } // SSHCheckPrincipalResponse is the response body used to check if a principal // exists. type SSHCheckPrincipalResponse struct { Exists bool `json:"exists"` } // SSHBastionRequest is the request body used to get the bastion for a given // host. type SSHBastionRequest struct { User string `json:"user"` Hostname string `json:"hostname"` } // Validate checks the values of the SSHBastionRequest. func (r *SSHBastionRequest) Validate() error { if r.Hostname == "" { return errs.BadRequest("missing or empty hostname") } return nil } // SSHBastionResponse is the response body used to return the bastion for a // given host. type SSHBastionResponse struct { Hostname string `json:"hostname"` Bastion *config.Bastion `json:"bastion,omitempty"` } // SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. func SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { render.Error(w, r, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { render.Error(w, r, errs.BadRequestErr(err, "error parsing publicKey")) return } var addUserPublicKey ssh.PublicKey if body.AddUserPublicKey != nil { addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) if err != nil { render.Error(w, r, errs.BadRequestErr(err, "error parsing addUserPublicKey")) return } } opts := provisioner.SignSSHOptions{ CertType: body.CertType, KeyID: body.KeyID, Principals: body.Principals, ValidBefore: body.ValidBefore, ValidAfter: body.ValidAfter, TemplateData: body.TemplateData, } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) ctx = provisioner.NewContextWithToken(ctx, body.OTT) ctx = provisioner.NewContextWithCertType(ctx, opts.CertType) a := mustAuthority(ctx) signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, r, errs.UnauthorizedErr(err)) return } cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error signing ssh certificate")) return } var addUserCertificate *SSHCertificate if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error signing ssh certificate")) return } addUserCertificate = &SSHCertificate{addUserCert} } // Sign identity certificate if available. var identityCertificate []Certificate if cr := body.IdentityCSR.CertificateRequest; cr != nil { ctx := authority.NewContextWithSkipTokenReuse(r.Context()) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignIdentityMethod) signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, r, errs.UnauthorizedErr(err)) return } // Enforce the same duration as ssh certificate. signOpts = append(signOpts, &identityModifier{ Identity: getIdentityURI(cr), NotBefore: time.Unix(cast.Int64(cert.ValidAfter), 0), NotAfter: time.Unix(cast.Int64(cert.ValidBefore), 0), }) certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...) if err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error signing identity certificate")) return } identityCertificate = certChainToPEM(certChain) } LogSSHCertificate(w, cert) render.JSONStatus(w, r, &SSHSignResponse{ Certificate: SSHCertificate{cert}, AddUserCertificate: addUserCertificate, IdentityCertificate: identityCertificate, }, http.StatusCreated) } // SSHRoots is an HTTP handler that returns the SSH public keys for user and host // certificates. func SSHRoots(w http.ResponseWriter, r *http.Request) { ctx := r.Context() keys, err := mustAuthority(ctx).GetSSHRoots(ctx) if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { render.Error(w, r, errs.NotFound("no keys found")) return } resp := new(SSHRootsResponse) for _, k := range keys.HostKeys { resp.HostKeys = append(resp.HostKeys, SSHPublicKey{PublicKey: k}) } for _, k := range keys.UserKeys { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } render.JSON(w, r, resp) } // SSHFederation is an HTTP handler that returns the federated SSH public keys // for user and host certificates. func SSHFederation(w http.ResponseWriter, r *http.Request) { ctx := r.Context() keys, err := mustAuthority(ctx).GetSSHFederation(ctx) if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { render.Error(w, r, errs.NotFound("no keys found")) return } resp := new(SSHRootsResponse) for _, k := range keys.HostKeys { resp.HostKeys = append(resp.HostKeys, SSHPublicKey{PublicKey: k}) } for _, k := range keys.UserKeys { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } render.JSON(w, r, resp) } // SSHConfig is an HTTP handler that returns rendered templates for ssh clients // and servers. func SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { render.Error(w, r, err) return } ctx := r.Context() ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data) if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } var cfg SSHConfigResponse switch body.Type { case provisioner.SSHUserCert: cfg.UserTemplates = ts case provisioner.SSHHostCert: cfg.HostTemplates = ts default: render.Error(w, r, errs.InternalServer("it should hot get here")) return } render.JSON(w, r, cfg) } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. func SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { render.Error(w, r, err) return } ctx := r.Context() exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token) if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } render.JSON(w, r, &SSHCheckPrincipalResponse{ Exists: exists, }) } // SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. func SSHGetHosts(w http.ResponseWriter, r *http.Request) { var cert *x509.Certificate if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { cert = r.TLS.PeerCertificates[0] } ctx := r.Context() hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert) if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } render.JSON(w, r, &SSHGetHostsResponse{ Hosts: hosts, }) } // SSHBastion provides returns the bastion configured if any. func SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { render.Error(w, r, err) return } ctx := r.Context() bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname) if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } render.JSON(w, r, &SSHBastionResponse{ Hostname: body.Hostname, Bastion: bastion, }) } // identityModifier is a custom modifier used to force a fixed duration, and set // the identity URI. type identityModifier struct { Identity *url.URL NotBefore time.Time NotAfter time.Time } // Enforce implements the enforcer interface and sets the validity bounds and // the identity uri to the certificate. func (m *identityModifier) Enforce(cert *x509.Certificate) error { cert.NotBefore = m.NotBefore cert.NotAfter = m.NotAfter if m.Identity != nil { var identityURL = m.Identity.String() for _, u := range cert.URIs { if u.String() == identityURL { return nil } } cert.URIs = append(cert.URIs, m.Identity) } return nil } // getIdentityURI returns the first valid UUID URN from the given CSR. func getIdentityURI(cr *x509.CertificateRequest) *url.URL { for _, u := range cr.URIs { s := u.String() // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx if len(s) == 9+36 && strings.EqualFold(s[:9], "urn:uuid:") { if _, err := uuid.Parse(s); err == nil { return u } } } return nil } ================================================ FILE: api/sshRekey.go ================================================ package api import ( "net/http" "time" "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" ) // SSHRekeyRequest is the request body of an SSH certificate request. type SSHRekeyRequest struct { OTT string `json:"ott"` PublicKey []byte `json:"publicKey"` //base64 encoded } // Validate validates the SSHSignRekey. func (s *SSHRekeyRequest) Validate() error { switch { case s.OTT == "": return errs.BadRequest("missing or empty ott") case len(s.PublicKey) == 0: return errs.BadRequest("missing or empty public key") default: return nil } } // SSHRekeyResponse is the response object that returns the SSH certificate. type SSHRekeyResponse struct { Certificate SSHCertificate `json:"crt"` IdentityCertificate []Certificate `json:"identityCrt,omitempty"` } // SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. func SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { render.Error(w, r, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { render.Error(w, r, errs.BadRequestErr(err, "error parsing publicKey")) return } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) ctx = provisioner.NewContextWithToken(ctx, body.OTT) a := mustAuthority(ctx) signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, r, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return } // Match identity cert with the SSH cert notBefore := time.Unix(cast.Int64(oldCert.ValidAfter), 0) notAfter := time.Unix(cast.Int64(oldCert.ValidBefore), 0) identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error renewing identity certificate")) return } LogSSHCertificate(w, newCert) render.JSONStatus(w, r, &SSHRekeyResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) } ================================================ FILE: api/sshRenew.go ================================================ package api import ( "crypto/x509" "net/http" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" ) // SSHRenewRequest is the request body of an SSH certificate request. type SSHRenewRequest struct { OTT string `json:"ott"` } // Validate validates the SSHSignRequest. func (s *SSHRenewRequest) Validate() error { switch s.OTT { case "": return errs.BadRequest("missing or empty ott") default: return nil } } // SSHRenewResponse is the response object that returns the SSH certificate. type SSHRenewResponse struct { Certificate SSHCertificate `json:"crt"` IdentityCertificate []Certificate `json:"identityCrt,omitempty"` } // SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. func SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { render.Error(w, r, err) return } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) ctx = provisioner.NewContextWithToken(ctx, body.OTT) a := mustAuthority(ctx) _, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, r, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } newCert, err := a.RenewSSH(ctx, oldCert) if err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error renewing ssh certificate")) return } // Match identity cert with the SSH cert notBefore := time.Unix(cast.Int64(oldCert.ValidAfter), 0) notAfter := time.Unix(cast.Int64(oldCert.ValidBefore), 0) identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error renewing identity certificate")) return } LogSSHCertificate(w, newCert) render.JSONStatus(w, r, &SSHSignResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) } // renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { return nil, nil } // Clone the certificate as we can modify it. cert, err := x509.ParseCertificate(r.TLS.PeerCertificates[0].Raw) if err != nil { return nil, errors.Wrap(err, "error parsing client certificate") } // Enforce the cert to match another certificate, for example an ssh // certificate. if !notBefore.IsZero() { cert.NotBefore = notBefore } if !notAfter.IsZero() { cert.NotAfter = notAfter } certChain, err := mustAuthority(r.Context()).Renew(cert) if err != nil { return nil, err } return certChainToPEM(certChain), nil } ================================================ FILE: api/sshRevoke.go ================================================ package api import ( "net/http" "golang.org/x/crypto/ocsp" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" ) // SSHRevokeResponse is the response object that returns the health of the server. type SSHRevokeResponse struct { Status string `json:"status"` } // SSHRevokeRequest is the request body for a revocation request. type SSHRevokeRequest struct { Serial string `json:"serial"` OTT string `json:"ott"` ReasonCode int `json:"reasonCode"` Reason string `json:"reason"` Passive bool `json:"passive"` } // Validate checks the fields of the RevokeRequest and returns nil if they are ok // or an error if something is wrong. func (r *SSHRevokeRequest) Validate() (err error) { if r.Serial == "" { return errs.BadRequest("missing serial") } if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise { return errs.BadRequest("reasonCode out of bounds") } if !r.Passive { return errs.NotImplemented("non-passive revocation not implemented") } if r.OTT == "" { return errs.BadRequest("missing ott") } return } // Revoke supports handful of different methods that revoke a Certificate. // // NOTE: currently only Passive revocation is supported. func SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { render.Error(w, r, err) return } opts := &authority.RevokeOptions{ Serial: body.Serial, Reason: body.Reason, ReasonCode: body.ReasonCode, PassiveOnly: body.Passive, } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod) a := mustAuthority(ctx) logOtt(w, body.OTT) if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, r, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT if err := a.Revoke(ctx, opts); err != nil { render.Error(w, r, errs.ForbiddenErr(err, "error revoking ssh certificate")) return } logSSHRevoke(w, opts) render.JSON(w, r, &SSHRevokeResponse{Status: "ok"}) } func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { if rl, ok := w.(logging.ResponseLogger); ok { rl.WithFields(map[string]interface{}{ "serial": ri.Serial, "reasonCode": ri.ReasonCode, "reason": ri.Reason, "passiveOnly": ri.PassiveOnly, "mTLS": ri.MTLS, "ssh": true, }) } } ================================================ FILE: api/ssh_test.go ================================================ package api import ( "bytes" "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "net/url" "reflect" "strings" "testing" "time" "github.com/google/uuid" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" ) var ( sshSignerKey = mustKey() sshUserKey = mustKey() sshHostKey = mustKey() ) func mustKey() *ecdsa.PrivateKey { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { panic(err) } return priv } func signSSHCertificate(cert *ssh.Certificate) error { signerKey, err := ssh.NewPublicKey(sshSignerKey.Public()) if err != nil { return err } signer, err := ssh.NewSignerFromSigner(sshSignerKey) if err != nil { return err } cert.SignatureKey = signerKey data := cert.Marshal() data = data[:len(data)-4] sig, err := signer.Sign(rand.Reader, data) if err != nil { return err } cert.Signature = sig return nil } func getSignedUserCertificate() (*ssh.Certificate, error) { key, err := ssh.NewPublicKey(sshUserKey.Public()) if err != nil { return nil, err } t := time.Now() cert := &ssh.Certificate{ Nonce: []byte("1234567890"), Key: key, Serial: 1234567890, CertType: ssh.UserCert, KeyId: "user@localhost", ValidPrincipals: []string{"user"}, ValidAfter: uint64(t.Unix()), ValidBefore: uint64(t.Add(time.Hour).Unix()), Permissions: ssh.Permissions{ CriticalOptions: map[string]string{}, Extensions: map[string]string{ "permit-X11-forwarding": "", "permit-agent-forwarding": "", "permit-port-forwarding": "", "permit-pty": "", "permit-user-rc": "", }, }, Reserved: []byte{}, } if err := signSSHCertificate(cert); err != nil { return nil, err } return cert, nil } func getSignedHostCertificate() (*ssh.Certificate, error) { key, err := ssh.NewPublicKey(sshHostKey.Public()) if err != nil { return nil, err } t := time.Now() cert := &ssh.Certificate{ Nonce: []byte("1234567890"), Key: key, Serial: 1234567890, CertType: ssh.UserCert, KeyId: "internal.smallstep.com", ValidPrincipals: []string{"internal.smallstep.com"}, ValidAfter: uint64(t.Unix()), ValidBefore: uint64(t.Add(time.Hour).Unix()), Permissions: ssh.Permissions{ CriticalOptions: map[string]string{}, Extensions: map[string]string{}, }, Reserved: []byte{}, } if err := signSSHCertificate(cert); err != nil { return nil, err } return cert, nil } func TestSSHCertificate_MarshalJSON(t *testing.T) { user, err := getSignedUserCertificate() require.NoError(t, err) host, err := getSignedHostCertificate() require.NoError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) type fields struct { Certificate *ssh.Certificate } tests := []struct { name string fields fields want []byte wantErr bool }{ {"nil", fields{Certificate: nil}, []byte("null"), false}, {"user", fields{Certificate: user}, []byte(`"` + userB64 + `"`), false}, {"user", fields{Certificate: host}, []byte(`"` + hostB64 + `"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := SSHCertificate{ Certificate: tt.fields.Certificate, } got, err := c.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("SSHCertificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("SSHCertificate.MarshalJSON() = %v, want %v", got, tt.want) } }) } } func TestSSHCertificate_UnmarshalJSON(t *testing.T) { user, err := getSignedUserCertificate() require.NoError(t, err) host, err := getSignedHostCertificate() require.NoError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) keyB64 := base64.StdEncoding.EncodeToString(user.Key.Marshal()) type args struct { data []byte } tests := []struct { name string args args want *ssh.Certificate wantErr bool }{ {"null", args{[]byte(`null`)}, nil, false}, {"empty", args{[]byte(`""`)}, nil, false}, {"user", args{[]byte(`"` + userB64 + `"`)}, user, false}, {"host", args{[]byte(`"` + hostB64 + `"`)}, host, false}, {"bad-string", args{[]byte(userB64)}, nil, true}, {"bad-base64", args{[]byte(`"this-is-not-base64"`)}, nil, true}, {"bad-key", args{[]byte(`"bm90LWEta2V5"`)}, nil, true}, {"bat-cert", args{[]byte(`"` + keyB64 + `"`)}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SSHCertificate{} if err := c.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { t.Errorf("SSHCertificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } if !reflect.DeepEqual(tt.want, c.Certificate) { t.Errorf("SSHCertificate.UnmarshalJSON() got = %v, want %v\n", c.Certificate, tt.want) } }) } } func TestSignSSHRequest_Validate(t *testing.T) { csr := parseCertificateRequest(csrPEM) badCSR := parseCertificateRequest(csrPEM) badCSR.SignatureAlgorithm = x509.SHA1WithRSA type fields struct { PublicKey []byte OTT string CertType string Principals []string ValidAfter TimeDuration ValidBefore TimeDuration AddUserPublicKey []byte KeyID string IdentityCSR CertificateRequest } tests := []struct { name string fields fields wantErr bool }{ {"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false}, {"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false}, {"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false}, {"ok-keyID", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{}}, false}, {"ok-identityCSR", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{CertificateRequest: csr}}, false}, {"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true}, {"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true}, {"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true}, {"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true}, {"identityCSR", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{CertificateRequest: badCSR}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &SSHSignRequest{ PublicKey: tt.fields.PublicKey, OTT: tt.fields.OTT, CertType: tt.fields.CertType, Principals: tt.fields.Principals, ValidAfter: tt.fields.ValidAfter, ValidBefore: tt.fields.ValidBefore, AddUserPublicKey: tt.fields.AddUserPublicKey, KeyID: tt.fields.KeyID, IdentityCSR: tt.fields.IdentityCSR, } if err := s.Validate(); (err != nil) != tt.wantErr { t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_SSHSign(t *testing.T) { user, err := getSignedUserCertificate() require.NoError(t, err) host, err := getSignedHostCertificate() require.NoError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) userReq, err := json.Marshal(SSHSignRequest{ PublicKey: user.Key.Marshal(), OTT: "ott", }) require.NoError(t, err) hostReq, err := json.Marshal(SSHSignRequest{ PublicKey: host.Key.Marshal(), OTT: "ott", }) require.NoError(t, err) userAddReq, err := json.Marshal(SSHSignRequest{ PublicKey: user.Key.Marshal(), OTT: "ott", AddUserPublicKey: user.Key.Marshal(), }) require.NoError(t, err) userIdentityReq, err := json.Marshal(SSHSignRequest{ PublicKey: user.Key.Marshal(), OTT: "ott", IdentityCSR: CertificateRequest{parseCertificateRequest(csrPEM)}, }) require.NoError(t, err) identityCerts := []*x509.Certificate{ parseCertificate(certPEM), } identityCertsPEM := []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`) tests := []struct { name string req []byte authErr error signCert *ssh.Certificate signErr error addUserCert *ssh.Certificate addUserErr error tlsSignCerts []*x509.Certificate tlsSignErr error body []byte statusCode int }{ {"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q}`, userB64)), http.StatusCreated}, {"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q}`, hostB64)), http.StatusCreated}, {"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q,"addUserCrt":%q}`, userB64, userB64)), http.StatusCreated}, {"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":%q,"identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated}, {"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":%q,"ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized}, {"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden}, {"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden}, {"fail-user-identity", userIdentityReq, nil, user, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return []provisioner.SignOption{}, tt.authErr }, signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { return tt.signCert, tt.signErr }, signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { return tt.addUserCert, tt.addUserErr }, signWithContext: func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { return tt.tlsSignCerts, tt.tlsSignErr }, }) req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req)) w := httptest.NewRecorder() SSHSign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.SignSSH StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.SignSSH unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.body) { t.Errorf("caHandler.SignSSH Body = %s, wants %s", body, tt.body) } } }) } } func Test_SSHRoots(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) require.NoError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) host, err := ssh.NewPublicKey(sshHostKey.Public()) require.NoError(t, err) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) tests := []struct { name string keys *authority.SSHKeys keysErr error body []byte statusCode int }{ {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q],"hostKey":[%q]}`, userB64, hostB64)), http.StatusOK}, {"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q,%q],"hostKey":[%q,%q]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK}, {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q]}`, userB64)), http.StatusOK}, {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":[%q]}`, hostB64)), http.StatusOK}, {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, }) req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody) w := httptest.NewRecorder() SSHRoots(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.SSHRoots unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.body) { t.Errorf("caHandler.SSHRoots Body = %s, wants %s", body, tt.body) } } }) } } func Test_SSHFederation(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) require.NoError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) host, err := ssh.NewPublicKey(sshHostKey.Public()) require.NoError(t, err) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) tests := []struct { name string keys *authority.SSHKeys keysErr error body []byte statusCode int }{ {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q],"hostKey":[%q]}`, userB64, hostB64)), http.StatusOK}, {"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q,%q],"hostKey":[%q,%q]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK}, {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q]}`, userB64)), http.StatusOK}, {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":[%q]}`, hostB64)), http.StatusOK}, {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, }) req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody) w := httptest.NewRecorder() SSHFederation(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.SSHFederation unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.body) { t.Errorf("caHandler.SSHFederation Body = %s, wants %s", body, tt.body) } } }) } } func Test_SSHConfig(t *testing.T) { userOutput := []templates.Output{ {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")}, {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")}, } hostOutput := []templates.Output{ {Name: "sshd_config.tpl", Type: templates.Snippet, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("TrustedUserCAKeys /etc/ssh/ca.pub")}, {Name: "ca.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/ca.pub", Content: []byte("ecdsa-sha2-nistp256 AAAA...=")}, } userJSON, err := json.Marshal(userOutput) require.NoError(t, err) hostJSON, err := json.Marshal(hostOutput) require.NoError(t, err) tests := []struct { name string req string output []templates.Output err error body []byte statusCode int }{ {"user", `{"type":"user"}`, userOutput, nil, []byte(fmt.Sprintf(`{"userTemplates":%s}`, userJSON)), http.StatusOK}, {"host", `{"type":"host"}`, hostOutput, nil, []byte(fmt.Sprintf(`{"hostTemplates":%s}`, hostJSON)), http.StatusOK}, {"noType", `{}`, userOutput, nil, []byte(fmt.Sprintf(`{"userTemplates":%s}`, userJSON)), http.StatusOK}, {"badType", `{"type":"bad"}`, userOutput, nil, nil, http.StatusBadRequest}, {"badData", `{"type":"user","data":{"bad"}}`, userOutput, nil, nil, http.StatusBadRequest}, {"error", `{"type": "user"}`, nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { return tt.output, tt.err }, }) req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req)) w := httptest.NewRecorder() SSHConfig(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.SSHConfig unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.body) { t.Errorf("caHandler.SSHConfig Body = %s, wants %s", body, tt.body) } } }) } } func Test_SSHCheckHost(t *testing.T) { tests := []struct { name string req string exists bool err error body []byte statusCode int }{ {"true", `{"type":"host","principal":"foo.example.com"}`, true, nil, []byte(`{"exists":true}`), http.StatusOK}, {"false", `{"type":"host","principal":"bar.example.com"}`, false, nil, []byte(`{"exists":false}`), http.StatusOK}, {"badType", `{"type":"user","principal":"bar.example.com"}`, false, nil, nil, http.StatusBadRequest}, {"badPrincipal", `{"type":"host","principal":""}`, false, nil, nil, http.StatusBadRequest}, {"badRequest", `{"foo"}`, false, nil, nil, http.StatusBadRequest}, {"error", `{"type":"host","principal":"foo.example.com"}`, false, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { return tt.exists, tt.err }, }) req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req)) w := httptest.NewRecorder() SSHCheckHost(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.body) { t.Errorf("caHandler.SSHCheckHost Body = %s, wants %s", body, tt.body) } } }) } } func Test_SSHGetHosts(t *testing.T) { hosts := []authority.Host{ {HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"}, {HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"}, } hostsJSON, err := json.Marshal(hosts) require.NoError(t, err) tests := []struct { name string hosts []authority.Host err error body []byte statusCode int }{ {"ok", hosts, nil, []byte(fmt.Sprintf(`{"hosts":%s}`, hostsJSON)), http.StatusOK}, {"empty (array)", []authority.Host{}, nil, []byte(`{"hosts":[]}`), http.StatusOK}, {"empty (nil)", nil, nil, []byte(`{"hosts":null}`), http.StatusOK}, {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) { return tt.hosts, tt.err }, }) req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody) w := httptest.NewRecorder() SSHGetHosts(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.body) { t.Errorf("caHandler.SSHGetHosts Body = %s, wants %s", body, tt.body) } } }) } } func Test_SSHBastion(t *testing.T) { bastion := &authority.Bastion{ Hostname: "bastion.local", } bastionPort := &authority.Bastion{ Hostname: "bastion.local", Port: "2222", } tests := []struct { name string bastion *authority.Bastion bastionErr error req []byte body []byte statusCode int }{ {"ok", bastion, nil, []byte(`{"hostname":"host.local"}`), []byte(`{"hostname":"host.local","bastion":{"hostname":"bastion.local"}}`), http.StatusOK}, {"ok", bastionPort, nil, []byte(`{"hostname":"host.local","user":"user"}`), []byte(`{"hostname":"host.local","bastion":{"hostname":"bastion.local","port":"2222"}}`), http.StatusOK}, {"empty", nil, nil, []byte(`{"hostname":"host.local"}`), []byte(`{"hostname":"host.local"}`), http.StatusOK}, {"bad json", bastion, nil, []byte(`bad json`), nil, http.StatusBadRequest}, {"bad request", bastion, nil, []byte(`{"hostname": ""}`), nil, http.StatusBadRequest}, {"error", nil, fmt.Errorf("an error"), []byte(`{"hostname":"host.local"}`), nil, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockMustAuthority(t, &mockAuthority{ getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { return tt.bastion, tt.bastionErr }, }) req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req)) w := httptest.NewRecorder() SSHBastion(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.SSHBastion StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.SSHBastion unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.body) { t.Errorf("caHandler.SSHBastion Body = %s, wants %s", body, tt.body) } } }) } } func TestSSHPublicKey_MarshalJSON(t *testing.T) { key, err := ssh.NewPublicKey(sshUserKey.Public()) require.NoError(t, err) keyB64 := base64.StdEncoding.EncodeToString(key.Marshal()) tests := []struct { name string publicKey *SSHPublicKey want []byte wantErr bool }{ {"ok", &SSHPublicKey{PublicKey: key}, []byte(`"` + keyB64 + `"`), false}, {"null", nil, []byte("null"), false}, {"null", &SSHPublicKey{PublicKey: nil}, []byte("null"), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.publicKey.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("SSHPublicKey.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("SSHPublicKey.MarshalJSON() = %s, want %s", got, tt.want) } }) } } func TestSSHPublicKey_UnmarshalJSON(t *testing.T) { key, err := ssh.NewPublicKey(sshUserKey.Public()) require.NoError(t, err) keyB64 := base64.StdEncoding.EncodeToString(key.Marshal()) type args struct { data []byte } tests := []struct { name string args args want *SSHPublicKey wantErr bool }{ {"ok", args{[]byte(`"` + keyB64 + `"`)}, &SSHPublicKey{PublicKey: key}, false}, {"empty", args{[]byte(`""`)}, &SSHPublicKey{}, false}, {"null", args{[]byte(`null`)}, &SSHPublicKey{}, false}, {"noString", args{[]byte("123")}, &SSHPublicKey{}, true}, {"badB64", args{[]byte(`"bad"`)}, &SSHPublicKey{}, true}, {"badKey", args{[]byte(`"Zm9vYmFyCg=="`)}, &SSHPublicKey{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &SSHPublicKey{} if err := p.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { t.Errorf("SSHPublicKey.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } if !reflect.DeepEqual(p, tt.want) { t.Errorf("SSHPublicKey.UnmarshalJSON() = %v, want %v", p, tt.want) } }) } } func Test_identityModifier_Enforce(t *testing.T) { now := time.Now() type fields struct { Identity *url.URL NotBefore time.Time NotAfter time.Time } type args struct { cert *x509.Certificate } tests := []struct { name string fields fields args args want *x509.Certificate assertion assert.ErrorAssertionFunc }{ {"ok", fields{&url.URL{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, now, now.Add(time.Hour)}, args{&x509.Certificate{}}, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}}, }, assert.NoError}, {"ok exists", fields{&url.URL{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, now, now.Add(time.Hour)}, args{&x509.Certificate{ URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}}, }}, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}}, }, assert.NoError}, {"ok append", fields{&url.URL{Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, now, now.Add(time.Hour)}, args{&x509.Certificate{ URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:27bb66db-e12a-4ff6-9161-aa6b0a98f914"}}, }}, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), URIs: []*url.URL{ {Scheme: "urn", Opaque: "uuid:27bb66db-e12a-4ff6-9161-aa6b0a98f914"}, {Scheme: "urn", Opaque: "uuid:0c4670b2-d9f1-42bb-9045-184836f16733"}, }, }, assert.NoError}, {"ok no identity", fields{nil, now, now.Add(time.Hour)}, args{&x509.Certificate{}}, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }, assert.NoError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m := &identityModifier{ Identity: tt.fields.Identity, NotBefore: tt.fields.NotBefore, NotAfter: tt.fields.NotAfter, } tt.assertion(t, m.Enforce(tt.args.cert)) }) } } func Test_getIdentityURI(t *testing.T) { id, err := uuid.Parse("54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1") require.NoError(t, err) u, err := url.Parse(id.URN()) require.NoError(t, err) type args struct { cr *x509.CertificateRequest } tests := []struct { name string args args want *url.URL }{ {"ok", args{&x509.CertificateRequest{ URIs: []*url.URL{u}, }}, &url.URL{Scheme: "urn", Opaque: "uuid:54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}}, {"ok multiple", args{&x509.CertificateRequest{ URIs: []*url.URL{u, {Scheme: "urn", Opaque: "uuid:f0e74f3a-95fe-4cf6-98e3-68e55b69ba48"}}, }}, &url.URL{Scheme: "urn", Opaque: "uuid:54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}}, {"ok multiple with invalid", args{&x509.CertificateRequest{ URIs: []*url.URL{{Scheme: "urn", Opaque: "uuid:f0e74f3a+95fe+4cf6+98e3+68e55b69ba48"}, u}, }}, &url.URL{Scheme: "urn", Opaque: "uuid:54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}}, {"ok missing", args{&x509.CertificateRequest{ URIs: []*url.URL{{Scheme: "https", Host: "example.com", Path: "/54a2ec9d-a7d9-4b53-8f9a-efcd275e35e1"}}, }}, nil}, {"ok empty", args{&x509.CertificateRequest{}}, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert.Equal(t, tt.want, getIdentityURI(tt.args.cr)) }) } } ================================================ FILE: authority/admin/api/acme.go ================================================ package api import ( "fmt" "net/http" "github.com/smallstep/linkedca" "google.golang.org/protobuf/types/known/timestamppb" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" ) // CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests type CreateExternalAccountKeyRequest struct { Reference string `json:"reference"` } // Validate validates a new ACME EAB Key request body. func (r *CreateExternalAccountKeyRequest) Validate() error { if len(r.Reference) > 256 { // an arbitrary, but sensible (IMO), limit return fmt.Errorf("reference length %d exceeds the maximum (256)", len(r.Reference)) } return nil } // GetExternalAccountKeysResponse is the type for GET /admin/acme/eab responses type GetExternalAccountKeysResponse struct { EAKs []*linkedca.EABKey `json:"eaks"` NextCursor string `json:"nextCursor"` } // requireEABEnabled is a middleware that ensures ACME EAB is enabled // before serving requests that act on ACME EAB credentials. func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) acmeProvisioner := prov.GetDetails().GetACME() if acmeProvisioner == nil { render.Error(w, r, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName())) return } if !acmeProvisioner.RequireEab { render.Error(w, r, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName())) return } next(w, r) } } // ACMEAdminResponder is responsible for writing ACME admin responses type ACMEAdminResponder interface { GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) } // acmeAdminResponder implements ACMEAdminResponder. type acmeAdminResponder struct{} // NewACMEAdminResponder returns a new ACMEAdminResponder func NewACMEAdminResponder() ACMEAdminResponder { return &acmeAdminResponder{} } // GetExternalAccountKeys writes the response for the EAB keys GET endpoint func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // CreateExternalAccountKey writes the response for the EAB key POST endpoint func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey { if k == nil { return nil } eak := &linkedca.EABKey{ Id: k.ID, HmacKey: k.HmacKey, Provisioner: k.ProvisionerID, Reference: k.Reference, Account: k.AccountID, CreatedAt: timestamppb.New(k.CreatedAt), BoundAt: timestamppb.New(k.BoundAt), } if k.Policy != nil { eak.Policy = &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{}, Deny: &linkedca.X509Names{}, }, } eak.Policy.X509.Allow.Dns = k.Policy.X509.Allowed.DNSNames eak.Policy.X509.Allow.Ips = k.Policy.X509.Allowed.IPRanges eak.Policy.X509.Deny.Dns = k.Policy.X509.Denied.DNSNames eak.Policy.X509.Deny.Ips = k.Policy.X509.Denied.IPRanges eak.Policy.X509.AllowWildcardNames = k.Policy.X509.AllowWildcardNames } return eak } func linkedEAKToCertificates(k *linkedca.EABKey) *acme.ExternalAccountKey { if k == nil { return nil } eak := &acme.ExternalAccountKey{ ID: k.Id, ProvisionerID: k.Provisioner, Reference: k.Reference, AccountID: k.Account, HmacKey: k.HmacKey, CreatedAt: k.CreatedAt.AsTime(), BoundAt: k.BoundAt.AsTime(), } if policy := k.GetPolicy(); policy != nil { eak.Policy = &acme.Policy{} if x509 := policy.GetX509(); x509 != nil { eak.Policy.X509 = acme.X509Policy{} if allow := x509.GetAllow(); allow != nil { eak.Policy.X509.Allowed = acme.PolicyNames{} eak.Policy.X509.Allowed.DNSNames = allow.Dns eak.Policy.X509.Allowed.IPRanges = allow.Ips } if deny := x509.GetDeny(); deny != nil { eak.Policy.X509.Denied = acme.PolicyNames{} eak.Policy.X509.Denied.DNSNames = deny.Dns eak.Policy.X509.Denied.IPRanges = deny.Ips } eak.Policy.X509.AllowWildcardNames = x509.AllowWildcardNames } } return eak } ================================================ FILE: authority/admin/api/acme_test.go ================================================ package api import ( "bytes" "context" "encoding/json" "io" "net/http" "net/http/httptest" "reflect" "strings" "testing" "time" "github.com/go-chi/chi/v5" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/smallstep/linkedca" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/admin" ) func readProtoJSON(r io.ReadCloser, m proto.Message) error { defer r.Close() data, err := io.ReadAll(r) if err != nil { return err } return protojson.Unmarshal(data, m) } func mockMustAuthority(t *testing.T, a adminAuthority) { t.Helper() fn := mustAuthority t.Cleanup(func() { mustAuthority = fn }) mustAuthority = func(ctx context.Context) adminAuthority { return a } } func TestHandler_requireEABEnabled(t *testing.T) { type test struct { ctx context.Context next http.HandlerFunc err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/prov.GetDetails": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'") err.Message = "error getting ACME details for provisioner 'provName'" return test{ ctx: ctx, err: err, statusCode: 500, } }, "fail/prov.GetDetails.GetACME": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Name: "provName", Details: &linkedca.ProvisionerDetails{}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'") err.Message = "error getting ACME details for provisioner 'provName'" return test{ ctx: ctx, err: err, statusCode: 500, } }, "ok/eab-disabled": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Name: "provName", Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ RequireEab: false, }, }, }, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName") err.Message = "ACME EAB not enabled for provisioner 'provName'" return test{ ctx: ctx, err: err, statusCode: 400, } }, "ok/eab-enabled": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Name: "provName", Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ RequireEab: true, }, }, }, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { w.Write(nil) // mock response with status 200 }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := httptest.NewRequest("GET", "/foo", http.NoBody).WithContext(tc.ctx) w := httptest.NewRecorder() requireEABEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 { err := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) assert.Equals(t, tc.err.Type, err.Type) assert.Equals(t, tc.err.Message, err.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, err.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } }) } } func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) { type fields struct { Reference string } tests := []struct { name string fields fields wantErr bool }{ { name: "fail/reference-too-long", fields: fields{ Reference: strings.Repeat("A", 257), }, wantErr: true, }, { name: "ok/empty-reference", fields: fields{ Reference: "", }, wantErr: false, }, { name: "ok", fields: fields{ Reference: "my-eab-reference", }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := &CreateExternalAccountKeyRequest{ Reference: tt.fields.Reference, } if err := r.Validate(); (err != nil) != tt.wantErr { t.Errorf("CreateExternalAccountKeyRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestHandler_CreateExternalAccountKey(t *testing.T) { type test struct { ctx context.Context statusCode int err *admin.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) return test{ ctx: ctx, statusCode: 501, err: &admin.Error{ Type: admin.ErrorNotImplementedType.String(), Status: http.StatusNotImplemented, Message: "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm", Detail: "not implemented", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := httptest.NewRequest("POST", "/foo", http.NoBody) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() acmeResponder := NewACMEAdminResponder() acmeResponder.CreateExternalAccountKey(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestHandler_DeleteExternalAccountKey(t *testing.T) { type test struct { ctx context.Context statusCode int err *admin.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") chiCtx.URLParams.Add("id", "keyID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) return test{ ctx: ctx, statusCode: 501, err: &admin.Error{ Type: admin.ErrorNotImplementedType.String(), Status: http.StatusNotImplemented, Message: "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm", Detail: "not implemented", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := httptest.NewRequest("DELETE", "/foo", http.NoBody) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() acmeResponder := NewACMEAdminResponder() acmeResponder.DeleteExternalAccountKey(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestHandler_GetExternalAccountKeys(t *testing.T) { type test struct { ctx context.Context statusCode int req *http.Request err *admin.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") req := httptest.NewRequest("GET", "/foo", http.NoBody) ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) return test{ ctx: ctx, statusCode: 501, req: req, err: &admin.Error{ Type: admin.ErrorNotImplementedType.String(), Status: http.StatusNotImplemented, Message: "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm", Detail: "not implemented", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() acmeResponder := NewACMEAdminResponder() acmeResponder.GetExternalAccountKeys(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func Test_eakToLinked(t *testing.T) { tests := []struct { name string k *acme.ExternalAccountKey want *linkedca.EABKey }{ { name: "no-key", k: nil, want: nil, }, { name: "no-policy", k: &acme.ExternalAccountKey{ ID: "keyID", ProvisionerID: "provID", Reference: "ref", AccountID: "accID", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), Policy: nil, }, want: &linkedca.EABKey{ Id: "keyID", Provisioner: "provID", HmacKey: []byte{1, 3, 3, 7}, Reference: "ref", Account: "accID", CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), Policy: nil, }, }, { name: "with-policy", k: &acme.ExternalAccountKey{ ID: "keyID", ProvisionerID: "provID", Reference: "ref", AccountID: "accID", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), Policy: &acme.Policy{ X509: acme.X509Policy{ Allowed: acme.PolicyNames{ DNSNames: []string{"*.local"}, IPRanges: []string{"10.0.0.0/24"}, }, Denied: acme.PolicyNames{ DNSNames: []string{"badhost.local"}, IPRanges: []string{"10.0.0.30"}, }, }, }, }, want: &linkedca.EABKey{ Id: "keyID", Provisioner: "provID", HmacKey: []byte{1, 3, 3, 7}, Reference: "ref", Account: "accID", CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), Policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"10.0.0.0/24"}, }, Deny: &linkedca.X509Names{ Dns: []string{"badhost.local"}, Ips: []string{"10.0.0.30"}, }, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := eakToLinked(tt.k); !reflect.DeepEqual(got, tt.want) { t.Errorf("eakToLinked() = %v, want %v", got, tt.want) } }) } } func Test_linkedEAKToCertificates(t *testing.T) { tests := []struct { name string k *linkedca.EABKey want *acme.ExternalAccountKey }{ { name: "no-key", k: nil, want: nil, }, { name: "no-policy", k: &linkedca.EABKey{ Id: "keyID", Provisioner: "provID", HmacKey: []byte{1, 3, 3, 7}, Reference: "ref", Account: "accID", CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), Policy: nil, }, want: &acme.ExternalAccountKey{ ID: "keyID", ProvisionerID: "provID", Reference: "ref", AccountID: "accID", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), Policy: nil, }, }, { name: "no-x509-policy", k: &linkedca.EABKey{ Id: "keyID", Provisioner: "provID", HmacKey: []byte{1, 3, 3, 7}, Reference: "ref", Account: "accID", CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), Policy: &linkedca.Policy{}, }, want: &acme.ExternalAccountKey{ ID: "keyID", ProvisionerID: "provID", Reference: "ref", AccountID: "accID", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), Policy: &acme.Policy{}, }, }, { name: "with-x509-policy", k: &linkedca.EABKey{ Id: "keyID", Provisioner: "provID", HmacKey: []byte{1, 3, 3, 7}, Reference: "ref", Account: "accID", CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)), BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)), Policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"10.0.0.0/24"}, }, Deny: &linkedca.X509Names{ Dns: []string{"badhost.local"}, Ips: []string{"10.0.0.30"}, }, AllowWildcardNames: true, }, }, }, want: &acme.ExternalAccountKey{ ID: "keyID", ProvisionerID: "provID", Reference: "ref", AccountID: "accID", HmacKey: []byte{1, 3, 3, 7}, CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour), BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC), Policy: &acme.Policy{ X509: acme.X509Policy{ Allowed: acme.PolicyNames{ DNSNames: []string{"*.local"}, IPRanges: []string{"10.0.0.0/24"}, }, Denied: acme.PolicyNames{ DNSNames: []string{"badhost.local"}, IPRanges: []string{"10.0.0.30"}, }, AllowWildcardNames: true, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := linkedEAKToCertificates(tt.k); !reflect.DeepEqual(got, tt.want) { t.Errorf("linkedEAKToCertificates() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/admin/api/admin.go ================================================ package api import ( "context" "net/http" "github.com/go-chi/chi/v5" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" ) type adminAuthority interface { LoadProvisionerByName(string) (provisioner.Interface, error) GetProvisioners(cursor string, limit int) (provisioner.List, string, error) IsAdminAPIEnabled() bool LoadAdminByID(id string) (*linkedca.Admin, bool) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) RemoveAdmin(ctx context.Context, id string) error AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error LoadProvisionerByID(id string) (provisioner.Interface, error) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error RemoveProvisioner(ctx context.Context, id string) error GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) RemoveAuthorityPolicy(ctx context.Context) error } // CreateAdminRequest represents the body for a CreateAdmin request. type CreateAdminRequest struct { Subject string `json:"subject"` Provisioner string `json:"provisioner"` Type linkedca.Admin_Type `json:"type"` } // Validate validates a new-admin request body. func (car *CreateAdminRequest) Validate() error { if car.Subject == "" { return admin.NewError(admin.ErrorBadRequestType, "subject cannot be empty") } if car.Provisioner == "" { return admin.NewError(admin.ErrorBadRequestType, "provisioner cannot be empty") } switch car.Type { case linkedca.Admin_SUPER_ADMIN, linkedca.Admin_ADMIN: default: return admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type") } return nil } // GetAdminsResponse for returning a list of admins. type GetAdminsResponse struct { Admins []*linkedca.Admin `json:"admins"` NextCursor string `json:"nextCursor"` } // UpdateAdminRequest represents the body for a UpdateAdmin request. type UpdateAdminRequest struct { Type linkedca.Admin_Type `json:"type"` } // Validate validates a new-admin request body. func (uar *UpdateAdminRequest) Validate() error { switch uar.Type { case linkedca.Admin_SUPER_ADMIN, linkedca.Admin_ADMIN: default: return admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type") } return nil } // DeleteResponse is the resource for successful DELETE responses. type DeleteResponse struct { Status string `json:"status"` } // GetAdmin returns the requested admin, or an error. func GetAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") adm, ok := mustAuthority(r.Context()).LoadAdminByID(id) if !ok { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)) return } render.ProtoJSON(w, adm) } // GetAdmins returns a segment of admins associated with the authority. func GetAdmins(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit) if err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error retrieving paginated admins")) return } render.JSON(w, r, &GetAdminsResponse{ Admins: admins, NextCursor: nextCursor, }) } // CreateAdmin creates a new admin. func CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { render.Error(w, r, err) return } auth := mustAuthority(r.Context()) p, err := auth.LoadProvisionerByName(body.Provisioner) if err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) return } adm := &linkedca.Admin{ ProvisionerId: p.GetID(), Subject: body.Subject, Type: body.Type, } // Store to authority collection. if err := auth.StoreAdmin(r.Context(), adm, p); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error storing admin")) return } render.ProtoJSONStatus(w, adm, http.StatusCreated) } // DeleteAdmin deletes admin. func DeleteAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error deleting admin %s", id)) return } render.JSON(w, r, &DeleteResponse{Status: "ok"}) } // UpdateAdmin updates an existing admin. func UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { render.Error(w, r, err) return } id := chi.URLParam(r, "id") auth := mustAuthority(r.Context()) adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) if err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error updating admin %s", id)) return } render.ProtoJSON(w, adm) } ================================================ FILE: authority/admin/api/admin_test.go ================================================ package api import ( "bytes" "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/go-chi/chi/v5" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/protobuf/types/known/timestamppb" "github.com/smallstep/linkedca" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" ) type mockAdminAuthority struct { MockLoadProvisionerByName func(name string) (provisioner.Interface, error) MockGetProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) MockRet1, MockRet2 interface{} // TODO: refactor the ret1/ret2 into those two MockErr error MockIsAdminAPIEnabled func() bool MockLoadAdminByID func(id string) (*linkedca.Admin, bool) MockGetAdmins func(cursor string, limit int) ([]*linkedca.Admin, string, error) MockStoreAdmin func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error MockUpdateAdmin func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) MockRemoveAdmin func(ctx context.Context, id string) error MockAuthorizeAdminToken func(r *http.Request, token string) (*linkedca.Admin, error) MockStoreProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error MockLoadProvisionerByID func(id string) (provisioner.Interface, error) MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error MockRemoveProvisioner func(ctx context.Context, id string) error MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error) MockCreateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) MockUpdateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) MockRemoveAuthorityPolicy func(ctx context.Context) error } func (m *mockAdminAuthority) IsAdminAPIEnabled() bool { if m.MockIsAdminAPIEnabled != nil { return m.MockIsAdminAPIEnabled() } return m.MockRet1.(bool) } func (m *mockAdminAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { if m.MockLoadProvisionerByName != nil { return m.MockLoadProvisionerByName(name) } return m.MockRet1.(provisioner.Interface), m.MockErr } func (m *mockAdminAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { if m.MockGetProvisioners != nil { return m.MockGetProvisioners(nextCursor, limit) } return m.MockRet1.(provisioner.List), m.MockRet2.(string), m.MockErr } func (m *mockAdminAuthority) LoadAdminByID(id string) (*linkedca.Admin, bool) { if m.MockLoadAdminByID != nil { return m.MockLoadAdminByID(id) } return m.MockRet1.(*linkedca.Admin), m.MockRet2.(bool) } func (m *mockAdminAuthority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) { if m.MockGetAdmins != nil { return m.MockGetAdmins(cursor, limit) } return m.MockRet1.([]*linkedca.Admin), m.MockRet2.(string), m.MockErr } func (m *mockAdminAuthority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { if m.MockStoreAdmin != nil { return m.MockStoreAdmin(ctx, adm, prov) } return m.MockErr } func (m *mockAdminAuthority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { if m.MockUpdateAdmin != nil { return m.MockUpdateAdmin(ctx, id, nu) } return m.MockRet1.(*linkedca.Admin), m.MockErr } func (m *mockAdminAuthority) RemoveAdmin(ctx context.Context, id string) error { if m.MockRemoveAdmin != nil { return m.MockRemoveAdmin(ctx, id) } return m.MockErr } func (m *mockAdminAuthority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) { if m.MockAuthorizeAdminToken != nil { return m.MockAuthorizeAdminToken(r, token) } return m.MockRet1.(*linkedca.Admin), m.MockErr } func (m *mockAdminAuthority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { if m.MockStoreProvisioner != nil { return m.MockStoreProvisioner(ctx, prov) } return m.MockErr } func (m *mockAdminAuthority) LoadProvisionerByID(id string) (provisioner.Interface, error) { if m.MockLoadProvisionerByID != nil { return m.MockLoadProvisionerByID(id) } return m.MockRet1.(provisioner.Interface), m.MockErr } func (m *mockAdminAuthority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error { if m.MockUpdateProvisioner != nil { return m.MockUpdateProvisioner(ctx, nu) } return m.MockErr } func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) error { if m.MockRemoveProvisioner != nil { return m.MockRemoveProvisioner(ctx, id) } return m.MockErr } func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { if m.MockGetAuthorityPolicy != nil { return m.MockGetAuthorityPolicy(ctx) } return m.MockRet1.(*linkedca.Policy), m.MockErr } func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { if m.MockCreateAuthorityPolicy != nil { return m.MockCreateAuthorityPolicy(ctx, adm, policy) } return m.MockRet1.(*linkedca.Policy), m.MockErr } func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { if m.MockUpdateAuthorityPolicy != nil { return m.MockUpdateAuthorityPolicy(ctx, adm, policy) } return m.MockRet1.(*linkedca.Policy), m.MockErr } func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error { if m.MockRemoveAuthorityPolicy != nil { return m.MockRemoveAuthorityPolicy(ctx) } return m.MockErr } func TestCreateAdminRequest_Validate(t *testing.T) { type fields struct { Subject string Provisioner string Type linkedca.Admin_Type } tests := []struct { name string fields fields err *admin.Error }{ { name: "fail/subject-empty", fields: fields{ Subject: "", Provisioner: "", Type: 0, }, err: admin.NewError(admin.ErrorBadRequestType, "subject cannot be empty"), }, { name: "fail/provisioner-empty", fields: fields{ Subject: "admin", Provisioner: "", Type: 0, }, err: admin.NewError(admin.ErrorBadRequestType, "provisioner cannot be empty"), }, { name: "fail/invalid-type", fields: fields{ Subject: "admin", Provisioner: "prov", Type: -1, }, err: admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type"), }, { name: "ok", fields: fields{ Subject: "admin", Provisioner: "prov", Type: linkedca.Admin_SUPER_ADMIN, }, err: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { car := &CreateAdminRequest{ Subject: tt.fields.Subject, Provisioner: tt.fields.Provisioner, Type: tt.fields.Type, } err := car.Validate() if (err != nil) != (tt.err != nil) { t.Errorf("CreateAdminRequest.Validate() error = %v, wantErr %v", err, (tt.err != nil)) return } if err != nil { assert.Type(t, &admin.Error{}, err) var adminErr *admin.Error if assert.True(t, errors.As(err, &adminErr)) { assert.Equals(t, tt.err.Type, adminErr.Type) assert.Equals(t, tt.err.Detail, adminErr.Detail) assert.Equals(t, tt.err.Status, adminErr.Status) assert.Equals(t, tt.err.Message, adminErr.Message) } } }) } } func TestUpdateAdminRequest_Validate(t *testing.T) { type fields struct { Type linkedca.Admin_Type } tests := []struct { name string fields fields err *admin.Error }{ { name: "fail/invalid-type", fields: fields{ Type: -1, }, err: admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type"), }, { name: "ok", fields: fields{ Type: linkedca.Admin_SUPER_ADMIN, }, err: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { uar := &UpdateAdminRequest{ Type: tt.fields.Type, } err := uar.Validate() if (err != nil) != (tt.err != nil) { t.Errorf("CreateAdminRequest.Validate() error = %v, wantErr %v", err, (tt.err != nil)) return } if err != nil { assert.Type(t, &admin.Error{}, err) var ae *admin.Error if assert.True(t, errors.As(err, &ae)) { assert.Equals(t, tt.err.Type, ae.Type) assert.Equals(t, tt.err.Detail, ae.Detail) assert.Equals(t, tt.err.Status, ae.Status) assert.Equals(t, tt.err.Message, ae.Message) } } }) } } func TestHandler_GetAdmin(t *testing.T) { type test struct { ctx context.Context auth adminAuthority statusCode int err *admin.Error adm *linkedca.Admin } var tests = map[string]func(t *testing.T) test{ "fail/auth.LoadAdminByID-not-found": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("id", "adminID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadAdminByID: func(id string) (*linkedca.Admin, bool) { assert.Equals(t, "adminID", id) return nil, false }, } return test{ ctx: ctx, auth: auth, statusCode: 404, err: &admin.Error{ Type: admin.ErrorNotFoundType.String(), Status: 404, Detail: "resource not found", Message: "admin adminID not found", }, } }, "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("id", "adminID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) createdAt := time.Now() var deletedAt time.Time adm := &linkedca.Admin{ Id: "adminID", AuthorityId: "authorityID", Subject: "admin", ProvisionerId: "provID", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), } auth := &mockAdminAuthority{ MockLoadAdminByID: func(id string) (*linkedca.Admin, bool) { assert.Equals(t, "adminID", id) return adm, true }, } return test{ ctx: ctx, auth: auth, statusCode: 200, err: nil, adm: adm, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", http.NoBody) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() GetAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } adm := &linkedca.Admin{} err := readProtoJSON(res.Body, adm) assert.FatalError(t, err) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} if !cmp.Equal(tc.adm, adm, opts...) { t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(tc.adm, adm, opts...)) } }) } } func TestHandler_GetAdmins(t *testing.T) { type test struct { ctx context.Context auth adminAuthority req *http.Request statusCode int err *admin.Error resp GetAdminsResponse } var tests = map[string]func(t *testing.T) test{ "fail/parse-cursor": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo?limit=A", http.NoBody) return test{ ctx: context.Background(), req: req, statusCode: 400, err: &admin.Error{ Status: 400, Type: admin.ErrorBadRequestType.String(), Detail: "bad request", Message: "error parsing cursor and limit from query params: limit 'A' is not an integer: strconv.Atoi: parsing \"A\": invalid syntax", }, } }, "fail/auth.GetAdmins": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", http.NoBody) auth := &mockAdminAuthority{ MockGetAdmins: func(cursor string, limit int) ([]*linkedca.Admin, string, error) { assert.Equals(t, "", cursor) assert.Equals(t, 0, limit) return nil, "", errors.New("force") }, } return test{ ctx: context.Background(), req: req, auth: auth, statusCode: 500, err: &admin.Error{ Status: 500, Type: admin.ErrorServerInternalType.String(), Detail: "the server experienced an internal error", Message: "error retrieving paginated admins: force", }, } }, "ok": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", http.NoBody) createdAt := time.Now() var deletedAt time.Time adm1 := &linkedca.Admin{ Id: "adminID1", AuthorityId: "authorityID1", Subject: "admin1", ProvisionerId: "provID", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), } adm2 := &linkedca.Admin{ Id: "adminID2", AuthorityId: "authorityID", Subject: "admin2", ProvisionerId: "provID", Type: linkedca.Admin_ADMIN, CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), } auth := &mockAdminAuthority{ MockGetAdmins: func(cursor string, limit int) ([]*linkedca.Admin, string, error) { assert.Equals(t, "", cursor) assert.Equals(t, 0, limit) return []*linkedca.Admin{ adm1, adm2, }, "nextCursorValue", nil }, } return test{ ctx: context.Background(), req: req, auth: auth, statusCode: 200, err: nil, resp: GetAdminsResponse{ Admins: []*linkedca.Admin{ adm1, adm2, }, NextCursor: "nextCursorValue", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() GetAdmins(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 { adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } response := GetAdminsResponse{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} if !cmp.Equal(tc.resp, response, opts...) { t.Errorf("GetAdmins diff =\n%s", cmp.Diff(tc.resp, response, opts...)) } }) } } func TestHandler_CreateAdmin(t *testing.T) { type test struct { ctx context.Context auth adminAuthority body []byte statusCode int err *admin.Error adm *linkedca.Admin } var tests = map[string]func(t *testing.T) test{ "fail/ReadJSON": func(t *testing.T) test { body := []byte("{!?}") return test{ ctx: context.Background(), body: body, statusCode: 400, err: &admin.Error{ Type: admin.ErrorBadRequestType.String(), Status: 400, Detail: "bad request", Message: "error reading request body: error decoding json: invalid character '!' looking for beginning of object key string", }, } }, "fail/validate": func(t *testing.T) test { req := CreateAdminRequest{ Subject: "", Provisioner: "", Type: -1, } body, err := json.Marshal(req) assert.FatalError(t, err) return test{ ctx: context.Background(), body: body, statusCode: 400, err: &admin.Error{ Type: admin.ErrorBadRequestType.String(), Status: 400, Detail: "bad request", Message: "subject cannot be empty", }, } }, "fail/auth.LoadProvisionerByName": func(t *testing.T) test { req := CreateAdminRequest{ Subject: "admin", Provisioner: "prov", Type: linkedca.Admin_SUPER_ADMIN, } body, err := json.Marshal(req) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "prov", name) return nil, errors.New("force") }, } return test{ ctx: context.Background(), body: body, auth: auth, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error loading provisioner prov: force", }, } }, "fail/auth.StoreAdmin": func(t *testing.T) test { req := CreateAdminRequest{ Subject: "admin", Provisioner: "prov", Type: linkedca.Admin_SUPER_ADMIN, } body, err := json.Marshal(req) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "prov", name) return &provisioner.ACME{ ID: "provID", Name: "prov", }, nil }, MockStoreAdmin: func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { assert.Equals(t, "admin", adm.Subject) assert.Equals(t, "provID", prov.GetID()) return errors.New("force") }, } return test{ ctx: context.Background(), body: body, auth: auth, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error storing admin: force", }, } }, "ok": func(t *testing.T) test { req := CreateAdminRequest{ Subject: "admin", Provisioner: "prov", Type: linkedca.Admin_SUPER_ADMIN, } body, err := json.Marshal(req) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "prov", name) return &provisioner.ACME{ ID: "provID", Name: "prov", }, nil }, MockStoreAdmin: func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { assert.Equals(t, "admin", adm.Subject) assert.Equals(t, "provID", prov.GetID()) return nil }, } return test{ ctx: context.Background(), body: body, auth: auth, statusCode: 201, err: nil, adm: &linkedca.Admin{ ProvisionerId: "provID", Subject: "admin", Type: linkedca.Admin_SUPER_ADMIN, }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() CreateAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } adm := &linkedca.Admin{} err := readProtoJSON(res.Body, adm) assert.FatalError(t, err) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} if !cmp.Equal(tc.adm, adm, opts...) { t.Errorf("h.CreateAdmin diff =\n%s", cmp.Diff(tc.adm, adm, opts...)) } }) } } func TestHandler_DeleteAdmin(t *testing.T) { type test struct { ctx context.Context auth adminAuthority statusCode int err *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/auth.RemoveAdmin": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("id", "adminID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockRemoveAdmin: func(ctx context.Context, id string) error { assert.Equals(t, "adminID", id) return errors.New("force") }, } return test{ ctx: ctx, auth: auth, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error deleting admin adminID: force", }, } }, "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("id", "adminID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockRemoveAdmin: func(ctx context.Context, id string) error { assert.Equals(t, "adminID", id) return nil }, } return test{ ctx: ctx, auth: auth, statusCode: 200, err: nil, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := httptest.NewRequest("DELETE", "/foo", http.NoBody) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() DeleteAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) response := DeleteResponse{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) assert.Equals(t, "ok", response.Status) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestHandler_UpdateAdmin(t *testing.T) { type test struct { ctx context.Context auth adminAuthority body []byte statusCode int err *admin.Error adm *linkedca.Admin } var tests = map[string]func(t *testing.T) test{ "fail/ReadJSON": func(t *testing.T) test { body := []byte("{!?}") return test{ ctx: context.Background(), body: body, statusCode: 400, err: &admin.Error{ Type: admin.ErrorBadRequestType.String(), Status: 400, Detail: "bad request", Message: "error reading request body: error decoding json: invalid character '!' looking for beginning of object key string", }, } }, "fail/validate": func(t *testing.T) test { req := UpdateAdminRequest{ Type: -1, } body, err := json.Marshal(req) assert.FatalError(t, err) return test{ ctx: context.Background(), body: body, statusCode: 400, err: &admin.Error{ Type: admin.ErrorBadRequestType.String(), Status: 400, Detail: "bad request", Message: "invalid value for admin type", }, } }, "fail/auth.UpdateAdmin": func(t *testing.T) test { req := UpdateAdminRequest{ Type: linkedca.Admin_ADMIN, } body, err := json.Marshal(req) assert.FatalError(t, err) chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("id", "adminID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockUpdateAdmin: func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { assert.Equals(t, "adminID", id) assert.Equals(t, linkedca.Admin_ADMIN, nu.Type) return nil, errors.New("force") }, } return test{ ctx: ctx, body: body, auth: auth, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error updating admin adminID: force", }, } }, "ok": func(t *testing.T) test { req := UpdateAdminRequest{ Type: linkedca.Admin_ADMIN, } body, err := json.Marshal(req) assert.FatalError(t, err) chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("id", "adminID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) adm := &linkedca.Admin{ Id: "adminID", ProvisionerId: "provID", Subject: "admin", Type: linkedca.Admin_SUPER_ADMIN, } auth := &mockAdminAuthority{ MockUpdateAdmin: func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { assert.Equals(t, "adminID", id) assert.Equals(t, linkedca.Admin_ADMIN, nu.Type) return adm, nil }, } return test{ ctx: ctx, body: body, auth: auth, statusCode: 200, err: nil, adm: adm, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() UpdateAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } adm := &linkedca.Admin{} err := readProtoJSON(res.Body, adm) assert.FatalError(t, err) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} if !cmp.Equal(tc.adm, adm, opts...) { t.Errorf("h.UpdateAdmin diff =\n%s", cmp.Diff(tc.adm, adm, opts...)) } }) } } ================================================ FILE: authority/admin/api/handler.go ================================================ package api import ( "context" "net/http" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" ) var mustAuthority = func(ctx context.Context) adminAuthority { return authority.MustFromContext(ctx) } type router struct { acmeResponder ACMEAdminResponder policyResponder PolicyAdminResponder webhookResponder WebhookAdminResponder } type RouterOption func(*router) func WithACMEResponder(acmeResponder ACMEAdminResponder) RouterOption { return func(r *router) { r.acmeResponder = acmeResponder } } func WithPolicyResponder(policyResponder PolicyAdminResponder) RouterOption { return func(r *router) { r.policyResponder = policyResponder } } func WithWebhookResponder(webhookResponder WebhookAdminResponder) RouterOption { return func(r *router) { r.webhookResponder = webhookResponder } } // Route traffic and implement the Router interface. func Route(r api.Router, options ...RouterOption) { router := &router{} for _, fn := range options { fn(router) } authnz := func(next http.HandlerFunc) http.HandlerFunc { return extractAuthorizeTokenAdmin(requireAPIEnabled(next)) } enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc { return checkAction(next, true) } disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc { return checkAction(next, false) } acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc { return authnz(loadProvisionerByName(requireEABEnabled(next))) } authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { return authnz(enabledInStandalone(next)) } provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { return authnz(disabledInStandalone(loadProvisionerByName(next))) } acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next))))) } webhookMiddleware := func(next http.HandlerFunc) http.HandlerFunc { return authnz(loadProvisionerByName(next)) } // Provisioners r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner)) r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners)) r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner)) r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner)) r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner)) // Admins r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin)) r.MethodFunc("GET", "/admins", authnz(GetAdmins)) r.MethodFunc("POST", "/admins", authnz(CreateAdmin)) r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin)) r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin)) // ACME responder if router.acmeResponder != nil { // ACME External Account Binding Keys r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(router.acmeResponder.GetExternalAccountKeys)) r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(router.acmeResponder.GetExternalAccountKeys)) r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(router.acmeResponder.CreateExternalAccountKey)) r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(router.acmeResponder.DeleteExternalAccountKey)) } // Policy responder if router.policyResponder != nil { // Policy - Authority r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(router.policyResponder.GetAuthorityPolicy)) r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(router.policyResponder.CreateAuthorityPolicy)) r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(router.policyResponder.UpdateAuthorityPolicy)) r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(router.policyResponder.DeleteAuthorityPolicy)) // Policy - Provisioner r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(router.policyResponder.GetProvisionerPolicy)) r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(router.policyResponder.CreateProvisionerPolicy)) r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(router.policyResponder.UpdateProvisionerPolicy)) r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(router.policyResponder.DeleteProvisionerPolicy)) // Policy - ACME Account r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(router.policyResponder.GetACMEAccountPolicy)) r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(router.policyResponder.GetACMEAccountPolicy)) r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(router.policyResponder.CreateACMEAccountPolicy)) r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(router.policyResponder.CreateACMEAccountPolicy)) r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(router.policyResponder.UpdateACMEAccountPolicy)) r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(router.policyResponder.UpdateACMEAccountPolicy)) r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(router.policyResponder.DeleteACMEAccountPolicy)) r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(router.policyResponder.DeleteACMEAccountPolicy)) } if router.webhookResponder != nil { r.MethodFunc("POST", "/provisioners/{provisionerName}/webhooks", webhookMiddleware(router.webhookResponder.CreateProvisionerWebhook)) r.MethodFunc("PUT", "/provisioners/{provisionerName}/webhooks/{webhookName}", webhookMiddleware(router.webhookResponder.UpdateProvisionerWebhook)) r.MethodFunc("DELETE", "/provisioners/{provisionerName}/webhooks/{webhookName}", webhookMiddleware(router.webhookResponder.DeleteProvisionerWebhook)) } } ================================================ FILE: authority/admin/api/middleware.go ================================================ package api import ( "net/http" "github.com/go-chi/chi/v5" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin/db/nosql" "github.com/smallstep/certificates/authority/provisioner" ) // requireAPIEnabled is a middleware that ensures the Administration API // is enabled before servicing requests. func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if !mustAuthority(r.Context()).IsAdminAPIEnabled() { render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) return } next(w, r) } } // extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { tok := r.Header.Get("Authorization") if tok == "" { render.Error(w, r, admin.NewError(admin.ErrorUnauthorizedType, "missing authorization header token")) return } ctx := r.Context() adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok) if err != nil { render.Error(w, r, err) return } ctx = linkedca.NewContextWithAdmin(ctx, adm) next(w, r.WithContext(ctx)) } } // loadProvisionerByName is a middleware that searches for a provisioner // by name and stores it in the context. func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) ctx := r.Context() auth := mustAuthority(ctx) adminDB := admin.MustFromContext(ctx) name := chi.URLParam(r, "provisionerName") // TODO(hs): distinguish 404 vs. 500 if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } prov, err := adminDB.GetProvisioner(ctx, p.GetID()) if err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error retrieving provisioner %s", name)) return } ctx = linkedca.NewContextWithProvisioner(ctx, prov) next(w, r.WithContext(ctx)) } } // checkAction checks if an action is supported in standalone or not func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // actions allowed in standalone mode are always supported if supportedInStandalone { next(w, r) return } // when an action is not supported in standalone mode and when // using a nosql.DB backend, actions are not supported if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok { render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode")) return } // continue to next http handler next(w, r) } } // loadExternalAccountKey is a middleware that searches for an ACME // External Account Key by reference or keyID and stores it in the context. func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) acmeDB := acme.MustDatabaseFromContext(ctx) reference := chi.URLParam(r, "reference") keyID := chi.URLParam(r, "keyID") var ( eak *acme.ExternalAccountKey err error ) if keyID != "" { eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID) } else { eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference) } if err != nil { if acme.IsErrNotFound(err) { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) return } render.Error(w, r, admin.WrapErrorISE(err, "error retrieving ACME External Account Key")) return } if eak == nil { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")) return } linkedEAK := eakToLinked(eak) ctx = linkedca.NewContextWithExternalAccountKey(ctx, linkedEAK) next(w, r.WithContext(ctx)) } } ================================================ FILE: authority/admin/api/middleware_test.go ================================================ package api import ( "bytes" "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/go-chi/chi/v5" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/protobuf/types/known/timestamppb" "github.com/smallstep/linkedca" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin/db/nosql" "github.com/smallstep/certificates/authority/provisioner" ) func TestHandler_requireAPIEnabled(t *testing.T) { type test struct { ctx context.Context auth adminAuthority next http.HandlerFunc err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/auth.IsAdminAPIEnabled": func(t *testing.T) test { return test{ ctx: context.Background(), auth: &mockAdminAuthority{ MockIsAdminAPIEnabled: func() bool { return false }, }, err: &admin.Error{ Type: admin.ErrorNotImplementedType.String(), Status: 501, Detail: "not implemented", Message: "administration API not enabled", }, statusCode: 501, } }, "ok": func(t *testing.T) test { auth := &mockAdminAuthority{ MockIsAdminAPIEnabled: func() bool { return true }, } next := func(w http.ResponseWriter, r *http.Request) { w.Write(nil) // mock response with status 200 } return test{ ctx: context.Background(), auth: auth, next: next, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", http.NoBody) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() requireAPIEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 { err := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) assert.Equals(t, tc.err.Type, err.Type) assert.Equals(t, tc.err.Message, err.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, err.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } // nothing to test when the requireAPIEnabled middleware succeeds, currently }) } } func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { type test struct { ctx context.Context auth adminAuthority req *http.Request next http.HandlerFunc err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/missing-authorization-token": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", http.NoBody) req.Header["Authorization"] = []string{""} return test{ ctx: context.Background(), req: req, statusCode: 401, err: &admin.Error{ Type: admin.ErrorUnauthorizedType.String(), Status: 401, Detail: "unauthorized", Message: "missing authorization header token", }, } }, "fail/auth.AuthorizeAdminToken": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", http.NoBody) req.Header["Authorization"] = []string{"token"} auth := &mockAdminAuthority{ MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { assert.Equals(t, "token", token) return nil, admin.NewError( admin.ErrorUnauthorizedType, "not authorized", ) }, } return test{ ctx: context.Background(), auth: auth, req: req, statusCode: 401, err: &admin.Error{ Type: admin.ErrorUnauthorizedType.String(), Status: 401, Detail: "unauthorized", Message: "not authorized", }, } }, "ok": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", http.NoBody) req.Header["Authorization"] = []string{"token"} createdAt := time.Now() var deletedAt time.Time adm := &linkedca.Admin{ Id: "adminID", AuthorityId: "authorityID", Subject: "admin", ProvisionerId: "provID", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), } auth := &mockAdminAuthority{ MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { assert.Equals(t, "token", token) return adm, nil }, } next := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() adm := linkedca.MustAdminFromContext(ctx) // verifying that the context now has a linkedca.Admin opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} if !cmp.Equal(adm, adm, opts...) { t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...)) } w.Write(nil) // mock response with status 200 } return test{ ctx: context.Background(), auth: auth, req: req, next: next, statusCode: 200, err: nil, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() extractAuthorizeTokenAdmin(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 { err := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) assert.Equals(t, tc.err.Type, err.Type) assert.Equals(t, tc.err.Message, err.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, err.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } }) } } func TestHandler_loadProvisionerByName(t *testing.T) { type test struct { adminDB admin.DB auth adminAuthority ctx context.Context next http.HandlerFunc err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/auth.LoadProvisionerByName": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return nil, errors.New("force") }, } err := admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName") err.Message = "error loading provisioner provName: force" return test{ ctx: ctx, auth: auth, adminDB: &admin.MockDB{}, statusCode: 500, err: err, } }, "fail/db.GetProvisioner": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ MgetID: func() string { return "provID" }, }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return nil, errors.New("force") }, } err := admin.WrapErrorISE(errors.New("force"), "error retrieving provisioner provName") err.Message = "error retrieving provisioner provName: force" return test{ ctx: ctx, auth: auth, adminDB: db, statusCode: 500, err: err, } }, "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("provisionerName", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.MockProvisioner{ MgetID: func() string { return "provID" }, }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", }, nil }, } return test{ ctx: ctx, auth: auth, adminDB: db, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { prov := linkedca.MustProvisionerFromContext(r.Context()) assert.NotNil(t, prov) assert.Equals(t, "provID", prov.GetId()) assert.Equals(t, "provName", prov.GetName()) w.Write(nil) // mock response with status 200 }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) req := httptest.NewRequest("GET", "/foo", http.NoBody) // chi routing is prepared in test setup req = req.WithContext(ctx) w := httptest.NewRecorder() loadProvisionerByName(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 { err := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) assert.Equals(t, tc.err.Type, err.Type) assert.Equals(t, tc.err.Message, err.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, err.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } }) } } func TestHandler_checkAction(t *testing.T) { type test struct { adminDB admin.DB next http.HandlerFunc supportedInStandalone bool err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "standalone-nosql-supported": func(t *testing.T) test { return test{ supportedInStandalone: true, adminDB: &nosql.DB{}, next: func(w http.ResponseWriter, r *http.Request) { w.Write(nil) // mock response with status 200 }, statusCode: 200, } }, "standalone-nosql-not-supported": func(t *testing.T) test { err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode") err.Message = "operation not supported in standalone mode" return test{ supportedInStandalone: false, adminDB: &nosql.DB{}, statusCode: 501, err: err, } }, "standalone-no-nosql-not-supported": func(t *testing.T) test { err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported") err.Message = "operation not supported" return test{ supportedInStandalone: false, adminDB: &admin.MockDB{}, next: func(w http.ResponseWriter, r *http.Request) { w.Write(nil) // mock response with status 200 }, statusCode: 200, err: err, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { ctx := admin.NewContext(context.Background(), tc.adminDB) req := httptest.NewRequest("GET", "/foo", http.NoBody).WithContext(ctx) w := httptest.NewRecorder() checkAction(tc.next, tc.supportedInStandalone)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 { err := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) assert.Equals(t, tc.err.Type, err.Type) assert.Equals(t, tc.err.Message, err.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, err.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } }) } } func TestHandler_loadExternalAccountKey(t *testing.T) { type test struct { ctx context.Context acmeDB acme.DB next http.HandlerFunc err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/keyID-not-found-error": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", } chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("keyID", "key") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) ctx = linkedca.NewContextWithProvisioner(ctx, prov) err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") err.Message = "ACME External Account Key not found" return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, "provID", provisionerID) assert.Equals(t, "key", keyID) return nil, acme.ErrNotFound }, }, err: err, statusCode: 404, } }, "fail/keyID-error": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", } chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("keyID", "key") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) ctx = linkedca.NewContextWithProvisioner(ctx, prov) err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key") err.Message = "error retrieving ACME External Account Key: force" return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, "provID", provisionerID) assert.Equals(t, "key", keyID) return nil, errors.New("force") }, }, err: err, statusCode: 500, } }, "fail/reference-not-found-error": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", } chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("reference", "ref") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) ctx = linkedca.NewContextWithProvisioner(ctx, prov) err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") err.Message = "ACME External Account Key not found" return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { assert.Equals(t, "provID", provisionerID) assert.Equals(t, "ref", reference) return nil, acme.ErrNotFound }, }, err: err, statusCode: 404, } }, "fail/reference-error": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", } chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("reference", "ref") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) ctx = linkedca.NewContextWithProvisioner(ctx, prov) err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key") err.Message = "error retrieving ACME External Account Key: force" return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { assert.Equals(t, "provID", provisionerID) assert.Equals(t, "ref", reference) return nil, errors.New("force") }, }, err: err, statusCode: 500, } }, "fail/no-key": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", } chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("reference", "ref") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) ctx = linkedca.NewContextWithProvisioner(ctx, prov) err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") err.Message = "ACME External Account Key not found" return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { assert.Equals(t, "provID", provisionerID) assert.Equals(t, "ref", reference) return nil, nil }, }, err: err, statusCode: 404, } }, "ok/keyID": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", } chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("keyID", "eakID") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) ctx = linkedca.NewContextWithProvisioner(ctx, prov) err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") err.Message = "ACME External Account Key not found" createdAt := time.Now().Add(-1 * time.Hour) var boundAt time.Time eak := &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: "provID", CreatedAt: createdAt, BoundAt: boundAt, } return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { assert.Equals(t, "provID", provisionerID) assert.Equals(t, "eakID", keyID) return eak, nil }, }, next: func(w http.ResponseWriter, r *http.Request) { contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context()) assert.NotNil(t, eak) exp := &linkedca.EABKey{ Id: "eakID", Provisioner: "provID", CreatedAt: timestamppb.New(createdAt), BoundAt: timestamppb.New(boundAt), } assert.Equals(t, exp, contextEAK) w.Write(nil) // mock response with status 200 }, err: nil, statusCode: 200, } }, "ok/reference": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", } chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("reference", "ref") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) ctx = linkedca.NewContextWithProvisioner(ctx, prov) err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found") err.Message = "ACME External Account Key not found" createdAt := time.Now().Add(-1 * time.Hour) var boundAt time.Time eak := &acme.ExternalAccountKey{ ID: "eakID", ProvisionerID: "provID", Reference: "ref", CreatedAt: createdAt, BoundAt: boundAt, } return test{ ctx: ctx, acmeDB: &acme.MockDB{ MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { assert.Equals(t, "provID", provisionerID) assert.Equals(t, "ref", reference) return eak, nil }, }, next: func(w http.ResponseWriter, r *http.Request) { contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context()) assert.NotNil(t, eak) exp := &linkedca.EABKey{ Id: "eakID", Provisioner: "provID", Reference: "ref", CreatedAt: timestamppb.New(createdAt), BoundAt: timestamppb.New(boundAt), } assert.Equals(t, exp, contextEAK) w.Write(nil) // mock response with status 200 }, err: nil, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB) req := httptest.NewRequest("GET", "/foo", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() loadExternalAccountKey(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) if res.StatusCode >= 400 { err := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) assert.Equals(t, tc.err.Type, err.Type) assert.Equals(t, tc.err.Message, err.Message) assert.Equals(t, tc.err.StatusCode(), res.StatusCode) assert.Equals(t, tc.err.Detail, err.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } }) } } ================================================ FILE: authority/admin/api/policy.go ================================================ package api import ( "context" "errors" "net/http" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/policy" ) // PolicyAdminResponder is the interface responsible for writing ACME admin // responses. type PolicyAdminResponder interface { GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) } // policyAdminResponder implements PolicyAdminResponder. type policyAdminResponder struct{} // NewPolicyAdminResponder returns a new PolicyAdminResponder. func NewPolicyAdminResponder() PolicyAdminResponder { return &policyAdminResponder{} } // GetAuthorityPolicy handles the GET /admin/authority/policy request func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } auth := mustAuthority(ctx) authorityPolicy, err := auth.GetAuthorityPolicy(r.Context()) var ae *admin.Error if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, r, admin.WrapErrorISE(ae, "error retrieving authority policy")) return } if authorityPolicy == nil { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) return } render.ProtoJSONStatus(w, authorityPolicy, http.StatusOK) } // CreateAuthorityPolicy handles the POST /admin/authority/policy request func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } auth := mustAuthority(ctx) authorityPolicy, err := auth.GetAuthorityPolicy(ctx) var ae *admin.Error if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, r, admin.WrapErrorISE(err, "error retrieving authority policy")) return } if authorityPolicy != nil { adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy") render.Error(w, r, adminErr) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) return } adm := linkedca.MustAdminFromContext(ctx) var createdPolicy *linkedca.Policy if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if isBadRequest(err) { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) return } render.Error(w, r, admin.WrapErrorISE(err, "error storing authority policy")) return } render.ProtoJSONStatus(w, createdPolicy, http.StatusCreated) } // UpdateAuthorityPolicy handles the PUT /admin/authority/policy request func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } auth := mustAuthority(ctx) authorityPolicy, err := auth.GetAuthorityPolicy(ctx) var ae *admin.Error if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, r, admin.WrapErrorISE(err, "error retrieving authority policy")) return } if authorityPolicy == nil { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy")) return } adm := linkedca.MustAdminFromContext(ctx) var updatedPolicy *linkedca.Policy if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if isBadRequest(err) { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) return } render.Error(w, r, admin.WrapErrorISE(err, "error updating authority policy")) return } render.ProtoJSONStatus(w, updatedPolicy, http.StatusOK) } // DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } auth := mustAuthority(ctx) authorityPolicy, err := auth.GetAuthorityPolicy(ctx) var ae *admin.Error if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, r, admin.WrapErrorISE(ae, "error retrieving authority policy")) return } if authorityPolicy == nil { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")) return } if err := auth.RemoveAuthorityPolicy(ctx); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error deleting authority policy")) return } render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK) } // GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } prov := linkedca.MustProvisionerFromContext(ctx) provisionerPolicy := prov.GetPolicy() if provisionerPolicy == nil { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) return } render.ProtoJSONStatus(w, provisionerPolicy, http.StatusOK) } // CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } prov := linkedca.MustProvisionerFromContext(ctx) provisionerPolicy := prov.GetPolicy() if provisionerPolicy != nil { adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name) render.Error(w, r, adminErr) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) return } prov.Policy = newPolicy auth := mustAuthority(ctx) if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy")) return } render.Error(w, r, admin.WrapErrorISE(err, "error creating provisioner policy")) return } render.ProtoJSONStatus(w, newPolicy, http.StatusCreated) } // UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } prov := linkedca.MustProvisionerFromContext(ctx) provisionerPolicy := prov.GetPolicy() if provisionerPolicy == nil { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy")) return } prov.Policy = newPolicy auth := mustAuthority(ctx) if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy")) return } render.Error(w, r, admin.WrapErrorISE(err, "error updating provisioner policy")) return } render.ProtoJSONStatus(w, newPolicy, http.StatusOK) } // DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } prov := linkedca.MustProvisionerFromContext(ctx) if prov.Policy == nil { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) return } // remove the policy prov.Policy = nil auth := mustAuthority(ctx) if err := auth.UpdateProvisioner(ctx, prov); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error deleting provisioner policy")) return } render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK) } func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } eak := linkedca.MustExternalAccountKeyFromContext(ctx) eakPolicy := eak.GetPolicy() if eakPolicy == nil { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) return } render.ProtoJSONStatus(w, eakPolicy, http.StatusOK) } func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } prov := linkedca.MustProvisionerFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx) eakPolicy := eak.GetPolicy() if eakPolicy != nil { adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id) render.Error(w, r, adminErr) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) return } eak.Policy = newPolicy acmeEAK := linkedEAKToCertificates(eak) acmeDB := acme.MustDatabaseFromContext(ctx) if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error creating ACME EAK policy")) return } render.ProtoJSONStatus(w, newPolicy, http.StatusCreated) } func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } prov := linkedca.MustProvisionerFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx) eakPolicy := eak.GetPolicy() if eakPolicy == nil { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) return } var newPolicy = new(linkedca.Policy) if err := read.ProtoJSON(r.Body, newPolicy); err != nil { render.Error(w, r, err) return } newPolicy.Deduplicate() if err := validatePolicy(newPolicy); err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy")) return } eak.Policy = newPolicy acmeEAK := linkedEAKToCertificates(eak) acmeDB := acme.MustDatabaseFromContext(ctx) if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error updating ACME EAK policy")) return } render.ProtoJSONStatus(w, newPolicy, http.StatusOK) } func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if err := blockLinkedCA(ctx); err != nil { render.Error(w, r, err) return } prov := linkedca.MustProvisionerFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx) eakPolicy := eak.GetPolicy() if eakPolicy == nil { render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) return } // remove the policy eak.Policy = nil acmeEAK := linkedEAKToCertificates(eak) acmeDB := acme.MustDatabaseFromContext(ctx) if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error deleting ACME EAK policy")) return } render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK) } // blockLinkedCA blocks all API operations on linked deployments func blockLinkedCA(ctx context.Context) error { // temporary blocking linked deployments adminDB := admin.MustFromContext(ctx) if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() { return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") } return nil } // isBadRequest checks if an error should result in a bad request error // returned to the client. func isBadRequest(err error) bool { var pe *authority.PolicyError isPolicyError := errors.As(err, &pe) return isPolicyError && (pe.Typ == authority.AdminLockOut || pe.Typ == authority.EvaluationFailure || pe.Typ == authority.ConfigurationFailure) } func validatePolicy(p *linkedca.Policy) error { // convert the policy; return early if nil options := policy.LinkedToCertificates(p) if options == nil { return nil } var err error // Initialize a temporary x509 allow/deny policy engine if _, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil { return err } // Initialize a temporary SSH allow/deny policy engine for host certificates if _, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil { return err } // Initialize a temporary SSH allow/deny policy engine for user certificates if _, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil { return err } return nil } ================================================ FILE: authority/admin/api/policy_test.go ================================================ package api import ( "bytes" "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/encoding/protojson" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" ) type fakeLinkedCA struct { admin.MockDB } func (f *fakeLinkedCA) IsLinkedCA() bool { return true } // testAdminError is an error type that models the expected // error body returned. type testAdminError struct { Type string `json:"type"` Message string `json:"message"` Detail string `json:"detail"` } type testX509Policy struct { Allow *testX509Names `json:"allow,omitempty"` Deny *testX509Names `json:"deny,omitempty"` AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"` } type testX509Names struct { CommonNames []string `json:"commonNames,omitempty"` DNSDomains []string `json:"dns,omitempty"` IPRanges []string `json:"ips,omitempty"` EmailAddresses []string `json:"emails,omitempty"` URIDomains []string `json:"uris,omitempty"` } type testSSHPolicy struct { User *testSSHUserPolicy `json:"user,omitempty"` Host *testSSHHostPolicy `json:"host,omitempty"` } type testSSHHostPolicy struct { Allow *testSSHHostNames `json:"allow,omitempty"` Deny *testSSHHostNames `json:"deny,omitempty"` } type testSSHHostNames struct { DNSDomains []string `json:"dns,omitempty"` IPRanges []string `json:"ips,omitempty"` Principals []string `json:"principals,omitempty"` } type testSSHUserPolicy struct { Allow *testSSHUserNames `json:"allow,omitempty"` Deny *testSSHUserNames `json:"deny,omitempty"` } type testSSHUserNames struct { EmailAddresses []string `json:"emails,omitempty"` Principals []string `json:"principals,omitempty"` } // testPolicyResponse models the Policy API JSON response type testPolicyResponse struct { X509 *testX509Policy `json:"x509,omitempty"` SSH *testSSHPolicy `json:"ssh,omitempty"` } func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { type test struct { auth adminAuthority adminDB admin.DB ctx context.Context err *admin.Error response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { ctx := context.Background() err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") }, }, err: err, statusCode: 500, } }, "fail/auth.GetAuthorityPolicy-not-found": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") err.Message = "authority policy does not exist" return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") }, }, err: err, statusCode: 404, } }, "ok": func(t *testing.T) test { ctx := context.Background() policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"10.0.0.0/16"}, Emails: []string{"@example.com"}, Uris: []string{"example.com"}, CommonNames: []string{"test"}, }, Deny: &linkedca.X509Names{ Dns: []string{"bad.local"}, Ips: []string{"10.0.0.30"}, Emails: []string{"bad@example.com"}, Uris: []string{"notexample.com"}, CommonNames: []string{"bad"}, }, }, Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@example.com"}, Principals: []string{"*"}, }, Deny: &linkedca.SSHUserNames{ Emails: []string{"bad@example.com"}, Principals: []string{"root"}, }, }, Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.example.com"}, Ips: []string{"10.10.0.0/16"}, Principals: []string{"good"}, }, Deny: &linkedca.SSHHostNames{ Dns: []string{"bad@example.com"}, Ips: []string{"10.10.0.30"}, Principals: []string{"bad"}, }, }, }, } return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil }, }, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ DNSDomains: []string{"*.local"}, IPRanges: []string{"10.0.0.0/16"}, EmailAddresses: []string{"@example.com"}, URIDomains: []string{"example.com"}, CommonNames: []string{"test"}, }, Deny: &testX509Names{ DNSDomains: []string{"bad.local"}, IPRanges: []string{"10.0.0.30"}, EmailAddresses: []string{"bad@example.com"}, URIDomains: []string{"notexample.com"}, CommonNames: []string{"bad"}, }, }, SSH: &testSSHPolicy{ User: &testSSHUserPolicy{ Allow: &testSSHUserNames{ EmailAddresses: []string{"@example.com"}, Principals: []string{"*"}, }, Deny: &testSSHUserNames{ EmailAddresses: []string{"bad@example.com"}, Principals: []string{"root"}, }, }, Host: &testSSHHostPolicy{ Allow: &testSSHHostNames{ DNSDomains: []string{"*.example.com"}, IPRanges: []string{"10.10.0.0/16"}, Principals: []string{"good"}, }, Deny: &testSSHHostNames{ DNSDomains: []string{"bad@example.com"}, IPRanges: []string{"10.10.0.30"}, Principals: []string{"bad"}, }, }, }, }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("GET", "/foo", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() par.GetAuthorityPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.Message, ae.Message) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } p := &testPolicyResponse{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, json.Unmarshal(body, &p)) assert.Equal(t, tc.response, p) }) } } func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { type test struct { auth adminAuthority adminDB admin.DB body []byte ctx context.Context acmeDB acme.DB err *admin.Error response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { ctx := context.Background() err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") }, }, err: err, statusCode: 500, } }, "fail/existing-policy": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorConflictType, "authority already has a policy") err.Message = "authority already has a policy" return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{}, nil }, }, err: err, statusCode: 409, } }, "fail/read.ProtoJSON": func(t *testing.T) test { ctx := context.Background() adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") }, }, body: body, err: adminErr, statusCode: 400, } }, "fail/validatePolicy": func(t *testing.T) test { ctx := context.Background() adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") adminErr.Message = "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" body := []byte(` { "x509": { "allow": { "uris": [ "https://example.com" ] } } }`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") }, }, body: body, err: adminErr, statusCode: 400, } }, "fail/CreateAuthorityPolicy-policy-admin-lockout-error": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } ctx := context.Background() ctx = linkedca.NewContextWithAdmin(ctx, adm) adminErr := admin.NewError(admin.ErrorBadRequestType, "error storing authority policy") adminErr.Message = "error storing authority policy: admin lock out" policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") }, MockCreateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { return nil, &authority.PolicyError{ Typ: authority.AdminLockOut, Err: errors.New("admin lock out"), } }, }, adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{ adm, { Subject: "anotherAdmin", }, }, nil }, }, body: body, err: adminErr, statusCode: 400, } }, "fail/CreateAuthorityPolicy-error": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } ctx := context.Background() ctx = linkedca.NewContextWithAdmin(ctx, adm) adminErr := admin.NewError(admin.ErrorServerInternalType, "error storing authority policy: force") adminErr.Message = "error storing authority policy: force" policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") }, MockCreateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { return nil, &authority.PolicyError{ Typ: authority.StoreFailure, Err: errors.New("force"), } }, }, adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{ adm, { Subject: "anotherAdmin", }, }, nil }, }, body: body, err: adminErr, statusCode: 500, } }, "ok": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } ctx := context.Background() ctx = linkedca.NewContextWithAdmin(ctx, adm) policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") }, MockCreateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { return policy, nil }, }, adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{ adm, { Subject: "anotherAdmin", }, }, nil }, }, body: body, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ DNSDomains: []string{"*.local"}, }, }, }, statusCode: 201, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() par.CreateAuthorityPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", // we expect a full string match. if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { assert.Equal(t, tc.err.Message, ae.Message) } return } p := &testPolicyResponse{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, json.Unmarshal(body, &p)) assert.Equal(t, tc.response, p) }) } } func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { type test struct { auth adminAuthority adminDB admin.DB body []byte ctx context.Context acmeDB acme.DB err *admin.Error response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { ctx := context.Background() err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") }, }, err: err, statusCode: 500, } }, "fail/no-existing-policy": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") err.Message = "authority policy does not exist" err.Status = http.StatusNotFound return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, nil }, }, err: err, statusCode: 404, } }, "fail/read.ProtoJSON": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } ctx := context.Background() adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil }, }, body: body, err: adminErr, statusCode: 400, } }, "fail/validatePolicy": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } ctx := context.Background() adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") adminErr.Message = "error validating authority policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" body := []byte(` { "x509": { "allow": { "uris": [ "https://example.com" ] } } }`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil }, }, body: body, err: adminErr, statusCode: 400, } }, "fail/UpdateAuthorityPolicy-policy-admin-lockout-error": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } ctx := context.Background() ctx = linkedca.NewContextWithAdmin(ctx, adm) adminErr := admin.NewError(admin.ErrorBadRequestType, "error updating authority policy: force") adminErr.Message = "error updating authority policy: admin lock out" policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil }, MockUpdateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { return nil, &authority.PolicyError{ Typ: authority.AdminLockOut, Err: errors.New("admin lock out"), } }, }, adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{ adm, { Subject: "anotherAdmin", }, }, nil }, }, body: body, err: adminErr, statusCode: 400, } }, "fail/UpdateAuthorityPolicy-error": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } ctx := context.Background() ctx = linkedca.NewContextWithAdmin(ctx, adm) adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating authority policy: force") adminErr.Message = "error updating authority policy: force" policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil }, MockUpdateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { return nil, &authority.PolicyError{ Typ: authority.StoreFailure, Err: errors.New("force"), } }, }, adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{ adm, { Subject: "anotherAdmin", }, }, nil }, }, body: body, err: adminErr, statusCode: 500, } }, "ok": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } ctx := context.Background() ctx = linkedca.NewContextWithAdmin(ctx, adm) policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil }, MockUpdateAuthorityPolicy: func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) { return policy, nil }, }, adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{ adm, { Subject: "anotherAdmin", }, }, nil }, }, body: body, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ DNSDomains: []string{"*.local"}, }, }, }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() par.UpdateAuthorityPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", // we expect a full string match. if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { assert.Equal(t, tc.err.Message, ae.Message) } return } p := &testPolicyResponse{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, json.Unmarshal(body, &p)) assert.Equal(t, tc.response, p) }) } } func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { type test struct { auth adminAuthority adminDB admin.DB body []byte ctx context.Context acmeDB acme.DB err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/auth.GetAuthorityPolicy-error": func(t *testing.T) test { ctx := context.Background() err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") }, }, err: err, statusCode: 500, } }, "fail/no-existing-policy": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") err.Message = "authority policy does not exist" err.Status = http.StatusNotFound return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, nil }, }, err: err, statusCode: 404, } }, "fail/auth.RemoveAuthorityPolicy-error": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } ctx := context.Background() err := admin.NewErrorISE("error deleting authority policy: force") err.Message = "error deleting authority policy: force" return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil }, MockRemoveAuthorityPolicy: func(ctx context.Context) error { return errors.New("force") }, }, err: err, statusCode: 500, } }, "ok": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } ctx := context.Background() return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil }, MockRemoveAuthorityPolicy: func(ctx context.Context) error { return nil }, }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() par.DeleteAuthorityPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.Message, ae.Message) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } body, err := io.ReadAll(res.Body) assert.NoError(t, err) res.Body.Close() response := DeleteResponse{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) assert.Equal(t, "ok", response.Status) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { type test struct { auth adminAuthority adminDB admin.DB ctx context.Context acmeDB acme.DB err *admin.Error response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/prov-no-policy": func(t *testing.T) test { prov := &linkedca.Provisioner{} ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist") err.Message = "provisioner policy does not exist" return test{ ctx: ctx, adminDB: &admin.MockDB{}, err: err, statusCode: 404, } }, "ok": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"10.0.0.0/16"}, Emails: []string{"@example.com"}, Uris: []string{"example.com"}, CommonNames: []string{"test"}, }, Deny: &linkedca.X509Names{ Dns: []string{"bad.local"}, Ips: []string{"10.0.0.30"}, Emails: []string{"bad@example.com"}, Uris: []string{"notexample.com"}, CommonNames: []string{"bad"}, }, }, Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@example.com"}, Principals: []string{"*"}, }, Deny: &linkedca.SSHUserNames{ Emails: []string{"bad@example.com"}, Principals: []string{"root"}, }, }, Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.example.com"}, Ips: []string{"10.10.0.0/16"}, Principals: []string{"good"}, }, Deny: &linkedca.SSHHostNames{ Dns: []string{"bad@example.com"}, Ips: []string{"10.10.0.30"}, Principals: []string{"bad"}, }, }, }, } prov := &linkedca.Provisioner{ Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ ctx: ctx, adminDB: &admin.MockDB{}, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ DNSDomains: []string{"*.local"}, IPRanges: []string{"10.0.0.0/16"}, EmailAddresses: []string{"@example.com"}, URIDomains: []string{"example.com"}, CommonNames: []string{"test"}, }, Deny: &testX509Names{ DNSDomains: []string{"bad.local"}, IPRanges: []string{"10.0.0.30"}, EmailAddresses: []string{"bad@example.com"}, URIDomains: []string{"notexample.com"}, CommonNames: []string{"bad"}, }, }, SSH: &testSSHPolicy{ User: &testSSHUserPolicy{ Allow: &testSSHUserNames{ EmailAddresses: []string{"@example.com"}, Principals: []string{"*"}, }, Deny: &testSSHUserNames{ EmailAddresses: []string{"bad@example.com"}, Principals: []string{"root"}, }, }, Host: &testSSHHostPolicy{ Allow: &testSSHHostNames{ DNSDomains: []string{"*.example.com"}, IPRanges: []string{"10.10.0.0/16"}, Principals: []string{"good"}, }, Deny: &testSSHHostNames{ DNSDomains: []string{"bad@example.com"}, IPRanges: []string{"10.10.0.30"}, Principals: []string{"bad"}, }, }, }, }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("GET", "/foo", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() par.GetProvisionerPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.Message, ae.Message) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } p := &testPolicyResponse{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, json.Unmarshal(body, &p)) assert.Equal(t, tc.response, p) }) } } func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { type test struct { auth adminAuthority adminDB admin.DB body []byte ctx context.Context err *admin.Error response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/existing-policy": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewError(admin.ErrorConflictType, "provisioner provName already has a policy") err.Message = "provisioner provName already has a policy" return test{ ctx: ctx, adminDB: &admin.MockDB{}, err: err, statusCode: 409, } }, "fail/read.ProtoJSON": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/validatePolicy": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") adminErr.Message = "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" body := []byte(` { "x509": { "allow": { "uris": [ "https://example.com" ] } } }`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") }, }, body: body, err: adminErr, statusCode: 400, } }, "fail/auth.UpdateProvisioner-policy-admin-lockout-error": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithAdmin(context.Background(), adm) ctx = linkedca.NewContextWithProvisioner(ctx, prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "error creating provisioner policy") adminErr.Message = "error creating provisioner policy: admin lock out" policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ Typ: authority.AdminLockOut, Err: errors.New("admin lock out"), } }, }, body: body, err: adminErr, statusCode: 400, } }, "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithAdmin(context.Background(), adm) ctx = linkedca.NewContextWithProvisioner(ctx, prov) adminErr := admin.NewError(admin.ErrorServerInternalType, "error creating provisioner policy: force") adminErr.Message = "error creating provisioner policy: force" policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ Typ: authority.StoreFailure, Err: errors.New("force"), } }, }, body: body, err: adminErr, statusCode: 500, } }, "ok": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithAdmin(context.Background(), adm) ctx = linkedca.NewContextWithProvisioner(ctx, prov) policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return nil }, }, body: body, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ DNSDomains: []string{"*.local"}, }, }, }, statusCode: 201, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() par.CreateProvisionerPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", // we expect a full string match. if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { assert.Equal(t, tc.err.Message, ae.Message) } return } p := &testPolicyResponse{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, json.Unmarshal(body, &p)) assert.Equal(t, tc.response, p) }) } } func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { type test struct { auth adminAuthority body []byte adminDB admin.DB ctx context.Context err *admin.Error response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/no-existing-policy": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist") err.Message = "provisioner policy does not exist" return test{ ctx: ctx, adminDB: &admin.MockDB{}, err: err, statusCode: 404, } }, "fail/read.ProtoJSON": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/validatePolicy": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") adminErr.Message = "error validating provisioner policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" body := []byte(` { "x509": { "allow": { "uris": [ "https://example.com" ] } } }`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") }, }, body: body, err: adminErr, statusCode: 400, } }, "fail/auth.UpdateProvisioner-policy-admin-lockout-error": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", Policy: policy, } ctx := linkedca.NewContextWithAdmin(context.Background(), adm) ctx = linkedca.NewContextWithProvisioner(ctx, prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "error updating provisioner policy") adminErr.Message = "error updating provisioner policy: admin lock out" body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ Typ: authority.AdminLockOut, Err: errors.New("admin lock out"), } }, }, body: body, err: adminErr, statusCode: 400, } }, "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", Policy: policy, } ctx := linkedca.NewContextWithAdmin(context.Background(), adm) ctx = linkedca.NewContextWithProvisioner(ctx, prov) adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating provisioner policy: force") adminErr.Message = "error updating provisioner policy: force" body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ Typ: authority.StoreFailure, Err: errors.New("force"), } }, }, body: body, err: adminErr, statusCode: 500, } }, "ok": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", Policy: policy, } ctx := linkedca.NewContextWithAdmin(context.Background(), adm) ctx = linkedca.NewContextWithProvisioner(ctx, prov) body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return nil }, }, body: body, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ DNSDomains: []string{"*.local"}, }, }, }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() par.UpdateProvisionerPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", // we expect a full string match. if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { assert.Equal(t, tc.err.Message, ae.Message) } return } p := &testPolicyResponse{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, json.Unmarshal(body, &p)) assert.Equal(t, tc.response, p) }) } } func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { type test struct { auth adminAuthority adminDB admin.DB body []byte ctx context.Context acmeDB acme.DB err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/no-existing-policy": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist") err.Message = "provisioner policy does not exist" return test{ ctx: ctx, adminDB: &admin.MockDB{}, err: err, statusCode: 404, } }, "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Policy: &linkedca.Policy{}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewErrorISE("error deleting provisioner policy: force") err.Message = "error deleting provisioner policy: force" return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return errors.New("force") }, }, err: err, statusCode: 500, } }, "ok": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Policy: &linkedca.Policy{}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return nil }, }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() par.DeleteProvisionerPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.Message, ae.Message) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } body, err := io.ReadAll(res.Body) assert.NoError(t, err) res.Body.Close() response := DeleteResponse{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) assert.Equal(t, "ok", response.Status) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { type test struct { ctx context.Context acmeDB acme.DB adminDB admin.DB err *admin.Error response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/no-policy": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) err := admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist") err.Message = "ACME EAK policy does not exist" return test{ ctx: ctx, adminDB: &admin.MockDB{}, err: err, statusCode: 404, } }, "ok": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"10.0.0.0/16"}, Emails: []string{"@example.com"}, Uris: []string{"example.com"}, CommonNames: []string{"test"}, }, Deny: &linkedca.X509Names{ Dns: []string{"bad.local"}, Ips: []string{"10.0.0.30"}, Emails: []string{"bad@example.com"}, Uris: []string{"notexample.com"}, CommonNames: []string{"bad"}, }, }, Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@example.com"}, Principals: []string{"*"}, }, Deny: &linkedca.SSHUserNames{ Emails: []string{"bad@example.com"}, Principals: []string{"root"}, }, }, Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.example.com"}, Ips: []string{"10.10.0.0/16"}, Principals: []string{"good"}, }, Deny: &linkedca.SSHHostNames{ Dns: []string{"bad@example.com"}, Ips: []string{"10.10.0.30"}, Principals: []string{"bad"}, }, }, }, } prov := &linkedca.Provisioner{ Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) return test{ ctx: ctx, adminDB: &admin.MockDB{}, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ DNSDomains: []string{"*.local"}, IPRanges: []string{"10.0.0.0/16"}, EmailAddresses: []string{"@example.com"}, URIDomains: []string{"example.com"}, CommonNames: []string{"test"}, }, Deny: &testX509Names{ DNSDomains: []string{"bad.local"}, IPRanges: []string{"10.0.0.30"}, EmailAddresses: []string{"bad@example.com"}, URIDomains: []string{"notexample.com"}, CommonNames: []string{"bad"}, }, }, SSH: &testSSHPolicy{ User: &testSSHUserPolicy{ Allow: &testSSHUserNames{ EmailAddresses: []string{"@example.com"}, Principals: []string{"*"}, }, Deny: &testSSHUserNames{ EmailAddresses: []string{"bad@example.com"}, Principals: []string{"root"}, }, }, Host: &testSSHHostPolicy{ Allow: &testSSHHostNames{ DNSDomains: []string{"*.example.com"}, IPRanges: []string{"10.10.0.0/16"}, Principals: []string{"good"}, }, Deny: &testSSHHostNames{ DNSDomains: []string{"bad@example.com"}, IPRanges: []string{"10.10.0.30"}, Principals: []string{"bad"}, }, }, }, }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { ctx := admin.NewContext(tc.ctx, tc.adminDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("GET", "/foo", http.NoBody) req = req.WithContext(ctx) w := httptest.NewRecorder() par.GetACMEAccountPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.Message, ae.Message) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } p := &testPolicyResponse{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, json.Unmarshal(body, &p)) assert.Equal(t, tc.response, p) }) } } func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { type test struct { acmeDB acme.DB adminDB admin.DB body []byte ctx context.Context err *admin.Error response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/existing-policy": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) err := admin.NewError(admin.ErrorConflictType, "ACME EAK eakID already has a policy") err.Message = "ACME EAK eakID already has a policy" return test{ ctx: ctx, adminDB: &admin.MockDB{}, err: err, statusCode: 409, } }, "fail/read.ProtoJSON": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/validatePolicy": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") adminErr.Message = "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" body := []byte(` { "x509": { "allow": { "uris": [ "https://example.com" ] } } }`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/acmeDB.UpdateExternalAccountKey-error": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) adminErr := admin.NewError(admin.ErrorServerInternalType, "error creating ACME EAK policy") adminErr.Message = "error creating ACME EAK policy: force" policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) assert.Equal(t, "eakID", eak.ID) return errors.New("force") }, }, body: body, err: adminErr, statusCode: 500, } }, "ok": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) assert.Equal(t, "eakID", eak.ID) return nil }, }, body: body, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ DNSDomains: []string{"*.local"}, }, }, }, statusCode: 201, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { ctx := admin.NewContext(tc.ctx, tc.adminDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() par.CreateACMEAccountPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", // we expect a full string match. if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { assert.Equal(t, tc.err.Message, ae.Message) } return } p := &testPolicyResponse{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, json.Unmarshal(body, &p)) assert.Equal(t, tc.response, p) }) } } func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { type test struct { acmeDB acme.DB adminDB admin.DB body []byte ctx context.Context err *admin.Error response *testPolicyResponse statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/no-existing-policy": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) err := admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist") err.Message = "ACME EAK policy does not exist" return test{ ctx: ctx, adminDB: &admin.MockDB{}, err: err, statusCode: 404, } }, "fail/read.ProtoJSON": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/validatePolicy": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) adminErr := admin.NewError(admin.ErrorBadRequestType, "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)") adminErr.Message = "error validating ACME EAK policy: cannot parse permitted URI domain constraint \"https://example.com\": URI domain constraint \"https://example.com\" contains scheme (not supported yet)" body := []byte(` { "x509": { "allow": { "uris": [ "https://example.com" ] } } }`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/acmeDB.UpdateExternalAccountKey-error": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", Id: "provID", } eak := &linkedca.EABKey{ Id: "eakID", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating ACME EAK policy: force") adminErr.Message = "error updating ACME EAK policy: force" body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) assert.Equal(t, "eakID", eak.ID) return errors.New("force") }, }, body: body, err: adminErr, statusCode: 500, } }, "ok": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", Id: "provID", } eak := &linkedca.EABKey{ Id: "eakID", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ ctx: ctx, adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) assert.Equal(t, "eakID", eak.ID) return nil }, }, body: body, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ DNSDomains: []string{"*.local"}, }, }, }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { ctx := admin.NewContext(tc.ctx, tc.adminDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() par.UpdateACMEAccountPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", // we expect a full string match. if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { assert.Equal(t, tc.err.Message, ae.Message) } return } p := &testPolicyResponse{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, json.Unmarshal(body, &p)) assert.Equal(t, tc.response, p) }) } } func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { type test struct { body []byte adminDB admin.DB ctx context.Context acmeDB acme.DB err *admin.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/linkedca": func(t *testing.T) test { ctx := context.Background() err := admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") err.Message = "policy operations not yet supported in linked deployments" return test{ ctx: ctx, adminDB: &fakeLinkedCA{}, err: err, statusCode: 501, } }, "fail/no-existing-policy": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } eak := &linkedca.EABKey{ Id: "eakID", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) err := admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist") err.Message = "ACME EAK policy does not exist" return test{ ctx: ctx, adminDB: &admin.MockDB{}, err: err, statusCode: 404, } }, "fail/acmeDB.UpdateExternalAccountKey-error": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", Id: "provID", } eak := &linkedca.EABKey{ Id: "eakID", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) err := admin.NewErrorISE("error deleting ACME EAK policy: force") err.Message = "error deleting ACME EAK policy: force" return test{ ctx: ctx, adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) assert.Equal(t, "eakID", eak.ID) return errors.New("force") }, }, err: err, statusCode: 500, } }, "ok": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } prov := &linkedca.Provisioner{ Name: "provName", Id: "provID", } eak := &linkedca.EABKey{ Id: "eakID", Policy: policy, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) return test{ ctx: ctx, adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) assert.Equal(t, "eakID", eak.ID) return nil }, }, statusCode: 200, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { ctx := admin.NewContext(tc.ctx, tc.adminDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() par.DeleteACMEAccountPolicy(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.Message, ae.Message) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) return } body, err := io.ReadAll(res.Body) assert.NoError(t, err) res.Body.Close() response := DeleteResponse{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) assert.Equal(t, "ok", response.Status) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func Test_isBadRequest(t *testing.T) { tests := []struct { name string err error want bool }{ { name: "nil", err: nil, want: false, }, { name: "no-policy-error", err: errors.New("some error"), want: false, }, { name: "no-bad-request", err: &authority.PolicyError{ Typ: authority.InternalFailure, Err: errors.New("error"), }, want: false, }, { name: "bad-request", err: &authority.PolicyError{ Typ: authority.AdminLockOut, Err: errors.New("admin lock out"), }, want: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := isBadRequest(tt.err); got != tt.want { t.Errorf("isBadRequest() = %v, want %v", got, tt.want) } }) } } func Test_validatePolicy(t *testing.T) { type args struct { p *linkedca.Policy } tests := []struct { name string args args wantErr bool }{ { name: "nil", args: args{ p: nil, }, wantErr: false, }, { name: "x509", args: args{ p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"**.local"}, }, }, }, }, wantErr: true, }, { name: "ssh user", args: args{ p: &linkedca.Policy{ Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@@example.com"}, }, }, }, }, }, wantErr: true, }, { name: "ssh host", args: args{ p: &linkedca.Policy{ Ssh: &linkedca.SSHPolicy{ Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"**.local"}, }, }, }, }, }, wantErr: true, }, { name: "ok", args: args{ p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@example.com"}, }, }, Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.local"}, }, }, }, }, }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := validatePolicy(tt.args.p); (err != nil) != tt.wantErr { t.Errorf("validatePolicy() error = %v, wantErr %v", err, tt.wantErr) } }) } } ================================================ FILE: authority/admin/api/provisioner.go ================================================ package api import ( "fmt" "net/http" "github.com/go-chi/chi/v5" "github.com/smallstep/linkedca" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) // GetProvisionersResponse is the type for GET /admin/provisioners responses. type GetProvisionersResponse struct { Provisioners provisioner.List `json:"provisioners"` NextCursor string `json:"nextCursor"` } // GetProvisioner returns the requested provisioner, or an error. func GetProvisioner(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) ctx := r.Context() id := r.URL.Query().Get("id") name := chi.URLParam(r, "name") auth := mustAuthority(ctx) db := admin.MustFromContext(ctx) if id != "" { if p, err = auth.LoadProvisionerByID(id); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } prov, err := db.GetProvisioner(ctx, p.GetID()) if err != nil { render.Error(w, r, err) return } render.ProtoJSON(w, prov) } // GetProvisioners returns the given segment of provisioners associated with the authority. func GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { render.Error(w, r, errs.InternalServerErr(err)) return } render.JSON(w, r, &GetProvisionersResponse{ Provisioners: p, NextCursor: next, }) } // CreateProvisioner creates a new prov. func CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, prov); err != nil { render.Error(w, r, err) return } // TODO: Validate inputs if err := authority.ValidateClaims(prov.Claims); err != nil { render.Error(w, r, err) return } // validate the templates and template data if err := validateTemplates(prov.X509Template, prov.SshTemplate); err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template")) return } if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } render.ProtoJSONStatus(w, prov, http.StatusCreated) } // DeleteProvisioner deletes a provisioner. func DeleteProvisioner(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) id := r.URL.Query().Get("id") name := chi.URLParam(r, "name") auth := mustAuthority(r.Context()) if id != "" { if p, err = auth.LoadProvisionerByID(id); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) return } render.JSON(w, r, &DeleteResponse{Status: "ok"}) } // UpdateProvisioner updates an existing prov. func UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, nu); err != nil { render.Error(w, r, err) return } ctx := r.Context() name := chi.URLParam(r, "name") auth := mustAuthority(ctx) db := admin.MustFromContext(ctx) p, err := auth.LoadProvisionerByName(name) if err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) return } old, err := db.GetProvisioner(r.Context(), p.GetID()) if err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID())) return } if nu.Id != old.Id { render.Error(w, r, admin.NewErrorISE("cannot change provisioner ID")) return } if nu.Type != old.Type { render.Error(w, r, admin.NewErrorISE("cannot change provisioner type")) return } if nu.AuthorityId != old.AuthorityId { render.Error(w, r, admin.NewErrorISE("cannot change provisioner authorityID")) return } if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) { render.Error(w, r, admin.NewErrorISE("cannot change provisioner createdAt")) return } if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) { render.Error(w, r, admin.NewErrorISE("cannot change provisioner deletedAt")) return } // TODO: Validate inputs if err := authority.ValidateClaims(nu.Claims); err != nil { render.Error(w, r, err) return } // validate the templates and template data if err := validateTemplates(nu.X509Template, nu.SshTemplate); err != nil { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template")) return } if err := auth.UpdateProvisioner(r.Context(), nu); err != nil { render.Error(w, r, err) return } render.ProtoJSON(w, nu) } // validateTemplates validates the X.509 and SSH templates and template data if set. func validateTemplates(x509, ssh *linkedca.Template) error { if x509 != nil { if len(x509.Template) > 0 { if err := x509util.ValidateTemplate(x509.Template); err != nil { return fmt.Errorf("invalid X.509 template: %w", err) } } if len(x509.Data) > 0 { if err := x509util.ValidateTemplateData(x509.Data); err != nil { return fmt.Errorf("invalid X.509 template data: %w", err) } } } if ssh != nil { if len(ssh.Template) > 0 { if err := sshutil.ValidateTemplate(ssh.Template); err != nil { return fmt.Errorf("invalid SSH template: %w", err) } } if len(ssh.Data) > 0 { if err := sshutil.ValidateTemplateData(ssh.Data); err != nil { return fmt.Errorf("invalid SSH template data: %w", err) } } } return nil } ================================================ FILE: authority/admin/api/provisioner_test.go ================================================ package api import ( "bytes" "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/go-chi/chi/v5" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/timestamppb" "github.com/smallstep/assert" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" ) func TestHandler_GetProvisioner(t *testing.T) { type test struct { ctx context.Context auth adminAuthority adminDB admin.DB req *http.Request statusCode int err *admin.Error prov *linkedca.Provisioner } var tests = map[string]func(t *testing.T) test{ "fail/auth.LoadProvisionerByID": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo?id=provID", http.NoBody) chiCtx := chi.NewRouteContext() ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByID: func(id string) (provisioner.Interface, error) { assert.Equals(t, "provID", id) return nil, errors.New("force") }, } return test{ ctx: ctx, req: req, auth: auth, adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error loading provisioner provID: force", }, } }, "fail/auth.LoadProvisionerByName": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", http.NoBody) chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return nil, errors.New("force") }, } return test{ ctx: ctx, req: req, auth: auth, adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error loading provisioner provName: force", }, } }, "fail/db.GetProvisioner": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", http.NoBody) chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.ACME{ ID: "acmeID", Name: "provName", }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "acmeID", id) return nil, admin.NewErrorISE("error loading provisioner provName: force") }, } return test{ ctx: ctx, req: req, auth: auth, adminDB: db, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error loading provisioner provName: force", }, } }, "ok": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", http.NoBody) chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.ACME{ ID: "acmeID", Name: "provName", }, nil }, } prov := &linkedca.Provisioner{ Id: "acmeID", Type: linkedca.Provisioner_ACME, Name: "provName", // TODO(hs): other fields too? } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "acmeID", id) return prov, nil }, } return test{ ctx: ctx, req: req, auth: auth, adminDB: db, statusCode: 200, err: nil, prov: prov, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) req := tc.req.WithContext(ctx) w := httptest.NewRecorder() GetProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } prov := &linkedca.Provisioner{} err := readProtoJSON(res.Body, prov) assert.FatalError(t, err) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Provisioner{}, timestamppb.Timestamp{})} if !cmp.Equal(tc.prov, prov, opts...) { t.Errorf("h.GetProvisioner diff =\n%s", cmp.Diff(tc.prov, prov, opts...)) } }) } } func TestHandler_GetProvisioners(t *testing.T) { type test struct { ctx context.Context auth adminAuthority req *http.Request statusCode int err *admin.Error resp GetProvisionersResponse } var tests = map[string]func(t *testing.T) test{ "fail/parse-cursor": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo?limit=X", http.NoBody) return test{ ctx: context.Background(), statusCode: 400, req: req, err: &admin.Error{ Status: 400, Type: admin.ErrorBadRequestType.String(), Detail: "bad request", Message: "error parsing cursor and limit from query params: limit 'X' is not an integer: strconv.Atoi: parsing \"X\": invalid syntax", }, } }, "fail/auth.GetProvisioners": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", http.NoBody) auth := &mockAdminAuthority{ MockGetProvisioners: func(cursor string, limit int) (provisioner.List, string, error) { assert.Equals(t, "", cursor) assert.Equals(t, 0, limit) return nil, "", errors.New("force") }, } return test{ ctx: context.Background(), req: req, auth: auth, statusCode: 500, err: &admin.Error{ Type: "", Status: 500, Detail: "", Message: "The certificate authority encountered an Internal Server Error. Please see the certificate authority logs for more info.", }, } }, "ok": func(t *testing.T) test { req := httptest.NewRequest("GET", "/foo", http.NoBody) provisioners := provisioner.List{ &provisioner.OIDC{ Type: "OIDC", Name: "oidcProv", }, &provisioner.ACME{ Type: "ACME", Name: "provName", ForceCN: false, RequireEAB: false, }, } auth := &mockAdminAuthority{ MockGetProvisioners: func(cursor string, limit int) (provisioner.List, string, error) { assert.Equals(t, "", cursor) assert.Equals(t, 0, limit) return provisioners, "nextCursorValue", nil }, } return test{ ctx: context.Background(), req: req, auth: auth, statusCode: 200, err: nil, resp: GetProvisionersResponse{ Provisioners: provisioners, NextCursor: "nextCursorValue", }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() GetProvisioners(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) response := GetProvisionersResponse{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) opts := []cmp.Option{cmpopts.IgnoreUnexported(provisioner.ACME{}, provisioner.OIDC{})} if !cmp.Equal(tc.resp, response, opts...) { t.Errorf("h.GetProvisioners diff =\n%s", cmp.Diff(tc.resp, response, opts...)) } }) } } func TestHandler_CreateProvisioner(t *testing.T) { type test struct { ctx context.Context auth adminAuthority body []byte statusCode int err *admin.Error prov *linkedca.Provisioner } var tests = map[string]func(t *testing.T) test{ "fail/readProtoJSON": func(t *testing.T) test { body := []byte("{!?}") return test{ ctx: context.Background(), body: body, statusCode: 400, err: &admin.Error{ Type: "badRequest", Status: 400, Detail: "bad request", Message: "proto: syntax error (line 1:2): invalid value !", }, } }, // TODO(hs): ValidateClaims can't be mocked atm // "fail/authority.ValidateClaims": func(t *testing.T) test { // return test{} // }, "fail/validateTemplates": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", X509Template: &linkedca.Template{ Template: []byte(`{ {{missingFunction "foo"}} }`), }, } body, err := protojson.Marshal(prov) assert.FatalError(t, err) return test{ ctx: context.Background(), body: body, statusCode: 400, err: &admin.Error{ Type: "badRequest", Status: 400, Detail: "bad request", Message: "invalid template: invalid X.509 template: error parsing template: template: template:1: function \"missingFunction\" not defined", }, } }, "fail/auth.StoreProvisioner": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockStoreProvisioner: func(ctx context.Context, prov *linkedca.Provisioner) error { assert.Equals(t, "provID", prov.Id) return errors.New("force") }, } return test{ ctx: context.Background(), body: body, auth: auth, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error storing provisioner provName: force", }, } }, "ok": func(t *testing.T) test { prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockStoreProvisioner: func(ctx context.Context, prov *linkedca.Provisioner) error { assert.Equals(t, "provID", prov.Id) return nil }, } return test{ ctx: context.Background(), body: body, auth: auth, statusCode: 201, err: nil, prov: prov, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() CreateProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(adminErr.Message, "syntax error")) } else { assert.Equals(t, tc.err.Message, adminErr.Message) } return } prov := &linkedca.Provisioner{} err := readProtoJSON(res.Body, prov) assert.FatalError(t, err) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Provisioner{}, timestamppb.Timestamp{})} if !cmp.Equal(tc.prov, prov, opts...) { t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(tc.prov, prov, opts...)) } }) } } func TestHandler_DeleteProvisioner(t *testing.T) { type test struct { ctx context.Context auth adminAuthority req *http.Request statusCode int err *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/auth.LoadProvisionerByID": func(t *testing.T) test { req := httptest.NewRequest("DELETE", "/foo?id=provID", http.NoBody) chiCtx := chi.NewRouteContext() ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByID: func(id string) (provisioner.Interface, error) { assert.Equals(t, "provID", id) return nil, errors.New("force") }, } return test{ ctx: ctx, req: req, auth: auth, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error loading provisioner provID: force", }, } }, "fail/auth.LoadProvisionerByName": func(t *testing.T) test { req := httptest.NewRequest("DELETE", "/foo", http.NoBody) chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return nil, errors.New("force") }, } return test{ ctx: ctx, req: req, auth: auth, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error loading provisioner provName: force", }, } }, "fail/auth.RemoveProvisioner": func(t *testing.T) test { req := httptest.NewRequest("DELETE", "/foo", http.NoBody) chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", Type: "OIDC", }, nil }, MockRemoveProvisioner: func(ctx context.Context, id string) error { assert.Equals(t, "provID", id) return errors.New("force") }, } return test{ ctx: ctx, req: req, auth: auth, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error removing provisioner provName: force", }, } }, "ok": func(t *testing.T) test { req := httptest.NewRequest("DELETE", "/foo", http.NoBody) chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", Type: "OIDC", }, nil }, MockRemoveProvisioner: func(ctx context.Context, id string) error { assert.Equals(t, "provID", id) return nil }, } return test{ ctx: ctx, req: req, auth: auth, statusCode: 200, err: nil, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() DeleteProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Message, adminErr.Message) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) return } body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) response := DeleteResponse{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) assert.Equals(t, "ok", response.Status) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestHandler_UpdateProvisioner(t *testing.T) { type test struct { ctx context.Context auth adminAuthority body []byte adminDB admin.DB statusCode int err *admin.Error prov *linkedca.Provisioner } var tests = map[string]func(t *testing.T) test{ "fail/readProtoJSON": func(t *testing.T) test { body := []byte("{!?}") return test{ ctx: context.Background(), body: body, adminDB: &admin.MockDB{}, statusCode: 400, err: &admin.Error{ Type: "badRequest", Status: 400, Detail: "bad request", Message: "proto: syntax error (line 1:2): invalid value !", }, } }, "fail/auth.LoadProvisionerByName": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return nil, errors.New("force") }, } return test{ ctx: ctx, body: body, adminDB: &admin.MockDB{}, auth: auth, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error loading provisioner from cached configuration 'provName': force", }, } }, "fail/db.GetProvisioner": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return nil, errors.New("force") }, } return test{ ctx: ctx, body: body, auth: auth, adminDB: db, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "error loading provisioner from db 'provID': force", }, } }, "fail/change-id-error": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) prov := &linkedca.Provisioner{ Id: "differentProvID", Type: linkedca.Provisioner_OIDC, Name: "provName", } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", }, nil }, } return test{ ctx: ctx, body: body, auth: auth, adminDB: db, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "cannot change provisioner ID", }, } }, "fail/change-type-error": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_JWK, Name: "provName", } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Type: linkedca.Provisioner_OIDC, }, nil }, } return test{ ctx: ctx, body: body, auth: auth, adminDB: db, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "cannot change provisioner type", }, } }, "fail/change-authority-id-error": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", AuthorityId: "differentAuthorityID", } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Type: linkedca.Provisioner_OIDC, AuthorityId: "authorityID", }, nil }, } return test{ ctx: ctx, body: body, auth: auth, adminDB: db, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "cannot change provisioner authorityID", }, } }, "fail/change-createdAt-error": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) createdAt := time.Now() prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", AuthorityId: "authorityID", CreatedAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Type: linkedca.Provisioner_OIDC, AuthorityId: "authorityID", CreatedAt: timestamppb.New(createdAt), }, nil }, } return test{ ctx: ctx, body: body, auth: auth, adminDB: db, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "cannot change provisioner createdAt", }, } }, "fail/change-deletedAt-error": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) createdAt := time.Now() var deletedAt time.Time prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", AuthorityId: "authorityID", CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(time.Now()), } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Type: linkedca.Provisioner_OIDC, AuthorityId: "authorityID", CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), }, nil }, } return test{ ctx: ctx, body: body, auth: auth, adminDB: db, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), Status: 500, Detail: "the server experienced an internal error", Message: "cannot change provisioner deletedAt", }, } }, // TODO(hs): ValidateClaims can't be mocked atm //"fail/ValidateClaims": func(t *testing.T) test { return test{} }, "fail/validateTemplates": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) createdAt := time.Now() var deletedAt time.Time prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", AuthorityId: "authorityID", CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), X509Template: &linkedca.Template{ Template: []byte("{ {{ missingFunction }} }"), }, } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", }, nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Type: linkedca.Provisioner_OIDC, AuthorityId: "authorityID", CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), }, nil }, } return test{ ctx: ctx, body: body, auth: auth, adminDB: db, statusCode: 400, err: &admin.Error{ Type: "badRequest", Status: 400, Detail: "bad request", Message: "invalid template: invalid X.509 template: error parsing template: template: template:1: function \"missingFunction\" not defined", }, } }, "fail/auth.UpdateProvisioner": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) createdAt := time.Now() var deletedAt time.Time prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", AuthorityId: "authorityID", CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", }, nil }, MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { assert.Equals(t, "provID", nu.Id) assert.Equals(t, "provName", nu.Name) return errors.New("force") }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Type: linkedca.Provisioner_OIDC, AuthorityId: "authorityID", CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), }, nil }, } return test{ ctx: ctx, body: body, auth: auth, adminDB: db, statusCode: 500, err: &admin.Error{ Type: "", // TODO(hs): this error can be improved Status: 500, Detail: "", Message: "", }, } }, "ok": func(t *testing.T) test { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("name", "provName") ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) createdAt := time.Now() var deletedAt time.Time prov := &linkedca.Provisioner{ Id: "provID", Type: linkedca.Provisioner_OIDC, Name: "provName", AuthorityId: "authorityID", CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_OIDC{ OIDC: &linkedca.OIDCProvisioner{ ClientId: "new-client-id", ClientSecret: "new-client-secret", }, }, }, } body, err := protojson.Marshal(prov) assert.FatalError(t, err) auth := &mockAdminAuthority{ MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { assert.Equals(t, "provName", name) return &provisioner.OIDC{ ID: "provID", Name: "provName", }, nil }, MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { assert.Equals(t, "provID", nu.Id) assert.Equals(t, "provName", nu.Name) return nil }, } db := &admin.MockDB{ MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { assert.Equals(t, "provID", id) return &linkedca.Provisioner{ Id: "provID", Name: "provName", Type: linkedca.Provisioner_OIDC, AuthorityId: "authorityID", CreatedAt: timestamppb.New(createdAt), DeletedAt: timestamppb.New(deletedAt), }, nil }, } return test{ ctx: ctx, body: body, auth: auth, adminDB: db, statusCode: 200, prov: prov, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() UpdateProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.FatalError(t, err) adminErr := admin.Error{} assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(adminErr.Message, "syntax error")) } else { assert.Equals(t, tc.err.Message, adminErr.Message) } return } prov := &linkedca.Provisioner{} err := readProtoJSON(res.Body, prov) assert.FatalError(t, err) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) opts := []cmp.Option{ cmpopts.IgnoreUnexported( linkedca.Provisioner{}, linkedca.ProvisionerDetails{}, linkedca.ProvisionerDetails_OIDC{}, linkedca.OIDCProvisioner{}, timestamppb.Timestamp{}, ), } if !cmp.Equal(tc.prov, prov, opts...) { t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(tc.prov, prov, opts...)) } }) } } func Test_validateTemplates(t *testing.T) { type args struct { x509 *linkedca.Template ssh *linkedca.Template } tests := []struct { name string args args err error }{ { name: "ok", args: args{}, err: nil, }, { name: "ok/x509", args: args{ x509: &linkedca.Template{ Template: []byte(`{"x": 1}`), }, }, err: nil, }, { name: "ok/ssh", args: args{ ssh: &linkedca.Template{ Template: []byte(`{"x": 1}`), }, }, err: nil, }, { name: "fail/x509-template-missing-quote", args: args{ x509: &linkedca.Template{ Template: []byte(`{ {{printf "%q" "quoted}} }`), }, }, err: errors.New("invalid X.509 template: error parsing template: template: template:1: unterminated quoted string"), }, { name: "fail/x509-template-data", args: args{ x509: &linkedca.Template{ Data: []byte(`{!?}`), }, }, err: errors.New("invalid X.509 template data: error validating json template data"), }, { name: "fail/ssh-template-unknown-function", args: args{ ssh: &linkedca.Template{ Template: []byte(`{ {{unknownFunction "foo"}} }`), }, }, err: errors.New("invalid SSH template: error parsing template: template: template:1: function \"unknownFunction\" not defined"), }, { name: "fail/ssh-template-data", args: args{ ssh: &linkedca.Template{ Data: []byte(`{!?}`), }, }, err: errors.New("invalid SSH template data: error validating json template data"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateTemplates(tt.args.x509, tt.args.ssh) if tt.err != nil { assert.Error(t, err) assert.Equals(t, tt.err.Error(), err.Error()) return } assert.Nil(t, err) }) } } ================================================ FILE: authority/admin/api/webhook.go ================================================ package api import ( "encoding/base64" "fmt" "net/http" "net/url" "github.com/go-chi/chi/v5" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/linkedca" "go.step.sm/crypto/randutil" ) // WebhookAdminResponder is the interface responsible for writing webhook admin // responses. type WebhookAdminResponder interface { CreateProvisionerWebhook(w http.ResponseWriter, r *http.Request) UpdateProvisionerWebhook(w http.ResponseWriter, r *http.Request) DeleteProvisionerWebhook(w http.ResponseWriter, r *http.Request) } // webhoookAdminResponder implements WebhookAdminResponder type webhookAdminResponder struct{} // NewWebhookAdminResponder returns a new WebhookAdminResponder func NewWebhookAdminResponder() WebhookAdminResponder { return &webhookAdminResponder{} } func validateWebhook(webhook *linkedca.Webhook) error { if webhook == nil { return nil } // name if webhook.Name == "" { return admin.NewError(admin.ErrorBadRequestType, "webhook name is required") } // url parsedURL, err := url.Parse(webhook.Url) if err != nil { return admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid") } if parsedURL.Host == "" { return admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid") } if parsedURL.Scheme != "https" { return admin.NewError(admin.ErrorBadRequestType, "webhook url must use https") } if parsedURL.User != nil { return admin.NewError(admin.ErrorBadRequestType, "webhook url may not contain username or password") } // kind if _, ok := linkedca.Webhook_Kind_name[int32(webhook.Kind)]; !ok || webhook.Kind == linkedca.Webhook_NO_KIND { return admin.NewError(admin.ErrorBadRequestType, "webhook kind %q is invalid", webhook.Kind) } return nil } func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter, r *http.Request) { ctx := r.Context() auth := mustAuthority(ctx) prov := linkedca.MustProvisionerFromContext(ctx) var newWebhook = new(linkedca.Webhook) if err := read.ProtoJSON(r.Body, newWebhook); err != nil { render.Error(w, r, err) return } if err := validateWebhook(newWebhook); err != nil { render.Error(w, r, err) return } if newWebhook.Secret != "" { err := admin.NewError(admin.ErrorBadRequestType, "webhook secret must not be set") render.Error(w, r, err) return } if newWebhook.Id != "" { err := admin.NewError(admin.ErrorBadRequestType, "webhook ID must not be set") render.Error(w, r, err) return } id, err := randutil.UUIDv4() if err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error generating webhook id")) return } newWebhook.Id = id // verify the name is unique for _, wh := range prov.Webhooks { if wh.Name == newWebhook.Name { err := admin.NewError(admin.ErrorConflictType, "provisioner %q already has a webhook with the name %q", prov.Name, newWebhook.Name) render.Error(w, r, err) return } } secret, err := randutil.Bytes(64) if err != nil { render.Error(w, r, admin.WrapErrorISE(err, "error generating webhook secret")) return } newWebhook.Secret = base64.StdEncoding.EncodeToString(secret) prov.Webhooks = append(prov.Webhooks, newWebhook) if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner webhook")) return } render.Error(w, r, admin.WrapErrorISE(err, "error creating provisioner webhook")) return } render.ProtoJSONStatus(w, newWebhook, http.StatusCreated) } func (war *webhookAdminResponder) DeleteProvisionerWebhook(w http.ResponseWriter, r *http.Request) { ctx := r.Context() auth := mustAuthority(ctx) prov := linkedca.MustProvisionerFromContext(ctx) webhookName := chi.URLParam(r, "webhookName") found := false for i, wh := range prov.Webhooks { if wh.Name == webhookName { prov.Webhooks = append(prov.Webhooks[0:i], prov.Webhooks[i+1:]...) found = true break } } if !found { render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK) return } if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error deleting provisioner webhook")) return } render.Error(w, r, admin.WrapErrorISE(err, "error deleting provisioner webhook")) return } render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK) } func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter, r *http.Request) { ctx := r.Context() auth := mustAuthority(ctx) prov := linkedca.MustProvisionerFromContext(ctx) var newWebhook = new(linkedca.Webhook) if err := read.ProtoJSON(r.Body, newWebhook); err != nil { render.Error(w, r, err) return } if err := validateWebhook(newWebhook); err != nil { render.Error(w, r, err) return } found := false for i, wh := range prov.Webhooks { if wh.Name != newWebhook.Name { continue } if newWebhook.Secret != "" && newWebhook.Secret != wh.Secret { err := admin.NewError(admin.ErrorBadRequestType, "webhook secret cannot be updated") render.Error(w, r, err) return } newWebhook.Secret = wh.Secret if newWebhook.Id != "" && newWebhook.Id != wh.Id { err := admin.NewError(admin.ErrorBadRequestType, "webhook ID cannot be updated") render.Error(w, r, err) return } newWebhook.Id = wh.Id prov.Webhooks[i] = newWebhook found = true break } if !found { msg := fmt.Sprintf("provisioner %q has no webhook with the name %q", prov.Name, newWebhook.Name) err := admin.NewError(admin.ErrorNotFoundType, "%s", msg) render.Error(w, r, err) return } if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner webhook")) return } render.Error(w, r, admin.WrapErrorISE(err, "error updating provisioner webhook")) return } // Return a copy without the signing secret. Include the client-supplied // auth secrets since those may have been updated in this request and we // should show in the response that they changed whResponse := &linkedca.Webhook{ Id: newWebhook.Id, Name: newWebhook.Name, Url: newWebhook.Url, Kind: newWebhook.Kind, CertType: newWebhook.CertType, Auth: newWebhook.Auth, DisableTlsClientAuth: newWebhook.DisableTlsClientAuth, } render.ProtoJSONStatus(w, whResponse, http.StatusCreated) } ================================================ FILE: authority/admin/api/webhook_test.go ================================================ package api import ( "bytes" "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/go-chi/chi/v5" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/linkedca" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/encoding/protojson" ) // ignore secret and id since those are set by the server func assertEqualWebhook(t *testing.T, a, b *linkedca.Webhook) { assert.Equal(t, a.Name, b.Name) assert.Equal(t, a.Url, b.Url) assert.Equal(t, a.Kind, b.Kind) assert.Equal(t, a.CertType, b.CertType) assert.Equal(t, a.DisableTlsClientAuth, b.DisableTlsClientAuth) assert.Equal(t, a.GetAuth(), b.GetAuth()) } func TestWebhookAdminResponder_CreateProvisionerWebhook(t *testing.T) { type test struct { auth adminAuthority body []byte ctx context.Context err *admin.Error response *linkedca.Webhook statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/existing-webhook": func(t *testing.T) test { webhook := &linkedca.Webhook{ Name: "already-exists", Url: "https://example.com", } prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{webhook}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewError(admin.ErrorConflictType, `provisioner "provName" already has a webhook with the name "already-exists"`) err.Message = `provisioner "provName" already has a webhook with the name "already-exists"` body := []byte(` { "name": "already-exists", "url": "https://example.com", "kind": "ENRICHING" }`) return test{ ctx: ctx, body: body, err: err, statusCode: 409, } }, "fail/read.ProtoJSON": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ ctx: ctx, body: body, err: adminErr, statusCode: 400, } }, "fail/missing-name": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook name is required") adminErr.Message = "webhook name is required" body := []byte(`{"url": "https://example.com", "kind": "ENRICHING"}`) return test{ ctx: ctx, body: body, err: adminErr, statusCode: 400, } }, "fail/missing-url": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid") adminErr.Message = "webhook url is invalid" body := []byte(`{"name": "metadata", "kind": "ENRICHING"}`) return test{ ctx: ctx, body: body, err: adminErr, statusCode: 400, } }, "fail/relative-url": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid") adminErr.Message = "webhook url is invalid" body := []byte(`{"name": "metadata", "url": "example.com/path", "kind": "ENRICHING"}`) return test{ ctx: ctx, body: body, err: adminErr, statusCode: 400, } }, "fail/http-url": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url must use https") adminErr.Message = "webhook url must use https" body := []byte(`{"name": "metadata", "url": "http://example.com", "kind": "ENRICHING"}`) return test{ ctx: ctx, body: body, err: adminErr, statusCode: 400, } }, "fail/basic-auth-in-url": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url may not contain username or password") adminErr.Message = "webhook url may not contain username or password" body := []byte(` { "name": "metadata", "url": "https://user:pass@example.com", "kind": "ENRICHING" }`) return test{ ctx: ctx, body: body, err: adminErr, statusCode: 400, } }, "fail/secret-in-request": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook secret must not be set") adminErr.Message = "webhook secret must not be set" body := []byte(` { "name": "metadata", "url": "https://example.com", "kind": "ENRICHING", "secret": "secret" }`) return test{ ctx: ctx, body: body, err: adminErr, statusCode: 400, } }, "fail/unsupported-webhook-kind": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, `(line 5:13): invalid value for enum field kind: "UNSUPPORTED"`) adminErr.Message = `(line 5:13): invalid value for enum field kind: "UNSUPPORTED"` body := []byte(` { "name": "metadata", "url": "https://example.com", "kind": "UNSUPPORTED", }`) return test{ ctx: ctx, body: body, err: adminErr, statusCode: 400, } }, "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { adm := &linkedca.Admin{ Subject: "step", } prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithAdmin(context.Background(), adm) ctx = linkedca.NewContextWithProvisioner(ctx, prov) adminErr := admin.NewError(admin.ErrorServerInternalType, "error creating provisioner webhook: force") adminErr.Message = "error creating provisioner webhook: force" body := []byte(`{"name": "metadata", "url": "https://example.com", "kind": "ENRICHING"}`) return test{ ctx: ctx, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ Typ: authority.StoreFailure, Err: errors.New("force"), } }, }, body: body, err: adminErr, statusCode: 500, } }, "ok": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) body := []byte(`{"name": "metadata", "url": "https://example.com", "kind": "ENRICHING", "certType": "X509"}`) return test{ ctx: ctx, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { assert.Equal(t, linkedca.Webhook_X509, nu.Webhooks[0].CertType) return nil }, }, body: body, response: &linkedca.Webhook{ Name: "metadata", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING, CertType: linkedca.Webhook_X509, }, statusCode: 201, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, &admin.MockDB{}) war := NewWebhookAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() war.CreateProvisionerWebhook(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", // we expect a full string match. if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { assert.Equal(t, tc.err.Message, ae.Message) } return } resp := &linkedca.Webhook{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, protojson.Unmarshal(body, resp)) assertEqualWebhook(t, tc.response, resp) assert.NotEmpty(t, resp.Secret) assert.NotEmpty(t, resp.Id) }) } } func TestWebhookAdminResponder_DeleteProvisionerWebhook(t *testing.T) { type test struct { auth adminAuthority err *admin.Error statusCode int provisionerWebhooks []*linkedca.Webhook webhookName string } var tests = map[string]func(t *testing.T) test{ "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { adminErr := admin.NewError(admin.ErrorServerInternalType, "error deleting provisioner webhook: force") adminErr.Message = "error deleting provisioner webhook: force" return test{ err: adminErr, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ Typ: authority.StoreFailure, Err: errors.New("force"), } }, }, statusCode: 500, webhookName: "my-webhook", provisionerWebhooks: []*linkedca.Webhook{ {Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}, }, } }, "ok/not-found": func(t *testing.T) test { return test{ statusCode: 200, webhookName: "no-exists", provisionerWebhooks: nil, } }, "ok": func(t *testing.T) test { return test{ statusCode: 200, webhookName: "exists", auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { assert.Equal(t, nu.Webhooks, []*linkedca.Webhook{ {Name: "my-2nd-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}, }) return nil }, }, provisionerWebhooks: []*linkedca.Webhook{ {Name: "exists", Url: "https.example.com", Kind: linkedca.Webhook_ENRICHING}, {Name: "my-2nd-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}, }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("webhookName", tc.webhookName) ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) prov := &linkedca.Provisioner{ Name: "provName", Webhooks: tc.provisionerWebhooks, } ctx = linkedca.NewContextWithProvisioner(ctx, prov) ctx = admin.NewContext(ctx, &admin.MockDB{}) req := httptest.NewRequest("DELETE", "/foo", http.NoBody).WithContext(ctx) war := NewWebhookAdminResponder() w := httptest.NewRecorder() war.DeleteProvisionerWebhook(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", // we expect a full string match. if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { assert.Equal(t, tc.err.Message, ae.Message) } return } body, err := io.ReadAll(res.Body) assert.NoError(t, err) res.Body.Close() response := DeleteResponse{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) assert.Equal(t, "ok", response.Status) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) }) } } func TestWebhookAdminResponder_UpdateProvisionerWebhook(t *testing.T) { type test struct { auth adminAuthority adminDB admin.DB body []byte ctx context.Context err *admin.Error response *linkedca.Webhook statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{{Name: "exists", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) err := admin.NewError(admin.ErrorNotFoundType, `provisioner "provName" has no webhook with the name "no-exists"`) err.Message = `provisioner "provName" has no webhook with the name "no-exists"` body := []byte(` { "name": "no-exists", "url": "https://example.com", "kind": "ENRICHING" }`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: err, statusCode: 404, } }, "fail/read.ProtoJSON": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "proto: syntax error (line 1:2): invalid value ?") adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/missing-name": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook name is required") adminErr.Message = "webhook name is required" body := []byte(`{"url": "https://example.com", "kind": "ENRICHING"}`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/missing-url": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid") adminErr.Message = "webhook url is invalid" body := []byte(`{"name": "metadata", "kind": "ENRICHING"}`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/relative-url": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url is invalid") adminErr.Message = "webhook url is invalid" body := []byte(`{"name": "metadata", "url": "example.com/path", "kind": "ENRICHING"}`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/http-url": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url must use https") adminErr.Message = "webhook url must use https" body := []byte(`{"name": "metadata", "url": "http://example.com", "kind": "ENRICHING"}`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/basic-auth-in-url": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook url may not contain username or password") adminErr.Message = "webhook url may not contain username or password" body := []byte(` { "name": "my-webhook", "url": "https://user:pass@example.com", "kind": "ENRICHING" }`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, } }, "fail/different-secret-in-request": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING, Secret: "c2VjcmV0"}}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorBadRequestType, "webhook secret cannot be updated") adminErr.Message = "webhook secret cannot be updated" body := []byte(` { "name": "my-webhook", "url": "https://example.com", "kind": "ENRICHING", "secret": "secret" }`) return test{ ctx: ctx, body: body, err: adminErr, statusCode: 400, } }, "fail/auth.UpdateProvisioner-error": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) adminErr := admin.NewError(admin.ErrorServerInternalType, "error updating provisioner webhook: force") adminErr.Message = "error updating provisioner webhook: force" body := []byte(`{"name": "my-webhook", "url": "https://example.com", "kind": "ENRICHING"}`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ Typ: authority.StoreFailure, Err: errors.New("force"), } }, }, body: body, err: adminErr, statusCode: 500, } }, "ok": func(t *testing.T) test { prov := &linkedca.Provisioner{ Name: "provName", Webhooks: []*linkedca.Webhook{{Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING}}, } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) body := []byte(`{"name": "my-webhook", "url": "https://example.com", "kind": "ENRICHING"}`) return test{ ctx: ctx, adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return nil }, }, body: body, response: &linkedca.Webhook{ Name: "my-webhook", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING, }, statusCode: 201, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { mockMustAuthority(t, tc.auth) ctx := admin.NewContext(tc.ctx, tc.adminDB) war := NewWebhookAdminResponder() req := httptest.NewRequest("PUT", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(ctx) w := httptest.NewRecorder() war.UpdateProvisionerWebhook(w, req) res := w.Result() assert.Equal(t, tc.statusCode, res.StatusCode) if res.StatusCode >= 400 { body, err := io.ReadAll(res.Body) res.Body.Close() assert.NoError(t, err) ae := testAdminError{} assert.NoError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.Equal(t, tc.err.Type, ae.Type) assert.Equal(t, tc.err.StatusCode(), res.StatusCode) assert.Equal(t, tc.err.Detail, ae.Detail) assert.Equal(t, []string{"application/json"}, res.Header["Content-Type"]) // when the error message starts with "proto", we expect it to have // a syntax error (in the tests). If the message doesn't start with "proto", // we expect a full string match. if strings.HasPrefix(tc.err.Message, "proto:") { assert.True(t, strings.Contains(ae.Message, "syntax error")) } else { assert.Equal(t, tc.err.Message, ae.Message) } return } resp := &linkedca.Webhook{} body, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.NoError(t, protojson.Unmarshal(body, resp)) assertEqualWebhook(t, tc.response, resp) }) } } ================================================ FILE: authority/admin/db/nosql/admin.go ================================================ package nosql import ( "context" "encoding/json" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/linkedca" "github.com/smallstep/nosql" "google.golang.org/protobuf/types/known/timestamppb" ) // dbAdmin is the database representation of the Admin type. type dbAdmin struct { ID string `json:"id"` AuthorityID string `json:"authorityID"` ProvisionerID string `json:"provisionerID"` Subject string `json:"subject"` Type linkedca.Admin_Type `json:"type"` CreatedAt time.Time `json:"createdAt"` DeletedAt time.Time `json:"deletedAt"` } func (dba *dbAdmin) convert() *linkedca.Admin { return &linkedca.Admin{ Id: dba.ID, AuthorityId: dba.AuthorityID, ProvisionerId: dba.ProvisionerID, Subject: dba.Subject, Type: dba.Type, CreatedAt: timestamppb.New(dba.CreatedAt), DeletedAt: timestamppb.New(dba.DeletedAt), } } func (dba *dbAdmin) clone() *dbAdmin { u := *dba return &u } func (db *DB) getDBAdminBytes(_ context.Context, id string) ([]byte, error) { data, err := db.db.Get(adminsTable, []byte(id)) if nosql.IsErrNotFound(err) { return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading admin %s", id) } return data, nil } func (db *DB) unmarshalDBAdmin(data []byte, id string) (*dbAdmin, error) { var dba = new(dbAdmin) if err := json.Unmarshal(data, dba); err != nil { return nil, errors.Wrapf(err, "error unmarshaling admin %s into dbAdmin", id) } if !dba.DeletedAt.IsZero() { return nil, admin.NewError(admin.ErrorDeletedType, "admin %s is deleted", id) } if dba.AuthorityID != db.authorityID { return nil, admin.NewError(admin.ErrorAuthorityMismatchType, "admin %s is not owned by authority %s", dba.ID, db.authorityID) } return dba, nil } func (db *DB) getDBAdmin(ctx context.Context, id string) (*dbAdmin, error) { data, err := db.getDBAdminBytes(ctx, id) if err != nil { return nil, err } dba, err := db.unmarshalDBAdmin(data, id) if err != nil { return nil, err } return dba, nil } func (db *DB) unmarshalAdmin(data []byte, id string) (*linkedca.Admin, error) { dba, err := db.unmarshalDBAdmin(data, id) if err != nil { return nil, err } return dba.convert(), nil } // GetAdmin retrieves and unmarshals a admin from the database. func (db *DB) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) { data, err := db.getDBAdminBytes(ctx, id) if err != nil { return nil, err } adm, err := db.unmarshalAdmin(data, id) if err != nil { return nil, err } return adm, nil } // GetAdmins retrieves and unmarshals all active (not deleted) admins // from the database. // TODO should we be paginating? func (db *DB) GetAdmins(context.Context) ([]*linkedca.Admin, error) { dbEntries, err := db.db.List(adminsTable) if err != nil { return nil, errors.Wrap(err, "error loading admins") } var admins = []*linkedca.Admin{} for _, entry := range dbEntries { adm, err := db.unmarshalAdmin(entry.Value, string(entry.Key)) if err != nil { var ae *admin.Error if errors.As(err, &ae) { if ae.IsType(admin.ErrorDeletedType) || ae.IsType(admin.ErrorAuthorityMismatchType) { continue } return nil, err } return nil, err } if adm.AuthorityId != db.authorityID { continue } admins = append(admins, adm) } return admins, nil } // CreateAdmin stores a new admin to the database. func (db *DB) CreateAdmin(ctx context.Context, adm *linkedca.Admin) error { var err error adm.Id, err = randID() if err != nil { return admin.WrapErrorISE(err, "error generating random id for admin") } adm.AuthorityId = db.authorityID dba := &dbAdmin{ ID: adm.Id, AuthorityID: db.authorityID, ProvisionerID: adm.ProvisionerId, Subject: adm.Subject, Type: adm.Type, CreatedAt: clock.Now(), } return db.save(ctx, dba.ID, dba, nil, "admin", adminsTable) } // UpdateAdmin saves an updated admin to the database. func (db *DB) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error { old, err := db.getDBAdmin(ctx, adm.Id) if err != nil { return err } nu := old.clone() nu.Type = adm.Type return db.save(ctx, old.ID, nu, old, "admin", adminsTable) } // DeleteAdmin saves an updated admin to the database. func (db *DB) DeleteAdmin(ctx context.Context, id string) error { old, err := db.getDBAdmin(ctx, id) if err != nil { return err } nu := old.clone() nu.DeletedAt = clock.Now() return db.save(ctx, old.ID, nu, old, "admin", adminsTable) } ================================================ FILE: authority/admin/db/nosql/admin_test.go ================================================ package nosql import ( "context" "encoding/json" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/db" "github.com/smallstep/linkedca" "github.com/smallstep/nosql" nosqldb "github.com/smallstep/nosql/database" "google.golang.org/protobuf/types/known/timestamppb" ) func TestDB_getDBAdminBytes(t *testing.T) { adminID := "adminID" type test struct { db nosql.DB err error adminErr *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return nil, errors.New("force") }, }, err: errors.New("error loading admin adminID: force"), } }, "ok": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return []byte("foo"), nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if b, err := d.getDBAdminBytes(context.Background(), adminID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) { assert.Equals(t, string(b), "foo") } }) } } func TestDB_getDBAdmin(t *testing.T) { adminID := "adminID" type test struct { db nosql.DB err error adminErr *admin.Error dba *dbAdmin } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return nil, errors.New("force") }, }, err: errors.New("error loading admin adminID: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return []byte("foo"), nil }, }, err: errors.New("error unmarshaling admin adminID into dbAdmin"), } }, "fail/deleted": func(t *testing.T) test { now := clock.Now() dba := &dbAdmin{ ID: adminID, AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: now, DeletedAt: now, } b, err := json.Marshal(dba) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return b, nil }, }, adminErr: admin.NewError(admin.ErrorDeletedType, "admin adminID is deleted"), } }, "ok": func(t *testing.T) test { now := clock.Now() dba := &dbAdmin{ ID: adminID, AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: now, } b, err := json.Marshal(dba) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return b, nil }, }, dba: dba, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if dba, err := d.getDBAdmin(context.Background(), adminID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dba.ID, adminID) assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID) assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID) assert.Equals(t, dba.Subject, tc.dba.Subject) assert.Equals(t, dba.Type, tc.dba.Type) assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt) assert.Fatal(t, dba.DeletedAt.IsZero()) } }) } } func TestDB_unmarshalDBAdmin(t *testing.T) { adminID := "adminID" type test struct { in []byte err error adminErr *admin.Error dba *dbAdmin } var tests = map[string]func(t *testing.T) test{ "fail/unmarshal-error": func(t *testing.T) test { return test{ in: []byte("foo"), err: errors.New("error unmarshaling admin adminID into dbAdmin"), } }, "fail/deleted-error": func(t *testing.T) test { dba := &dbAdmin{ DeletedAt: time.Now(), } data, err := json.Marshal(dba) assert.FatalError(t, err) return test{ in: data, adminErr: admin.NewError(admin.ErrorDeletedType, "admin adminID is deleted"), } }, "fail/authority-mismatch-error": func(t *testing.T) test { dba := &dbAdmin{ ID: adminID, AuthorityID: "foo", } data, err := json.Marshal(dba) assert.FatalError(t, err) return test{ in: data, adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, "admin %s is not owned by authority %s", adminID, admin.DefaultAuthorityID), } }, "ok": func(t *testing.T) test { dba := &dbAdmin{ ID: adminID, Subject: "max@smallstep.com", ProvisionerID: "provID", AuthorityID: admin.DefaultAuthorityID, Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: clock.Now(), } data, err := json.Marshal(dba) assert.FatalError(t, err) return test{ in: data, dba: dba, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{authorityID: admin.DefaultAuthorityID} if dba, err := d.unmarshalDBAdmin(tc.in, adminID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dba.ID, adminID) assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID) assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID) assert.Equals(t, dba.Subject, tc.dba.Subject) assert.Equals(t, dba.Type, tc.dba.Type) assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt) assert.Fatal(t, dba.DeletedAt.IsZero()) } }) } } func TestDB_unmarshalAdmin(t *testing.T) { adminID := "adminID" type test struct { in []byte err error adminErr *admin.Error dba *dbAdmin } var tests = map[string]func(t *testing.T) test{ "fail/unmarshal-error": func(t *testing.T) test { return test{ in: []byte("foo"), err: errors.New("error unmarshaling admin adminID into dbAdmin"), } }, "fail/deleted-error": func(t *testing.T) test { dba := &dbAdmin{ DeletedAt: time.Now(), } data, err := json.Marshal(dba) assert.FatalError(t, err) return test{ in: data, adminErr: admin.NewError(admin.ErrorDeletedType, "admin adminID is deleted"), } }, "ok": func(t *testing.T) test { dba := &dbAdmin{ ID: adminID, Subject: "max@smallstep.com", ProvisionerID: "provID", AuthorityID: admin.DefaultAuthorityID, Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: clock.Now(), } data, err := json.Marshal(dba) assert.FatalError(t, err) return test{ in: data, dba: dba, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{authorityID: admin.DefaultAuthorityID} if adm, err := d.unmarshalAdmin(tc.in, adminID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, adm.Id, adminID) assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID) assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID) assert.Equals(t, adm.Subject, tc.dba.Subject) assert.Equals(t, adm.Type, tc.dba.Type) assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt)) assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt)) } }) } } func TestDB_GetAdmin(t *testing.T) { adminID := "adminID" type test struct { db nosql.DB err error adminErr *admin.Error dba *dbAdmin } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return nil, errors.New("force") }, }, err: errors.New("error loading admin adminID: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return []byte("foo"), nil }, }, err: errors.New("error unmarshaling admin adminID into dbAdmin"), } }, "fail/deleted": func(t *testing.T) test { dba := &dbAdmin{ ID: adminID, AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: clock.Now(), DeletedAt: clock.Now(), } b, err := json.Marshal(dba) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return b, nil }, }, dba: dba, adminErr: admin.NewError(admin.ErrorDeletedType, "admin adminID is deleted"), } }, "fail/authorityID-mismatch": func(t *testing.T) test { dba := &dbAdmin{ ID: adminID, AuthorityID: "foo", ProvisionerID: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: clock.Now(), } b, err := json.Marshal(dba) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return b, nil }, }, dba: dba, adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, "admin %s is not owned by authority %s", dba.ID, admin.DefaultAuthorityID), } }, "ok": func(t *testing.T) test { dba := &dbAdmin{ ID: adminID, AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: clock.Now(), } b, err := json.Marshal(dba) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return b, nil }, }, dba: dba, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if adm, err := d.GetAdmin(context.Background(), adminID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, adm.Id, adminID) assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID) assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID) assert.Equals(t, adm.Subject, tc.dba.Subject) assert.Equals(t, adm.Type, tc.dba.Type) assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt)) assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt)) } }) } } func TestDB_DeleteAdmin(t *testing.T) { adminID := "adminID" type test struct { db nosql.DB err error adminErr *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return nil, errors.New("force") }, }, err: errors.New("error loading admin adminID: force"), } }, "fail/save-error": func(t *testing.T) test { dba := &dbAdmin{ ID: adminID, AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: clock.Now(), } data, err := json.Marshal(dba) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return data, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) assert.Equals(t, string(old), string(data)) var _dba = new(dbAdmin) assert.FatalError(t, json.Unmarshal(nu, _dba)) assert.Equals(t, _dba.ID, dba.ID) assert.Equals(t, _dba.AuthorityID, dba.AuthorityID) assert.Equals(t, _dba.ProvisionerID, dba.ProvisionerID) assert.Equals(t, _dba.Subject, dba.Subject) assert.Equals(t, _dba.Type, dba.Type) assert.Equals(t, _dba.CreatedAt, dba.CreatedAt) assert.True(t, _dba.DeletedAt.Before(time.Now())) assert.True(t, _dba.DeletedAt.After(time.Now().Add(-time.Minute))) return nil, false, errors.New("force") }, }, err: errors.New("error saving authority admin: force"), } }, "ok": func(t *testing.T) test { dba := &dbAdmin{ ID: adminID, AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: clock.Now(), } data, err := json.Marshal(dba) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return data, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) assert.Equals(t, string(old), string(data)) var _dba = new(dbAdmin) assert.FatalError(t, json.Unmarshal(nu, _dba)) assert.Equals(t, _dba.ID, dba.ID) assert.Equals(t, _dba.AuthorityID, dba.AuthorityID) assert.Equals(t, _dba.ProvisionerID, dba.ProvisionerID) assert.Equals(t, _dba.Subject, dba.Subject) assert.Equals(t, _dba.Type, dba.Type) assert.Equals(t, _dba.CreatedAt, dba.CreatedAt) assert.True(t, _dba.DeletedAt.Before(time.Now())) assert.True(t, _dba.DeletedAt.After(time.Now().Add(-time.Minute))) return nu, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if err := d.DeleteAdmin(context.Background(), adminID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } }) } } func TestDB_UpdateAdmin(t *testing.T) { adminID := "adminID" type test struct { db nosql.DB err error adminErr *admin.Error adm *linkedca.Admin } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ adm: &linkedca.Admin{Id: adminID}, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ adm: &linkedca.Admin{Id: adminID}, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return nil, errors.New("force") }, }, err: errors.New("error loading admin adminID: force"), } }, "fail/save-error": func(t *testing.T) test { dba := &dbAdmin{ ID: adminID, AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: clock.Now(), } upd := dba.convert() upd.Type = linkedca.Admin_ADMIN data, err := json.Marshal(dba) assert.FatalError(t, err) return test{ adm: upd, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return data, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) assert.Equals(t, string(old), string(data)) var _dba = new(dbAdmin) assert.FatalError(t, json.Unmarshal(nu, _dba)) assert.Equals(t, _dba.ID, dba.ID) assert.Equals(t, _dba.AuthorityID, dba.AuthorityID) assert.Equals(t, _dba.ProvisionerID, dba.ProvisionerID) assert.Equals(t, _dba.Subject, dba.Subject) assert.Equals(t, _dba.Type, linkedca.Admin_ADMIN) assert.Equals(t, _dba.CreatedAt, dba.CreatedAt) return nil, false, errors.New("force") }, }, err: errors.New("error saving authority admin: force"), } }, "ok": func(t *testing.T) test { dba := &dbAdmin{ ID: adminID, AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: clock.Now(), } upd := dba.convert() upd.Type = linkedca.Admin_ADMIN data, err := json.Marshal(dba) assert.FatalError(t, err) return test{ adm: upd, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) return data, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) assert.Equals(t, string(old), string(data)) var _dba = new(dbAdmin) assert.FatalError(t, json.Unmarshal(nu, _dba)) assert.Equals(t, _dba.ID, dba.ID) assert.Equals(t, _dba.AuthorityID, dba.AuthorityID) assert.Equals(t, _dba.ProvisionerID, dba.ProvisionerID) assert.Equals(t, _dba.Subject, dba.Subject) assert.Equals(t, _dba.Type, linkedca.Admin_ADMIN) assert.Equals(t, _dba.CreatedAt, dba.CreatedAt) return nu, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if err := d.UpdateAdmin(context.Background(), tc.adm); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } }) } } func TestDB_CreateAdmin(t *testing.T) { type test struct { db nosql.DB err error adminErr *admin.Error adm *linkedca.Admin } var tests = map[string]func(t *testing.T) test{ "fail/save-error": func(t *testing.T) test { adm := &linkedca.Admin{ AuthorityId: admin.DefaultAuthorityID, ProvisionerId: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_ADMIN, } return test{ adm: adm, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, old, nil) var _dba = new(dbAdmin) assert.FatalError(t, json.Unmarshal(nu, _dba)) assert.True(t, _dba.ID != "" && _dba.ID == string(key)) assert.Equals(t, _dba.AuthorityID, adm.AuthorityId) assert.Equals(t, _dba.ProvisionerID, adm.ProvisionerId) assert.Equals(t, _dba.Subject, adm.Subject) assert.Equals(t, _dba.Type, linkedca.Admin_ADMIN) assert.True(t, _dba.CreatedAt.Before(time.Now())) assert.True(t, _dba.CreatedAt.After(time.Now().Add(-time.Minute))) return nil, false, errors.New("force") }, }, err: errors.New("error saving authority admin: force"), } }, "ok": func(t *testing.T) test { adm := &linkedca.Admin{ AuthorityId: admin.DefaultAuthorityID, ProvisionerId: "provID", Subject: "max@smallstep.com", Type: linkedca.Admin_ADMIN, } return test{ adm: adm, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, old, nil) var _dba = new(dbAdmin) assert.FatalError(t, json.Unmarshal(nu, _dba)) assert.True(t, _dba.ID != "" && _dba.ID == string(key)) assert.Equals(t, _dba.AuthorityID, adm.AuthorityId) assert.Equals(t, _dba.ProvisionerID, adm.ProvisionerId) assert.Equals(t, _dba.Subject, adm.Subject) assert.Equals(t, _dba.Type, linkedca.Admin_ADMIN) assert.True(t, _dba.CreatedAt.Before(time.Now())) assert.True(t, _dba.CreatedAt.After(time.Now().Add(-time.Minute))) return nu, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if err := d.CreateAdmin(context.Background(), tc.adm); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } }) } } func TestDB_GetAdmins(t *testing.T) { now := clock.Now() fooAdmin := &dbAdmin{ ID: "foo", AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "foo@smallstep.com", Type: linkedca.Admin_SUPER_ADMIN, CreatedAt: now, } foob, err := json.Marshal(fooAdmin) assert.FatalError(t, err) barAdmin := &dbAdmin{ ID: "bar", AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "bar@smallstep.com", Type: linkedca.Admin_ADMIN, CreatedAt: now, DeletedAt: now, } barb, err := json.Marshal(barAdmin) assert.FatalError(t, err) bazAdmin := &dbAdmin{ ID: "baz", AuthorityID: "bazzer", ProvisionerID: "provID", Subject: "baz@smallstep.com", Type: linkedca.Admin_ADMIN, CreatedAt: now, } bazb, err := json.Marshal(bazAdmin) assert.FatalError(t, err) zapAdmin := &dbAdmin{ ID: "zap", AuthorityID: admin.DefaultAuthorityID, ProvisionerID: "provID", Subject: "zap@smallstep.com", Type: linkedca.Admin_ADMIN, CreatedAt: now, } zapb, err := json.Marshal(zapAdmin) assert.FatalError(t, err) type test struct { db nosql.DB err error adminErr *admin.Error verify func(*testing.T, []*linkedca.Admin) } var tests = map[string]func(t *testing.T) test{ "fail/db.List-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, adminsTable) return nil, errors.New("force") }, }, err: errors.New("error loading admins: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { ret := []*nosqldb.Entry{ {Bucket: adminsTable, Key: []byte("foo"), Value: foob}, {Bucket: adminsTable, Key: []byte("bar"), Value: barb}, {Bucket: adminsTable, Key: []byte("zap"), Value: []byte("zap")}, } return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, adminsTable) return ret, nil }, }, err: errors.New("error unmarshaling admin zap into dbAdmin"), } }, "ok/none": func(t *testing.T) test { ret := []*nosqldb.Entry{} return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, adminsTable) return ret, nil }, }, verify: func(t *testing.T, admins []*linkedca.Admin) { assert.Equals(t, len(admins), 0) }, } }, "ok/only-invalid": func(t *testing.T) test { ret := []*nosqldb.Entry{ {Bucket: adminsTable, Key: []byte("bar"), Value: barb}, {Bucket: adminsTable, Key: []byte("baz"), Value: bazb}, } return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, adminsTable) return ret, nil }, }, verify: func(t *testing.T, admins []*linkedca.Admin) { assert.Equals(t, len(admins), 0) }, } }, "ok": func(t *testing.T) test { ret := []*nosqldb.Entry{ {Bucket: adminsTable, Key: []byte("foo"), Value: foob}, {Bucket: adminsTable, Key: []byte("bar"), Value: barb}, {Bucket: adminsTable, Key: []byte("baz"), Value: bazb}, {Bucket: adminsTable, Key: []byte("zap"), Value: zapb}, } return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, adminsTable) return ret, nil }, }, verify: func(t *testing.T, admins []*linkedca.Admin) { assert.Equals(t, len(admins), 2) assert.Equals(t, admins[0].Id, fooAdmin.ID) assert.Equals(t, admins[0].AuthorityId, fooAdmin.AuthorityID) assert.Equals(t, admins[0].ProvisionerId, fooAdmin.ProvisionerID) assert.Equals(t, admins[0].Subject, fooAdmin.Subject) assert.Equals(t, admins[0].Type, fooAdmin.Type) assert.Equals(t, admins[0].CreatedAt, timestamppb.New(fooAdmin.CreatedAt)) assert.Equals(t, admins[0].DeletedAt, timestamppb.New(fooAdmin.DeletedAt)) assert.Equals(t, admins[1].Id, zapAdmin.ID) assert.Equals(t, admins[1].AuthorityId, zapAdmin.AuthorityID) assert.Equals(t, admins[1].ProvisionerId, zapAdmin.ProvisionerID) assert.Equals(t, admins[1].Subject, zapAdmin.Subject) assert.Equals(t, admins[1].Type, zapAdmin.Type) assert.Equals(t, admins[1].CreatedAt, timestamppb.New(zapAdmin.CreatedAt)) assert.Equals(t, admins[1].DeletedAt, timestamppb.New(zapAdmin.DeletedAt)) }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if admins, err := d.GetAdmins(context.Background()); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { tc.verify(t, admins) } }) } } ================================================ FILE: authority/admin/db/nosql/nosql.go ================================================ package nosql import ( "context" "encoding/json" "time" "github.com/pkg/errors" nosqlDB "github.com/smallstep/nosql/database" "go.step.sm/crypto/randutil" ) var ( adminsTable = []byte("admins") provisionersTable = []byte("provisioners") authorityPoliciesTable = []byte("authority_policies") ) // DB is a struct that implements the AdminDB interface. type DB struct { db nosqlDB.DB authorityID string } // New configures and returns a new Authority DB backend implemented using a nosql DB. func New(db nosqlDB.DB, authorityID string) (*DB, error) { tables := [][]byte{adminsTable, provisionersTable, authorityPoliciesTable} for _, b := range tables { if err := db.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", string(b)) } } return &DB{db, authorityID}, nil } // save writes the new data to the database, overwriting the old data if it // existed. func (db *DB) save(_ context.Context, id string, nu, old interface{}, typ string, table []byte) error { var ( err error newB []byte ) if nu == nil { newB = nil } else { newB, err = json.Marshal(nu) if err != nil { return errors.Wrapf(err, "error marshaling authority type: %s, value: %v", typ, nu) } } var oldB []byte if old == nil { oldB = nil } else { oldB, err = json.Marshal(old) if err != nil { return errors.Wrapf(err, "error marshaling admin type: %s, value: %v", typ, old) } } _, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB) switch { case err != nil: return errors.Wrapf(err, "error saving authority %s", typ) case !swapped: return errors.Errorf("error saving authority %s; changed since last read", typ) default: return nil } } func randID() (val string, err error) { val, err = randutil.UUIDv4() if err != nil { return "", errors.Wrap(err, "error generating random alphanumeric ID") } return val, nil } // Clock that returns time in UTC rounded to seconds. type Clock struct{} // Now returns the UTC time rounded to seconds. func (c *Clock) Now() time.Time { return time.Now().UTC().Truncate(time.Second) } var clock = new(Clock) ================================================ FILE: authority/admin/db/nosql/policy.go ================================================ package nosql import ( "context" "encoding/json" "fmt" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/nosql" ) type dbX509Policy struct { Allow *dbX509Names `json:"allow,omitempty"` Deny *dbX509Names `json:"deny,omitempty"` AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"` } type dbX509Names struct { CommonNames []string `json:"cn,omitempty"` DNSDomains []string `json:"dns,omitempty"` IPRanges []string `json:"ip,omitempty"` EmailAddresses []string `json:"email,omitempty"` URIDomains []string `json:"uri,omitempty"` } type dbSSHPolicy struct { // User contains SSH user certificate options. User *dbSSHUserPolicy `json:"user,omitempty"` // Host contains SSH host certificate options. Host *dbSSHHostPolicy `json:"host,omitempty"` } type dbSSHHostPolicy struct { Allow *dbSSHHostNames `json:"allow,omitempty"` Deny *dbSSHHostNames `json:"deny,omitempty"` } type dbSSHHostNames struct { DNSDomains []string `json:"dns,omitempty"` IPRanges []string `json:"ip,omitempty"` Principals []string `json:"principal,omitempty"` } type dbSSHUserPolicy struct { Allow *dbSSHUserNames `json:"allow,omitempty"` Deny *dbSSHUserNames `json:"deny,omitempty"` } type dbSSHUserNames struct { EmailAddresses []string `json:"email,omitempty"` Principals []string `json:"principal,omitempty"` } type dbPolicy struct { X509 *dbX509Policy `json:"x509,omitempty"` SSH *dbSSHPolicy `json:"ssh,omitempty"` } type dbAuthorityPolicy struct { ID string `json:"id"` AuthorityID string `json:"authorityID"` Policy *dbPolicy `json:"policy,omitempty"` } func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy { if dbap == nil { return nil } return dbToLinked(dbap.Policy) } func (db *DB) getDBAuthorityPolicyBytes(_ context.Context, authorityID string) ([]byte, error) { data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID)) if nosql.IsErrNotFound(err) { return nil, admin.NewError(admin.ErrorNotFoundType, "authority policy not found") } else if err != nil { return nil, fmt.Errorf("error loading authority policy: %w", err) } return data, nil } func (db *DB) unmarshalDBAuthorityPolicy(data []byte) (*dbAuthorityPolicy, error) { if len(data) == 0 { //nolint:nilnil // legacy return nil, nil } var dba = new(dbAuthorityPolicy) if err := json.Unmarshal(data, dba); err != nil { return nil, fmt.Errorf("error unmarshaling policy bytes into dbAuthorityPolicy: %w", err) } return dba, nil } func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*dbAuthorityPolicy, error) { data, err := db.getDBAuthorityPolicyBytes(ctx, authorityID) if err != nil { return nil, err } dbap, err := db.unmarshalDBAuthorityPolicy(data) if err != nil { return nil, err } if dbap == nil { //nolint:nilnil // legacy return nil, nil } if dbap.AuthorityID != authorityID { return nil, admin.NewError(admin.ErrorAuthorityMismatchType, "authority policy is not owned by authority %s", authorityID) } return dbap, nil } func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { dbap := &dbAuthorityPolicy{ ID: db.authorityID, AuthorityID: db.authorityID, Policy: linkedToDB(policy), } if err := db.save(ctx, dbap.ID, dbap, nil, "authority_policy", authorityPoliciesTable); err != nil { return admin.WrapErrorISE(err, "error creating authority policy") } return nil } func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID) if err != nil { return nil, err } return dbap.convert(), nil } func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { old, err := db.getDBAuthorityPolicy(ctx, db.authorityID) if err != nil { return err } dbap := &dbAuthorityPolicy{ ID: db.authorityID, AuthorityID: db.authorityID, Policy: linkedToDB(policy), } if err := db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable); err != nil { return admin.WrapErrorISE(err, "error updating authority policy") } return nil } func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error { old, err := db.getDBAuthorityPolicy(ctx, db.authorityID) if err != nil { return err } if err := db.save(ctx, old.ID, nil, old, "authority_policy", authorityPoliciesTable); err != nil { return admin.WrapErrorISE(err, "error deleting authority policy") } return nil } func dbToLinked(p *dbPolicy) *linkedca.Policy { if p == nil { return nil } r := &linkedca.Policy{} if x509 := p.X509; x509 != nil { r.X509 = &linkedca.X509Policy{} if allow := x509.Allow; allow != nil { r.X509.Allow = &linkedca.X509Names{} r.X509.Allow.Dns = allow.DNSDomains r.X509.Allow.Emails = allow.EmailAddresses r.X509.Allow.Ips = allow.IPRanges r.X509.Allow.Uris = allow.URIDomains r.X509.Allow.CommonNames = allow.CommonNames } if deny := x509.Deny; deny != nil { r.X509.Deny = &linkedca.X509Names{} r.X509.Deny.Dns = deny.DNSDomains r.X509.Deny.Emails = deny.EmailAddresses r.X509.Deny.Ips = deny.IPRanges r.X509.Deny.Uris = deny.URIDomains r.X509.Deny.CommonNames = deny.CommonNames } r.X509.AllowWildcardNames = x509.AllowWildcardNames } if ssh := p.SSH; ssh != nil { r.Ssh = &linkedca.SSHPolicy{} if host := ssh.Host; host != nil { r.Ssh.Host = &linkedca.SSHHostPolicy{} if allow := host.Allow; allow != nil { r.Ssh.Host.Allow = &linkedca.SSHHostNames{} r.Ssh.Host.Allow.Dns = allow.DNSDomains r.Ssh.Host.Allow.Ips = allow.IPRanges r.Ssh.Host.Allow.Principals = allow.Principals } if deny := host.Deny; deny != nil { r.Ssh.Host.Deny = &linkedca.SSHHostNames{} r.Ssh.Host.Deny.Dns = deny.DNSDomains r.Ssh.Host.Deny.Ips = deny.IPRanges r.Ssh.Host.Deny.Principals = deny.Principals } } if user := ssh.User; user != nil { r.Ssh.User = &linkedca.SSHUserPolicy{} if allow := user.Allow; allow != nil { r.Ssh.User.Allow = &linkedca.SSHUserNames{} r.Ssh.User.Allow.Emails = allow.EmailAddresses r.Ssh.User.Allow.Principals = allow.Principals } if deny := user.Deny; deny != nil { r.Ssh.User.Deny = &linkedca.SSHUserNames{} r.Ssh.User.Deny.Emails = deny.EmailAddresses r.Ssh.User.Deny.Principals = deny.Principals } } } return r } func linkedToDB(p *linkedca.Policy) *dbPolicy { if p == nil { return nil } // return early if x509 nor SSH is set if p.GetX509() == nil && p.GetSsh() == nil { return nil } r := &dbPolicy{} // fill x509 policy configuration if x509 := p.GetX509(); x509 != nil { r.X509 = &dbX509Policy{} if allow := x509.GetAllow(); allow != nil { r.X509.Allow = &dbX509Names{} if allow.Dns != nil { r.X509.Allow.DNSDomains = allow.Dns } if allow.Ips != nil { r.X509.Allow.IPRanges = allow.Ips } if allow.Emails != nil { r.X509.Allow.EmailAddresses = allow.Emails } if allow.Uris != nil { r.X509.Allow.URIDomains = allow.Uris } if allow.CommonNames != nil { r.X509.Allow.CommonNames = allow.CommonNames } } if deny := x509.GetDeny(); deny != nil { r.X509.Deny = &dbX509Names{} if deny.Dns != nil { r.X509.Deny.DNSDomains = deny.Dns } if deny.Ips != nil { r.X509.Deny.IPRanges = deny.Ips } if deny.Emails != nil { r.X509.Deny.EmailAddresses = deny.Emails } if deny.Uris != nil { r.X509.Deny.URIDomains = deny.Uris } if deny.CommonNames != nil { r.X509.Deny.CommonNames = deny.CommonNames } } r.X509.AllowWildcardNames = x509.GetAllowWildcardNames() } // fill ssh policy configuration if ssh := p.GetSsh(); ssh != nil { r.SSH = &dbSSHPolicy{} if host := ssh.GetHost(); host != nil { r.SSH.Host = &dbSSHHostPolicy{} if allow := host.GetAllow(); allow != nil { r.SSH.Host.Allow = &dbSSHHostNames{} if allow.Dns != nil { r.SSH.Host.Allow.DNSDomains = allow.Dns } if allow.Ips != nil { r.SSH.Host.Allow.IPRanges = allow.Ips } if allow.Principals != nil { r.SSH.Host.Allow.Principals = allow.Principals } } if deny := host.GetDeny(); deny != nil { r.SSH.Host.Deny = &dbSSHHostNames{} if deny.Dns != nil { r.SSH.Host.Deny.DNSDomains = deny.Dns } if deny.Ips != nil { r.SSH.Host.Deny.IPRanges = deny.Ips } if deny.Principals != nil { r.SSH.Host.Deny.Principals = deny.Principals } } } if user := ssh.GetUser(); user != nil { r.SSH.User = &dbSSHUserPolicy{} if allow := user.GetAllow(); allow != nil { r.SSH.User.Allow = &dbSSHUserNames{} if allow.Emails != nil { r.SSH.User.Allow.EmailAddresses = allow.Emails } if allow.Principals != nil { r.SSH.User.Allow.Principals = allow.Principals } } if deny := user.GetDeny(); deny != nil { r.SSH.User.Deny = &dbSSHUserNames{} if deny.Emails != nil { r.SSH.User.Deny.EmailAddresses = deny.Emails } if deny.Principals != nil { r.SSH.User.Deny.Principals = deny.Principals } } } } return r } ================================================ FILE: authority/admin/db/nosql/policy_test.go ================================================ package nosql import ( "context" "encoding/json" "errors" "reflect" "testing" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/db" "github.com/smallstep/linkedca" "github.com/smallstep/nosql" nosqldb "github.com/smallstep/nosql/database" ) func TestDB_getDBAuthorityPolicyBytes(t *testing.T) { authID := "authID" type test struct { ctx context.Context authorityID string db nosql.DB err error adminErr *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return nil, errors.New("force") }, }, err: errors.New("error loading authority policy: force"), } }, "ok": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return []byte("foo"), nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if b, err := d.getDBAuthorityPolicyBytes(tc.ctx, tc.authorityID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, string(b), "foo") } }) } } func TestDB_getDBAuthorityPolicy(t *testing.T) { authID := "authID" type test struct { ctx context.Context authorityID string db nosql.DB err error adminErr *admin.Error dbap *dbAuthorityPolicy } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return []byte("foo"), nil }, }, err: errors.New("error unmarshaling policy bytes into dbAuthorityPolicy"), } }, "fail/authorityID-error": func(t *testing.T) test { dbp := &dbAuthorityPolicy{ ID: "ID", AuthorityID: "diffAuthID", Policy: linkedToDB(&linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, }), } b, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return b, nil }, }, adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, "authority policy is not owned by authority authID"), } }, "ok/empty-bytes": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return []byte{}, nil }, }, } }, "ok": func(t *testing.T) test { dbap := &dbAuthorityPolicy{ ID: "ID", AuthorityID: authID, Policy: linkedToDB(&linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, }), } b, err := json.Marshal(dbap) assert.FatalError(t, err) return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return b, nil }, }, dbap: dbap, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} dbp, err := d.getDBAuthorityPolicy(tc.ctx, tc.authorityID) switch { case err != nil: var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } case assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) && tc.dbap == nil: assert.Nil(t, dbp) case assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr): assert.Equals(t, dbp.ID, "ID") assert.Equals(t, dbp.AuthorityID, tc.dbap.AuthorityID) assert.Equals(t, dbp.Policy, tc.dbap.Policy) } }) } } func TestDB_CreateAuthorityPolicy(t *testing.T) { authID := "authID" type test struct { ctx context.Context authorityID string policy *linkedca.Policy db nosql.DB err error adminErr *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/save-error": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } return test{ ctx: context.Background(), authorityID: authID, policy: policy, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) var _dbap = new(dbAuthorityPolicy) assert.FatalError(t, json.Unmarshal(nu, _dbap)) assert.Equals(t, _dbap.ID, authID) assert.Equals(t, _dbap.AuthorityID, authID) assert.Equals(t, _dbap.Policy, linkedToDB(policy)) return nil, false, errors.New("force") }, }, adminErr: admin.NewErrorISE("error creating authority policy: error saving authority authority_policy: force"), } }, "ok": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } return test{ ctx: context.Background(), authorityID: authID, policy: policy, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, old, nil) var _dbap = new(dbAuthorityPolicy) assert.FatalError(t, json.Unmarshal(nu, _dbap)) assert.Equals(t, _dbap.ID, authID) assert.Equals(t, _dbap.AuthorityID, authID) assert.Equals(t, _dbap.Policy, linkedToDB(policy)) return nil, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: tc.authorityID} if err := d.CreateAuthorityPolicy(tc.ctx, tc.policy); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } }) } } func TestDB_GetAuthorityPolicy(t *testing.T) { authID := "authID" type test struct { ctx context.Context authorityID string policy *linkedca.Policy db nosql.DB err error adminErr *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return nil, errors.New("force") }, }, err: errors.New("error loading authority policy: force"), } }, "ok": func(t *testing.T) test { policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } return test{ ctx: context.Background(), authorityID: authID, policy: policy, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) dbap := &dbAuthorityPolicy{ ID: authID, AuthorityID: authID, Policy: linkedToDB(policy), } b, err := json.Marshal(dbap) assert.FatalError(t, err) return b, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: tc.authorityID} got, err := d.GetAuthorityPolicy(tc.ctx) if err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } return } assert.NotNil(t, got) assert.Equals(t, tc.policy, got) }) } } func TestDB_UpdateAuthorityPolicy(t *testing.T) { authID := "authID" type test struct { ctx context.Context authorityID string policy *linkedca.Policy db nosql.DB err error adminErr *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return nil, errors.New("force") }, }, err: errors.New("error loading authority policy: force"), } }, "fail/save-error": func(t *testing.T) test { oldPolicy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.localhost"}, }, }, } policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } return test{ ctx: context.Background(), authorityID: authID, policy: policy, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) dbap := &dbAuthorityPolicy{ ID: authID, AuthorityID: authID, Policy: linkedToDB(oldPolicy), } b, err := json.Marshal(dbap) assert.FatalError(t, err) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) var _dbap = new(dbAuthorityPolicy) assert.FatalError(t, json.Unmarshal(nu, _dbap)) assert.Equals(t, _dbap.ID, authID) assert.Equals(t, _dbap.AuthorityID, authID) assert.Equals(t, _dbap.Policy, linkedToDB(policy)) return nil, false, errors.New("force") }, }, adminErr: admin.NewErrorISE("error updating authority policy: error saving authority authority_policy: force"), } }, "ok": func(t *testing.T) test { oldPolicy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.localhost"}, }, }, } policy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, } return test{ ctx: context.Background(), authorityID: authID, policy: policy, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) dbap := &dbAuthorityPolicy{ ID: authID, AuthorityID: authID, Policy: linkedToDB(oldPolicy), } b, err := json.Marshal(dbap) assert.FatalError(t, err) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) var _dbap = new(dbAuthorityPolicy) assert.FatalError(t, json.Unmarshal(nu, _dbap)) assert.Equals(t, _dbap.ID, authID) assert.Equals(t, _dbap.AuthorityID, authID) assert.Equals(t, _dbap.Policy, linkedToDB(policy)) return nil, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: tc.authorityID} if err := d.UpdateAuthorityPolicy(tc.ctx, tc.policy); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } return } }) } } func TestDB_DeleteAuthorityPolicy(t *testing.T) { authID := "authID" type test struct { ctx context.Context authorityID string db nosql.DB err error adminErr *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) return nil, errors.New("force") }, }, err: errors.New("error loading authority policy: force"), } }, "fail/save-error": func(t *testing.T) test { oldPolicy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.localhost"}, }, }, } return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) dbap := &dbAuthorityPolicy{ ID: authID, AuthorityID: authID, Policy: linkedToDB(oldPolicy), } b, err := json.Marshal(dbap) assert.FatalError(t, err) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) assert.Equals(t, nil, nu) return nil, false, errors.New("force") }, }, adminErr: admin.NewErrorISE("error deleting authority policy: error saving authority authority_policy: force"), } }, "ok": func(t *testing.T) test { oldPolicy := &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.localhost"}, }, }, } return test{ ctx: context.Background(), authorityID: authID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) dbap := &dbAuthorityPolicy{ ID: authID, AuthorityID: authID, Policy: linkedToDB(oldPolicy), } b, err := json.Marshal(dbap) assert.FatalError(t, err) return b, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, authorityPoliciesTable) assert.Equals(t, string(key), authID) assert.Equals(t, nil, nu) return nil, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: tc.authorityID} if err := d.DeleteAuthorityPolicy(tc.ctx); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } return } }) } } func Test_linkedToDB(t *testing.T) { type args struct { p *linkedca.Policy } tests := []struct { name string args args want *dbPolicy }{ { name: "nil policy", args: args{ p: nil, }, want: nil, }, { name: "no x509 nor ssh", args: args{ p: &linkedca.Policy{}, }, want: nil, }, { name: "x509", args: args{ p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"192.168.0.1/24"}, Emails: []string{"@example.com"}, Uris: []string{"*.example.com"}, CommonNames: []string{"some name"}, }, Deny: &linkedca.X509Names{ Dns: []string{"badhost.local"}, Ips: []string{"192.168.0.30"}, Emails: []string{"root@example.com"}, Uris: []string{"bad.example.com"}, CommonNames: []string{"bad name"}, }, AllowWildcardNames: true, }, }, }, want: &dbPolicy{ X509: &dbX509Policy{ Allow: &dbX509Names{ DNSDomains: []string{"*.local"}, IPRanges: []string{"192.168.0.1/24"}, EmailAddresses: []string{"@example.com"}, URIDomains: []string{"*.example.com"}, CommonNames: []string{"some name"}, }, Deny: &dbX509Names{ DNSDomains: []string{"badhost.local"}, IPRanges: []string{"192.168.0.30"}, EmailAddresses: []string{"root@example.com"}, URIDomains: []string{"bad.example.com"}, CommonNames: []string{"bad name"}, }, AllowWildcardNames: true, }, }, }, { name: "ssh user", args: args{ p: &linkedca.Policy{ Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@example.com"}, Principals: []string{"user"}, }, Deny: &linkedca.SSHUserNames{ Emails: []string{"root@example.com"}, Principals: []string{"root"}, }, }, }, }, }, want: &dbPolicy{ SSH: &dbSSHPolicy{ User: &dbSSHUserPolicy{ Allow: &dbSSHUserNames{ EmailAddresses: []string{"@example.com"}, Principals: []string{"user"}, }, Deny: &dbSSHUserNames{ EmailAddresses: []string{"root@example.com"}, Principals: []string{"root"}, }, }, }, }, }, { name: "full ssh policy", args: args{ p: &linkedca.Policy{ Ssh: &linkedca.SSHPolicy{ Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.local"}, Ips: []string{"192.168.0.1/24"}, Principals: []string{"host"}, }, Deny: &linkedca.SSHHostNames{ Dns: []string{"badhost.local"}, Ips: []string{"192.168.0.30"}, Principals: []string{"bad"}, }, }, }, }, }, want: &dbPolicy{ SSH: &dbSSHPolicy{ Host: &dbSSHHostPolicy{ Allow: &dbSSHHostNames{ DNSDomains: []string{"*.local"}, IPRanges: []string{"192.168.0.1/24"}, Principals: []string{"host"}, }, Deny: &dbSSHHostNames{ DNSDomains: []string{"badhost.local"}, IPRanges: []string{"192.168.0.30"}, Principals: []string{"bad"}, }, }, }, }, }, { name: "full policy", args: args{ p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"192.168.0.1/24"}, Emails: []string{"@example.com"}, Uris: []string{"*.example.com"}, CommonNames: []string{"some name"}, }, Deny: &linkedca.X509Names{ Dns: []string{"badhost.local"}, Ips: []string{"192.168.0.30"}, Emails: []string{"root@example.com"}, Uris: []string{"bad.example.com"}, CommonNames: []string{"bad name"}, }, AllowWildcardNames: true, }, Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@example.com"}, Principals: []string{"user"}, }, Deny: &linkedca.SSHUserNames{ Emails: []string{"root@example.com"}, Principals: []string{"root"}, }, }, Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.local"}, Ips: []string{"192.168.0.1/24"}, Principals: []string{"host"}, }, Deny: &linkedca.SSHHostNames{ Dns: []string{"badhost.local"}, Ips: []string{"192.168.0.30"}, Principals: []string{"bad"}, }, }, }, }, }, want: &dbPolicy{ X509: &dbX509Policy{ Allow: &dbX509Names{ DNSDomains: []string{"*.local"}, IPRanges: []string{"192.168.0.1/24"}, EmailAddresses: []string{"@example.com"}, URIDomains: []string{"*.example.com"}, CommonNames: []string{"some name"}, }, Deny: &dbX509Names{ DNSDomains: []string{"badhost.local"}, IPRanges: []string{"192.168.0.30"}, EmailAddresses: []string{"root@example.com"}, URIDomains: []string{"bad.example.com"}, CommonNames: []string{"bad name"}, }, AllowWildcardNames: true, }, SSH: &dbSSHPolicy{ User: &dbSSHUserPolicy{ Allow: &dbSSHUserNames{ EmailAddresses: []string{"@example.com"}, Principals: []string{"user"}, }, Deny: &dbSSHUserNames{ EmailAddresses: []string{"root@example.com"}, Principals: []string{"root"}, }, }, Host: &dbSSHHostPolicy{ Allow: &dbSSHHostNames{ DNSDomains: []string{"*.local"}, IPRanges: []string{"192.168.0.1/24"}, Principals: []string{"host"}, }, Deny: &dbSSHHostNames{ DNSDomains: []string{"badhost.local"}, IPRanges: []string{"192.168.0.30"}, Principals: []string{"bad"}, }, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := linkedToDB(tt.args.p); !reflect.DeepEqual(got, tt.want) { t.Errorf("linkedToDB() = %v, want %v", got, tt.want) } }) } } func Test_dbToLinked(t *testing.T) { type args struct { p *dbPolicy } tests := []struct { name string args args want *linkedca.Policy }{ { name: "nil policy", args: args{ p: nil, }, want: nil, }, { name: "x509", args: args{ p: &dbPolicy{ X509: &dbX509Policy{ Allow: &dbX509Names{ DNSDomains: []string{"*.local"}, IPRanges: []string{"192.168.0.1/24"}, EmailAddresses: []string{"@example.com"}, URIDomains: []string{"*.example.com"}, CommonNames: []string{"some name"}, }, Deny: &dbX509Names{ DNSDomains: []string{"badhost.local"}, IPRanges: []string{"192.168.0.30"}, EmailAddresses: []string{"root@example.com"}, URIDomains: []string{"bad.example.com"}, CommonNames: []string{"bad name"}, }, AllowWildcardNames: true, }, }, }, want: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"192.168.0.1/24"}, Emails: []string{"@example.com"}, Uris: []string{"*.example.com"}, CommonNames: []string{"some name"}, }, Deny: &linkedca.X509Names{ Dns: []string{"badhost.local"}, Ips: []string{"192.168.0.30"}, Emails: []string{"root@example.com"}, Uris: []string{"bad.example.com"}, CommonNames: []string{"bad name"}, }, AllowWildcardNames: true, }, }, }, { name: "ssh user", args: args{ p: &dbPolicy{ SSH: &dbSSHPolicy{ User: &dbSSHUserPolicy{ Allow: &dbSSHUserNames{ EmailAddresses: []string{"@example.com"}, Principals: []string{"user"}, }, Deny: &dbSSHUserNames{ EmailAddresses: []string{"root@example.com"}, Principals: []string{"root"}, }, }, }, }, }, want: &linkedca.Policy{ Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@example.com"}, Principals: []string{"user"}, }, Deny: &linkedca.SSHUserNames{ Emails: []string{"root@example.com"}, Principals: []string{"root"}, }, }, }, }, }, { name: "ssh host", args: args{ p: &dbPolicy{ SSH: &dbSSHPolicy{ Host: &dbSSHHostPolicy{ Allow: &dbSSHHostNames{ DNSDomains: []string{"*.local"}, IPRanges: []string{"192.168.0.1/24"}, Principals: []string{"host"}, }, Deny: &dbSSHHostNames{ DNSDomains: []string{"badhost.local"}, IPRanges: []string{"192.168.0.30"}, Principals: []string{"bad"}, }, }, }, }, }, want: &linkedca.Policy{ Ssh: &linkedca.SSHPolicy{ Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.local"}, Ips: []string{"192.168.0.1/24"}, Principals: []string{"host"}, }, Deny: &linkedca.SSHHostNames{ Dns: []string{"badhost.local"}, Ips: []string{"192.168.0.30"}, Principals: []string{"bad"}, }, }, }, }, }, { name: "full policy", args: args{ p: &dbPolicy{ X509: &dbX509Policy{ Allow: &dbX509Names{ DNSDomains: []string{"*.local"}, IPRanges: []string{"192.168.0.1/24"}, EmailAddresses: []string{"@example.com"}, URIDomains: []string{"*.example.com"}, CommonNames: []string{"some name"}, }, Deny: &dbX509Names{ DNSDomains: []string{"badhost.local"}, IPRanges: []string{"192.168.0.30"}, EmailAddresses: []string{"root@example.com"}, URIDomains: []string{"bad.example.com"}, CommonNames: []string{"bad name"}, }, AllowWildcardNames: true, }, SSH: &dbSSHPolicy{ User: &dbSSHUserPolicy{ Allow: &dbSSHUserNames{ EmailAddresses: []string{"@example.com"}, Principals: []string{"user"}, }, Deny: &dbSSHUserNames{ EmailAddresses: []string{"root@example.com"}, Principals: []string{"root"}, }, }, Host: &dbSSHHostPolicy{ Allow: &dbSSHHostNames{ DNSDomains: []string{"*.local"}, IPRanges: []string{"192.168.0.1/24"}, Principals: []string{"host"}, }, Deny: &dbSSHHostNames{ DNSDomains: []string{"badhost.local"}, IPRanges: []string{"192.168.0.30"}, Principals: []string{"bad"}, }, }, }, }, }, want: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, Ips: []string{"192.168.0.1/24"}, Emails: []string{"@example.com"}, Uris: []string{"*.example.com"}, CommonNames: []string{"some name"}, }, Deny: &linkedca.X509Names{ Dns: []string{"badhost.local"}, Ips: []string{"192.168.0.30"}, Emails: []string{"root@example.com"}, Uris: []string{"bad.example.com"}, CommonNames: []string{"bad name"}, }, AllowWildcardNames: true, }, Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@example.com"}, Principals: []string{"user"}, }, Deny: &linkedca.SSHUserNames{ Emails: []string{"root@example.com"}, Principals: []string{"root"}, }, }, Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.local"}, Ips: []string{"192.168.0.1/24"}, Principals: []string{"host"}, }, Deny: &linkedca.SSHHostNames{ Dns: []string{"badhost.local"}, Ips: []string{"192.168.0.30"}, Principals: []string{"bad"}, }, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := dbToLinked(tt.args.p); !reflect.DeepEqual(got, tt.want) { t.Errorf("dbToLinked() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/admin/db/nosql/provisioner.go ================================================ package nosql import ( "context" "encoding/json" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/linkedca" "github.com/smallstep/nosql" "google.golang.org/protobuf/types/known/timestamppb" ) // dbProvisioner is the database representation of a Provisioner type. type dbProvisioner struct { ID string `json:"id"` AuthorityID string `json:"authorityID"` Type linkedca.Provisioner_Type `json:"type"` Name string `json:"name"` Claims *linkedca.Claims `json:"claims"` Details []byte `json:"details"` X509Template *linkedca.Template `json:"x509Template"` SSHTemplate *linkedca.Template `json:"sshTemplate"` CreatedAt time.Time `json:"createdAt"` DeletedAt time.Time `json:"deletedAt"` Webhooks []dbWebhook `json:"webhooks,omitempty"` } type dbBasicAuth struct { Username string `json:"username"` Password string `json:"password"` } type dbWebhook struct { Name string `json:"name"` ID string `json:"id"` URL string `json:"url"` Kind string `json:"kind"` Secret string `json:"secret"` BearerToken string `json:"bearerToken,omitempty"` BasicAuth *dbBasicAuth `json:"basicAuth,omitempty"` DisableTLSClientAuth bool `json:"disableTLSClientAuth,omitempty"` CertType string `json:"certType,omitempty"` } func (dbp *dbProvisioner) clone() *dbProvisioner { u := *dbp return &u } func (dbp *dbProvisioner) convert2linkedca() (*linkedca.Provisioner, error) { details, err := admin.UnmarshalProvisionerDetails(dbp.Type, dbp.Details) if err != nil { return nil, err } return &linkedca.Provisioner{ Id: dbp.ID, AuthorityId: dbp.AuthorityID, Type: dbp.Type, Name: dbp.Name, Claims: dbp.Claims, Details: details, X509Template: dbp.X509Template, SshTemplate: dbp.SSHTemplate, CreatedAt: timestamppb.New(dbp.CreatedAt), DeletedAt: timestamppb.New(dbp.DeletedAt), Webhooks: dbWebhooksToLinkedca(dbp.Webhooks), }, nil } func (db *DB) getDBProvisionerBytes(_ context.Context, id string) ([]byte, error) { data, err := db.db.Get(provisionersTable, []byte(id)) if nosql.IsErrNotFound(err) { return nil, admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id) } else if err != nil { return nil, errors.Wrapf(err, "error loading provisioner %s", id) } return data, nil } func (db *DB) unmarshalDBProvisioner(data []byte, id string) (*dbProvisioner, error) { var dbp = new(dbProvisioner) if err := json.Unmarshal(data, dbp); err != nil { return nil, errors.Wrapf(err, "error unmarshaling provisioner %s into dbProvisioner", id) } if !dbp.DeletedAt.IsZero() { return nil, admin.NewError(admin.ErrorDeletedType, "provisioner %s is deleted", id) } if dbp.AuthorityID != db.authorityID { return nil, admin.NewError(admin.ErrorAuthorityMismatchType, "provisioner %s is not owned by authority %s", id, db.authorityID) } return dbp, nil } func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner, error) { data, err := db.getDBProvisionerBytes(ctx, id) if err != nil { return nil, err } dbp, err := db.unmarshalDBProvisioner(data, id) if err != nil { return nil, err } return dbp, nil } func (db *DB) unmarshalProvisioner(data []byte, id string) (*linkedca.Provisioner, error) { dbp, err := db.unmarshalDBProvisioner(data, id) if err != nil { return nil, err } return dbp.convert2linkedca() } // GetProvisioner retrieves and unmarshals a provisioner from the database. func (db *DB) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) { data, err := db.getDBProvisionerBytes(ctx, id) if err != nil { return nil, err } prov, err := db.unmarshalProvisioner(data, id) if err != nil { return nil, err } return prov, nil } // GetProvisioners retrieves and unmarshals all active (not deleted) provisioners // from the database. func (db *DB) GetProvisioners(_ context.Context) ([]*linkedca.Provisioner, error) { dbEntries, err := db.db.List(provisionersTable) if err != nil { return nil, errors.Wrap(err, "error loading provisioners") } var provs []*linkedca.Provisioner for _, entry := range dbEntries { prov, err := db.unmarshalProvisioner(entry.Value, string(entry.Key)) if err != nil { var ae *admin.Error if errors.As(err, &ae) { if ae.IsType(admin.ErrorDeletedType) || ae.IsType(admin.ErrorAuthorityMismatchType) { continue } return nil, err } return nil, err } if prov.AuthorityId != db.authorityID { continue } provs = append(provs, prov) } return provs, nil } // CreateProvisioner stores a new provisioner to the database. func (db *DB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { var err error prov.Id, err = randID() if err != nil { return admin.WrapErrorISE(err, "error generating random id for provisioner") } details, err := json.Marshal(prov.Details.GetData()) if err != nil { return admin.WrapErrorISE(err, "error marshaling details when creating provisioner %s", prov.Name) } dbp := &dbProvisioner{ ID: prov.Id, AuthorityID: db.authorityID, Type: prov.Type, Name: prov.Name, Claims: prov.Claims, Details: details, X509Template: prov.X509Template, SSHTemplate: prov.SshTemplate, CreatedAt: clock.Now(), Webhooks: linkedcaWebhooksToDB(prov.Webhooks), } if err := db.save(ctx, prov.Id, dbp, nil, "provisioner", provisionersTable); err != nil { return admin.WrapErrorISE(err, "error creating provisioner %s", prov.Name) } return nil } // UpdateProvisioner saves an updated provisioner to the database. func (db *DB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { old, err := db.getDBProvisioner(ctx, prov.Id) if err != nil { return err } nu := old.clone() if old.Type != prov.Type { return admin.NewError(admin.ErrorBadRequestType, "cannot update provisioner type") } nu.Name = prov.Name nu.Claims = prov.Claims nu.Details, err = json.Marshal(prov.Details.GetData()) if err != nil { return admin.WrapErrorISE(err, "error marshaling details when updating provisioner %s", prov.Name) } nu.X509Template = prov.X509Template nu.SSHTemplate = prov.SshTemplate nu.Webhooks = linkedcaWebhooksToDB(prov.Webhooks) return db.save(ctx, prov.Id, nu, old, "provisioner", provisionersTable) } // DeleteProvisioner saves an updated admin to the database. func (db *DB) DeleteProvisioner(ctx context.Context, id string) error { old, err := db.getDBProvisioner(ctx, id) if err != nil { return err } nu := old.clone() nu.DeletedAt = clock.Now() return db.save(ctx, old.ID, nu, old, "provisioner", provisionersTable) } func dbWebhooksToLinkedca(dbwhs []dbWebhook) []*linkedca.Webhook { if len(dbwhs) == 0 { return nil } lwhs := make([]*linkedca.Webhook, len(dbwhs)) for i, dbwh := range dbwhs { lwh := &linkedca.Webhook{ Name: dbwh.Name, Id: dbwh.ID, Url: dbwh.URL, Kind: linkedca.Webhook_Kind(linkedca.Webhook_Kind_value[dbwh.Kind]), Secret: dbwh.Secret, DisableTlsClientAuth: dbwh.DisableTLSClientAuth, CertType: linkedca.Webhook_CertType(linkedca.Webhook_CertType_value[dbwh.CertType]), } if dbwh.BearerToken != "" { lwh.Auth = &linkedca.Webhook_BearerToken{ BearerToken: &linkedca.BearerToken{ BearerToken: dbwh.BearerToken, }, } } else if dbwh.BasicAuth != nil && (dbwh.BasicAuth.Username != "" || dbwh.BasicAuth.Password != "") { lwh.Auth = &linkedca.Webhook_BasicAuth{ BasicAuth: &linkedca.BasicAuth{ Username: dbwh.BasicAuth.Username, Password: dbwh.BasicAuth.Password, }, } } lwhs[i] = lwh } return lwhs } func linkedcaWebhooksToDB(lwhs []*linkedca.Webhook) []dbWebhook { if len(lwhs) == 0 { return nil } dbwhs := make([]dbWebhook, len(lwhs)) for i, lwh := range lwhs { dbwh := dbWebhook{ Name: lwh.Name, ID: lwh.Id, URL: lwh.Url, Kind: lwh.Kind.String(), Secret: lwh.Secret, DisableTLSClientAuth: lwh.DisableTlsClientAuth, CertType: lwh.CertType.String(), } switch a := lwh.GetAuth().(type) { case *linkedca.Webhook_BearerToken: dbwh.BearerToken = a.BearerToken.BearerToken case *linkedca.Webhook_BasicAuth: dbwh.BasicAuth = &dbBasicAuth{ Username: a.BasicAuth.Username, Password: a.BasicAuth.Password, } } dbwhs[i] = dbwh } return dbwhs } ================================================ FILE: authority/admin/db/nosql/provisioner_test.go ================================================ package nosql import ( "context" "encoding/json" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/db" "github.com/smallstep/linkedca" "github.com/smallstep/nosql" nosqldb "github.com/smallstep/nosql/database" ) func TestDB_getDBProvisionerBytes(t *testing.T) { provID := "provID" type test struct { db nosql.DB err error adminErr *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return nil, errors.New("force") }, }, err: errors.New("error loading provisioner provID: force"), } }, "ok": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return []byte("foo"), nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db} if b, err := d.getDBProvisionerBytes(context.Background(), provID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, string(b), "foo") } }) } } func TestDB_getDBProvisioner(t *testing.T) { provID := "provID" type test struct { db nosql.DB err error adminErr *admin.Error dbp *dbProvisioner } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return nil, errors.New("force") }, }, err: errors.New("error loading provisioner provID: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return []byte("foo"), nil }, }, err: errors.New("error unmarshaling provisioner provID into dbProvisioner"), } }, "fail/deleted": func(t *testing.T) test { now := clock.Now() dbp := &dbProvisioner{ ID: provID, AuthorityID: admin.DefaultAuthorityID, Type: linkedca.Provisioner_JWK, Name: "provName", CreatedAt: now, DeletedAt: now, } b, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return b, nil }, }, adminErr: admin.NewError(admin.ErrorDeletedType, "provisioner provID is deleted"), } }, "ok": func(t *testing.T) test { now := clock.Now() dbp := &dbProvisioner{ ID: provID, AuthorityID: admin.DefaultAuthorityID, Type: linkedca.Provisioner_JWK, Name: "provName", CreatedAt: now, } b, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return b, nil }, }, dbp: dbp, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if dbp, err := d.getDBProvisioner(context.Background(), provID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dbp.ID, provID) assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID) assert.Equals(t, dbp.Type, tc.dbp.Type) assert.Equals(t, dbp.Name, tc.dbp.Name) assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt) assert.Fatal(t, dbp.DeletedAt.IsZero()) assert.Equals(t, dbp.Webhooks, tc.dbp.Webhooks) } }) } } func TestDB_unmarshalDBProvisioner(t *testing.T) { provID := "provID" type test struct { in []byte err error adminErr *admin.Error dbp *dbProvisioner } var tests = map[string]func(t *testing.T) test{ "fail/unmarshal-error": func(t *testing.T) test { return test{ in: []byte("foo"), err: errors.New("error unmarshaling provisioner provID into dbProvisioner"), } }, "fail/deleted-error": func(t *testing.T) test { dbp := &dbProvisioner{ DeletedAt: clock.Now(), } data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ in: data, adminErr: admin.NewError(admin.ErrorDeletedType, "provisioner %s is deleted", provID), } }, "fail/authority-mismatch-error": func(t *testing.T) test { dbp := &dbProvisioner{ ID: provID, AuthorityID: "foo", } data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ in: data, adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, "provisioner %s is not owned by authority %s", provID, admin.DefaultAuthorityID), } }, "ok": func(t *testing.T) test { dbp := &dbProvisioner{ ID: provID, AuthorityID: admin.DefaultAuthorityID, Type: linkedca.Provisioner_JWK, Name: "provName", CreatedAt: clock.Now(), } data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ in: data, dbp: dbp, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{authorityID: admin.DefaultAuthorityID} if dbp, err := d.unmarshalDBProvisioner(tc.in, provID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dbp.ID, provID) assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID) assert.Equals(t, dbp.Type, tc.dbp.Type) assert.Equals(t, dbp.Name, tc.dbp.Name) assert.Equals(t, dbp.Details, tc.dbp.Details) assert.Equals(t, dbp.Claims, tc.dbp.Claims) assert.Equals(t, dbp.X509Template, tc.dbp.X509Template) assert.Equals(t, dbp.SSHTemplate, tc.dbp.SSHTemplate) assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt) assert.Fatal(t, dbp.DeletedAt.IsZero()) assert.Equals(t, dbp.Webhooks, tc.dbp.Webhooks) } }) } } func defaultDBP(t *testing.T) *dbProvisioner { details := &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ ForceCn: true, }, } detailBytes, err := json.Marshal(details) assert.FatalError(t, err) return &dbProvisioner{ ID: "provID", AuthorityID: admin.DefaultAuthorityID, Type: linkedca.Provisioner_ACME, Name: "provName", Details: detailBytes, Claims: &linkedca.Claims{ DisableRenewal: true, X509: &linkedca.X509Claims{ Enabled: true, Durations: &linkedca.Durations{ Min: "5m", Max: "12h", Default: "6h", }, }, Ssh: &linkedca.SSHClaims{ Enabled: true, UserDurations: &linkedca.Durations{ Min: "5m", Max: "12h", Default: "6h", }, HostDurations: &linkedca.Durations{ Min: "5m", Max: "12h", Default: "6h", }, }, }, X509Template: &linkedca.Template{ Template: []byte("foo"), Data: []byte("bar"), }, SSHTemplate: &linkedca.Template{ Template: []byte("baz"), Data: []byte("zap"), }, CreatedAt: clock.Now(), Webhooks: []dbWebhook{ { Name: "metadata", URL: "https://inventory.smallstep.com", Kind: linkedca.Webhook_ENRICHING.String(), Secret: "secret", BearerToken: "token", }, }, } } func TestDB_unmarshalProvisioner(t *testing.T) { provID := "provID" type test struct { in []byte err error adminErr *admin.Error dbp *dbProvisioner } var tests = map[string]func(t *testing.T) test{ "fail/unmarshal-error": func(t *testing.T) test { return test{ in: []byte("foo"), err: errors.New("error unmarshaling provisioner provID into dbProvisioner"), } }, "fail/deleted-error": func(t *testing.T) test { dbp := &dbProvisioner{ DeletedAt: time.Now(), } data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ in: data, adminErr: admin.NewError(admin.ErrorDeletedType, "provisioner provID is deleted"), } }, "ok": func(t *testing.T) test { dbp := defaultDBP(t) data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ in: data, dbp: dbp, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{authorityID: admin.DefaultAuthorityID} if prov, err := d.unmarshalProvisioner(tc.in, provID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, prov.Id, provID) assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID) assert.Equals(t, prov.Type, tc.dbp.Type) assert.Equals(t, prov.Name, tc.dbp.Name) assert.Equals(t, prov.Claims, tc.dbp.Claims) assert.Equals(t, prov.X509Template, tc.dbp.X509Template) assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate) assert.Equals(t, prov.Webhooks, dbWebhooksToLinkedca(tc.dbp.Webhooks)) retDetailsBytes, err := json.Marshal(prov.Details.GetData()) assert.FatalError(t, err) assert.Equals(t, retDetailsBytes, tc.dbp.Details) } }) } } func TestDB_GetProvisioner(t *testing.T) { provID := "provID" type test struct { db nosql.DB err error adminErr *admin.Error dbp *dbProvisioner } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return nil, errors.New("force") }, }, err: errors.New("error loading provisioner provID: force"), } }, "fail/unmarshal-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return []byte("foo"), nil }, }, err: errors.New("error unmarshaling provisioner provID into dbProvisioner"), } }, "fail/deleted": func(t *testing.T) test { dbp := defaultDBP(t) dbp.DeletedAt = clock.Now() b, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return b, nil }, }, dbp: dbp, adminErr: admin.NewError(admin.ErrorDeletedType, "provisioner provID is deleted"), } }, "fail/authorityID-mismatch": func(t *testing.T) test { dbp := defaultDBP(t) dbp.AuthorityID = "foo" b, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return b, nil }, }, dbp: dbp, adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, "provisioner %s is not owned by authority %s", dbp.ID, admin.DefaultAuthorityID), } }, "ok": func(t *testing.T) test { dbp := defaultDBP(t) b, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return b, nil }, }, dbp: dbp, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if prov, err := d.GetProvisioner(context.Background(), provID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, prov.Id, provID) assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID) assert.Equals(t, prov.Type, tc.dbp.Type) assert.Equals(t, prov.Name, tc.dbp.Name) assert.Equals(t, prov.Claims, tc.dbp.Claims) assert.Equals(t, prov.X509Template, tc.dbp.X509Template) assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate) assert.Equals(t, prov.Webhooks, dbWebhooksToLinkedca(tc.dbp.Webhooks)) retDetailsBytes, err := json.Marshal(prov.Details.GetData()) assert.FatalError(t, err) assert.Equals(t, retDetailsBytes, tc.dbp.Details) } }) } } func TestDB_DeleteProvisioner(t *testing.T) { provID := "provID" type test struct { db nosql.DB err error adminErr *admin.Error } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return nil, errors.New("force") }, }, err: errors.New("error loading provisioner provID: force"), } }, "fail/save-error": func(t *testing.T) test { dbp := defaultDBP(t) data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return data, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) assert.Equals(t, string(old), string(data)) var _dbp = new(dbProvisioner) assert.FatalError(t, json.Unmarshal(nu, _dbp)) assert.Equals(t, _dbp.ID, provID) assert.Equals(t, _dbp.AuthorityID, dbp.AuthorityID) assert.Equals(t, _dbp.Type, dbp.Type) assert.Equals(t, _dbp.Name, dbp.Name) assert.Equals(t, _dbp.Claims, dbp.Claims) assert.Equals(t, _dbp.X509Template, dbp.X509Template) assert.Equals(t, _dbp.SSHTemplate, dbp.SSHTemplate) assert.Equals(t, _dbp.CreatedAt, dbp.CreatedAt) assert.Equals(t, _dbp.Details, dbp.Details) assert.Equals(t, _dbp.Webhooks, dbp.Webhooks) assert.True(t, _dbp.DeletedAt.Before(time.Now())) assert.True(t, _dbp.DeletedAt.After(time.Now().Add(-time.Minute))) return nil, false, errors.New("force") }, }, err: errors.New("error saving authority provisioner: force"), } }, "ok": func(t *testing.T) test { dbp := defaultDBP(t) data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return data, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) assert.Equals(t, string(old), string(data)) var _dbp = new(dbProvisioner) assert.FatalError(t, json.Unmarshal(nu, _dbp)) assert.Equals(t, _dbp.ID, provID) assert.Equals(t, _dbp.AuthorityID, dbp.AuthorityID) assert.Equals(t, _dbp.Type, dbp.Type) assert.Equals(t, _dbp.Name, dbp.Name) assert.Equals(t, _dbp.Claims, dbp.Claims) assert.Equals(t, _dbp.X509Template, dbp.X509Template) assert.Equals(t, _dbp.SSHTemplate, dbp.SSHTemplate) assert.Equals(t, _dbp.CreatedAt, dbp.CreatedAt) assert.Equals(t, _dbp.Details, dbp.Details) assert.Equals(t, _dbp.Webhooks, dbp.Webhooks) assert.True(t, _dbp.DeletedAt.Before(time.Now())) assert.True(t, _dbp.DeletedAt.After(time.Now().Add(-time.Minute))) return nu, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if err := d.DeleteProvisioner(context.Background(), provID); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } }) } } func TestDB_GetProvisioners(t *testing.T) { fooProv := defaultDBP(t) fooProv.Name = "foo" foob, err := json.Marshal(fooProv) assert.FatalError(t, err) barProv := defaultDBP(t) barProv.Name = "bar" barProv.DeletedAt = clock.Now() barb, err := json.Marshal(barProv) assert.FatalError(t, err) bazProv := defaultDBP(t) bazProv.Name = "baz" bazProv.AuthorityID = "baz" bazb, err := json.Marshal(bazProv) assert.FatalError(t, err) zapProv := defaultDBP(t) zapProv.Name = "zap" zapb, err := json.Marshal(zapProv) assert.FatalError(t, err) type test struct { db nosql.DB err error adminErr *admin.Error verify func(*testing.T, []*linkedca.Provisioner) } var tests = map[string]func(t *testing.T) test{ "fail/db.List-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, provisionersTable) return nil, errors.New("force") }, }, err: errors.New("error loading provisioners"), } }, "fail/unmarshal-error": func(t *testing.T) test { ret := []*nosqldb.Entry{ {Bucket: provisionersTable, Key: []byte("foo"), Value: foob}, {Bucket: provisionersTable, Key: []byte("bar"), Value: barb}, {Bucket: provisionersTable, Key: []byte("zap"), Value: []byte("zap")}, } return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, provisionersTable) return ret, nil }, }, err: errors.New("error unmarshaling provisioner zap into dbProvisioner"), } }, "ok/none": func(t *testing.T) test { ret := []*nosqldb.Entry{} return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, provisionersTable) return ret, nil }, }, verify: func(t *testing.T, provs []*linkedca.Provisioner) { assert.Equals(t, len(provs), 0) }, } }, "ok/only-invalid": func(t *testing.T) test { ret := []*nosqldb.Entry{ {Bucket: provisionersTable, Key: []byte("bar"), Value: barb}, {Bucket: provisionersTable, Key: []byte("baz"), Value: bazb}, } return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, provisionersTable) return ret, nil }, }, verify: func(t *testing.T, provs []*linkedca.Provisioner) { assert.Equals(t, len(provs), 0) }, } }, "ok": func(t *testing.T) test { ret := []*nosqldb.Entry{ {Bucket: provisionersTable, Key: []byte("foo"), Value: foob}, {Bucket: provisionersTable, Key: []byte("bar"), Value: barb}, {Bucket: provisionersTable, Key: []byte("baz"), Value: bazb}, {Bucket: provisionersTable, Key: []byte("zap"), Value: zapb}, } return test{ db: &db.MockNoSQLDB{ MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, provisionersTable) return ret, nil }, }, verify: func(t *testing.T, provs []*linkedca.Provisioner) { assert.Equals(t, len(provs), 2) assert.Equals(t, provs[0].Id, fooProv.ID) assert.Equals(t, provs[0].AuthorityId, fooProv.AuthorityID) assert.Equals(t, provs[0].Type, fooProv.Type) assert.Equals(t, provs[0].Name, fooProv.Name) assert.Equals(t, provs[0].Claims, fooProv.Claims) assert.Equals(t, provs[0].X509Template, fooProv.X509Template) assert.Equals(t, provs[0].SshTemplate, fooProv.SSHTemplate) assert.Equals(t, provs[0].Webhooks, dbWebhooksToLinkedca(fooProv.Webhooks)) retDetailsBytes, err := json.Marshal(provs[0].Details.GetData()) assert.FatalError(t, err) assert.Equals(t, retDetailsBytes, fooProv.Details) assert.Equals(t, provs[1].Id, zapProv.ID) assert.Equals(t, provs[1].AuthorityId, zapProv.AuthorityID) assert.Equals(t, provs[1].Type, zapProv.Type) assert.Equals(t, provs[1].Name, zapProv.Name) assert.Equals(t, provs[1].Claims, zapProv.Claims) assert.Equals(t, provs[1].X509Template, zapProv.X509Template) assert.Equals(t, provs[1].SshTemplate, zapProv.SSHTemplate) assert.Equals(t, provs[1].Webhooks, dbWebhooksToLinkedca(zapProv.Webhooks)) retDetailsBytes, err = json.Marshal(provs[1].Details.GetData()) assert.FatalError(t, err) assert.Equals(t, retDetailsBytes, zapProv.Details) }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if provs, err := d.GetProvisioners(context.Background()); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { tc.verify(t, provs) } }) } } func TestDB_CreateProvisioner(t *testing.T) { type test struct { db nosql.DB err error adminErr *admin.Error prov *linkedca.Provisioner } var tests = map[string]func(t *testing.T) test{ "fail/save-error": func(t *testing.T) test { dbp := defaultDBP(t) prov, err := dbp.convert2linkedca() assert.FatalError(t, err) return test{ prov: prov, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, old, nil) var _dbp = new(dbProvisioner) assert.FatalError(t, json.Unmarshal(nu, _dbp)) assert.True(t, _dbp.ID != "" && _dbp.ID == string(key)) assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId) assert.Equals(t, _dbp.Type, prov.Type) assert.Equals(t, _dbp.Name, prov.Name) assert.Equals(t, _dbp.Claims, prov.Claims) assert.Equals(t, _dbp.X509Template, prov.X509Template) assert.Equals(t, _dbp.SSHTemplate, prov.SshTemplate) assert.Equals(t, _dbp.Webhooks, linkedcaWebhooksToDB(prov.Webhooks)) retDetailsBytes, err := json.Marshal(prov.Details.GetData()) assert.FatalError(t, err) assert.Equals(t, retDetailsBytes, _dbp.Details) assert.True(t, _dbp.DeletedAt.IsZero()) assert.True(t, _dbp.CreatedAt.Before(time.Now())) assert.True(t, _dbp.CreatedAt.After(time.Now().Add(-time.Minute))) return nil, false, errors.New("force") }, }, adminErr: admin.NewErrorISE("error creating provisioner provName: error saving authority provisioner: force"), } }, "ok": func(t *testing.T) test { dbp := defaultDBP(t) prov, err := dbp.convert2linkedca() assert.FatalError(t, err) return test{ prov: prov, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, old, nil) var _dbp = new(dbProvisioner) assert.FatalError(t, json.Unmarshal(nu, _dbp)) assert.True(t, _dbp.ID != "" && _dbp.ID == string(key)) assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId) assert.Equals(t, _dbp.Type, prov.Type) assert.Equals(t, _dbp.Name, prov.Name) assert.Equals(t, _dbp.Claims, prov.Claims) assert.Equals(t, _dbp.X509Template, prov.X509Template) assert.Equals(t, _dbp.SSHTemplate, prov.SshTemplate) assert.Equals(t, _dbp.Webhooks, linkedcaWebhooksToDB(prov.Webhooks)) retDetailsBytes, err := json.Marshal(prov.Details.GetData()) assert.FatalError(t, err) assert.Equals(t, retDetailsBytes, _dbp.Details) assert.True(t, _dbp.DeletedAt.IsZero()) assert.True(t, _dbp.CreatedAt.Before(time.Now())) assert.True(t, _dbp.CreatedAt.After(time.Now().Add(-time.Minute))) return nu, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if err := d.CreateProvisioner(context.Background(), tc.prov); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } }) } } func TestDB_UpdateProvisioner(t *testing.T) { provID := "provID" type test struct { db nosql.DB err error adminErr *admin.Error prov *linkedca.Provisioner } var tests = map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { return test{ prov: &linkedca.Provisioner{Id: provID}, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), } }, "fail/db.Get-error": func(t *testing.T) test { return test{ prov: &linkedca.Provisioner{Id: provID}, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return nil, errors.New("force") }, }, err: errors.New("error loading provisioner provID: force"), } }, "fail/update-deleted": func(t *testing.T) test { dbp := defaultDBP(t) dbp.DeletedAt = clock.Now() data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ prov: &linkedca.Provisioner{Id: provID}, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return data, nil }, }, adminErr: admin.NewError(admin.ErrorDeletedType, "provisioner %s is deleted", provID), } }, "fail/update-type-error": func(t *testing.T) test { dbp := defaultDBP(t) upd, err := dbp.convert2linkedca() assert.FatalError(t, err) upd.Type = linkedca.Provisioner_JWK data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ prov: upd, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return data, nil }, }, adminErr: admin.NewError(admin.ErrorBadRequestType, "cannot update provisioner type"), } }, "fail/save-error": func(t *testing.T) test { dbp := defaultDBP(t) prov, err := dbp.convert2linkedca() assert.FatalError(t, err) data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ prov: prov, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return data, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) assert.Equals(t, string(old), string(data)) var _dbp = new(dbProvisioner) assert.FatalError(t, json.Unmarshal(nu, _dbp)) assert.True(t, _dbp.ID != "" && _dbp.ID == string(key)) assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId) assert.Equals(t, _dbp.Type, prov.Type) assert.Equals(t, _dbp.Name, prov.Name) assert.Equals(t, _dbp.Claims, prov.Claims) assert.Equals(t, _dbp.X509Template, prov.X509Template) assert.Equals(t, _dbp.SSHTemplate, prov.SshTemplate) assert.Equals(t, _dbp.Webhooks, linkedcaWebhooksToDB(prov.Webhooks)) retDetailsBytes, err := json.Marshal(prov.Details.GetData()) assert.FatalError(t, err) assert.Equals(t, retDetailsBytes, _dbp.Details) assert.True(t, _dbp.DeletedAt.IsZero()) assert.True(t, _dbp.CreatedAt.Before(time.Now())) assert.True(t, _dbp.CreatedAt.After(time.Now().Add(-time.Minute))) return nil, false, errors.New("force") }, }, err: errors.New("error saving authority provisioner: force"), } }, "ok": func(t *testing.T) test { dbp := defaultDBP(t) prov, err := dbp.convert2linkedca() assert.FatalError(t, err) prov.Name = "new-name" prov.Claims = &linkedca.Claims{ DisableRenewal: true, X509: &linkedca.X509Claims{ Enabled: true, Durations: &linkedca.Durations{ Min: "10m", Max: "8h", Default: "4h", }, }, Ssh: &linkedca.SSHClaims{ Enabled: true, UserDurations: &linkedca.Durations{ Min: "7m", Max: "11h", Default: "5h", }, HostDurations: &linkedca.Durations{ Min: "4m", Max: "24h", Default: "24h", }, }, } prov.X509Template = &linkedca.Template{ Template: []byte("x"), Data: []byte("y"), } prov.SshTemplate = &linkedca.Template{ Template: []byte("z"), Data: []byte("w"), } prov.Details = &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ ForceCn: false, }, }, } prov.Webhooks = []*linkedca.Webhook{ { Name: "users", Url: "https://example.com/users", }, } data, err := json.Marshal(dbp) assert.FatalError(t, err) return test{ prov: prov, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) return data, nil }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) assert.Equals(t, string(old), string(data)) var _dbp = new(dbProvisioner) assert.FatalError(t, json.Unmarshal(nu, _dbp)) assert.True(t, _dbp.ID != "" && _dbp.ID == string(key)) assert.Equals(t, _dbp.AuthorityID, prov.AuthorityId) assert.Equals(t, _dbp.Type, prov.Type) assert.Equals(t, _dbp.Name, prov.Name) assert.Equals(t, _dbp.Claims, prov.Claims) assert.Equals(t, _dbp.X509Template, prov.X509Template) assert.Equals(t, _dbp.SSHTemplate, prov.SshTemplate) assert.Equals(t, _dbp.Webhooks, linkedcaWebhooksToDB(prov.Webhooks)) retDetailsBytes, err := json.Marshal(prov.Details.GetData()) assert.FatalError(t, err) assert.Equals(t, retDetailsBytes, _dbp.Details) assert.True(t, _dbp.DeletedAt.IsZero()) assert.True(t, _dbp.CreatedAt.Before(time.Now())) assert.True(t, _dbp.CreatedAt.After(time.Now().Add(-time.Minute))) return nu, true, nil }, }, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} if err := d.UpdateProvisioner(context.Background(), tc.prov); err != nil { var ae *admin.Error if errors.As(err, &ae) { if assert.NotNil(t, tc.adminErr) { assert.Equals(t, ae.Type, tc.adminErr.Type) assert.Equals(t, ae.Detail, tc.adminErr.Detail) assert.Equals(t, ae.Status, tc.adminErr.Status) assert.Equals(t, ae.Err.Error(), tc.adminErr.Err.Error()) assert.Equals(t, ae.Detail, tc.adminErr.Detail) } } else { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } } }) } } func Test_linkedcaWebhooksToDB(t *testing.T) { type test struct { in []*linkedca.Webhook want []dbWebhook } var tests = map[string]test{ "nil": { in: nil, want: nil, }, "zero": { in: []*linkedca.Webhook{}, want: nil, }, "bearer": { in: []*linkedca.Webhook{ { Name: "bearer", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING, Secret: "secret", Auth: &linkedca.Webhook_BearerToken{ BearerToken: &linkedca.BearerToken{ BearerToken: "token", }, }, DisableTlsClientAuth: true, CertType: linkedca.Webhook_X509, }, }, want: []dbWebhook{ { Name: "bearer", URL: "https://example.com", Kind: "ENRICHING", Secret: "secret", BearerToken: "token", DisableTLSClientAuth: true, CertType: linkedca.Webhook_X509.String(), }, }, }, "basic": { in: []*linkedca.Webhook{ { Name: "basic", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING, Secret: "secret", Auth: &linkedca.Webhook_BasicAuth{ BasicAuth: &linkedca.BasicAuth{ Username: "user", Password: "pass", }, }, }, }, want: []dbWebhook{ { Name: "basic", URL: "https://example.com", Kind: "ENRICHING", Secret: "secret", BasicAuth: &dbBasicAuth{ Username: "user", Password: "pass", }, CertType: linkedca.Webhook_ALL.String(), }, }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { got := linkedcaWebhooksToDB(tc.in) assert.Equals(t, tc.want, got) }) } } func Test_dbWebhooksToLinkedca(t *testing.T) { type test struct { in []dbWebhook want []*linkedca.Webhook } var tests = map[string]test{ "nil": { in: nil, want: nil, }, "zero": { in: []dbWebhook{}, want: nil, }, "bearer": { in: []dbWebhook{ { Name: "bearer", ID: "69350cb6-6c31-4b5e-bf25-affd5053427d", URL: "https://example.com", Kind: "ENRICHING", Secret: "secret", BearerToken: "token", DisableTLSClientAuth: true, }, }, want: []*linkedca.Webhook{ { Name: "bearer", Id: "69350cb6-6c31-4b5e-bf25-affd5053427d", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING, Secret: "secret", Auth: &linkedca.Webhook_BearerToken{ BearerToken: &linkedca.BearerToken{ BearerToken: "token", }, }, DisableTlsClientAuth: true, }, }, }, "basic": { in: []dbWebhook{ { Name: "basic", ID: "69350cb6-6c31-4b5e-bf25-affd5053427d", URL: "https://example.com", Kind: "ENRICHING", Secret: "secret", BasicAuth: &dbBasicAuth{ Username: "user", Password: "pass", }, }, }, want: []*linkedca.Webhook{ { Name: "basic", Id: "69350cb6-6c31-4b5e-bf25-affd5053427d", Url: "https://example.com", Kind: linkedca.Webhook_ENRICHING, Secret: "secret", Auth: &linkedca.Webhook_BasicAuth{ BasicAuth: &linkedca.BasicAuth{ Username: "user", Password: "pass", }, }, }, }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { got := dbWebhooksToLinkedca(tc.in) assert.Equals(t, tc.want, got) }) } } ================================================ FILE: authority/admin/db.go ================================================ package admin import ( "context" "encoding/json" "fmt" "github.com/pkg/errors" "github.com/smallstep/linkedca" ) const ( // DefaultAuthorityID is the default AuthorityID. This will be the ID // of the first Authority created, as well as the default AuthorityID // if one is not specified in the configuration. DefaultAuthorityID = "00000000-0000-0000-0000-000000000000" ) // ErrNotFound is an error that should be used by the authority.DB interface to // indicate that an entity does not exist. var ErrNotFound = errors.New("not found") // UnmarshalProvisionerDetails unmarshals details type to the specific provisioner details. func UnmarshalProvisionerDetails(typ linkedca.Provisioner_Type, data []byte) (*linkedca.ProvisionerDetails, error) { var v linkedca.ProvisionerDetails switch typ { case linkedca.Provisioner_JWK: v.Data = new(linkedca.ProvisionerDetails_JWK) case linkedca.Provisioner_OIDC: v.Data = new(linkedca.ProvisionerDetails_OIDC) case linkedca.Provisioner_GCP: v.Data = new(linkedca.ProvisionerDetails_GCP) case linkedca.Provisioner_AWS: v.Data = new(linkedca.ProvisionerDetails_AWS) case linkedca.Provisioner_AZURE: v.Data = new(linkedca.ProvisionerDetails_Azure) case linkedca.Provisioner_ACME: v.Data = new(linkedca.ProvisionerDetails_ACME) case linkedca.Provisioner_X5C: v.Data = new(linkedca.ProvisionerDetails_X5C) case linkedca.Provisioner_K8SSA: v.Data = new(linkedca.ProvisionerDetails_K8SSA) case linkedca.Provisioner_SSHPOP: v.Data = new(linkedca.ProvisionerDetails_SSHPOP) case linkedca.Provisioner_SCEP: v.Data = new(linkedca.ProvisionerDetails_SCEP) case linkedca.Provisioner_NEBULA: v.Data = new(linkedca.ProvisionerDetails_Nebula) default: return nil, fmt.Errorf("unsupported provisioner type %s", typ) } if err := json.Unmarshal(data, v.Data); err != nil { return nil, err } return &linkedca.ProvisionerDetails{Data: v.Data}, nil } // DB is the DB interface expected by the step-ca Admin API. type DB interface { CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error DeleteProvisioner(ctx context.Context, id string) error CreateAdmin(ctx context.Context, admin *linkedca.Admin) error GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error DeleteAdmin(ctx context.Context, id string) error CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error DeleteAuthorityPolicy(ctx context.Context) error } type dbKey struct{} // NewContext adds the given admin database to the context. func NewContext(ctx context.Context, db DB) context.Context { return context.WithValue(ctx, dbKey{}, db) } // FromContext returns the current admin database from the given context. func FromContext(ctx context.Context) (db DB, ok bool) { db, ok = ctx.Value(dbKey{}).(DB) return } // MustFromContext returns the current admin database from the given context. It // will panic if it's not in the context. func MustFromContext(ctx context.Context) DB { var ( db DB ok bool ) if db, ok = FromContext(ctx); !ok { panic("admin database is not in the context") } return db } // MockDB is an implementation of the DB interface that should only be used as // a mock in tests. type MockDB struct { MockCreateProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error MockGetProvisioner func(ctx context.Context, id string) (*linkedca.Provisioner, error) MockGetProvisioners func(ctx context.Context) ([]*linkedca.Provisioner, error) MockUpdateProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error MockDeleteProvisioner func(ctx context.Context, id string) error MockCreateAdmin func(ctx context.Context, adm *linkedca.Admin) error MockGetAdmin func(ctx context.Context, id string) (*linkedca.Admin, error) MockGetAdmins func(ctx context.Context) ([]*linkedca.Admin, error) MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error MockDeleteAdmin func(ctx context.Context, id string) error MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error) MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error MockDeleteAuthorityPolicy func(ctx context.Context) error MockError error MockRet1 interface{} } // CreateProvisioner mock. func (m *MockDB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { if m.MockCreateProvisioner != nil { return m.MockCreateProvisioner(ctx, prov) } else if m.MockError != nil { return m.MockError } return m.MockError } // GetProvisioner mock. func (m *MockDB) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) { if m.MockGetProvisioner != nil { return m.MockGetProvisioner(ctx, id) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*linkedca.Provisioner), m.MockError } // GetProvisioners mock func (m *MockDB) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) { if m.MockGetProvisioners != nil { return m.MockGetProvisioners(ctx) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.([]*linkedca.Provisioner), m.MockError } // UpdateProvisioner mock func (m *MockDB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { if m.MockUpdateProvisioner != nil { return m.MockUpdateProvisioner(ctx, prov) } return m.MockError } // DeleteProvisioner mock func (m *MockDB) DeleteProvisioner(ctx context.Context, id string) error { if m.MockDeleteProvisioner != nil { return m.MockDeleteProvisioner(ctx, id) } return m.MockError } // CreateAdmin mock func (m *MockDB) CreateAdmin(ctx context.Context, admin *linkedca.Admin) error { if m.MockCreateAdmin != nil { return m.MockCreateAdmin(ctx, admin) } return m.MockError } // GetAdmin mock. func (m *MockDB) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) { if m.MockGetAdmin != nil { return m.MockGetAdmin(ctx, id) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.(*linkedca.Admin), m.MockError } // GetAdmins mock func (m *MockDB) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) { if m.MockGetAdmins != nil { return m.MockGetAdmins(ctx) } else if m.MockError != nil { return nil, m.MockError } return m.MockRet1.([]*linkedca.Admin), m.MockError } // UpdateAdmin mock func (m *MockDB) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error { if m.MockUpdateAdmin != nil { return m.MockUpdateAdmin(ctx, adm) } return m.MockError } // DeleteAdmin mock func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error { if m.MockDeleteAdmin != nil { return m.MockDeleteAdmin(ctx, id) } return m.MockError } // CreateAuthorityPolicy mock func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { if m.MockCreateAuthorityPolicy != nil { return m.MockCreateAuthorityPolicy(ctx, policy) } return m.MockError } // GetAuthorityPolicy mock func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { if m.MockGetAuthorityPolicy != nil { return m.MockGetAuthorityPolicy(ctx) } return m.MockRet1.(*linkedca.Policy), m.MockError } // UpdateAuthorityPolicy mock func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error { if m.MockUpdateAuthorityPolicy != nil { return m.MockUpdateAuthorityPolicy(ctx, policy) } return m.MockError } // DeleteAuthorityPolicy mock func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error { if m.MockDeleteAuthorityPolicy != nil { return m.MockDeleteAuthorityPolicy(ctx) } return m.MockError } ================================================ FILE: authority/admin/errors.go ================================================ package admin import ( "encoding/json" "fmt" "net/http" "github.com/pkg/errors" "github.com/smallstep/certificates/api/render" ) // ProblemType is the type of the Admin problem. type ProblemType int const ( // ErrorNotFoundType resource not found. ErrorNotFoundType ProblemType = iota // ErrorAuthorityMismatchType resource Authority ID does not match the // context Authority ID. ErrorAuthorityMismatchType // ErrorDeletedType resource has been deleted. ErrorDeletedType // ErrorBadRequestType bad request. ErrorBadRequestType // ErrorNotImplementedType not implemented. ErrorNotImplementedType // ErrorUnauthorizedType unauthorized. ErrorUnauthorizedType // ErrorServerInternalType internal server error. ErrorServerInternalType // ErrorConflictType conflict. ErrorConflictType ) // String returns the string representation of the admin problem type, // fulfilling the Stringer interface. func (ap ProblemType) String() string { switch ap { case ErrorNotFoundType: return "notFound" case ErrorAuthorityMismatchType: return "authorityMismatch" case ErrorDeletedType: return "deleted" case ErrorBadRequestType: return "badRequest" case ErrorNotImplementedType: return "notImplemented" case ErrorUnauthorizedType: return "unauthorized" case ErrorServerInternalType: return "internalServerError" case ErrorConflictType: return "conflict" default: return fmt.Sprintf("unsupported error type '%d'", int(ap)) } } type errorMetadata struct { details string status int typ string String string } var ( errorServerInternalMetadata = errorMetadata{ typ: ErrorServerInternalType.String(), details: "the server experienced an internal error", status: http.StatusInternalServerError, } errorMap = map[ProblemType]errorMetadata{ ErrorNotFoundType: { typ: ErrorNotFoundType.String(), details: "resource not found", status: http.StatusNotFound, }, ErrorAuthorityMismatchType: { typ: ErrorAuthorityMismatchType.String(), details: "resource not owned by authority", status: http.StatusUnauthorized, }, ErrorDeletedType: { typ: ErrorDeletedType.String(), details: "resource is deleted", status: http.StatusNotFound, }, ErrorNotImplementedType: { typ: ErrorNotImplementedType.String(), details: "not implemented", status: http.StatusNotImplemented, }, ErrorBadRequestType: { typ: ErrorBadRequestType.String(), details: "bad request", status: http.StatusBadRequest, }, ErrorUnauthorizedType: { typ: ErrorUnauthorizedType.String(), details: "unauthorized", status: http.StatusUnauthorized, }, ErrorServerInternalType: errorServerInternalMetadata, ErrorConflictType: { typ: ErrorConflictType.String(), details: "conflict", status: http.StatusConflict, }, } ) // Error represents an Admin error type Error struct { Type string `json:"type"` Detail string `json:"detail"` Message string `json:"message"` Err error `json:"-"` Status int `json:"-"` } // IsType returns true if the error type matches the input type. func (e *Error) IsType(pt ProblemType) bool { return pt.String() == e.Type } // NewError creates a new Error type. func NewError(pt ProblemType, msg string, args ...interface{}) *Error { return newError(pt, errors.Errorf(msg, args...)) } func newError(pt ProblemType, err error) *Error { meta, ok := errorMap[pt] if !ok { meta = errorServerInternalMetadata return &Error{ Type: meta.typ, Detail: meta.details, Status: meta.status, Err: err, } } return &Error{ Type: meta.typ, Detail: meta.details, Status: meta.status, Err: err, } } // NewErrorISE creates a new ErrorServerInternalType Error. func NewErrorISE(msg string, args ...interface{}) *Error { return NewError(ErrorServerInternalType, msg, args...) } // WrapError attempts to wrap the internal error. func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error { var ee *Error switch { case err == nil: return nil case errors.As(err, &ee): if ee.Err == nil { ee.Err = errors.Errorf(msg+"; "+ee.Detail, args...) } else { ee.Err = errors.Wrapf(ee.Err, msg, args...) } return ee default: return newError(typ, errors.Wrapf(err, msg, args...)) } } // WrapErrorISE shortcut to wrap an internal server error type. func WrapErrorISE(err error, msg string, args ...interface{}) *Error { return WrapError(ErrorServerInternalType, err, msg, args...) } // StatusCode returns the status code and implements the StatusCoder interface. func (e *Error) StatusCode() int { return e.Status } // Error allows AError to implement the error interface. func (e *Error) Error() string { return e.Err.Error() } // Cause returns the internal error and implements the Causer interface. func (e *Error) Cause() error { if e.Err == nil { return errors.New(e.Detail) } return e.Err } // ToLog implements the EnableLogger interface. func (e *Error) ToLog() (interface{}, error) { b, err := json.Marshal(e) if err != nil { return nil, WrapErrorISE(err, "error marshaling authority.Error for logging") } return string(b), nil } // Render implements render.RenderableError for Error. func (e *Error) Render(w http.ResponseWriter, r *http.Request) { e.Message = e.Err.Error() render.JSONStatus(w, r, e, e.StatusCode()) } ================================================ FILE: authority/administrator/collection.go ================================================ package administrator import ( "sort" "sync" "github.com/pkg/errors" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/linkedca" ) // DefaultAdminLimit is the default limit for listing provisioners. const DefaultAdminLimit = 20 // DefaultAdminMax is the maximum limit for listing provisioners. const DefaultAdminMax = 100 type adminSlice []*linkedca.Admin func (p adminSlice) Len() int { return len(p) } func (p adminSlice) Less(i, j int) bool { return p[i].Id < p[j].Id } func (p adminSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // Collection is a memory map of admins. type Collection struct { byID *sync.Map bySubProv *sync.Map byProv *sync.Map sorted adminSlice provisioners *provisioner.Collection superCount int superCountByProvisioner map[string]int } // NewCollection initializes a collection of provisioners. The given list of // audiences are the audiences used by the JWT provisioner. func NewCollection(provisioners *provisioner.Collection) *Collection { return &Collection{ byID: new(sync.Map), byProv: new(sync.Map), bySubProv: new(sync.Map), superCountByProvisioner: map[string]int{}, provisioners: provisioners, } } // LoadByID a admin by the ID. func (c *Collection) LoadByID(id string) (*linkedca.Admin, bool) { return loadAdmin(c.byID, id) } type subProv struct { subject string provisioner string } func newSubProv(subject, prov string) subProv { return subProv{subject, prov} } // LoadBySubProv loads an admin by subject and provisioner name. func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) { return loadAdmin(c.bySubProv, newSubProv(sub, provName)) } // LoadByProvisioner loads admins by provisioner name. func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) { val, ok := c.byProv.Load(provName) if !ok { return nil, false } admins, ok := val.([]*linkedca.Admin) if !ok { return nil, false } return admins, true } // Store adds an admin to the collection and enforces the uniqueness of // admin IDs and admin subject <-> provisioner name combos. func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error { // Input validation. if adm.ProvisionerId != prov.GetID() { return admin.NewErrorISE("admin.provisionerId does not match provisioner argument") } // Store admin always in byID. ID must be unique. if _, loaded := c.byID.LoadOrStore(adm.Id, adm); loaded { return errors.New("cannot add multiple admins with the same id") } provName := prov.GetName() // Store admin always in bySubProv. Subject <-> ProvisionerName must be unique. if _, loaded := c.bySubProv.LoadOrStore(newSubProv(adm.Subject, provName), adm); loaded { c.byID.Delete(adm.Id) return errors.New("cannot add multiple admins with the same subject and provisioner") } var isSuper = (adm.Type == linkedca.Admin_SUPER_ADMIN) if admins, ok := c.LoadByProvisioner(provName); ok { c.byProv.Store(provName, append(admins, adm)) if isSuper { c.superCountByProvisioner[provName]++ } } else { c.byProv.Store(provName, []*linkedca.Admin{adm}) if isSuper { c.superCountByProvisioner[provName] = 1 } } if isSuper { c.superCount++ } c.sorted = append(c.sorted, adm) sort.Sort(c.sorted) return nil } // Remove deletes an admin from all associated collections and lists. func (c *Collection) Remove(id string) error { adm, ok := c.LoadByID(id) if !ok { return admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id) } if adm.Type == linkedca.Admin_SUPER_ADMIN && c.SuperCount() == 1 { return admin.NewError(admin.ErrorBadRequestType, "cannot remove the last super admin") } prov, ok := c.provisioners.Load(adm.ProvisionerId) if !ok { return admin.NewError(admin.ErrorNotFoundType, "provisioner %s for admin %s not found", adm.ProvisionerId, id) } provName := prov.GetName() adminsByProv, ok := c.LoadByProvisioner(provName) if !ok { return admin.NewError(admin.ErrorNotFoundType, "admins not found for provisioner %s", provName) } // Find index in sorted list. sortedIndex := sort.Search(c.sorted.Len(), func(i int) bool { return c.sorted[i].Id >= adm.Id }) if c.sorted[sortedIndex].Id != adm.Id { return admin.NewError(admin.ErrorNotFoundType, "admin %s not found in sorted list", adm.Id) } var found bool for i, a := range adminsByProv { if a.Id == adm.Id { // Remove admin from list. https://stackoverflow.com/questions/37334119/how-to-delete-an-element-from-a-slice-in-golang // Order does not matter. adminsByProv[i] = adminsByProv[len(adminsByProv)-1] c.byProv.Store(provName, adminsByProv[:len(adminsByProv)-1]) found = true } } if !found { return admin.NewError(admin.ErrorNotFoundType, "admin %s not found in adminsByProvisioner list", adm.Id) } // Remove index in sorted list copy(c.sorted[sortedIndex:], c.sorted[sortedIndex+1:]) // Shift a[i+1:] left one index. c.sorted[len(c.sorted)-1] = nil // Erase last element (write zero value). c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice. c.byID.Delete(adm.Id) c.bySubProv.Delete(newSubProv(adm.Subject, provName)) if adm.Type == linkedca.Admin_SUPER_ADMIN { c.superCount-- c.superCountByProvisioner[provName]-- } return nil } // Update updates the given admin in all related lists and collections. func (c *Collection) Update(id string, nu *linkedca.Admin) (*linkedca.Admin, error) { adm, ok := c.LoadByID(id) if !ok { return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", adm.Id) } if adm.Type == nu.Type { return adm, nil } if adm.Type == linkedca.Admin_SUPER_ADMIN && c.SuperCount() == 1 { return nil, admin.NewError(admin.ErrorBadRequestType, "cannot change role of last super admin") } adm.Type = nu.Type return adm, nil } // SuperCount returns the total number of admins. func (c *Collection) SuperCount() int { return c.superCount } // SuperCountByProvisioner returns the total number of admins. func (c *Collection) SuperCountByProvisioner(provName string) int { if cnt, ok := c.superCountByProvisioner[provName]; ok { return cnt } return 0 } // Find implements pagination on a list of sorted admins. func (c *Collection) Find(cursor string, limit int) ([]*linkedca.Admin, string) { switch { case limit <= 0: limit = DefaultAdminLimit case limit > DefaultAdminMax: limit = DefaultAdminMax } n := c.sorted.Len() i := sort.Search(n, func(i int) bool { return c.sorted[i].Id >= cursor }) slice := []*linkedca.Admin{} for ; i < n && len(slice) < limit; i++ { slice = append(slice, c.sorted[i]) } if i < n { return slice, c.sorted[i].Id } return slice, "" } func loadAdmin(m *sync.Map, key interface{}) (*linkedca.Admin, bool) { val, ok := m.Load(key) if !ok { return nil, false } adm, ok := val.(*linkedca.Admin) if !ok { return nil, false } return adm, true } ================================================ FILE: authority/admins.go ================================================ package authority import ( "context" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/linkedca" ) // LoadAdminByID returns an *linkedca.Admin with the given ID. func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) { a.adminMutex.RLock() defer a.adminMutex.RUnlock() return a.admins.LoadByID(id) } // LoadAdminBySubProv returns an *linkedca.Admin with the given ID. func (a *Authority) LoadAdminBySubProv(subject, prov string) (*linkedca.Admin, bool) { a.adminMutex.RLock() defer a.adminMutex.RUnlock() return a.admins.LoadBySubProv(subject, prov) } // GetAdmins returns a map listing each provisioner and the JWK Key Set // with their public keys. func (a *Authority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) { a.adminMutex.RLock() defer a.adminMutex.RUnlock() admins, nextCursor := a.admins.Find(cursor, limit) return admins, nextCursor, nil } // StoreAdmin stores an *linkedca.Admin to the authority. func (a *Authority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { a.adminMutex.Lock() defer a.adminMutex.Unlock() if adm.ProvisionerId != prov.GetID() { return admin.NewErrorISE("admin.provisionerId does not match provisioner argument") } if _, ok := a.admins.LoadBySubProv(adm.Subject, prov.GetName()); ok { return admin.NewError(admin.ErrorBadRequestType, "admin with subject %s and provisioner %s already exists", adm.Subject, prov.GetName()) } // Store to database -- this will set the ID. if err := a.adminDB.CreateAdmin(ctx, adm); err != nil { return admin.WrapErrorISE(err, "error creating admin") } if err := a.admins.Store(adm, prov); err != nil { if err := a.ReloadAdminResources(ctx); err != nil { return admin.WrapErrorISE(err, "error reloading admin resources on failed admin store") } return admin.WrapErrorISE(err, "error storing admin in authority cache") } return nil } // UpdateAdmin stores an *linkedca.Admin to the authority. func (a *Authority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { a.adminMutex.Lock() defer a.adminMutex.Unlock() adm, err := a.admins.Update(id, nu) if err != nil { return nil, admin.WrapErrorISE(err, "error updating cached admin %s", id) } if err := a.adminDB.UpdateAdmin(ctx, adm); err != nil { if err := a.ReloadAdminResources(ctx); err != nil { return nil, admin.WrapErrorISE(err, "error reloading admin resources on failed admin update") } return nil, admin.WrapErrorISE(err, "error updating admin %s", id) } return adm, nil } // RemoveAdmin removes an *linkedca.Admin from the authority. func (a *Authority) RemoveAdmin(ctx context.Context, id string) error { a.adminMutex.Lock() defer a.adminMutex.Unlock() return a.removeAdmin(ctx, id) } // removeAdmin helper that assumes lock. func (a *Authority) removeAdmin(ctx context.Context, id string) error { if err := a.admins.Remove(id); err != nil { return admin.WrapErrorISE(err, "error removing admin %s from authority cache", id) } if err := a.adminDB.DeleteAdmin(ctx, id); err != nil { if err := a.ReloadAdminResources(ctx); err != nil { return admin.WrapErrorISE(err, "error reloading admin resources on failed admin remove") } return admin.WrapErrorISE(err, "error deleting admin %s", id) } return nil } ================================================ FILE: authority/authority.go ================================================ package authority import ( "bytes" "context" "crypto" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/hex" "fmt" "log" "strings" "sync" "time" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "github.com/smallstep/linkedca" "go.step.sm/crypto/kms" kmsapi "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/sshagentkms" "go.step.sm/crypto/pemutil" "github.com/smallstep/certificates/authority/admin" adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql" "github.com/smallstep/certificates/authority/administrator" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/internal/constraints" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/scep" "github.com/smallstep/certificates/templates" "github.com/smallstep/nosql" ) // Authority implements the Certificate Authority internal interface. type Authority struct { config *config.Config keyManager kms.KeyManager provisioners *provisioner.Collection admins *administrator.Collection db db.AuthDB adminDB admin.DB templates *templates.Templates linkedCAToken string wrapTransport httptransport.Wrapper webhookClient provisioner.HTTPClient httpClient provisioner.HTTPClient // X509 CA password []byte issuerPassword []byte x509CAService cas.CertificateAuthorityService rootX509Certs []*x509.Certificate rootX509CertPool *x509.CertPool federatedX509Certs []*x509.Certificate intermediateX509Certs []*x509.Certificate certificates *sync.Map x509Enforcers []provisioner.CertificateEnforcer // SCEP CA scepOptions *scep.Options validateSCEP bool scepAuthority *scep.Authority scepKeyManager provisioner.SCEPKeyManager // SSH CA sshHostPassword []byte sshUserPassword []byte sshCAUserCertSignKey ssh.Signer sshCAHostCertSignKey ssh.Signer sshCAUserCerts []ssh.PublicKey sshCAHostCerts []ssh.PublicKey sshCAUserFederatedCerts []ssh.PublicKey sshCAHostFederatedCerts []ssh.PublicKey // CRL vars crlTicker *time.Ticker crlStopper chan struct{} crlMutex sync.Mutex // If true, do not re-initialize initOnce bool startTime time.Time // Custom functions sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) getIdentityFunc provisioner.GetIdentityFunc authorizeRenewFunc provisioner.AuthorizeRenewFunc authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc // Constraints and Policy engines constraintsEngine *constraints.Engine policyEngine *policy.Engine adminMutex sync.RWMutex // If true, do not initialize the authority skipInit bool // If true, do not output initialization logs quietInit bool // Called whenever applicable, in order to instrument the authority. meter Meter } // Info contains information about the authority. type Info struct { StartTime time.Time RootX509Certs []*x509.Certificate SSHCAUserPublicKey []byte SSHCAHostPublicKey []byte DNSNames []string } // New creates and initiates a new Authority type. func New(cfg *config.Config, opts ...Option) (*Authority, error) { err := cfg.Validate() if err != nil { return nil, err } var a = &Authority{ config: cfg, certificates: new(sync.Map), validateSCEP: true, meter: noopMeter{}, wrapTransport: httptransport.NoopWrapper(), } // Apply options. for _, fn := range opts { if err := fn(a); err != nil { return nil, err } } if a.keyManager != nil { a.keyManager = newInstrumentedKeyManager(a.keyManager, a.meter) } // Initialize system cert pool if err := initializeSystemCertPool(); err != nil { return nil, fmt.Errorf("failed to initialize the system cert pool: %w", err) } if !a.skipInit { // Initialize authority from options or configuration. if err := a.init(); err != nil { return nil, err } } return a, nil } // NewEmbedded initializes an authority that can be embedded in a different // project without the limitations of the config. func NewEmbedded(opts ...Option) (*Authority, error) { a := &Authority{ config: &config.Config{}, certificates: new(sync.Map), meter: noopMeter{}, wrapTransport: httptransport.NoopWrapper(), } // Apply options. for _, fn := range opts { if err := fn(a); err != nil { return nil, err } } if a.keyManager != nil { a.keyManager = newInstrumentedKeyManager(a.keyManager, a.meter) } // Initialize system cert pool if err := initializeSystemCertPool(); err != nil { return nil, fmt.Errorf("failed to initialize the system cert pool: %w", err) } // Validate required options switch { case a.config == nil: return nil, errors.New("cannot create an authority without a configuration") case len(a.rootX509Certs) == 0 && a.config.Root.HasEmpties(): return nil, errors.New("cannot create an authority without a root certificate") case a.x509CAService == nil && a.config.IntermediateCert == "": return nil, errors.New("cannot create an authority without an issuer certificate") case a.x509CAService == nil && a.config.IntermediateKey == "": return nil, errors.New("cannot create an authority without an issuer signer") } // Initialize config required fields. a.config.Init() if !a.skipInit { // Initialize authority from options or configuration. if err := a.init(); err != nil { return nil, err } } return a, nil } type authorityKey struct{} // NewContext adds the given authority to the context. func NewContext(ctx context.Context, a *Authority) context.Context { return context.WithValue(ctx, authorityKey{}, a) } // FromContext returns the current authority from the given context. func FromContext(ctx context.Context) (a *Authority, ok bool) { a, ok = ctx.Value(authorityKey{}).(*Authority) return } // MustFromContext returns the current authority from the given context. It will // panic if the authority is not in the context. func MustFromContext(ctx context.Context) *Authority { var ( a *Authority ok bool ) if a, ok = FromContext(ctx); !ok { panic("authority is not in the context") } return a } // ReloadAdminResources reloads admins and provisioners from the DB. func (a *Authority) ReloadAdminResources(ctx context.Context) error { var ( provList provisioner.List adminList []*linkedca.Admin ) if a.config.AuthorityConfig.EnableAdmin { provs, err := a.adminDB.GetProvisioners(ctx) if err != nil { return admin.WrapErrorISE(err, "error getting provisioners to initialize authority") } provList, err = provisionerListToCertificates(provs) if err != nil { return admin.WrapErrorISE(err, "error converting provisioner list to certificates") } adminList, err = a.adminDB.GetAdmins(ctx) if err != nil { return admin.WrapErrorISE(err, "error getting admins to initialize authority") } } else { provList = a.config.AuthorityConfig.Provisioners adminList = a.config.AuthorityConfig.Admins } provisionerConfig, err := a.generateProvisionerConfig(ctx) if err != nil { return admin.WrapErrorISE(err, "error generating provisioner config") } // Create provisioner collection. provClxn := provisioner.NewCollection(provisionerConfig.Audiences) for _, p := range provList { if err := p.Init(provisionerConfig); err != nil { log.Printf("failed to initialize %s provisioner %q: %v\n", p.GetType(), p.GetName(), err) p = provisioner.Uninitialized{ Interface: p, Reason: err, } } if err := provClxn.Store(p); err != nil { return err } } // Create admin collection. adminClxn := administrator.NewCollection(provClxn) for _, adm := range adminList { p, ok := provClxn.Load(adm.ProvisionerId) if !ok { return admin.NewErrorISE("provisioner %s not found when loading admin %s", adm.ProvisionerId, adm.Id) } if err := adminClxn.Store(adm, p); err != nil { return err } } a.config.AuthorityConfig.Provisioners = provList a.provisioners = provClxn a.config.AuthorityConfig.Admins = adminList a.admins = adminClxn switch { case a.requiresSCEP() && a.GetSCEP() == nil: // TODO(hs): try to initialize SCEP here too? It's a bit // problematic if this method is called as part of an update // via Admin API and a password needs to be provided. case a.requiresSCEP() && a.GetSCEP() != nil: // update the SCEP Authority with the currently active SCEP // provisioner names and revalidate the configuration. a.scepAuthority.UpdateProvisioners(a.getSCEPProvisionerNames()) if err := a.scepAuthority.Validate(); err != nil { log.Printf("failed validating SCEP authority: %v\n", err) } case !a.requiresSCEP() && a.GetSCEP() != nil: // TODO(hs): don't remove the authority if we can't also // reload it. //a.scepAuthority = nil } return nil } // init performs validation and initializes the fields of an Authority struct. func (a *Authority) init() error { // Check if handler has already been validated/initialized. if a.initOnce { return nil } var err error ctx := NewContext(context.Background(), a) // Set password if they are not set. var configPassword []byte if a.config.Password != "" { configPassword = []byte(a.config.Password) } if configPassword != nil && a.password == nil { a.password = configPassword } if a.sshHostPassword == nil { a.sshHostPassword = a.password } if a.sshUserPassword == nil { a.sshUserPassword = a.password } // Automatically enable admin for all linked cas. if a.linkedCAToken != "" { a.config.AuthorityConfig.EnableAdmin = true } // Initialize step-ca Database if it's not already initialized with WithDB. // If a.config.DB is nil then a simple, barebones in memory DB will be used. if a.db == nil { if a.db, err = db.New(a.config.DB); err != nil { return err } } // Initialize key manager if it has not been set in the options. if a.keyManager == nil { var options kmsapi.Options if a.config.KMS != nil { options = *a.config.KMS } a.keyManager, err = kms.New(ctx, options) if err != nil { return err } a.keyManager = newInstrumentedKeyManager(a.keyManager, a.meter) } // Initialize linkedca client if necessary. On a linked RA, the issuer // configuration might come from majordomo. var linkedcaClient *linkedCaClient if a.config.AuthorityConfig.EnableAdmin && a.linkedCAToken != "" && a.adminDB == nil { linkedcaClient, err = newLinkedCAClient(a.linkedCAToken) if err != nil { return err } // If authorityId is configured make sure it matches the one in the token if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, linkedcaClient.authorityID) { return errors.New("error initializing linkedca: token authority and configured authority do not match") } a.config.AuthorityConfig.AuthorityID = linkedcaClient.authorityID linkedcaClient.Run() } // Initialize the X.509 CA Service if it has not been set in the options. if a.x509CAService == nil { var options casapi.Options if a.config.AuthorityConfig.Options != nil { options = *a.config.AuthorityConfig.Options } // AuthorityID might be empty. It's always available linked CAs/RAs. options.AuthorityID = a.config.AuthorityConfig.AuthorityID // Configure linked RA if linkedcaClient != nil && options.CertificateAuthority == "" { conf, err := linkedcaClient.GetConfiguration(ctx) if err != nil { return err } if conf.RaConfig != nil { options.CertificateAuthority = conf.RaConfig.CaUrl options.CertificateAuthorityFingerprint = conf.RaConfig.Fingerprint options.CertificateIssuer = &casapi.CertificateIssuer{ Type: conf.RaConfig.Provisioner.Type.String(), Provisioner: conf.RaConfig.Provisioner.Name, } // Configure the RA authority type if needed if options.Type == "" { options.Type = casapi.StepCAS } } // Remote configuration is currently only supported on a linked RA if sc := conf.ServerConfig; sc != nil { if a.config.Address == "" { a.config.Address = sc.Address } if len(a.config.DNSNames) == 0 { a.config.DNSNames = sc.DnsNames } } } // Set the issuer password if passed in the flags. if options.CertificateIssuer != nil && a.issuerPassword != nil { options.CertificateIssuer.Password = string(a.issuerPassword) } // Read intermediate and create X509 signer for default CAS. if options.Is(casapi.SoftCAS) { options.CertificateChain, err = pemutil.ReadCertificateBundle(a.config.IntermediateCert) if err != nil { return err } options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ SigningKey: a.config.IntermediateKey, Password: a.password, }) if err != nil { return err } // If not defined with an option, add intermediates to the list of // certificates used for name constraints validation at issuance // time. if len(a.intermediateX509Certs) == 0 { a.intermediateX509Certs = append(a.intermediateX509Certs, options.CertificateChain...) } } a.x509CAService, err = cas.New(ctx, options) if err != nil { return err } // Get root certificate from CAS. if srv, ok := a.x509CAService.(casapi.CertificateAuthorityGetter); ok { resp, err := srv.GetCertificateAuthority(&casapi.GetCertificateAuthorityRequest{ Name: options.CertificateAuthority, }) if err != nil { return err } a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate) a.intermediateX509Certs = append(a.intermediateX509Certs, resp.IntermediateCertificates...) } } // Read root certificates and store them in the certificates map. if len(a.rootX509Certs) == 0 { a.rootX509Certs = make([]*x509.Certificate, 0, len(a.config.Root)) for _, path := range a.config.Root { crts, err := pemutil.ReadCertificateBundle(path) if err != nil { return err } a.rootX509Certs = append(a.rootX509Certs, crts...) } } for _, crt := range a.rootX509Certs { sum := sha256.Sum256(crt.Raw) a.certificates.Store(hex.EncodeToString(sum[:]), crt) } a.rootX509CertPool = x509.NewCertPool() for _, cert := range a.rootX509Certs { a.rootX509CertPool.AddCert(cert) } // Read federated certificates and store them in the certificates map. if len(a.federatedX509Certs) == 0 { a.federatedX509Certs = make([]*x509.Certificate, 0, len(a.config.FederatedRoots)) for _, path := range a.config.FederatedRoots { crts, err := pemutil.ReadCertificateBundle(path) if err != nil { return err } a.federatedX509Certs = append(a.federatedX509Certs, crts...) } } for _, crt := range a.federatedX509Certs { sum := sha256.Sum256(crt.Raw) a.certificates.Store(hex.EncodeToString(sum[:]), crt) } // Initialize HTTPClient with all root certs clientRoots := make([]*x509.Certificate, 0, len(a.rootX509Certs)+len(a.federatedX509Certs)) clientRoots = append(clientRoots, a.rootX509Certs...) clientRoots = append(clientRoots, a.federatedX509Certs...) a.httpClient = newHTTPClient(a.wrapTransport, clientRoots...) if err != nil { return err } // Decrypt and load SSH keys var tmplVars templates.Step if a.config.SSH != nil { if a.config.SSH.HostKey != "" { signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ SigningKey: a.config.SSH.HostKey, Password: a.sshHostPassword, }) if err != nil { return err } // If our signer is from sshagentkms, just unwrap it instead of // wrapping it in another layer, and this prevents crypto from // erroring out with: ssh: unsupported key type *agent.Key switch s := signer.(type) { case *sshagentkms.WrappedSSHSigner: a.sshCAHostCertSignKey = s.Signer case *instrumentedKMSSigner: switch is := s.Signer.(type) { case *sshagentkms.WrappedSSHSigner: a.sshCAHostCertSignKey = is.Signer default: a.sshCAHostCertSignKey, err = ssh.NewSignerFromSigner(s) } case crypto.Signer: a.sshCAHostCertSignKey, err = ssh.NewSignerFromSigner(s) default: return errors.Errorf("unsupported signer type %T", signer) } if err != nil { return errors.Wrap(err, "error creating ssh signer") } // Append public key to list of host certs a.sshCAHostCerts = append(a.sshCAHostCerts, a.sshCAHostCertSignKey.PublicKey()) a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, a.sshCAHostCertSignKey.PublicKey()) } if a.config.SSH.UserKey != "" { signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ SigningKey: a.config.SSH.UserKey, Password: a.sshUserPassword, }) if err != nil { return err } // If our signer is from sshagentkms, just unwrap it instead of // wrapping it in another layer, and this prevents crypto from // erroring out with: ssh: unsupported key type *agent.Key switch s := signer.(type) { case *sshagentkms.WrappedSSHSigner: a.sshCAUserCertSignKey = s.Signer case *instrumentedKMSSigner: switch is := s.Signer.(type) { case *sshagentkms.WrappedSSHSigner: a.sshCAUserCertSignKey = is.Signer default: a.sshCAUserCertSignKey, err = ssh.NewSignerFromSigner(s) } case crypto.Signer: a.sshCAUserCertSignKey, err = ssh.NewSignerFromSigner(s) default: return errors.Errorf("unsupported signer type %T", signer) } if err != nil { return errors.Wrap(err, "error creating ssh signer") } // Append public key to list of user certs a.sshCAUserCerts = append(a.sshCAUserCerts, a.sshCAUserCertSignKey.PublicKey()) a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, a.sshCAUserCertSignKey.PublicKey()) } // Append other public keys and add them to the template variables. for _, key := range a.config.SSH.Keys { publicKey := key.PublicKey() switch key.Type { case provisioner.SSHHostCert: if key.Federated { a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, publicKey) } else { a.sshCAHostCerts = append(a.sshCAHostCerts, publicKey) } case provisioner.SSHUserCert: if key.Federated { a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, publicKey) } else { a.sshCAUserCerts = append(a.sshCAUserCerts, publicKey) } default: return errors.Errorf("unsupported type %s", key.Type) } } } // Configure template variables. On the template variables HostFederatedKeys // and UserFederatedKeys we will skip the actual CA that will be available // in HostKey and UserKey. // // We cannot do it in the previous blocks because this configuration can be // injected using options. if a.sshCAHostCertSignKey != nil { tmplVars.SSH.HostKey = a.sshCAHostCertSignKey.PublicKey() tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts[1:]...) } else { tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts...) } if a.sshCAUserCertSignKey != nil { tmplVars.SSH.UserKey = a.sshCAUserCertSignKey.PublicKey() tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts[1:]...) } else { tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts...) } if a.config.AuthorityConfig.EnableAdmin { // Initialize step-ca Admin Database if it's not already initialized using // WithAdminDB. if a.adminDB == nil { if linkedcaClient != nil { a.adminDB = linkedcaClient } else { a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID) if err != nil { return err } } } provs, err := a.adminDB.GetProvisioners(ctx) if err != nil { return admin.WrapErrorISE(err, "error loading provisioners to initialize authority") } if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") { // Migration will currently only be kicked off once, because either one or more provisioners // are migrated or a default JWK provisioner will be created in the DB. It won't run for // linked or hosted deployments. Not for linked, because that case is explicitly checked // for above. Not for hosted, because there'll be at least an existing OIDC provisioner. var firstJWKProvisioner *linkedca.Provisioner if len(a.config.AuthorityConfig.Provisioners) > 0 { // Existing provisioners detected; try migrating them to DB storage. a.initLogf("Starting migration of provisioners") for _, p := range a.config.AuthorityConfig.Provisioners { lp, err := ProvisionerToLinkedca(p) if err != nil { return admin.WrapErrorISE(err, "error transforming provisioner %q while migrating", p.GetName()) } // Store the provisioner to be migrated if err := a.adminDB.CreateProvisioner(ctx, lp); err != nil { return admin.WrapErrorISE(err, "error creating provisioner %q while migrating", p.GetName()) } // Mark the first JWK provisioner, so that it can be used for administration purposes if firstJWKProvisioner == nil && lp.Type == linkedca.Provisioner_JWK { firstJWKProvisioner = lp a.initLogf("Migrated JWK provisioner %q with admin permissions", p.GetName()) } else { a.initLogf("Migrated %s provisioner %q", p.GetType(), p.GetName()) } } c := a.config if c.WasLoadedFromFile() { // The provisioners in the configuration file can be deleted from // the file by editing it. Automatic rewriting of the file was considered // to be too surprising for users and not the right solution for all // use cases, so we leave it up to users to this themselves. a.initLogf("Provisioners that were migrated can now be removed from `ca.json` by editing it") } a.initLogf("Finished migrating provisioners") } // Create first JWK provisioner for remote administration purposes if none exists yet if firstJWKProvisioner == nil { firstJWKProvisioner, err = CreateFirstProvisioner(ctx, a.adminDB, string(a.password)) if err != nil { return admin.WrapErrorISE(err, "error creating first provisioner") } a.initLogf("Created JWK provisioner %q with admin permissions", firstJWKProvisioner.GetName()) } // Create first super admin, belonging to the first JWK provisioner // TODO(hs): pass a user-provided first super admin subject to here. With `ca init` it's // added to the DB immediately if using remote management. But when migrating from // ca.json to the DB, this option doesn't exist. Adding a flag just to do it during // migration isn't nice. We could opt for a user to change it afterwards. There exist // cases in which creation of `step` could lock out a user from API access. This is the // case if `step` isn't allowed to be signed by Name Constraints or the X.509 policy. // We have protection for that when creating and updating a policy, but if a policy or // Name Constraints are in use at the time of migration, that could lock the user out. superAdminSubject := "step" if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{ ProvisionerId: firstJWKProvisioner.Id, Subject: superAdminSubject, Type: linkedca.Admin_SUPER_ADMIN, }); err != nil { return admin.WrapErrorISE(err, "error creating first admin") } a.initLogf("Created super admin %q for JWK provisioner %q", superAdminSubject, firstJWKProvisioner.GetName()) } } // Load Provisioners and Admins if err := a.ReloadAdminResources(ctx); err != nil { return err } // The SCEP functionality is provided through an instance of // scep.Authority. It is initialized when the CA is started and // if it doesn't exist yet. It gets refreshed if it already // exists. If the SCEP authority is no longer required on reload, // it gets removed. // TODO(hs): reloading through SIGHUP doesn't hit these cases. This // is because an entirely new authority.Authority is created, including // a new scep.Authority. Look into this to see if we want this to // keep working like that, or want to reuse a single instance and // update that. switch { case a.requiresSCEP() && a.GetSCEP() == nil: if a.scepOptions == nil { options := &scep.Options{ Roots: a.rootX509Certs, Intermediates: a.intermediateX509Certs, } // intermediate certificates can be empty in RA mode if len(a.intermediateX509Certs) > 0 { options.SignerCert = a.intermediateX509Certs[0] } // attempt to create the (default) SCEP signer if the intermediate // key is configured. if a.config.IntermediateKey != "" { if options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ SigningKey: a.config.IntermediateKey, Password: a.password, }); err != nil { return err } // TODO(hs): instead of creating the decrypter here, pass the // intermediate key + chain down to the SCEP authority, // and only instantiate it when required there. Is that possible? // Also with entering passwords? // TODO(hs): if moving the logic, try improving the logic for the // decrypter password too? Right now it needs to be entered multiple // times; I've observed it to be three times maximum, every time // the intermediate key is read. _, isRSAKey := options.Signer.Public().(*rsa.PublicKey) if km, ok := a.keyManager.(kmsapi.Decrypter); ok && isRSAKey { if decrypter, err := km.CreateDecrypter(&kmsapi.CreateDecrypterRequest{ DecryptionKey: a.config.IntermediateKey, Password: a.password, }); err == nil { // only pass the decrypter down when it was successfully created, // meaning it's an RSA key, and `CreateDecrypter` did not fail. options.Decrypter = decrypter // intermediate certificates can be empty in RA mode if len(options.Intermediates) > 0 { options.DecrypterCert = options.Intermediates[0] } } } } a.scepOptions = options } // provide the current SCEP provisioner names, so that the provisioners // can be validated when the CA is started. a.scepOptions.SCEPProvisionerNames = a.getSCEPProvisionerNames() // create a new SCEP authority scepAuthority, err := scep.New(a, *a.scepOptions) if err != nil { return err } if a.validateSCEP { // validate the SCEP authority if err := scepAuthority.Validate(); err != nil { a.initLogf("failed validating SCEP authority: %v", err) } } // set the SCEP authority a.scepAuthority = scepAuthority case !a.requiresSCEP() && a.GetSCEP() != nil: // clear the SCEP authority if it's no longer required a.scepAuthority = nil case a.requiresSCEP() && a.GetSCEP() != nil: // update the SCEP Authority with the currently active SCEP // provisioner names and revalidate the configuration. a.scepAuthority.UpdateProvisioners(a.getSCEPProvisionerNames()) if err := a.scepAuthority.Validate(); err != nil { log.Printf("failed validating SCEP authority: %v\n", err) } } // Load X509 constraints engine. // // This is currently only available in CA mode. if size := len(a.intermediateX509Certs); size > 0 { last := a.intermediateX509Certs[size-1] constraintCerts := make([]*x509.Certificate, 0, size+1) constraintCerts = append(constraintCerts, a.intermediateX509Certs...) for _, root := range a.rootX509Certs { if bytes.Equal(last.RawIssuer, root.RawSubject) && bytes.Equal(last.AuthorityKeyId, root.SubjectKeyId) { constraintCerts = append(constraintCerts, root) } } a.constraintsEngine = constraints.New(constraintCerts...) } // Load x509 and SSH Policy Engines if err := a.reloadPolicyEngines(ctx); err != nil { return err } // Configure templates, currently only ssh templates are supported. if a.sshCAHostCertSignKey != nil || a.sshCAUserCertSignKey != nil { a.templates = a.config.Templates if a.templates == nil { a.templates = templates.DefaultTemplates() } if a.templates.Data == nil { a.templates.Data = make(map[string]interface{}) } a.templates.Data["Step"] = tmplVars } // Start the CRL generator, we can assume the configuration is validated. if a.config.CRL.IsEnabled() { // Default cache duration to the default one if v := a.config.CRL.CacheDuration; v == nil || v.Duration <= 0 { a.config.CRL.CacheDuration = config.DefaultCRLCacheDuration } // Start CRL generator if err := a.startCRLGenerator(); err != nil { return err } } // JWT numeric dates are seconds. a.startTime = time.Now().Truncate(time.Second) // Set flag indicating that initialization has been completed, and should // not be repeated. a.initOnce = true return nil } // initLogf is used to log initialization information. The output // can be disabled by starting the CA with the `--quiet` flag. func (a *Authority) initLogf(format string, v ...any) { if !a.quietInit { log.Printf(format, v...) } } // GetID returns the define authority id or a zero uuid. func (a *Authority) GetID() string { const zeroUUID = "00000000-0000-0000-0000-000000000000" if id := a.config.AuthorityConfig.AuthorityID; id != "" { return id } return zeroUUID } // GetDatabase returns the authority database. If the configuration does not // define a database, GetDatabase will return a db.SimpleDB instance. func (a *Authority) GetDatabase() db.AuthDB { return a.db } // GetAdminDatabase returns the admin database, if one exists. func (a *Authority) GetAdminDatabase() admin.DB { return a.adminDB } // GetConfig returns the config. func (a *Authority) GetConfig() *config.Config { return a.config } // GetBackdate returns the [time.Duration] representing the // amount of time that is to be subtracted from the current // time when issuing a new certificate. func (a *Authority) GetBackdate() *time.Duration { if a.config == nil || a.config.AuthorityConfig == nil || a.config.AuthorityConfig.Backdate == nil { return nil } return &a.config.AuthorityConfig.Backdate.Duration } // GetInfo returns information about the authority. func (a *Authority) GetInfo() Info { ai := Info{ StartTime: a.startTime, RootX509Certs: a.rootX509Certs, DNSNames: a.config.DNSNames, } if a.sshCAUserCertSignKey != nil { ai.SSHCAUserPublicKey = ssh.MarshalAuthorizedKey(a.sshCAUserCertSignKey.PublicKey()) } if a.sshCAHostCertSignKey != nil { ai.SSHCAHostPublicKey = ssh.MarshalAuthorizedKey(a.sshCAHostCertSignKey.PublicKey()) } return ai } // IsAdminAPIEnabled returns a boolean indicating whether the Admin API has // been enabled. func (a *Authority) IsAdminAPIEnabled() bool { return a.config.AuthorityConfig.EnableAdmin } // Shutdown safely shuts down any clients, databases, etc. held by the Authority. func (a *Authority) Shutdown() error { if a.crlTicker != nil { a.crlTicker.Stop() close(a.crlStopper) } if err := a.keyManager.Close(); err != nil { log.Printf("error closing the key manager: %v", err) } return a.db.Shutdown() } // CloseForReload closes internal services, to allow a safe reload. func (a *Authority) CloseForReload() { if a.crlTicker != nil { a.crlTicker.Stop() close(a.crlStopper) } if err := a.keyManager.Close(); err != nil { log.Printf("error closing the key manager: %v", err) } if client, ok := a.adminDB.(*linkedCaClient); ok { client.Stop() } } // IsRevoked returns whether or not a certificate has been // revoked before. func (a *Authority) IsRevoked(sn string) (bool, error) { // Check the passive revocation table. if lca, ok := a.adminDB.(interface { IsRevoked(string) (bool, error) }); ok { return lca.IsRevoked(sn) } return a.db.IsRevoked(sn) } // requiresSCEP iterates over the configured provisioners // and determines if at least one of them is a SCEP provisioner. func (a *Authority) requiresSCEP() bool { for _, p := range a.config.AuthorityConfig.Provisioners { if p.GetType() == provisioner.TypeSCEP { return true } } return false } // getSCEPProvisionerNames returns the names of the SCEP provisioners // that are currently available in the CA. func (a *Authority) getSCEPProvisionerNames() (names []string) { for _, p := range a.config.AuthorityConfig.Provisioners { if p.GetType() == provisioner.TypeSCEP { names = append(names, p.GetName()) } } return } // GetSCEP returns the configured SCEP Authority func (a *Authority) GetSCEP() *scep.Authority { return a.scepAuthority } // HasACMEProvisioner returns true if at least one ACME provisioner is configured. func (a *Authority) HasACMEProvisioner() bool { for _, p := range a.config.AuthorityConfig.Provisioners { if p.GetType() == provisioner.TypeACME { return true } } return false } func (a *Authority) startCRLGenerator() error { if !a.config.CRL.IsEnabled() { return nil } // Check that there is a valid CRL in the DB right now. If it doesn't exist // or is expired, generate one now _, ok := a.db.(db.CertificateRevocationListDB) if !ok { return errors.Errorf("CRL Generation requested, but database does not support CRL generation") } // Always create a new CRL on startup in case the CA has been down and the // time to next expected CRL update is less than the cache duration. if err := a.GenerateCertificateRevocationList(); err != nil { return errors.Wrap(err, "could not generate a CRL") } a.crlStopper = make(chan struct{}, 1) a.crlTicker = time.NewTicker(a.config.CRL.TickerDuration()) go func() { for { select { case <-a.crlTicker.C: log.Println("Regenerating CRL") if err := a.GenerateCertificateRevocationList(); err != nil { log.Printf("error regenerating the CRL: %v", err) } case <-a.crlStopper: return } } }() return nil } ================================================ FILE: authority/authority_test.go ================================================ package authority import ( "context" "crypto" "crypto/rand" "crypto/sha256" "crypto/x509" "encoding/hex" "encoding/pem" "fmt" "net" "os" "path/filepath" "reflect" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "go.step.sm/crypto/jose" "go.step.sm/crypto/minica" "go.step.sm/crypto/pemutil" ) func TestMain(m *testing.M) { if err := initializeSystemCertPool(); err != nil { fmt.Fprintf(os.Stderr, "failed to initialize system cert pool: %v\n", err) fmt.Fprintln(os.Stderr, "See https://pkg.go.dev/crypto/x509#SystemCertPool") os.Exit(2) } os.Exit(m.Run()) } func testAuthority(t *testing.T, opts ...Option) *Authority { maxjwk, err := jose.ReadKey("testdata/secrets/max_pub.jwk") assert.FatalError(t, err) clijwk, err := jose.ReadKey("testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) disableRenewal := true enableSSHCA := true p := provisioner.List{ &provisioner.JWK{ Name: "Max", Type: "JWK", Key: maxjwk, }, &provisioner.JWK{ Name: "step-cli", Type: "JWK", Key: clijwk, Claims: &provisioner.Claims{ EnableSSHCA: &enableSSHCA, }, }, &provisioner.JWK{ Name: "dev", Type: "JWK", Key: maxjwk, Claims: &provisioner.Claims{ DisableRenewal: &disableRenewal, }, }, &provisioner.JWK{ Name: "renew_disabled", Type: "JWK", Key: maxjwk, Claims: &provisioner.Claims{ DisableRenewal: &disableRenewal, }, }, &provisioner.SSHPOP{ Name: "sshpop", Type: "SSHPOP", Claims: &provisioner.Claims{ EnableSSHCA: &enableSSHCA, }, }, &provisioner.ACME{ Name: "acme", Type: "ACME", }, &provisioner.JWK{ Name: "uninitialized", Type: "JWK", Key: clijwk, Claims: &provisioner.Claims{ MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, MaxTLSDur: &provisioner.Duration{Duration: time.Minute}, }, }, } c := &Config{ Address: "127.0.0.1:443", Root: []string{"testdata/certs/root_ca.crt"}, IntermediateCert: "testdata/certs/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", SSH: &SSHConfig{ HostKey: "testdata/secrets/ssh_host_ca_key", UserKey: "testdata/secrets/ssh_user_ca_key", }, DNSNames: []string{"example.com"}, Password: "pass", AuthorityConfig: &AuthConfig{ Provisioners: p, }, } a, err := New(c, opts...) assert.FatalError(t, err) // Avoid errors when test tokens are created before the test authority. This // happens in some tests where we re-create the same authority to test // special cases without re-creating the token. a.startTime = a.startTime.Add(-1 * time.Minute) return a } func TestAuthorityNew(t *testing.T) { type newTest struct { config *Config err error } tests := map[string]func(t *testing.T) *newTest{ "ok": func(t *testing.T) *newTest { c, err := LoadConfiguration("../ca/testdata/ca.json") assert.FatalError(t, err) return &newTest{ config: c, } }, "fail bad root": func(t *testing.T) *newTest { c, err := LoadConfiguration("../ca/testdata/ca.json") assert.FatalError(t, err) c.Root = []string{"foo"} return &newTest{ config: c, err: errors.New(`error reading "foo": no such file or directory`), } }, "fail bad password": func(t *testing.T) *newTest { c, err := LoadConfiguration("../ca/testdata/ca.json") assert.FatalError(t, err) c.Password = "wrong" return &newTest{ config: c, err: errors.New("error decrypting ../ca/testdata/secrets/intermediate_ca_key: x509: decryption password incorrect"), } }, "fail loading CA cert": func(t *testing.T) *newTest { c, err := LoadConfiguration("../ca/testdata/ca.json") assert.FatalError(t, err) c.IntermediateCert = "wrong" return &newTest{ config: c, err: errors.New(`error reading "wrong": no such file or directory`), } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) auth, err := New(tc.config) if err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { sum := sha256.Sum256(auth.rootX509Certs[0].Raw) root, ok := auth.certificates.Load(hex.EncodeToString(sum[:])) assert.Fatal(t, ok) assert.Equals(t, auth.rootX509Certs[0], root) assert.True(t, auth.initOnce) assert.NotNil(t, auth.x509CAService) for _, p := range tc.config.AuthorityConfig.Provisioners { var _p provisioner.Interface _p, ok = auth.provisioners.Load(p.GetID()) assert.True(t, ok) assert.Equals(t, p, _p) var kid, encryptedKey string if kid, encryptedKey, ok = p.GetEncryptedKey(); ok { var key string key, ok = auth.provisioners.LoadEncryptedKey(kid) assert.True(t, ok) assert.Equals(t, encryptedKey, key) } } // sanity check _, ok = auth.provisioners.Load("fooo") assert.False(t, ok) } } }) } } func TestAuthorityNew_bundles(t *testing.T) { ca0, err := minica.New() if err != nil { t.Fatal(err) } ca1, err := minica.New() if err != nil { t.Fatal(err) } ca2, err := minica.New() if err != nil { t.Fatal(err) } rootPath := t.TempDir() writeCert := func(fn string, certs ...*x509.Certificate) error { var b []byte for _, crt := range certs { b = append(b, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, })...) } return os.WriteFile(filepath.Join(rootPath, fn), b, 0600) } writeKey := func(fn string, signer crypto.Signer) error { _, err := pemutil.Serialize(signer, pemutil.ToFile(filepath.Join(rootPath, fn), 0600)) return err } if err := writeCert("root0.crt", ca0.Root); err != nil { t.Fatal(err) } if err := writeCert("int0.crt", ca0.Intermediate); err != nil { t.Fatal(err) } if err := writeKey("int0.key", ca0.Signer); err != nil { t.Fatal(err) } if err := writeCert("root1.crt", ca1.Root); err != nil { t.Fatal(err) } if err := writeCert("int1.crt", ca1.Intermediate); err != nil { t.Fatal(err) } if err := writeKey("int1.key", ca1.Signer); err != nil { t.Fatal(err) } if err := writeCert("bundle0.crt", ca0.Root, ca1.Root); err != nil { t.Fatal(err) } if err := writeCert("bundle1.crt", ca1.Root, ca2.Root); err != nil { t.Fatal(err) } tests := []struct { name string config *config.Config wantErr bool }{ {"ok ca0", &config.Config{ Address: "127.0.0.1:443", Root: []string{filepath.Join(rootPath, "root0.crt")}, IntermediateCert: filepath.Join(rootPath, "int0.crt"), IntermediateKey: filepath.Join(rootPath, "int0.key"), DNSNames: []string{"127.0.0.1"}, AuthorityConfig: &AuthConfig{}, }, false}, {"ok bundle", &config.Config{ Address: "127.0.0.1:443", Root: []string{filepath.Join(rootPath, "bundle0.crt")}, IntermediateCert: filepath.Join(rootPath, "int0.crt"), IntermediateKey: filepath.Join(rootPath, "int0.key"), DNSNames: []string{"127.0.0.1"}, AuthorityConfig: &AuthConfig{}, }, false}, {"ok federated ca1", &config.Config{ Address: "127.0.0.1:443", Root: []string{filepath.Join(rootPath, "root0.crt")}, FederatedRoots: []string{filepath.Join(rootPath, "root1.crt")}, IntermediateCert: filepath.Join(rootPath, "int0.crt"), IntermediateKey: filepath.Join(rootPath, "int0.key"), DNSNames: []string{"127.0.0.1"}, AuthorityConfig: &AuthConfig{}, }, false}, {"ok federated bundle", &config.Config{ Address: "127.0.0.1:443", Root: []string{filepath.Join(rootPath, "root0.crt")}, FederatedRoots: []string{filepath.Join(rootPath, "bundle1.crt")}, IntermediateCert: filepath.Join(rootPath, "int0.crt"), IntermediateKey: filepath.Join(rootPath, "int0.key"), DNSNames: []string{"127.0.0.1"}, AuthorityConfig: &AuthConfig{}, }, false}, {"fail root", &config.Config{ Address: "127.0.0.1:443", Root: []string{filepath.Join(rootPath, "missing.crt")}, IntermediateCert: filepath.Join(rootPath, "int0.crt"), IntermediateKey: filepath.Join(rootPath, "int0.key"), DNSNames: []string{"127.0.0.1"}, AuthorityConfig: &AuthConfig{}, }, true}, {"fail federated", &config.Config{ Address: "127.0.0.1:443", Root: []string{filepath.Join(rootPath, "root0.crt")}, FederatedRoots: []string{filepath.Join(rootPath, "missing.crt")}, IntermediateCert: filepath.Join(rootPath, "int0.crt"), IntermediateKey: filepath.Join(rootPath, "int0.key"), DNSNames: []string{"127.0.0.1"}, AuthorityConfig: &AuthConfig{}, }, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := New(tt.config) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } }) } } func TestAuthority_GetDatabase(t *testing.T) { auth := testAuthority(t) authWithDatabase, err := New(auth.config, WithDatabase(auth.db)) assert.FatalError(t, err) tests := []struct { name string auth *Authority want db.AuthDB }{ {"ok", auth, auth.db}, {"ok WithDatabase", authWithDatabase, auth.db}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.auth.GetDatabase(); !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetDatabase() = %v, want %v", got, tt.want) } }) } } func TestNewEmbedded(t *testing.T) { caPEM, err := os.ReadFile("testdata/certs/root_ca.crt") assert.FatalError(t, err) crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") assert.FatalError(t, err) key, err := pemutil.Read("testdata/secrets/intermediate_ca_key", pemutil.WithPassword([]byte("pass"))) assert.FatalError(t, err) type args struct { opts []Option } tests := []struct { name string args args wantErr bool }{ {"ok", args{[]Option{WithX509RootBundle(caPEM), WithX509Signer(crt, key.(crypto.Signer))}}, false}, {"ok empty config", args{[]Option{WithConfig(&Config{}), WithX509RootBundle(caPEM), WithX509Signer(crt, key.(crypto.Signer))}}, false}, {"ok config file", args{[]Option{WithConfigFile("../ca/testdata/ca.json")}}, false}, {"ok config", args{[]Option{WithConfig(&Config{ Root: []string{"testdata/certs/root_ca.crt"}, IntermediateCert: "testdata/certs/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", Password: "pass", AuthorityConfig: &AuthConfig{}, })}}, false}, {"fail options", args{[]Option{WithX509RootBundle([]byte("bad data"))}}, true}, {"fail missing config", args{[]Option{WithConfig(nil), WithX509RootBundle(caPEM), WithX509Signer(crt, key.(crypto.Signer))}}, true}, {"fail missing root", args{[]Option{WithX509Signer(crt, key.(crypto.Signer))}}, true}, {"fail missing signer", args{[]Option{WithX509RootBundle(caPEM)}}, true}, {"fail missing root file", args{[]Option{WithConfig(&Config{ IntermediateCert: "testdata/certs/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", Password: "pass", AuthorityConfig: &AuthConfig{}, })}}, true}, {"fail missing issuer", args{[]Option{WithConfig(&Config{ Root: []string{"testdata/certs/root_ca.crt"}, IntermediateKey: "testdata/secrets/intermediate_ca_key", Password: "pass", AuthorityConfig: &AuthConfig{}, })}}, true}, {"fail missing signer", args{[]Option{WithConfig(&Config{ Root: []string{"testdata/certs/root_ca.crt"}, IntermediateCert: "testdata/certs/intermediate_ca.crt", Password: "pass", AuthorityConfig: &AuthConfig{}, })}}, true}, {"fail bad password", args{[]Option{WithConfig(&Config{ Root: []string{"testdata/certs/root_ca.crt"}, IntermediateCert: "testdata/certs/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", Password: "bad", AuthorityConfig: &AuthConfig{}, })}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := NewEmbedded(tt.args.opts...) if (err != nil) != tt.wantErr { t.Errorf("NewEmbedded() error = %v, wantErr %v", err, tt.wantErr) return } if err == nil { assert.True(t, got.initOnce) assert.NotNil(t, got.rootX509Certs) assert.NotNil(t, got.x509CAService) } }) } } func TestNewEmbedded_Sign(t *testing.T) { caPEM, err := os.ReadFile("testdata/certs/root_ca.crt") assert.FatalError(t, err) crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") assert.FatalError(t, err) key, err := pemutil.Read("testdata/secrets/intermediate_ca_key", pemutil.WithPassword([]byte("pass"))) assert.FatalError(t, err) a, err := NewEmbedded(WithX509RootBundle(caPEM), WithX509Signer(crt, key.(crypto.Signer))) assert.FatalError(t, err) // Sign cr, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{ DNSNames: []string{"foo.bar.zar"}, }, key) assert.FatalError(t, err) csr, err := x509.ParseCertificateRequest(cr) assert.FatalError(t, err) cert, err := a.SignWithContext(context.Background(), csr, provisioner.SignOptions{}) assert.FatalError(t, err) assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames) assert.Equals(t, crt, cert[1]) } func TestNewEmbedded_GetTLSCertificate(t *testing.T) { caPEM, err := os.ReadFile("testdata/certs/root_ca.crt") assert.FatalError(t, err) crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") assert.FatalError(t, err) key, err := pemutil.Read("testdata/secrets/intermediate_ca_key", pemutil.WithPassword([]byte("pass"))) assert.FatalError(t, err) a, err := NewEmbedded(WithX509RootBundle(caPEM), WithX509Signer(crt, key.(crypto.Signer))) assert.FatalError(t, err) // GetTLSCertificate cert, err := a.GetTLSCertificate() assert.FatalError(t, err) assert.Equals(t, []string{"localhost"}, cert.Leaf.DNSNames) assert.True(t, cert.Leaf.IPAddresses[0].Equal(net.ParseIP("127.0.0.1"))) assert.True(t, cert.Leaf.IPAddresses[1].Equal(net.ParseIP("::1"))) } func TestAuthority_CloseForReload(t *testing.T) { tests := []struct { name string auth *Authority }{ {"ok", testAuthority(t)}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.auth.CloseForReload() }) } } func testScepAuthority(t *testing.T, opts ...Option) *Authority { p := provisioner.List{ &provisioner.SCEP{ Name: "scep1", Type: "SCEP", }, } c := &Config{ Address: "127.0.0.1:8443", InsecureAddress: "127.0.0.1:8080", Root: []string{"testdata/scep/root.crt"}, IntermediateCert: "testdata/scep/intermediate.crt", IntermediateKey: "testdata/scep/intermediate.key", DNSNames: []string{"example.com"}, Password: "pass", AuthorityConfig: &AuthConfig{ Provisioners: p, }, } a, err := New(c, opts...) assert.FatalError(t, err) return a } func TestAuthority_GetSCEP(t *testing.T) { _ = testScepAuthority(t) p := provisioner.List{ &provisioner.SCEP{ Name: "scep1", Type: "SCEP", }, } type fields struct { config *Config } tests := []struct { name string fields fields wantService bool wantErr bool }{ { name: "ok", fields: fields{ config: &Config{ Address: "127.0.0.1:8443", InsecureAddress: "127.0.0.1:8080", Root: []string{"testdata/scep/root.crt"}, IntermediateCert: "testdata/scep/intermediate.crt", IntermediateKey: "testdata/scep/intermediate.key", DNSNames: []string{"example.com"}, Password: "pass", AuthorityConfig: &AuthConfig{ Provisioners: p, }, }, }, wantService: true, wantErr: false, }, { name: "wrong password", fields: fields{ config: &Config{ Address: "127.0.0.1:8443", InsecureAddress: "127.0.0.1:8080", Root: []string{"testdata/scep/root.crt"}, IntermediateCert: "testdata/scep/intermediate.crt", IntermediateKey: "testdata/scep/intermediate.key", DNSNames: []string{"example.com"}, Password: "wrongpass", AuthorityConfig: &AuthConfig{ Provisioners: p, }, }, }, wantService: false, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a, err := New(tt.fields.config) if (err != nil) != tt.wantErr { t.Errorf("Authority.New(), error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantService { if got := a.GetSCEP(); (got != nil) != tt.wantService { t.Errorf("Authority.GetSCEPService() = %v, wantService %v", got, tt.wantService) } } }) } } func TestAuthority_GetID(t *testing.T) { type fields struct { authorityID string } tests := []struct { name string fields fields want string }{ {"ok", fields{""}, "00000000-0000-0000-0000-000000000000"}, {"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ AuthorityID: tt.fields.authorityID, }, }, } if got := a.GetID(); got != tt.want { t.Errorf("Authority.GetID() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/authorize.go ================================================ package authority import ( "context" "crypto/sha256" "crypto/x509" "encoding/hex" "fmt" "net/http" "net/url" "strconv" "strings" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" ) // Claims extends jose.Claims with step attributes. type Claims struct { jose.Claims SANs []string `json:"sans,omitempty"` Email string `json:"email,omitempty"` Nonce string `json:"nonce,omitempty"` } type skipTokenReuseKey struct{} // NewContextWithSkipTokenReuse creates a new context from ctx and attaches a // value to skip the token reuse. func NewContextWithSkipTokenReuse(ctx context.Context) context.Context { return context.WithValue(ctx, skipTokenReuseKey{}, true) } // SkipTokenReuseFromContext returns if the token reuse needs to be ignored. func SkipTokenReuseFromContext(ctx context.Context) bool { m, _ := ctx.Value(skipTokenReuseKey{}).(bool) return m } // getProvisionerFromToken extracts a provisioner from the given token without // doing any token validation. func (a *Authority) getProvisionerFromToken(token string) (provisioner.Interface, *Claims, error) { tok, err := jose.ParseSigned(token) if err != nil { return nil, nil, fmt.Errorf("error parsing token: %w", err) } // Get claims w/out verification. We need to look up the provisioner // key in order to verify the claims and we need the issuer from the claims // before we can look up the provisioner. var claims Claims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, nil, fmt.Errorf("error unmarshaling token: %w", err) } // This method will also validate the audiences for JWK provisioners. p, ok := a.provisioners.LoadByToken(tok, &claims.Claims) if !ok { return nil, nil, fmt.Errorf("provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")) } // If the provisioner is disabled, send an appropriate message to the client if _, ok := p.(provisioner.Uninitialized); ok { return nil, nil, errs.New(http.StatusUnauthorized, "provisioner %q is disabled due to an initialization error", p.GetName()) } return p, &claims, nil } // authorizeToken parses the token and returns the provisioner used to generate // the token. This method enforces the One-Time use policy (tokens can only be // used once). func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) { p, claims, err := a.getProvisionerFromToken(token) if err != nil { return nil, errs.UnauthorizedErr(err) } // TODO: use new persistence layer abstraction. // Do not accept tokens issued before the start of the ca. // This check is meant as a stopgap solution to the current lack of a persistence layer. if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) { return nil, errs.Unauthorized("token issued before the bootstrap of certificate authority") } } // Store the token to protect against reuse unless it's skipped. // If we cannot get a token id from the provisioner, just hash the token. if !SkipTokenReuseFromContext(ctx) { if err := a.UseToken(ctx, token, p); err != nil { return nil, err } } return p, nil } // AuthorizeAdminToken authorize an Admin token. func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c token") } verifiedChains, err := jwt.Headers[0].Certificates(x509.VerifyOptions{ Roots: a.rootX509CertPool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, }) if err != nil { return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error verifying x5c certificate chain in token") } leaf := verifiedChains[0][0] if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 { return nil, admin.NewError(admin.ErrorUnauthorizedType, "adminHandler.authorizeToken; certificate used to sign x5c token cannot be used for digital signature") } // Using the leaf certificates key to validate the claims accomplishes two // things: // 1. Asserts that the private key used to sign the token corresponds // to the public certificate in the `x5c` header of the token. // 2. Asserts that the claims are valid - have not been tampered with. var claims jose.Claims if err := jwt.Claims(leaf.PublicKey, &claims); err != nil { return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims") } prov, err := a.LoadProvisionerByCertificate(leaf) if err != nil { return nil, err } // Check that the token has not been used. if err := a.UseToken(r.Context(), token, prov); err != nil { return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token") } // According to "rfc7519 JSON Web Token" acceptable skew should be no // more than a few minutes. if err := claims.ValidateWithLeeway(jose.Expected{ Time: time.Now().UTC(), }, time.Minute); err != nil { return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "x5c.authorizeToken; invalid x5c claims") } // validate audience: path matches the current path if !matchesAudience(claims.Audience, a.config.Audience(r.URL.Path)) { return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid audience claim (aud)") } // validate issuer: old versions used the provisioner name, new version uses // 'step-admin-client/1.0' if claims.Issuer != "step-admin-client/1.0" && claims.Issuer != prov.GetName() { return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid issuer claim (iss)") } if claims.Subject == "" { return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token subject cannot be empty") } var ( ok bool adm *linkedca.Admin ) adminFound := false adminSANs := append([]string{leaf.Subject.CommonName}, leaf.DNSNames...) adminSANs = append(adminSANs, leaf.EmailAddresses...) for _, san := range adminSANs { if adm, ok = a.LoadAdminBySubProv(san, prov.GetName()); ok { adminFound = true break } } if !adminFound { return nil, admin.NewError(admin.ErrorUnauthorizedType, "adminHandler.authorizeToken; unable to load admin with subject(s) %s and provisioner '%s'", adminSANs, prov.GetName()) } if strings.HasPrefix(r.URL.Path, "/admin/admins") && (r.Method != "GET") && adm.Type != linkedca.Admin_SUPER_ADMIN { return nil, admin.NewError(admin.ErrorUnauthorizedType, "must have super admin access to make this request") } return adm, nil } // UseToken stores the token to protect against reuse. // // This method currently ignores most errors coming from the GetTokenID because // the token is already validated. But it should specifically ignore the errors // provisioner.ErrAllowTokenReuse, provisioner.ErrNotImplemented, and // provisioner.ErrTokenFlowNotSupported unless this latter one used in a renewal // flow without mTLS. func (a *Authority) UseToken(ctx context.Context, token string, prov provisioner.Interface) error { reuseKey, err := prov.GetTokenID(token) if err != nil { // Fail on ErrTokenFlowNotSupported but allow x5cInsecure renew token if errors.Is(err, provisioner.ErrTokenFlowNotSupported) && provisioner.RenewMethod != provisioner.MethodFromContext(ctx) { return errs.BadRequest("token flow is not supported") } return nil } if reuseKey == "" { sum := sha256.Sum256([]byte(token)) reuseKey = strings.ToLower(hex.EncodeToString(sum[:])) } ok, err := a.db.UseToken(reuseKey, token) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token") } if !ok { return errs.Unauthorized("token already used") } return nil } // Authorize grabs the method from the context and authorizes the request by // validating the one-time-token. func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.SignOption, error) { var opts = []interface{}{errs.WithKeyVal("token", token)} switch m := provisioner.MethodFromContext(ctx); m { case provisioner.SignMethod, provisioner.SignIdentityMethod: signOpts, err := a.authorizeSign(ctx, token) return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) case provisioner.RevokeMethod: return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeRevoke(ctx, token), "authority.Authorize", opts...) case provisioner.SSHSignMethod: if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...) } signOpts, err := a.authorizeSSHSign(ctx, token) return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) case provisioner.SSHRenewMethod: if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...) } _, err := a.authorizeSSHRenew(ctx, token) return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) case provisioner.SSHRevokeMethod: return nil, errs.Wrap(http.StatusInternalServerError, a.authorizeSSHRevoke(ctx, token), "authority.Authorize", opts...) case provisioner.SSHRekeyMethod: if a.sshCAHostCertSignKey == nil && a.sshCAUserCertSignKey == nil { return nil, errs.NotImplemented("authority.Authorize; ssh certificate flows are not enabled", opts...) } _, signOpts, err := a.authorizeSSHRekey(ctx, token) return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) default: return nil, errs.InternalServer("authority.Authorize; method %d is not supported", append([]interface{}{m}, opts...)...) } } // authorizeSign loads the provisioner from the token and calls the provisioner // AuthorizeSign method. Returns a list of methods to apply to the signing flow. func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) { p, err := a.authorizeToken(ctx, token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSign") } signOpts, err := p.AuthorizeSign(ctx, token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSign") } return signOpts, nil } // AuthorizeSign authorizes a signature request by validating and authenticating // a token that must be sent w/ the request. // // Deprecated: Use Authorize(context.Context, string) ([]provisioner.SignOption, error). func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) { ctx := NewContext(context.Background(), a) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) return a.Authorize(ctx, token) } // authorizeRevoke locates the provisioner used to generate the authenticating // token and then performs the token validation flow. func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { p, err := a.authorizeToken(ctx, token) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke") } if err := p.AuthorizeRevoke(ctx, token); err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke") } return nil } // authorizeRenew locates the provisioner (using the provisioner extension in the cert), and checks // if for the configured provisioner, the renewal is enabled or not. If the // extra extension cannot be found, authorize the renewal by default. // // TODO(mariano): should we authorize by default? func (a *Authority) authorizeRenew(ctx context.Context, cert *x509.Certificate) (provisioner.Interface, error) { serial := cert.SerialNumber.String() var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} isRevoked, err := a.IsRevoked(serial) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) } if isRevoked { return nil, errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) } p, err := a.LoadProvisionerByCertificate(cert) if err != nil { var ok bool // For backward compatibility this method will also succeed if the // certificate does not have a provisioner extension. LoadByCertificate // returns the noop provisioner if this happens, and it allows // certificate renewals. if p, ok = a.provisioners.LoadByCertificate(cert); !ok { return nil, errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) } } if err := p.AuthorizeRenew(ctx, cert); err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) } return p, nil } // authorizeSSHCertificate returns an error if the given certificate is revoked. func (a *Authority) authorizeSSHCertificate(_ context.Context, cert *ssh.Certificate) error { var err error var isRevoked bool serial := strconv.FormatUint(cert.Serial, 10) if lca, ok := a.adminDB.(interface { IsSSHRevoked(string) (bool, error) }); ok { isRevoked, err = lca.IsSSHRevoked(serial) } else { isRevoked, err = a.db.IsSSHRevoked(serial) } if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHCertificate", errs.WithKeyVal("serialNumber", serial)) } if isRevoked { return errs.Unauthorized("authority.authorizeSSHCertificate: certificate has been revoked", errs.WithKeyVal("serialNumber", serial)) } return nil } // authorizeSSHSign loads the provisioner from the token, checks that it has not // been used again and calls the provisioner AuthorizeSSHSign method. Returns a // list of methods to apply to the signing flow. func (a *Authority) authorizeSSHSign(ctx context.Context, token string) ([]provisioner.SignOption, error) { p, err := a.authorizeToken(ctx, token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeSSHSign") } signOpts, err := p.AuthorizeSSHSign(ctx, token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeSSHSign") } return signOpts, nil } // authorizeSSHRenew authorizes an SSH certificate renewal request, by // validating the contents of an SSHPOP token. func (a *Authority) authorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { p, err := a.authorizeToken(ctx, token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRenew") } cert, err := p.AuthorizeSSHRenew(ctx, token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRenew") } return cert, nil } // authorizeSSHRekey authorizes an SSH certificate rekey request, by // validating the contents of an SSHPOP token. func (a *Authority) authorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) { p, err := a.authorizeToken(ctx, token) if err != nil { return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRekey") } cert, signOpts, err := p.AuthorizeSSHRekey(ctx, token) if err != nil { return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRekey") } return cert, signOpts, nil } // authorizeSSHRevoke authorizes an SSH certificate revoke request, by // validating the contents of an SSHPOP token. func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error { p, err := a.authorizeToken(ctx, token) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRevoke") } if err = p.AuthorizeSSHRevoke(ctx, token); err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHRevoke") } return nil } // AuthorizeRenewToken validates the renew token and returns the leaf // certificate in the x5cInsecure header. func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) { var claims jose.Claims jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs) if err != nil { return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token")) } leaf := chain[0][0] if err := jwt.Claims(leaf.PublicKey, &claims); err != nil { return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token")) } p, err := a.LoadProvisionerByCertificate(leaf) if err != nil { return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate") } if err := a.UseToken(ctx, ott, p); err != nil { return nil, err } if err := claims.ValidateWithLeeway(jose.Expected{ Subject: leaf.Subject.CommonName, Time: time.Now().UTC(), }, time.Minute); err != nil { switch { case errors.Is(err, jose.ErrInvalidIssuer): return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid issuer claim (iss)")) case errors.Is(err, jose.ErrInvalidSubject): return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid subject claim (sub)")) case errors.Is(err, jose.ErrNotValidYet): return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token not valid yet (nbf)")) case errors.Is(err, jose.ErrExpired): return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token is expired (exp)")) case errors.Is(err, jose.ErrIssuedInTheFuture): return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token issued in the future (iat)")) default: return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token")) } } audiences := a.config.GetAudiences().Renew if !matchesAudience(claims.Audience, audiences) && !isRAProvisioner(p) { return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)")) } // validate issuer: old versions used the provisioner name, new version uses // 'step-ca-client/1.0' if claims.Issuer != "step-ca-client/1.0" && claims.Issuer != p.GetName() { return nil, admin.NewError(admin.ErrorUnauthorizedType, "error validating renew token: invalid issuer claim (iss)") } return leaf, nil } // matchesAudience returns true if A and B share at least one element. func matchesAudience(as, bs []string) bool { if len(bs) == 0 || len(as) == 0 { return false } for _, b := range bs { for _, a := range as { if b == a || stripPort(a) == stripPort(b) { return true } } } return false } // stripPort attempts to strip the port from the given url. If parsing the url // produces errors it will just return the passed argument. func stripPort(rawurl string) string { u, err := url.Parse(rawurl) if err != nil { return rawurl } u.Host = u.Hostname() return u.String() } ================================================ FILE: authority/authorize_test.go ================================================ package authority import ( "context" "crypto" "crypto/ed25519" "crypto/rand" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/base64" "errors" "fmt" "net/http" "reflect" "strconv" "testing" "time" "golang.org/x/crypto/ssh" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" "go.step.sm/crypto/x509util" "github.com/google/uuid" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" ) var testAudiences = provisioner.Audiences{ Sign: []string{"https://example.com/1.0/sign", "https://example.com/sign"}, Revoke: []string{"https://example.com/1.0/revoke", "https://example.com/revoke"}, SSHSign: []string{"https://example.com/1.0/ssh/sign"}, SSHRevoke: []string{"https://example.com/1.0/ssh/revoke"}, SSHRenew: []string{"https://example.com/1.0/ssh/renew"}, SSHRekey: []string{"https://example.com/1.0/ssh/rekey"}, } type tokOption func(*jose.SignerOptions) error func withSSHPOPFile(cert *ssh.Certificate) tokOption { return func(so *jose.SignerOptions) error { so.WithHeader("sshpop", base64.StdEncoding.EncodeToString(cert.Marshal())) return nil } } func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", jwk.KeyID) for _, o := range tokOpts { if err := o(so); err != nil { return "", err } } sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) if err != nil { return "", err } id, err := randutil.ASCII(64) if err != nil { return "", err } claims := struct { jose.Claims SANS []string `json:"sans"` }{ Claims: jose.Claims{ ID: id, Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, }, SANS: sans, } return jose.Signed(sig).Claims(claims).CompactSerialize() } func generateCustomToken(sub, iss, aud string, jwk *jose.JSONWebKey, extraHeaders, extraClaims map[string]any) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", jwk.KeyID) for k, v := range extraHeaders { so.WithHeader(jose.HeaderKey(k), v) } sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) if err != nil { return "", err } id, err := randutil.ASCII(64) if err != nil { return "", err } iat := time.Now() claims := jose.Claims{ ID: id, Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, } return jose.Signed(sig).Claims(claims).Claims(extraClaims).CompactSerialize() } func TestAuthority_authorizeToken(t *testing.T) { a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) now := time.Now().UTC() validIssuer := "step-cli" validAudience := []string{"https://example.com/revoke"} type authorizeTest struct { auth *Authority token string err error code int } tests := map[string]func(t *testing.T) *authorizeTest{ "fail/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", err: errors.New("error parsing token"), code: http.StatusUnauthorized, } }, "fail/prehistoric-token": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), IssuedAt: jose.NewNumericDate(now.Add(-time.Hour)), Audience: validAudience, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, err: errors.New("token issued before the bootstrap of certificate authority"), code: http.StatusUnauthorized, } }, "fail/provisioner-not-found": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "44", } _sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "foo")) assert.FatalError(t, err) raw, err := jose.Signed(_sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, err: errors.New("provisioner not found or invalid audience (https://example.com/revoke)"), code: http.StatusUnauthorized, } }, "fail/token-flow-not-supported": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), IssuedAt: jose.NewNumericDate(now), Audience: []string{"acme/acme"}, ID: "45", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, err: errors.New("token flow is not supported"), code: http.StatusBadRequest, } }, "ok/simpledb": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, } }, "fail/simpledb/token-already-used": func(t *testing.T) *authorizeTest { _a := testAuthority(t) cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) _, err = _a.authorizeToken(context.Background(), raw) assert.FatalError(t, err) return &authorizeTest{ auth: _a, token: raw, err: errors.New("token already used"), code: http.StatusUnauthorized, } }, "ok/sha256": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, } }, "fail/sha256/token-already-used": func(t *testing.T) *authorizeTest { _a := testAuthority(t) cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) _, err = _a.authorizeToken(context.Background(), raw) assert.FatalError(t, err) return &authorizeTest{ auth: _a, token: raw, err: errors.New("token already used"), code: http.StatusUnauthorized, } }, "ok/mockNoSQLDB": func(t *testing.T) *authorizeTest { _a := testAuthority(t) _a.db = &db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, } cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: _a, token: raw, } }, "fail/mockNoSQLDB/error": func(t *testing.T) *authorizeTest { _a := testAuthority(t) _a.db = &db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return false, errors.New("force") }, } cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: _a, token: raw, err: errors.New("failed when attempting to store token: force"), code: http.StatusInternalServerError, } }, "fail/mockNoSQLDB/token-already-used": func(t *testing.T) *authorizeTest { _a := testAuthority(t) _a.db = &db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return false, nil }, } cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: _a, token: raw, err: errors.New("token already used"), code: http.StatusUnauthorized, } }, "fail/uninitialized": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: "uninitialized", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: uuid.NewString(), } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, err: errors.New(`provisioner "uninitialized" is disabled due to an initialization error`), code: http.StatusUnauthorized, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) p, err := tc.auth.authorizeToken(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, p.GetID(), "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") } } }) } } func TestAuthority_authorizeRevoke(t *testing.T) { a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) now := time.Now().UTC() validIssuer := "step-cli" validAudience := []string{"https://example.com/revoke"} type authorizeTest struct { auth *Authority token string err error code int } tests := map[string]func(t *testing.T) *authorizeTest{ "fail/token/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", err: errors.New("authority.authorizeRevoke: error parsing token"), code: http.StatusUnauthorized, } }, "fail/token/invalid-subject": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, err: errors.New("authority.authorizeRevoke: jwk.AuthorizeRevoke: jwk.authorizeToken; jwk token subject cannot be empty"), code: http.StatusUnauthorized, } }, "ok/token": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func TestAuthority_authorizeSign(t *testing.T) { a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) now := time.Now().UTC() validIssuer := "step-cli" validAudience := []string{"https://example.com/sign"} type authorizeTest struct { auth *Authority token string err error code int } tests := map[string]func(t *testing.T) *authorizeTest{ "fail/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", err: errors.New("authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, "fail/invalid-subject": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, err: errors.New("authority.authorizeSign: jwk.AuthorizeSign: jwk.authorizeToken; jwk token subject cannot be empty"), code: http.StatusUnauthorized, } }, "ok": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) got, err := tc.auth.authorizeSign(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, 11, len(got)) // number of provisioner.SignOptions returned } } }) } } func TestAuthority_Authorize(t *testing.T) { a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) now := time.Now().UTC() validIssuer := "step-cli" type authorizeTest struct { auth *Authority token string ctx context.Context err error code int } tests := map[string]func(t *testing.T) *authorizeTest{ "default-to-signMethod": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", ctx: context.Background(), err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, "fail/sign/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod), err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, "ok/sign": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: testAudiences.Sign, ID: "1", } token, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: token, ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod), } }, "fail/revoke/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod), err: errors.New("authority.Authorize: authority.authorizeRevoke: error parsing token"), code: http.StatusUnauthorized, } }, "ok/revoke": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: testAudiences.Revoke, ID: "2", } token, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: token, ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod), } }, "fail/sshSign/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), err: errors.New("authority.Authorize: authority.authorizeSSHSign: error parsing token"), code: http.StatusUnauthorized, } }, "fail/sshSign/disabled": func(t *testing.T) *authorizeTest { _a := testAuthority(t) _a.sshCAHostCertSignKey = nil _a.sshCAUserCertSignKey = nil return &authorizeTest{ auth: _a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), err: errors.New("authority.Authorize; ssh certificate flows are not enabled"), code: http.StatusNotImplemented, } }, "ok/sshSign": func(t *testing.T) *authorizeTest { raw, err := generateSimpleSSHUserToken(validIssuer, testAudiences.SSHSign[0], jwk) assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), } }, "fail/sshRenew/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), err: errors.New("authority.Authorize: authority.authorizeSSHRenew: error parsing token"), code: http.StatusUnauthorized, } }, "fail/sshRenew/disabled": func(t *testing.T) *authorizeTest { _a := testAuthority(t) _a.sshCAHostCertSignKey = nil _a.sshCAUserCertSignKey = nil return &authorizeTest{ auth: _a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), err: errors.New("authority.Authorize; ssh certificate flows are not enabled"), code: http.StatusNotImplemented, } }, "ok/sshRenew": func(t *testing.T) *authorizeTest { key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") assert.FatalError(t, err) signer, ok := key.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") sshSigner, err := ssh.NewSignerFromSigner(signer) assert.FatalError(t, err) cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) assert.FatalError(t, err) p, ok := a.provisioners.Load("sshpop/sshpop") assert.Fatal(t, ok, "sshpop provisioner not found in test authority") tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return &authorizeTest{ auth: a, token: tok, ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), } }, "fail/sshRevoke/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: error parsing token"), code: http.StatusUnauthorized, } }, "ok/sshRevoke": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: testAudiences.SSHRevoke, ID: "3", } token, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: token, ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), } }, "fail/sshRekey/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), err: errors.New("authority.Authorize: authority.authorizeSSHRekey: error parsing token"), code: http.StatusUnauthorized, } }, "fail/sshRekey/disabled": func(t *testing.T) *authorizeTest { _a := testAuthority(t) _a.sshCAHostCertSignKey = nil _a.sshCAUserCertSignKey = nil return &authorizeTest{ auth: _a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), err: errors.New("authority.Authorize; ssh certificate flows are not enabled"), code: http.StatusNotImplemented, } }, "ok/sshRekey": func(t *testing.T) *authorizeTest { key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") assert.FatalError(t, err) signer, ok := key.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") sshSigner, err := ssh.NewSignerFromSigner(signer) assert.FatalError(t, err) cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) assert.FatalError(t, err) p, ok := a.provisioners.Load("sshpop/sshpop") assert.Fatal(t, ok, "sshpop provisioner not found in test authority") tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return &authorizeTest{ auth: a, token: tok, ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), } }, "fail/unexpected-method": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), 15), err: errors.New("authority.Authorize; method 15 is not supported"), code: http.StatusInternalServerError, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) got, err := tc.auth.Authorize(tc.ctx, tc.token) if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, got) var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) var ctxErr *errs.Error assert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") assert.Equals(t, ctxErr.Details["token"], tc.token) } } else { assert.Nil(t, tc.err) } }) } } func TestAuthority_authorizeRenew(t *testing.T) { fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt") fooCrt.NotAfter = time.Now().Add(time.Hour) assert.FatalError(t, err) renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt") assert.FatalError(t, err) otherCrt, err := pemutil.ReadCertificate("testdata/certs/provisioner-not-found.crt") assert.FatalError(t, err) type authorizeTest struct { auth *Authority cert *x509.Certificate err error code int } tests := map[string]func(t *testing.T) *authorizeTest{ "fail/db.IsRevoked-error": func(t *testing.T) *authorizeTest { a := testAuthority(t) a.db = &db.MockAuthDB{ MIsRevoked: func(key string) (bool, error) { return false, errors.New("force") }, } return &authorizeTest{ auth: a, cert: fooCrt, err: errors.New("authority.authorizeRenew: force"), code: http.StatusInternalServerError, } }, "fail/revoked": func(t *testing.T) *authorizeTest { a := testAuthority(t) a.db = &db.MockAuthDB{ MIsRevoked: func(key string) (bool, error) { return true, nil }, } return &authorizeTest{ auth: a, cert: fooCrt, err: errors.New("authority.authorizeRenew: certificate has been revoked"), code: http.StatusUnauthorized, } }, "fail/load-provisioner": func(t *testing.T) *authorizeTest { a := testAuthority(t) a.db = &db.MockAuthDB{ MIsRevoked: func(key string) (bool, error) { return false, nil }, } return &authorizeTest{ auth: a, cert: otherCrt, err: errors.New("authority.authorizeRenew: provisioner not found"), code: http.StatusUnauthorized, } }, "fail/provisioner-authorize-renewal-fail": func(t *testing.T) *authorizeTest { a := testAuthority(t) a.db = &db.MockAuthDB{ MIsRevoked: func(key string) (bool, error) { return false, nil }, } return &authorizeTest{ auth: a, cert: renewDisabledCrt, err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"), code: http.StatusUnauthorized, } }, "ok": func(t *testing.T) *authorizeTest { a := testAuthority(t) a.db = &db.MockAuthDB{ MIsRevoked: func(key string) (bool, error) { return false, nil }, } return &authorizeTest{ auth: a, cert: fooCrt, } }, "ok/from db": func(t *testing.T) *authorizeTest { a := testAuthority(t) a.db = &db.MockAuthDB{ MIsRevoked: func(key string) (bool, error) { return false, nil }, MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { p, ok := a.provisioners.LoadByName("step-cli") if !ok { t.Fatal("provisioner step-cli not found") } return &db.CertificateData{ Provisioner: &db.ProvisionerData{ ID: p.GetID(), }, }, nil }, } return &authorizeTest{ auth: a, cert: fooCrt, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) _, err := tc.auth.authorizeRenew(context.Background(), tc.cert) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) var ctxErr *errs.Error assert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") assert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String()) } } else { assert.Nil(t, tc.err) } }) } } func generateSimpleSSHUserToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { return generateSSHToken("subject@localhost", iss, aud, time.Now(), &provisioner.SignSSHOptions{ CertType: "user", Principals: []string{"name"}, }, jwk) } type stepPayload struct { SSH *provisioner.SignSSHOptions `json:"ssh,omitempty"` } func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner.SignSSHOptions, jwk *jose.JSONWebKey) (string, error) { sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), ) if err != nil { return "", err } id, err := randutil.ASCII(64) if err != nil { return "", err } claims := struct { jose.Claims Step *stepPayload `json:"step,omitempty"` }{ Claims: jose.Claims{ ID: id, Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, }, Step: &stepPayload{ SSH: sshOpts, }, } return jose.Signed(sig).Claims(claims).CompactSerialize() } func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { now := time.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) if err != nil { return nil, nil, err } cert.Key, err = ssh.NewPublicKey(jwk.Public().Key) if err != nil { return nil, nil, err } if cert.ValidAfter == 0 { cert.ValidAfter = uint64(now.Unix()) } if cert.ValidBefore == 0 { cert.ValidBefore = uint64(now.Add(time.Hour).Unix()) } if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } return cert, jwk, nil } func TestAuthority_authorizeSSHSign(t *testing.T) { a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) now := time.Now().UTC() validIssuer := "step-cli" validAudience := []string{"https://example.com/ssh/sign"} type authorizeTest struct { auth *Authority token string err error code int } tests := map[string]func(t *testing.T) *authorizeTest{ "fail/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", err: errors.New("authority.authorizeSSHSign: error parsing token"), code: http.StatusUnauthorized, } }, "fail/invalid-subject": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, err: errors.New("authority.authorizeSSHSign: jwk.AuthorizeSSHSign: jwk.authorizeToken; jwk token subject cannot be empty"), code: http.StatusUnauthorized, } }, "ok": func(t *testing.T) *authorizeTest { raw, err := generateSimpleSSHUserToken(validIssuer, validAudience[0], jwk) assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Len(t, 10, got) // number of provisioner.SignOptions returned } } }) } } func TestAuthority_authorizeSSHRenew(t *testing.T) { now := time.Now().UTC() sshpop := func(a *Authority) (*ssh.Certificate, string) { p, ok := a.provisioners.Load("sshpop/sshpop") assert.Fatal(t, ok, "sshpop provisioner not found in test authority") key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") assert.FatalError(t, err) signer, ok := key.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") sshSigner, err := ssh.NewSignerFromSigner(signer) assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) assert.FatalError(t, err) token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return cert, token } a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) validIssuer := "step-cli" type authorizeTest struct { auth *Authority token string cert *ssh.Certificate err error code int } tests := map[string]func(t *testing.T) *authorizeTest{ "fail/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", err: errors.New("authority.authorizeSSHRenew: error parsing token"), code: http.StatusUnauthorized, } }, "fail/sshRenew-unimplemented-jwk-provisioner": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: testAudiences.SSHRenew, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, err: errors.New("authority.authorizeSSHRenew: provisioner.AuthorizeSSHRenew not implemented"), code: http.StatusUnauthorized, } }, "fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest { aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error { return errs.Forbidden("forbidden") })) _, token := sshpop(aa) return &authorizeTest{ auth: aa, token: token, err: errors.New("authority.authorizeSSHRenew: forbidden"), code: http.StatusForbidden, } }, "ok": func(t *testing.T) *authorizeTest { cert, token := sshpop(a) return &authorizeTest{ auth: a, token: token, cert: cert, } }, "ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest { aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error { return nil })) cert, token := sshpop(aa) return &authorizeTest{ auth: aa, token: token, cert: cert, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.cert.Serial, got.Serial) } } }) } } func TestAuthority_authorizeSSHRevoke(t *testing.T) { a := testAuthority(t, []Option{WithDatabase(&db.MockAuthDB{ MIsSSHRevoked: func(serial string) (bool, error) { return false, nil }, MUseToken: func(id, tok string) (bool, error) { return true, nil }, })}...) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) now := time.Now().UTC() validIssuer := "step-cli" type authorizeTest struct { auth *Authority token string cert *ssh.Certificate err error code int } tests := map[string]func(t *testing.T) *authorizeTest{ "fail/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", err: errors.New("authority.authorizeSSHRevoke: error parsing token"), code: http.StatusUnauthorized, } }, "fail/invalid-subject": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: testAudiences.SSHRevoke, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, err: errors.New("authority.authorizeSSHRevoke: jwk.AuthorizeSSHRevoke: jwk.authorizeToken; jwk token subject cannot be empty"), code: http.StatusUnauthorized, } }, "ok": func(t *testing.T) *authorizeTest { key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") assert.FatalError(t, err) signer, ok := key.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") sshSigner, err := ssh.NewSignerFromSigner(signer) assert.FatalError(t, err) cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) assert.FatalError(t, err) p, ok := a.provisioners.Load("sshpop/sshpop") assert.Fatal(t, ok, "sshpop provisioner not found in test authority") tok, err := generateToken(strconv.FormatUint(cert.Serial, 10), p.GetName(), testAudiences.SSHRevoke[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return &authorizeTest{ auth: a, token: tok, cert: cert, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func TestAuthority_authorizeSSHRekey(t *testing.T) { a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) now := time.Now().UTC() validIssuer := "step-cli" type authorizeTest struct { auth *Authority token string cert *ssh.Certificate err error code int } tests := map[string]func(t *testing.T) *authorizeTest{ "fail/invalid-token": func(t *testing.T) *authorizeTest { return &authorizeTest{ auth: a, token: "foo", err: errors.New("authority.authorizeSSHRekey: error parsing token"), code: http.StatusUnauthorized, } }, "fail/sshRekey-unimplemented-jwk-provisioner": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: testAudiences.SSHRekey, ID: "43", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) return &authorizeTest{ auth: a, token: raw, err: errors.New("authority.authorizeSSHRekey: provisioner.AuthorizeSSHRekey not implemented"), code: http.StatusUnauthorized, } }, "ok": func(t *testing.T) *authorizeTest { key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") assert.FatalError(t, err) signer, ok := key.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") sshSigner, err := ssh.NewSignerFromSigner(signer) assert.FatalError(t, err) cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) assert.FatalError(t, err) p, ok := a.provisioners.Load("sshpop/sshpop") assert.Fatal(t, ok, "sshpop provisioner not found in test authority") tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return &authorizeTest{ auth: a, token: tok, cert: cert, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.cert.Serial, cert.Serial) assert.Len(t, 4, signOpts) } } }) } } func TestAuthority_AuthorizeRenewToken(t *testing.T) { ctx := context.Background() type stepProvisionerASN1 struct { Type int Name []byte CredentialID []byte KeyValuePairs []string `asn1:"optional,omitempty"` } _, signer, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } csr, err := x509util.CreateCertificateRequest("test.example.com", []string{"test.example.com"}, signer) if err != nil { t.Fatal(err) } _, otherSigner, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) { chain, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...) if err != nil { t.Fatal(err) } var x5c []string for _, c := range chain { x5c = append(x5c, base64.StdEncoding.EncodeToString(c.Raw)) } so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("x5cInsecure", x5c) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: key}, so) if err != nil { t.Fatal(err) } s, err := jose.Signed(sig).Claims(claims).CompactSerialize() if err != nil { t.Fatal(err) } return s, chain[0] } now := time.Now() a1 := testAuthority(t) t1, c1 := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) t2, c2 := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), IssuedAt: jose.NewNumericDate(now), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now.Add(-time.Hour) cert.NotAfter = now.Add(-time.Minute) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) t3, c3 := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", Issuer: "step-cli", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) a4 := testAuthority(t) a4.db = &db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { return &db.CertificateData{ Provisioner: &db.ProvisionerData{ID: "Max:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", Name: "Max"}, RaInfo: &provisioner.RAInfo{ProvisionerName: "ra"}, }, nil }, } t4, c4 := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://ra.example.com/1.0/renew"}, Subject: "test.example.com", Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) badIssuer, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "test.example.com", Issuer: "bad-issuer", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) badSubject, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/renew"}, Subject: "bad-subject", Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/sign"}, Subject: "test.example.com", Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/sign"}, Subject: "test.example.com", Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/sign"}, Subject: "test.example.com", Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) badAudience, _ := generateX5cToken(a1, signer, jose.Claims{ Audience: []string{"https://example.com/1.0/sign"}, Subject: "test.example.com", Issuer: "step-ca-client/1.0", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { cert.NotBefore = now cert.NotAfter = now.Add(time.Hour) b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) if err != nil { return err } cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, Value: b, }) return nil })) type args struct { ctx context.Context ott string } tests := []struct { name string authority *Authority args args want *x509.Certificate wantErr bool }{ {"ok", a1, args{ctx, t1}, c1, false}, {"ok expired cert", a1, args{ctx, t2}, c2, false}, {"ok provisioner issuer", a1, args{ctx, t3}, c3, false}, {"ok ra provisioner", a4, args{ctx, t4}, c4, false}, {"fail token", a1, args{ctx, "not.a.token"}, nil, true}, {"fail token reuse", a1, args{ctx, t1}, nil, true}, {"fail token signature", a1, args{ctx, badSigner}, nil, true}, {"fail token provisioner", a1, args{ctx, badProvisioner}, nil, true}, {"fail token iss", a1, args{ctx, badIssuer}, nil, true}, {"fail token sub", a1, args{ctx, badSubject}, nil, true}, {"fail token iat", a1, args{ctx, badNotBefore}, nil, true}, {"fail token iat", a1, args{ctx, badExpiry}, nil, true}, {"fail token iat", a1, args{ctx, badIssuedAt}, nil, true}, {"fail token aud", a1, args{ctx, badAudience}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.authority.AuthorizeRenewToken(tt.args.ctx, tt.args.ott) if (err != nil) != tt.wantErr { t.Errorf("Authority.AuthorizeRenewToken() error = %+v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.AuthorizeRenewToken() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/config/config.go ================================================ package config import ( "bytes" "encoding/json" "fmt" "net" "os" "time" "github.com/pkg/errors" "github.com/smallstep/linkedca" kms "go.step.sm/crypto/kms/apiv1" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" cas "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/templates" ) const ( legacyAuthority = "step-certificate-authority" ) var ( // DefaultBackdate length of time to backdate certificates to avoid // clock skew validation issues. DefaultBackdate = time.Minute // DefaultDisableRenewal disables renewals per provisioner. DefaultDisableRenewal = false // DefaultAllowRenewalAfterExpiry allows renewals even if the certificate is // expired. DefaultAllowRenewalAfterExpiry = false // DefaultEnableSSHCA enable SSH CA features per provisioner or globally // for all provisioners. DefaultEnableSSHCA = false // DefaultDisableSmallstepExtensions is the default value for the // DisableSmallstepExtensions provisioner claim. DefaultDisableSmallstepExtensions = false // DefaultCRLCacheDuration is the default cache duration for the CRL. DefaultCRLCacheDuration = &provisioner.Duration{Duration: 24 * time.Hour} // DefaultCRLExpiredDuration is the default duration in which expired // certificates will remain in the CRL after expiration. DefaultCRLExpiredDuration = time.Hour // GlobalProvisionerClaims is the default duration that expired certificates // remain in the CRL after expiration. GlobalProvisionerClaims = provisioner.Claims{ MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, EnableSSHCA: &DefaultEnableSSHCA, DisableRenewal: &DefaultDisableRenewal, AllowRenewalAfterExpiry: &DefaultAllowRenewalAfterExpiry, DisableSmallstepExtensions: &DefaultDisableSmallstepExtensions, } ) // Config represents the CA configuration and it's mapped to a JSON object. type Config struct { Root multiString `json:"root"` FederatedRoots []string `json:"federatedRoots"` IntermediateCert string `json:"crt"` IntermediateKey string `json:"key"` Address string `json:"address"` InsecureAddress string `json:"insecureAddress"` DNSNames []string `json:"dnsNames"` KMS *kms.Options `json:"kms,omitempty"` SSH *SSHConfig `json:"ssh,omitempty"` Logger json.RawMessage `json:"logger,omitempty"` DB *db.Config `json:"db,omitempty"` Monitoring json.RawMessage `json:"monitoring,omitempty"` AuthorityConfig *AuthConfig `json:"authority,omitempty"` TLS *TLSOptions `json:"tls,omitempty"` Password string `json:"password,omitempty"` Templates *templates.Templates `json:"templates,omitempty"` CommonName string `json:"commonName,omitempty"` CRL *CRLConfig `json:"crl,omitempty"` MetricsAddress string `json:"metricsAddress,omitempty"` SkipValidation bool `json:"-"` // Keeps record of the filename the Config is read from loadedFromFilepath string } // CRLConfig represents config options for CRL generation type CRLConfig struct { Enabled bool `json:"enabled"` GenerateOnRevoke bool `json:"generateOnRevoke,omitempty"` CacheDuration *provisioner.Duration `json:"cacheDuration,omitempty"` RenewPeriod *provisioner.Duration `json:"renewPeriod,omitempty"` IDPurl string `json:"idpURL,omitempty"` } // IsEnabled returns if the CRL is enabled. func (c *CRLConfig) IsEnabled() bool { return c != nil && c.Enabled } // Validate validates the CRL configuration. func (c *CRLConfig) Validate() error { if c == nil { return nil } if c.CacheDuration != nil && c.CacheDuration.Duration < 0 { return errors.New("crl.cacheDuration must be greater than or equal to 0") } if c.RenewPeriod != nil && c.RenewPeriod.Duration < 0 { return errors.New("crl.renewPeriod must be greater than or equal to 0") } if c.RenewPeriod != nil && c.CacheDuration != nil && c.RenewPeriod.Duration > c.CacheDuration.Duration { return errors.New("crl.cacheDuration must be greater than or equal to crl.renewPeriod") } return nil } // TickerDuration the renewal ticker duration. This is set by renewPeriod, of it // is not set is ~2/3 of cacheDuration. func (c *CRLConfig) TickerDuration() time.Duration { if !c.IsEnabled() { return 0 } if c.RenewPeriod != nil && c.RenewPeriod.Duration > 0 { return c.RenewPeriod.Duration } return (c.CacheDuration.Duration / 3) * 2 } // ASN1DN contains ASN1.DN attributes that are used in Subject and Issuer // x509 Certificate blocks. type ASN1DN struct { Country string `json:"country,omitempty"` Organization string `json:"organization,omitempty"` OrganizationalUnit string `json:"organizationalUnit,omitempty"` Locality string `json:"locality,omitempty"` Province string `json:"province,omitempty"` StreetAddress string `json:"streetAddress,omitempty"` SerialNumber string `json:"serialNumber,omitempty"` CommonName string `json:"commonName,omitempty"` } // AuthConfig represents the configuration options for the authority. An // underlaying registration authority can also be configured using the // cas.Options. type AuthConfig struct { *cas.Options AuthorityID string `json:"authorityId,omitempty"` DeploymentType string `json:"deploymentType,omitempty"` Provisioners provisioner.List `json:"provisioners,omitempty"` Admins []*linkedca.Admin `json:"-"` Template *ASN1DN `json:"template,omitempty"` Claims *provisioner.Claims `json:"claims,omitempty"` Policy *policy.Options `json:"policy,omitempty"` DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` Backdate *provisioner.Duration `json:"backdate,omitempty"` EnableAdmin bool `json:"enableAdmin,omitempty"` DisableGetSSHHosts bool `json:"disableGetSSHHosts,omitempty"` } // init initializes the required fields in the AuthConfig if they are not // provided. func (c *AuthConfig) init() { if c.Provisioners == nil { c.Provisioners = provisioner.List{} } if c.Template == nil { c.Template = &ASN1DN{} } if c.Backdate == nil { c.Backdate = &provisioner.Duration{ Duration: DefaultBackdate, } } } // Validate validates the authority configuration. func (c *AuthConfig) Validate(provisioner.Audiences) error { if c == nil { return errors.New("authority cannot be undefined") } // Initialize required fields. c.init() // Check that only one K8sSA is enabled var k8sCount int for _, p := range c.Provisioners { if p.GetType() == provisioner.TypeK8sSA { k8sCount++ } } if k8sCount > 1 { return errors.New("cannot have more than one kubernetes service account provisioner") } if c.Backdate.Duration < 0 { return errors.New("authority.backdate cannot be less than 0") } return nil } // LoadConfiguration parses the given filename in JSON format and returns the // configuration struct. func LoadConfiguration(filename string) (*Config, error) { f, err := os.Open(filename) if err != nil { return nil, errors.Wrapf(err, "error opening %s", filename) } defer f.Close() var c Config if err := json.NewDecoder(f).Decode(&c); err != nil { return nil, errors.Wrapf(err, "error parsing %s", filename) } // store filename that was read to populate Config c.loadedFromFilepath = filename // initialize the Config c.Init() return &c, nil } // Init initializes the minimal configuration required to create an authority. This // is mainly used on embedded authorities. func (c *Config) Init() { if c.DNSNames == nil { c.DNSNames = []string{"localhost", "127.0.0.1", "::1"} } if c.TLS == nil { c.TLS = &DefaultTLSOptions } if c.AuthorityConfig == nil { c.AuthorityConfig = &AuthConfig{} } if c.CommonName == "" { c.CommonName = "Step Online CA" } if c.CRL != nil && c.CRL.Enabled && c.CRL.CacheDuration == nil { c.CRL.CacheDuration = DefaultCRLCacheDuration } c.AuthorityConfig.init() } // Save saves the configuration to the given filename. func (c *Config) Save(filename string) error { var b bytes.Buffer enc := json.NewEncoder(&b) enc.SetIndent("", "\t") if err := enc.Encode(c); err != nil { //nolint:gosec // config struct contains password field by design return fmt.Errorf("error encoding configuration: %w", err) } if err := os.WriteFile(filename, b.Bytes(), 0600); err != nil { return fmt.Errorf("error writing %q: %w", filename, err) } return nil } // Commit saves the current configuration to the same // file it was initially loaded from. // // TODO(hs): rename Save() to WriteTo() and replace this // with Save()? Or is Commit clear enough. func (c *Config) Commit() error { if !c.WasLoadedFromFile() { return errors.New("cannot commit configuration if not loaded from file") } return c.Save(c.loadedFromFilepath) } // WasLoadedFromFile returns whether or not the Config was // loaded from a file. func (c *Config) WasLoadedFromFile() bool { return c.loadedFromFilepath != "" } // Filepath returns the path to the file the Config was // loaded from. func (c *Config) Filepath() string { return c.loadedFromFilepath } // Validate validates the configuration. func (c *Config) Validate() error { switch { case c.SkipValidation: return nil case c.Address == "": return errors.New("address cannot be empty") case len(c.DNSNames) == 0: return errors.New("dnsNames cannot be empty") case c.AuthorityConfig == nil: return errors.New("authority cannot be nil") } // Options holds the RA/CAS configuration. ra := c.AuthorityConfig.Options // The default RA/CAS requires root, crt and key. if ra.Is(cas.SoftCAS) { switch { case c.Root.HasEmpties(): return errors.New("root cannot be empty") case c.IntermediateCert == "": return errors.New("crt cannot be empty") case c.IntermediateKey == "": return errors.New("key cannot be empty") } } // Validate address (a port is required) if _, _, err := net.SplitHostPort(c.Address); err != nil { return errors.Errorf("invalid address %s", c.Address) } if addr := c.MetricsAddress; addr != "" { if _, _, err := net.SplitHostPort(addr); err != nil { return errors.Errorf("invalid metrics address %q", c.Address) } } if c.TLS == nil { c.TLS = &DefaultTLSOptions } else { if len(c.TLS.CipherSuites) == 0 { c.TLS.CipherSuites = DefaultTLSOptions.CipherSuites } if c.TLS.MaxVersion == 0 { c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion } if c.TLS.MinVersion == 0 { c.TLS.MinVersion = DefaultTLSOptions.MinVersion } if c.TLS.MinVersion > c.TLS.MaxVersion { return errors.New("tls minVersion cannot exceed tls maxVersion") } c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation } // Validate KMS options, nil is ok. if err := c.KMS.Validate(); err != nil { return err } // Validate RA/CAS options, nil is ok. if err := ra.Validate(); err != nil { return err } // Validate ssh: nil is ok if err := c.SSH.Validate(); err != nil { return err } // Validate templates: nil is ok if err := c.Templates.Validate(); err != nil { return err } // Validate crl config: nil is ok if err := c.CRL.Validate(); err != nil { return err } return c.AuthorityConfig.Validate(c.GetAudiences()) } // GetAudiences returns the legacy and possible urls without the ports that will // be used as the default provisioner audiences. The CA might have proxies in // front so we cannot rely on the port. func (c *Config) GetAudiences() provisioner.Audiences { audiences := provisioner.Audiences{ Sign: []string{legacyAuthority}, Revoke: []string{legacyAuthority}, SSHSign: []string{}, SSHRevoke: []string{}, SSHRenew: []string{}, } for _, name := range c.DNSNames { hostname := toHostname(name) audiences.Sign = append(audiences.Sign, fmt.Sprintf("https://%s/1.0/sign", hostname), fmt.Sprintf("https://%s/sign", hostname), fmt.Sprintf("https://%s/1.0/ssh/sign", hostname), fmt.Sprintf("https://%s/ssh/sign", hostname)) audiences.Renew = append(audiences.Renew, fmt.Sprintf("https://%s/1.0/renew", hostname), fmt.Sprintf("https://%s/renew", hostname)) audiences.Revoke = append(audiences.Revoke, fmt.Sprintf("https://%s/1.0/revoke", hostname), fmt.Sprintf("https://%s/revoke", hostname)) audiences.SSHSign = append(audiences.SSHSign, fmt.Sprintf("https://%s/1.0/ssh/sign", hostname), fmt.Sprintf("https://%s/ssh/sign", hostname), fmt.Sprintf("https://%s/1.0/sign", hostname), fmt.Sprintf("https://%s/sign", hostname)) audiences.SSHRevoke = append(audiences.SSHRevoke, fmt.Sprintf("https://%s/1.0/ssh/revoke", hostname), fmt.Sprintf("https://%s/ssh/revoke", hostname)) audiences.SSHRenew = append(audiences.SSHRenew, fmt.Sprintf("https://%s/1.0/ssh/renew", hostname), fmt.Sprintf("https://%s/ssh/renew", hostname)) audiences.SSHRekey = append(audiences.SSHRekey, fmt.Sprintf("https://%s/1.0/ssh/rekey", hostname), fmt.Sprintf("https://%s/ssh/rekey", hostname)) } return audiences } // Audience returns the list of audiences for a given path. func (c *Config) Audience(path string) []string { audiences := make([]string, len(c.DNSNames)+1) for i, name := range c.DNSNames { hostname := toHostname(name) audiences[i] = "https://" + hostname + path } // For backward compatibility audiences[len(c.DNSNames)] = path return audiences } func toHostname(name string) string { // ensure an IPv6 address is represented with square brackets when used as hostname if ip := net.ParseIP(name); ip != nil && ip.To4() == nil { name = "[" + name + "]" } return name } ================================================ FILE: authority/config/config_test.go ================================================ package config import ( "fmt" "reflect" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" _ "github.com/smallstep/certificates/cas" "go.step.sm/crypto/jose" ) func TestConfigValidate(t *testing.T) { maxjwk, err := jose.ReadKey("../testdata/secrets/max_pub.jwk") assert.FatalError(t, err) clijwk, err := jose.ReadKey("../testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) ac := &AuthConfig{ Provisioners: provisioner.List{ &provisioner.JWK{ Name: "Max", Type: "JWK", Key: maxjwk, }, &provisioner.JWK{ Name: "step-cli", Type: "JWK", Key: clijwk, }, }, } type ConfigValidateTest struct { config *Config err error tls *TLSOptions } tests := map[string]func(*testing.T) ConfigValidateTest{ "skip-validation": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ SkipValidation: true, }, } }, "empty-address": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ Root: []string{"../testdata/secrets/root_ca.crt"}, IntermediateCert: "../testdata/secrets/intermediate_ca.crt", IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, }, err: errors.New("address cannot be empty"), } }, "invalid-address": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1", Root: []string{"../testdata/secrets/root_ca.crt"}, IntermediateCert: "../testdata/secrets/intermediate_ca.crt", IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, }, err: errors.New("invalid address 127.0.0.1"), } }, "empty-root": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", IntermediateCert: "../testdata/secrets/intermediate_ca.crt", IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, }, err: errors.New("root cannot be empty"), } }, "empty-intermediate-cert": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", Root: []string{"../testdata/secrets/root_ca.crt"}, IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, }, err: errors.New("crt cannot be empty"), } }, "empty-intermediate-key": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", Root: []string{"../testdata/secrets/root_ca.crt"}, IntermediateCert: "../testdata/secrets/intermediate_ca.crt", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, }, err: errors.New("key cannot be empty"), } }, "empty-dnsNames": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", Root: []string{"../testdata/secrets/root_ca.crt"}, IntermediateCert: "../testdata/secrets/intermediate_ca.crt", IntermediateKey: "../testdata/secrets/intermediate_ca_key", Password: "pass", AuthorityConfig: ac, }, err: errors.New("dnsNames cannot be empty"), } }, "empty-TLS": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", Root: []string{"../testdata/secrets/root_ca.crt"}, IntermediateCert: "../testdata/secrets/intermediate_ca.crt", IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, }, tls: &DefaultTLSOptions, } }, "empty-TLS-values": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", Root: []string{"../testdata/secrets/root_ca.crt"}, IntermediateCert: "../testdata/secrets/intermediate_ca.crt", IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, TLS: &TLSOptions{}, }, tls: &DefaultTLSOptions, } }, "custom-tls-values": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", Root: []string{"../testdata/secrets/root_ca.crt"}, IntermediateCert: "../testdata/secrets/intermediate_ca.crt", IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, TLS: &TLSOptions{ CipherSuites: CipherSuites{ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", }, MinVersion: 1.0, MaxVersion: 1.1, Renegotiation: true, }, }, tls: &TLSOptions{ CipherSuites: CipherSuites{ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", }, MinVersion: 1.0, MaxVersion: 1.1, Renegotiation: true, }, } }, "tls-min>max": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", Root: []string{"../testdata/secrets/root_ca.crt"}, IntermediateCert: "../testdata/secrets/intermediate_ca.crt", IntermediateKey: "../testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", AuthorityConfig: ac, TLS: &TLSOptions{ CipherSuites: CipherSuites{ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", }, MinVersion: 1.2, MaxVersion: 1.1, Renegotiation: true, }, }, err: errors.New("tls minVersion cannot exceed tls maxVersion"), } }, } for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) err := tc.config.Validate() if err != nil { if assert.NotNil(t, tc.err) { assert.Equals(t, tc.err.Error(), err.Error()) } } else { if assert.Nil(t, tc.err) { fmt.Printf("tc.tls = %v\n", tc.tls) fmt.Printf("*tc.config.TLS = %v\n", tc.config.TLS) assert.Equals(t, tc.config.TLS, tc.tls) } } }) } } func TestAuthConfigValidate(t *testing.T) { asn1dn := ASN1DN{ Country: "Tazmania", Organization: "Acme Co", Locality: "Landscapes", Province: "Sudden Cliffs", StreetAddress: "TNT", CommonName: "test", } maxjwk, err := jose.ReadKey("../testdata/secrets/max_pub.jwk") assert.FatalError(t, err) clijwk, err := jose.ReadKey("../testdata/secrets/step_cli_key_pub.jwk") assert.FatalError(t, err) p := provisioner.List{ &provisioner.JWK{ Name: "Max", Type: "JWK", Key: maxjwk, }, &provisioner.JWK{ Name: "step-cli", Type: "JWK", Key: clijwk, }, } type AuthConfigValidateTest struct { ac *AuthConfig asn1dn ASN1DN err error } tests := map[string]func(*testing.T) AuthConfigValidateTest{ "fail-nil-authconfig": func(t *testing.T) AuthConfigValidateTest { return AuthConfigValidateTest{ ac: nil, err: errors.New("authority cannot be undefined"), } }, "ok-empty-provisioners": func(t *testing.T) AuthConfigValidateTest { return AuthConfigValidateTest{ ac: &AuthConfig{}, asn1dn: ASN1DN{}, } }, "ok-empty-asn1dn-template": func(t *testing.T) AuthConfigValidateTest { return AuthConfigValidateTest{ ac: &AuthConfig{ Provisioners: p, }, asn1dn: ASN1DN{}, } }, "ok-custom-asn1dn": func(t *testing.T) AuthConfigValidateTest { return AuthConfigValidateTest{ ac: &AuthConfig{ Provisioners: p, Template: &asn1dn, }, asn1dn: asn1dn, } }, } for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) err := tc.ac.Validate(provisioner.Audiences{}) if err != nil { if assert.NotNil(t, tc.err) { assert.Equals(t, tc.err.Error(), err.Error()) } } else { if assert.Nil(t, tc.err, fmt.Sprintf("expected error: %s, but got ", tc.err)) { assert.Equals(t, *tc.ac.Template, tc.asn1dn) } } }) } } func Test_toHostname(t *testing.T) { tests := []struct { name string want string }{ {name: "localhost", want: "localhost"}, {name: "ca.smallstep.com", want: "ca.smallstep.com"}, {name: "127.0.0.1", want: "127.0.0.1"}, {name: "::1", want: "[::1]"}, {name: "[::1]", want: "[::1]"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := toHostname(tt.name); got != tt.want { t.Errorf("toHostname() = %v, want %v", got, tt.want) } }) } } func TestConfig_Audience(t *testing.T) { type fields struct { DNSNames []string } type args struct { path string } tests := []struct { name string fields fields args args want []string }{ {"ok", fields{[]string{ "ca", "ca.example.com", "127.0.0.1", "::1", }}, args{"/path"}, []string{ "https://ca/path", "https://ca.example.com/path", "https://127.0.0.1/path", "https://[::1]/path", "/path", }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Config{ DNSNames: tt.fields.DNSNames, } if got := c.Audience(tt.args.path); !reflect.DeepEqual(got, tt.want) { t.Errorf("Config.Audience() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/config/ssh.go ================================================ package config import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" ) // SSHConfig contains the user and host keys. type SSHConfig struct { HostKey string `json:"hostKey"` UserKey string `json:"userKey"` Keys []*SSHPublicKey `json:"keys,omitempty"` AddUserPrincipal string `json:"addUserPrincipal,omitempty"` AddUserCommand string `json:"addUserCommand,omitempty"` Bastion *Bastion `json:"bastion,omitempty"` } // Bastion contains the custom properties used on bastion. type Bastion struct { Hostname string `json:"hostname"` User string `json:"user,omitempty"` Port string `json:"port,omitempty"` Command string `json:"cmd,omitempty"` Flags string `json:"flags,omitempty"` } // HostTag are tagged with k,v pairs. These tags are how a user is ultimately // associated with a host. type HostTag struct { ID string Name string Value string } // Host defines expected attributes for an ssh host. type Host struct { HostID string `json:"hid"` HostTags []HostTag `json:"host_tags"` Hostname string `json:"hostname"` } // Validate checks the fields in SSHConfig. func (c *SSHConfig) Validate() error { if c == nil { return nil } for _, k := range c.Keys { if err := k.Validate(); err != nil { return err } } return nil } // SSHPublicKey contains a public key used by federated CAs to keep old signing // keys for this ca. type SSHPublicKey struct { Type string `json:"type"` Federated bool `json:"federated"` Key jose.JSONWebKey `json:"key"` publicKey ssh.PublicKey } // Validate checks the fields in SSHPublicKey. func (k *SSHPublicKey) Validate() error { switch { case k.Type == "": return errors.New("type cannot be empty") case k.Type != provisioner.SSHHostCert && k.Type != provisioner.SSHUserCert: return errors.Errorf("invalid type %s, it must be user or host", k.Type) case !k.Key.IsPublic(): return errors.New("invalid key type, it must be a public key") } key, err := ssh.NewPublicKey(k.Key.Key) if err != nil { return errors.Wrap(err, "error creating ssh key") } k.publicKey = key return nil } // PublicKey returns the ssh public key. func (k *SSHPublicKey) PublicKey() ssh.PublicKey { return k.publicKey } // SSHKeys represents the SSH User and Host public keys. type SSHKeys struct { UserKeys []ssh.PublicKey HostKeys []ssh.PublicKey } ================================================ FILE: authority/config/ssh_test.go ================================================ package config import ( "reflect" "testing" "github.com/smallstep/assert" "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" ) func TestSSHPublicKey_Validate(t *testing.T) { key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0) assert.FatalError(t, err) type fields struct { Type string Federated bool Key jose.JSONWebKey } tests := []struct { name string fields fields wantErr bool }{ {"user", fields{"user", true, key.Public()}, false}, {"host", fields{"host", false, key.Public()}, false}, {"empty", fields{"", true, key.Public()}, true}, {"badType", fields{"bad", false, key.Public()}, true}, {"badKey", fields{"user", false, *key}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &SSHPublicKey{ Type: tt.fields.Type, Federated: tt.fields.Federated, Key: tt.fields.Key, } if err := k.Validate(); (err != nil) != tt.wantErr { t.Errorf("SSHPublicKey.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestSSHPublicKey_PublicKey(t *testing.T) { key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0) assert.FatalError(t, err) pub, err := ssh.NewPublicKey(key.Public().Key) assert.FatalError(t, err) type fields struct { publicKey ssh.PublicKey } tests := []struct { name string fields fields want ssh.PublicKey }{ {"ok", fields{pub}, pub}, {"nil", fields{nil}, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &SSHPublicKey{ publicKey: tt.fields.publicKey, } if got := k.PublicKey(); !reflect.DeepEqual(got, tt.want) { t.Errorf("SSHPublicKey.PublicKey() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/config/tls_options.go ================================================ package config import ( "crypto/tls" "fmt" "github.com/pkg/errors" ) var ( // DefaultTLSMinVersion default minimum version of TLS. DefaultTLSMinVersion = TLSVersion(1.2) // DefaultTLSMaxVersion default maximum version of TLS. DefaultTLSMaxVersion = TLSVersion(1.3) // DefaultTLSRenegotiation default TLS connection renegotiation policy. DefaultTLSRenegotiation = false // Never regnegotiate. // DefaultTLSCipherSuites specifies default step ciphersuite(s). // These are TLS 1.0 - 1.2 cipher suites. DefaultTLSCipherSuites = CipherSuites{ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", } // ApprovedTLSCipherSuites smallstep approved ciphersuites. ApprovedTLSCipherSuites = CipherSuites{ // AEADs w/ ECDHE "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", // CBC w/ ECDHE "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", } // DefaultTLSOptions represents the default TLS version as well as the cipher // suites used in the TLS certificates. DefaultTLSOptions = TLSOptions{ CipherSuites: DefaultTLSCipherSuites, MinVersion: DefaultTLSMinVersion, MaxVersion: DefaultTLSMaxVersion, Renegotiation: DefaultTLSRenegotiation, } ) // TLSVersion represents a TLS version number. type TLSVersion float64 // Validate implements models.Validator and checks that a cipher suite is // valid. func (v TLSVersion) Validate() error { if _, ok := tlsVersions[v]; ok { return nil } return errors.Errorf("%f is not a valid tls version", v) } // Value returns the Go constant for the TLSVersion. func (v TLSVersion) Value() uint16 { return tlsVersions[v] } // String returns the Go constant for the TLSVersion. func (v TLSVersion) String() string { k := v.Value() switch k { case tls.VersionTLS10: return "1.0" case tls.VersionTLS11: return "1.1" case tls.VersionTLS12: return "1.2" case tls.VersionTLS13: return "1.3" default: return fmt.Sprintf("unexpected value: %f", v) } } // tlsVersions has the list of supported tls version. var tlsVersions = map[TLSVersion]uint16{ // Defaults to TLS 1.3 0: tls.VersionTLS13, // Options 1.0: tls.VersionTLS10, 1.1: tls.VersionTLS11, 1.2: tls.VersionTLS12, 1.3: tls.VersionTLS13, } // CipherSuites represents an array of string codes representing the cipher // suites. type CipherSuites []string // Validate implements models.Validator and checks that a cipher suite is // valid. func (c CipherSuites) Validate() error { for _, s := range c { if _, ok := cipherSuites[s]; !ok { return errors.Errorf("%s is not a valid cipher suite", s) } } return nil } // Value returns an []uint16 for the cipher suites. func (c CipherSuites) Value() []uint16 { values := make([]uint16, len(c)) for i, s := range c { values[i] = cipherSuites[s] } return values } // cipherSuites has the list of supported cipher suites. var cipherSuites = map[string]uint16{ // TLS 1.0 - 1.2 cipher suites. "TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA, // lgtm[go/insecure-tls] "TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256, // lgtm[go/insecure-tls] "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, // lgtm[go/insecure-tls] "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, // lgtm[go/insecure-tls] "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, // lgtm[go/insecure-tls] "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, // lgtm[go/insecure-tls] "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, // TLS 1.3 cipher sutes. "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, "TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384, "TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256, // Legacy names. "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, } // TLSOptions represents the TLS options that can be specified on *tls.Config // types to configure HTTPS servers and clients. type TLSOptions struct { CipherSuites CipherSuites `json:"cipherSuites"` MinVersion TLSVersion `json:"minVersion"` MaxVersion TLSVersion `json:"maxVersion"` Renegotiation bool `json:"renegotiation"` } // TLSConfig returns the tls.Config equivalent of the TLSOptions. func (t *TLSOptions) TLSConfig() *tls.Config { var rs tls.RenegotiationSupport if t.Renegotiation { rs = tls.RenegotiateFreelyAsClient } else { rs = tls.RenegotiateNever } //nolint:gosec // default MinVersion 1.2, if defined but empty 1.3 is used return &tls.Config{ CipherSuites: t.CipherSuites.Value(), MinVersion: t.MinVersion.Value(), MaxVersion: t.MaxVersion.Value(), Renegotiation: rs, } } ================================================ FILE: authority/config/tls_options_test.go ================================================ package config import ( "crypto/tls" "reflect" "testing" ) func TestTLSVersion_Validate(t *testing.T) { tests := []struct { name string v TLSVersion wantErr bool }{ {"default", TLSVersion(0), false}, {"1.0", TLSVersion(1.0), false}, {"1.1", TLSVersion(1.1), false}, {"1.2", TLSVersion(1.2), false}, {"1.3", TLSVersion(1.3), false}, {"0.99", TLSVersion(0.99), true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.v.Validate(); (err != nil) != tt.wantErr { t.Errorf("TLSVersion.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestTLSVersion_String(t *testing.T) { tests := []struct { name string v TLSVersion want string }{ {"default", TLSVersion(0), "1.3"}, {"1.0", TLSVersion(1.0), "1.0"}, {"1.1", TLSVersion(1.1), "1.1"}, {"1.2", TLSVersion(1.2), "1.2"}, {"1.3", TLSVersion(1.3), "1.3"}, {"0.99", TLSVersion(0.99), "unexpected value: 0.990000"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.v.String(); got != tt.want { t.Errorf("TLSVersion.String() = %v, want %v", got, tt.want) } }) } } func TestCipherSuites_Validate(t *testing.T) { tests := []struct { name string c CipherSuites wantErr bool }{ {"TLS_RSA_WITH_RC4_128_SHA", CipherSuites{"TLS_RSA_WITH_RC4_128_SHA"}, false}, {"TLS_RSA_WITH_3DES_EDE_CBC_SHA", CipherSuites{"TLS_RSA_WITH_3DES_EDE_CBC_SHA"}, false}, {"TLS_RSA_WITH_AES_128_CBC_SHA", CipherSuites{"TLS_RSA_WITH_AES_128_CBC_SHA"}, false}, {"TLS_RSA_WITH_AES_256_CBC_SHA", CipherSuites{"TLS_RSA_WITH_AES_256_CBC_SHA"}, false}, {"TLS_RSA_WITH_AES_128_CBC_SHA256", CipherSuites{"TLS_RSA_WITH_AES_128_CBC_SHA256"}, false}, {"TLS_RSA_WITH_AES_128_GCM_SHA256", CipherSuites{"TLS_RSA_WITH_AES_128_GCM_SHA256"}, false}, {"TLS_RSA_WITH_AES_256_GCM_SHA384", CipherSuites{"TLS_RSA_WITH_AES_256_GCM_SHA384"}, false}, {"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", CipherSuites{"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA"}, false}, {"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", CipherSuites{"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA"}, false}, {"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", CipherSuites{"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256"}, false}, {"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", CipherSuites{"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"}, false}, {"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", CipherSuites{"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"}, false}, {"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", CipherSuites{"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"}, false}, {"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", CipherSuites{"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305"}, false}, {"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", CipherSuites{"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA"}, false}, {"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", CipherSuites{"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"}, false}, {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", CipherSuites{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}, false}, {"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", CipherSuites{"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256"}, false}, {"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", CipherSuites{"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"}, false}, {"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", CipherSuites{"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"}, false}, {"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", CipherSuites{"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305"}, false}, {"multiple", CipherSuites{"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"}, false}, {"fail", CipherSuites{"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_BAD_CIPHERSUITE"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.c.Validate(); (err != nil) != tt.wantErr { t.Errorf("CipherSuites.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestCipherSuites_Value(t *testing.T) { tests := []struct { name string c CipherSuites want []uint16 }{ {"TLS_RSA_WITH_RC4_128_SHA", CipherSuites{"TLS_RSA_WITH_RC4_128_SHA"}, []uint16{tls.TLS_RSA_WITH_RC4_128_SHA}}, {"TLS_RSA_WITH_3DES_EDE_CBC_SHA", CipherSuites{"TLS_RSA_WITH_3DES_EDE_CBC_SHA"}, []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA}}, {"TLS_RSA_WITH_AES_128_CBC_SHA", CipherSuites{"TLS_RSA_WITH_AES_128_CBC_SHA"}, []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}}, {"TLS_RSA_WITH_AES_256_CBC_SHA", CipherSuites{"TLS_RSA_WITH_AES_256_CBC_SHA"}, []uint16{tls.TLS_RSA_WITH_AES_256_CBC_SHA}}, {"TLS_RSA_WITH_AES_128_CBC_SHA256", CipherSuites{"TLS_RSA_WITH_AES_128_CBC_SHA256"}, []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA256}}, {"TLS_RSA_WITH_AES_128_GCM_SHA256", CipherSuites{"TLS_RSA_WITH_AES_128_GCM_SHA256"}, []uint16{tls.TLS_RSA_WITH_AES_128_GCM_SHA256}}, {"TLS_RSA_WITH_AES_256_GCM_SHA384", CipherSuites{"TLS_RSA_WITH_AES_256_GCM_SHA384"}, []uint16{tls.TLS_RSA_WITH_AES_256_GCM_SHA384}}, {"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", CipherSuites{"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA"}, []uint16{tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA}}, {"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", CipherSuites{"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA"}, []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA}}, {"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", CipherSuites{"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256"}, []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256}}, {"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", CipherSuites{"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"}, []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}, {"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", CipherSuites{"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"}, []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}}, {"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", CipherSuites{"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"}, []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384}}, {"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", CipherSuites{"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305"}, []uint16{tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}}, {"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", CipherSuites{"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA"}, []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA}}, {"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", CipherSuites{"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"}, []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}}, {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", CipherSuites{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}, []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}}, {"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", CipherSuites{"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256"}, []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256}}, {"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", CipherSuites{"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"}, []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}}, {"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", CipherSuites{"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"}, []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384}}, {"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", CipherSuites{"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305"}, []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}}, {"multiple", CipherSuites{"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"}, []uint16{tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}, {"fail", CipherSuites{"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_BAD_CIPHERSUITE"}, []uint16{tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 0}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.c.Value(); !reflect.DeepEqual(got, tt.want) { t.Errorf("CipherSuites.Value() = %v, want %v", got, tt.want) } }) } } func TestTLSOptions_TLSConfig(t *testing.T) { type fields struct { CipherSuites CipherSuites MinVersion TLSVersion MaxVersion TLSVersion Renegotiation bool } tests := []struct { name string fields fields want *tls.Config }{ {"default", fields{DefaultTLSCipherSuites, DefaultTLSMinVersion, DefaultTLSMaxVersion, DefaultTLSRenegotiation}, &tls.Config{ CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, MinVersion: tls.VersionTLS12, MaxVersion: tls.VersionTLS13, Renegotiation: tls.RenegotiateNever, }}, {"renegotation", fields{DefaultTLSCipherSuites, DefaultTLSMinVersion, DefaultTLSMaxVersion, true}, &tls.Config{ CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, MinVersion: tls.VersionTLS12, MaxVersion: tls.VersionTLS13, Renegotiation: tls.RenegotiateFreelyAsClient, }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &TLSOptions{ CipherSuites: tt.fields.CipherSuites, MinVersion: tt.fields.MinVersion, MaxVersion: tt.fields.MaxVersion, Renegotiation: tt.fields.Renegotiation, } if got := o.TLSConfig(); !reflect.DeepEqual(got, tt.want) { t.Errorf("TLSOptions.TLSConfig() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/config/types.go ================================================ package config import ( "encoding/json" "github.com/pkg/errors" ) // multiString represents a type that can be encoded/decoded in JSON as a single // string or an array of strings. type multiString []string // First returns the first element of a multiString. It will return an empty // string if the multistring is empty. func (s multiString) First() string { if len(s) > 0 { return s[0] } return "" } // HasEmpties returns `true` if any string in the array is empty. func (s multiString) HasEmpties() bool { if len(s) == 0 { return true } for _, ss := range s { if ss == "" { return true } } return false } // MarshalJSON marshals the multistring as a string or a slice of strings . With // 0 elements it will return the empty string, with 1 element a regular string, // otherwise a slice of strings. func (s multiString) MarshalJSON() ([]byte, error) { switch len(s) { case 0: return []byte(`""`), nil case 1: return json.Marshal(s[0]) default: return json.Marshal([]string(s)) } } // UnmarshalJSON parses a string or a slice and sets it to the multiString. func (s *multiString) UnmarshalJSON(data []byte) error { if s == nil { return errors.New("multiString cannot be nil") } if len(data) == 0 { *s = nil return nil } // Parse string if data[0] == '"' { var str string if err := json.Unmarshal(data, &str); err != nil { return errors.Wrapf(err, "error unmarshalling %s", data) } *s = []string{str} return nil } // Parse array var ss []string if err := json.Unmarshal(data, &ss); err != nil { return errors.Wrapf(err, "error unmarshalling %s", data) } *s = ss return nil } ================================================ FILE: authority/config/types_test.go ================================================ package config import ( "reflect" "testing" ) func Test_multiString_First(t *testing.T) { tests := []struct { name string s multiString want string }{ {"empty", multiString{}, ""}, {"string", multiString{"one"}, "one"}, {"slice", multiString{"one", "two"}, "one"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.s.First(); got != tt.want { t.Errorf("multiString.First() = %v, want %v", got, tt.want) } }) } } func Test_multiString_Empties(t *testing.T) { tests := []struct { name string s multiString want bool }{ {"empty", multiString{}, true}, {"string", multiString{"one"}, false}, {"empty string", multiString{""}, true}, {"slice", multiString{"one", "two"}, false}, {"empty slice", multiString{"one", ""}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.s.HasEmpties(); got != tt.want { t.Errorf("multiString.Empties() = %v, want %v", got, tt.want) } }) } } func Test_multiString_MarshalJSON(t *testing.T) { tests := []struct { name string s multiString want []byte wantErr bool }{ {"empty", []string{}, []byte(`""`), false}, {"string", []string{"a string"}, []byte(`"a string"`), false}, {"slice", []string{"string one", "string two"}, []byte(`["string one","string two"]`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.s.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("multiString.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("multiString.MarshalJSON() = %v, want %v", got, tt.want) } }) } } func Test_multiString_UnmarshalJSON(t *testing.T) { type args struct { data []byte } tests := []struct { name string s *multiString args args want *multiString wantErr bool }{ {"empty", new(multiString), args{[]byte{}}, new(multiString), false}, {"empty string", new(multiString), args{[]byte(`""`)}, &multiString{""}, false}, {"string", new(multiString), args{[]byte(`"a string"`)}, &multiString{"a string"}, false}, {"slice", new(multiString), args{[]byte(`["string one","string two"]`)}, &multiString{"string one", "string two"}, false}, {"error", new(multiString), args{[]byte(`["123",123]`)}, new(multiString), true}, {"nil", nil, args{nil}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.s.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { t.Errorf("multiString.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(tt.s, tt.want) { t.Errorf("multiString.UnmarshalJSON() = %v, want %v", tt.s, tt.want) } }) } } ================================================ FILE: authority/config.go ================================================ package authority import "github.com/smallstep/certificates/authority/config" // Config is an alias to support older APIs. type Config = config.Config // LoadConfiguration is an alias to support older APIs. var LoadConfiguration = config.LoadConfiguration // AuthConfig is an alias to support older APIs. type AuthConfig = config.AuthConfig // TLS // ASN1DN is an alias to support older APIs. type ASN1DN = config.ASN1DN // DefaultTLSOptions is an alias to support older APIs. var DefaultTLSOptions = config.DefaultTLSOptions // TLSOptions is an alias to support older APIs. type TLSOptions = config.TLSOptions // CipherSuites is an alias to support older APIs. type CipherSuites = config.CipherSuites // SSH // SSHConfig is an alias to support older APIs. type SSHConfig = config.SSHConfig // Bastion is an alias to support older APIs. type Bastion = config.Bastion // HostTag is an alias to support older APIs. type HostTag = config.HostTag // Host is an alias to support older APIs. type Host = config.Host // SSHPublicKey is an alias to support older APIs. type SSHPublicKey = config.SSHPublicKey // SSHKeys is an alias to support older APIs. type SSHKeys = config.SSHKeys ================================================ FILE: authority/export.go ================================================ package authority import ( "encoding/json" "net/url" "os" "path/filepath" "strings" "github.com/pkg/errors" "google.golang.org/protobuf/types/known/structpb" "github.com/smallstep/cli-utils/step" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/authority/provisioner" ) // Export creates a linkedca configuration form the current ca.json and loaded // authorities. // // Note that export will not export neither the pki password nor the certificate // issuer password. func (a *Authority) Export() (c *linkedca.Configuration, err error) { // Recover from panics defer func() { if r := recover(); r != nil { err = r.(error) } }() files := make(map[string][]byte) // The exported configuration should not include the password in it. c = &linkedca.Configuration{ Version: "1.0", Root: mustReadFilesOrURIs(a.config.Root, files), FederatedRoots: mustReadFilesOrURIs(a.config.FederatedRoots, files), Intermediate: mustReadFileOrURI(a.config.IntermediateCert, files), IntermediateKey: mustReadFileOrURI(a.config.IntermediateKey, files), Address: a.config.Address, InsecureAddress: a.config.InsecureAddress, DnsNames: a.config.DNSNames, Db: mustMarshalToStruct(a.config.DB), Logger: mustMarshalToStruct(a.config.Logger), Monitoring: mustMarshalToStruct(a.config.Monitoring), Authority: &linkedca.Authority{ Id: a.config.AuthorityConfig.AuthorityID, EnableAdmin: a.config.AuthorityConfig.EnableAdmin, DisableIssuedAtCheck: a.config.AuthorityConfig.DisableIssuedAtCheck, Backdate: mustDuration(a.config.AuthorityConfig.Backdate), DeploymentType: a.config.AuthorityConfig.DeploymentType, }, Files: files, } // SSH if v := a.config.SSH; v != nil { c.Ssh = &linkedca.SSH{ HostKey: mustReadFileOrURI(v.HostKey, files), UserKey: mustReadFileOrURI(v.UserKey, files), AddUserPrincipal: v.AddUserPrincipal, AddUserCommand: v.AddUserCommand, } for _, k := range v.Keys { typ, ok := linkedca.SSHPublicKey_Type_value[strings.ToUpper(k.Type)] if !ok { return nil, errors.Errorf("unsupported ssh key type %s", k.Type) } c.Ssh.Keys = append(c.Ssh.Keys, &linkedca.SSHPublicKey{ Type: linkedca.SSHPublicKey_Type(typ), Federated: k.Federated, Key: mustMarshalToStruct(k), }) } if b := v.Bastion; b != nil { c.Ssh.Bastion = &linkedca.Bastion{ Hostname: b.Hostname, User: b.User, Port: b.Port, Command: b.Command, Flags: b.Flags, } } } // KMS if v := a.config.KMS; v != nil { var typ int32 var ok bool if v.Type == "" { typ = int32(linkedca.KMS_SOFTKMS) } else { typ, ok = linkedca.KMS_Type_value[strings.ToUpper(string(v.Type))] if !ok { return nil, errors.Errorf("unsupported kms type %s", v.Type) } } c.Kms = &linkedca.KMS{ Type: linkedca.KMS_Type(typ), CredentialsFile: v.CredentialsFile, Uri: v.URI, Pin: v.Pin, ManagementKey: v.ManagementKey, Region: v.Region, Profile: v.Profile, } } // Authority // cas options if v := a.config.AuthorityConfig.Options; v != nil { c.Authority.Type = 0 c.Authority.CertificateAuthority = v.CertificateAuthority c.Authority.CertificateAuthorityFingerprint = v.CertificateAuthorityFingerprint c.Authority.CredentialsFile = v.CredentialsFile if iss := v.CertificateIssuer; iss != nil { typ, ok := linkedca.CertificateIssuer_Type_value[strings.ToUpper(iss.Type)] if !ok { return nil, errors.Errorf("unknown certificate issuer type %s", iss.Type) } // The exported certificate issuer should not include the password. c.Authority.CertificateIssuer = &linkedca.CertificateIssuer{ Type: linkedca.CertificateIssuer_Type(typ), Provisioner: iss.Provisioner, Certificate: mustReadFileOrURI(iss.Certificate, files), Key: mustReadFileOrURI(iss.Key, files), } } } // admins for { list, cursor := a.admins.Find("", 100) c.Authority.Admins = append(c.Authority.Admins, list...) if cursor == "" { break } } // provisioners for { list, cursor := a.provisioners.Find("", 100) for _, p := range list { lp, err := ProvisionerToLinkedca(p) if err != nil { return nil, err } c.Authority.Provisioners = append(c.Authority.Provisioners, lp) } if cursor == "" { break } } // global claims c.Authority.Claims = claimsToLinkedca(a.config.AuthorityConfig.Claims) // Distinguished names template if v := a.config.AuthorityConfig.Template; v != nil { c.Authority.Template = &linkedca.DistinguishedName{ Country: v.Country, Organization: v.Organization, OrganizationalUnit: v.OrganizationalUnit, Locality: v.Locality, Province: v.Province, StreetAddress: v.StreetAddress, SerialNumber: v.SerialNumber, CommonName: v.CommonName, } } // TLS if v := a.config.TLS; v != nil { c.Tls = &linkedca.TLS{ MinVersion: v.MinVersion.String(), MaxVersion: v.MaxVersion.String(), Renegotiation: v.Renegotiation, } for _, cs := range v.CipherSuites.Value() { c.Tls.CipherSuites = append(c.Tls.CipherSuites, linkedca.TLS_CiperSuite(cs)) } } // Templates if v := a.config.Templates; v != nil { c.Templates = &linkedca.ConfigTemplates{ Ssh: &linkedca.SSHConfigTemplate{}, Data: mustMarshalToStruct(v.Data), } // Remove automatically loaded vars if c.Templates.Data != nil && c.Templates.Data.Fields != nil { delete(c.Templates.Data.Fields, "Step") } for _, t := range v.SSH.Host { typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))] if !ok { return nil, errors.Errorf("unsupported template type %s", t.Type) } c.Templates.Ssh.Hosts = append(c.Templates.Ssh.Hosts, &linkedca.ConfigTemplate{ Type: linkedca.ConfigTemplate_Type(typ), Name: t.Name, Template: mustReadFileOrURI(t.TemplatePath, files), Path: t.Path, Comment: t.Comment, Requires: t.RequiredData, Content: t.Content, }) } for _, t := range v.SSH.User { typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))] if !ok { return nil, errors.Errorf("unsupported template type %s", t.Type) } c.Templates.Ssh.Users = append(c.Templates.Ssh.Users, &linkedca.ConfigTemplate{ Type: linkedca.ConfigTemplate_Type(typ), Name: t.Name, Template: mustReadFileOrURI(t.TemplatePath, files), Path: t.Path, Comment: t.Comment, Requires: t.RequiredData, Content: t.Content, }) } } return c, nil } func mustDuration(d *provisioner.Duration) string { if d == nil || d.Duration == 0 { return "" } return d.String() } func mustMarshalToStruct(v interface{}) *structpb.Struct { b, err := json.Marshal(v) if err != nil { panic(errors.Wrapf(err, "error marshaling %T", v)) } var r *structpb.Struct if err := json.Unmarshal(b, &r); err != nil { panic(errors.Wrapf(err, "error unmarshaling %T", v)) } return r } func mustReadFileOrURI(fn string, m map[string][]byte) string { if fn == "" { return "" } stepPath := filepath.ToSlash(step.Path()) if !strings.HasSuffix(stepPath, "/") { stepPath += "/" } fn = strings.TrimPrefix(filepath.ToSlash(fn), stepPath) ok, err := isFilename(fn) if err != nil { panic(err) } if ok { b, err := os.ReadFile(step.Abs(fn)) if err != nil { panic(errors.Wrapf(err, "error reading %s", fn)) } m[fn] = b return fn } return fn } func mustReadFilesOrURIs(fns []string, m map[string][]byte) []string { var result []string for _, fn := range fns { result = append(result, mustReadFileOrURI(fn, m)) } return result } func isFilename(fn string) (bool, error) { u, err := url.Parse(fn) if err != nil { return false, errors.Wrapf(err, "error parsing %s", fn) } return u.Scheme == "" || u.Scheme == "file", nil } ================================================ FILE: authority/http_client.go ================================================ package authority import ( "crypto/tls" "crypto/x509" "net/http" "sync/atomic" "github.com/smallstep/certificates/authority/poolhttp" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/internal/httptransport" ) // systemCertPool holds a copy of the system cert pool. This cert pool must be // initialized when the authority is created and we should always get a clone of // this pool. var systemCertPool atomic.Pointer[x509.CertPool] // initializeSystemCertPool initializes the system cert pool if necessary. func initializeSystemCertPool() error { if systemCertPool.Load() == nil { pool, err := x509.SystemCertPool() if err != nil { return err } systemCertPool.Store(pool) } return nil } // newHTTPClient will return an HTTP client that trusts the system cert pool and // the given roots. func newHTTPClient(wt httptransport.Wrapper, roots ...*x509.Certificate) provisioner.HTTPClient { return poolhttp.New(func() *http.Client { pool := systemCertPool.Load().Clone() for _, crt := range roots { pool.AddCert(crt) } tr, ok := http.DefaultTransport.(*http.Transport) if !ok { tr = httptransport.New() } else { tr = tr.Clone() } tr.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, RootCAs: pool, } rr := wt(tr) return &http.Client{Transport: rr} }) } ================================================ FILE: authority/http_client_test.go ================================================ package authority import ( "context" "crypto/tls" "crypto/x509" "fmt" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/internal/httptransport" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" ) func mustCertificate(t *testing.T, a *Authority, csr *x509.CertificateRequest) []*x509.Certificate { t.Helper() ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) now := time.Now() signOpts := provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(now), NotAfter: provisioner.NewTimeDuration(now.Add(5 * time.Minute)), Backdate: 1 * time.Minute, } sans := []string{} sans = append(sans, csr.DNSNames...) sans = append(sans, csr.EmailAddresses...) for _, s := range csr.IPAddresses { sans = append(sans, s.String()) } for _, s := range csr.URIs { sans = append(sans, s.String()) } key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) require.NoError(t, err) token, err := generateToken(csr.Subject.CommonName, "step-cli", testAudiences.Sign[0], sans, now, key) require.NoError(t, err) extraOpts, err := a.Authorize(ctx, token) require.NoError(t, err) chain, err := a.SignWithContext(ctx, csr, signOpts, extraOpts...) require.NoError(t, err) return chain } func Test_newHTTPClient(t *testing.T) { signer, err := keyutil.GenerateDefaultSigner() require.NoError(t, err) csr, err := x509util.CreateCertificateRequest("test", []string{"localhost", "127.0.0.1", "[::1]"}, signer) require.NoError(t, err) auth := testAuthority(t) chain := mustCertificate(t, auth, csr) t.Run("SystemCertPool", func(t *testing.T) { resp, err := auth.httpClient.Get("https://smallstep.com") require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) b, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.NotEmpty(t, b) assert.NoError(t, resp.Body.Close()) }) t.Run("LocalCertPool", func(t *testing.T) { srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "ok") })) srv.TLS = &tls.Config{ Certificates: []tls.Certificate{ {Certificate: [][]byte{chain[0].Raw, chain[1].Raw}, PrivateKey: signer, Leaf: chain[0]}, }, } srv.StartTLS() defer srv.Close() resp, err := auth.httpClient.Get(srv.URL) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) b, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, []byte("ok"), b) assert.NoError(t, resp.Body.Close()) t.Run("DefaultClient", func(t *testing.T) { client := &http.Client{} _, err := client.Get(srv.URL) assert.Error(t, err) }) }) t.Run("custom transport", func(t *testing.T) { tmp := http.DefaultTransport t.Cleanup(func() { http.DefaultTransport = tmp }) transport := struct { http.RoundTripper }{http.DefaultTransport} http.DefaultTransport = transport client := newHTTPClient(httptransport.NoopWrapper(), auth.rootX509Certs...) assert.NotNil(t, client) }) } ================================================ FILE: authority/internal/constraints/constraints.go ================================================ package constraints import ( "crypto/x509" "fmt" "net" "net/http" "net/url" "github.com/smallstep/certificates/errs" ) // ConstraintError is the typed error that will be returned if a constraint // error is found. type ConstraintError struct { Type string Name string Detail string } // Error implements the error interface. func (e ConstraintError) Error() string { return e.Detail } // As implements the As(any) bool interface and allows to use "errors.As()" to // convert the ConstraintError to an errs.Error. func (e ConstraintError) As(v any) bool { if err, ok := v.(**errs.Error); ok { *err = &errs.Error{ Status: http.StatusForbidden, Msg: e.Detail, Err: e, } return true } return false } // Engine implements a constraint validator for DNS names, IP addresses, Email // addresses and URIs. type Engine struct { hasNameConstraints bool permittedDNSDomains []string excludedDNSDomains []string permittedIPRanges []*net.IPNet excludedIPRanges []*net.IPNet permittedEmailAddresses []string excludedEmailAddresses []string permittedURIDomains []string excludedURIDomains []string } // New creates a constraint validation engine that contains the given chain of // certificates. func New(chain ...*x509.Certificate) *Engine { e := new(Engine) for _, crt := range chain { e.permittedDNSDomains = append(e.permittedDNSDomains, crt.PermittedDNSDomains...) e.excludedDNSDomains = append(e.excludedDNSDomains, crt.ExcludedDNSDomains...) e.permittedIPRanges = append(e.permittedIPRanges, crt.PermittedIPRanges...) e.excludedIPRanges = append(e.excludedIPRanges, crt.ExcludedIPRanges...) e.permittedEmailAddresses = append(e.permittedEmailAddresses, crt.PermittedEmailAddresses...) e.excludedEmailAddresses = append(e.excludedEmailAddresses, crt.ExcludedEmailAddresses...) e.permittedURIDomains = append(e.permittedURIDomains, crt.PermittedURIDomains...) e.excludedURIDomains = append(e.excludedURIDomains, crt.ExcludedURIDomains...) } e.hasNameConstraints = len(e.permittedDNSDomains) > 0 || len(e.excludedDNSDomains) > 0 || len(e.permittedIPRanges) > 0 || len(e.excludedIPRanges) > 0 || len(e.permittedEmailAddresses) > 0 || len(e.excludedEmailAddresses) > 0 || len(e.permittedURIDomains) > 0 || len(e.excludedURIDomains) > 0 return e } // Validate checks the given names with the name constraints defined in the // service. func (e *Engine) Validate(dnsNames []string, ipAddresses []net.IP, emailAddresses []string, uris []*url.URL) error { if e == nil || !e.hasNameConstraints { return nil } for _, name := range dnsNames { if err := checkNameConstraints("DNS name", name, name, e.permittedDNSDomains, e.excludedDNSDomains, func(parsedName, constraint any) (bool, error) { return matchDomainConstraint(parsedName.(string), constraint.(string)) }, ); err != nil { return err } } for _, ip := range ipAddresses { if err := checkNameConstraints("IP address", ip.String(), ip, e.permittedIPRanges, e.excludedIPRanges, func(parsedName, constraint any) (bool, error) { return matchIPConstraint(parsedName.(net.IP), constraint.(*net.IPNet)) }, ); err != nil { return err } } for _, email := range emailAddresses { mailbox, ok := parseRFC2821Mailbox(email) if !ok { return fmt.Errorf("cannot parse rfc822Name %q", email) } if err := checkNameConstraints("Email address", email, mailbox, e.permittedEmailAddresses, e.excludedEmailAddresses, func(parsedName, constraint any) (bool, error) { return matchEmailConstraint(parsedName.(rfc2821Mailbox), constraint.(string)) }, ); err != nil { return err } } for _, uri := range uris { if err := checkNameConstraints("URI", uri.String(), uri, e.permittedURIDomains, e.excludedURIDomains, func(parsedName, constraint any) (bool, error) { return matchURIConstraint(parsedName.(*url.URL), constraint.(string)) }, ); err != nil { return err } } return nil } // ValidateCertificate validates the DNS names, IP addresses, Email addresses // and URIs present in the given certificate. func (e *Engine) ValidateCertificate(cert *x509.Certificate) error { return e.Validate(cert.DNSNames, cert.IPAddresses, cert.EmailAddresses, cert.URIs) } ================================================ FILE: authority/internal/constraints/constraints_test.go ================================================ package constraints import ( "crypto/x509" "net" "net/url" "reflect" "testing" "go.step.sm/crypto/minica" ) func TestNew(t *testing.T) { ca1, err := minica.New() if err != nil { t.Fatal(err) } ca2, err := minica.New( minica.WithIntermediateTemplate(`{ "subject": {{ toJson .Subject }}, "keyUsage": ["certSign", "crlSign"], "basicConstraints": { "isCA": true, "maxPathLen": 0 }, "nameConstraints": { "critical": true, "permittedDNSDomains": ["internal.example.org"], "excludedDNSDomains": ["internal.example.com"], "permittedIPRanges": ["192.168.1.0/24", "192.168.2.1/32"], "excludedIPRanges": ["192.168.3.0/24", "192.168.4.0/28"], "permittedEmailAddresses": ["root@example.org", "example.org", ".acme.org"], "excludedEmailAddresses": ["root@example.com", "example.com", ".acme.com"], "permittedURIDomains": ["host.example.org", ".acme.org"], "excludedURIDomains": ["host.example.com", ".acme.com"] } }`), ) if err != nil { t.Fatal(err) } type args struct { chain []*x509.Certificate } tests := []struct { name string args args want *Engine }{ {"ok", args{[]*x509.Certificate{ca1.Intermediate, ca1.Root}}, &Engine{ hasNameConstraints: false, }}, {"ok with constraints", args{[]*x509.Certificate{ca2.Intermediate, ca2.Root}}, &Engine{ hasNameConstraints: true, permittedDNSDomains: []string{"internal.example.org"}, excludedDNSDomains: []string{"internal.example.com"}, permittedIPRanges: []*net.IPNet{ {IP: net.ParseIP("192.168.1.0").To4(), Mask: net.IPMask{255, 255, 255, 0}}, {IP: net.ParseIP("192.168.2.1").To4(), Mask: net.IPMask{255, 255, 255, 255}}, }, excludedIPRanges: []*net.IPNet{ {IP: net.ParseIP("192.168.3.0").To4(), Mask: net.IPMask{255, 255, 255, 0}}, {IP: net.ParseIP("192.168.4.0").To4(), Mask: net.IPMask{255, 255, 255, 240}}, }, permittedEmailAddresses: []string{"root@example.org", "example.org", ".acme.org"}, excludedEmailAddresses: []string{"root@example.com", "example.com", ".acme.com"}, permittedURIDomains: []string{"host.example.org", ".acme.org"}, excludedURIDomains: []string{"host.example.com", ".acme.com"}, }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := New(tt.args.chain...); !reflect.DeepEqual(got, tt.want) { t.Errorf("New() = %v, want %v", got, tt.want) } }) } } func TestNew_hasNameConstraints(t *testing.T) { tests := []struct { name string fn func(c *x509.Certificate) want bool }{ {"no constraints", func(c *x509.Certificate) {}, false}, {"permittedDNSDomains", func(c *x509.Certificate) { c.PermittedDNSDomains = []string{"constraint"} }, true}, {"excludedDNSDomains", func(c *x509.Certificate) { c.ExcludedDNSDomains = []string{"constraint"} }, true}, {"permittedIPRanges", func(c *x509.Certificate) { c.PermittedIPRanges = []*net.IPNet{{IP: net.ParseIP("192.168.3.0").To4(), Mask: net.IPMask{255, 255, 255, 0}}} }, true}, {"excludedIPRanges", func(c *x509.Certificate) { c.ExcludedIPRanges = []*net.IPNet{{IP: net.ParseIP("192.168.3.0").To4(), Mask: net.IPMask{255, 255, 255, 0}}} }, true}, {"permittedEmailAddresses", func(c *x509.Certificate) { c.PermittedEmailAddresses = []string{"constraint"} }, true}, {"excludedEmailAddresses", func(c *x509.Certificate) { c.ExcludedEmailAddresses = []string{"constraint"} }, true}, {"permittedURIDomains", func(c *x509.Certificate) { c.PermittedURIDomains = []string{"constraint"} }, true}, {"excludedURIDomains", func(c *x509.Certificate) { c.ExcludedURIDomains = []string{"constraint"} }, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cert := &x509.Certificate{} tt.fn(cert) if e := New(cert); e.hasNameConstraints != tt.want { t.Errorf("Engine.hasNameConstraints = %v, want %v", e.hasNameConstraints, tt.want) } }) } } func TestEngine_Validate(t *testing.T) { type fields struct { hasNameConstraints bool permittedDNSDomains []string excludedDNSDomains []string permittedIPRanges []*net.IPNet excludedIPRanges []*net.IPNet permittedEmailAddresses []string excludedEmailAddresses []string permittedURIDomains []string excludedURIDomains []string } type args struct { dnsNames []string ipAddresses []net.IP emailAddresses []string uris []*url.URL } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{hasNameConstraints: false}, args{ dnsNames: []string{"example.com", "host.example.com"}, ipAddresses: []net.IP{{192, 168, 1, 1}, {0x26, 0x00, 0x1f, 0x1c, 0x47, 0x01, 0x9d, 0x00, 0xc3, 0xa7, 0x66, 0x94, 0x87, 0x0f, 0x20, 0x72}}, emailAddresses: []string{"root@example.com"}, uris: []*url.URL{{Scheme: "https", Host: "example.com", Path: "/uuid/c6d1a755-0c12-431e-9136-b64cb3173ec7"}}, }, false}, {"ok permitted dns", fields{ hasNameConstraints: true, permittedDNSDomains: []string{"example.com"}, }, args{dnsNames: []string{"example.com", "www.example.com"}}, false}, {"ok not excluded dns", fields{ hasNameConstraints: true, excludedDNSDomains: []string{"example.org"}, }, args{dnsNames: []string{"example.com", "www.example.com"}}, false}, {"ok permitted ip", fields{ hasNameConstraints: true, permittedIPRanges: []*net.IPNet{ {IP: net.ParseIP("192.168.1.0"), Mask: net.IPMask{255, 255, 255, 0}}, {IP: net.ParseIP("192.168.2.1").To4(), Mask: net.IPMask{255, 255, 255, 255}}, {IP: net.ParseIP("2600:1700:22f8:2600:e559:bd88:350a:34d6"), Mask: net.IPMask{255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, }, }, args{ipAddresses: []net.IP{{192, 168, 1, 10}, {192, 168, 2, 1}, {0x26, 0x0, 0x17, 0x00, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc}}}, false}, {"ok not excluded ip", fields{ hasNameConstraints: true, excludedIPRanges: []*net.IPNet{ {IP: net.ParseIP("192.168.1.0"), Mask: net.IPMask{255, 255, 255, 0}}, {IP: net.ParseIP("192.168.2.1").To4(), Mask: net.IPMask{255, 255, 255, 255}}, }, }, args{ipAddresses: []net.IP{{192, 168, 2, 2}, {192, 168, 3, 1}}}, false}, {"ok permitted emails", fields{ hasNameConstraints: true, permittedEmailAddresses: []string{"root@example.com", "acme.org", ".acme.com"}, }, args{emailAddresses: []string{"root@example.com", "name@acme.org", "name@coyote.acme.com", `"(quoted)"@www.acme.com`}}, false}, {"ok not excluded emails", fields{ hasNameConstraints: true, excludedEmailAddresses: []string{"root@example.com", "acme.org", ".acme.com"}, }, args{emailAddresses: []string{"name@example.com", "root@acme.com", "root@other.com"}}, false}, {"ok permitted uris", fields{ hasNameConstraints: true, permittedURIDomains: []string{"example.com", ".acme.com"}, }, args{uris: []*url.URL{{Scheme: "https", Host: "example.com", Path: "/path"}, {Scheme: "https", Host: "www.acme.com", Path: "/path"}}}, false}, {"ok not excluded uris", fields{ hasNameConstraints: true, excludedURIDomains: []string{"example.com", ".acme.com"}, }, args{uris: []*url.URL{{Scheme: "https", Host: "example.org", Path: "/path"}, {Scheme: "https", Host: "acme.com", Path: "/path"}}}, false}, {"fail permitted dns", fields{ hasNameConstraints: true, permittedDNSDomains: []string{"example.com"}, }, args{dnsNames: []string{"www.example.com", "www.example.org"}}, true}, {"fail not excluded dns", fields{ hasNameConstraints: true, excludedDNSDomains: []string{"example.org"}, }, args{dnsNames: []string{"example.com", "www.example.org"}}, true}, {"fail permitted ip", fields{ hasNameConstraints: true, permittedIPRanges: []*net.IPNet{ {IP: net.ParseIP("192.168.1.0").To4(), Mask: net.IPMask{255, 255, 255, 0}}, {IP: net.ParseIP("192.168.2.1").To4(), Mask: net.IPMask{255, 255, 255, 255}}, }, }, args{ipAddresses: []net.IP{{192, 168, 1, 10}, {192, 168, 2, 10}}}, true}, {"fail not excluded ip", fields{ hasNameConstraints: true, excludedIPRanges: []*net.IPNet{ {IP: net.ParseIP("192.168.1.0").To4(), Mask: net.IPMask{255, 255, 255, 0}}, {IP: net.ParseIP("192.168.2.1").To4(), Mask: net.IPMask{255, 255, 255, 255}}, }, }, args{ipAddresses: []net.IP{{192, 168, 2, 2}, {192, 168, 1, 1}}}, true}, {"fail permitted emails", fields{ hasNameConstraints: true, permittedEmailAddresses: []string{"root@example.com", "acme.org", ".acme.com"}, }, args{emailAddresses: []string{"root@example.com", "name@acme.org", "name@acme.com"}}, true}, {"fail not excluded emails", fields{ hasNameConstraints: true, excludedEmailAddresses: []string{"root@example.com", "acme.org", ".acme.com"}, }, args{emailAddresses: []string{"name@example.com", "root@example.com"}}, true}, {"fail permitted uris", fields{ hasNameConstraints: true, permittedURIDomains: []string{"example.com", ".acme.com"}, }, args{uris: []*url.URL{{Scheme: "https", Host: "example.com", Path: "/path"}, {Scheme: "https", Host: "acme.com", Path: "/path"}}}, true}, {"fail not excluded uris", fields{ hasNameConstraints: true, excludedURIDomains: []string{"example.com", ".acme.com"}, }, args{uris: []*url.URL{{Scheme: "https", Host: "www.example.com", Path: "/path"}, {Scheme: "https", Host: "acme.com", Path: "/path"}}}, true}, {"fail parse emails", fields{ hasNameConstraints: true, permittedEmailAddresses: []string{"example.com"}, }, args{emailAddresses: []string{`(notquoted)@example.com`}}, true}, {"fail match dns", fields{ hasNameConstraints: true, permittedDNSDomains: []string{"example.com"}, }, args{dnsNames: []string{`www.example.com.`}}, true}, {"fail match email", fields{ hasNameConstraints: true, excludedEmailAddresses: []string{`(notquoted)@example.com`}, }, args{emailAddresses: []string{`ok@example.com`}}, true}, {"fail match uri", fields{ hasNameConstraints: true, permittedURIDomains: []string{"example.com"}, }, args{uris: []*url.URL{{Scheme: "urn", Opaque: "uuid:36efb1ae-6617-4b23-b799-874a37aaea1c"}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := &Engine{ hasNameConstraints: tt.fields.hasNameConstraints, permittedDNSDomains: tt.fields.permittedDNSDomains, excludedDNSDomains: tt.fields.excludedDNSDomains, permittedIPRanges: tt.fields.permittedIPRanges, excludedIPRanges: tt.fields.excludedIPRanges, permittedEmailAddresses: tt.fields.permittedEmailAddresses, excludedEmailAddresses: tt.fields.excludedEmailAddresses, permittedURIDomains: tt.fields.permittedURIDomains, excludedURIDomains: tt.fields.excludedURIDomains, } if err := e.Validate(tt.args.dnsNames, tt.args.ipAddresses, tt.args.emailAddresses, tt.args.uris); (err != nil) != tt.wantErr { t.Errorf("service.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestEngine_Validate_nil(t *testing.T) { var e *Engine if err := e.Validate([]string{"www.example.com"}, nil, nil, nil); err != nil { t.Errorf("service.Validate() error = %v, wantErr false", err) } } func TestEngine_ValidateCertificate(t *testing.T) { type fields struct { hasNameConstraints bool permittedDNSDomains []string excludedDNSDomains []string permittedIPRanges []*net.IPNet excludedIPRanges []*net.IPNet permittedEmailAddresses []string excludedEmailAddresses []string permittedURIDomains []string excludedURIDomains []string } type args struct { cert *x509.Certificate } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{hasNameConstraints: false}, args{&x509.Certificate{ DNSNames: []string{"example.com"}, IPAddresses: []net.IP{{127, 0, 0, 1}}, EmailAddresses: []string{"info@example.com"}, URIs: []*url.URL{{Scheme: "https", Host: "uuid.example.com", Path: "/dc4c76b5-5262-4551-a881-48094a604d63"}}, }}, false}, {"ok with constraints", fields{ hasNameConstraints: true, permittedDNSDomains: []string{"example.com"}, permittedIPRanges: []*net.IPNet{ {IP: net.ParseIP("127.0.0.1").To4(), Mask: net.IPMask{255, 255, 255, 255}}, {IP: net.ParseIP("10.3.0.0").To4(), Mask: net.IPMask{255, 255, 0, 0}}, }, permittedEmailAddresses: []string{"example.com"}, permittedURIDomains: []string{".example.com"}, }, args{&x509.Certificate{ DNSNames: []string{"www.example.com"}, IPAddresses: []net.IP{{127, 0, 0, 1}, {10, 3, 1, 1}}, EmailAddresses: []string{"info@example.com"}, URIs: []*url.URL{{Scheme: "https", Host: "uuid.example.com", Path: "/dc4c76b5-5262-4551-a881-48094a604d63"}}, }}, false}, {"fail", fields{ hasNameConstraints: true, permittedURIDomains: []string{".example.com"}, }, args{&x509.Certificate{ DNSNames: []string{"example.com"}, IPAddresses: []net.IP{{127, 0, 0, 1}}, EmailAddresses: []string{"info@example.com"}, URIs: []*url.URL{{Scheme: "https", Host: "uuid.example.org", Path: "/dc4c76b5-5262-4551-a881-48094a604d63"}}, }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := &Engine{ hasNameConstraints: tt.fields.hasNameConstraints, permittedDNSDomains: tt.fields.permittedDNSDomains, excludedDNSDomains: tt.fields.excludedDNSDomains, permittedIPRanges: tt.fields.permittedIPRanges, excludedIPRanges: tt.fields.excludedIPRanges, permittedEmailAddresses: tt.fields.permittedEmailAddresses, excludedEmailAddresses: tt.fields.excludedEmailAddresses, permittedURIDomains: tt.fields.permittedURIDomains, excludedURIDomains: tt.fields.excludedURIDomains, } if err := e.ValidateCertificate(tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("Engine.ValidateCertificate() error = %v, wantErr %v", err, tt.wantErr) } }) } } ================================================ FILE: authority/internal/constraints/verify.go ================================================ // Copyright (c) 2009 The Go Authors. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. package constraints import ( "bytes" "fmt" "net" "net/url" "reflect" "strings" ) func checkNameConstraints(nameType, name string, parsedName, permitted, excluded any, match func(name, constraint any) (bool, error)) error { excludedValue := reflect.ValueOf(excluded) for i := 0; i < excludedValue.Len(); i++ { constraint := excludedValue.Index(i).Interface() match, err := match(parsedName, constraint) if err != nil { return ConstraintError{ Type: nameType, Name: name, Detail: err.Error(), } } if match { return ConstraintError{ Type: nameType, Name: name, Detail: fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint), } } } var ( err error ok = true ) permittedValue := reflect.ValueOf(permitted) for i := 0; i < permittedValue.Len(); i++ { constraint := permittedValue.Index(i).Interface() if ok, err = match(parsedName, constraint); err != nil { return ConstraintError{ Type: nameType, Name: name, Detail: err.Error(), } } if ok { break } } if !ok { return ConstraintError{ Type: nameType, Name: name, Detail: fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name), } } return nil } func matchDomainConstraint(domain, constraint string) (bool, error) { // The meaning of zero length constraints is not specified, but this // code follows NSS and accepts them as matching everything. if constraint == "" { return true, nil } domainLabels, ok := domainToReverseLabels(domain) if !ok { return false, fmt.Errorf("internal error: cannot parse domain %q", domain) } // RFC 5280 says that a leading period in a domain name means that at least // one label must be prepended, but only for URI and email constraints, not // DNS constraints. The code also supports that behavior for DNS // constraints. mustHaveSubdomains := false if constraint[0] == '.' { mustHaveSubdomains = true constraint = constraint[1:] } constraintLabels, ok := domainToReverseLabels(constraint) if !ok { return false, fmt.Errorf("internal error: cannot parse domain %q", constraint) } if len(domainLabels) < len(constraintLabels) || (mustHaveSubdomains && len(domainLabels) == len(constraintLabels)) { return false, nil } for i, constraintLabel := range constraintLabels { if !strings.EqualFold(constraintLabel, domainLabels[i]) { return false, nil } } return true, nil } func normalizeIP(ip net.IP) net.IP { if ip4 := ip.To4(); ip4 != nil { return ip4 } return ip } func matchIPConstraint(ip net.IP, constraint *net.IPNet) (bool, error) { ip = normalizeIP(ip) constraintIP := normalizeIP(constraint.IP) if len(ip) != len(constraintIP) { return false, nil } for i := range ip { if mask := constraint.Mask[i]; ip[i]&mask != constraintIP[i]&mask { return false, nil } } return true, nil } func matchEmailConstraint(mailbox rfc2821Mailbox, constraint string) (bool, error) { // If the constraint contains an @, then it specifies an exact mailbox // name. if strings.Contains(constraint, "@") { constraintMailbox, ok := parseRFC2821Mailbox(constraint) if !ok { return false, fmt.Errorf("internal error: cannot parse constraint %q", constraint) } return mailbox.local == constraintMailbox.local && strings.EqualFold(mailbox.domain, constraintMailbox.domain), nil } // Otherwise the constraint is like a DNS constraint of the domain part // of the mailbox. return matchDomainConstraint(mailbox.domain, constraint) } func matchURIConstraint(uri *url.URL, constraint string) (bool, error) { // From RFC 5280, Section 4.2.1.10: // “a uniformResourceIdentifier that does not include an authority // component with a host name specified as a fully qualified domain // name (e.g., if the URI either does not include an authority // component or includes an authority component in which the host name // is specified as an IP address), then the application MUST reject the // certificate.” host := uri.Host if host == "" { return false, fmt.Errorf("URI with empty host (%q) cannot be matched against constraints", uri.String()) } if strings.Contains(host, ":") && !strings.HasSuffix(host, "]") { var err error host, _, err = net.SplitHostPort(uri.Host) if err != nil { return false, err } } if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") || net.ParseIP(host) != nil { return false, fmt.Errorf("URI with IP (%q) cannot be matched against constraints", uri.String()) } return matchDomainConstraint(host, constraint) } // domainToReverseLabels converts a textual domain name like foo.example.com to // the list of labels in reverse order, e.g. ["com", "example", "foo"]. func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) { for domain != "" { if i := strings.LastIndexByte(domain, '.'); i == -1 { reverseLabels = append(reverseLabels, domain) domain = "" } else { reverseLabels = append(reverseLabels, domain[i+1:]) domain = domain[:i] } } if len(reverseLabels) > 0 && reverseLabels[0] == "" { // An empty label at the end indicates an absolute value. return nil, false } for _, label := range reverseLabels { if label == "" { // Empty labels are otherwise invalid. return nil, false } for _, c := range label { if c < 33 || c > 126 { // Invalid character. return nil, false } } } return reverseLabels, true } // rfc2821Mailbox represents a “mailbox” (which is an email address to most // people) by breaking it into the “local” (i.e. before the '@') and “domain” // parts. type rfc2821Mailbox struct { local, domain string } // parseRFC2821Mailbox parses an email address into local and domain parts, // based on the ABNF for a “Mailbox” from RFC 2821. According to RFC 5280, // Section 4.2.1.6 that's correct for an rfc822Name from a certificate: “The // format of an rfc822Name is a "Mailbox" as defined in RFC 2821, Section 4.1.2”. func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) { if in == "" { return mailbox, false } localPartBytes := make([]byte, 0, len(in)/2) if in[0] == '"' { // Quoted-string = DQUOTE *qcontent DQUOTE // non-whitespace-control = %d1-8 / %d11 / %d12 / %d14-31 / %d127 // qcontent = qtext / quoted-pair // qtext = non-whitespace-control / // %d33 / %d35-91 / %d93-126 // quoted-pair = ("\" text) / obs-qp // text = %d1-9 / %d11 / %d12 / %d14-127 / obs-text // // (Names beginning with “obs-” are the obsolete syntax from RFC 2822, // Section 4. Since it has been 16 years, we no longer accept that.) in = in[1:] QuotedString: for { if in == "" { return mailbox, false } c := in[0] in = in[1:] switch { case c == '"': break QuotedString case c == '\\': // quoted-pair if in == "" { return mailbox, false } if in[0] == 11 || in[0] == 12 || (1 <= in[0] && in[0] <= 9) || (14 <= in[0] && in[0] <= 127) { localPartBytes = append(localPartBytes, in[0]) in = in[1:] } else { return mailbox, false } case c == 11 || c == 12 || // Space (char 32) is not allowed based on the // BNF, but RFC 3696 gives an example that // assumes that it is. Several “verified” // errata continue to argue about this point. // We choose to accept it. c == 32 || c == 33 || c == 127 || (1 <= c && c <= 8) || (14 <= c && c <= 31) || (35 <= c && c <= 91) || (93 <= c && c <= 126): // qtext localPartBytes = append(localPartBytes, c) default: return mailbox, false } } } else { // Atom ("." Atom)* NextChar: for in != "" { // atext from RFC 2822, Section 3.2.4 c := in[0] switch { case c == '\\': // Examples given in RFC 3696 suggest that // escaped characters can appear outside of a // quoted string. Several “verified” errata // continue to argue the point. We choose to // accept it. in = in[1:] if in == "" { return mailbox, false } fallthrough case ('0' <= c && c <= '9') || ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '!' || c == '#' || c == '$' || c == '%' || c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || c == '/' || c == '=' || c == '?' || c == '^' || c == '_' || c == '`' || c == '{' || c == '|' || c == '}' || c == '~' || c == '.': localPartBytes = append(localPartBytes, in[0]) in = in[1:] default: break NextChar } } if len(localPartBytes) == 0 { return mailbox, false } // From RFC 3696, Section 3: // “period (".") may also appear, but may not be used to start // or end the local part, nor may two or more consecutive // periods appear.” twoDots := []byte{'.', '.'} if localPartBytes[0] == '.' || localPartBytes[len(localPartBytes)-1] == '.' || bytes.Contains(localPartBytes, twoDots) { return mailbox, false } } if in == "" || in[0] != '@' { return mailbox, false } in = in[1:] // The RFC species a format for domains, but that's known to be // violated in practice so we accept that anything after an '@' is the // domain part. if _, ok := domainToReverseLabels(in); !ok { return mailbox, false } mailbox.local = string(localPartBytes) mailbox.domain = in return mailbox, true } ================================================ FILE: authority/linkedca.go ================================================ package authority import ( "context" "crypto" "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/hex" "encoding/pem" "fmt" "net/url" "regexp" "strings" "time" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/tlsutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/internal/cast" ) const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$" type linkedCaClient struct { renewer *tlsutil.Renewer client linkedca.MajordomoClient authorityID string } // interface guard var _ admin.DB = (*linkedCaClient)(nil) type linkedCAClaims struct { jose.Claims SANs []string `json:"sans"` SHA string `json:"sha"` } func newLinkedCAClient(token string) (*linkedCaClient, error) { tok, err := jose.ParseSigned(token) if err != nil { return nil, errors.Wrap(err, "error parsing token") } var claims linkedCAClaims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, errors.Wrap(err, "error parsing token") } // Validate claims if len(claims.Audience) != 1 { return nil, errors.New("error parsing token: invalid aud claim") } if claims.SHA == "" { return nil, errors.New("error parsing token: invalid sha claim") } // Get linkedCA endpoint from audience. u, err := url.Parse(claims.Audience[0]) if err != nil { return nil, errors.New("error parsing token: invalid aud claim") } // Get authority from SANs authority, err := getAuthority(claims.SANs) if err != nil { return nil, err } // Create csr to login with signer, err := keyutil.GenerateDefaultSigner() if err != nil { return nil, err } csr, err := x509util.CreateCertificateRequest(claims.Subject, claims.SANs, signer) if err != nil { return nil, err } // Get and verify root certificate root, err := getRootCertificate(u.Host, claims.SHA) if err != nil { return nil, err } pool := x509.NewCertPool() pool.AddCert(root) // Login with majordomo and get certificates cert, tlsConfig, err := login(authority, token, csr, signer, u.Host, pool) if err != nil { return nil, err } // Start TLS renewer and set the GetClientCertificate callback to it. renewer, err := tlsutil.NewRenewer(cert, tlsConfig, func() (*tls.Certificate, *tls.Config, error) { return login(authority, token, csr, signer, u.Host, pool) }) if err != nil { return nil, err } tlsConfig.GetClientCertificate = renewer.GetClientCertificate // Start mTLS client conn, err := grpc.NewClient(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) if err != nil { return nil, errors.Wrapf(err, "error connecting %s", u.Host) } return &linkedCaClient{ renewer: renewer, client: linkedca.NewMajordomoClient(conn), authorityID: authority, }, nil } // IsLinkedCA is a sentinel function that can be used to // check if a linkedCaClient is the underlying type of an // admin.DB interface. func (c *linkedCaClient) IsLinkedCA() bool { return true } func (c *linkedCaClient) Run() { c.renewer.Run() } func (c *linkedCaClient) Stop() { c.renewer.Stop() } func (c *linkedCaClient) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { resp, err := c.client.CreateProvisioner(ctx, &linkedca.CreateProvisionerRequest{ Type: prov.Type, Name: prov.Name, Details: prov.Details, Claims: prov.Claims, X509Template: prov.X509Template, SshTemplate: prov.SshTemplate, }) if err != nil { return errors.Wrap(err, "error creating provisioner") } prov.Id = resp.Id prov.AuthorityId = resp.AuthorityId return nil } func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) { resp, err := c.client.GetProvisioner(ctx, &linkedca.GetProvisionerRequest{ Id: id, }) if err != nil { return nil, errors.Wrap(err, "error getting provisioners") } return resp, nil } func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) { resp, err := c.GetConfiguration(ctx) if err != nil { return nil, err } return resp.Provisioners, nil } func (c *linkedCaClient) GetConfiguration(ctx context.Context) (*linkedca.ConfigurationResponse, error) { resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ AuthorityId: c.authorityID, }) if err != nil { return nil, errors.Wrap(err, "error getting configuration") } return resp, nil } func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { _, err := c.client.UpdateProvisioner(ctx, &linkedca.UpdateProvisionerRequest{ Id: prov.Id, Name: prov.Name, Details: prov.Details, Claims: prov.Claims, X509Template: prov.X509Template, SshTemplate: prov.SshTemplate, }) return errors.Wrap(err, "error updating provisioner") } func (c *linkedCaClient) DeleteProvisioner(ctx context.Context, id string) error { _, err := c.client.DeleteProvisioner(ctx, &linkedca.DeleteProvisionerRequest{ Id: id, }) return errors.Wrap(err, "error deleting provisioner") } func (c *linkedCaClient) CreateAdmin(ctx context.Context, adm *linkedca.Admin) error { resp, err := c.client.CreateAdmin(ctx, &linkedca.CreateAdminRequest{ Subject: adm.Subject, ProvisionerId: adm.ProvisionerId, Type: adm.Type, }) if err != nil { return errors.Wrap(err, "error creating admin") } adm.Id = resp.Id adm.AuthorityId = resp.AuthorityId return nil } func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) { resp, err := c.client.GetAdmin(ctx, &linkedca.GetAdminRequest{ Id: id, }) if err != nil { return nil, errors.Wrap(err, "error getting admins") } return resp, nil } func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) { resp, err := c.GetConfiguration(ctx) if err != nil { return nil, err } return resp.Admins, nil } func (c *linkedCaClient) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error { _, err := c.client.UpdateAdmin(ctx, &linkedca.UpdateAdminRequest{ Id: adm.Id, Type: adm.Type, }) return errors.Wrap(err, "error updating admin") } func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error { _, err := c.client.DeleteAdmin(ctx, &linkedca.DeleteAdminRequest{ Id: id, }) return errors.Wrap(err, "error deleting admin") } func (c *linkedCaClient) GetCertificateData(serial string) (*db.CertificateData, error) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() resp, err := c.client.GetCertificate(ctx, &linkedca.GetCertificateRequest{ Serial: serial, }) if err != nil { return nil, err } var pd *db.ProvisionerData if p := resp.Provisioner; p != nil { pd = &db.ProvisionerData{ ID: p.Id, Name: p.Name, Type: p.Type.String(), } } var raInfo *provisioner.RAInfo if p := resp.RaProvisioner; p != nil && p.Provisioner != nil { raInfo = &provisioner.RAInfo{ AuthorityID: p.AuthorityId, ProvisionerID: p.Provisioner.Id, ProvisionerType: p.Provisioner.Type.String(), ProvisionerName: p.Provisioner.Name, } } return &db.CertificateData{ Provisioner: pd, RaInfo: raInfo, }, nil } func (c *linkedCaClient) StoreCertificateChain(p provisioner.Interface, fullchain ...*x509.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() raProvisioner, endpointID := createRegistrationAuthorityProvisioner(p) _, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{ PemCertificate: serializeCertificateChain(fullchain[0]), PemCertificateChain: serializeCertificateChain(fullchain[1:]...), Provisioner: createProvisionerIdentity(p), AttestationData: createAttestationData(p), RaProvisioner: raProvisioner, EndpointId: endpointID, }) return errors.Wrap(err, "error posting certificate") } func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullchain ...*x509.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{ PemCertificate: serializeCertificateChain(fullchain[0]), PemCertificateChain: serializeCertificateChain(fullchain[1:]...), PemParentCertificate: serializeCertificateChain(parent), }) return errors.Wrap(err, "error posting renewed certificate") } func (c *linkedCaClient) StoreSSHCertificate(p provisioner.Interface, crt *ssh.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{ Certificate: string(ssh.MarshalAuthorizedKey(crt)), Provisioner: createProvisionerIdentity(p), }) return errors.Wrap(err, "error posting ssh certificate") } func (c *linkedCaClient) StoreRenewedSSHCertificate(p provisioner.Interface, parent, crt *ssh.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{ Certificate: string(ssh.MarshalAuthorizedKey(crt)), ParentCertificate: string(ssh.MarshalAuthorizedKey(parent)), Provisioner: createProvisionerIdentity(p), }) return errors.Wrap(err, "error posting renewed ssh certificate") } func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.RevokeCertificate(ctx, &linkedca.RevokeCertificateRequest{ Serial: rci.Serial, PemCertificate: serializeCertificate(crt), Reason: rci.Reason, ReasonCode: linkedca.RevocationReasonCode(cast.Int32(rci.ReasonCode)), Passive: true, }) return errors.Wrap(err, "error revoking certificate") } func (c *linkedCaClient) RevokeSSH(cert *ssh.Certificate, rci *db.RevokedCertificateInfo) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.RevokeSSHCertificate(ctx, &linkedca.RevokeSSHCertificateRequest{ Serial: rci.Serial, Certificate: serializeSSHCertificate(cert), Reason: rci.Reason, ReasonCode: linkedca.RevocationReasonCode(cast.Int32(rci.ReasonCode)), Passive: true, }) return errors.Wrap(err, "error revoking ssh certificate") } func (c *linkedCaClient) IsRevoked(serial string) (bool, error) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() resp, err := c.client.GetCertificateStatus(ctx, &linkedca.GetCertificateStatusRequest{ Serial: serial, }) if err != nil { return false, errors.Wrap(err, "error getting certificate status") } return resp.Status != linkedca.RevocationStatus_ACTIVE, nil } func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() resp, err := c.client.GetSSHCertificateStatus(ctx, &linkedca.GetSSHCertificateStatusRequest{ Serial: serial, }) if err != nil { return false, errors.Wrap(err, "error getting certificate status") } return resp.Status != linkedca.RevocationStatus_ACTIVE, nil } func (c *linkedCaClient) CreateAuthorityPolicy(_ context.Context, _ *linkedca.Policy) error { return errors.New("not implemented yet") } func (c *linkedCaClient) GetAuthorityPolicy(context.Context) (*linkedca.Policy, error) { return nil, errors.New("not implemented yet") } func (c *linkedCaClient) UpdateAuthorityPolicy(_ context.Context, _ *linkedca.Policy) error { return errors.New("not implemented yet") } func (c *linkedCaClient) DeleteAuthorityPolicy(context.Context) error { return errors.New("not implemented yet") } func createProvisionerIdentity(p provisioner.Interface) *linkedca.ProvisionerIdentity { if p == nil { return nil } return &linkedca.ProvisionerIdentity{ Id: p.GetID(), Type: linkedca.Provisioner_Type(cast.Int32(int(p.GetType()))), Name: p.GetName(), } } func createRegistrationAuthorityProvisioner(p provisioner.Interface) (*linkedca.RegistrationAuthorityProvisioner, string) { if rap, ok := p.(raProvisioner); ok { if info := rap.RAInfo(); info != nil { typ := linkedca.Provisioner_Type_value[strings.ToUpper(info.ProvisionerType)] return &linkedca.RegistrationAuthorityProvisioner{ AuthorityId: info.AuthorityID, Provisioner: &linkedca.ProvisionerIdentity{ Id: info.ProvisionerID, Type: linkedca.Provisioner_Type(typ), Name: info.ProvisionerName, }, }, info.EndpointID } } return nil, "" } func createAttestationData(p provisioner.Interface) *linkedca.AttestationData { if ap, ok := p.(attProvisioner); ok { if data := ap.AttestationData(); data != nil { return &linkedca.AttestationData{ PermanentIdentifier: data.PermanentIdentifier, } } } return nil } func serializeCertificate(crt *x509.Certificate) string { if crt == nil { return "" } return string(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, })) } func serializeCertificateChain(fullchain ...*x509.Certificate) string { var chain string for _, crt := range fullchain { chain += string(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, })) } return chain } func serializeSSHCertificate(crt *ssh.Certificate) string { if crt == nil { return "" } return string(ssh.MarshalAuthorizedKey(crt)) } func getAuthority(sans []string) (string, error) { for _, s := range sans { if strings.HasPrefix(s, "urn:smallstep:authority:") { if regexp.MustCompile(uuidPattern).MatchString(s[24:]) { return s[24:], nil } } } return "", fmt.Errorf("error parsing token: invalid sans claim") } // getRootCertificate creates an insecure majordomo client and returns the // verified root certificate. func getRootCertificate(endpoint, fingerprint string) (*x509.Certificate, error) { conn, err := grpc.NewClient(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ //nolint:gosec // used in bootstrap protocol InsecureSkipVerify: true, // lgtm[go/disabled-certificate-check] }))) if err != nil { return nil, errors.Wrapf(err, "error connecting %s", endpoint) } ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() client := linkedca.NewMajordomoClient(conn) resp, err := client.GetRootCertificate(ctx, &linkedca.GetRootCertificateRequest{ Fingerprint: fingerprint, }) if err != nil { return nil, fmt.Errorf("error getting root certificate: %w", err) } var block *pem.Block b := []byte(resp.PemCertificate) for len(b) > 0 { block, b = pem.Decode(b) if block == nil { break } if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { continue } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, fmt.Errorf("error parsing certificate: %w", err) } // verify the sha256 sum := sha256.Sum256(cert.Raw) if !strings.EqualFold(fingerprint, hex.EncodeToString(sum[:])) { return nil, fmt.Errorf("error verifying certificate: SHA256 fingerprint does not match") } return cert, nil } return nil, fmt.Errorf("error getting root certificate: certificate not found") } // login creates a new majordomo client with just the root ca pool and returns // the signed certificate and tls configuration. func login(authority, token string, csr *x509.CertificateRequest, signer crypto.PrivateKey, endpoint string, rootCAs *x509.CertPool) (*tls.Certificate, *tls.Config, error) { conn, err := grpc.NewClient(endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ MinVersion: tls.VersionTLS12, RootCAs: rootCAs, }))) if err != nil { return nil, nil, errors.Wrapf(err, "error connecting %s", endpoint) } // Login to get the signed certificate ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() client := linkedca.NewMajordomoClient(conn) resp, err := client.Login(ctx, &linkedca.LoginRequest{ AuthorityId: authority, Token: token, PemCertificateRequest: string(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE REQUEST", Bytes: csr.Raw, })), }) if err != nil { return nil, nil, errors.Wrapf(err, "error logging in %s", endpoint) } // Parse login response var block *pem.Block var bundle []*x509.Certificate rest := []byte(resp.PemCertificateChain) for { block, rest = pem.Decode(rest) if block == nil { break } if block.Type != "CERTIFICATE" { return nil, nil, errors.New("error decoding login response: pemCertificateChain is not a certificate bundle") } crt, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, nil, errors.Wrap(err, "error parsing login response") } bundle = append(bundle, crt) } if len(bundle) == 0 { return nil, nil, errors.New("error decoding login response: pemCertificateChain should not be empty") } // Build tls.Certificate with PemCertificate and intermediates in the // PemCertificateChain cert := &tls.Certificate{ PrivateKey: signer, } rest = []byte(resp.PemCertificate) for { block, rest = pem.Decode(rest) if block == nil { break } if block.Type == "CERTIFICATE" { leaf, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, nil, errors.Wrap(err, "error parsing pemCertificate") } cert.Certificate = append(cert.Certificate, block.Bytes) cert.Leaf = leaf } } // Add intermediates to the tls.Certificate last := len(bundle) - 1 for i := 0; i < last; i++ { cert.Certificate = append(cert.Certificate, bundle[i].Raw) } // Add root to the pool if it's not there yet rootCAs.AddCert(bundle[last]) return cert, &tls.Config{ MinVersion: tls.VersionTLS12, RootCAs: rootCAs, }, nil } ================================================ FILE: authority/meter.go ================================================ package authority import ( "crypto" "crypto/x509" "io" "go.step.sm/crypto/kms" kmsapi "go.step.sm/crypto/kms/apiv1" "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/authority/provisioner" ) // Meter wraps the set of defined callbacks for metrics gatherers. type Meter interface { // X509Signed is called whenever an X509 certificate is signed. X509Signed([]*x509.Certificate, provisioner.Interface, error) // X509Renewed is called whenever an X509 certificate is renewed. X509Renewed([]*x509.Certificate, provisioner.Interface, error) // X509Rekeyed is called whenever an X509 certificate is rekeyed. X509Rekeyed([]*x509.Certificate, provisioner.Interface, error) // X509WebhookAuthorized is called whenever an X509 authoring webhook is called. X509WebhookAuthorized(provisioner.Interface, error) // X509WebhookEnriched is called whenever an X509 enriching webhook is called. X509WebhookEnriched(provisioner.Interface, error) // SSHSigned is called whenever an SSH certificate is signed. SSHSigned(*ssh.Certificate, provisioner.Interface, error) // SSHRenewed is called whenever an SSH certificate is renewed. SSHRenewed(*ssh.Certificate, provisioner.Interface, error) // SSHRekeyed is called whenever an SSH certificate is rekeyed. SSHRekeyed(*ssh.Certificate, provisioner.Interface, error) // SSHWebhookAuthorized is called whenever an SSH authoring webhook is called. SSHWebhookAuthorized(provisioner.Interface, error) // SSHWebhookEnriched is called whenever an SSH enriching webhook is called. SSHWebhookEnriched(provisioner.Interface, error) // KMSSigned is called per KMS signer signature. KMSSigned(error) } // noopMeter implements a noop [Meter]. type noopMeter struct{} func (noopMeter) SSHRekeyed(*ssh.Certificate, provisioner.Interface, error) {} func (noopMeter) SSHRenewed(*ssh.Certificate, provisioner.Interface, error) {} func (noopMeter) SSHSigned(*ssh.Certificate, provisioner.Interface, error) {} func (noopMeter) SSHWebhookAuthorized(provisioner.Interface, error) {} func (noopMeter) SSHWebhookEnriched(provisioner.Interface, error) {} func (noopMeter) X509Rekeyed([]*x509.Certificate, provisioner.Interface, error) {} func (noopMeter) X509Renewed([]*x509.Certificate, provisioner.Interface, error) {} func (noopMeter) X509Signed([]*x509.Certificate, provisioner.Interface, error) {} func (noopMeter) X509WebhookAuthorized(provisioner.Interface, error) {} func (noopMeter) X509WebhookEnriched(provisioner.Interface, error) {} func (noopMeter) KMSSigned(error) {} type instrumentedKeyManager struct { kms.KeyManager meter Meter } type instrumentedKeyAndDecrypterManager struct { kms.KeyManager decrypter kmsapi.Decrypter meter Meter } func newInstrumentedKeyManager(k kms.KeyManager, m Meter) kms.KeyManager { decrypter, isDecrypter := k.(kmsapi.Decrypter) switch { case isDecrypter: return &instrumentedKeyAndDecrypterManager{&instrumentedKeyManager{k, m}, decrypter, m} default: return &instrumentedKeyManager{k, m} } } func (i *instrumentedKeyManager) CreateSigner(req *kmsapi.CreateSignerRequest) (s crypto.Signer, err error) { if s, err = i.KeyManager.CreateSigner(req); err == nil { s = &instrumentedKMSSigner{s, i.meter} } return } func (i *instrumentedKeyAndDecrypterManager) CreateDecrypter(req *kmsapi.CreateDecrypterRequest) (s crypto.Decrypter, err error) { return i.decrypter.CreateDecrypter(req) } type instrumentedKMSSigner struct { crypto.Signer meter Meter } func (i *instrumentedKMSSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { signature, err = i.Signer.Sign(rand, digest, opts) i.meter.KMSSigned(err) return } var _ kms.KeyManager = (*instrumentedKeyManager)(nil) var _ kms.KeyManager = (*instrumentedKeyAndDecrypterManager)(nil) var _ kmsapi.Decrypter = (*instrumentedKeyAndDecrypterManager)(nil) ================================================ FILE: authority/options.go ================================================ package authority import ( "context" "crypto" "crypto/x509" "encoding/pem" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "go.step.sm/crypto/kms" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/scep" ) // Option sets options to the Authority. type Option func(*Authority) error // WithConfig replaces the current config with the given one. No validation is // performed in the given value. func WithConfig(cfg *config.Config) Option { return func(a *Authority) error { a.config = cfg return nil } } // WithConfigFile reads the given filename as a configuration file and replaces // the current one. No validation is performed in the given configuration. func WithConfigFile(filename string) Option { return func(a *Authority) (err error) { a.config, err = config.LoadConfiguration(filename) return } } // WithPassword set the password to decrypt the intermediate key as well as the // ssh host and user keys if they are not overridden by other options. func WithPassword(password []byte) Option { return func(a *Authority) (err error) { a.password = password return } } // WithSSHHostPassword set the password to decrypt the key used to sign SSH host // certificates. func WithSSHHostPassword(password []byte) Option { return func(a *Authority) (err error) { a.sshHostPassword = password return } } // WithSSHUserPassword set the password to decrypt the key used to sign SSH user // certificates. func WithSSHUserPassword(password []byte) Option { return func(a *Authority) (err error) { a.sshUserPassword = password return } } // WithIssuerPassword set the password to decrypt the certificate issuer private // key used in RA mode. func WithIssuerPassword(password []byte) Option { return func(a *Authority) (err error) { a.issuerPassword = password return } } // WithDatabase sets an already initialized authority database to a new // authority. This option is intended to be use on graceful reloads. func WithDatabase(d db.AuthDB) Option { return func(a *Authority) error { a.db = d return nil } } // WithQuietInit disables log output when the authority is initialized. func WithQuietInit() Option { return func(a *Authority) error { a.quietInit = true return nil } } // WithWebhookClient sets the http.Client to be used for outbound requests. func WithWebhookClient(c provisioner.HTTPClient) Option { return func(a *Authority) error { a.webhookClient = c return nil } } // Wrapper wraps the set of functions mapping [http.Transport] references to [http.RoundTripper]. type TransportWrapper = httptransport.Wrapper // WithTransportWrapper sets the transport wrapper of the authority to the provided one or, in case // that one is nil, to a noop one. func WithTransportWrapper(tw httptransport.Wrapper) Option { if tw == nil { tw = httptransport.NoopWrapper() } return func(a *Authority) error { a.wrapTransport = tw return nil } } // WithGetIdentityFunc sets a custom function to retrieve the identity from // an external resource. func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, email string) (*provisioner.Identity, error)) Option { return func(a *Authority) error { a.getIdentityFunc = fn return nil } } // WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of // an X.509 certificate. func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option { return func(a *Authority) error { a.authorizeRenewFunc = fn return nil } } // WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal // of a SSH certificate. func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option { return func(a *Authority) error { a.authorizeSSHRenewFunc = fn return nil } } // WithSSHBastionFunc sets a custom function to get the bastion for a // given user-host pair. func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option { return func(a *Authority) error { a.sshBastionFunc = fn return nil } } // WithSSHGetHosts sets a custom function to return a list of step ssh enabled // hosts. func WithSSHGetHosts(fn func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)) Option { return func(a *Authority) error { a.sshGetHostsFunc = fn return nil } } // WithSSHCheckHost sets a custom function to check whether a given host is // step ssh enabled. The token is used to validate the request, while the roots // are used to validate the token. func WithSSHCheckHost(fn func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)) Option { return func(a *Authority) error { a.sshCheckHostFunc = fn return nil } } // WithKeyManager defines the key manager used to get and create keys, and sign // certificates. func WithKeyManager(k kms.KeyManager) Option { return func(a *Authority) error { a.keyManager = k return nil } } // WithX509CAService allows the consumer to provide an externally implemented // API implementation of apiv1.CertificateAuthorityService func WithX509CAService(svc casapi.CertificateAuthorityService) Option { return func(a *Authority) error { a.x509CAService = svc return nil } } // WithX509Signer defines the signer used to sign X509 certificates. func WithX509Signer(crt *x509.Certificate, s crypto.Signer) Option { return WithX509SignerChain([]*x509.Certificate{crt}, s) } // WithX509SignerChain defines the signer used to sign X509 certificates. This // option is similar to WithX509Signer but it supports a chain of intermediates. func WithX509SignerChain(issuerChain []*x509.Certificate, s crypto.Signer) Option { return func(a *Authority) error { srv, err := cas.New(context.Background(), casapi.Options{ Type: casapi.SoftCAS, Signer: s, CertificateChain: issuerChain, }) if err != nil { return err } a.x509CAService = srv a.intermediateX509Certs = append(a.intermediateX509Certs, issuerChain...) return nil } } // WithX509SignerFunc defines the function used to get the chain of certificates // and signer used when we sign X.509 certificates. func WithX509SignerFunc(fn func() ([]*x509.Certificate, crypto.Signer, error)) Option { return func(a *Authority) error { srv, err := cas.New(context.Background(), casapi.Options{ Type: casapi.SoftCAS, CertificateSigner: fn, }) if err != nil { return err } a.x509CAService = srv return nil } } // WithFullSCEPOptions defines the options used for SCEP support. // // This feature is EXPERIMENTAL and might change at any time. func WithFullSCEPOptions(options *scep.Options) Option { return func(a *Authority) error { a.scepOptions = options a.validateSCEP = false return nil } } // WithSCEPKeyManager defines the key manager used on SCEP provisioners. // // This feature is EXPERIMENTAL and might change at any time. func WithSCEPKeyManager(skm provisioner.SCEPKeyManager) Option { return func(a *Authority) error { a.scepKeyManager = skm return nil } } // WithSSHUserSigner defines the signer used to sign SSH user certificates. func WithSSHUserSigner(s crypto.Signer) Option { return func(a *Authority) error { signer, err := ssh.NewSignerFromSigner(s) if err != nil { return errors.Wrap(err, "error creating ssh user signer") } a.sshCAUserCertSignKey = signer // Append public key to list of user certs pub := signer.PublicKey() a.sshCAUserCerts = append(a.sshCAUserCerts, pub) a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, pub) return nil } } // WithSSHHostSigner defines the signer used to sign SSH host certificates. func WithSSHHostSigner(s crypto.Signer) Option { return func(a *Authority) error { signer, err := ssh.NewSignerFromSigner(s) if err != nil { return errors.Wrap(err, "error creating ssh host signer") } a.sshCAHostCertSignKey = signer // Append public key to list of host certs pub := signer.PublicKey() a.sshCAHostCerts = append(a.sshCAHostCerts, pub) a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, pub) return nil } } // WithX509RootCerts is an option that allows to define the list of root // certificates to use. This option will replace any root certificate defined // before. func WithX509RootCerts(rootCerts ...*x509.Certificate) Option { return func(a *Authority) error { a.rootX509Certs = rootCerts return nil } } // WithX509FederatedCerts is an option that allows to define the list of // federated certificates. This option will replace any federated certificate // defined before. func WithX509FederatedCerts(certs ...*x509.Certificate) Option { return func(a *Authority) error { a.federatedX509Certs = certs return nil } } // WithX509IntermediateCerts is an option that allows to define the list of // intermediate certificates that the CA will be using. This option will replace // any intermediate certificate defined before. // // Note that these certificates will not be bundled with the certificates signed // by the CA, because the CAS service will take care of that. They should match, // but that's not guaranteed. These certificates will be mainly used for name // constraint validation before a certificate is issued. // // This option should only be used on specific configurations, for example when // WithX509SignerFunc is used, as we don't know the list of intermediates in // advance. func WithX509IntermediateCerts(intermediateCerts ...*x509.Certificate) Option { return func(a *Authority) error { a.intermediateX509Certs = intermediateCerts return nil } } // WithX509RootBundle is an option that allows to define the list of root // certificates. This option will replace any root certificate defined before. func WithX509RootBundle(pemCerts []byte) Option { return func(a *Authority) error { certs, err := readCertificateBundle(pemCerts) if err != nil { return err } a.rootX509Certs = certs return nil } } // WithX509FederatedBundle is an option that allows to define the list of // federated certificates. This option will replace any federated certificate // defined before. func WithX509FederatedBundle(pemCerts []byte) Option { return func(a *Authority) error { certs, err := readCertificateBundle(pemCerts) if err != nil { return err } a.federatedX509Certs = certs return nil } } // WithAdminDB is an option to set the database backing the admin APIs. func WithAdminDB(d admin.DB) Option { return func(a *Authority) error { a.adminDB = d return nil } } // WithProvisioners is an option to set the provisioner collection. // // Deprecated: provisioner collections will likely change func WithProvisioners(ps *provisioner.Collection) Option { return func(a *Authority) error { a.provisioners = ps return nil } } // WithLinkedCAToken is an option to set the authentication token used to enable // linked ca. func WithLinkedCAToken(token string) Option { return func(a *Authority) error { a.linkedCAToken = token return nil } } // WithX509Enforcers is an option that allows to define custom certificate // modifiers that will be processed just before the signing of the certificate. func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option { return func(a *Authority) error { a.x509Enforcers = ces return nil } } // WithSkipInit is an option that allows the constructor to skip initializtion // of the authority. func WithSkipInit() Option { return func(a *Authority) error { a.skipInit = true return nil } } func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) { var block *pem.Block var certs []*x509.Certificate for len(pemCerts) > 0 { block, pemCerts = pem.Decode(pemCerts) if block == nil { break } if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { continue } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, err } certs = append(certs, cert) } return certs, nil } // WithMeter is an option that sets the authority's [Meter] to the provided one. func WithMeter(m Meter) Option { if m == nil { m = noopMeter{} } return func(a *Authority) (_ error) { a.meter = m return } } ================================================ FILE: authority/policy/engine.go ================================================ package policy import ( "crypto/x509" "errors" "fmt" "golang.org/x/crypto/ssh" ) // Engine is a container for multiple policies. type Engine struct { x509Policy X509Policy sshUserPolicy UserPolicy sshHostPolicy HostPolicy } // New returns a new Engine using Options. func New(options *Options) (*Engine, error) { // if no options provided, return early if options == nil { //nolint:nilnil // legacy return nil, nil } var ( x509Policy X509Policy sshHostPolicy HostPolicy sshUserPolicy UserPolicy err error ) // initialize the x509 allow/deny policy engine if x509Policy, err = NewX509PolicyEngine(options.GetX509Options()); err != nil { return nil, err } // initialize the SSH allow/deny policy engine for host certificates if sshHostPolicy, err = NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil { return nil, err } // initialize the SSH allow/deny policy engine for user certificates if sshUserPolicy, err = NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil { return nil, err } return &Engine{ x509Policy: x509Policy, sshHostPolicy: sshHostPolicy, sshUserPolicy: sshUserPolicy, }, nil } // IsX509CertificateAllowed evaluates an X.509 certificate against // the X.509 policy (if available) and returns an error if one of the // names in the certificate is not allowed. func (e *Engine) IsX509CertificateAllowed(cert *x509.Certificate) error { // return early if there's no policy to evaluate if e == nil || e.x509Policy == nil { return nil } // return result of X.509 policy evaluation return e.x509Policy.IsX509CertificateAllowed(cert) } // AreSANsAllowed evaluates the slice of SANs against the X.509 policy // (if available) and returns an error if one of the SANs is not allowed. func (e *Engine) AreSANsAllowed(sans []string) error { // return early if there's no policy to evaluate if e == nil || e.x509Policy == nil { return nil } // return result of X.509 policy evaluation return e.x509Policy.AreSANsAllowed(sans) } // IsSSHCertificateAllowed evaluates an SSH certificate against the // user or host policy (if configured) and returns an error if one of the // principals in the certificate is not allowed. func (e *Engine) IsSSHCertificateAllowed(cert *ssh.Certificate) error { // return early if there's no policy to evaluate if e == nil || (e.sshHostPolicy == nil && e.sshUserPolicy == nil) { return nil } switch cert.CertType { case ssh.HostCert: // when no host policy engine is configured, but a user policy engine is // configured, the host certificate is denied. if e.sshHostPolicy == nil && e.sshUserPolicy != nil { return errors.New("authority not allowed to sign SSH host certificates when SSH user certificate policy is active") } // return result of SSH host policy evaluation return e.sshHostPolicy.IsSSHCertificateAllowed(cert) case ssh.UserCert: // when no user policy engine is configured, but a host policy engine is // configured, the user certificate is denied. if e.sshUserPolicy == nil && e.sshHostPolicy != nil { return errors.New("authority not allowed to sign SSH user certificates when SSH host certificate policy is active") } // return result of SSH user policy evaluation return e.sshUserPolicy.IsSSHCertificateAllowed(cert) default: return fmt.Errorf("unexpected SSH certificate type %q", cert.CertType) } } ================================================ FILE: authority/policy/options.go ================================================ package policy // Options is a container for authority level x509 and SSH // policy configuration. type Options struct { X509 *X509PolicyOptions `json:"x509,omitempty"` SSH *SSHPolicyOptions `json:"ssh,omitempty"` } // GetX509Options returns the x509 authority level policy // configuration func (o *Options) GetX509Options() *X509PolicyOptions { if o == nil { return nil } return o.X509 } // GetSSHOptions returns the SSH authority level policy // configuration func (o *Options) GetSSHOptions() *SSHPolicyOptions { if o == nil { return nil } return o.SSH } // X509PolicyOptionsInterface is an interface for providers // of x509 allowed and denied names. type X509PolicyOptionsInterface interface { GetAllowedNameOptions() *X509NameOptions GetDeniedNameOptions() *X509NameOptions AreWildcardNamesAllowed() bool } // X509PolicyOptions is a container for x509 allowed and denied // names. type X509PolicyOptions struct { // AllowedNames contains the x509 allowed names AllowedNames *X509NameOptions `json:"allow,omitempty"` // DeniedNames contains the x509 denied names DeniedNames *X509NameOptions `json:"deny,omitempty"` // AllowWildcardNames indicates if literal wildcard names // like *.example.com are allowed. Defaults to false. AllowWildcardNames bool `json:"allowWildcardNames,omitempty"` } // X509NameOptions models the X509 name policy configuration. type X509NameOptions struct { CommonNames []string `json:"cn,omitempty"` DNSDomains []string `json:"dns,omitempty"` IPRanges []string `json:"ip,omitempty"` EmailAddresses []string `json:"email,omitempty"` URIDomains []string `json:"uri,omitempty"` } // HasNames checks if the AllowedNameOptions has one or more // names configured. func (o *X509NameOptions) HasNames() bool { return len(o.CommonNames) > 0 || len(o.DNSDomains) > 0 || len(o.IPRanges) > 0 || len(o.EmailAddresses) > 0 || len(o.URIDomains) > 0 } // GetAllowedNameOptions returns x509 allowed name policy configuration func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions { if o == nil { return nil } return o.AllowedNames } // GetDeniedNameOptions returns the x509 denied name policy configuration func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions { if o == nil { return nil } return o.DeniedNames } // AreWildcardNamesAllowed returns whether the authority allows // literal wildcard names to be signed. func (o *X509PolicyOptions) AreWildcardNamesAllowed() bool { if o == nil { return true } return o.AllowWildcardNames } // SSHPolicyOptionsInterface is an interface for providers of // SSH user and host name policy configuration. type SSHPolicyOptionsInterface interface { GetAllowedUserNameOptions() *SSHNameOptions GetDeniedUserNameOptions() *SSHNameOptions GetAllowedHostNameOptions() *SSHNameOptions GetDeniedHostNameOptions() *SSHNameOptions } // SSHPolicyOptions is a container for SSH user and host policy // configuration type SSHPolicyOptions struct { // User contains SSH user certificate options. User *SSHUserCertificateOptions `json:"user,omitempty"` // Host contains SSH host certificate options. Host *SSHHostCertificateOptions `json:"host,omitempty"` } // GetAllowedUserNameOptions returns the SSH allowed user name policy // configuration. func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions { if o == nil || o.User == nil { return nil } return o.User.AllowedNames } // GetDeniedUserNameOptions returns the SSH denied user name policy // configuration. func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions { if o == nil || o.User == nil { return nil } return o.User.DeniedNames } // GetAllowedHostNameOptions returns the SSH allowed host name policy // configuration. func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions { if o == nil || o.Host == nil { return nil } return o.Host.AllowedNames } // GetDeniedHostNameOptions returns the SSH denied host name policy // configuration. func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions { if o == nil || o.Host == nil { return nil } return o.Host.DeniedNames } // SSHUserCertificateOptions is a collection of SSH user certificate options. type SSHUserCertificateOptions struct { // AllowedNames contains the names the provisioner is authorized to sign AllowedNames *SSHNameOptions `json:"allow,omitempty"` // DeniedNames contains the names the provisioner is not authorized to sign DeniedNames *SSHNameOptions `json:"deny,omitempty"` } // SSHHostCertificateOptions is a collection of SSH host certificate options. // It's an alias of SSHUserCertificateOptions, as the options are the same // for both types of certificates. type SSHHostCertificateOptions SSHUserCertificateOptions // SSHNameOptions models the SSH name policy configuration. type SSHNameOptions struct { DNSDomains []string `json:"dns,omitempty"` IPRanges []string `json:"ip,omitempty"` EmailAddresses []string `json:"email,omitempty"` Principals []string `json:"principal,omitempty"` } // GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the // names that a provisioner is authorized to sign SSH certificates for. func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions { if o == nil { return nil } return o.AllowedNames } // GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the // names that a provisioner is NOT authorized to sign SSH certificates for. func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions { if o == nil { return nil } return o.DeniedNames } // HasNames checks if the SSHNameOptions has one or more // names configured. func (o *SSHNameOptions) HasNames() bool { return len(o.DNSDomains) > 0 || len(o.IPRanges) > 0 || len(o.EmailAddresses) > 0 || len(o.Principals) > 0 } ================================================ FILE: authority/policy/options_test.go ================================================ package policy import ( "testing" ) func TestX509PolicyOptions_IsWildcardLiteralAllowed(t *testing.T) { tests := []struct { name string options *X509PolicyOptions want bool }{ { name: "nil-options", options: nil, want: true, }, { name: "not-set", options: &X509PolicyOptions{}, want: false, }, { name: "set-true", options: &X509PolicyOptions{ AllowWildcardNames: true, }, want: true, }, { name: "set-false", options: &X509PolicyOptions{ AllowWildcardNames: false, }, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.options.AreWildcardNamesAllowed(); got != tt.want { t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/policy/policy.go ================================================ package policy import ( "fmt" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/policy" ) // X509Policy is an alias for policy.X509NamePolicyEngine type X509Policy policy.X509NamePolicyEngine // UserPolicy is an alias for policy.SSHNamePolicyEngine type UserPolicy policy.SSHNamePolicyEngine // HostPolicy is an alias for policy.SSHNamePolicyEngine type HostPolicy policy.SSHNamePolicyEngine // NewX509PolicyEngine creates a new x509 name policy engine func NewX509PolicyEngine(policyOptions X509PolicyOptionsInterface) (X509Policy, error) { // return early if no policy engine options to configure if policyOptions == nil { //nolint:nilnil,nolintlint // expected values return nil, nil } options := []policy.NamePolicyOption{} allowed := policyOptions.GetAllowedNameOptions() if allowed != nil && allowed.HasNames() { options = append(options, policy.WithPermittedCommonNames(allowed.CommonNames...), policy.WithPermittedDNSDomains(allowed.DNSDomains...), policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...), policy.WithPermittedEmailAddresses(allowed.EmailAddresses...), policy.WithPermittedURIDomains(allowed.URIDomains...), ) } denied := policyOptions.GetDeniedNameOptions() if denied != nil && denied.HasNames() { options = append(options, policy.WithExcludedCommonNames(denied.CommonNames...), policy.WithExcludedDNSDomains(denied.DNSDomains...), policy.WithExcludedIPsOrCIDRs(denied.IPRanges...), policy.WithExcludedEmailAddresses(denied.EmailAddresses...), policy.WithExcludedURIDomains(denied.URIDomains...), ) } // ensure no policy engine is returned when no name options were provided if len(options) == 0 { //nolint:nilnil,nolintlint // expected values return nil, nil } // check if configuration specifies that wildcard names are allowed if policyOptions.AreWildcardNamesAllowed() { options = append(options, policy.WithAllowLiteralWildcardNames()) } // enable subject common name verification by default options = append(options, policy.WithSubjectCommonNameVerification()) return policy.New(options...) } type sshPolicyEngineType string const ( UserPolicyEngineType sshPolicyEngineType = "user" HostPolicyEngineType sshPolicyEngineType = "host" ) // newSSHUserPolicyEngine creates a new SSH user certificate policy engine func NewSSHUserPolicyEngine(policyOptions SSHPolicyOptionsInterface) (UserPolicy, error) { policyEngine, err := newSSHPolicyEngine(policyOptions, UserPolicyEngineType) if err != nil { return nil, err } return policyEngine, nil } // newSSHHostPolicyEngine create a new SSH host certificate policy engine func NewSSHHostPolicyEngine(policyOptions SSHPolicyOptionsInterface) (HostPolicy, error) { policyEngine, err := newSSHPolicyEngine(policyOptions, HostPolicyEngineType) if err != nil { return nil, err } return policyEngine, nil } // newSSHPolicyEngine creates a new SSH name policy engine func newSSHPolicyEngine(policyOptions SSHPolicyOptionsInterface, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) { // return early if no policy engine options to configure if policyOptions == nil { //nolint:nilnil,nolintlint // expected values return nil, nil } var ( allowed *SSHNameOptions denied *SSHNameOptions ) switch typ { case UserPolicyEngineType: allowed = policyOptions.GetAllowedUserNameOptions() denied = policyOptions.GetDeniedUserNameOptions() case HostPolicyEngineType: allowed = policyOptions.GetAllowedHostNameOptions() denied = policyOptions.GetDeniedHostNameOptions() default: return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ) } options := []policy.NamePolicyOption{} if allowed != nil && allowed.HasNames() { options = append(options, policy.WithPermittedDNSDomains(allowed.DNSDomains...), policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...), policy.WithPermittedEmailAddresses(allowed.EmailAddresses...), policy.WithPermittedPrincipals(allowed.Principals...), ) } if denied != nil && denied.HasNames() { options = append(options, policy.WithExcludedDNSDomains(denied.DNSDomains...), policy.WithExcludedIPsOrCIDRs(denied.IPRanges...), policy.WithExcludedEmailAddresses(denied.EmailAddresses...), policy.WithExcludedPrincipals(denied.Principals...), ) } // ensure no policy engine is returned when no name options were provided if len(options) == 0 { //nolint:nilnil,nolintlint // expected values return nil, nil } return policy.New(options...) } func LinkedToCertificates(p *linkedca.Policy) *Options { // return early if p == nil { return nil } // return early if x509 nor SSH is set if p.GetX509() == nil && p.GetSsh() == nil { return nil } opts := &Options{} // fill x509 policy configuration if x509 := p.GetX509(); x509 != nil { opts.X509 = &X509PolicyOptions{} if allow := x509.GetAllow(); allow != nil { opts.X509.AllowedNames = &X509NameOptions{} if allow.Dns != nil { opts.X509.AllowedNames.DNSDomains = allow.Dns } if allow.Ips != nil { opts.X509.AllowedNames.IPRanges = allow.Ips } if allow.Emails != nil { opts.X509.AllowedNames.EmailAddresses = allow.Emails } if allow.Uris != nil { opts.X509.AllowedNames.URIDomains = allow.Uris } if allow.CommonNames != nil { opts.X509.AllowedNames.CommonNames = allow.CommonNames } } if deny := x509.GetDeny(); deny != nil { opts.X509.DeniedNames = &X509NameOptions{} if deny.Dns != nil { opts.X509.DeniedNames.DNSDomains = deny.Dns } if deny.Ips != nil { opts.X509.DeniedNames.IPRanges = deny.Ips } if deny.Emails != nil { opts.X509.DeniedNames.EmailAddresses = deny.Emails } if deny.Uris != nil { opts.X509.DeniedNames.URIDomains = deny.Uris } if deny.CommonNames != nil { opts.X509.DeniedNames.CommonNames = deny.CommonNames } } opts.X509.AllowWildcardNames = x509.GetAllowWildcardNames() } // fill ssh policy configuration if ssh := p.GetSsh(); ssh != nil { opts.SSH = &SSHPolicyOptions{} if host := ssh.GetHost(); host != nil { opts.SSH.Host = &SSHHostCertificateOptions{} if allow := host.GetAllow(); allow != nil { opts.SSH.Host.AllowedNames = &SSHNameOptions{} if allow.Dns != nil { opts.SSH.Host.AllowedNames.DNSDomains = allow.Dns } if allow.Ips != nil { opts.SSH.Host.AllowedNames.IPRanges = allow.Ips } if allow.Principals != nil { opts.SSH.Host.AllowedNames.Principals = allow.Principals } } if deny := host.GetDeny(); deny != nil { opts.SSH.Host.DeniedNames = &SSHNameOptions{} if deny.Dns != nil { opts.SSH.Host.DeniedNames.DNSDomains = deny.Dns } if deny.Ips != nil { opts.SSH.Host.DeniedNames.IPRanges = deny.Ips } if deny.Principals != nil { opts.SSH.Host.DeniedNames.Principals = deny.Principals } } } if user := ssh.GetUser(); user != nil { opts.SSH.User = &SSHUserCertificateOptions{} if allow := user.GetAllow(); allow != nil { opts.SSH.User.AllowedNames = &SSHNameOptions{} if allow.Emails != nil { opts.SSH.User.AllowedNames.EmailAddresses = allow.Emails } if allow.Principals != nil { opts.SSH.User.AllowedNames.Principals = allow.Principals } } if deny := user.GetDeny(); deny != nil { opts.SSH.User.DeniedNames = &SSHNameOptions{} if deny.Emails != nil { opts.SSH.User.DeniedNames.EmailAddresses = deny.Emails } if deny.Principals != nil { opts.SSH.User.DeniedNames.Principals = deny.Principals } } } } return opts } ================================================ FILE: authority/policy/policy_test.go ================================================ package policy import ( "testing" "github.com/google/go-cmp/cmp" "github.com/smallstep/linkedca" ) func TestPolicyToCertificates(t *testing.T) { type args struct { policy *linkedca.Policy } tests := []struct { name string args args want *Options }{ { name: "nil", args: args{ policy: nil, }, want: nil, }, { name: "no-policy", args: args{ &linkedca.Policy{}, }, want: nil, }, { name: "partial-policy", args: args{ &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, AllowWildcardNames: false, }, }, }, want: &Options{ X509: &X509PolicyOptions{ AllowedNames: &X509NameOptions{ DNSDomains: []string{"*.local"}, }, AllowWildcardNames: false, }, }, }, { name: "full-policy", args: args{ &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step"}, Ips: []string{"127.0.0.1/24"}, Emails: []string{"*.example.com"}, Uris: []string{"https://*.local"}, CommonNames: []string{"some name"}, }, Deny: &linkedca.X509Names{ Dns: []string{"bad"}, Ips: []string{"127.0.0.30"}, Emails: []string{"badhost.example.com"}, Uris: []string{"https://badhost.local"}, CommonNames: []string{"another name"}, }, AllowWildcardNames: true, }, Ssh: &linkedca.SSHPolicy{ Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.localhost"}, Ips: []string{"127.0.0.1/24"}, Principals: []string{"user"}, }, Deny: &linkedca.SSHHostNames{ Dns: []string{"badhost.localhost"}, Ips: []string{"127.0.0.40"}, Principals: []string{"root"}, }, }, User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@work"}, Principals: []string{"user"}, }, Deny: &linkedca.SSHUserNames{ Emails: []string{"root@work"}, Principals: []string{"root"}, }, }, }, }, }, want: &Options{ X509: &X509PolicyOptions{ AllowedNames: &X509NameOptions{ DNSDomains: []string{"step"}, IPRanges: []string{"127.0.0.1/24"}, EmailAddresses: []string{"*.example.com"}, URIDomains: []string{"https://*.local"}, CommonNames: []string{"some name"}, }, DeniedNames: &X509NameOptions{ DNSDomains: []string{"bad"}, IPRanges: []string{"127.0.0.30"}, EmailAddresses: []string{"badhost.example.com"}, URIDomains: []string{"https://badhost.local"}, CommonNames: []string{"another name"}, }, AllowWildcardNames: true, }, SSH: &SSHPolicyOptions{ Host: &SSHHostCertificateOptions{ AllowedNames: &SSHNameOptions{ DNSDomains: []string{"*.localhost"}, IPRanges: []string{"127.0.0.1/24"}, Principals: []string{"user"}, }, DeniedNames: &SSHNameOptions{ DNSDomains: []string{"badhost.localhost"}, IPRanges: []string{"127.0.0.40"}, Principals: []string{"root"}, }, }, User: &SSHUserCertificateOptions{ AllowedNames: &SSHNameOptions{ EmailAddresses: []string{"@work"}, Principals: []string{"user"}, }, DeniedNames: &SSHNameOptions{ EmailAddresses: []string{"root@work"}, Principals: []string{"root"}, }, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := LinkedToCertificates(tt.args.policy) if !cmp.Equal(tt.want, got) { t.Errorf("policyToCertificates() diff=\n%s", cmp.Diff(tt.want, got)) } }) } } ================================================ FILE: authority/policy.go ================================================ package authority import ( "context" "errors" "fmt" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/authority/admin" authPolicy "github.com/smallstep/certificates/authority/policy" policy "github.com/smallstep/certificates/policy" ) type policyErrorType int const ( AdminLockOut policyErrorType = iota + 1 StoreFailure ReloadFailure ConfigurationFailure EvaluationFailure InternalFailure ) type PolicyError struct { Typ policyErrorType Err error } func (p *PolicyError) Error() string { return p.Err.Error() } func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { a.adminMutex.Lock() defer a.adminMutex.Unlock() p, err := a.adminDB.GetAuthorityPolicy(ctx) if err != nil { return nil, &PolicyError{ Typ: InternalFailure, Err: err, } } return p, nil } func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) { a.adminMutex.Lock() defer a.adminMutex.Unlock() if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil { return nil, err } if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil { return nil, &PolicyError{ Typ: StoreFailure, Err: err, } } if err := a.reloadPolicyEngines(ctx); err != nil { return nil, &PolicyError{ Typ: ReloadFailure, Err: fmt.Errorf("error reloading policy engines when creating authority policy: %w", err), } } return p, nil } func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) { a.adminMutex.Lock() defer a.adminMutex.Unlock() if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil { return nil, err } if err := a.adminDB.UpdateAuthorityPolicy(ctx, p); err != nil { return nil, &PolicyError{ Typ: StoreFailure, Err: err, } } if err := a.reloadPolicyEngines(ctx); err != nil { return nil, &PolicyError{ Typ: ReloadFailure, Err: fmt.Errorf("error reloading policy engines when updating authority policy: %w", err), } } return p, nil } func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error { a.adminMutex.Lock() defer a.adminMutex.Unlock() if err := a.adminDB.DeleteAuthorityPolicy(ctx); err != nil { return &PolicyError{ Typ: StoreFailure, Err: err, } } if err := a.reloadPolicyEngines(ctx); err != nil { return &PolicyError{ Typ: ReloadFailure, Err: fmt.Errorf("error reloading policy engines when deleting authority policy: %w", err), } } return nil } func (a *Authority) checkAuthorityPolicy(ctx context.Context, currentAdmin *linkedca.Admin, p *linkedca.Policy) error { // no policy and thus nothing to evaluate; return early if p == nil { return nil } // get all current admins from the database allAdmins, err := a.adminDB.GetAdmins(ctx) if err != nil { return &PolicyError{ Typ: InternalFailure, Err: fmt.Errorf("error retrieving admins: %w", err), } } return a.checkPolicy(ctx, currentAdmin, allAdmins, p) } func (a *Authority) checkProvisionerPolicy(ctx context.Context, provName string, p *linkedca.Policy) error { // no policy and thus nothing to evaluate; return early if p == nil { return nil } // get all admins for the provisioner; ignoring case in which they're not found allProvisionerAdmins, _ := a.admins.LoadByProvisioner(provName) // check the policy; pass in nil as the current admin, as all admins for the // provisioner will be checked by looping through allProvisionerAdmins. Also, // the current admin may be a super admin not belonging to the provisioner, so // can't be blocked, but is not required to be in the policy, either. return a.checkPolicy(ctx, nil, allProvisionerAdmins, p) } // checkPolicy checks if a new or updated policy configuration results in the user // locking themselves or other admins out of the CA. func (a *Authority) checkPolicy(_ context.Context, currentAdmin *linkedca.Admin, otherAdmins []*linkedca.Admin, p *linkedca.Policy) error { // convert the policy; return early if nil policyOptions := authPolicy.LinkedToCertificates(p) if policyOptions == nil { return nil } engine, err := authPolicy.NewX509PolicyEngine(policyOptions.GetX509Options()) if err != nil { return &PolicyError{ Typ: ConfigurationFailure, Err: err, } } // when an empty X.509 policy is provided, the resulting engine is nil // and there's no policy to evaluate. if engine == nil { return nil } // TODO(hs): Provide option to force the policy, even when the admin subject would be locked out? // check if the admin user that instructed the authority policy to be // created or updated, would still be allowed when the provided policy // would be applied. This case is skipped when current admin is nil, which // is the case when a provisioner policy is checked. if currentAdmin != nil { sans := []string{currentAdmin.GetSubject()} if err := isAllowed(engine, sans); err != nil { return err } } // loop through admins to verify that none of them would be // locked out when the new policy were to be applied. Returns // an error with a message that includes the admin subject that // would be locked out. for _, adm := range otherAdmins { sans := []string{adm.GetSubject()} if err := isAllowed(engine, sans); err != nil { return err } } // TODO(hs): mask the error message for non-super admins? return nil } // reloadPolicyEngines reloads x509 and SSH policy engines using // configuration stored in the DB or from the configuration file. func (a *Authority) reloadPolicyEngines(ctx context.Context) error { var ( err error policyOptions *authPolicy.Options ) if a.config.AuthorityConfig.EnableAdmin { // temporarily disable policy loading when LinkedCA is in use if _, ok := a.adminDB.(*linkedCaClient); ok { return nil } linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx) if err != nil { var ae *admin.Error if isAdminError := errors.As(err, &ae); (isAdminError && ae.Type != admin.ErrorNotFoundType.String()) || !isAdminError { return fmt.Errorf("error getting policy to (re)load policy engines: %w", err) } } policyOptions = authPolicy.LinkedToCertificates(linkedPolicy) } else { policyOptions = a.config.AuthorityConfig.Policy } engine, err := authPolicy.New(policyOptions) if err != nil { return err } // only update the policy engine when no error was returned a.policyEngine = engine return nil } func isAllowed(engine authPolicy.X509Policy, sans []string) error { if err := engine.AreSANsAllowed(sans); err != nil { var policyErr *policy.NamePolicyError isNamePolicyError := errors.As(err, &policyErr) if isNamePolicyError && policyErr.Reason == policy.NotAllowed { return &PolicyError{ Typ: AdminLockOut, Err: fmt.Errorf("the provided policy would lock out %s from the CA. Please create an x509 policy to include %s as an allowed DNS name", sans, sans), } } return &PolicyError{ Typ: EvaluationFailure, Err: err, } } return nil } ================================================ FILE: authority/policy_test.go ================================================ package authority import ( "context" "errors" "reflect" "testing" "github.com/go-jose/go-jose/v3" "github.com/stretchr/testify/assert" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/administrator" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" ) func TestAuthority_checkPolicy(t *testing.T) { type test struct { ctx context.Context currentAdmin *linkedca.Admin otherAdmins []*linkedca.Admin policy *linkedca.Policy err *PolicyError } tests := map[string]func(t *testing.T) test{ "fail/NewX509PolicyEngine-error": func(t *testing.T) test { return test{ ctx: context.Background(), policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"**.local"}, }, }, }, err: &PolicyError{ Typ: ConfigurationFailure, Err: errors.New("cannot parse permitted domain constraint \"**.local\": domain constraint \"**.local\" can only have wildcard as starting character"), }, } }, "fail/currentAdmin-evaluation-error": func(t *testing.T) test { return test{ ctx: context.Background(), currentAdmin: &linkedca.Admin{Subject: "*"}, otherAdmins: []*linkedca.Admin{}, policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, }, err: &PolicyError{ Typ: EvaluationFailure, Err: errors.New("cannot parse dns domain \"*\""), }, } }, "fail/currentAdmin-lockout": func(t *testing.T) test { return test{ ctx: context.Background(), currentAdmin: &linkedca.Admin{Subject: "step"}, otherAdmins: []*linkedca.Admin{ { Subject: "otherAdmin", }, }, policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, }, err: &PolicyError{ Typ: AdminLockOut, Err: errors.New("the provided policy would lock out [step] from the CA. Please create an x509 policy to include [step] as an allowed DNS name"), }, } }, "fail/otherAdmins-evaluation-error": func(t *testing.T) test { return test{ ctx: context.Background(), currentAdmin: &linkedca.Admin{Subject: "step"}, otherAdmins: []*linkedca.Admin{ { Subject: "other", }, { Subject: "**", }, }, policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "other", "*.local"}, }, }, }, err: &PolicyError{ Typ: EvaluationFailure, Err: errors.New("cannot parse dns domain \"**\""), }, } }, "fail/otherAdmins-lockout": func(t *testing.T) test { return test{ ctx: context.Background(), currentAdmin: &linkedca.Admin{Subject: "step"}, otherAdmins: []*linkedca.Admin{ { Subject: "otherAdmin", }, }, policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step"}, }, }, }, err: &PolicyError{ Typ: AdminLockOut, Err: errors.New("the provided policy would lock out [otherAdmin] from the CA. Please create an x509 policy to include [otherAdmin] as an allowed DNS name"), }, } }, "ok/no-policy": func(t *testing.T) test { return test{ ctx: context.Background(), currentAdmin: &linkedca.Admin{Subject: "step"}, otherAdmins: []*linkedca.Admin{}, policy: nil, } }, "ok/empty-policy": func(t *testing.T) test { return test{ ctx: context.Background(), currentAdmin: &linkedca.Admin{Subject: "step"}, otherAdmins: []*linkedca.Admin{}, policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{}, }, }, }, } }, "ok/policy": func(t *testing.T) test { return test{ ctx: context.Background(), currentAdmin: &linkedca.Admin{Subject: "step"}, otherAdmins: []*linkedca.Admin{ { Subject: "otherAdmin", }, }, policy: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, } }, } for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { a := &Authority{} err := a.checkPolicy(tc.ctx, tc.currentAdmin, tc.otherAdmins, tc.policy) if tc.err == nil { assert.Nil(t, err) } else { assert.IsType(t, &PolicyError{}, err) var pe *PolicyError if assert.True(t, errors.As(err, &pe)) { assert.Equal(t, tc.err.Typ, pe.Typ) assert.Equal(t, tc.err.Error(), pe.Error()) } } }) } } func mustPolicyEngine(t *testing.T, options *policy.Options) *policy.Engine { engine, err := policy.New(options) if err != nil { t.Fatal(err) } return engine } func TestAuthority_reloadPolicyEngines(t *testing.T) { existingPolicyEngine, err := policy.New(&policy.Options{ X509: &policy.X509PolicyOptions{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.hosts.example.com"}, }, }, SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.hosts.example.com"}, }, }, User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ EmailAddresses: []string{"@mails.example.com"}, }, }, }, }) assert.NoError(t, err) newX509Options := &policy.Options{ X509: &policy.X509PolicyOptions{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.X509NameOptions{ DNSDomains: []string{"badhost.local"}, }, AllowWildcardNames: true, }, } newSSHHostOptions := &policy.Options{ SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.SSHNameOptions{ DNSDomains: []string{"badhost.local"}, }, }, }, } newSSHUserOptions := &policy.Options{ SSH: &policy.SSHPolicyOptions{ User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ Principals: []string{"*"}, }, DeniedNames: &policy.SSHNameOptions{ Principals: []string{"root"}, }, }, }, } newSSHOptions := &policy.Options{ SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.SSHNameOptions{ DNSDomains: []string{"badhost.local"}, }, }, User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ Principals: []string{"*"}, }, DeniedNames: &policy.SSHNameOptions{ Principals: []string{"root"}, }, }, }, } newOptions := &policy.Options{ X509: &policy.X509PolicyOptions{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.X509NameOptions{ DNSDomains: []string{"badhost.local"}, }, AllowWildcardNames: true, }, SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.SSHNameOptions{ DNSDomains: []string{"badhost.local"}, }, }, User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ Principals: []string{"*"}, }, DeniedNames: &policy.SSHNameOptions{ Principals: []string{"root"}, }, }, }, } newAdminX509Options := &policy.Options{ X509: &policy.X509PolicyOptions{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.local"}, }, }, } newAdminSSHHostOptions := &policy.Options{ SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.local"}, }, }, }, } newAdminSSHUserOptions := &policy.Options{ SSH: &policy.SSHPolicyOptions{ User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ EmailAddresses: []string{"@example.com"}, }, }, }, } newAdminOptions := &policy.Options{ X509: &policy.X509PolicyOptions{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.X509NameOptions{ DNSDomains: []string{"badhost.local"}, }, AllowWildcardNames: true, }, SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.SSHNameOptions{ DNSDomains: []string{"badhost.local"}, }, }, User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ EmailAddresses: []string{"@example.com"}, }, DeniedNames: &policy.SSHNameOptions{ EmailAddresses: []string{"baduser@example.com"}, }, }, }, } tests := []struct { name string config *config.Config adminDB admin.DB ctx context.Context expected *policy.Engine wantErr bool }{ { name: "fail/standalone-x509-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: false, Policy: &policy.Options{ X509: &policy.X509PolicyOptions{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"**.local"}, }, }, }, }, }, ctx: context.Background(), wantErr: true, expected: existingPolicyEngine, }, { name: "fail/standalone-ssh-host-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: false, Policy: &policy.Options{ SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"**.local"}, }, }, }, }, }, }, ctx: context.Background(), wantErr: true, expected: existingPolicyEngine, }, { name: "fail/standalone-ssh-user-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: false, Policy: &policy.Options{ SSH: &policy.SSHPolicyOptions{ User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ EmailAddresses: []string{"**example.com"}, }, }, }, }, }, }, ctx: context.Background(), wantErr: true, expected: existingPolicyEngine, }, { name: "fail/adminDB.GetAuthorityPolicy-error", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, errors.New("force") }, }, ctx: context.Background(), wantErr: true, expected: existingPolicyEngine, }, { name: "fail/admin-x509-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"**.local"}, }, }, }, nil }, }, ctx: context.Background(), wantErr: true, expected: existingPolicyEngine, }, { name: "fail/admin-ssh-host-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{ Ssh: &linkedca.SSHPolicy{ Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"**.local"}, }, }, }, }, nil }, }, ctx: context.Background(), wantErr: true, expected: existingPolicyEngine, }, { name: "fail/admin-ssh-user-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{ Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@@example.com"}, }, }, }, }, nil }, }, ctx: context.Background(), wantErr: true, expected: existingPolicyEngine, }, { name: "ok/linkedca-unsupported", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &linkedCaClient{}, ctx: context.Background(), wantErr: false, expected: existingPolicyEngine, }, { name: "ok/standalone-no-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: false, Policy: nil, }, }, ctx: context.Background(), wantErr: false, expected: mustPolicyEngine(t, nil), }, { name: "ok/standalone-x509-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: false, Policy: &policy.Options{ X509: &policy.X509PolicyOptions{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.X509NameOptions{ DNSDomains: []string{"badhost.local"}, }, AllowWildcardNames: true, }, }, }, }, ctx: context.Background(), wantErr: false, expected: mustPolicyEngine(t, newX509Options), }, { name: "ok/standalone-ssh-host-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: false, Policy: &policy.Options{ SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.SSHNameOptions{ DNSDomains: []string{"badhost.local"}, }, }, }, }, }, }, ctx: context.Background(), wantErr: false, expected: mustPolicyEngine(t, newSSHHostOptions), }, { name: "ok/standalone-ssh-user-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: false, Policy: &policy.Options{ SSH: &policy.SSHPolicyOptions{ User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ Principals: []string{"*"}, }, DeniedNames: &policy.SSHNameOptions{ Principals: []string{"root"}, }, }, }, }, }, }, ctx: context.Background(), wantErr: false, expected: mustPolicyEngine(t, newSSHUserOptions), }, { name: "ok/standalone-ssh-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: false, Policy: &policy.Options{ SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.SSHNameOptions{ DNSDomains: []string{"badhost.local"}, }, }, User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ Principals: []string{"*"}, }, DeniedNames: &policy.SSHNameOptions{ Principals: []string{"root"}, }, }, }, }, }, }, ctx: context.Background(), wantErr: false, expected: mustPolicyEngine(t, newSSHOptions), }, { name: "ok/standalone-full-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: false, Policy: &policy.Options{ X509: &policy.X509PolicyOptions{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.X509NameOptions{ DNSDomains: []string{"badhost.local"}, }, AllowWildcardNames: true, }, SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.SSHNameOptions{ DNSDomains: []string{"badhost.local"}, }, }, User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ Principals: []string{"*"}, }, DeniedNames: &policy.SSHNameOptions{ Principals: []string{"root"}, }, }, }, }, }, }, ctx: context.Background(), wantErr: false, expected: mustPolicyEngine(t, newOptions), }, { name: "ok/admin-x509-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, }, }, nil }, }, ctx: context.Background(), wantErr: false, expected: mustPolicyEngine(t, newAdminX509Options), }, { name: "ok/admin-ssh-host-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{ Ssh: &linkedca.SSHPolicy{ Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.local"}, }, }, }, }, nil }, }, ctx: context.Background(), wantErr: false, expected: mustPolicyEngine(t, newAdminSSHHostOptions), }, { name: "ok/admin-ssh-user-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{ Ssh: &linkedca.SSHPolicy{ User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@example.com"}, }, }, }, }, nil }, }, ctx: context.Background(), wantErr: false, expected: mustPolicyEngine(t, newAdminSSHUserOptions), }, { name: "ok/admin-full-policy", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, ctx: context.Background(), adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, Deny: &linkedca.X509Names{ Dns: []string{"badhost.local"}, }, AllowWildcardNames: true, }, Ssh: &linkedca.SSHPolicy{ Host: &linkedca.SSHHostPolicy{ Allow: &linkedca.SSHHostNames{ Dns: []string{"*.local"}, }, Deny: &linkedca.SSHHostNames{ Dns: []string{"badhost.local"}, }, }, User: &linkedca.SSHUserPolicy{ Allow: &linkedca.SSHUserNames{ Emails: []string{"@example.com"}, }, Deny: &linkedca.SSHUserNames{ Emails: []string{"baduser@example.com"}, }, }, }, }, nil }, }, wantErr: false, expected: mustPolicyEngine(t, newAdminOptions), }, { // both DB and JSON config; DB config is taken if Admin API is enabled name: "ok/admin-over-standalone", config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, Policy: &policy.Options{ SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.local"}, }, DeniedNames: &policy.SSHNameOptions{ DNSDomains: []string{"badhost.local"}, }, }, User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ Principals: []string{"*"}, }, DeniedNames: &policy.SSHNameOptions{ Principals: []string{"root"}, }, }, }, }, }, }, ctx: context.Background(), adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"*.local"}, }, Deny: &linkedca.X509Names{ Dns: []string{"badhost.local"}, }, AllowWildcardNames: true, }, }, nil }, }, wantErr: false, expected: mustPolicyEngine(t, newX509Options), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ config: tt.config, adminDB: tt.adminDB, policyEngine: existingPolicyEngine, } if err := a.reloadPolicyEngines(tt.ctx); (err != nil) != tt.wantErr { t.Errorf("Authority.reloadPolicyEngines() error = %v, wantErr %v", err, tt.wantErr) } assert.Equal(t, tt.expected, a.policyEngine) }) } } func TestAuthority_checkAuthorityPolicy(t *testing.T) { type fields struct { provisioners *provisioner.Collection admins *administrator.Collection db db.AuthDB adminDB admin.DB } type args struct { ctx context.Context currentAdmin *linkedca.Admin provName string p *linkedca.Policy } tests := []struct { name string fields fields args args wantErr bool }{ { name: "no policy", fields: fields{}, args: args{ currentAdmin: nil, provName: "prov", p: nil, }, wantErr: false, }, { name: "fail/adminDB.GetAdmins-error", fields: fields{ admins: administrator.NewCollection(nil), adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return nil, errors.New("force") }, }, }, args: args{ currentAdmin: &linkedca.Admin{Subject: "step"}, provName: "prov", p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, wantErr: true, }, { name: "fail/policy", fields: fields{ admins: administrator.NewCollection(nil), adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{ { Id: "adminID1", Subject: "anotherAdmin", }, { Id: "adminID2", Subject: "step", }, { Id: "adminID3", Subject: "otherAdmin", }, }, nil }, }, }, args: args{ currentAdmin: &linkedca.Admin{Subject: "step"}, provName: "prov", p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, wantErr: true, }, { name: "ok", fields: fields{ admins: administrator.NewCollection(nil), adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{ { Id: "adminID2", Subject: "step", }, { Id: "adminID3", Subject: "otherAdmin", }, }, nil }, }, }, args: args{ currentAdmin: &linkedca.Admin{Subject: "step"}, provName: "prov", p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ provisioners: tt.fields.provisioners, admins: tt.fields.admins, db: tt.fields.db, adminDB: tt.fields.adminDB, } if err := a.checkAuthorityPolicy(tt.args.ctx, tt.args.currentAdmin, tt.args.p); (err != nil) != tt.wantErr { t.Errorf("Authority.checkProvisionerPolicy() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestAuthority_checkProvisionerPolicy(t *testing.T) { jwkProvisioner := &provisioner.JWK{ ID: "jwkID", Type: "JWK", Name: "jwkProv", Key: &jose.JSONWebKey{KeyID: "jwkKeyID"}, } provisioners := provisioner.NewCollection(testAudiences) provisioners.Store(jwkProvisioner) admins := administrator.NewCollection(provisioners) admins.Store(&linkedca.Admin{ Id: "adminID", Subject: "step", ProvisionerId: "jwkID", }, jwkProvisioner) type fields struct { provisioners *provisioner.Collection admins *administrator.Collection db db.AuthDB adminDB admin.DB } type args struct { ctx context.Context provName string p *linkedca.Policy } tests := []struct { name string fields fields args args wantErr bool }{ { name: "no policy", fields: fields{}, args: args{ provName: "prov", p: nil, }, wantErr: false, }, { name: "fail/policy", fields: fields{ provisioners: provisioners, admins: admins, }, args: args{ provName: "jwkProv", p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"otherAdmin"}, // step not in policy }, }, }, }, wantErr: true, }, { name: "ok", fields: fields{ provisioners: provisioners, admins: admins, }, args: args{ provName: "jwkProv", p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ provisioners: tt.fields.provisioners, admins: tt.fields.admins, db: tt.fields.db, adminDB: tt.fields.adminDB, } if err := a.checkProvisionerPolicy(tt.args.ctx, tt.args.provName, tt.args.p); (err != nil) != tt.wantErr { t.Errorf("Authority.checkProvisionerPolicy() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestAuthority_RemoveAuthorityPolicy(t *testing.T) { type fields struct { config *config.Config db db.AuthDB adminDB admin.DB } type args struct { ctx context.Context } tests := []struct { name string fields fields args args wantErr *PolicyError }{ { name: "fail/adminDB.DeleteAuthorityPolicy", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockDeleteAuthorityPolicy: func(ctx context.Context) error { return errors.New("force") }, }, }, wantErr: &PolicyError{ Typ: StoreFailure, Err: errors.New("force"), }, }, { name: "fail/a.reloadPolicyEngines", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockDeleteAuthorityPolicy: func(ctx context.Context) error { return nil }, MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, errors.New("force") }, }, }, wantErr: &PolicyError{ Typ: ReloadFailure, Err: errors.New("error reloading policy engines when deleting authority policy: error getting policy to (re)load policy engines: force"), }, }, { name: "ok", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockDeleteAuthorityPolicy: func(ctx context.Context) error { return nil }, MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, nil }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ config: tt.fields.config, db: tt.fields.db, adminDB: tt.fields.adminDB, } err := a.RemoveAuthorityPolicy(tt.args.ctx) if err != nil { var pe *PolicyError if assert.True(t, errors.As(err, &pe)) { assert.Equal(t, tt.wantErr.Typ, pe.Typ) assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) } return } }) } } func TestAuthority_GetAuthorityPolicy(t *testing.T) { type fields struct { config *config.Config db db.AuthDB adminDB admin.DB } type args struct { ctx context.Context } tests := []struct { name string fields fields args args want *linkedca.Policy wantErr *PolicyError }{ { name: "fail/adminDB.GetAuthorityPolicy", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, errors.New("force") }, }, }, wantErr: &PolicyError{ Typ: InternalFailure, Err: errors.New("force"), }, }, { name: "ok", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{}, nil }, }, }, want: &linkedca.Policy{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ config: tt.fields.config, db: tt.fields.db, adminDB: tt.fields.adminDB, } got, err := a.GetAuthorityPolicy(tt.args.ctx) if err != nil { var pe *PolicyError if assert.True(t, errors.As(err, &pe)) { assert.Equal(t, tt.wantErr.Typ, pe.Typ) assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) } return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetAuthorityPolicy() = %v, want %v", got, tt.want) } }) } } func TestAuthority_CreateAuthorityPolicy(t *testing.T) { type fields struct { config *config.Config db db.AuthDB adminDB admin.DB } type args struct { ctx context.Context adm *linkedca.Admin p *linkedca.Policy } tests := []struct { name string fields fields args args want *linkedca.Policy wantErr *PolicyError }{ { name: "fail/a.checkAuthorityPolicy", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return nil, errors.New("force") }, }, }, args: args{ ctx: context.Background(), adm: &linkedca.Admin{Subject: "step"}, p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, wantErr: &PolicyError{ Typ: InternalFailure, Err: errors.New("error retrieving admins: force"), }, }, { name: "fail/adminDB.CreateAuthorityPolicy", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{}, nil }, MockCreateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { return errors.New("force") }, }, }, args: args{ ctx: context.Background(), adm: &linkedca.Admin{Subject: "step"}, p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, wantErr: &PolicyError{ Typ: StoreFailure, Err: errors.New("force"), }, }, { name: "fail/a.reloadPolicyEngines", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, errors.New("force") }, MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{}, nil }, }, }, args: args{ ctx: context.Background(), adm: &linkedca.Admin{Subject: "step"}, p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, wantErr: &PolicyError{ Typ: ReloadFailure, Err: errors.New("error reloading policy engines when creating authority policy: error getting policy to (re)load policy engines: force"), }, }, { name: "ok", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, nil }, MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{}, nil }, }, }, args: args{ ctx: context.Background(), adm: &linkedca.Admin{Subject: "step"}, p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, want: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ config: tt.fields.config, db: tt.fields.db, adminDB: tt.fields.adminDB, } got, err := a.CreateAuthorityPolicy(tt.args.ctx, tt.args.adm, tt.args.p) if err != nil { var pe *PolicyError if assert.True(t, errors.As(err, &pe)) { assert.Equal(t, tt.wantErr.Typ, pe.Typ) assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) } return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.CreateAuthorityPolicy() = %v, want %v", got, tt.want) } }) } } func TestAuthority_UpdateAuthorityPolicy(t *testing.T) { type fields struct { config *config.Config db db.AuthDB adminDB admin.DB } type args struct { ctx context.Context adm *linkedca.Admin p *linkedca.Policy } tests := []struct { name string fields fields args args want *linkedca.Policy wantErr *PolicyError }{ { name: "fail/a.checkAuthorityPolicy", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return nil, errors.New("force") }, }, }, args: args{ ctx: context.Background(), adm: &linkedca.Admin{Subject: "step"}, p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, wantErr: &PolicyError{ Typ: InternalFailure, Err: errors.New("error retrieving admins: force"), }, }, { name: "fail/adminDB.UpdateAuthorityPolicy", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{}, nil }, MockUpdateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { return errors.New("force") }, }, }, args: args{ ctx: context.Background(), adm: &linkedca.Admin{Subject: "step"}, p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, wantErr: &PolicyError{ Typ: StoreFailure, Err: errors.New("force"), }, }, { name: "fail/a.reloadPolicyEngines", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, errors.New("force") }, MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{}, nil }, }, }, args: args{ ctx: context.Background(), adm: &linkedca.Admin{Subject: "step"}, p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, wantErr: &PolicyError{ Typ: ReloadFailure, Err: errors.New("error reloading policy engines when updating authority policy: error getting policy to (re)load policy engines: force"), }, }, { name: "ok", fields: fields{ config: &config.Config{ AuthorityConfig: &config.AuthConfig{ EnableAdmin: true, }, }, adminDB: &admin.MockDB{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, nil }, MockUpdateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { return nil }, MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { return []*linkedca.Admin{}, nil }, }, }, args: args{ ctx: context.Background(), adm: &linkedca.Admin{Subject: "step"}, p: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, want: &linkedca.Policy{ X509: &linkedca.X509Policy{ Allow: &linkedca.X509Names{ Dns: []string{"step", "otherAdmin"}, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ config: tt.fields.config, db: tt.fields.db, adminDB: tt.fields.adminDB, } got, err := a.UpdateAuthorityPolicy(tt.args.ctx, tt.args.adm, tt.args.p) if err != nil { var pe *PolicyError if assert.True(t, errors.As(err, &pe)) { assert.Equal(t, tt.wantErr.Typ, pe.Typ) assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) } return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.UpdateAuthorityPolicy() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/poolhttp/poolhttp.go ================================================ package poolhttp import ( "net/http" "sync" "github.com/smallstep/certificates/internal/httptransport" ) // Transporter is implemented by custom HTTP clients with a method that // returns an [*http.Transport]. type Transporter interface { Transport() *http.Transport } // Client is an HTTP client that uses a [sync.Pool] to create new and reuse HTTP // clients. It implements the [provisioner.HTTPClient] and [Transporter] // interfaces. This is the HTTP client used by the provisioners. type Client struct { rw sync.RWMutex pool sync.Pool } // New creates a new poolhttp [Client], the [sync.Pool] will initialize a new // [*http.Client] with the given function. func New(fn func() *http.Client) *Client { return &Client{ pool: sync.Pool{ New: func() any { return fn() }, }, } } // SetNew replaces the inner pool with a new [sync.Pool] with the given New // function. This method can be use concurrently with other methods of this // package. func (c *Client) SetNew(fn func() *http.Client) { c.rw.Lock() c.pool = sync.Pool{ New: func() any { return fn() }, } c.rw.Unlock() } // getClient gets a client from the pool. func (c *Client) getClient() *http.Client { c.rw.RLock() defer c.rw.RUnlock() if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { return hc } return nil } // Get issues a GET request to the specified URL. If the response is one of the // following redirect codes, Get follows the redirect after calling the // [Client.CheckRedirect] function: func (c *Client) Get(u string) (resp *http.Response, err error) { if hc := c.getClient(); hc != nil { resp, err = hc.Get(u) c.pool.Put(hc) } else { resp, err = http.DefaultClient.Get(u) } return } // Do sends an HTTP request and returns an HTTP response, following policy (such // as redirects, cookies, auth) as configured on the client. func (c *Client) Do(req *http.Request) (resp *http.Response, err error) { if hc := c.getClient(); hc != nil { resp, err = hc.Do(req) //nolint:gosec // intentional HTTP request to configured endpoint c.pool.Put(hc) } else { resp, err = http.DefaultClient.Do(req) //nolint:gosec // intentional HTTP request to configured endpoint } return } // Transport() returns a clone of the http.Client Transport or returns the // default transport. func (c *Client) Transport() *http.Transport { if hc := c.getClient(); hc != nil { tr, ok := hc.Transport.(*http.Transport) c.pool.Put(hc) if ok { return tr.Clone() } } return httptransport.New() } ================================================ FILE: authority/poolhttp/poolhttp_test.go ================================================ package poolhttp import ( "fmt" "io" "net/http" "net/http/httptest" "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func requireBody(t *testing.T, want string, r io.ReadCloser) { t.Helper() t.Cleanup(func() { require.NoError(t, r.Close()) }) b, err := io.ReadAll(r) require.NoError(t, err) require.Equal(t, want, string(b)) } func TestClient(t *testing.T) { httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello World") })) t.Cleanup(httpSrv.Close) tlsSrv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello World") })) t.Cleanup(tlsSrv.Close) tests := []struct { name string client *Client srv *httptest.Server }{ {"http", New(func() *http.Client { return httpSrv.Client() }), httpSrv}, {"tls", New(func() *http.Client { return tlsSrv.Client() }), tlsSrv}, {"nil", New(func() *http.Client { return nil }), httpSrv}, {"empty", &Client{}, httpSrv}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { resp, err := tc.client.Get(tc.srv.URL) require.NoError(t, err) requireBody(t, "Hello World\n", resp.Body) req, err := http.NewRequest("GET", tc.srv.URL, http.NoBody) require.NoError(t, err) resp, err = tc.client.Do(req) require.NoError(t, err) requireBody(t, "Hello World\n", resp.Body) client := &http.Client{ Transport: tc.client.Transport(), } resp, err = client.Get(tc.srv.URL) require.NoError(t, err) requireBody(t, "Hello World\n", resp.Body) }) } } func TestClient_SetNew(t *testing.T) { srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello World") })) t.Cleanup(srv.Close) c := New(func() *http.Client { return srv.Client() }) tests := []struct { name string client *http.Client assertion assert.ErrorAssertionFunc }{ {"ok", srv.Client(), assert.NoError}, {"fail", http.DefaultClient, assert.Error}, {"ok again", srv.Client(), assert.NoError}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { c.SetNew(func() *http.Client { return tc.client }) _, err := c.Get(srv.URL) tc.assertion(t, err) }) } } func TestClient_parallel(t *testing.T) { t.Parallel() srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello World") })) t.Cleanup(srv.Close) c := New(func() *http.Client { return srv.Client() }) req, err := http.NewRequest("GET", srv.URL, http.NoBody) require.NoError(t, err) for i := range 10 { t.Run(strconv.Itoa(i), func(t *testing.T) { t.Parallel() resp, err := c.Get(srv.URL) require.NoError(t, err) requireBody(t, "Hello World\n", resp.Body) resp, err = c.Do(req) require.NoError(t, err) requireBody(t, "Hello World\n", resp.Body) }) } } ================================================ FILE: authority/provisioner/acme.go ================================================ package provisioner import ( "context" "crypto/x509" "encoding/pem" "fmt" "net" "strings" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/acme/wire" "github.com/smallstep/linkedca" ) // ACMEChallenge represents the supported acme challenges. type ACMEChallenge string //nolint:staticcheck,revive // better names const ( // HTTP_01 is the http-01 ACME challenge. HTTP_01 ACMEChallenge = "http-01" // DNS_01 is the dns-01 ACME challenge. DNS_01 ACMEChallenge = "dns-01" // TLS_ALPN_01 is the tls-alpn-01 ACME challenge. TLS_ALPN_01 ACMEChallenge = "tls-alpn-01" // DEVICE_ATTEST_01 is the device-attest-01 ACME challenge. DEVICE_ATTEST_01 ACMEChallenge = "device-attest-01" // WIREOIDC_01 is the Wire OIDC challenge. WIREOIDC_01 ACMEChallenge = "wire-oidc-01" // WIREDPOP_01 is the Wire DPoP challenge. WIREDPOP_01 ACMEChallenge = "wire-dpop-01" ) // String returns a normalized version of the challenge. func (c ACMEChallenge) String() string { return strings.ToLower(string(c)) } // Validate returns an error if the acme challenge is not a valid one. func (c ACMEChallenge) Validate() error { switch ACMEChallenge(c.String()) { case HTTP_01, DNS_01, TLS_ALPN_01, DEVICE_ATTEST_01, WIREOIDC_01, WIREDPOP_01: return nil default: return fmt.Errorf("acme challenge %q is not supported", c) } } // ACMEAttestationFormat represents the format used on a device-attest-01 // challenge. type ACMEAttestationFormat string const ( // APPLE is the format used to enable device-attest-01 on Apple devices. APPLE ACMEAttestationFormat = "apple" // STEP is the format used to enable device-attest-01 on devices that // provide attestation certificates like the PIV interface on YubiKeys. // // TODO(mariano): should we rename this to something else. STEP ACMEAttestationFormat = "step" // TPM is the format used to enable device-attest-01 with TPMs. TPM ACMEAttestationFormat = "tpm" ) // String returns a normalized version of the attestation format. func (f ACMEAttestationFormat) String() string { return strings.ToLower(string(f)) } // Validate returns an error if the attestation format is not a valid one. func (f ACMEAttestationFormat) Validate() error { switch ACMEAttestationFormat(f.String()) { case APPLE, STEP, TPM: return nil default: return fmt.Errorf("acme attestation format %q is not supported", f) } } // ACME is the acme provisioner type, an entity that can authorize the ACME // provisioning flow. type ACME struct { *base ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` ForceCN bool `json:"forceCN,omitempty"` // TermsOfService contains a URL pointing to the ACME server's // terms of service. Defaults to empty. TermsOfService string `json:"termsOfService,omitempty"` // Website contains an URL pointing to more information about // the ACME server. Defaults to empty. Website string `json:"website,omitempty"` // CaaIdentities is an array of hostnames that the ACME server // identifies itself with. These hostnames can be used by ACME // clients to determine the correct issuer domain name to use // when configuring CAA records. Defaults to empty array. CaaIdentities []string `json:"caaIdentities,omitempty"` // RequireEAB makes the provisioner require ACME EAB to be provided // by clients when creating a new Account. If set to true, the provided // EAB will be verified. If set to false and an EAB is provided, it is // not verified. Defaults to false. RequireEAB bool `json:"requireEAB,omitempty"` // Challenges contains the enabled challenges for this provisioner. If this // value is not set the default http-01, dns-01 and tls-alpn-01 challenges // will be enabled, device-attest-01, wire-oidc-01 and wire-dpop-01 will be // disabled. Challenges []ACMEChallenge `json:"challenges,omitempty"` // AttestationFormats contains the enabled attestation formats for this // provisioner. If this value is not set the default apple, step and tpm // will be used. AttestationFormats []ACMEAttestationFormat `json:"attestationFormats,omitempty"` // AttestationRoots contains a bundle of root certificates in PEM format // that will be used to verify the attestation certificates. If provided, // this bundle will be used even for well-known CAs like Apple and Yubico. AttestationRoots []byte `json:"attestationRoots,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` attestationRootPool *x509.CertPool ctl *Controller } // GetID returns the provisioner unique identifier. func (p ACME) GetID() string { if p.ID != "" { return p.ID } return p.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (p *ACME) GetIDForToken() string { return "acme/" + p.Name } // GetTokenID returns the identifier of the token. This provisioner will always // return [ErrTokenFlowNotSupported]. func (p *ACME) GetTokenID(string) (string, error) { return "", ErrTokenFlowNotSupported } // GetName returns the name of the provisioner. func (p *ACME) GetName() string { return p.Name } // GetType returns the type of provisioner. func (p *ACME) GetType() Type { return TypeACME } // GetEncryptedKey returns the base provisioner encrypted key if it's defined. func (p *ACME) GetEncryptedKey() (string, string, bool) { return "", "", false } // GetOptions returns the configured provisioner options. func (p *ACME) GetOptions() *Options { return p.Options } // DefaultTLSCertDuration returns the default TLS cert duration enforced by // the provisioner. func (p *ACME) DefaultTLSCertDuration() time.Duration { return p.ctl.Claimer.DefaultTLSCertDuration() } // Init initializes and validates the fields of an ACME type. func (p *ACME) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") case p.Name == "": return errors.New("provisioner name cannot be empty") } for _, c := range p.Challenges { if err := c.Validate(); err != nil { return err } } for _, f := range p.AttestationFormats { if err := f.Validate(); err != nil { return err } } // Parse attestation roots. // The pool will be nil if there are no roots. if rest := p.AttestationRoots; len(rest) > 0 { var block *pem.Block var hasCert bool p.attestationRootPool = x509.NewCertPool() for rest != nil { block, rest = pem.Decode(rest) if block == nil { break } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return errors.New("error parsing attestationRoots: malformed certificate") } p.attestationRootPool.AddCert(cert) hasCert = true } if !hasCert { return errors.New("error parsing attestationRoots: no certificates found") } } if err := p.initializeWireOptions(); err != nil { return fmt.Errorf("failed initializing Wire options: %w", err) } p.ctl, err = NewController(p, p.Claims, config, p.Options) return } // initializeWireOptions initializes the options for the ACME Wire // integration. It'll return early if no Wire challenge types are // enabled. func (p *ACME) initializeWireOptions() error { hasWireChallenges := false for _, c := range p.Challenges { if c == WIREOIDC_01 || c == WIREDPOP_01 { hasWireChallenges = true break } } if !hasWireChallenges { return nil } w, err := p.GetOptions().GetWireOptions() if err != nil { return fmt.Errorf("failed getting Wire options: %w", err) } if err := w.Validate(); err != nil { return fmt.Errorf("failed validating Wire options: %w", err) } // at this point the Wire options have been validated, and (mostly) // initialized. Remote keys will be loaded upon the first verification, // currently. // TODO(hs): can/should we "prime" the underlying remote keyset, to verify // auto discovery works as expected? Because of the current way provisioners // are initialized, doing that as part of the initialization isn't the best // time to do it, because it could result in operations not resulting in the // expected result in all cases. return nil } // ACMEIdentifierType encodes ACME Identifier types type ACMEIdentifierType string const ( // IP is the ACME ip identifier type IP ACMEIdentifierType = "ip" // DNS is the ACME dns identifier type DNS ACMEIdentifierType = "dns" // WireUser is the Wire user identifier type WireUser ACMEIdentifierType = "wireapp-user" // WireDevice is the Wire device identifier type WireDevice ACMEIdentifierType = "wireapp-device" ) // ACMEIdentifier encodes ACME Order Identifiers type ACMEIdentifier struct { Type ACMEIdentifierType Value string } // AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a // certificate for an ACME Order Identifier. func (p *ACME) AuthorizeOrderIdentifier(_ context.Context, identifier ACMEIdentifier) error { x509Policy := p.ctl.getPolicy().getX509() // identifier is allowed if no policy is configured if x509Policy == nil { return nil } // assuming only valid identifiers (IP or DNS) are provided var err error switch identifier.Type { case IP: err = x509Policy.IsIPAllowed(net.ParseIP(identifier.Value)) case DNS: err = x509Policy.IsDNSAllowed(identifier.Value) case WireUser: var wireID wire.UserID if wireID, err = wire.ParseUserID(identifier.Value); err != nil { return fmt.Errorf("failed parsing Wire SANs: %w", err) } err = x509Policy.AreSANsAllowed([]string{wireID.Handle}) case WireDevice: var wireID wire.DeviceID if wireID, err = wire.ParseDeviceID(identifier.Value); err != nil { return fmt.Errorf("failed parsing Wire SANs: %w", err) } err = x509Policy.AreSANsAllowed([]string{wireID.ClientID}) default: err = fmt.Errorf("invalid ACME identifier type '%s' provided", identifier.Type) } return err } // AuthorizeSign does not do any validation, because all validation is handled // in the ACME protocol. This method returns a list of modifiers / constraints // on the resulting certificate. func (p *ACME) AuthorizeSign(context.Context, string) ([]SignOption, error) { opts := []SignOption{ p, // modifiers / withOptions newProvisionerExtensionOption(TypeACME, p.Name, "").WithControllerOptions(p.ctl), newForceCNOption(p.ForceCN), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), p.ctl.newWebhookController(nil, linkedca.Webhook_X509), } return opts, nil } // AuthorizeRevoke is called just before the certificate is to be revoked by // the CA. It can be used to authorize revocation of a certificate. With the // ACME protocol, revocation authorization is specified and performed as part // of the client/server interaction, so this is a no-op. func (p *ACME) AuthorizeRevoke(context.Context, string) error { return nil } // AuthorizeRenew returns an error if the renewal is disabled. // NOTE: This method does not actually validate the certificate or check its // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { return p.ctl.AuthorizeRenew(ctx, cert) } // IsChallengeEnabled checks if the given challenge is enabled. By default // http-01, dns-01 and tls-alpn-01 are enabled, to disable any of them the // Challenge provisioner property should have at least one element. func (p *ACME) IsChallengeEnabled(_ context.Context, challenge ACMEChallenge) bool { enabledChallenges := []ACMEChallenge{ HTTP_01, DNS_01, TLS_ALPN_01, } if len(p.Challenges) > 0 { enabledChallenges = p.Challenges } for _, ch := range enabledChallenges { if strings.EqualFold(string(ch), string(challenge)) { return true } } return false } // IsAttestationFormatEnabled checks if the given attestation format is enabled. // By default apple, step and tpm are enabled, to disable any of them the // AttestationFormat provisioner property should have at least one element. func (p *ACME) IsAttestationFormatEnabled(_ context.Context, format ACMEAttestationFormat) bool { enabledFormats := []ACMEAttestationFormat{ APPLE, STEP, TPM, } if len(p.AttestationFormats) > 0 { enabledFormats = p.AttestationFormats } for _, f := range enabledFormats { if strings.EqualFold(string(f), string(format)) { return true } } return false } // GetAttestationRoots returns certificate pool with the configured attestation // roots and reports if the pool contains at least one certificate. // // TODO(hs): we may not want to expose the root pool like this; call into an // interface function instead to authorize? func (p *ACME) GetAttestationRoots() (*x509.CertPool, bool) { return p.attestationRootPool, p.attestationRootPool != nil } ================================================ FILE: authority/provisioner/acme_118_test.go ================================================ //go:build go1.18 package provisioner import ( "bytes" "crypto/x509" "os" "testing" ) func TestACME_GetAttestationRoots(t *testing.T) { appleCA, err := os.ReadFile("testdata/certs/apple-att-ca.crt") if err != nil { t.Fatal(err) } yubicoCA, err := os.ReadFile("testdata/certs/yubico-piv-ca.crt") if err != nil { t.Fatal(err) } pool := x509.NewCertPool() pool.AppendCertsFromPEM(appleCA) pool.AppendCertsFromPEM(yubicoCA) type fields struct { Type string Name string AttestationRoots []byte } tests := []struct { name string fields fields want *x509.CertPool want1 bool }{ {"ok", fields{"ACME", "acme", bytes.Join([][]byte{appleCA, yubicoCA}, []byte("\n"))}, pool, true}, {"nil", fields{"ACME", "acme", nil}, nil, false}, {"empty", fields{"ACME", "acme", []byte{}}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &ACME{ Type: tt.fields.Type, Name: tt.fields.Name, AttestationRoots: tt.fields.AttestationRoots, } if err := p.Init(Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, }); err != nil { t.Fatal(err) } got, got1 := p.GetAttestationRoots() switch { case tt.want == nil && got == nil: break case tt.want == nil && got != nil, tt.want != nil && got == nil: t.Errorf("ACME.GetAttestationRoots() got = %v, want %v", got, tt.want) default: //nolint:staticcheck // this file only runs in go1.18 gotSubjects := got.Subjects() //nolint:staticcheck // this file only runs in go1.18 wantSubjects := tt.want.Subjects() if len(gotSubjects) != len(wantSubjects) { t.Errorf("ACME.GetAttestationRoots() got = %v, want %v", got, tt.want) } else { for i, gotSub := range gotSubjects { if !bytes.Equal(gotSub, wantSubjects[i]) { t.Errorf("ACME.GetAttestationRoots() got = %v, want %v", got, tt.want) break } } } } if got1 != tt.want1 { t.Errorf("ACME.GetAttestationRoots() got1 = %v, want %v", got1, tt.want1) } }) } } ================================================ FILE: authority/provisioner/acme_119_test.go ================================================ //go:build !go1.18 package provisioner import ( "bytes" "crypto/x509" "os" "testing" ) func TestACME_GetAttestationRoots(t *testing.T) { appleCA, err := os.ReadFile("testdata/certs/apple-att-ca.crt") if err != nil { t.Fatal(err) } yubicoCA, err := os.ReadFile("testdata/certs/yubico-piv-ca.crt") if err != nil { t.Fatal(err) } pool := x509.NewCertPool() pool.AppendCertsFromPEM(appleCA) pool.AppendCertsFromPEM(yubicoCA) type fields struct { Type string Name string AttestationRoots []byte } tests := []struct { name string fields fields want *x509.CertPool want1 bool }{ {"ok", fields{"ACME", "acme", bytes.Join([][]byte{appleCA, yubicoCA}, []byte("\n"))}, pool, true}, {"nil", fields{"ACME", "acme", nil}, nil, false}, {"empty", fields{"ACME", "acme", []byte{}}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &ACME{ Type: tt.fields.Type, Name: tt.fields.Name, AttestationRoots: tt.fields.AttestationRoots, } if err := p.Init(Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, }); err != nil { t.Fatal(err) } got, got1 := p.GetAttestationRoots() if tt.want == nil && got != nil { t.Errorf("ACME.GetAttestationRoots() got = %v, want %v", got, tt.want) } else if !tt.want.Equal(got) { t.Errorf("ACME.GetAttestationRoots() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { t.Errorf("ACME.GetAttestationRoots() got1 = %v, want %v", got1, tt.want1) } }) } } ================================================ FILE: authority/provisioner/acme_test.go ================================================ package provisioner import ( "bytes" "context" "crypto/x509" "errors" "fmt" "net/http" "os" "testing" "time" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner/wire" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestACMEChallenge_Validate(t *testing.T) { tests := []struct { name string c ACMEChallenge wantErr bool }{ {"http-01", HTTP_01, false}, {"dns-01", DNS_01, false}, {"tls-alpn-01", TLS_ALPN_01, false}, {"device-attest-01", DEVICE_ATTEST_01, false}, {"wire-oidc-01", DEVICE_ATTEST_01, false}, {"wire-dpop-01", DEVICE_ATTEST_01, false}, {"uppercase", "HTTP-01", false}, {"fail", "http-02", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.c.Validate() if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) }) } } func TestACMEAttestationFormat_Validate(t *testing.T) { tests := []struct { name string f ACMEAttestationFormat wantErr bool }{ {"apple", APPLE, false}, {"step", STEP, false}, {"tpm", TPM, false}, {"uppercase", "APPLE", false}, {"fail", "FOO", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.f.Validate() if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) }) } } func TestACME_Getters(t *testing.T) { p, err := generateACME() require.NoError(t, err) id := "acme/test@acme-provisioner.com" assert.Equal(t, id, p.GetID()) assert.Equal(t, "test@acme-provisioner.com", p.GetName()) assert.Equal(t, TypeACME, p.GetType()) kid, key, ok := p.GetEncryptedKey() if kid != "" || key != "" || ok == true { t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } tokenID, err := p.GetTokenID("token") assert.Empty(t, tokenID) assert.Equal(t, ErrTokenFlowNotSupported, err) } func TestACME_Init(t *testing.T) { appleCA, err := os.ReadFile("testdata/certs/apple-att-ca.crt") require.NoError(t, err) yubicoCA, err := os.ReadFile("testdata/certs/yubico-piv-ca.crt") require.NoError(t, err) fakeWireDPoPKey := []byte(`-----BEGIN PUBLIC KEY----- MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= -----END PUBLIC KEY-----`) type ProvisionerValidateTest struct { p *ACME err error } tests := map[string]func(*testing.T) ProvisionerValidateTest{ "fail/empty": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{}, err: errors.New("provisioner type cannot be empty"), } }, "fail/empty-name": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{ Type: "ACME", }, err: errors.New("provisioner name cannot be empty"), } }, "fail/empty-type": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{Name: "foo"}, err: errors.New("provisioner type cannot be empty"), } }, "fail/bad-claims": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{Name: "foo", Type: "ACME", Claims: &Claims{DefaultTLSDur: &Duration{0}}}, err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, "fail/bad-challenge": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{Name: "foo", Type: "ACME", Challenges: []ACMEChallenge{HTTP_01, "zar"}}, err: errors.New("acme challenge \"zar\" is not supported"), } }, "fail/bad-attestation-format": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{Name: "foo", Type: "ACME", AttestationFormats: []ACMEAttestationFormat{APPLE, "zar"}}, err: errors.New("acme attestation format \"zar\" is not supported"), } }, "fail/parse-attestation-roots": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{Name: "foo", Type: "ACME", AttestationRoots: []byte("-----BEGIN CERTIFICATE-----\nZm9v\n-----END CERTIFICATE-----")}, err: errors.New("error parsing attestationRoots: malformed certificate"), } }, "fail/empty-attestation-roots": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{Name: "foo", Type: "ACME", AttestationRoots: []byte("\n")}, err: errors.New("error parsing attestationRoots: no certificates found"), } }, "fail/wire-missing-options": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{ Name: "foo", Type: "ACME", Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01}, }, err: errors.New("failed initializing Wire options: failed getting Wire options: no options available"), } }, "fail/wire-missing-wire-options": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{ Name: "foo", Type: "ACME", Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01}, Options: &Options{}, }, err: errors.New("failed initializing Wire options: failed getting Wire options: no Wire options available"), } }, "fail/wire-validate-options": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{ Name: "foo", Type: "ACME", Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01}, Options: &Options{ Wire: &wire.Options{ OIDC: &wire.OIDCOptions{}, DPOP: &wire.DPOPOptions{ SigningKey: fakeWireDPoPKey, }, }, }, }, err: errors.New("failed initializing Wire options: failed validating Wire options: failed initializing OIDC options: provider not set"), } }, "ok": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{Name: "foo", Type: "ACME"}, } }, "ok/attestation": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{ Name: "foo", Type: "ACME", Challenges: []ACMEChallenge{DNS_01, DEVICE_ATTEST_01}, AttestationFormats: []ACMEAttestationFormat{APPLE, STEP}, AttestationRoots: bytes.Join([][]byte{appleCA, yubicoCA}, []byte("\n")), }, } }, "ok/wire": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{ Name: "foo", Type: "ACME", Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01}, Options: &Options{ Wire: &wire.Options{ OIDC: &wire.OIDCOptions{ Provider: &wire.Provider{ IssuerURL: "https://issuer.example.com", }, }, DPOP: &wire.DPOPOptions{ SigningKey: fakeWireDPoPKey, }, }, }, }, } }, } config := Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, } for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) t.Log(string(tc.p.AttestationRoots)) err := tc.p.Init(config) if tc.err != nil { assert.EqualError(t, err, tc.err.Error()) return } assert.NoError(t, err) }) } } func TestACME_AuthorizeRenew(t *testing.T) { now := time.Now().Truncate(time.Second) type test struct { p *ACME cert *x509.Certificate err error code int } tests := map[string]func(*testing.T) test{ "fail/renew-disabled": func(t *testing.T) test { p, err := generateACME() require.NoError(t, err) // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) require.NoError(t, err) return test{ p: p, cert: &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }, code: http.StatusUnauthorized, err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateACME() require.NoError(t, err) return test{ p: p, cert: &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) err := tc.p.AuthorizeRenew(context.Background(), tc.cert) if tc.err != nil { if assert.Implements(t, (*render.StatusCodedError)(nil), err) { var sc render.StatusCodedError if errors.As(err, &sc) { assert.Equal(t, tc.code, sc.StatusCode()) } } assert.EqualError(t, err, tc.err.Error()) return } assert.NoError(t, err) }) } } func TestACME_AuthorizeSign(t *testing.T) { type test struct { p *ACME token string code int err error } tests := map[string]func(*testing.T) test{ "ok": func(t *testing.T) test { p, err := generateACME() require.NoError(t, err) return test{ p: p, token: "foo", } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) opts, err := tc.p.AuthorizeSign(context.Background(), tc.token) if tc.err != nil { if assert.Implements(t, (*render.StatusCodedError)(nil), err) { var sc render.StatusCodedError if errors.As(err, &sc) { assert.Equal(t, tc.code, sc.StatusCode()) } } assert.EqualError(t, err, tc.err.Error()) return } assert.NoError(t, err) if assert.NotNil(t, opts) { assert.Len(t, opts, 8) // number of SignOptions returned for _, o := range opts { switch v := o.(type) { case *ACME: case *provisionerExtensionOption: assert.Equal(t, v.Type, TypeACME) assert.Equal(t, v.Name, tc.p.GetName()) assert.Equal(t, v.CredentialID, "") assert.Len(t, v.KeyValuePairs, 0) case *forceCNOption: assert.Equal(t, v.ForceCN, tc.p.ForceCN) case profileDefaultDuration: assert.Equal(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: assert.Equal(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equal(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) case *x509NamePolicyValidator: assert.Equal(t, nil, v.policyEngine) case *WebhookController: assert.Len(t, v.webhooks, 0) default: require.NoError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } }) } } func TestACME_IsChallengeEnabled(t *testing.T) { ctx := context.Background() type fields struct { Challenges []ACMEChallenge } type args struct { ctx context.Context challenge ACMEChallenge } tests := []struct { name string fields fields args args want bool }{ {"ok http-01", fields{nil}, args{ctx, HTTP_01}, true}, {"ok dns-01", fields{nil}, args{ctx, DNS_01}, true}, {"ok tls-alpn-01", fields{[]ACMEChallenge{}}, args{ctx, TLS_ALPN_01}, true}, {"fail device-attest-01", fields{[]ACMEChallenge{}}, args{ctx, "device-attest-01"}, false}, {"ok http-01 enabled", fields{[]ACMEChallenge{"http-01"}}, args{ctx, "HTTP-01"}, true}, {"ok dns-01 enabled", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, DNS_01}, true}, {"ok tls-alpn-01 enabled", fields{[]ACMEChallenge{"http-01", "dns-01", "tls-alpn-01"}}, args{ctx, TLS_ALPN_01}, true}, {"ok device-attest-01 enabled", fields{[]ACMEChallenge{"device-attest-01", "dns-01"}}, args{ctx, DEVICE_ATTEST_01}, true}, {"ok wire-oidc-01 enabled", fields{[]ACMEChallenge{"wire-oidc-01"}}, args{ctx, WIREOIDC_01}, true}, {"ok wire-dpop-01 enabled", fields{[]ACMEChallenge{"wire-dpop-01"}}, args{ctx, WIREDPOP_01}, true}, {"fail http-01", fields{[]ACMEChallenge{"dns-01"}}, args{ctx, "http-01"}, false}, {"fail dns-01", fields{[]ACMEChallenge{"http-01", "tls-alpn-01"}}, args{ctx, "dns-01"}, false}, {"fail tls-alpn-01", fields{[]ACMEChallenge{"http-01", "dns-01", "device-attest-01"}}, args{ctx, "tls-alpn-01"}, false}, {"fail device-attest-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "device-attest-01"}, false}, {"fail wire-oidc-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "wire-oidc-01"}, false}, {"fail wire-dpop-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "wire-dpop-01"}, false}, {"fail unknown", fields{[]ACMEChallenge{"http-01", "dns-01", "tls-alpn-01", "device-attest-01"}}, args{ctx, "unknown"}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &ACME{ Challenges: tt.fields.Challenges, } got := p.IsChallengeEnabled(tt.args.ctx, tt.args.challenge) assert.Equal(t, tt.want, got) }) } } func TestACME_IsAttestationFormatEnabled(t *testing.T) { ctx := context.Background() type fields struct { AttestationFormats []ACMEAttestationFormat } type args struct { ctx context.Context format ACMEAttestationFormat } tests := []struct { name string fields fields args args want bool }{ {"ok", fields{[]ACMEAttestationFormat{APPLE, STEP, TPM}}, args{ctx, TPM}, true}, {"ok empty apple", fields{nil}, args{ctx, APPLE}, true}, {"ok empty step", fields{nil}, args{ctx, STEP}, true}, {"ok empty tpm", fields{[]ACMEAttestationFormat{}}, args{ctx, "tpm"}, true}, {"ok uppercase", fields{[]ACMEAttestationFormat{APPLE, STEP, TPM}}, args{ctx, "STEP"}, true}, {"fail apple", fields{[]ACMEAttestationFormat{STEP, TPM}}, args{ctx, APPLE}, false}, {"fail step", fields{[]ACMEAttestationFormat{APPLE, TPM}}, args{ctx, STEP}, false}, {"fail step", fields{[]ACMEAttestationFormat{APPLE, STEP}}, args{ctx, TPM}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &ACME{ AttestationFormats: tt.fields.AttestationFormats, } got := p.IsAttestationFormatEnabled(tt.args.ctx, tt.args.format) assert.Equal(t, tt.want, got) }) } } ================================================ FILE: authority/provisioner/aws.go ================================================ package provisioner import ( "context" "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/hex" "encoding/json" "encoding/pem" "fmt" "io" "net" "net/http" "os" "strings" "time" "github.com/pkg/errors" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/webhook" _ "embed" ) // awsIssuer is the string used as issuer in the generated tokens. const awsIssuer = "ec2.amazonaws.com" // awsIdentityURL is the url used to retrieve the instance identity document. const awsIdentityURL = "http://169.254.169.254/latest/dynamic/instance-identity/document" // awsSignatureURL is the url used to retrieve the instance identity signature. const awsSignatureURL = "http://169.254.169.254/latest/dynamic/instance-identity/signature" // awsAPITokenURL is the url used to get the IMDSv2 API token const awsAPITokenURL = "http://169.254.169.254/latest/api/token" //nolint:gosec // no credentials here // awsAPITokenTTL is the default TTL to use when requesting IMDSv2 API tokens // -- we keep this short-lived since we get a new token with every call to readURL() const awsAPITokenTTL = "30" // awsMetadataTokenHeader is the header that must be passed with every IMDSv2 request const awsMetadataTokenHeader = "X-aws-ec2-metadata-token" //nolint:gosec // no credentials here // awsMetadataTokenTTLHeader is the header used to indicate the token TTL requested const awsMetadataTokenTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds" //nolint:gosec // no credentials here // awsCertificate is the certificate used to validate the instance identity // signature. It is embedded in the binary at compile time. // //go:embed aws_certificates.pem var awsCertificate string // awsSignatureAlgorithm is the signature algorithm used to verify the identity // document signature. const awsSignatureAlgorithm = x509.SHA256WithRSA type awsConfig struct { identityURL string signatureURL string tokenURL string tokenTTL string certificates []*x509.Certificate signatureAlgorithm x509.SignatureAlgorithm } func newAWSConfig(certPath string) (*awsConfig, error) { var certBytes []byte if certPath == "" { certBytes = []byte(awsCertificate) } else { if b, err := os.ReadFile(certPath); err == nil { certBytes = b } else { return nil, errors.Wrapf(err, "error reading %s", certPath) } } // Read all the certificates. var certs []*x509.Certificate for len(certBytes) > 0 { var block *pem.Block block, certBytes = pem.Decode(certBytes) if block == nil { break } if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { continue } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, errors.Wrap(err, "error parsing AWS IID certificate") } certs = append(certs, cert) } if len(certs) == 0 { return nil, errors.New("error parsing AWS IID certificate: no certificates found") } return &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, tokenURL: awsAPITokenURL, tokenTTL: awsAPITokenTTL, certificates: certs, signatureAlgorithm: awsSignatureAlgorithm, }, nil } type awsPayload struct { jose.Claims Amazon awsAmazonPayload `json:"amazon"` SANs []string `json:"sans"` document awsInstanceIdentityDocument } type awsAmazonPayload struct { Document []byte `json:"document"` Signature []byte `json:"signature"` } type awsInstanceIdentityDocument struct { AccountID string `json:"accountId"` Architecture string `json:"architecture"` AvailabilityZone string `json:"availabilityZone"` BillingProducts []string `json:"billingProducts"` DevpayProductCodes []string `json:"devpayProductCodes"` ImageID string `json:"imageId"` InstanceID string `json:"instanceId"` InstanceType string `json:"instanceType"` KernelID string `json:"kernelId"` PendingTime time.Time `json:"pendingTime"` PrivateIP string `json:"privateIp"` RamdiskID string `json:"ramdiskId"` Region string `json:"region"` Version string `json:"version"` } // AWS is the provisioner that supports identity tokens created from the Amazon // Web Services Instance Identity Documents. // // If DisableCustomSANs is true, only the internal DNS and IP will be added as a // SAN. By default it will accept any SAN in the CSR. // // If DisableTrustOnFirstUse is true, multiple sign request for this provisioner // with the same instance will be accepted. By default only the first request // will be accepted. // // If InstanceAge is set, only the instances with a pendingTime within the given // period will be accepted. // // IIDRoots can be used to specify a path to the certificates used to verify the // identity certificate signature. // // Amazon Identity docs are available at // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html type AWS struct { *base ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Accounts []string `json:"accounts"` DisableCustomSANs bool `json:"disableCustomSANs"` DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` IMDSVersions []string `json:"imdsVersions"` InstanceAge Duration `json:"instanceAge,omitempty"` IIDRoots string `json:"iidRoots,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` config *awsConfig ctl *Controller } // GetID returns the provisioner unique identifier. func (p *AWS) GetID() string { if p.ID != "" { return p.ID } return p.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (p *AWS) GetIDForToken() string { return "aws/" + p.Name } // GetTokenID returns the identifier of the token. func (p *AWS) GetTokenID(token string) (string, error) { payload, err := p.authorizeToken(token) if err != nil { return "", err } // If TOFU is disabled create an ID for the token, so it cannot be reused. // The timestamps, document and signatures should be mostly unique. if p.DisableTrustOnFirstUse { sum := sha256.Sum256([]byte(token)) return strings.ToLower(hex.EncodeToString(sum[:])), nil } // Use provisioner + instance-id as the identifier. unique := fmt.Sprintf("%s.%s", p.GetIDForToken(), payload.document.InstanceID) sum := sha256.Sum256([]byte(unique)) return strings.ToLower(hex.EncodeToString(sum[:])), nil } // GetName returns the name of the provisioner. func (p *AWS) GetName() string { return p.Name } // GetType returns the type of provisioner. func (p *AWS) GetType() Type { return TypeAWS } // GetEncryptedKey is not available in an AWS provisioner. func (p *AWS) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } // GetIdentityToken retrieves the identity document and it's signature and // generates a token with them. func (p *AWS) GetIdentityToken(subject, caURL string) (string, error) { // Initialize the config if this method is used from the cli. if err := p.assertConfig(); err != nil { return "", err } var idoc awsInstanceIdentityDocument doc, err := p.readURL(p.config.identityURL) if err != nil { return "", errors.Wrap(err, "error retrieving identity document:\n Are you in an AWS VM?\n Is the metadata service enabled?\n Are you using the proper metadata service version?") } if err := json.Unmarshal(doc, &idoc); err != nil { return "", errors.Wrap(err, "error unmarshaling identity document") } sig, err := p.readURL(p.config.signatureURL) if err != nil { return "", errors.Wrap(err, "error retrieving identity document:\n Are you in an AWS VM?\n Is the metadata service enabled?\n Are you using the proper metadata service version?") } signature, err := base64.StdEncoding.DecodeString(string(sig)) if err != nil { return "", errors.Wrap(err, "error decoding identity document signature") } if err := p.checkSignature(doc, signature); err != nil { return "", err } audience, err := generateSignAudience(caURL, p.GetIDForToken()) if err != nil { return "", err } // Create unique ID for Trust On First Use (TOFU). Only the first instance // per provisioner is allowed as we don't have a way to trust the given // sans. unique := fmt.Sprintf("%s.%s", p.GetIDForToken(), idoc.InstanceID) sum := sha256.Sum256([]byte(unique)) // Create a JWT from the identity document signer, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.HS256, Key: signature}, new(jose.SignerOptions).WithType("JWT"), ) if err != nil { return "", errors.Wrap(err, "error creating signer") } now := time.Now() payload := awsPayload{ Claims: jose.Claims{ Issuer: awsIssuer, Subject: subject, Audience: []string{audience}, Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), NotBefore: jose.NewNumericDate(now), IssuedAt: jose.NewNumericDate(now), ID: strings.ToLower(hex.EncodeToString(sum[:])), }, Amazon: awsAmazonPayload{ Document: doc, Signature: signature, }, } tok, err := jose.Signed(signer).Claims(payload).CompactSerialize() if err != nil { return "", errors.Wrap(err, "error serializing token") } return tok, nil } // Init validates and initializes the AWS provisioner. func (p *AWS) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") case p.Name == "": return errors.New("provisioner name cannot be empty") case p.InstanceAge.Value() < 0: return errors.New("provisioner instanceAge cannot be negative") } // Add default config if p.config, err = newAWSConfig(p.IIDRoots); err != nil { return err } // validate IMDS versions if len(p.IMDSVersions) == 0 { p.IMDSVersions = []string{"v2", "v1"} } for _, v := range p.IMDSVersions { switch v { case "v1": // valid case "v2": // valid default: return errors.Errorf("%s: not a supported AWS Instance Metadata Service version", v) } } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.ctl, err = NewController(p, p.Claims, config, p.Options) return } // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { payload, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign") } doc := payload.document // Template options data := x509util.NewTemplateData() data.SetCommonName(payload.Claims.Subject) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } // Enforce known CN and default DNS and IP if configured. // By default we'll accept the CN and SANs in the CSR. // There's no way to trust them other than TOFU. var so []SignOption if p.DisableCustomSANs { dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region) so = append(so, dnsNamesSubsetValidator([]string{dnsName}), ipAddressesValidator([]net.IP{ net.ParseIP(doc.PrivateIP), }), emailAddressesValidator(nil), newURIsValidator(ctx, nil), ) // Template options data.SetSANs([]string{dnsName, doc.PrivateIP}) } templateOptions, err := CustomTemplateOptions(p.Options, data, x509util.DefaultIIDLeafTemplate) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign") } return append(so, p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID).WithControllerOptions(p.ctl), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, commonNameValidator(payload.Claims.Subject), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), p.ctl.newWebhookController( data, linkedca.Webhook_X509, webhook.WithAuthorizationPrincipal(doc.InstanceID), ), ), nil } // AuthorizeRenew returns an error if the renewal is disabled. // NOTE: This method does not actually validate the certificate or check it's // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { return p.ctl.AuthorizeRenew(ctx, cert) } // assertConfig initializes the config if it has not been initialized func (p *AWS) assertConfig() (err error) { if p.config != nil { return } p.config, err = newAWSConfig(p.IIDRoots) return err } // checkSignature returns an error if the signature is not valid. func (p *AWS) checkSignature(signed, signature []byte) error { for _, crt := range p.config.certificates { if err := crt.CheckSignature(p.config.signatureAlgorithm, signed, signature); err == nil { return nil } } return errors.New("error validating identity document signature") } // readURL does a GET request to the given url and returns the body. It's not // using pkg/errors to avoid verbose errors, the caller should use it and write // the appropriate error. func (p *AWS) readURL(url string) ([]byte, error) { var resp *http.Response var err error // Initialize IMDS versions when this is called from the cli. if len(p.IMDSVersions) == 0 { p.IMDSVersions = []string{"v2", "v1"} } for _, v := range p.IMDSVersions { switch v { case "v1": resp, err = p.readURLv1(url) if err == nil && resp.StatusCode < 400 { return p.readResponseBody(resp) } case "v2": resp, err = p.readURLv2(url) if err == nil && resp.StatusCode < 400 { return p.readResponseBody(resp) } default: return nil, fmt.Errorf("%s: not a supported AWS Instance Metadata Service version", v) } if resp != nil { resp.Body.Close() } } // all versions have been exhausted and we haven't returned successfully yet so pass // the error on to the caller if err != nil { return nil, err } return nil, fmt.Errorf("request for metadata returned non-successful status code %d", resp.StatusCode) } func (p *AWS) readURLv1(url string) (*http.Response, error) { client := http.Client{} req, err := http.NewRequest(http.MethodGet, url, http.NoBody) if err != nil { return nil, err } resp, err := client.Do(req) if err != nil { return nil, err } return resp, nil } func (p *AWS) readURLv2(url string) (*http.Response, error) { client := http.Client{} // first get the token req, err := http.NewRequest(http.MethodPut, p.config.tokenURL, http.NoBody) if err != nil { return nil, err } req.Header.Set(awsMetadataTokenTTLHeader, p.config.tokenTTL) resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode >= 400 { return nil, fmt.Errorf("request for API token returned non-successful status code %d", resp.StatusCode) } token, err := io.ReadAll(resp.Body) if err != nil { return nil, err } // now make the request req, err = http.NewRequest(http.MethodGet, url, http.NoBody) if err != nil { return nil, err } req.Header.Set(awsMetadataTokenHeader, string(token)) resp, err = client.Do(req) if err != nil { return nil, err } return resp, nil } func (p *AWS) readResponseBody(resp *http.Response) ([]byte, error) { defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return nil, err } return b, nil } // authorizeToken performs common jwt authorization actions and returns the // claims for case specific downstream parsing. // e.g. a Sign request will auth/validate different fields than a Revoke request. func (p *AWS) authorizeToken(token string) (*awsPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; error parsing aws token") } if len(jwt.Headers) == 0 { return nil, errs.InternalServer("aws.authorizeToken; error parsing token, header is missing") } var unsafeClaims awsPayload if err := jwt.UnsafeClaimsWithoutVerification(&unsafeClaims); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error unmarshaling claims") } var payload awsPayload if err := jwt.Claims(unsafeClaims.Amazon.Signature, &payload); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error verifying claims") } // Validate identity document signature if err := p.checkSignature(payload.Amazon.Document, payload.Amazon.Signature); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; invalid aws token signature") } var doc awsInstanceIdentityDocument if err := json.Unmarshal(payload.Amazon.Document, &doc); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "aws.authorizeToken; error unmarshaling aws identity document") } switch { case doc.AccountID == "": return nil, errs.Unauthorized("aws.authorizeToken; aws identity document accountId cannot be empty") case doc.InstanceID == "": return nil, errs.Unauthorized("aws.authorizeToken; aws identity document instanceId cannot be empty") case doc.PrivateIP == "": return nil, errs.Unauthorized("aws.authorizeToken; aws identity document privateIp cannot be empty") case doc.Region == "": return nil, errs.Unauthorized("aws.authorizeToken; aws identity document region cannot be empty") } // According to "rfc7519 JSON Web Token" acceptable skew should be no // more than a few minutes. now := time.Now().UTC() if err = payload.ValidateWithLeeway(jose.Expected{ Issuer: awsIssuer, Time: now, }, time.Minute); err != nil { return nil, errs.Wrapf(http.StatusUnauthorized, err, "aws.authorizeToken; invalid aws token") } // validate audiences with the defaults if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) { return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)") } // Validate subject, it has to be known if disableCustomSANs is enabled if p.DisableCustomSANs { if payload.Subject != doc.InstanceID && payload.Subject != doc.PrivateIP && payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region) { return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)") } } // validate accounts if len(p.Accounts) > 0 { var found bool for _, sa := range p.Accounts { if sa == doc.AccountID { found = true break } } if !found { return nil, errs.Unauthorized("aws.authorizeToken; invalid aws identity document - accountId is not valid") } } // validate instance age if d := p.InstanceAge.Value(); d > 0 { if now.Sub(doc.PendingTime) > d { return nil, errs.Unauthorized("aws.authorizeToken; aws identity document pendingTime is too old") } } payload.document = doc return &payload, nil } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *AWS) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSSHSign") } doc := claims.document signOptions := []SignOption{} // Enforce host certificate. defaults := SignSSHOptions{ CertType: SSHHostCert, } // Validated principals. principals := []string{ doc.PrivateIP, fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region), } // Only enforce known principals if disable custom sans is true. if p.DisableCustomSANs { defaults.Principals = principals } else { // Check that at least one principal is sent in the request. signOptions = append(signOptions, &sshCertOptionsRequireValidator{ Principals: true, }) } // Certificate templates. data := sshutil.CreateTemplateData(sshutil.HostCert, doc.InstanceID, principals) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } templateOptions, err := CustomSSHTemplateOptions(p.Options, data, sshutil.DefaultIIDTemplate) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSSHSign") } signOptions = append(signOptions, templateOptions) return append(signOptions, p, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, // Ensure that all principal names are allowed newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), // Call webhooks p.ctl.newWebhookController( data, linkedca.Webhook_SSH, webhook.WithAuthorizationPrincipal(doc.InstanceID), ), ), nil } ================================================ FILE: authority/provisioner/aws_certificates.pem ================================================ # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/verify-signature.html # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/regions-certs.html use RSA format # certificate for us-east-2 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUVJTc+hOU+8Gk3JlqsX438Dk5c58wDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE3MTE0OVoXDTI5MDQyODE3MTE0OVowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUVJTc+hOU +8Gk3JlqsX438Dk5c58wEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQAywJQaVNWJqW0R0T0xVOSoN1GLk9x9kKEuN67RN9CLin4dA97qa7Mr5W4P FZ6vnh5CjOhQBRXV9xJUeYSdqVItNAUFK/fEzDdjf1nUfPlQ3OJ49u6CV01NoJ9m usvY9kWcV46dqn2bk2MyfTTgvmeqP8fiMRPxxnVRkSzlldP5Fg== -----END CERTIFICATE----- # certificate for us-east-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUE1y2NIKCU+Rg4uu4u32koG9QEYIwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE3MzQwMVoXDTI5MDQyODE3MzQwMVowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUE1y2NIKC U+Rg4uu4u32koG9QEYIwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQAlxSmwcWnhT4uAeSinJuz+1BTcKhVSWb5jT8pYjQb8ZoZkXXRGb09mvYeU NeqOBr27rvRAnaQ/9LUQf72+SahDFuS4CMI8nwowytqbmwquqFr4dxA/SDADyRiF ea1UoMuNHTY49J/1vPomqsVn7mugTp+TbjqCfOJTpu0temHcFA== -----END CERTIFICATE----- # certificate for us-west-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUK2zmY9PUSTR7rc1k2OwPYu4+g7wwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE3MDI0M1oXDTI5MDQyODE3MDI0M1owXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUK2zmY9PU STR7rc1k2OwPYu4+g7wwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQA1Ng4QmN4n7iPh5CnadSOc0ZfM7by0dBePwZJyGvOHdaw6P6E/vEk76KsC Q8p+akuzVzVPkU4kBK/TRqLp19wEWoVwhhTaxHjQ1tTRHqXIVlrkw4JrtFbeNM21 GlkSLonuzmNZdivn9WuQYeGe7nUD4w3q9GgiF3CPorJe+UxtbA== -----END CERTIFICATE----- # certificate for us-west-2 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUFx8PxCkbHwpD31bOyCtyz3GclbgwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE3MjM1OVoXDTI5MDQyODE3MjM1OVowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUFx8PxCkb HwpD31bOyCtyz3GclbgwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQBzOl+9Xy1+UsbUBI95HO9mbbdnuX+aMJXgG9uFZNjgNEbMcvx+h8P9IMko z7PzFdheQQ1NLjsHH9mSR1SyC4m9ja6BsejH5nLBWyCdjfdP3muZM4O5+r7vUa1O dWU+hP/T7DUrPAIVMOE7mpYa+WPWJrN6BlRwQkKQ7twm9kDalA== -----END CERTIFICATE----- # certificate for eu-south-1 -----BEGIN CERTIFICATE----- MIICNjCCAZ+gAwIBAgIJAOZ3GEIaDcugMA0GCSqGSIb3DQEBCwUAMFwxCzAJBgNV BAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdTZWF0 dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0xOTEwMjQx NTE5MDlaGA8yMTk5MDMyOTE1MTkwOVowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgT EFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0Ft YXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB gQCjiPgW3vsXRj4JoA16WQDyoPc/eh3QBARaApJEc4nPIGoUolpAXcjFhWplo2O+ ivgfCsc4AU9OpYdAPha3spLey/bhHPRi1JZHRNqScKP0hzsCNmKhfnZTIEQCFvsp DRp4zr91/WS06/flJFBYJ6JHhp0KwM81XQG59lV6kkoW7QIDAQABMA0GCSqGSIb3 DQEBCwUAA4GBAGLLrY3P+HH6C57dYgtJkuGZGT2+rMkk2n81/abzTJvsqRqGRrWv XRKRXlKdM/dfiuYGokDGxiC0Mg6TYy6wvsR2qRhtXW1OtZkiHWcQCnOttz+8vpew wx8JGMvowtuKB1iMsbwyRqZkFYLcvH+Opfb/Aayi20/ChQLdI6M2R5VU -----END CERTIFICATE----- # certificate for ap-east-1 -----BEGIN CERTIFICATE----- MIICSzCCAbQCCQDtQvkVxRvK9TANBgkqhkiG9w0BAQsFADBqMQswCQYDVQQGEwJV UzETMBEGA1UECBMKV2FzaGluZ3RvbjEQMA4GA1UEBxMHU2VhdHRsZTEYMBYGA1UE ChMPQW1hem9uLmNvbSBJbmMuMRowGAYDVQQDExFlYzIuYW1hem9uYXdzLmNvbTAe Fw0xOTAyMDMwMzAwMDZaFw0yOTAyMDIwMzAwMDZaMGoxCzAJBgNVBAYTAlVTMRMw EQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgwFgYDVQQKEw9B bWF6b24uY29tIEluYy4xGjAYBgNVBAMTEWVjMi5hbWF6b25hd3MuY29tMIGfMA0G CSqGSIb3DQEBAQUAA4GNADCBiQKBgQC1kkHXYTfc7gY5Q55JJhjTieHAgacaQkiR Pity9QPDE3b+NXDh4UdP1xdIw73JcIIG3sG9RhWiXVCHh6KkuCTqJfPUknIKk8vs M3RXflUpBe8Pf+P92pxqPMCz1Fr2NehS3JhhpkCZVGxxwLC5gaG0Lr4rFORubjYY Rh84dK98VwIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAA6xV9f0HMqXjPHuGILDyaNN dKcvplNFwDTydVg32MNubAGnecoEBtUPtxBsLoVYXCOb+b5/ZMDubPF9tU/vSXuo TpYM5Bq57gJzDRaBOntQbX9bgHiUxw6XZWaTS/6xjRJDT5p3S1E0mPI3lP/eJv4o Ezk5zb3eIf10/sqt4756 -----END CERTIFICATE----- # certificate for af-south-1 -----BEGIN CERTIFICATE----- MIICNjCCAZ+gAwIBAgIJAKumfZiRrNvHMA0GCSqGSIb3DQEBCwUAMFwxCzAJBgNV BAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdTZWF0 dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0xOTExMjcw NzE0MDVaGA8yMTk5MDUwMjA3MTQwNVowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgT EFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0Ft YXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB gQDFd571nUzVtke3rPyRkYfvs3jh0C0EMzzG72boyUNjnfw1+m0TeFraTLKb9T6F 7TuB/ZEN+vmlYqr2+5Va8U8qLbPF0bRH+FdaKjhgWZdYXxGzQzU3ioy5W5ZM1VyB 7iUsxEAlxsybC3ziPYaHI42UiTkQNahmoroNeqVyHNnBpQIDAQABMA0GCSqGSIb3 DQEBCwUAA4GBAAJLylWyElEgOpW4B1XPyRVD4pAds8Guw2+krgqkY0HxLCdjosuH RytGDGN+q75aAoXzW5a7SGpxLxk6Hfv0xp3RjDHsoeP0i1d8MD3hAC5ezxS4oukK s5gbPOnokhKTMPXbTdRn5ZifCbWlx+bYN/mTYKvxho7b5SVg2o1La9aK -----END CERTIFICATE----- # certificate for me-south-1 -----BEGIN CERTIFICATE----- MIIDPDCCAqWgAwIBAgIJAMl6uIV/zqJFMA0GCSqGSIb3DQEBCwUAMHIxCzAJBgNV BAYTAlVTMRMwEQYDVQQIDApXYXNoaW5ndG9uMRAwDgYDVQQHDAdTZWF0dGxlMSAw HgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzEaMBgGA1UEAwwRZWMyLmFt YXpvbmF3cy5jb20wIBcNMTkwNDI2MTQzMjQ3WhgPMjE5ODA5MjkxNDMyNDdaMHIx CzAJBgNVBAYTAlVTMRMwEQYDVQQIDApXYXNoaW5ndG9uMRAwDgYDVQQHDAdTZWF0 dGxlMSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzEaMBgGA1UEAwwR ZWMyLmFtYXpvbmF3cy5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBALVN CDTZEnIeoX1SEYqq6k1BV0ZlpY5y3KnoOreCAE589TwS4MX5+8Fzd6AmACmugeBP Qk7Hm6b2+g/d4tWycyxLaQlcq81DB1GmXehRkZRgGeRge1ePWd1TUA0I8P/QBT7S gUePm/kANSFU+P7s7u1NNl+vynyi0wUUrw7/wIZTAgMBAAGjgdcwgdQwHQYDVR0O BBYEFILtMd+T4YgH1cgc+hVsVOV+480FMIGkBgNVHSMEgZwwgZmAFILtMd+T4YgH 1cgc+hVsVOV+480FoXakdDByMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGlu Z3RvbjEQMA4GA1UEBwwHU2VhdHRsZTEgMB4GA1UECgwXQW1hem9uIFdlYiBTZXJ2 aWNlcyBMTEMxGjAYBgNVBAMMEWVjMi5hbWF6b25hd3MuY29tggkAyXq4hX/OokUw DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQBhkNTBIFgWFd+ZhC/LhRUY 4OjEiykmbEp6hlzQ79T0Tfbn5A4NYDI2icBP0+hmf6qSnIhwJF6typyd1yPK5Fqt NTpxxcXmUKquX+pHmIkK1LKDO8rNE84jqxrxRsfDi6by82fjVYf2pgjJW8R1FAw+ mL5WQRFexbfB5aXhcMo0AA== -----END CERTIFICATE----- # certificate for cn-north-1 -----BEGIN CERTIFICATE----- MIIDCzCCAnSgAwIBAgIJALSOMbOoU2svMA0GCSqGSIb3DQEBCwUAMFwxCzAJBgNV BAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdTZWF0 dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAeFw0yMzA3MDQw ODM1MzlaFw0yODA3MDIwODM1MzlaMFwxCzAJBgNVBAYTAlVTMRkwFwYDVQQIExBX YXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdTZWF0dGxlMSAwHgYDVQQKExdBbWF6 b24gV2ViIFNlcnZpY2VzIExMQzCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA uhhUNlqAZdcWWB/OSDVDGk3OA99EFzOn/mJlmciQ/Xwu2dFJWmSCqEAE6gjufCjQ q3voxAhC2CF+elKtJW/C0Sz/LYo60PUqd6iXF4h+upB9HkOOGuWHXsHBTsvgkgGA 1CGgel4U0Cdq+23eANr8N8m28UzljjSnTlrYCHtzN4sCAwEAAaOB1DCB0TALBgNV HQ8EBAMCB4AwHQYDVR0OBBYEFBkZu3wT27NnYgrfH+xJz4HJaNJoMIGOBgNVHSME gYYwgYOAFBkZu3wT27NnYgrfH+xJz4HJaNJooWCkXjBcMQswCQYDVQQGEwJVUzEZ MBcGA1UECBMQV2FzaGluZ3RvbiBTdGF0ZTEQMA4GA1UEBxMHU2VhdHRsZTEgMB4G A1UEChMXQW1hem9uIFdlYiBTZXJ2aWNlcyBMTEOCCQC0jjGzqFNrLzASBgNVHRMB Af8ECDAGAQH/AgEAMA0GCSqGSIb3DQEBCwUAA4GBAECji43p+oPkYqmzll7e8Hgb oADS0ph+YUz5P/bUCm61wFjlxaTfwKcuTR3ytj7bFLoW5Bm7Sa+TCl3lOGb2taon 2h+9NirRK6JYk87LMNvbS40HGPFumJL2NzEsGUeK+MRiWu+Oh5/lJGii3qw4YByx SUDlRyNy1jJFstEZjOhs -----END CERTIFICATE----- # certificate for cn-northwest-1 -----BEGIN CERTIFICATE----- MIIDCzCCAnSgAwIBAgIJALSOMbOoU2svMA0GCSqGSIb3DQEBCwUAMFwxCzAJBgNV BAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdTZWF0 dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAeFw0yMzA3MDQw ODM1MzlaFw0yODA3MDIwODM1MzlaMFwxCzAJBgNVBAYTAlVTMRkwFwYDVQQIExBX YXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdTZWF0dGxlMSAwHgYDVQQKExdBbWF6 b24gV2ViIFNlcnZpY2VzIExMQzCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA uhhUNlqAZdcWWB/OSDVDGk3OA99EFzOn/mJlmciQ/Xwu2dFJWmSCqEAE6gjufCjQ q3voxAhC2CF+elKtJW/C0Sz/LYo60PUqd6iXF4h+upB9HkOOGuWHXsHBTsvgkgGA 1CGgel4U0Cdq+23eANr8N8m28UzljjSnTlrYCHtzN4sCAwEAAaOB1DCB0TALBgNV HQ8EBAMCB4AwHQYDVR0OBBYEFBkZu3wT27NnYgrfH+xJz4HJaNJoMIGOBgNVHSME gYYwgYOAFBkZu3wT27NnYgrfH+xJz4HJaNJooWCkXjBcMQswCQYDVQQGEwJVUzEZ MBcGA1UECBMQV2FzaGluZ3RvbiBTdGF0ZTEQMA4GA1UEBxMHU2VhdHRsZTEgMB4G A1UEChMXQW1hem9uIFdlYiBTZXJ2aWNlcyBMTEOCCQC0jjGzqFNrLzASBgNVHRMB Af8ECDAGAQH/AgEAMA0GCSqGSIb3DQEBCwUAA4GBAECji43p+oPkYqmzll7e8Hgb oADS0ph+YUz5P/bUCm61wFjlxaTfwKcuTR3ytj7bFLoW5Bm7Sa+TCl3lOGb2taon 2h+9NirRK6JYk87LMNvbS40HGPFumJL2NzEsGUeK+MRiWu+Oh5/lJGii3qw4YByx SUDlRyNy1jJFstEZjOhs -----END CERTIFICATE----- # certificate for eu-central-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUFD5GsmkxRuecttwsCG763m3u63UwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE1NTUyOVoXDTI5MDQyODE1NTUyOVowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUFD5Gsmkx RuecttwsCG763m3u63UwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQBBh0WaXlBsW56Hqk588MmJxsOrvcKfDjF57RgEDgnGnQaJcStCVWDO9UYO JX2tdsPw+E7AjDqjsuxYaotLn3Mr3mK0sNOXq9BljBnWD4pARg89KZnZI8FN35HQ O/LYOVHCknuPL123VmVRNs51qQA9hkPjvw21UzpDLxaUxt9Z/w== -----END CERTIFICATE----- # certificate for eu-central-2 -----BEGIN CERTIFICATE----- MIICMzCCAZygAwIBAgIGAXjSGFGiMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMTA0MTQyMDM1 MTJaGA8yMjAwMDQxNDIwMzUxMlowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC2 mdGdps5Rz2jzYcGNsgETTGUthJRrVqSnUWJXTlVaIbkGPLKO6Or7AfWKFp2sgRJ8 vLsjoBVR5cESVK7cuK1wItjvJyi/opKZAUusJx2hpgU3pUHhlp9ATh/VeVD582jT d9IY+8t5MDa6Z3fGliByEiXz0LEHdi8MBacLREu1TwIDAQABMA0GCSqGSIb3DQEB BQUAA4GBAILlpoE3k9o7KdALAxsFJNitVS+g3RMzdbiFM+7MA63Nv5fsf+0xgcjS NBElvPCDKFvTJl4QQhToy056llO5GvdS9RK+H8xrP2mrqngApoKTApv93vHBixgF Sn5KrczRO0YSm3OjkqbydU7DFlmkXXR7GYE+5jbHvQHYiT1J5sMu -----END CERTIFICATE----- # certificate for ap-south-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUDLA+x6tTAP3LRTr0z6nOxfsozdMwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE0MTMwMVoXDTI5MDQyODE0MTMwMVowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUDLA+x6tT AP3LRTr0z6nOxfsozdMwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQAZ7rYKoAwwiiH1M5GJbrT/BEk3OO2VrEPw8ZxgpqQ/EKlzMlOs/0Cyrmp7 UYyUgYFQe5nq37Z94rOUSeMgv/WRxaMwrLlLqD78cuF9DSkXaZIX/kECtVaUnjk8 BZx0QhoIHOpQocJUSlm/dLeMuE0+0A3HNR6JVktGsUdv9ulmKw== -----END CERTIFICATE----- # certificate for ap-south-2 -----BEGIN CERTIFICATE----- MIICMzCCAZygAwIBAgIGAXjwLj9CMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMTA0MjAxNjQ3 NDVaGA8yMjAwMDQyMDE2NDc0NVowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDT wHu0ND+sFcobrjvcAYm0PNRD8f4R1jAzvoLt2+qGeOTAyO1Httj6cmsYN3AP1hN5 iYuppFiYsl2eNPa/CD0Vg0BAfDFlV5rzjpA0j7TJabVh4kj7JvtD+xYMi6wEQA4x 6SPONY4OeZ2+8o/HS8nucpWDVdPRO6ciWUlMhjmDmwIDAQABMA0GCSqGSIb3DQEB BQUAA4GBAAy6sgTdRkTqELHBeWj69q60xHyUmsWqHAQNXKVc9ApWGG4onzuqlMbG ETwUZ9mTq2vxlV0KvuetCDNS5u4cJsxe/TGGbYP0yP2qfMl0cCImzRI5W0gn8gog dervfeT7nH5ih0TWEy/QDWfkQ601L4erm4yh4YQq8vcqAPSkf04N -----END CERTIFICATE----- # certificate for ap-southeast-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUSqP6ih+++5KF07NXngrWf26mhSUwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE0MzAxNFoXDTI5MDQyODE0MzAxNFowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUSqP6ih++ +5KF07NXngrWf26mhSUwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQAw13BxW11U/JL58j//Fmk7qqtrZTqXmaz1qm2WlIpJpW750MOcP4ux1uPy eM0RdVZ4jHSMv5gtLAv/PjExBfw9n6vNCk+5GZG4Xec5DoapBZHXmfMo93sjxBFP 4x9rWn0GuwAVO9ukjYPevq2Rerilrq5VvppHtbATVNY2qecXDA== -----END CERTIFICATE----- # certificate for ap-southeast-2 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUFxWyAdk4oiXIOC9PxcgjYYh71mwwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE1MjE0M1oXDTI5MDQyODE1MjE0M1owXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUFxWyAdk4 oiXIOC9PxcgjYYh71mwwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQByjeQe6lr7fiIhoGdjBXYzDfkX0lGGvMIhRh57G1bbceQfaYdZd7Ptc0jl bpycKGaTvhUdkpMOiV2Hi9dOOYawkdhyJDstmDNKu6P9+b6Kak8He5z3NU1tUR2Y uTwcz7Ye8Nldx//ws3raErfTI7D6s9m63OX8cAJ/f8bNgikwpw== -----END CERTIFICATE----- # certificate for ap-southeast-3 -----BEGIN CERTIFICATE----- MIICMzCCAZygAwIBAgIGAXbVDG2yMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMTAxMDYwMDE1 MzBaGA8yMjAwMDEwNjAwMTUzMFowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn CS/Vbt0gQ1ebWcur2hSO7PnJifE4OPxQ7RgSAlc4/spJp1sDP+ZrS0LO1ZJfKhXf 1R9S3AUwLnsc7b+IuVXdY5LK9RKqu64nyXP5dx170zoL8loEyCSuRR2fs+04i2Qs WBVP+KFNAn7P5L1EHRjkgTO8kjNKviwRV+OkP9ab5wIDAQABMA0GCSqGSIb3DQEB BQUAA4GBAI4WUy6+DKh0JDSzQEZNyBgNlSoSuC2owtMxCwGB6nBfzzfcekWvs6eo fLTSGovrReX7MtVgrcJBZjmPIentw5dWUs+87w/g9lNwUnUt0ZHYyh2tuBG6hVJu UEwDJ/z3wDd6wQviLOTF3MITawt9P8siR1hXqLJNxpjRQFZrgHqi -----END CERTIFICATE----- # certificate for ap-southeast-4 -----BEGIN CERTIFICATE----- MIICMzCCAZygAwIBAgIGAXjSh40SMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMTA0MTQyMjM2 NDJaGA8yMjAwMDQxNDIyMzY0MlowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDH ezwQr2VQpQSTW5TXNefiQrP+qWTGAbGsPeMX4hBMjAJUKys2NIRcRZaLM/BCew2F IPVjNtlaj6Gwn9ipU4Mlz3zIwAMWi1AvGMSreppt+wV6MRtfOjh0Dvj/veJe88aE ZJMozNgkJFRS+WFWsckQeL56tf6kY6QTlNo8V/0CsQIDAQABMA0GCSqGSIb3DQEB BQUAA4GBAF7vpPghH0FRo5gu49EArRNPrIvW1egMdZHrzJNqbztLCtV/wcgkqIww uXYj+1rhlL+/iMpQWjdVGEqIZSeXn5fLmdx50eegFCwND837r9e8XYTiQS143Sxt 9+Yi6BZ7U7YD8kK9NBWoJxFqUeHdpRCs0O7COjT3gwm7ZxvAmssh -----END CERTIFICATE----- # certificate for eu-south-2 -----BEGIN CERTIFICATE----- MIICMzCCAZygAwIBAgIGAXjwLkiaMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMTA0MjAxNjQ3 NDhaGA8yMjAwMDQyMDE2NDc0OFowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDB /VvR1+45Aey5zn3vPk6xBm5o9grSDL6D2iAuprQnfVXn8CIbSDbWFhA3fi5ippjK kh3sl8VyCvCOUXKdOaNrYBrPRkrdHdBuL2Tc84RO+3m/rxIUZ2IK1fDlC6sWAjdd f6sBrV2w2a78H0H8EwuwiSgttURBjwJ7KPPJCqaqrQIDAQABMA0GCSqGSIb3DQEB BQUAA4GBAKR+FzqQDzun/iMMzcFucmLMl5BxEblrFXOz7IIuOeiGkndmrqUeDCyk ztLku45s7hxdNy4ltTuVAaE5aNBdw5J8U1mRvsKvHLy2ThH6hAWKwTqtPAJp7M21 GDwgDDOkPSz6XVOehg+hBgiphYp84DUbWVYeP8YqLEJSqscKscWC -----END CERTIFICATE----- # certificate for il-central-1 -----BEGIN CERTIFICATE----- MIICMzCCAZygAwIBAgIGAX0QQGVLMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMTExMTExODI2 MzVaGA8yMjAwMTExMTE4MjYzNVowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDr c24u3AgFxnoPgzxR6yFXOamcPuxYXhYKWmapb+S8vOy5hpLoRe4RkOrY0cM3bN07 GdEMlin5mU0y1t8y3ct4YewvmkgT42kTyMM+t1K4S0xsqjXxxS716uGYh7eWtkxr Cihj8AbXN/6pa095h+7TZyl2n83keiNUzM2KoqQVMwIDAQABMA0GCSqGSIb3DQEB BQUAA4GBADwA6VVEIIZD2YL00F12po40xDLzIc9XvqFPS9iFaWi2ho8wLio7wA49 VYEFZSI9CR3SGB9tL8DUib97mlxmd1AcGShMmMlhSB29vhuhrUNB/FmU7H8s62/j D6cOR1A1cClIyZUe1yT1ZbPySCs43J+Thr8i8FSRxzDBSZZi5foW -----END CERTIFICATE----- # certificate for me-central-1 -----BEGIN CERTIFICATE----- MIICMzCCAZygAwIBAgIGAXjRrnDjMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMTA0MTQxODM5 MzNaGA8yMjAwMDQxNDE4MzkzM1owXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDc aTgW/KyA6zyruJQrYy00a6wqLA7eeUzk3bMiTkLsTeDQfrkaZMfBAjGaaOymRo1C 3qzE4rIenmahvUplu9ZmLwL1idWXMRX2RlSvIt+d2SeoKOKQWoc2UOFZMHYxDue7 zkyk1CIRaBukTeY13/RIrlc6X61zJ5BBtZXlHwayjQIDAQABMA0GCSqGSIb3DQEB BQUAA4GBABTqTy3R6RXKPW45FA+cgo7YZEj/Cnz5YaoUivRRdX2A83BHuBTvJE2+ WX00FTEj4hRVjameE1nENoO8Z7fUVloAFDlDo69fhkJeSvn51D1WRrPnoWGgEfr1 +OfK1bAcKTtfkkkP9r4RdwSjKzO5Zu/B+Wqm3kVEz/QNcz6npmA6 -----END CERTIFICATE----- # certificate for us-gov-east-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIULVyrqjjwZ461qelPCiShB1KCCj4wDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDUwNzE1MjIzNloXDTI5MDUwNjE1MjIzNlowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCpohwYUVPH9I7Vbkb3WMe/JB0Y/bmfVj3VpcK445YBRO9K80al esjgBc2tAX4KYg4Lht4EBKccLHTzaNi51YEGX1aLNrSmxhz1+WtzNLNUsyY3zD9z vwX/3k1+JB2dRA+m+Cpwx4mjzZyAeQtHtegVaAytkmqtxQrSCexBxvqRqQIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQU1ZXneBYnPvYXkHVlVjg7918V gE8wgZkGA1UdIwSBkTCBjoAU1ZXneBYnPvYXkHVlVjg7918VgE+hYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IULVyrqjjw Z461qelPCiShB1KCCj4wEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQBfAL/YZv0y3zmVbXjyxQCsDloeDCJjFKIu3ameEckeIWJbST9LMto0zViZ puIAf05x6GQiEqfBMk+YMxJfcTmJB4Ebaj4egFlslJPSHyC2xuydHlr3B04INOH5 Z2oCM68u6GGbj0jZjg7GJonkReG9N72kDva/ukwZKgq8zErQVQ== -----END CERTIFICATE----- # certificate for us-gov-west-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUe5wGF3jfb7lUHzvDxmM/ktGCLwwwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDUwNzE3MzAzMloXDTI5MDUwNjE3MzAzMlowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCpohwYUVPH9I7Vbkb3WMe/JB0Y/bmfVj3VpcK445YBRO9K80al esjgBc2tAX4KYg4Lht4EBKccLHTzaNi51YEGX1aLNrSmxhz1+WtzNLNUsyY3zD9z vwX/3k1+JB2dRA+m+Cpwx4mjzZyAeQtHtegVaAytkmqtxQrSCexBxvqRqQIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQU1ZXneBYnPvYXkHVlVjg7918V gE8wgZkGA1UdIwSBkTCBjoAU1ZXneBYnPvYXkHVlVjg7918VgE+hYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUe5wGF3jf b7lUHzvDxmM/ktGCLwwwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQCbTdpx1Iob9SwUReY4exMnlwQlmkTLyA8tYGWzchCJOJJEPfsW0ryy1A0H YIuvyUty3rJdp9ib8h3GZR71BkZnNddHhy06kPs4p8ewF8+d8OWtOJQcI+ZnFfG4 KyM4rUsBrljpG2aOCm12iACEyrvgJJrS8VZwUDZS6mZEnn/lhA== -----END CERTIFICATE----- # certificate for ca-west-1 -----BEGIN CERTIFICATE----- MIICMzCCAZygAwIBAgIGAYPou9weMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMjEwMTgwMTM2 MDlaGA8yMjAxMTAxODAxMzYwOVowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDK 1kIcG5Q6adBXQM75GldfTSiXl7tn54p10TnspI0ErDdb2B6q2Ji/v4XBVH13ZCMg qlRHMqV8AWI5iO6gFn2A9sN3AZXTMqwtZeiDdebq3k6Wt7ieYvpXTg0qvgsjQIov RZWaBDBJy9x8C2hW+w9lMQjFHkJ7Jy/PHCJ69EzebQIDAQABMA0GCSqGSIb3DQEB BQUAA4GBAGe9Snkz1A6rHBH6/5kDtYvtPYwhx2sXNxztbhkXErFk40Nw5l459NZx EeudxJBLoCkkSgYjhRcOZ/gvDVtWG7qyb6fAqgoisyAbk8K9LzxSim2S1nmT9vD8 4B/t/VvwQBylc+ej8kRxMH7fquZLp7IXfmtBzyUqu6Dpbne+chG2 -----END CERTIFICATE----- # certificate for ap-northeast-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIULgwDh7TiDrPPBJwscqDwiBHkEFQwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTEyMjMxMFoXDTI5MDQyODEyMjMxMFowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IULgwDh7Ti DrPPBJwscqDwiBHkEFQwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQBtjAglBde1t4F9EHCZOj4qnY6Gigy07Ou54i+lR77MhbpzE8V28Li9l+YT QMIn6SzJqU3/fIycIro1OVY1lHmaKYgPGSEZxBenSBHfzwDLRmC9oRp4QMe0BjOC gepj1lUoiN7OA6PtA+ycNlsP0oJvdBjhvayLiuM3tUfLTrgHbw== -----END CERTIFICATE----- # certificate for ap-northeast-2 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUbBSn2UIO6vYk4iNWV0RPxJJtHlgwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTEzMzg0NloXDTI5MDQyODEzMzg0NlowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUbBSn2UIO 6vYk4iNWV0RPxJJtHlgwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQAmjTjalG8MGLqWTC2uYqEM8nzI3px1eo0ArvFRsyqQ3fgmWcQpxExqUqRy l3+2134Kv8dFab04Gut5wlfRtc2OwPKKicmv/IXGN+9bKFnQFjTqif08NIzrDZch aFT/uvxrIiM+oN2YsHq66GUhO2+xVRXDXVxM/VObFgPERbJpyA== -----END CERTIFICATE----- # certificate for ap-northeast-3 -----BEGIN CERTIFICATE----- MIICMzCCAZygAwIBAgIGAYPou9weMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMjEwMTgwMTM2 MDlaGA8yMjAxMTAxODAxMzYwOVowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDK 1kIcG5Q6adBXQM75GldfTSiXl7tn54p10TnspI0ErDdb2B6q2Ji/v4XBVH13ZCMg qlRHMqV8AWI5iO6gFn2A9sN3AZXTMqwtZeiDdebq3k6Wt7ieYvpXTg0qvgsjQIov RZWaBDBJy9x8C2hW+w9lMQjFHkJ7Jy/PHCJ69EzebQIDAQABMA0GCSqGSIb3DQEB BQUAA4GBAGe9Snkz1A6rHBH6/5kDtYvtPYwhx2sXNxztbhkXErFk40Nw5l459NZx EeudxJBLoCkkSgYjhRcOZ/gvDVtWG7qyb6fAqgoisyAbk8K9LzxSim2S1nmT9vD8 4B/t/VvwQBylc+ej8kRxMH7fquZLp7IXfmtBzyUqu6Dpbne+chG2 -----END CERTIFICATE----- # certificate for ca-central-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUIrLgixJJB5C4G8z6pZ5rB0JU2aQwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE1MzU0M1oXDTI5MDQyODE1MzU0M1owXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUIrLgixJJ B5C4G8z6pZ5rB0JU2aQwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQBHiQJmzyFAaSYs8SpiRijIDZW2RIo7qBKb/pI3rqK6yOWDlPuMr6yNI81D IrKGGftg4Z+2KETYU4x76HSf0s//vfH3QA57qFaAwddhKYy4BhteFQl/Wex3xTlX LiwI07kwJvJy3mS6UfQ4HcvZy219tY+0iyOWrz/jVxwq7TOkCw== -----END CERTIFICATE----- # certificate for eu-west-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUakDaQ1Zqy87Hy9ESXA1pFC116HkwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE2MTgxMFoXDTI5MDQyODE2MTgxMFowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUakDaQ1Zq y87Hy9ESXA1pFC116HkwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQADIkn/MqaLGPuK5+prZZ5Ox4bBZLPtreO2C7r0pqU2kPM2lVPyYYydkvP0 lgSmmsErGu/oL9JNztDe2oCA+kNy17ehcsf8cw0uP861czNFKCeU8b7FgBbL+sIm qi33rAq6owWGi/5uEcfCR+JP7W+oSYVir5r/yDmWzx+BVH5S/g== -----END CERTIFICATE----- # certificate for eu-west-2 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUCgCV/DPxYNND/swDgEKGiC5I+EwwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE2MjkxNFoXDTI5MDQyODE2MjkxNFowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUCgCV/DPx YNND/swDgEKGiC5I+EwwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQATPu/sOE2esNa4+XPEGKlEJSgqzyBSQLQc+VWo6FAJhGG9fp7D97jhHeLC 5vwfmtTAfnGBxadfAOT3ASkxnOZhXtnRna460LtnNHm7ArCVgXKJo7uBn6ViXtFh uEEw4y6p9YaLQna+VC8Xtgw6WKq2JXuKzuhuNKSFaGGw9vRcHg== -----END CERTIFICATE----- # certificate for eu-west-3 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUaC9fX57UDr6u1vBvsCsECKBZQyIwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE2MzczOFoXDTI5MDQyODE2MzczOFowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUaC9fX57U Dr6u1vBvsCsECKBZQyIwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQCARv1bQEDaMEzYI0nPlu8GHcMXgmgA94HyrXhMMcaIlQwocGBs6VILGVhM TXP2r3JFaPEpmXSQNQHvGA13clKwAZbni8wtzv6qXb4L4muF34iQRHF0nYrEDoK7 mMPR8+oXKKuPO/mv/XKo6XAV5DDERdSYHX5kkA2R9wtvyZjPnQ== -----END CERTIFICATE----- # certificate for eu-north-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUN1c9U6U/xiVDFgJcYKZB4NkH1QEwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE2MDYwM1oXDTI5MDQyODE2MDYwM1owXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUN1c9U6U/ xiVDFgJcYKZB4NkH1QEwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQBTIQdoFSDRHkpqNPUbZ9WXR2O5v/9bpmHojMYZb3Hw46wsaRso7STiGGX/ tRqjIkPUIXsdhZ3+7S/RmhFznmZc8e0bjU4n5vi9CJtQSt+1u4E17+V2bF+D3h/7 wcfE0l3414Q8JaTDtfEf/aF3F0uyBvr4MDMd7mFvAMmDmBPSlA== -----END CERTIFICATE----- # certificate for sa-east-1 -----BEGIN CERTIFICATE----- MIIDITCCAoqgAwIBAgIUX4Bh4MQ86Roh37VDRRX1MNOB3TcwDQYJKoZIhvcNAQEL BQAwXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAO BgNVBAcTB1NlYXR0bGUxIDAeBgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExD MB4XDTI0MDQyOTE2NDYwOVoXDTI5MDQyODE2NDYwOVowXDELMAkGA1UEBhMCVVMx GTAXBgNVBAgTEFdhc2hpbmd0b24gU3RhdGUxEDAOBgNVBAcTB1NlYXR0bGUxIDAe BgNVBAoTF0FtYXpvbiBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUA A4GNADCBiQKBgQCHvRjf/0kStpJ248khtIaN8qkDN3tkw4VjvA9nvPl2anJO+eIB UqPfQG09kZlwpWpmyO8bGB2RWqWxCwuB/dcnIob6w420k9WY5C0IIGtDRNauN3ku vGXkw3HEnF0EjYr0pcyWUvByWY4KswZV42X7Y7XSS13hOIcL6NLA+H94/QIDAQAB o4HfMIHcMAsGA1UdDwQEAwIHgDAdBgNVHQ4EFgQUJdbMCBXKtvCcWdwUUizvtUF2 UTgwgZkGA1UdIwSBkTCBjoAUJdbMCBXKtvCcWdwUUizvtUF2UTihYKReMFwxCzAJ BgNVBAYTAlVTMRkwFwYDVQQIExBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHEwdT ZWF0dGxlMSAwHgYDVQQKExdBbWF6b24gV2ViIFNlcnZpY2VzIExMQ4IUX4Bh4MQ8 6Roh37VDRRX1MNOB3TcwEgYDVR0TAQH/BAgwBgEB/wIBADANBgkqhkiG9w0BAQsF AAOBgQBnhocfH6ZIX6F5K9+Y9V4HFk8vSaaKL5ytw/P5td1h9ej94KF3xkZ5fyjN URvGQv3kNmNJBoNarcP9I7JIMjsNPmVzqWawyCEGCZImoARxSS3Fc5EAs2PyBfcD 9nCtzMTaKO09Xyq0wqXVYn1xJsE5d5yBDsGrzaTHKjxo61+ezQ== -----END CERTIFICATE----- ================================================ FILE: authority/provisioner/aws_test.go ================================================ package provisioner import ( "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/hex" "encoding/pem" "errors" "fmt" "net" "net/http" "net/url" "strings" "testing" "time" "go.step.sm/crypto/jose" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" ) func TestAWS_Getters(t *testing.T) { p, err := generateAWS() assert.FatalError(t, err) aud := "aws/" + p.Name if got := p.GetID(); got != aud { t.Errorf("AWS.GetID() = %v, want %v", got, aud) } if got := p.GetName(); got != p.Name { t.Errorf("AWS.GetName() = %v, want %v", got, p.Name) } if got := p.GetType(); got != TypeAWS { t.Errorf("AWS.GetType() = %v, want %v", got, TypeAWS) } kid, key, ok := p.GetEncryptedKey() if kid != "" || key != "" || ok == true { t.Errorf("AWS.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } } func TestAWS_GetTokenID(t *testing.T) { p1, srv, err := generateAWSWithServer() assert.FatalError(t, err) defer srv.Close() p2, err := generateAWS() assert.FatalError(t, err) p2.Accounts = p1.Accounts p2.config = p1.config p2.DisableTrustOnFirstUse = true t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com") assert.FatalError(t, err) _, claims, err := parseAWSToken(t1) assert.FatalError(t, err) sum := sha256.Sum256([]byte(fmt.Sprintf("%s.%s", p1.GetID(), claims.document.InstanceID))) w1 := strings.ToLower(hex.EncodeToString(sum[:])) t2, err := p2.GetIdentityToken("foo.local", "https://ca.smallstep.com") assert.FatalError(t, err) sum = sha256.Sum256([]byte(t2)) w2 := strings.ToLower(hex.EncodeToString(sum[:])) type args struct { token string } tests := []struct { name string aws *AWS args args want string wantErr bool }{ {"ok", p1, args{t1}, w1, false}, {"ok no TOFU", p2, args{t2}, w2, false}, {"fail", p1, args{"bad-token"}, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.aws.GetTokenID(tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("AWS.GetTokenID() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("AWS.GetTokenID() = %v, want %v", got, tt.want) } }) } } func TestAWS_GetIdentityToken(t *testing.T) { p1, srv, err := generateAWSWithServer() assert.FatalError(t, err) defer srv.Close() p2, err := generateAWS() assert.FatalError(t, err) p2.Accounts = p1.Accounts p2.config.identityURL = srv.URL + "/bad-document" p2.config.signatureURL = p1.config.signatureURL p2.config.tokenURL = p1.config.tokenURL p3, err := generateAWS() assert.FatalError(t, err) p3.Accounts = p1.Accounts p3.config.signatureURL = srv.URL p3.config.identityURL = p1.config.identityURL p3.config.tokenURL = p1.config.tokenURL p4, err := generateAWS() assert.FatalError(t, err) p4.Accounts = p1.Accounts p4.config.signatureURL = srv.URL + "/bad-signature" p4.config.identityURL = p1.config.identityURL p4.config.tokenURL = p1.config.tokenURL p5, err := generateAWS() assert.FatalError(t, err) p5.Accounts = p1.Accounts p5.config.identityURL = "https://1234.1234.1234.1234" p5.config.signatureURL = p1.config.signatureURL p5.config.tokenURL = p1.config.tokenURL p6, err := generateAWS() assert.FatalError(t, err) p6.Accounts = p1.Accounts p6.config.identityURL = p1.config.identityURL p6.config.signatureURL = "https://1234.1234.1234.1234" p6.config.tokenURL = p1.config.tokenURL p7, err := generateAWS() assert.FatalError(t, err) p7.Accounts = p1.Accounts p7.config.identityURL = srv.URL + "/bad-json" p7.config.signatureURL = p1.config.signatureURL p7.config.tokenURL = p1.config.tokenURL p8, err := generateAWS() assert.FatalError(t, err) p8.IMDSVersions = nil p8.Accounts = p1.Accounts p8.config = p1.config caURL := "https://ca.smallstep.com" u, err := url.Parse(caURL) assert.FatalError(t, err) type args struct { subject string caURL string } tests := []struct { name string aws *AWS args args wantErr bool }{ {"ok", p1, args{"foo.local", caURL}, false}, {"ok no imds", p8, args{"foo.local", caURL}, false}, {"fail ca url", p1, args{"foo.local", "://ca.smallstep.com"}, true}, {"fail identityURL", p2, args{"foo.local", caURL}, true}, {"fail signatureURL", p3, args{"foo.local", caURL}, true}, {"fail signature", p4, args{"foo.local", caURL}, true}, {"fail read identityURL", p5, args{"foo.local", caURL}, true}, {"fail read signatureURL", p6, args{"foo.local", caURL}, true}, {"fail unmarshal identityURL", p7, args{"foo.local", caURL}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.aws.GetIdentityToken(tt.args.subject, tt.args.caURL) if (err != nil) != tt.wantErr { t.Errorf("AWS.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr == false { _, c, err := parseAWSToken(got) if assert.NoError(t, err) { assert.Equals(t, awsIssuer, c.Issuer) assert.Equals(t, tt.args.subject, c.Subject) assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: tt.aws.GetID()}).String()}, c.Audience) assert.Equals(t, tt.aws.Accounts[0], c.document.AccountID) for _, crt := range tt.aws.config.certificates { err = crt.CheckSignature(tt.aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature) if err == nil { break } } assert.NoError(t, err) } } }) } } func TestAWS_GetIdentityToken_V1Only(t *testing.T) { aws, srv, err := generateAWSWithServerV1Only() assert.FatalError(t, err) defer srv.Close() subject := "foo.local" caURL := "https://ca.smallstep.com" u, err := url.Parse(caURL) assert.Nil(t, err) token, err := aws.GetIdentityToken(subject, caURL) assert.Nil(t, err) _, c, err := parseAWSToken(token) if assert.NoError(t, err) { assert.Equals(t, awsIssuer, c.Issuer) assert.Equals(t, subject, c.Subject) assert.Equals(t, jose.Audience{u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: aws.GetID()}).String()}, c.Audience) assert.Equals(t, aws.Accounts[0], c.document.AccountID) for _, crt := range aws.config.certificates { err = crt.CheckSignature(aws.config.signatureAlgorithm, c.Amazon.Document, c.Amazon.Signature) if err == nil { break } } assert.NoError(t, err) } } func TestAWS_GetIdentityToken_BadIDMS(t *testing.T) { aws, srv, err := generateAWSWithServer() aws.IMDSVersions = []string{"bad"} assert.FatalError(t, err) defer srv.Close() subject := "foo.local" caURL := "https://ca.smallstep.com" token, err := aws.GetIdentityToken(subject, caURL) assert.Equals(t, token, "") badIDMS := errors.New("bad: not a supported AWS Instance Metadata Service version") assert.HasSuffix(t, err.Error(), badIDMS.Error()) } func TestAWS_Init(t *testing.T) { config := Config{ Claims: globalProvisionerClaims, } badClaims := &Claims{ DefaultTLSDur: &Duration{0}, } zero := Duration{Duration: 0} type fields struct { Type string Name string Accounts []string DisableCustomSANs bool DisableTrustOnFirstUse bool InstanceAge Duration IMDSVersions []string IIDRoots string Claims *Claims } type args struct { config Config } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, "", nil}, args{config}, false}, {"ok/v1", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1"}, "", nil}, args{config}, false}, {"ok/v2", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v2"}, "", nil}, args{config}, false}, {"ok/empty", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{}, "", nil}, args{config}, false}, {"ok/duration", fields{"AWS", "name", []string{"account"}, true, true, Duration{Duration: 1 * time.Minute}, []string{"v1", "v2"}, "", nil}, args{config}, false}, {"ok/cert", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, "testdata/certs/aws.crt", nil}, args{config}, false}, {"fail type ", fields{"", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, "", nil}, args{config}, true}, {"fail name", fields{"AWS", "", []string{"account"}, false, false, zero, []string{"v1", "v2"}, "", nil}, args{config}, true}, {"bad instance age", fields{"AWS", "name", []string{"account"}, false, false, Duration{Duration: -1 * time.Minute}, []string{"v1", "v2"}, "", nil}, args{config}, true}, {"fail/imds", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"bad"}, "", nil}, args{config}, true}, {"fail/missing", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"bad"}, "testdata/missing.crt", nil}, args{config}, true}, {"fail/cert", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"bad"}, "testdata/certs/rsa.csr", nil}, args{config}, true}, {"fail claims", fields{"AWS", "name", []string{"account"}, false, false, zero, []string{"v1", "v2"}, "", badClaims}, args{config}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &AWS{ Type: tt.fields.Type, Name: tt.fields.Name, Accounts: tt.fields.Accounts, DisableCustomSANs: tt.fields.DisableCustomSANs, DisableTrustOnFirstUse: tt.fields.DisableTrustOnFirstUse, InstanceAge: tt.fields.InstanceAge, IMDSVersions: tt.fields.IMDSVersions, IIDRoots: tt.fields.IIDRoots, Claims: tt.fields.Claims, } if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { t.Errorf("AWS.Init() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestAWS_authorizeToken(t *testing.T) { block, _ := pem.Decode([]byte(awsTestKey)) if block == nil || block.Type != "RSA PRIVATE KEY" { t.Fatal("error decoding AWS key") } key, err := x509.ParsePKCS1PrivateKey(block.Bytes) assert.FatalError(t, err) badKey, err := rsa.GenerateKey(rand.Reader, 2048) assert.FatalError(t, err) type test struct { p *AWS token string err error code int } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; error parsing aws token"), } }, "fail/cannot-validate-sig": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), badKey) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; invalid aws token signature"), } }, "fail/empty-account-id": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", awsIssuer, p.GetID(), "", "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; aws identity document accountId cannot be empty"), } }, "fail/empty-instance-id": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; aws identity document instanceId cannot be empty"), } }, "fail/empty-private-ip": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", "", "us-west-1", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; aws identity document privateIp cannot be empty"), } }, "fail/empty-region": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", "127.0.0.1", "", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; aws identity document region cannot be empty"), } }, "fail/invalid-token-issuer": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", "bad-issuer", p.GetID(), p.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; invalid aws token"), } }, "fail/invalid-audience": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", awsIssuer, "bad-audience", p.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; invalid token - invalid audience claim (aud)"), } }, "fail/invalid-subject-disabled-custom-SANs": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) p.DisableCustomSANs = true tok, err := generateAWSToken( p, "foo", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; invalid token - invalid subject claim (sub)"), } }, "fail/invalid-account-id": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", awsIssuer, p.GetID(), "foo", "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; invalid aws identity document - accountId is not valid"), } }, "fail/instance-age": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) p.InstanceAge = Duration{1 * time.Minute} tok, err := generateAWSToken( p, "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now().Add(-1*time.Minute), key) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("aws.authorizeToken; aws identity document pendingTime is too old"), } }, "ok": func(t *testing.T) test { p, err := generateAWS() assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, } }, "ok/identityCert": func(t *testing.T) test { p, err := generateAWS() p.IIDRoots = "testdata/certs/aws-test.crt" assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, } }, "ok/identityCert2": func(t *testing.T) test { p, err := generateAWS() p.IIDRoots = "testdata/certs/aws.crt" assert.FatalError(t, err) tok, err := generateAWSToken( p, "instance-id", awsIssuer, p.GetID(), p.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) && assert.NotNil(t, claims) { assert.Equals(t, claims.Subject, "instance-id") assert.Equals(t, claims.Issuer, awsIssuer) assert.NotNil(t, claims.Amazon) aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID()) assert.FatalError(t, err) assert.Equals(t, claims.Audience[0], aud) } } }) } } func TestAWS_AuthorizeSign(t *testing.T) { p1, srv, err := generateAWSWithServer() assert.FatalError(t, err) defer srv.Close() p2, err := generateAWS() assert.FatalError(t, err) p2.Accounts = p1.Accounts p2.config = p1.config p2.DisableCustomSANs = true p2.InstanceAge = Duration{1 * time.Minute} p3, err := generateAWS() assert.FatalError(t, err) p3.config = p1.config t1, err := p1.GetIdentityToken("foo.local", "https://ca.smallstep.com") assert.FatalError(t, err) t2, err := p2.GetIdentityToken("instance-id", "https://ca.smallstep.com") assert.FatalError(t, err) assert.FatalError(t, err) t3, err := p3.GetIdentityToken("foo.local", "https://ca.smallstep.com") assert.FatalError(t, err) // Alternative common names with DisableCustomSANs = true t2PrivateIP, err := p2.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com") assert.FatalError(t, err) t2Hostname, err := p2.GetIdentityToken("ip-127-0-0-1.us-west-1.compute.internal", "https://ca.smallstep.com") assert.FatalError(t, err) block, _ := pem.Decode([]byte(awsTestKey)) if block == nil || block.Type != "RSA PRIVATE KEY" { t.Fatal("error decoding AWS key") } key, err := x509.ParsePKCS1PrivateKey(block.Bytes) assert.FatalError(t, err) badKey, err := rsa.GenerateKey(rand.Reader, 2048) assert.FatalError(t, err) t4, err := generateAWSToken( p1, "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) failSubject, err := generateAWSToken( p2, "bad-subject", awsIssuer, p2.GetID(), p2.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) failIssuer, err := generateAWSToken( p1, "instance-id", "bad-issuer", p1.GetID(), p1.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) failAudience, err := generateAWSToken( p1, "instance-id", awsIssuer, "bad-audience", p1.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) failAccount, err := generateAWSToken( p1, "instance-id", awsIssuer, p1.GetID(), "", "instance-id", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) failInstanceID, err := generateAWSToken( p1, "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "", "127.0.0.1", "us-west-1", time.Now(), key) assert.FatalError(t, err) failPrivateIP, err := generateAWSToken( p1, "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", "", "us-west-1", time.Now(), key) assert.FatalError(t, err) failRegion, err := generateAWSToken( p1, "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", "127.0.0.1", "", time.Now(), key) assert.FatalError(t, err) failExp, err := generateAWSToken( p1, "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now().Add(-360*time.Second), key) assert.FatalError(t, err) failNbf, err := generateAWSToken( p1, "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now().Add(360*time.Second), key) assert.FatalError(t, err) failKey, err := generateAWSToken( p1, "instance-id", awsIssuer, p1.GetID(), p1.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now(), badKey) assert.FatalError(t, err) failInstanceAge, err := generateAWSToken( p2, "instance-id", awsIssuer, p2.GetID(), p2.Accounts[0], "instance-id", "127.0.0.1", "us-west-1", time.Now().Add(-1*time.Minute), key) assert.FatalError(t, err) type args struct { token, cn string } tests := []struct { name string aws *AWS args args wantLen int code int wantErr bool }{ {"ok", p1, args{t1, "foo.local"}, 9, http.StatusOK, false}, {"ok", p2, args{t2, "instance-id"}, 13, http.StatusOK, false}, {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 13, http.StatusOK, false}, {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 13, http.StatusOK, false}, {"ok", p1, args{t4, "instance-id"}, 9, http.StatusOK, false}, {"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true}, {"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true}, {"fail issuer", p1, args{token: failIssuer}, 0, http.StatusUnauthorized, true}, {"fail audience", p1, args{token: failAudience}, 0, http.StatusUnauthorized, true}, {"fail account", p1, args{token: failAccount}, 0, http.StatusUnauthorized, true}, {"fail instanceID", p1, args{token: failInstanceID}, 0, http.StatusUnauthorized, true}, {"fail privateIP", p1, args{token: failPrivateIP}, 0, http.StatusUnauthorized, true}, {"fail region", p1, args{token: failRegion}, 0, http.StatusUnauthorized, true}, {"fail exp", p1, args{token: failExp}, 0, http.StatusUnauthorized, true}, {"fail nbf", p1, args{token: failNbf}, 0, http.StatusUnauthorized, true}, {"fail key", p1, args{token: failKey}, 0, http.StatusUnauthorized, true}, {"fail instance age", p2, args{token: failInstanceAge}, 0, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) switch got, err := tt.aws.AuthorizeSign(ctx, tt.args.token); { case (err != nil) != tt.wantErr: t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Equals(t, tt.wantLen, len(got)) for _, o := range got { switch v := o.(type) { case *AWS: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeAWS) assert.Equals(t, v.Name, tt.aws.GetName()) assert.Equals(t, v.CredentialID, tt.aws.Accounts[0]) assert.Len(t, 2, v.KeyValuePairs) case profileDefaultDuration: assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), tt.args.cn) case defaultPublicKeyValidator: case *validityValidator: assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) case emailAddressesValidator: assert.Equals(t, v, nil) case *urisValidator: assert.Equals(t, v.uris, nil) assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case dnsNamesSubsetValidator: assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) case *x509NamePolicyValidator: assert.Equals(t, nil, v.policyEngine) case *WebhookController: assert.Len(t, 0, v.webhooks) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } }) } } func TestAWS_AuthorizeSSHSign(t *testing.T) { tm, fn := mockNow() defer fn() p1, srv, err := generateAWSWithServer() assert.FatalError(t, err) p1.DisableCustomSANs = true defer srv.Close() p2, err := generateAWS() assert.FatalError(t, err) p2.Accounts = p1.Accounts p2.config = p1.config p2.DisableCustomSANs = false p3, err := generateAWS() assert.FatalError(t, err) // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com") assert.FatalError(t, err) t2, err := p2.GetIdentityToken("foo.local", "https://ca.smallstep.com") assert.FatalError(t, err) key, err := generateJSONWebKey() assert.FatalError(t, err) signer, err := generateJSONWebKey() assert.FatalError(t, err) pub := key.Public().Key rsa2048, err := rsa.GenerateKey(rand.Reader, 2048) assert.FatalError(t, err) //nolint:gosec // tests minimum size of the key rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } expectedHostOptionsIP := &SignSSHOptions{ CertType: "host", Principals: []string{"127.0.0.1"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } expectedHostOptionsHostname := &SignSSHOptions{ CertType: "host", Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } expectedCustomOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"foo.local"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } type args struct { token string sshOpts SignSSHOptions key interface{} } tests := []struct { name string aws *AWS args args expected *SignSSHOptions code int wantErr bool wantSignErr bool }{ {"ok", p1, args{t1, SignSSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-rsa2048", p1, args{t1, SignSSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false}, {"ok-type", p1, args{t1, SignSSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-principals", p1, args{t1, SignSSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-principal-ip", p1, args{t1, SignSSHOptions{Principals: []string{"127.0.0.1"}}, pub}, expectedHostOptionsIP, http.StatusOK, false, false}, {"ok-principal-hostname", p1, args{t1, SignSSHOptions{Principals: []string{"ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptionsHostname, http.StatusOK, false, false}, {"ok-options", p1, args{t1, SignSSHOptions{CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-custom", p2, args{t2, SignSSHOptions{Principals: []string{"foo.local"}}, pub}, expectedCustomOptions, http.StatusOK, false, false}, {"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true}, {"fail-type", p1, args{t1, SignSSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true}, {"fail-principal", p1, args{t1, SignSSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, {"fail-extra-principal", p1, args{t1, SignSSHOptions{Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, {"fail-sshCA-disabled", p3, args{"foo", SignSSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, {"fail-invalid-token", p1, args{"foo", SignSSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.aws.AuthorizeSSHSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("AWS.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) if (err != nil) != tt.wantSignErr { t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) } else { if tt.wantSignErr { assert.Nil(t, cert) } else { assert.NoError(t, validateSSHCertificate(cert, tt.expected)) } } } }) } } func TestAWS_AuthorizeRenew(t *testing.T) { now := time.Now().Truncate(time.Second) p1, err := generateAWS() assert.FatalError(t, err) p2, err := generateAWS() assert.FatalError(t, err) // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { cert *x509.Certificate } tests := []struct { name string aws *AWS args args code int wantErr bool }{ {"ok", p1, args{&x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, http.StatusOK, false}, {"fail/renew-disabled", p2, args{&x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) } } func TestAWS_HardcodedCertificates(t *testing.T) { certBytes := []byte(awsCertificate) var certs []*x509.Certificate for len(certBytes) > 0 { var block *pem.Block block, certBytes = pem.Decode(certBytes) if block == nil { break } if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { continue } cert, err := x509.ParseCertificate(block.Bytes) assert.FatalError(t, err) // check that the certificate is not expired assert.True(t, cert.NotAfter.After(time.Now())) certs = append(certs, cert) } assert.Len(t, 33, certs, "expected 33 certificates in aws_certificates.pem, but got %d", len(certs)) } ================================================ FILE: authority/provisioner/azure.go ================================================ package provisioner import ( "context" "crypto/sha256" "crypto/x509" "encoding/hex" "encoding/json" "io" "net/http" "regexp" "strings" "time" "github.com/pkg/errors" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/webhook" ) // azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens. const azureOIDCBaseURL = "https://login.microsoftonline.com" //nolint:gosec // azureIdentityTokenURL is the URL to get the identity token for an instance. const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token" const azureIdentityTokenAPIVersion = "2018-02-01" // azureInstanceComputeURL is the URL to get the instance compute metadata. const azureInstanceComputeURL = "http://169.254.169.254/metadata/instance/compute/azEnvironment" // azureDefaultAudience is the default audience used. const azureDefaultAudience = "https://management.azure.com/" // azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim. // Using case insensitive as resourceGroups appears as resourcegroups. var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`) // azureEnvironments is the list of all Azure environments. var azureEnvironments = map[string]string{ "AzurePublicCloud": "https://management.azure.com/", "AzureCloud": "https://management.azure.com/", "AzureUSGovernmentCloud": "https://management.usgovcloudapi.net/", "AzureUSGovernment": "https://management.usgovcloudapi.net/", "AzureChinaCloud": "https://management.chinacloudapi.cn/", "AzureGermanCloud": "https://management.microsoftazure.de/", } type azureConfig struct { oidcDiscoveryURL string identityTokenURL string instanceComputeURL string } func newAzureConfig(tenantID string) *azureConfig { return &azureConfig{ oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration", identityTokenURL: azureIdentityTokenURL, instanceComputeURL: azureInstanceComputeURL, } } type azureIdentityToken struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ClientID string `json:"client_id"` ExpiresIn int64 `json:"expires_in,string"` ExpiresOn int64 `json:"expires_on,string"` ExtExpiresIn int64 `json:"ext_expires_in,string"` NotBefore int64 `json:"not_before,string"` Resource string `json:"resource"` TokenType string `json:"token_type"` } type azurePayload struct { jose.Claims AppID string `json:"appid"` AppIDAcr string `json:"appidacr"` IdentityProvider string `json:"idp"` ObjectID string `json:"oid"` TenantID string `json:"tid"` Version string `json:"ver"` XMSMirID string `json:"xms_mirid"` } // Azure is the provisioner that supports identity tokens created from the // Microsoft Azure Instance Metadata service. // // The default audience is "https://management.azure.com/". // // If DisableCustomSANs is true, only the internal DNS and IP will be added as a // SAN. By default it will accept any SAN in the CSR. // // If DisableTrustOnFirstUse is true, multiple sign request for this provisioner // with the same instance will be accepted. By default only the first request // will be accepted. // // Microsoft Azure identity docs are available at // https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token // and https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service type Azure struct { *base ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` TenantID string `json:"tenantID"` ResourceGroups []string `json:"resourceGroups"` SubscriptionIDs []string `json:"subscriptionIDs"` ObjectIDs []string `json:"objectIDs"` Audience string `json:"audience,omitempty"` DisableCustomSANs bool `json:"disableCustomSANs"` DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` config *azureConfig oidcConfig openIDConfiguration keyStore *keyStore ctl *Controller environment string } // GetID returns the provisioner unique identifier. func (p *Azure) GetID() string { if p.ID != "" { return p.ID } return p.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (p *Azure) GetIDForToken() string { return p.TenantID } // GetTokenID returns the identifier of the token. The default value for Azure // the SHA256 of "xms_mirid", but if DisableTrustOnFirstUse is set to true, then // it will be the token kid. func (p *Azure) GetTokenID(token string) (string, error) { jwt, err := jose.ParseSigned(token) if err != nil { return "", errors.Wrap(err, "error parsing token") } // Get claims w/out verification. We need to look up the provisioner // key in order to verify the claims and we need the issuer from the claims // before we can look up the provisioner. var claims azurePayload if err = jwt.UnsafeClaimsWithoutVerification(&claims); err != nil { return "", errors.Wrap(err, "error verifying claims") } // If TOFU is disabled then allow token re-use. Azure caches the token for // 24h and without allowing the re-use we cannot use it twice. if p.DisableTrustOnFirstUse { return "", ErrAllowTokenReuse } sum := sha256.Sum256([]byte(claims.XMSMirID)) return strings.ToLower(hex.EncodeToString(sum[:])), nil } // GetName returns the name of the provisioner. func (p *Azure) GetName() string { return p.Name } // GetType returns the type of provisioner. func (p *Azure) GetType() Type { return TypeAzure } // GetEncryptedKey is not available in an Azure provisioner. func (p *Azure) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } // GetIdentityToken retrieves from the metadata service the identity token and // returns it. func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) { _, _ = subject, caURL // unused input // Initialize the config if this method is used from the cli. p.assertConfig() // default to AzurePublicCloud to keep existing behavior identityTokenResource := azureEnvironments["AzurePublicCloud"] var err error p.environment, err = p.getAzureEnvironment() if err != nil { return "", errors.Wrap(err, "error getting azure environment") } if resource, ok := azureEnvironments[p.environment]; ok { identityTokenResource = resource } req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody) if err != nil { return "", errors.Wrap(err, "error creating request") } req.Header.Set("Metadata", "true") query := req.URL.Query() query.Add("resource", identityTokenResource) query.Add("api-version", azureIdentityTokenAPIVersion) req.URL.RawQuery = query.Encode() resp, err := http.DefaultClient.Do(req) if err != nil { return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?") } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return "", errors.Wrap(err, "error reading identity token response") } if resp.StatusCode >= 400 { return "", errors.Errorf("error getting identity token: status=%d, response=%s", resp.StatusCode, b) } var identityToken azureIdentityToken if err := json.Unmarshal(b, &identityToken); err != nil { return "", errors.Wrap(err, "error unmarshaling identity token response") } return identityToken.AccessToken, nil } // Init validates and initializes the Azure provisioner. func (p *Azure) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") case p.Name == "": return errors.New("provisioner name cannot be empty") case p.TenantID == "": return errors.New("provisioner tenantId cannot be empty") case p.Audience == "": // use default audience p.Audience = azureDefaultAudience } // Initialize config p.assertConfig() // Decode and validate openid-configuration endpoint if err = getAndDecode(http.DefaultClient, p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { return } if err := p.oidcConfig.Validate(); err != nil { return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL) } // Get JWK key set if p.keyStore, err = newKeyStore(http.DefaultClient, p.oidcConfig.JWKSetURI); err != nil { return } p.ctl, err = NewController(p, p.Claims, config, p.Options) return } // authorizeToken returns the claims, name, group, subscription, identityObjectID, error. func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, string, string, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, "", "", "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token") } if len(jwt.Headers) == 0 { return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header") } var found bool var claims azurePayload keys := p.keyStore.Get(jwt.Headers[0].KeyID) for _, key := range keys { if err := jwt.Claims(key.Public(), &claims); err == nil { found = true break } } if !found { return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token") } if err := claims.ValidateWithLeeway(jose.Expected{ Audience: []string{p.Audience}, Issuer: p.oidcConfig.Issuer, Time: time.Now(), }, 1*time.Minute); err != nil { return nil, "", "", "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload") } // Validate TenantID if claims.TenantID != p.TenantID { return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)") } re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) if len(re) != 5 { return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID) } var subscription, group, name string identityObjectID := claims.ObjectID subscription, group, name = re[1], re[2], re[4] return &claims, name, group, subscription, identityObjectID, nil } // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { _, name, group, subscription, identityObjectID, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign") } // Filter by resource group if len(p.ResourceGroups) > 0 { var found bool for _, g := range p.ResourceGroups { if g == group { found = true break } } if !found { return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid resource group") } } // Filter by subscription id if len(p.SubscriptionIDs) > 0 { var found bool for _, s := range p.SubscriptionIDs { if s == subscription { found = true break } } if !found { return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid subscription id") } } // Filter by Azure AD identity object id if len(p.ObjectIDs) > 0 { var found bool for _, i := range p.ObjectIDs { if i == identityObjectID { found = true break } } if !found { return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid identity object id") } } // Template options data := x509util.NewTemplateData() data.SetCommonName(name) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } // Enforce known common name and default DNS if configured. // By default we'll accept the CN and SANs in the CSR. // There's no way to trust them other than TOFU. var so []SignOption if p.DisableCustomSANs { // name will work only inside the virtual network so = append(so, commonNameValidator(name), dnsNamesSubsetValidator([]string{name}), ipAddressesValidator(nil), emailAddressesValidator(nil), newURIsValidator(ctx, nil), ) // Enforce SANs in the template. data.SetSANs([]string{name}) } templateOptions, err := CustomTemplateOptions(p.Options, data, x509util.DefaultIIDLeafTemplate) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign") } return append(so, p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID).WithControllerOptions(p.ctl), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), p.ctl.newWebhookController( data, linkedca.Webhook_X509, webhook.WithAuthorizationPrincipal(identityObjectID), ), ), nil } // AuthorizeRenew returns an error if the renewal is disabled. // NOTE: This method does not actually validate the certificate or check it's // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *Azure) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName()) } _, name, _, _, identityObjectID, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign") } signOptions := []SignOption{} // Enforce host certificate. defaults := SignSSHOptions{ CertType: SSHHostCert, } // Validated principals. principals := []string{name} // Only enforce known principals if disable custom sans is true. if p.DisableCustomSANs { defaults.Principals = principals } else { // Check that at least one principal is sent in the request. signOptions = append(signOptions, &sshCertOptionsRequireValidator{ Principals: true, }) } // Certificate templates. data := sshutil.CreateTemplateData(sshutil.HostCert, name, principals) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } templateOptions, err := CustomSSHTemplateOptions(p.Options, data, sshutil.DefaultIIDTemplate) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign") } signOptions = append(signOptions, templateOptions) return append(signOptions, p, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, // Ensure that all principal names are allowed newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), // Call webhooks p.ctl.newWebhookController( data, linkedca.Webhook_SSH, webhook.WithAuthorizationPrincipal(identityObjectID), ), ), nil } // assertConfig initializes the config if it has not been initialized func (p *Azure) assertConfig() { if p.config == nil { p.config = newAzureConfig(p.TenantID) } } // getAzureEnvironment returns the Azure environment for the current instance func (p *Azure) getAzureEnvironment() (string, error) { if p.environment != "" { return p.environment, nil } req, err := http.NewRequest("GET", p.config.instanceComputeURL, http.NoBody) if err != nil { return "", errors.Wrap(err, "error creating request") } req.Header.Add("Metadata", "True") query := req.URL.Query() query.Add("format", "text") query.Add("api-version", "2021-02-01") req.URL.RawQuery = query.Encode() resp, err := http.DefaultClient.Do(req) if err != nil { return "", errors.Wrap(err, "error getting azure instance environment, are you in a Azure VM?") } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return "", errors.Wrap(err, "error reading azure environment response") } if resp.StatusCode >= 400 { return "", errors.Errorf("error getting azure environment: status=%d, response=%s", resp.StatusCode, b) } return string(b), nil } ================================================ FILE: authority/provisioner/azure_test.go ================================================ package provisioner import ( "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/hex" "errors" "fmt" "net/http" "net/http/httptest" "strings" "testing" "time" "go.step.sm/crypto/jose" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" ) func TestAzure_Getters(t *testing.T) { p, err := generateAzure() assert.FatalError(t, err) if got := p.GetID(); got != p.TenantID { t.Errorf("Azure.GetID() = %v, want %v", got, p.TenantID) } if got := p.GetName(); got != p.Name { t.Errorf("Azure.GetName() = %v, want %v", got, p.Name) } if got := p.GetType(); got != TypeAzure { t.Errorf("Azure.GetType() = %v, want %v", got, TypeAzure) } kid, key, ok := p.GetEncryptedKey() if kid != "" || key != "" || ok == true { t.Errorf("Azure.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } } func TestAzure_GetTokenID(t *testing.T) { p1, srv, err := generateAzureWithServer() assert.FatalError(t, err) defer srv.Close() p2, err := generateAzure() assert.FatalError(t, err) p2.TenantID = p1.TenantID p2.config = p1.config p2.oidcConfig = p1.oidcConfig p2.keyStore = p1.keyStore p2.DisableTrustOnFirstUse = true t1, err := p1.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) t2, err := p2.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) sum := sha256.Sum256([]byte("/subscriptions/subscriptionID/resourceGroups/resourceGroup/providers/Microsoft.Compute/virtualMachines/virtualMachine")) w1 := strings.ToLower(hex.EncodeToString(sum[:])) type args struct { token string } tests := []struct { name string azure *Azure args args want string wantErr bool }{ {"ok", p1, args{t1}, w1, false}, {"ok no TOFU", p2, args{t2}, "", true}, {"fail token", p1, args{"bad-token"}, "", true}, {"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.azure.GetTokenID(tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Azure.GetTokenID() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("Azure.GetTokenID() = %v, want %v", got, tt.want) } }) } } func TestAzure_GetIdentityToken(t *testing.T) { p1, err := generateAzure() assert.FatalError(t, err) t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) srvIdentity := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wantResource := r.URL.Query().Get("want_resource") resource := r.URL.Query().Get("resource") if wantResource == "" || resource != wantResource { http.Error(w, fmt.Sprintf("Azure query param resource = %s, wantResource %s", resource, wantResource), http.StatusBadRequest) return } switch r.URL.Path { case "/bad-request": http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) case "/bad-json": w.Write([]byte(t1)) default: w.Header().Add("Content-Type", "application/json") fmt.Fprintf(w, `{"access_token":"%s"}`, t1) } })) defer srvIdentity.Close() srvInstance := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/bad-request": http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) case "/AzureChinaCloud": w.Header().Add("Content-Type", "text/plain") w.Write([]byte("AzureChinaCloud")) case "/AzureGermanCloud": w.Header().Add("Content-Type", "text/plain") w.Write([]byte("AzureGermanCloud")) case "/AzureUSGovernmentCloud": w.Header().Add("Content-Type", "text/plain") w.Write([]byte("AzureUSGovernmentCloud")) default: w.Header().Add("Content-Type", "text/plain") w.Write([]byte("AzurePublicCloud")) } })) defer srvInstance.Close() type args struct { subject string caURL string } tests := []struct { name string azure *Azure args args identityTokenURL string instanceComputeURL string wantEnvironment string want string wantErr bool }{ {"ok", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false}, {"ok azure china", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false}, {"ok azure germany", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureGermanCloud", t1, false}, {"ok azure us gov", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureUSGovernmentCloud", t1, false}, {"fail instance request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL + "/bad-request", "AzurePublicCloud", "", true}, {"fail request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL, "AzurePublicCloud", "", true}, {"fail unmarshal", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-json", srvInstance.URL, "AzurePublicCloud", "", true}, {"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", srvInstance.URL, "AzurePublicCloud", "", true}, {"fail connect", p1, args{"subject", "caURL"}, "foobarzar", srvInstance.URL, "AzurePublicCloud", "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // reset environment between tests to avoid caching issues p1.environment = "" tt.azure.config.identityTokenURL = tt.identityTokenURL + "?want_resource=" + azureEnvironments[tt.wantEnvironment] tt.azure.config.instanceComputeURL = tt.instanceComputeURL + "/" + tt.wantEnvironment got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL) if (err != nil) != tt.wantErr { t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("Azure.GetIdentityToken() = %v, want %v", got, tt.want) } }) } } func TestAzure_Init(t *testing.T) { p1, srv, err := generateAzureWithServer() assert.FatalError(t, err) defer srv.Close() config := Config{ Claims: globalProvisionerClaims, } badClaims := &Claims{ DefaultTLSDur: &Duration{0}, } badDiscoveryURL := &azureConfig{ oidcDiscoveryURL: srv.URL + "/error", identityTokenURL: p1.config.identityTokenURL, } badJWKURL := &azureConfig{ oidcDiscoveryURL: srv.URL + "/openid-configuration-fail-jwk", identityTokenURL: p1.config.identityTokenURL, } badAzureConfig := &azureConfig{ oidcDiscoveryURL: srv.URL + "/openid-configuration-no-issuer", identityTokenURL: p1.config.identityTokenURL, } type fields struct { Type string Name string TenantID string Claims *Claims config *azureConfig } type args struct { config Config } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{p1.Type, p1.Name, p1.TenantID, nil, p1.config}, args{config}, false}, {"ok with config", fields{p1.Type, p1.Name, p1.TenantID, nil, p1.config}, args{config}, false}, {"fail type", fields{"", p1.Name, p1.TenantID, nil, p1.config}, args{config}, true}, {"fail name", fields{p1.Type, "", p1.TenantID, nil, p1.config}, args{config}, true}, {"fail tenant id", fields{p1.Type, p1.Name, "", nil, p1.config}, args{config}, true}, {"fail claims", fields{p1.Type, p1.Name, p1.TenantID, badClaims, p1.config}, args{config}, true}, {"fail discovery URL", fields{p1.Type, p1.Name, p1.TenantID, nil, badDiscoveryURL}, args{config}, true}, {"fail JWK URL", fields{p1.Type, p1.Name, p1.TenantID, nil, badJWKURL}, args{config}, true}, {"fail config Validate", fields{p1.Type, p1.Name, p1.TenantID, nil, badAzureConfig}, args{config}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Azure{ Type: tt.fields.Type, Name: tt.fields.Name, TenantID: tt.fields.TenantID, Claims: tt.fields.Claims, config: tt.fields.config, } if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { t.Errorf("Azure.Init() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestAzure_authorizeToken(t *testing.T) { type test struct { p *Azure token string err error code int } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { p, err := generateAzure() assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("azure.authorizeToken; error parsing azure token"), } }, "fail/cannot-validate-sig": func(t *testing.T) test { p, srv, err := generateAzureWithServer() assert.FatalError(t, err) defer srv.Close() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), jwk) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("azure.authorizeToken; cannot validate azure token"), } }, "fail/invalid-token-issuer": func(t *testing.T) test { p, srv, err := generateAzureWithServer() assert.FatalError(t, err) defer srv.Close() tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("azure.authorizeToken; failed to validate azure token payload"), } }, "fail/invalid-tenant-id": func(t *testing.T) test { p, srv, err := generateAzureWithServer() assert.FatalError(t, err) defer srv.Close() tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, "foo", "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)"), } }, "fail/invalid-xms-mir-id": func(t *testing.T) test { p, srv, err := generateAzureWithServer() assert.FatalError(t, err) defer srv.Close() jwk := &p.keyStore.keySet.Keys[0] sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), ) assert.FatalError(t, err) now := time.Now() claims := azurePayload{ Claims: jose.Claims{ Subject: "subject", Issuer: p.oidcConfig.Issuer, IssuedAt: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Audience: []string{azureDefaultAudience}, ID: "the-jti", }, AppID: "the-appid", AppIDAcr: "the-appidacr", IdentityProvider: "the-idp", ObjectID: "the-oid", TenantID: p.TenantID, Version: "the-version", XMSMirID: "foo", } tok, err := jose.Signed(sig).Claims(claims).CompactSerialize() assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("azure.authorizeToken; error parsing xms_mirid claim - foo"), } }, "ok": func(t *testing.T) test { p, srv, err := generateAzureWithServer() assert.FatalError(t, err) defer srv.Close() tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if claims, name, group, subscriptionID, objectID, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, claims.Subject, "subject") assert.Equals(t, claims.Issuer, tc.p.oidcConfig.Issuer) assert.Equals(t, claims.Audience[0], azureDefaultAudience) assert.Equals(t, name, "virtualMachine") assert.Equals(t, group, "resourceGroup") assert.Equals(t, subscriptionID, "subscriptionID") assert.Equals(t, objectID, "the-oid") } } }) } } func TestAzure_AuthorizeSign(t *testing.T) { p1, srv, err := generateAzureWithServer() assert.FatalError(t, err) defer srv.Close() p2, err := generateAzure() assert.FatalError(t, err) p2.TenantID = p1.TenantID p2.ResourceGroups = []string{"resourceGroup"} p2.config = p1.config p2.oidcConfig = p1.oidcConfig p2.keyStore = p1.keyStore p2.DisableCustomSANs = true p3, err := generateAzure() assert.FatalError(t, err) p3.config = p1.config p3.oidcConfig = p1.oidcConfig p3.keyStore = p1.keyStore p4, err := generateAzure() assert.FatalError(t, err) p4.TenantID = p1.TenantID p4.ResourceGroups = []string{"foobarzar"} p4.config = p1.config p4.oidcConfig = p1.oidcConfig p4.keyStore = p1.keyStore p5, err := generateAzure() assert.FatalError(t, err) p5.TenantID = p1.TenantID p5.SubscriptionIDs = []string{"subscriptionID"} p5.config = p1.config p5.oidcConfig = p1.oidcConfig p5.keyStore = p1.keyStore p6, err := generateAzure() assert.FatalError(t, err) p6.TenantID = p1.TenantID p6.SubscriptionIDs = []string{"foobarzar"} p6.config = p1.config p6.oidcConfig = p1.oidcConfig p6.keyStore = p1.keyStore p7, err := generateAzure() assert.FatalError(t, err) p7.TenantID = p1.TenantID p7.ObjectIDs = []string{"the-oid"} p7.config = p1.config p7.oidcConfig = p1.oidcConfig p7.keyStore = p1.keyStore p8, err := generateAzure() assert.FatalError(t, err) p8.TenantID = p1.TenantID p8.ObjectIDs = []string{"foobarzar"} p8.config = p1.config p8.oidcConfig = p1.oidcConfig p8.keyStore = p1.keyStore badKey, err := generateJSONWebKey() assert.FatalError(t, err) t1, err := p1.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) t2, err := p2.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) t3, err := p3.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) t4, err := p4.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) t5, err := p5.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) t6, err := p6.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) t7, err := p6.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) t8, err := p6.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), badKey) assert.FatalError(t, err) type args struct { token string } tests := []struct { name string azure *Azure args args wantLen int code int wantErr bool }{ {"ok", p1, args{t1}, 8, http.StatusOK, false}, {"ok", p2, args{t2}, 13, http.StatusOK, false}, {"ok", p1, args{t11}, 8, http.StatusOK, false}, {"ok", p5, args{t5}, 8, http.StatusOK, false}, {"ok", p7, args{t7}, 8, http.StatusOK, false}, {"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true}, {"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true}, {"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true}, {"fail object id", p8, args{t8}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true}, {"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true}, {"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true}, {"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) switch got, err := tt.azure.AuthorizeSign(ctx, tt.args.token); { case (err != nil) != tt.wantErr: t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Equals(t, tt.wantLen, len(got)) for _, o := range got { switch v := o.(type) { case *Azure: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeAzure) assert.Equals(t, v.Name, tt.azure.GetName()) assert.Equals(t, v.CredentialID, tt.azure.TenantID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), "virtualMachine") case defaultPublicKeyValidator: case *validityValidator: assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, v, nil) case emailAddressesValidator: assert.Equals(t, v, nil) case *urisValidator: assert.Equals(t, v.uris, nil) assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case dnsNamesSubsetValidator: assert.Equals(t, []string(v), []string{"virtualMachine"}) case *x509NamePolicyValidator: assert.Equals(t, nil, v.policyEngine) case *WebhookController: assert.Len(t, 0, v.webhooks) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } }) } } func TestAzure_AuthorizeRenew(t *testing.T) { now := time.Now().Truncate(time.Second) p1, err := generateAzure() assert.FatalError(t, err) p2, err := generateAzure() assert.FatalError(t, err) // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { cert *x509.Certificate } tests := []struct { name string azure *Azure args args code int wantErr bool }{ {"ok", p1, args{&x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, http.StatusOK, false}, {"fail/renew-disabled", p2, args{&x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) } } func TestAzure_AuthorizeSSHSign(t *testing.T) { tm, fn := mockNow() defer fn() p1, srv, err := generateAzureWithServer() assert.FatalError(t, err) p1.DisableCustomSANs = true defer srv.Close() p2, err := generateAzure() assert.FatalError(t, err) p2.TenantID = p1.TenantID p2.config = p1.config p2.oidcConfig = p1.oidcConfig p2.keyStore = p1.keyStore p2.DisableCustomSANs = false p3, err := generateAzure() assert.FatalError(t, err) // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := p1.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) t2, err := p2.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) key, err := generateJSONWebKey() assert.FatalError(t, err) signer, err := generateJSONWebKey() assert.FatalError(t, err) pub := key.Public().Key rsa2048, err := rsa.GenerateKey(rand.Reader, 2048) assert.FatalError(t, err) //nolint:gosec // tests minimum size of the key rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"virtualMachine"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } expectedCustomOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"foo.bar"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } type args struct { token string sshOpts SignSSHOptions key interface{} } tests := []struct { name string azure *Azure args args expected *SignSSHOptions code int wantErr bool wantSignErr bool }{ {"ok", p1, args{t1, SignSSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-rsa2048", p1, args{t1, SignSSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false}, {"ok-type", p1, args{t1, SignSSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-principals", p1, args{t1, SignSSHOptions{Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-options", p1, args{t1, SignSSHOptions{CertType: "host", Principals: []string{"virtualMachine"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-custom", p2, args{t2, SignSSHOptions{Principals: []string{"foo.bar"}}, pub}, expectedCustomOptions, http.StatusOK, false, false}, {"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true}, {"fail-type", p1, args{t1, SignSSHOptions{CertType: "user"}, pub}, nil, http.StatusOK, false, true}, {"fail-principal", p1, args{t1, SignSSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, {"fail-extra-principal", p1, args{t1, SignSSHOptions{Principals: []string{"virtualMachine", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, {"fail-sshCA-disabled", p3, args{"foo", SignSSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, {"fail-invalid-token", p1, args{"foo", SignSSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.azure.AuthorizeSSHSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) if (err != nil) != tt.wantSignErr { t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) } else { if tt.wantSignErr { assert.Nil(t, cert) } else { assert.NoError(t, validateSSHCertificate(cert, tt.expected)) } } } }) } } func TestAzure_assertConfig(t *testing.T) { p1, err := generateAzure() assert.FatalError(t, err) p2, err := generateAzure() assert.FatalError(t, err) p2.config = nil tests := []struct { name string azure *Azure }{ {"ok with config", p1}, {"ok no config", p2}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.azure.assertConfig() }) } } ================================================ FILE: authority/provisioner/claims.go ================================================ package provisioner import ( "time" "github.com/pkg/errors" "golang.org/x/crypto/ssh" ) // Claims so that individual provisioners can override global claims. type Claims struct { // TLS CA properties MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` // SSH CA properties MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"` MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"` DefaultUserSSHDur *Duration `json:"defaultUserSSHCertDuration,omitempty"` MinHostSSHDur *Duration `json:"minHostSSHCertDuration,omitempty"` MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"` DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"` EnableSSHCA *bool `json:"enableSSHCA,omitempty"` // Renewal properties DisableRenewal *bool `json:"disableRenewal,omitempty"` AllowRenewalAfterExpiry *bool `json:"allowRenewalAfterExpiry,omitempty"` // Other properties DisableSmallstepExtensions *bool `json:"disableSmallstepExtensions,omitempty"` } // Claimer is the type that controls claims. It provides an interface around the // current claim and the global one. type Claimer struct { global Claims claims *Claims } // NewClaimer initializes a new claimer with the given claims. func NewClaimer(claims *Claims, global Claims) (*Claimer, error) { c := &Claimer{global: global, claims: claims} err := c.Validate() return c, err } // Claims returns the merge of the inner and global claims. func (c *Claimer) Claims() Claims { disableRenewal := c.IsDisableRenewal() allowRenewalAfterExpiry := c.AllowRenewalAfterExpiry() enableSSHCA := c.IsSSHCAEnabled() disableSmallstepExtensions := c.IsDisableSmallstepExtensions() return Claims{ MinTLSDur: &Duration{c.MinTLSCertDuration()}, MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, EnableSSHCA: &enableSSHCA, DisableRenewal: &disableRenewal, AllowRenewalAfterExpiry: &allowRenewalAfterExpiry, DisableSmallstepExtensions: &disableSmallstepExtensions, } } // DefaultTLSCertDuration returns the default TLS cert duration for the // provisioner. If the default is not set within the provisioner, then the global // default from the authority configuration will be used. func (c *Claimer) DefaultTLSCertDuration() time.Duration { if c.claims == nil || c.claims.DefaultTLSDur == nil { return c.global.DefaultTLSDur.Duration } return c.claims.DefaultTLSDur.Duration } // MinTLSCertDuration returns the minimum TLS cert duration for the provisioner. // If the minimum is not set within the provisioner, then the global // minimum from the authority configuration will be used. func (c *Claimer) MinTLSCertDuration() time.Duration { if c.claims == nil || c.claims.MinTLSDur == nil { if c.claims != nil && c.claims.DefaultTLSDur != nil && c.claims.DefaultTLSDur.Duration < c.global.MinTLSDur.Duration { return c.claims.DefaultTLSDur.Duration } return c.global.MinTLSDur.Duration } return c.claims.MinTLSDur.Duration } // MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner. // If the maximum is not set within the provisioner, then the global // maximum from the authority configuration will be used. func (c *Claimer) MaxTLSCertDuration() time.Duration { if c.claims == nil || c.claims.MaxTLSDur == nil { if c.claims != nil && c.claims.DefaultTLSDur != nil && c.claims.DefaultTLSDur.Duration > c.global.MaxTLSDur.Duration { return c.claims.DefaultTLSDur.Duration } return c.global.MaxTLSDur.Duration } return c.claims.MaxTLSDur.Duration } // IsDisableRenewal returns if the renewal flow is disabled for the // provisioner. If the property is not set within the provisioner, then the // global value from the authority configuration will be used. func (c *Claimer) IsDisableRenewal() bool { if c.claims == nil || c.claims.DisableRenewal == nil { return *c.global.DisableRenewal } return *c.claims.DisableRenewal } // IsDisableSmallstepExtensions returns whether Smallstep extensions, such as // the provisioner extension, should be excluded from the certificate. func (c *Claimer) IsDisableSmallstepExtensions() bool { if c.claims == nil || c.claims.DisableSmallstepExtensions == nil { return *c.global.DisableSmallstepExtensions } return *c.claims.DisableSmallstepExtensions } // AllowRenewalAfterExpiry returns if the renewal flow is authorized if the // certificate is expired. If the property is not set within the provisioner // then the global value from the authority configuration will be used. func (c *Claimer) AllowRenewalAfterExpiry() bool { if c.claims == nil || c.claims.AllowRenewalAfterExpiry == nil { return *c.global.AllowRenewalAfterExpiry } return *c.claims.AllowRenewalAfterExpiry } // DefaultSSHCertDuration returns the default SSH certificate duration for the // given certificate type. func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) { switch certType { case ssh.UserCert: return c.DefaultUserSSHCertDuration(), nil case ssh.HostCert: return c.DefaultHostSSHCertDuration(), nil case 0: return 0, errors.New("ssh certificate type has not been set") default: return 0, errors.Errorf("ssh certificate has an unknown type: %d", certType) } } // DefaultUserSSHCertDuration returns the default SSH user cert duration for the // provisioner. If the default is not set within the provisioner, then the // global default from the authority configuration will be used. func (c *Claimer) DefaultUserSSHCertDuration() time.Duration { if c.claims == nil || c.claims.DefaultUserSSHDur == nil { return c.global.DefaultUserSSHDur.Duration } return c.claims.DefaultUserSSHDur.Duration } // MinUserSSHCertDuration returns the minimum SSH user cert duration for the // provisioner. If the minimum is not set within the provisioner, then the // global minimum from the authority configuration will be used. func (c *Claimer) MinUserSSHCertDuration() time.Duration { if c.claims == nil || c.claims.MinUserSSHDur == nil { if c.claims != nil && c.claims.DefaultUserSSHDur != nil && c.claims.DefaultUserSSHDur.Duration < c.global.MinUserSSHDur.Duration { return c.claims.DefaultUserSSHDur.Duration } return c.global.MinUserSSHDur.Duration } return c.claims.MinUserSSHDur.Duration } // MaxUserSSHCertDuration returns the maximum SSH user cert duration for the // provisioner. If the maximum is not set within the provisioner, then the // global maximum from the authority configuration will be used. func (c *Claimer) MaxUserSSHCertDuration() time.Duration { if c.claims == nil || c.claims.MaxUserSSHDur == nil { if c.claims != nil && c.claims.DefaultUserSSHDur != nil && c.claims.DefaultUserSSHDur.Duration > c.global.MaxUserSSHDur.Duration { return c.claims.DefaultUserSSHDur.Duration } return c.global.MaxUserSSHDur.Duration } return c.claims.MaxUserSSHDur.Duration } // DefaultHostSSHCertDuration returns the default SSH host cert duration for the // provisioner. If the default is not set within the provisioner, then the // global default from the authority configuration will be used. func (c *Claimer) DefaultHostSSHCertDuration() time.Duration { if c.claims == nil || c.claims.DefaultHostSSHDur == nil { return c.global.DefaultHostSSHDur.Duration } return c.claims.DefaultHostSSHDur.Duration } // MinHostSSHCertDuration returns the minimum SSH host cert duration for the // provisioner. If the minimum is not set within the provisioner, then the // global minimum from the authority configuration will be used. func (c *Claimer) MinHostSSHCertDuration() time.Duration { if c.claims == nil || c.claims.MinHostSSHDur == nil { if c.claims != nil && c.claims.DefaultHostSSHDur != nil && c.claims.DefaultHostSSHDur.Duration < c.global.MinHostSSHDur.Duration { return c.claims.DefaultHostSSHDur.Duration } return c.global.MinHostSSHDur.Duration } return c.claims.MinHostSSHDur.Duration } // MaxHostSSHCertDuration returns the maximum SSH Host cert duration for the // provisioner. If the maximum is not set within the provisioner, then the // global maximum from the authority configuration will be used. func (c *Claimer) MaxHostSSHCertDuration() time.Duration { if c.claims == nil || c.claims.MaxHostSSHDur == nil { if c.claims != nil && c.claims.DefaultHostSSHDur != nil && c.claims.DefaultHostSSHDur.Duration > c.global.MaxHostSSHDur.Duration { return c.claims.DefaultHostSSHDur.Duration } return c.global.MaxHostSSHDur.Duration } return c.claims.MaxHostSSHDur.Duration } // IsSSHCAEnabled returns if the SSH CA is enabled for the provisioner. If the // property is not set within the provisioner, then the global value from the // authority configuration will be used. func (c *Claimer) IsSSHCAEnabled() bool { if c.claims == nil || c.claims.EnableSSHCA == nil { return *c.global.EnableSSHCA } return *c.claims.EnableSSHCA } // Validate validates and modifies the Claims with default values. func (c *Claimer) Validate() error { var ( minDur = c.MinTLSCertDuration() maxDur = c.MaxTLSCertDuration() defDur = c.DefaultTLSCertDuration() ) switch { case minDur <= 0: return errors.Errorf("claims: MinTLSCertDuration must be greater than 0") case maxDur <= 0: return errors.Errorf("claims: MaxTLSCertDuration must be greater than 0") case defDur <= 0: return errors.Errorf("claims: DefaultTLSCertDuration must be greater than 0") case maxDur < minDur: return errors.Errorf("claims: MaxCertDuration cannot be less "+ "than MinCertDuration: MaxCertDuration - %v, MinCertDuration - %v", maxDur, minDur) case defDur < minDur: return errors.Errorf("claims: DefaultCertDuration cannot be less than MinCertDuration: DefaultCertDuration - %v, MinCertDuration - %v", defDur, minDur) case maxDur < defDur: return errors.Errorf("claims: MaxCertDuration cannot be less than DefaultCertDuration: MaxCertDuration - %v, DefaultCertDuration - %v", maxDur, defDur) default: return nil } } ================================================ FILE: authority/provisioner/claims_test.go ================================================ package provisioner import ( "testing" "time" "golang.org/x/crypto/ssh" ) func TestClaimer_DefaultSSHCertDuration(t *testing.T) { duration := Duration{ Duration: time.Hour, } type fields struct { global Claims claims *Claims } type args struct { certType uint32 } tests := []struct { name string fields fields args args want time.Duration wantErr bool }{ {"user", fields{globalProvisionerClaims, &Claims{DefaultUserSSHDur: &duration}}, args{1}, time.Hour, false}, {"user global", fields{globalProvisionerClaims, nil}, args{ssh.UserCert}, 16 * time.Hour, false}, {"host global", fields{globalProvisionerClaims, &Claims{DefaultHostSSHDur: &duration}}, args{2}, time.Hour, false}, {"host global", fields{globalProvisionerClaims, nil}, args{ssh.HostCert}, 30 * 24 * time.Hour, false}, {"invalid", fields{globalProvisionerClaims, nil}, args{0}, 0, true}, {"invalid global", fields{globalProvisionerClaims, nil}, args{3}, 0, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Claimer{ global: tt.fields.global, claims: tt.fields.claims, } got, err := c.DefaultSSHCertDuration(tt.args.certType) if (err != nil) != tt.wantErr { t.Errorf("Claimer.DefaultSSHCertDuration() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("Claimer.DefaultSSHCertDuration() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/provisioner/collection.go ================================================ package provisioner import ( "crypto/sha1" //nolint:gosec // not used for cryptographic security "crypto/x509" "encoding/asn1" "encoding/binary" "encoding/hex" "fmt" "net/url" "sort" "strings" "sync" "go.step.sm/crypto/jose" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/internal/cast" ) // DefaultProvisionersLimit is the default limit for listing provisioners. const DefaultProvisionersLimit = 20 // DefaultProvisionersMax is the maximum limit for listing provisioners. const DefaultProvisionersMax = 100 type uidProvisioner struct { provisioner Interface uid string } type provisionerSlice []uidProvisioner func (p provisionerSlice) Len() int { return len(p) } func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // loadByTokenPayload is a payload used to extract the id used to load the // provisioner. type loadByTokenPayload struct { jose.Claims Email string `json:"email"` // OIDC email AuthorizedParty string `json:"azp"` // OIDC client id TenantID string `json:"tid"` // Microsoft Azure tenant id } // Collection is a memory map of provisioners. type Collection struct { byID *sync.Map byKey *sync.Map byName *sync.Map byTokenID *sync.Map sorted provisionerSlice audiences Audiences } // NewCollection initializes a collection of provisioners. The given list of // audiences are the audiences used by the JWT provisioner. func NewCollection(audiences Audiences) *Collection { return &Collection{ byID: new(sync.Map), byKey: new(sync.Map), byName: new(sync.Map), byTokenID: new(sync.Map), audiences: audiences, } } // Load a provisioner by the ID. func (c *Collection) Load(id string) (Interface, bool) { return loadProvisioner(c.byID, id) } // LoadByName a provisioner by name. func (c *Collection) LoadByName(name string) (Interface, bool) { return loadProvisioner(c.byName, name) } // LoadByTokenID a provisioner by identifier found in token. // For different provisioner types this identifier may be found in different // attributes of the token. func (c *Collection) LoadByTokenID(tokenProvisionerID string) (Interface, bool) { return loadProvisioner(c.byTokenID, tokenProvisionerID) } // LoadByToken parses the token claims and loads the provisioner associated. func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) { var audiences []string // Get all audiences with the given fragment fragment := extractFragment(claims.Audience) if fragment == "" { audiences = c.audiences.All() } else { audiences = c.audiences.WithFragment(fragment).All() } // match with server audiences if matchesAudience(claims.Audience, audiences) { // Use fragment to get provisioner name (GCP, AWS, SSHPOP) if fragment != "" { return c.LoadByTokenID(fragment) } // If matches with stored audiences it will be a JWT token (default), and // the id would be :. // TODO: is this ok? return c.LoadByTokenID(claims.Issuer + ":" + token.Headers[0].KeyID) } // The ID will be just the clientID stored in azp, aud or tid. var payload loadByTokenPayload if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil { return nil, false } // Kubernetes Service Account tokens. if payload.Issuer == k8sSAIssuer { if p, ok := c.LoadByTokenID(K8sSAID); ok { return p, ok } // Kubernetes service account provisioner not found return nil, false } // Audience is required for non k8sSA tokens. if len(payload.Audience) == 0 { return nil, false } // Try with azp (OIDC) if payload.AuthorizedParty != "" { if p, ok := c.LoadByTokenID(payload.AuthorizedParty); ok { return p, ok } } // Try with tid (Azure, Azure OIDC) if payload.TenantID != "" { // Try to load an OIDC provisioner first. if payload.Email != "" { if p, ok := c.LoadByTokenID(payload.Audience[0]); ok { return p, ok } } // Try to load an Azure provisioner. if p, ok := c.LoadByTokenID(payload.TenantID); ok { return p, ok } } // Fallback to aud return c.LoadByTokenID(payload.Audience[0]) } // LoadByCertificate looks for the provisioner extension and extracts the // proper id to load the provisioner. func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) { for _, e := range cert.Extensions { if e.Id.Equal(StepOIDProvisioner) { var provisioner extensionASN1 if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { return nil, false } return c.LoadByName(string(provisioner.Name)) } } // Default to noop provisioner if an extension is not found. This allows to // accept a renewal of a cert without the provisioner extension. return &noop{}, true } // LoadEncryptedKey returns an encrypted key by indexed by KeyID. At this moment // only JWK encrypted keys are indexed by KeyID. func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) { p, ok := loadProvisioner(c.byKey, keyID) if !ok { return "", false } _, key, ok := p.GetEncryptedKey() return key, ok } // Store adds a provisioner to the collection and enforces the uniqueness of // provisioner IDs. func (c *Collection) Store(p Interface) error { // Store provisioner always in byID. ID must be unique. if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded { return admin.NewError(admin.ErrorBadRequestType, "cannot add multiple provisioners with the same id") } // Store provisioner always by name. if _, loaded := c.byName.LoadOrStore(p.GetName(), p); loaded { c.byID.Delete(p.GetID()) return admin.NewError(admin.ErrorBadRequestType, "cannot add multiple provisioners with the same name") } // Store provisioner always by ID presented in token. if _, loaded := c.byTokenID.LoadOrStore(p.GetIDForToken(), p); loaded { c.byID.Delete(p.GetID()) c.byName.Delete(p.GetName()) return admin.NewError(admin.ErrorBadRequestType, "cannot add multiple provisioners with the same token identifier") } // Store provisioner in byKey if EncryptedKey is defined. if kid, _, ok := p.GetEncryptedKey(); ok { c.byKey.Store(kid, p) } // Store sorted provisioners. // Use the first 4 bytes (32bit) of the sum to insert the order // Using big endian format to get the strings sorted: // 0x00000000, 0x00000001, 0x00000002, ... bi := make([]byte, 4) sum := provisionerSum(p) binary.BigEndian.PutUint32(bi, cast.Uint32(c.sorted.Len())) sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3] c.sorted = append(c.sorted, uidProvisioner{ provisioner: p, uid: hex.EncodeToString(sum), }) sort.Sort(c.sorted) return nil } // Remove deletes an provisioner from all associated collections and lists. func (c *Collection) Remove(id string) error { prov, ok := c.Load(id) if !ok { return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id) } var found bool for i, elem := range c.sorted { if elem.provisioner.GetID() != id { continue } // Remove index in sorted list copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index. c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value). c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice. found = true break } if !found { return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName()) } c.byID.Delete(id) c.byName.Delete(prov.GetName()) c.byTokenID.Delete(prov.GetIDForToken()) if kid, _, ok := prov.GetEncryptedKey(); ok { c.byKey.Delete(kid) } return nil } // Update updates the given provisioner in all related lists and collections. func (c *Collection) Update(nu Interface) error { old, ok := c.Load(nu.GetID()) if !ok { return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", nu.GetID()) } if old.GetName() != nu.GetName() { if _, ok := c.LoadByName(nu.GetName()); ok { return admin.NewError(admin.ErrorBadRequestType, "provisioner with name %s already exists", nu.GetName()) } } if old.GetIDForToken() != nu.GetIDForToken() { if _, ok := c.LoadByTokenID(nu.GetIDForToken()); ok { return admin.NewError(admin.ErrorBadRequestType, "provisioner with Token ID %s already exists", nu.GetIDForToken()) } } if err := c.Remove(old.GetID()); err != nil { return err } return c.Store(nu) } // Find implements pagination on a list of sorted provisioners. func (c *Collection) Find(cursor string, limit int) (List, string) { switch { case limit <= 0: limit = DefaultProvisionersLimit case limit > DefaultProvisionersMax: limit = DefaultProvisionersMax } n := c.sorted.Len() cursor = fmt.Sprintf("%040s", cursor) i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor }) slice := List{} for ; i < n && len(slice) < limit; i++ { slice = append(slice, c.sorted[i].provisioner) } if i < n { return slice, strings.TrimLeft(c.sorted[i].uid, "0") } return slice, "" } func loadProvisioner(m *sync.Map, key string) (Interface, bool) { i, ok := m.Load(key) if !ok { return nil, false } p, ok := i.(Interface) if !ok { return nil, false } return p, true } // provisionerSum returns the SHA1 of the provisioners ID. From this we will // create the unique and sorted id. func provisionerSum(p Interface) []byte { //nolint:gosec // not used for cryptographic security sum := sha1.Sum([]byte(p.GetID())) return sum[:] } // matchesAudience returns true if A and B share at least one element. func matchesAudience(as, bs []string) bool { if len(bs) == 0 || len(as) == 0 { return false } for _, b := range bs { for _, a := range as { if b == a || stripPort(a) == stripPort(b) { return true } } } return false } // stripPort attempts to strip the port from the given url. If parsing the url // produces errors it will just return the passed argument. func stripPort(rawurl string) string { u, err := url.Parse(rawurl) if err != nil { return rawurl } u.Host = u.Hostname() return u.String() } // extractFragment extracts the first fragment of an audience url. func extractFragment(audience []string) string { for _, s := range audience { if u, err := url.Parse(s); err == nil && u.Fragment != "" { return u.Fragment } } return "" } ================================================ FILE: authority/provisioner/collection_test.go ================================================ package provisioner import ( "crypto/x509" "crypto/x509/pkix" "reflect" "strings" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" ) func TestCollection_Load(t *testing.T) { p, err := generateJWK() require.NoError(t, err) byID := new(sync.Map) byID.Store(p.GetID(), p) byID.Store("string", "a-string") type fields struct { byID *sync.Map } type args struct { id string } tests := []struct { name string fields fields args args want Interface want1 bool }{ {"ok", fields{byID}, args{p.GetID()}, p, true}, {"fail", fields{byID}, args{"fail"}, nil, false}, {"invalid", fields{byID}, args{"string"}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Collection{ byID: tt.fields.byID, } got, got1 := c.Load(tt.args.id) if !reflect.DeepEqual(got, tt.want) { t.Errorf("Collection.Load() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1) } }) } } func TestCollection_LoadByTokenID(t *testing.T) { p1, err := generateJWK() require.NoError(t, err) p2, err := generateACME() require.NoError(t, err) byTokenID := new(sync.Map) byTokenID.Store(p1.GetIDForToken(), p1) byTokenID.Store(p2.GetIDForToken(), p2) byTokenID.Store("string", "a-string") type fields struct { byTokenID *sync.Map } type args struct { id string } tests := []struct { name string fields fields args args want Interface want1 bool }{ {"ok jwk", fields{byTokenID}, args{p1.GetIDForToken()}, p1, true}, {"ok acme", fields{byTokenID}, args{p2.GetIDForToken()}, p2, true}, {"fail missing", fields{byTokenID}, args{"missing"}, nil, false}, {"invalid", fields{byTokenID}, args{"string"}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Collection{ byTokenID: tt.fields.byTokenID, } got, got1 := c.LoadByTokenID(tt.args.id) if !reflect.DeepEqual(got, tt.want) { t.Errorf("Collection.Load() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1) } }) } } func TestCollection_LoadByToken(t *testing.T) { p1, err := generateJWK() require.NoError(t, err) p2, err := generateJWK() require.NoError(t, err) p3, err := generateOIDC() require.NoError(t, err) p4, err := generateK8sSA(nil) require.NoError(t, err) byID := new(sync.Map) byID.Store(p1.GetID(), p1) byID.Store(p2.GetID(), p2) byID.Store(p3.GetID(), p3) byID.Store(p4.GetID(), p4) byID.Store("string", "a-string") byID2 := new(sync.Map) byID2.Store(p1.GetID(), p1) byID2.Store(p2.GetID(), p2) byID2.Store(p3.GetID(), p3) jwk, err := decryptJSONWebKey(p1.EncryptedKey) require.NoError(t, err) token, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], jwk) require.NoError(t, err) t1, c1, err := parseToken(token) require.NoError(t, err) jwk, err = decryptJSONWebKey(p2.EncryptedKey) require.NoError(t, err) token, err = generateSimpleToken(p2.Name, testAudiences.Sign[1], jwk) require.NoError(t, err) t2, c2, err := parseToken(token) require.NoError(t, err) token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0]) require.NoError(t, err) t3, c3, err := parseToken(token) require.NoError(t, err) token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0]) require.NoError(t, err) t4, c4, err := parseToken(token) require.NoError(t, err) jwk, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) token, err = generateK8sSAToken(jwk, nil) require.NoError(t, err) t5, c5, err := parseToken(token) require.NoError(t, err) type fields struct { byID *sync.Map audiences Audiences } type args struct { token *jose.JSONWebToken claims *jose.Claims } tests := []struct { name string fields fields args args want Interface want1 bool }{ {"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true}, {"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true}, {"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true}, {"ok4", fields{byID, testAudiences}, args{t5, c5}, p4, true}, {"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false}, {"fail", fields{byID, Audiences{Sign: []string{"https://foo"}}}, args{t1, c1}, nil, false}, {"fail-no-k8sSa-provisioner", fields{byID2, testAudiences}, args{t5, c5}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Collection{ byID: tt.fields.byID, byTokenID: tt.fields.byID, audiences: tt.fields.audiences, } got, got1 := c.LoadByToken(tt.args.token, tt.args.claims) if !reflect.DeepEqual(got, tt.want) { t.Errorf("Collection.LoadByToken() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { t.Errorf("Collection.LoadByToken() got1 = %v, want %v", got1, tt.want1) } }) } } func TestCollection_LoadByCertificate(t *testing.T) { mustExtension := func(typ Type, name, credentialID string) pkix.Extension { e := Extension{ Type: typ, Name: name, CredentialID: credentialID, } ext, err := e.ToExtension() if err != nil { t.Fatal(err) } return ext } p1, err := generateJWK() require.NoError(t, err) p2, err := generateOIDC() require.NoError(t, err) p3, err := generateACME() require.NoError(t, err) byName := new(sync.Map) byName.Store(p1.GetName(), p1) byName.Store(p2.GetName(), p2) byName.Store(p3.GetName(), p3) ok1Cert := &x509.Certificate{ Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)}, } ok2Cert := &x509.Certificate{ Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)}, } ok3Cert := &x509.Certificate{ Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")}, } notFoundCert := &x509.Certificate{ Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")}, } badCert := &x509.Certificate{ Extensions: []pkix.Extension{ {Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")}, }, } type fields struct { byName *sync.Map audiences Audiences } type args struct { cert *x509.Certificate } tests := []struct { name string fields fields args args want Interface want1 bool }{ {"ok1", fields{byName, testAudiences}, args{ok1Cert}, p1, true}, {"ok2", fields{byName, testAudiences}, args{ok2Cert}, p2, true}, {"ok3", fields{byName, testAudiences}, args{ok3Cert}, p3, true}, {"noExtension", fields{byName, testAudiences}, args{&x509.Certificate{}}, &noop{}, true}, {"notFound", fields{byName, testAudiences}, args{notFoundCert}, nil, false}, {"badCert", fields{byName, testAudiences}, args{badCert}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Collection{ byName: tt.fields.byName, audiences: tt.fields.audiences, } got, got1 := c.LoadByCertificate(tt.args.cert) if !reflect.DeepEqual(got, tt.want) { t.Errorf("Collection.LoadByCertificate() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { t.Errorf("Collection.LoadByCertificate() got1 = %v, want %v", got1, tt.want1) } }) } } func TestCollection_LoadEncryptedKey(t *testing.T) { c := NewCollection(testAudiences) p1, err := generateJWK() require.NoError(t, err) require.NoError(t, c.Store(p1)) p2, err := generateOIDC() require.NoError(t, err) require.NoError(t, c.Store(p2)) // Add oidc in byKey. // It should not happen. p2KeyID := p2.keyStore.keySet.Keys[0].KeyID c.byKey.Store(p2KeyID, p2) type args struct { keyID string } tests := []struct { name string args args want string want1 bool }{ {"ok", args{p1.Key.KeyID}, p1.EncryptedKey, true}, {"oidc", args{p2KeyID}, "", false}, {"notFound", args{"not-found"}, "", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1 := c.LoadEncryptedKey(tt.args.keyID) if got != tt.want { t.Errorf("Collection.LoadEncryptedKey() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { t.Errorf("Collection.LoadEncryptedKey() got1 = %v, want %v", got1, tt.want1) } }) } } func TestCollection_Store(t *testing.T) { c := NewCollection(testAudiences) p1, err := generateJWK() require.NoError(t, err) p2, err := generateOIDC() require.NoError(t, err) type args struct { p Interface } tests := []struct { name string args args wantErr bool }{ {"ok1", args{p1}, false}, {"ok2", args{p2}, false}, {"fail1", args{p1}, true}, {"fail2", args{p2}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := c.Store(tt.args.p); (err != nil) != tt.wantErr { t.Errorf("Collection.Store() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestCollection_Find(t *testing.T) { c, err := generateCollection(10, 10) require.NoError(t, err) trim := func(s string) string { return strings.TrimLeft(s, "0") } toList := func(ps provisionerSlice) List { l := List{} for _, p := range ps { l = append(l, p.provisioner) } return l } type args struct { cursor string limit int } tests := []struct { name string args args want List want1 string }{ {"all", args{"", DefaultProvisionersMax}, toList(c.sorted[0:20]), ""}, {"0 to 19", args{"", 20}, toList(c.sorted[0:20]), ""}, {"0 to 9", args{"", 10}, toList(c.sorted[0:10]), trim(c.sorted[10].uid)}, {"9 to 19", args{trim(c.sorted[10].uid), 10}, toList(c.sorted[10:20]), ""}, {"1", args{trim(c.sorted[1].uid), 1}, toList(c.sorted[1:2]), trim(c.sorted[2].uid)}, {"1 to 5", args{trim(c.sorted[1].uid), 4}, toList(c.sorted[1:5]), trim(c.sorted[5].uid)}, {"defaultLimit", args{"", 0}, toList(c.sorted[0:20]), ""}, {"overTheLimit", args{"", DefaultProvisionersMax + 1}, toList(c.sorted[0:20]), ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1 := c.Find(tt.args.cursor, tt.args.limit) if !reflect.DeepEqual(got, tt.want) { t.Errorf("Collection.Find() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { t.Errorf("Collection.Find() got1 = %v, want %v", got1, tt.want1) } }) } } func Test_matchesAudience(t *testing.T) { type matchesTest struct { a, b []string exp bool } tests := map[string]matchesTest{ "false arg1 empty": { a: []string{}, b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, exp: false, }, "false arg2 empty": { a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, b: []string{}, exp: false, }, "false arg1,arg2 empty": { a: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, b: []string{"step-gateway", "step-cli"}, exp: false, }, "false": { a: []string{"step-gateway", "step-cli"}, b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, exp: false, }, "true": { a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"}, b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, exp: true, }, "true,portsA": { a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"}, b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com/sign"}, exp: true, }, "true,portsB": { a: []string{"step-gateway", "https://test.ca.smallstep.com/sign"}, b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:9000/sign"}, exp: true, }, "true,portsAB": { a: []string{"step-gateway", "https://test.ca.smallstep.com:9000/sign"}, b: []string{"https://127.0.0.1:0/sign", "https://test.ca.smallstep.com:8000/sign"}, exp: true, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { assert.Equal(t, tc.exp, matchesAudience(tc.a, tc.b)) }) } } func Test_stripPort(t *testing.T) { type args struct { rawurl string } tests := []struct { name string args args want string }{ {"with port", args{"https://ca.smallstep.com:9000/sign"}, "https://ca.smallstep.com/sign"}, {"with no port", args{"https://ca.smallstep.com/sign/"}, "https://ca.smallstep.com/sign/"}, {"bad url", args{"https://a bad url:9000"}, "https://a bad url:9000"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := stripPort(tt.args.rawurl); got != tt.want { t.Errorf("stripPort() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/provisioner/controller.go ================================================ package provisioner import ( "context" "crypto/x509" "net/http" "strings" "time" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/webhook" ) // Controller wraps a provisioner with other attributes useful in callback // functions. type Controller struct { Interface Audiences *Audiences Claimer *Claimer IdentityFunc GetIdentityFunc AuthorizeRenewFunc AuthorizeRenewFunc AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc policy *policyEngine httpClient HTTPClient webhookClient HTTPClient webhooks []*Webhook wrapTransport httptransport.Wrapper } // NewController initializes a new provisioner controller. func NewController(p Interface, claims *Claims, config Config, options *Options) (*Controller, error) { claimer, err := NewClaimer(claims, config.Claims) if err != nil { return nil, err } policy, err := newPolicyEngine(options) if err != nil { return nil, err } wt := config.WrapTransport if wt == nil { wt = httptransport.NoopWrapper() } for _, wh := range options.GetWebhooks() { if err := wh.Validate(); err != nil { return nil, err } } return &Controller{ Interface: p, Audiences: &config.Audiences, Claimer: claimer, IdentityFunc: config.GetIdentityFunc, AuthorizeRenewFunc: config.AuthorizeRenewFunc, AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc, policy: policy, webhookClient: config.WebhookClient, webhooks: options.GetWebhooks(), httpClient: config.HTTPClient, wrapTransport: wt, }, nil } // GetHTTPClient returns the configured HTTP client or the default one if none // is configured. func (c *Controller) GetHTTPClient() HTTPClient { if c.httpClient != nil { return c.httpClient } return &http.Client{} } // GetIdentity returns the identity for a given email. func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) { if c.IdentityFunc != nil { return c.IdentityFunc(ctx, c.Interface, email) } return DefaultIdentityFunc(ctx, c.Interface, email) } // AuthorizeRenew returns nil if the given cert can be renewed, returns an error // otherwise. func (c *Controller) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { if c.AuthorizeRenewFunc != nil { return c.AuthorizeRenewFunc(ctx, c, cert) } return DefaultAuthorizeRenew(ctx, c, cert) } // AuthorizeSSHRenew returns nil if the given cert can be renewed, returns an // error otherwise. func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificate) error { if c.AuthorizeSSHRenewFunc != nil { return c.AuthorizeSSHRenewFunc(ctx, c, cert) } return DefaultAuthorizeSSHRenew(ctx, c, cert) } func (c *Controller) newWebhookController(templateData WebhookSetter, certType linkedca.Webhook_CertType, opts ...webhook.RequestBodyOption) *WebhookController { client := c.webhookClient if client == nil { client = &http.Client{ Transport: c.wrapTransport(httptransport.New()), } } return &WebhookController{ TemplateData: templateData, client: client, wrapTransport: c.wrapTransport, webhooks: c.webhooks, certType: certType, options: opts, } } // Identity is the type representing an externally supplied identity that is used // by provisioners to populate certificate fields. type Identity struct { Usernames []string `json:"usernames"` Permissions `json:"permissions"` } // GetIdentityFunc is a function that returns an identity. type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error) // AuthorizeRenewFunc is a function that returns nil if the renewal of a // certificate is enabled. type AuthorizeRenewFunc func(ctx context.Context, p *Controller, cert *x509.Certificate) error // AuthorizeSSHRenewFunc is a function that returns nil if the renewal of the // given SSH certificate is enabled. type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Certificate) error // DefaultIdentityFunc return a default identity depending on the provisioner // type. For OIDC email is always present and the usernames might // contain empty strings. func DefaultIdentityFunc(_ context.Context, p Interface, email string) (*Identity, error) { switch k := p.(type) { case *OIDC: // OIDC principals would be: // ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled // 2. Sanitized local. // 3. Raw local (if different). // 4. Email address. name := SanitizeSSHUserPrincipal(email) usernames := []string{name} if i := strings.LastIndex(email, "@"); i >= 0 { usernames = append(usernames, email[:i]) } usernames = append(usernames, email) return &Identity{ // Remove duplicated and empty usernames. Usernames: SanitizeStringSlices(usernames), }, nil default: return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k) } } // DefaultAuthorizeRenew is the default implementation of AuthorizeRenew. It // will return an error if the provisioner has the renewal disabled, if the // certificate is not yet valid or if the certificate is expired and renew after // expiry is disabled. func DefaultAuthorizeRenew(_ context.Context, p *Controller, cert *x509.Certificate) error { if p.Claimer.IsDisableRenewal() { return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName()) } now := time.Now().Truncate(time.Second) if now.Before(cert.NotBefore) { return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano)) } if now.After(cert.NotAfter) && !p.Claimer.AllowRenewalAfterExpiry() { // return a custom 401 Unauthorized error with a clearer message for the client // TODO(hs): these errors likely need to be refactored as a whole; HTTP status codes shouldn't be in this layer. return errs.New(http.StatusUnauthorized, "The request lacked necessary authorization to be completed: certificate expired on %s", cert.NotAfter) } return nil } // DefaultAuthorizeSSHRenew is the default implementation of AuthorizeSSHRenew. It // will return an error if the provisioner has the renewal disabled, if the // certificate is not yet valid or if the certificate is expired and renew after // expiry is disabled. func DefaultAuthorizeSSHRenew(_ context.Context, p *Controller, cert *ssh.Certificate) error { if p.Claimer.IsDisableRenewal() { return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName()) } unixNow := time.Now().Unix() if after := cast.Int64(cert.ValidAfter); after < 0 || unixNow < cast.Int64(cert.ValidAfter) { return errs.Unauthorized("certificate is not yet valid") } if before := cast.Int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewalAfterExpiry() { return errs.Unauthorized("certificate has expired") } return nil } // SanitizeStringSlices removes duplicated an empty strings. func SanitizeStringSlices(original []string) []string { output := []string{} seen := make(map[string]struct{}) for _, entry := range original { if entry == "" { continue } if _, value := seen[entry]; !value { seen[entry] = struct{}{} output = append(output, entry) } } return output } // SanitizeSSHUserPrincipal grabs an email or a string with the format // local@domain and returns a sanitized version of the local, valid to be used // as a user name. If the email starts with a letter between a and z, the // resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`. func SanitizeSSHUserPrincipal(email string) string { if i := strings.LastIndex(email, "@"); i >= 0 { email = email[:i] } return strings.Map(func(r rune) rune { switch { case r >= 'a' && r <= 'z': return r case r >= '0' && r <= '9': return r case r == '-': return '-' case r == '.': // drop dots return -1 default: return '_' } }, strings.ToLower(email)) } func (c *Controller) getPolicy() *policyEngine { if c == nil { return nil } return c.policy } ================================================ FILE: authority/provisioner/controller_test.go ================================================ package provisioner import ( "context" "crypto/x509" "fmt" "net/http" "reflect" "testing" "time" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/webhook" "github.com/smallstep/linkedca" "github.com/stretchr/testify/assert" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" ) var trueValue = true func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer { t.Helper() c, err := NewClaimer(claims, global) if err != nil { t.Fatal(err) } return c } func mustDuration(t *testing.T, s string) *Duration { t.Helper() d, err := NewDuration(s) if err != nil { t.Fatal(err) } return d } func mustNewPolicyEngine(t *testing.T, options *Options) *policyEngine { t.Helper() c, err := newPolicyEngine(options) if err != nil { t.Fatal(err) } return c } func TestNewController(t *testing.T) { options := &Options{ X509: &X509Options{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"*.local"}, }, }, SSH: &SSHOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.local"}, }, }, User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ EmailAddresses: []string{"@example.com"}, }, }, }, } type args struct { p Interface claims *Claims config Config options *Options } tests := []struct { name string args args want *Controller wantErr bool }{ {"ok", args{&JWK{}, nil, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, HTTPClient: &http.Client{}, WrapTransport: httptransport.NoopWrapper(), }, nil}, &Controller{ Interface: &JWK{}, Audiences: &testAudiences, Claimer: mustClaimer(t, nil, globalProvisionerClaims), httpClient: &http.Client{}, wrapTransport: httptransport.NoopWrapper(), }, false}, {"ok with claims", args{&JWK{}, &Claims{ DisableRenewal: &defaultDisableRenewal, }, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, }, nil}, &Controller{ Interface: &JWK{}, Audiences: &testAudiences, Claimer: mustClaimer(t, &Claims{ DisableRenewal: &defaultDisableRenewal, }, globalProvisionerClaims), wrapTransport: httptransport.NoopWrapper(), }, false}, {"ok with claims and options", args{&JWK{}, &Claims{ DisableRenewal: &defaultDisableRenewal, }, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, }, options}, &Controller{ Interface: &JWK{}, Audiences: &testAudiences, Claimer: mustClaimer(t, &Claims{ DisableRenewal: &defaultDisableRenewal, }, globalProvisionerClaims), policy: mustNewPolicyEngine(t, options), wrapTransport: httptransport.NoopWrapper(), }, false}, {"fail claimer", args{&JWK{}, &Claims{ MinTLSDur: mustDuration(t, "24h"), MaxTLSDur: mustDuration(t, "2h"), }, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, }, nil}, nil, true}, {"fail options", args{&JWK{}, &Claims{ DisableRenewal: &defaultDisableRenewal, }, Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, }, &Options{ X509: &X509Options{ AllowedNames: &policy.X509NameOptions{ DNSDomains: []string{"**.local"}, }, }, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := NewController(tt.args.p, tt.args.claims, tt.args.config, tt.args.options) if (err != nil) != tt.wantErr { t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr) return } // A function can only be compared to nil if tt.want != nil && got != nil { assert.NotNil(t, got.wrapTransport) tt.want.wrapTransport = nil got.wrapTransport = nil } if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewController() = %v, want %v", got, tt.want) } }) } } func TestController_GetHTTPClient(t *testing.T) { srv := generateTLSJWKServer(2) defer srv.Close() type fields struct { httpClient *http.Client } tests := []struct { name string fields fields want *http.Client }{ {"ok custom", fields{srv.Client()}, srv.Client()}, {"ok default", fields{http.DefaultClient}, http.DefaultClient}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Controller{ httpClient: tt.fields.httpClient, } assert.Equal(t, tt.want, c.GetHTTPClient()) }) } } func TestController_GetIdentity(t *testing.T) { ctx := context.Background() type fields struct { Interface Interface IdentityFunc GetIdentityFunc } type args struct { ctx context.Context email string } tests := []struct { name string fields fields args args want *Identity wantErr bool }{ {"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{ Usernames: []string{"jane", "jane@doe.org"}, }, false}, {"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) { return &Identity{Usernames: []string{"jane"}}, nil }}, args{ctx, "jane@doe.org"}, &Identity{ Usernames: []string{"jane"}, }, false}, {"ok badname", fields{&OIDC{}, nil}, args{ctx, "1000@doe.org"}, &Identity{ Usernames: []string{"1000", "1000@doe.org"}, }, false}, {"ok sanitized badname", fields{&OIDC{}, nil}, args{ctx, "1000+10@doe.org"}, &Identity{ Usernames: []string{"1000_10", "1000+10", "1000+10@doe.org"}, }, false}, {"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true}, {"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) { return nil, fmt.Errorf("an error") }}, args{ctx, "jane@doe.org"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Controller{ Interface: tt.fields.Interface, IdentityFunc: tt.fields.IdentityFunc, } got, err := c.GetIdentity(tt.args.ctx, tt.args.email) if (err != nil) != tt.wantErr { t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want) } }) } } func TestController_AuthorizeRenew(t *testing.T) { ctx := context.Background() now := time.Now().Truncate(time.Second) type fields struct { Interface Interface Claimer *Claimer AuthorizeRenewFunc AuthorizeRenewFunc } type args struct { ctx context.Context cert *x509.Certificate } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, false}, {"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { return nil }}, args{ctx, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, false}, {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { return nil }}, args{ctx, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, false}, {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ NotBefore: now.Add(-time.Hour), NotAfter: now.Add(-time.Minute), }}, false}, {"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, true}, {"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ NotBefore: now.Add(time.Hour), NotAfter: now.Add(2 * time.Hour), }}, true}, {"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ NotBefore: now.Add(-time.Hour), NotAfter: now.Add(-time.Minute), }}, true}, {"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { return fmt.Errorf("an error") }}, args{ctx, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Controller{ Interface: tt.fields.Interface, Claimer: tt.fields.Claimer, AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc, } if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestController_AuthorizeSSHRenew(t *testing.T) { ctx := context.Background() now := time.Now() type fields struct { Interface Interface Claimer *Claimer AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc } type args struct { ctx context.Context cert *ssh.Certificate } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, false}, {"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { return nil }}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, false}, {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { return nil }}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, false}, {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Add(-time.Hour).Unix()), ValidBefore: uint64(now.Add(-time.Minute).Unix()), }}, false}, {"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, true}, {"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Add(time.Hour).Unix()), ValidBefore: uint64(now.Add(2 * time.Hour).Unix()), }}, true}, {"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Add(-time.Hour).Unix()), ValidBefore: uint64(now.Add(-time.Minute).Unix()), }}, true}, {"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { return fmt.Errorf("an error") }}, args{ctx, &ssh.Certificate{ ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Controller{ Interface: tt.fields.Interface, Claimer: tt.fields.Claimer, AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc, } if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestDefaultAuthorizeRenew(t *testing.T) { ctx := context.Background() now := time.Now().Truncate(time.Second) type args struct { ctx context.Context p *Controller cert *x509.Certificate } tests := []struct { name string args args wantErr bool }{ {"ok", args{ctx, &Controller{ Interface: &JWK{}, Claimer: mustClaimer(t, nil, globalProvisionerClaims), }, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, false}, {"ok renew after expiry", args{ctx, &Controller{ Interface: &JWK{}, Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), }, &x509.Certificate{ NotBefore: now.Add(-time.Hour), NotAfter: now.Add(-time.Minute), }}, false}, {"fail disabled", args{ctx, &Controller{ Interface: &JWK{}, Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), }, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, true}, {"fail not yet valid", args{ctx, &Controller{ Interface: &JWK{}, Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), }, &x509.Certificate{ NotBefore: now.Add(time.Hour), NotAfter: now.Add(2 * time.Hour), }}, true}, {"fail expired", args{ctx, &Controller{ Interface: &JWK{}, Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), }, &x509.Certificate{ NotBefore: now.Add(-time.Hour), NotAfter: now.Add(-time.Minute), }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestDefaultAuthorizeSSHRenew(t *testing.T) { ctx := context.Background() now := time.Now() type args struct { ctx context.Context p *Controller cert *ssh.Certificate } tests := []struct { name string args args wantErr bool }{ {"ok", args{ctx, &Controller{ Interface: &JWK{}, Claimer: mustClaimer(t, nil, globalProvisionerClaims), }, &ssh.Certificate{ ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, false}, {"ok renew after expiry", args{ctx, &Controller{ Interface: &JWK{}, Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), }, &ssh.Certificate{ ValidAfter: uint64(now.Add(-time.Hour).Unix()), ValidBefore: uint64(now.Add(-time.Minute).Unix()), }}, false}, {"fail disabled", args{ctx, &Controller{ Interface: &JWK{}, Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), }, &ssh.Certificate{ ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), }}, true}, {"fail not yet valid", args{ctx, &Controller{ Interface: &JWK{}, Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), }, &ssh.Certificate{ ValidAfter: uint64(now.Add(time.Hour).Unix()), ValidBefore: uint64(now.Add(2 * time.Hour).Unix()), }}, true}, {"fail expired", args{ctx, &Controller{ Interface: &JWK{}, Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), }, &ssh.Certificate{ ValidAfter: uint64(now.Add(-time.Hour).Unix()), ValidBefore: uint64(now.Add(-time.Minute).Unix()), }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_newWebhookController(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) if err != nil { t.Fatal(err) } opts := []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)} type args struct { templateData WebhookSetter certType linkedca.Webhook_CertType opts []webhook.RequestBodyOption } tests := []struct { name string args args want *WebhookController }{ {"ok", args{x509util.TemplateData{"foo": "bar"}, linkedca.Webhook_X509, nil}, &WebhookController{ TemplateData: x509util.TemplateData{"foo": "bar"}, certType: linkedca.Webhook_X509, client: http.DefaultClient, }}, {"ok with options", args{x509util.TemplateData{"foo": "bar"}, linkedca.Webhook_SSH, opts}, &WebhookController{ TemplateData: x509util.TemplateData{"foo": "bar"}, certType: linkedca.Webhook_SSH, client: http.DefaultClient, options: opts, }}, } for _, tt := range tests { c := Controller{ webhookClient: new(http.Client), wrapTransport: httptransport.NoopWrapper(), } got := c.newWebhookController(tt.args.templateData, tt.args.certType, tt.args.opts...) assert.Equal(t, tt.args.templateData, got.TemplateData) assert.Same(t, c.webhookClient, got.client) assert.Equal(t, c.webhooks, got.webhooks) assert.Equal(t, tt.args.opts, got.options) assert.Equal(t, tt.args.certType, got.certType) } } ================================================ FILE: authority/provisioner/duration.go ================================================ package provisioner import ( "encoding/json" "time" "github.com/pkg/errors" ) // Duration is a wrapper around Time.Duration to aid with marshal/unmarshal. type Duration struct { time.Duration } // NewDuration parses a duration string and returns a Duration type or an error // if the given string is not a duration. func NewDuration(s string) (*Duration, error) { d, err := time.ParseDuration(s) if err != nil { return nil, errors.Wrapf(err, "error parsing %s as duration", s) } return &Duration{Duration: d}, nil } // MarshalJSON parses a duration string and sets it to the duration. // // A duration string is a possibly signed sequence of decimal numbers, each with // optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". func (d *Duration) MarshalJSON() ([]byte, error) { return json.Marshal(d.Duration.String()) } // UnmarshalJSON parses a duration string and sets it to the duration. // // A duration string is a possibly signed sequence of decimal numbers, each with // optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". func (d *Duration) UnmarshalJSON(data []byte) (err error) { var ( s string dd time.Duration ) if d == nil { return errors.New("duration cannot be nil") } if err = json.Unmarshal(data, &s); err != nil { return errors.Wrapf(err, "error unmarshaling %s", data) } if dd, err = time.ParseDuration(s); err != nil { return errors.Wrapf(err, "error parsing %s as duration", s) } d.Duration = dd return } // Value returns 0 if the duration is null, the inner duration otherwise. func (d *Duration) Value() time.Duration { if d == nil { return 0 } return d.Duration } ================================================ FILE: authority/provisioner/duration_test.go ================================================ package provisioner import ( "reflect" "testing" "time" ) func TestNewDuration(t *testing.T) { type args struct { s string } tests := []struct { name string args args want *Duration wantErr bool }{ {"ok", args{"1h2m3s"}, &Duration{Duration: 3723 * time.Second}, false}, {"fail empty", args{""}, nil, true}, {"fail number", args{"123"}, nil, true}, {"fail string", args{"1hour"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := NewDuration(tt.args.s) if (err != nil) != tt.wantErr { t.Errorf("NewDuration() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewDuration() = %v, want %v", got, tt.want) } }) } } func TestDuration_UnmarshalJSON(t *testing.T) { type args struct { data []byte } tests := []struct { name string d *Duration args args want *Duration wantErr bool }{ {"empty", new(Duration), args{[]byte{}}, new(Duration), true}, {"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true}, {"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true}, {"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true}, {"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false}, {"nil", nil, args{nil}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(tt.d, tt.want) { t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want) } }) } } func TestDuration_MarshalJSON(t *testing.T) { tests := []struct { name string d *Duration want []byte wantErr bool }{ {"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.d.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want) } }) } } func TestDuration_Value(t *testing.T) { var dur *Duration tests := []struct { name string duration *Duration want time.Duration }{ {"ok", &Duration{Duration: 1 * time.Minute}, 1 * time.Minute}, {"ok new", new(Duration), 0}, {"ok nil", nil, 0}, {"ok nil var", dur, 0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.duration.Value(); got != tt.want { t.Errorf("Duration.Value() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/provisioner/extension.go ================================================ package provisioner import ( "crypto/x509" "crypto/x509/pkix" "encoding/asn1" ) var ( // StepOIDRoot is the root OID for smallstep. StepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} // StepOIDProvisioner is the OID for the provisioner extension. StepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(StepOIDRoot, 1)...) ) // Extension is the Go representation of the provisioner extension. type Extension struct { Type Type Name string CredentialID string KeyValuePairs []string } type extensionASN1 struct { Type int Name []byte CredentialID []byte KeyValuePairs []string `asn1:"optional,omitempty"` } // Marshal marshals the extension using encoding/asn1. func (e *Extension) Marshal() ([]byte, error) { return asn1.Marshal(extensionASN1{ Type: int(e.Type), Name: []byte(e.Name), CredentialID: []byte(e.CredentialID), KeyValuePairs: e.KeyValuePairs, }) } // ToExtension returns the pkix.Extension representation of the provisioner // extension. func (e *Extension) ToExtension() (pkix.Extension, error) { b, err := e.Marshal() if err != nil { return pkix.Extension{}, err } return pkix.Extension{ Id: StepOIDProvisioner, Value: b, }, nil } // GetProvisionerExtension goes through all the certificate extensions and // returns the provisioner extension (1.3.6.1.4.1.37476.9000.64.1). func GetProvisionerExtension(cert *x509.Certificate) (*Extension, bool) { for _, e := range cert.Extensions { if e.Id.Equal(StepOIDProvisioner) { var provisioner extensionASN1 if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { return nil, false } return &Extension{ Type: Type(provisioner.Type), Name: string(provisioner.Name), CredentialID: string(provisioner.CredentialID), KeyValuePairs: provisioner.KeyValuePairs, }, true } } return nil, false } ================================================ FILE: authority/provisioner/extension_test.go ================================================ package provisioner import ( "crypto/x509" "crypto/x509/pkix" "reflect" "testing" "go.step.sm/crypto/pemutil" ) func TestExtension_Marshal(t *testing.T) { type fields struct { Type Type Name string CredentialID string KeyValuePairs []string } tests := []struct { name string fields fields want []byte wantErr bool }{ {"ok", fields{TypeJWK, "name", "credentialID", nil}, []byte{ 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, 0x44, }, false}, {"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, []byte{ 0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, 0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f, 0x13, 0x03, 0x62, 0x61, 0x72, }, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := &Extension{ Type: tt.fields.Type, Name: tt.fields.Name, CredentialID: tt.fields.CredentialID, KeyValuePairs: tt.fields.KeyValuePairs, } got, err := e.Marshal() if (err != nil) != tt.wantErr { t.Errorf("Extension.Marshal() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Extension.Marshal() = %x, want %v", got, tt.want) } }) } } func TestExtension_ToExtension(t *testing.T) { type fields struct { Type Type Name string CredentialID string KeyValuePairs []string } tests := []struct { name string fields fields want pkix.Extension wantErr bool }{ {"ok", fields{TypeJWK, "name", "credentialID", nil}, pkix.Extension{ Id: StepOIDProvisioner, Value: []byte{ 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, 0x44, }, }, false}, {"ok empty pairs", fields{TypeJWK, "name", "credentialID", []string{}}, pkix.Extension{ Id: StepOIDProvisioner, Value: []byte{ 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, 0x44, }, }, false}, {"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, pkix.Extension{ Id: StepOIDProvisioner, Value: []byte{ 0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, 0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f, 0x13, 0x03, 0x62, 0x61, 0x72, }, }, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := &Extension{ Type: tt.fields.Type, Name: tt.fields.Name, CredentialID: tt.fields.CredentialID, KeyValuePairs: tt.fields.KeyValuePairs, } got, err := e.ToExtension() if (err != nil) != tt.wantErr { t.Errorf("Extension.ToExtension() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Extension.ToExtension() = %v, want %v", got, tt.want) } }) } } func TestGetProvisionerExtension(t *testing.T) { mustCertificate := func(fn string) *x509.Certificate { cert, err := pemutil.ReadCertificate(fn) if err != nil { t.Fatal(err) } return cert } type args struct { cert *x509.Certificate } tests := []struct { name string args args want *Extension want1 bool }{ {"ok", args{mustCertificate("testdata/certs/good-extension.crt")}, &Extension{ Type: TypeJWK, Name: "mariano@smallstep.com", CredentialID: "nvgnR8wSzpUlrt_tC3mvrhwhBx9Y7T1WL_JjcFVWYBQ", }, true}, {"fail unmarshal", args{mustCertificate("testdata/certs/bad-extension.crt")}, nil, false}, {"missing extension", args{mustCertificate("testdata/certs/aws.crt")}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1 := GetProvisionerExtension(tt.args.cert) if !reflect.DeepEqual(got, tt.want) { t.Errorf("GetProvisionerExtension() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { t.Errorf("GetProvisionerExtension() got1 = %v, want %v", got1, tt.want1) } }) } } ================================================ FILE: authority/provisioner/gcp/projectvalidator.go ================================================ package gcp import ( "context" "net/http" "google.golang.org/api/cloudresourcemanager/v1" "github.com/smallstep/certificates/errs" ) type ProjectValidator struct { ProjectIDs []string } func (p *ProjectValidator) ValidateProject(_ context.Context, projectID string) error { if len(p.ProjectIDs) == 0 { return nil } for _, pi := range p.ProjectIDs { if pi == projectID { return nil } } return errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid project id") } type OrganizationValidator struct { *ProjectValidator OrganizationID string projectsService *cloudresourcemanager.ProjectsService } func NewOrganizationValidator(projectIDs []string, organizationID string) (*OrganizationValidator, error) { var svc *cloudresourcemanager.ProjectsService if organizationID != "" { crm, err := cloudresourcemanager.NewService(context.Background()) if err != nil { return nil, err } svc = crm.Projects } return &OrganizationValidator{ ProjectValidator: &ProjectValidator{projectIDs}, OrganizationID: organizationID, projectsService: svc, }, nil } func (p *OrganizationValidator) ValidateProject(ctx context.Context, projectID string) error { if err := p.ProjectValidator.ValidateProject(ctx, projectID); err != nil { return err } if p.OrganizationID == "" { return nil } ancestry, err := p.projectsService. GetAncestry(projectID, &cloudresourcemanager.GetAncestryRequest{}). Context(ctx). Do() if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "gcp.authorizeToken") } if len(ancestry.Ancestor) < 1 { return errs.InternalServer("gcp.authorizeToken; getAncestry response malformed") } progenitor := ancestry.Ancestor[len(ancestry.Ancestor)-1] if progenitor.ResourceId.Type != "organization" || progenitor.ResourceId.Id != p.OrganizationID { return errs.Unauthorized("gcp.authorizeToken; invalid gcp token - project does not belong to organization") } return nil } ================================================ FILE: authority/provisioner/gcp/projectvalidator_test.go ================================================ package gcp import ( "context" "testing" "github.com/stretchr/testify/assert" "google.golang.org/api/cloudresourcemanager/v1" ) func TestProjectValidator_ValidateProject(t *testing.T) { ctx := context.Background() type fields struct { ProjectIDs []string } type args struct { in0 context.Context projectID string } tests := []struct { name string fields fields args args assertion assert.ErrorAssertionFunc }{ {"allowed-1", fields{[]string{"allowed-1", "allowed-2"}}, args{ctx, "allowed-1"}, assert.NoError}, {"allowed-2", fields{[]string{"allowed-1", "allowed-2"}}, args{ctx, "allowed-2"}, assert.NoError}, {"empty", fields{nil}, args{ctx, "allowed-1"}, assert.NoError}, {"not allowed", fields{[]string{"allowed-1", "allowed-2"}}, args{ctx, "not-allowed"}, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &ProjectValidator{ ProjectIDs: tt.fields.ProjectIDs, } tt.assertion(t, p.ValidateProject(tt.args.in0, tt.args.projectID)) }) } } func TestNewOrganizationValidator(t *testing.T) { ctx := context.Background() _, err := cloudresourcemanager.NewService(ctx) skip := (err != nil) type args struct { projectIDs []string organizationID string } tests := []struct { name string skip bool args args want *OrganizationValidator assertion assert.ErrorAssertionFunc }{ {"ok projects", false, args{[]string{"project-1", "project-2"}, ""}, &OrganizationValidator{ ProjectValidator: &ProjectValidator{[]string{"project-1", "project-2"}}, }, assert.NoError}, {"ok organization", skip, args{[]string{}, "organization"}, &OrganizationValidator{ ProjectValidator: &ProjectValidator{[]string{}}, OrganizationID: "organization", projectsService: &cloudresourcemanager.ProjectsService{}, }, assert.NoError}, {"ok projects organization", skip, args{[]string{"project-1"}, "organization"}, &OrganizationValidator{ ProjectValidator: &ProjectValidator{[]string{"project-1"}}, OrganizationID: "organization", projectsService: &cloudresourcemanager.ProjectsService{}, }, assert.NoError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.skip { t.SkipNow() return } got, err := NewOrganizationValidator(tt.args.projectIDs, tt.args.organizationID) tt.assertion(t, err) assert.EqualExportedValues(t, tt.want, got) }) } } func TestOrganizationValidator_ValidateProject(t *testing.T) { ctx := context.Background() svc, err := cloudresourcemanager.NewService(ctx) skip := (err != nil) var projectsService *cloudresourcemanager.ProjectsService if !skip { projectsService = svc.Projects } type fields struct { ProjectValidator *ProjectValidator OrganizationID string projectsService *cloudresourcemanager.ProjectsService } type args struct { ctx context.Context projectID string } tests := []struct { name string skip bool fields fields args args assertion assert.ErrorAssertionFunc }{ {"ok projects", false, fields{&ProjectValidator{ProjectIDs: []string{"allowed"}}, "", projectsService}, args{ctx, "allowed"}, assert.NoError}, {"fail projects", false, fields{&ProjectValidator{ProjectIDs: []string{"allowed"}}, "organization", projectsService}, args{ctx, "not-allowed"}, assert.Error}, {"fail organization", skip, fields{&ProjectValidator{ProjectIDs: []string{"allowed"}}, "fake-organization", projectsService}, args{ctx, "allowed"}, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &OrganizationValidator{ ProjectValidator: tt.fields.ProjectValidator, OrganizationID: tt.fields.OrganizationID, projectsService: tt.fields.projectsService, } if tt.skip { t.SkipNow() return } tt.assertion(t, p.ValidateProject(tt.args.ctx, tt.args.projectID)) }) } } ================================================ FILE: authority/provisioner/gcp.go ================================================ package provisioner import ( "bytes" "context" "crypto/sha256" "crypto/x509" "encoding/hex" "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/pkg/errors" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/authority/provisioner/gcp" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/webhook" ) // gcpCertsURL is the url that serves Google OAuth2 public keys. const gcpCertsURL = "https://www.googleapis.com/oauth2/v3/certs" // gcpIdentityURL is the base url for the identity document in GCP. const gcpIdentityURL = "http://metadata/computeMetadata/v1/instance/service-accounts/default/identity" // DefaultDisableSSHCAHost is the default value for SSH Host CA used when DisableSSHCAHost is not set var DefaultDisableSSHCAHost = false // DefaultDisableSSHCAUser is the default value for SSH User CA used when DisableSSHCAUser is not set var DefaultDisableSSHCAUser = true // gcpPayload extends jwt.Claims with custom GCP attributes. type gcpPayload struct { jose.Claims AuthorizedParty string `json:"azp"` Email string `json:"email"` EmailVerified bool `json:"email_verified"` Google gcpGooglePayload `json:"google"` } type gcpGooglePayload struct { ComputeEngine gcpComputeEnginePayload `json:"compute_engine"` } type gcpComputeEnginePayload struct { InstanceID string `json:"instance_id"` InstanceName string `json:"instance_name"` InstanceCreationTimestamp *jose.NumericDate `json:"instance_creation_timestamp"` ProjectID string `json:"project_id"` ProjectNumber int64 `json:"project_number"` Zone string `json:"zone"` LicenseID []string `json:"license_id"` } type gcpConfig struct { CertsURL string IdentityURL string } func newGCPConfig() *gcpConfig { return &gcpConfig{ CertsURL: gcpCertsURL, IdentityURL: gcpIdentityURL, } } // projectValidator is an interface to enable testing without using // gcp.OrganizationProjectValidator. type projectValidator interface { ValidateProject(ctx context.Context, projectID string) error } // GCP is the provisioner that supports identity tokens created by the Google // Cloud Platform metadata API. // // If DisableCustomSANs is true, only the internal DNS and IP will be added as a // SAN. By default it will accept any SAN in the CSR. // // If DisableTrustOnFirstUse is true, multiple sign request for this provisioner // with the same instance will be accepted. By default only the first request // will be accepted. // // If InstanceAge is set, only the instances with an instance_creation_timestamp // within the given period will be accepted. // // Google Identity docs are available at // https://cloud.google.com/compute/docs/instances/verifying-instance-identity type GCP struct { *base ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` ServiceAccounts []string `json:"serviceAccounts,omitempty"` ProjectIDs []string `json:"projectIDs,omitempty"` OrganizationID string `json:"organizationID,omitempty"` DisableCustomSANs bool `json:"disableCustomSANs"` DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` DisableSSHCAUser *bool `json:"disableSSHCAUser,omitempty"` DisableSSHCAHost *bool `json:"disableSSHCAHost,omitempty"` InstanceAge Duration `json:"instanceAge,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` config *gcpConfig keyStore *keyStore ctl *Controller projectValidator projectValidator } // GetID returns the provisioner unique identifier. The name should uniquely // identify any GCP provisioner. func (p *GCP) GetID() string { if p.ID != "" { return p.ID } return p.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (p *GCP) GetIDForToken() string { return "gcp/" + p.Name } // GetTokenID returns the identifier of the token. The default value for GCP the // SHA256 of "provisioner_id.instance_id", but if DisableTrustOnFirstUse is set // to true, then it will be the SHA256 of the token. func (p *GCP) GetTokenID(token string) (string, error) { jwt, err := jose.ParseSigned(token) if err != nil { return "", errors.Wrap(err, "error parsing token") } // If TOFU is disabled create an ID for the token, so it cannot be reused. if p.DisableTrustOnFirstUse { sum := sha256.Sum256([]byte(token)) return strings.ToLower(hex.EncodeToString(sum[:])), nil } // Get claims w/out verification. var claims gcpPayload if err = jwt.UnsafeClaimsWithoutVerification(&claims); err != nil { return "", errors.Wrap(err, "error verifying claims") } // Create unique ID for Trust On First Use (TOFU). Only the first instance // per provisioner is allowed as we don't have a way to trust the given // sans. unique := fmt.Sprintf("%s.%s", p.GetIDForToken(), claims.Google.ComputeEngine.InstanceID) sum := sha256.Sum256([]byte(unique)) return strings.ToLower(hex.EncodeToString(sum[:])), nil } // GetName returns the name of the provisioner. func (p *GCP) GetName() string { return p.Name } // GetType returns the type of provisioner. func (p *GCP) GetType() Type { return TypeGCP } // GetEncryptedKey is not available in a GCP provisioner. func (p *GCP) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } // GetIdentityURL returns the url that generates the GCP token. func (p *GCP) GetIdentityURL(audience string) string { // Initialize config if required p.assertConfig() q := url.Values{} q.Add("audience", audience) q.Add("format", "full") q.Add("licenses", "FALSE") return fmt.Sprintf("%s?%s", p.config.IdentityURL, q.Encode()) } // GetIdentityToken does an HTTP request to the identity url. func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) { _ = subject // unused input audience, err := generateSignAudience(caURL, p.GetIDForToken()) if err != nil { return "", err } req, err := http.NewRequest("GET", p.GetIdentityURL(audience), http.NoBody) if err != nil { return "", errors.Wrap(err, "error creating identity request") } req.Header.Set("Metadata-Flavor", "Google") resp, err := http.DefaultClient.Do(req) if err != nil { return "", errors.Wrap(err, "error doing identity request, are you in a GCP VM?") } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return "", errors.Wrap(err, "error on identity request") } if resp.StatusCode >= 400 { return "", errors.Errorf("error on identity request: status=%d, response=%s", resp.StatusCode, b) } return string(bytes.TrimSpace(b)), nil } // Init validates and initializes the GCP provisioner. func (p *GCP) Init(config Config) (err error) { if p.DisableSSHCAHost == nil { p.DisableSSHCAHost = &DefaultDisableSSHCAHost } if p.DisableSSHCAUser == nil { p.DisableSSHCAUser = &DefaultDisableSSHCAUser } switch { case p.Type == "": return errors.New("provisioner type cannot be empty") case p.Name == "": return errors.New("provisioner name cannot be empty") case p.InstanceAge.Value() < 0: return errors.New("provisioner instanceAge cannot be negative") } if len(p.ProjectIDs) > 0 && p.OrganizationID != "" { return errors.New("provisioner cannot have both `projectIDs` and `organizationID` set") } // Initialize config p.assertConfig() // Initialize key store if p.keyStore, err = newKeyStore(http.DefaultClient, p.config.CertsURL); err != nil { return } // Initialize the project validator if p.projectValidator, err = gcp.NewOrganizationValidator(p.ProjectIDs, p.OrganizationID); err != nil { return } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) if p.ctl, err = NewController(p, p.Claims, config, p.Options); err != nil { return } return } // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(ctx, token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign") } ce := claims.Google.ComputeEngine // Template options data := x509util.NewTemplateData() data.SetCommonName(ce.InstanceName) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } // Enforce known common name and default DNS if configured. // By default we we'll accept the CN and SANs in the CSR. // There's no way to trust them other than TOFU. var so []SignOption if p.DisableCustomSANs { dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID) dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID) so = append(so, commonNameSliceValidator([]string{ ce.InstanceName, ce.InstanceID, dnsName1, dnsName2, }), dnsNamesSubsetValidator([]string{ dnsName1, dnsName2, }), ipAddressesValidator(nil), emailAddressesValidator(nil), newURIsValidator(ctx, nil), ) // Template SANs data.SetSANs([]string{dnsName1, dnsName2}) } templateOptions, err := CustomTemplateOptions(p.Options, data, x509util.DefaultIIDLeafTemplate) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign") } return append(so, p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName).WithControllerOptions(p.ctl), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), p.ctl.newWebhookController( data, linkedca.Webhook_X509, webhook.WithAuthorizationPrincipal(ce.InstanceID), ), ), nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { return p.ctl.AuthorizeRenew(ctx, cert) } // assertConfig initializes the config if it has not been initialized. func (p *GCP) assertConfig() { if p.config == nil { p.config = newGCPConfig() } } // authorizeToken performs common jwt authorization actions and returns the // claims for case specific downstream parsing. // e.g. a Sign request will auth/validate different fields than a Revoke request. func (p *GCP) authorizeToken(ctx context.Context, token string) (*gcpPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; error parsing gcp token") } if len(jwt.Headers) == 0 { return nil, errs.Unauthorized("gcp.authorizeToken; error parsing gcp token - header is missing") } var found bool var claims gcpPayload kid := jwt.Headers[0].KeyID keys := p.keyStore.Get(kid) for _, key := range keys { if err := jwt.Claims(key.Public(), &claims); err == nil { found = true break } } if !found { return nil, errs.Unauthorized("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid %s", kid) } // According to "rfc7519 JSON Web Token" acceptable skew should be no // more than a few minutes. now := time.Now().UTC() if err = claims.ValidateWithLeeway(jose.Expected{ Issuer: "https://accounts.google.com", Time: now, }, time.Minute); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "gcp.authorizeToken; invalid gcp token payload") } // validate audiences with the defaults if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) { return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)") } // validate subject (service account) if len(p.ServiceAccounts) > 0 { var found bool for _, sa := range p.ServiceAccounts { if sa == claims.Subject || sa == claims.Email { found = true break } } if !found { return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid subject claim") } } // validate projects if err := p.projectValidator.ValidateProject(ctx, claims.Google.ComputeEngine.ProjectID); err != nil { return nil, err } // validate instance age if d := p.InstanceAge.Value(); d > 0 { if now.Sub(claims.Google.ComputeEngine.InstanceCreationTimestamp.Time()) > d { return nil, errs.Unauthorized("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old") } } switch { case claims.Google.ComputeEngine.InstanceID == "": return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty") case claims.Google.ComputeEngine.InstanceName == "": return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty") case claims.Google.ComputeEngine.ProjectID == "": return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty") case claims.Google.ComputeEngine.Zone == "": return nil, errs.Unauthorized("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty") } return &claims, nil } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { certType, hasCertType := CertTypeFromContext(ctx) if !hasCertType { certType = SSHHostCert } err := p.isUnauthorizedToIssueSSHCert(certType) if err != nil { return nil, err } claims, err := p.authorizeToken(ctx, token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSSHSign") } var principals []string var keyID string var defaults SignSSHOptions var ct sshutil.CertType var template string switch certType { case SSHHostCert: defaults, keyID, principals, ct, template = p.genHostOptions(ctx, claims) case SSHUserCert: defaults, keyID, principals, ct, template = p.genUserOptions(ctx, claims) default: return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; invalid requested certType") } signOptions := []SignOption{} // Only enforce known principals if disable custom sans is true, or it is a user cert request if p.DisableCustomSANs || certType == SSHUserCert { defaults.Principals = principals } else { // Check that at least one principal is sent in the request. signOptions = append(signOptions, &sshCertOptionsRequireValidator{ Principals: true, }) } // Certificate templates. data := sshutil.CreateTemplateData(ct, keyID, principals) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } templateOptions, err := CustomSSHTemplateOptions(p.Options, data, template) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSSHSign") } signOptions = append(signOptions, templateOptions) return append(signOptions, p, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, // Ensure that all principal names are allowed newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), // Call webhooks p.ctl.newWebhookController( data, linkedca.Webhook_SSH, webhook.WithAuthorizationPrincipal(keyID), ), ), nil } func (p *GCP) genHostOptions(_ context.Context, claims *gcpPayload) (SignSSHOptions, string, []string, sshutil.CertType, string) { ce := claims.Google.ComputeEngine keyID := ce.InstanceName principals := []string{ fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID), fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID), } return SignSSHOptions{CertType: SSHHostCert}, keyID, principals, sshutil.HostCert, sshutil.DefaultIIDTemplate } func FormatServiceAccountUsername(serviceAccountID string) string { return fmt.Sprintf("sa_%v", serviceAccountID) } func (p *GCP) genUserOptions(_ context.Context, claims *gcpPayload) (SignSSHOptions, string, []string, sshutil.CertType, string) { keyID := claims.Email principals := []string{ FormatServiceAccountUsername(claims.Subject), claims.Email, } return SignSSHOptions{CertType: SSHUserCert}, keyID, principals, sshutil.UserCert, sshutil.DefaultTemplate } func (p *GCP) isUnauthorizedToIssueSSHCert(certType string) error { if !p.ctl.Claimer.IsSSHCAEnabled() { return errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName()) } if certType == SSHHostCert && *p.DisableSSHCAHost { return errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA for Hosts is disabled for gcp provisioner '%s'", p.GetName()) } if certType == SSHUserCert && *p.DisableSSHCAUser { return errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA for Users is disabled for gcp provisioner '%s'", p.GetName()) } return nil } ================================================ FILE: authority/provisioner/gcp_test.go ================================================ package provisioner import ( "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/hex" "errors" "fmt" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" "go.step.sm/crypto/jose" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner/gcp" ) func TestGCP_Getters(t *testing.T) { p, err := generateGCP() assert.FatalError(t, err) id := "gcp/" + p.Name if got := p.GetID(); got != id { t.Errorf("GCP.GetID() = %v, want %v", got, id) } if got := p.GetName(); got != p.Name { t.Errorf("GCP.GetName() = %v, want %v", got, p.Name) } if got := p.GetType(); got != TypeGCP { t.Errorf("GCP.GetType() = %v, want %v", got, TypeGCP) } kid, key, ok := p.GetEncryptedKey() if kid != "" || key != "" || ok == true { t.Errorf("GCP.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } aud := "https://ca.smallstep.com/1.0/sign#" + url.QueryEscape(id) expected := fmt.Sprintf("http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=%s&format=full&licenses=FALSE", url.QueryEscape(aud)) if got := p.GetIdentityURL(aud); got != expected { t.Errorf("GCP.GetIdentityURL() = %v, want %v", got, expected) } } func TestGCP_GetTokenID(t *testing.T) { p1, err := generateGCP() assert.FatalError(t, err) p1.Name = "name" p2, err := generateGCP() assert.FatalError(t, err) p2.DisableTrustOnFirstUse = true now := time.Now() t1, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", "gcp/name", "instance-id", "instance-name", "project-id", "zone", now, &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) t2, err := generateGCPToken(p2.ServiceAccounts[0], "https://accounts.google.com", p2.GetID(), "instance-id", "instance-name", "project-id", "zone", now, &p2.keyStore.keySet.Keys[0]) assert.FatalError(t, err) sum := sha256.Sum256([]byte("gcp/name.instance-id")) want1 := strings.ToLower(hex.EncodeToString(sum[:])) sum = sha256.Sum256([]byte(t2)) want2 := strings.ToLower(hex.EncodeToString(sum[:])) type args struct { token string } tests := []struct { name string gcp *GCP args args want string wantErr bool }{ {"ok", p1, args{t1}, want1, false}, {"ok", p2, args{t2}, want2, false}, {"fail token", p1, args{"token"}, "", true}, {"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.gcp.GetTokenID(tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("GCP.GetTokenID() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("GCP.GetTokenID() = %v, want %v", got, tt.want) } }) } } func TestGCP_GetIdentityToken(t *testing.T) { p1, err := generateGCP() assert.FatalError(t, err) t1, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/bad-request": http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) default: w.Write([]byte(t1)) } })) defer srv.Close() type args struct { subject string caURL string } tests := []struct { name string gcp *GCP args args identityURL string want string wantErr bool }{ {"ok", p1, args{"subject", "https://ca"}, srv.URL, t1, false}, {"fail ca url", p1, args{"subject", "://ca"}, srv.URL, "", true}, {"fail request", p1, args{"subject", "https://ca"}, srv.URL + "/bad-request", "", true}, {"fail url", p1, args{"subject", "https://ca"}, "://ca.smallstep.com", "", true}, {"fail connect", p1, args{"subject", "https://ca"}, "foobarzar", "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.gcp.config.IdentityURL = tt.identityURL got, err := tt.gcp.GetIdentityToken(tt.args.subject, tt.args.caURL) t.Log(err) if (err != nil) != tt.wantErr { t.Errorf("GCP.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("GCP.GetIdentityToken() = %v, want %v", got, tt.want) } }) } } func TestGCP_Init(t *testing.T) { srv := generateJWKServer(2) defer srv.Close() config := Config{ Claims: globalProvisionerClaims, } badClaims := &Claims{ DefaultTLSDur: &Duration{0}, } zero := Duration{Duration: 0} type fields struct { Type string Name string ServiceAccounts []string InstanceAge Duration Claims *Claims } type args struct { config Config certsURL string } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{"GCP", "name", nil, zero, nil}, args{config, srv.URL}, false}, {"ok", fields{"GCP", "name", nil, zero, nil}, args{config, srv.URL}, false}, {"ok", fields{"GCP", "name", []string{"service-account"}, zero, nil}, args{config, srv.URL}, false}, {"ok", fields{"GCP", "name", []string{"service-account"}, Duration{Duration: 1 * time.Minute}, nil}, args{config, srv.URL}, false}, {"bad type", fields{"", "name", nil, zero, nil}, args{config, srv.URL}, true}, {"bad name", fields{"GCP", "", nil, zero, nil}, args{config, srv.URL}, true}, {"bad duration", fields{"GCP", "name", nil, Duration{Duration: -1 * time.Minute}, nil}, args{config, srv.URL}, true}, {"bad claims", fields{"GCP", "name", nil, zero, badClaims}, args{config, srv.URL}, true}, {"bad certs", fields{"GCP", "name", nil, zero, nil}, args{config, srv.URL + "/error"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &GCP{ Type: tt.fields.Type, Name: tt.fields.Name, ServiceAccounts: tt.fields.ServiceAccounts, InstanceAge: tt.fields.InstanceAge, Claims: tt.fields.Claims, config: &gcpConfig{ CertsURL: tt.args.certsURL, IdentityURL: gcpIdentityURL, }, } if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { t.Errorf("GCP.Init() error = %v, wantErr %v", err, tt.wantErr) } if *p.DisableSSHCAUser != true { t.Errorf("By default DisableSSHCAUser should be true") } if *p.DisableSSHCAHost != false { t.Errorf("By default DisableSSHCAHost should be false") } }) } } func TestGCP_authorizeToken(t *testing.T) { type test struct { p *GCP token string err error code int } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("gcp.authorizeToken; error parsing gcp token"), } }, "fail/cannot-validate-sig": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) tok, err := generateGCPToken(p.ServiceAccounts[0], "https://accounts.google.com", p.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), jwk) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("gcp.authorizeToken; failed to validate gcp token payload - cannot find key for kid "), } }, "fail/invalid-issuer": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) tok, err := generateGCPToken(p.ServiceAccounts[0], "https://foo.bar.zap", p.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("gcp.authorizeToken; invalid gcp token payload"), } }, "fail/invalid-serviceAccount": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) tok, err := generateGCPToken("foo", "https://accounts.google.com", p.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("gcp.authorizeToken; invalid gcp token - invalid subject claim"), } }, "fail/invalid-projectID": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) p.projectValidator = &gcp.ProjectValidator{ProjectIDs: []string{"foo", "bar"}} tok, err := generateGCPToken(p.ServiceAccounts[0], "https://accounts.google.com", p.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("gcp.authorizeToken; invalid gcp token - invalid project id"), } }, "fail/instance-age": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) p.InstanceAge = Duration{1 * time.Minute} tok, err := generateGCPToken(p.ServiceAccounts[0], "https://accounts.google.com", p.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now().Add(-1*time.Minute), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("gcp.authorizeToken; token google.compute_engine.instance_creation_timestamp is too old"), } }, "fail/empty-instance-id": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) tok, err := generateGCPToken(p.ServiceAccounts[0], "https://accounts.google.com", p.GetID(), "", "instance-name", "project-id", "zone", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_id cannot be empty"), } }, "fail/empty-instance-name": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) tok, err := generateGCPToken(p.ServiceAccounts[0], "https://accounts.google.com", p.GetID(), "instance-id", "", "project-id", "zone", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.instance_name cannot be empty"), } }, "fail/empty-project-id": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) tok, err := generateGCPToken(p.ServiceAccounts[0], "https://accounts.google.com", p.GetID(), "instance-id", "instance-name", "", "zone", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.project_id cannot be empty"), } }, "fail/empty-zone": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) tok, err := generateGCPToken(p.ServiceAccounts[0], "https://accounts.google.com", p.GetID(), "instance-id", "instance-name", "project-id", "", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("gcp.authorizeToken; gcp token google.compute_engine.zone cannot be empty"), } }, "ok": func(t *testing.T) test { p, err := generateGCP() assert.FatalError(t, err) tok, err := generateGCPToken(p.ServiceAccounts[0], "https://accounts.google.com", p.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) && assert.NotNil(t, claims) { assert.Equals(t, claims.Subject, tc.p.ServiceAccounts[0]) assert.Equals(t, claims.Issuer, "https://accounts.google.com") assert.NotNil(t, claims.Google) aud, err := generateSignAudience("https://ca.smallstep.com", tc.p.GetID()) assert.FatalError(t, err) assert.Equals(t, claims.Audience[0], aud) } } }) } } func TestGCP_AuthorizeSign(t *testing.T) { p1, err := generateGCP() assert.FatalError(t, err) p2, err := generateGCP() assert.FatalError(t, err) p2.DisableCustomSANs = true p3, err := generateGCP() assert.FatalError(t, err) p3.projectValidator = &gcp.ProjectValidator{ProjectIDs: []string{"other-project-id"}} p3.ServiceAccounts = []string{"foo@developer.gserviceaccount.com"} p3.InstanceAge = Duration{1 * time.Minute} aKey, err := generateJSONWebKey() assert.FatalError(t, err) t1, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) t2, err := generateGCPToken(p2.ServiceAccounts[0], "https://accounts.google.com", p2.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p2.keyStore.keySet.Keys[0]) assert.FatalError(t, err) t3, err := generateGCPToken(p3.ServiceAccounts[0], "https://accounts.google.com", p3.GetID(), "instance-id", "instance-name", "other-project-id", "zone", time.Now(), &p3.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failKey, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), aKey) assert.FatalError(t, err) failIss, err := generateGCPToken(p1.ServiceAccounts[0], "https://foo.bar.zar", p1.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failAud, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", "gcp:foo", "instance-id", "instance-name", "project-id", "zone", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failExp, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failNbf, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failServiceAccount, err := generateGCPToken("foo", "https://accounts.google.com", p1.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failInvalidProjectID, err := generateGCPToken(p3.ServiceAccounts[0], "https://accounts.google.com", p3.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p3.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failInvalidInstanceAge, err := generateGCPToken(p3.ServiceAccounts[0], "https://accounts.google.com", p3.GetID(), "instance-id", "instance-name", "other-project-id", "zone", time.Now().Add(-1*time.Minute), &p3.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failInstanceID, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "", "instance-name", "project-id", "zone", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failInstanceName, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "instance-id", "", "project-id", "zone", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failProjectID, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "instance-id", "instance-name", "", "zone", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failZone, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "instance-id", "instance-name", "project-id", "", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) type args struct { token string } tests := []struct { name string gcp *GCP args args wantLen int code int wantErr bool }{ {"ok", p1, args{t1}, 8, http.StatusOK, false}, {"ok", p2, args{t2}, 13, http.StatusOK, false}, {"ok", p3, args{t3}, 8, http.StatusOK, false}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, {"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true}, {"fail aud", p1, args{failAud}, 0, http.StatusUnauthorized, true}, {"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true}, {"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true}, {"fail service account", p1, args{failServiceAccount}, 0, http.StatusUnauthorized, true}, {"fail invalid project id", p3, args{failInvalidProjectID}, 0, http.StatusUnauthorized, true}, {"fail invalid instance age", p3, args{failInvalidInstanceAge}, 0, http.StatusUnauthorized, true}, {"fail instance id", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true}, {"fail instance name", p1, args{failInstanceName}, 0, http.StatusUnauthorized, true}, {"fail project id", p1, args{failProjectID}, 0, http.StatusUnauthorized, true}, {"fail zone", p1, args{failZone}, 0, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) switch got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token); { case (err != nil) != tt.wantErr: t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Equals(t, tt.wantLen, len(got)) for _, o := range got { switch v := o.(type) { case *GCP: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeGCP) assert.Equals(t, v.Name, tt.gcp.GetName()) assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0]) assert.Len(t, 4, v.KeyValuePairs) case profileDefaultDuration: assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration()) case commonNameSliceValidator: assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) case defaultPublicKeyValidator: case *validityValidator: assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, v, nil) case emailAddressesValidator: assert.Equals(t, v, nil) case *urisValidator: assert.Equals(t, v.uris, nil) assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case dnsNamesSubsetValidator: assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) case *x509NamePolicyValidator: assert.Equals(t, nil, v.policyEngine) case *WebhookController: assert.Len(t, 0, v.webhooks) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } }) } } func TestGCP_AuthorizeSSHSign(t *testing.T) { tm, fn := mockNow() defer fn() p1, err := generateGCP() assert.FatalError(t, err) p1.DisableCustomSANs = true // enable ssh user CA disableSSCAUser := false p1.DisableSSHCAUser = &disableSSCAUser p2, err := generateGCP() assert.FatalError(t, err) p2.DisableCustomSANs = false p3, err := generateGCP() assert.FatalError(t, err) // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) p4, err := generateGCP() assert.FatalError(t, err) // disable ssh host CA disableSSCAHost := true p4.DisableSSHCAHost = &disableSSCAHost t1, err := generateGCPToken(p1.ServiceAccounts[0], "https://accounts.google.com", p1.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) t2, err := generateGCPToken(p2.ServiceAccounts[0], "https://accounts.google.com", p2.GetID(), "instance-id", "instance-name", "project-id", "zone", time.Now(), &p2.keyStore.keySet.Keys[0]) assert.FatalError(t, err) key, err := generateJSONWebKey() assert.FatalError(t, err) signer, err := generateJSONWebKey() assert.FatalError(t, err) pub := key.Public().Key rsa2048, err := rsa.GenerateKey(rand.Reader, 2048) assert.FatalError(t, err) //nolint:gosec // tests minimum size of the key rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } expectedHostOptionsPrincipal1 := &SignSSHOptions{ CertType: "host", Principals: []string{"instance-name.c.project-id.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } expectedHostOptionsPrincipal2 := &SignSSHOptions{ CertType: "host", Principals: []string{"instance-name.zone.c.project-id.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } expectedCustomOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"foo.bar", "bar.foo"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{FormatServiceAccountUsername(p1.ServiceAccounts[0]), "foo@developer.gserviceaccount.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(p1.ctl.Claimer.DefaultUserSSHCertDuration())), } type args struct { token string sshOpts SignSSHOptions key interface{} } tests := []struct { name string gcp *GCP args args expected *SignSSHOptions code int wantErr bool wantSignErr bool }{ {"ok", p1, args{t1, SignSSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-rsa2048", p1, args{t1, SignSSHOptions{}, rsa2048.Public()}, expectedHostOptions, http.StatusOK, false, false}, {"ok-type-host", p1, args{t1, SignSSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-type-user", p1, args{t1, SignSSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false}, {"ok-principals", p1, args{t1, SignSSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-principal1", p1, args{t1, SignSSHOptions{Principals: []string{"instance-name.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal1, http.StatusOK, false, false}, {"ok-principal2", p1, args{t1, SignSSHOptions{Principals: []string{"instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptionsPrincipal2, http.StatusOK, false, false}, {"ok-options", p1, args{t1, SignSSHOptions{CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-custom", p2, args{t2, SignSSHOptions{Principals: []string{"foo.bar", "bar.foo"}}, pub}, expectedCustomOptions, http.StatusOK, false, false}, {"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedHostOptions, http.StatusOK, false, true}, {"fail-principal", p1, args{t1, SignSSHOptions{Principals: []string{"smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, {"fail-extra-principal", p1, args{t1, SignSSHOptions{Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal", "smallstep.com"}}, pub}, nil, http.StatusOK, false, true}, {"fail-sshCA-disabled", p3, args{"foo", SignSSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, {"fail-type-host", p4, args{"foo", SignSSHOptions{CertType: "host"}, pub}, nil, http.StatusUnauthorized, true, false}, {"fail-type-user", p4, args{"foo", SignSSHOptions{CertType: "host"}, pub}, nil, http.StatusUnauthorized, true, false}, {"fail-invalid-token", p1, args{"foo", SignSSHOptions{}, pub}, expectedHostOptions, http.StatusUnauthorized, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() if tt.args.sshOpts.CertType == SSHUserCert { ctx = NewContextWithCertType(ctx, SSHUserCert) } got, err := tt.gcp.AuthorizeSSHSign(ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("GCP.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) if (err != nil) != tt.wantSignErr { t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) } else { if tt.wantSignErr { assert.Nil(t, cert) } else { assert.NoError(t, validateSSHCertificate(cert, tt.expected)) } } } }) } } func TestGCP_AuthorizeRenew(t *testing.T) { now := time.Now().Truncate(time.Second) p1, err := generateGCP() assert.FatalError(t, err) p2, err := generateGCP() assert.FatalError(t, err) // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { cert *x509.Certificate } tests := []struct { name string prov *GCP args args code int wantErr bool }{ {"ok", p1, args{&x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, http.StatusOK, false}, {"fail/renewal-disabled", p2, args{&x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) } } ================================================ FILE: authority/provisioner/jwk.go ================================================ package provisioner import ( "context" "crypto/x509" "net/http" "time" "github.com/pkg/errors" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" ) // jwtPayload extends jwt.Claims with step attributes. type jwtPayload struct { jose.Claims SANs []string `json:"sans,omitempty"` Step *stepPayload `json:"step,omitempty"` Confirmation *cnfPayload `json:"cnf,omitempty"` } type stepPayload struct { SSH *SignSSHOptions `json:"ssh,omitempty"` RA *RAInfo `json:"ra,omitempty"` } type cnfPayload struct { Fingerprint string `json:"x5rt#S256,omitempty"` } // JWK is the default provisioner, an entity that can sign tokens necessary for // signature requests. type JWK struct { *base ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Key *jose.JSONWebKey `json:"key"` EncryptedKey string `json:"encryptedKey,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id // should uniquely identify any JWK provisioner. func (p *JWK) GetID() string { if p.ID != "" { return p.ID } return p.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (p *JWK) GetIDForToken() string { return p.Name + ":" + p.Key.KeyID } // GetTokenID returns the identifier of the token. func (p *JWK) GetTokenID(ott string) (string, error) { // Validate payload token, err := jose.ParseSigned(ott) if err != nil { return "", errors.Wrap(err, "error parsing token") } // Get claims w/out verification. We need to look up the provisioner // key in order to verify the claims and we need the issuer from the claims // before we can look up the provisioner. var claims jose.Claims if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { return "", errors.Wrap(err, "error verifying claims") } return claims.ID, nil } // GetName returns the name of the provisioner. func (p *JWK) GetName() string { return p.Name } // GetType returns the type of provisioner. func (p *JWK) GetType() Type { return TypeJWK } // GetEncryptedKey returns the base provisioner encrypted key if it's defined. func (p *JWK) GetEncryptedKey() (string, string, bool) { return p.Key.KeyID, p.EncryptedKey, p.EncryptedKey != "" } // Init initializes and validates the fields of a JWK type. func (p *JWK) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") case p.Name == "": return errors.New("provisioner name cannot be empty") case p.Key == nil: return errors.New("provisioner key cannot be empty") } p.ctl, err = NewController(p, p.Claims, config, p.Options) return } // authorizeToken performs common jwt authorization actions and returns the // claims for case specific downstream parsing. // e.g. a Sign request will auth/validate different fields than a Revoke request. func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "jwk.authorizeToken; error parsing jwk token") } var claims jwtPayload if err = jwt.Claims(p.Key, &claims); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "jwk.authorizeToken; error parsing jwk claims") } // According to "rfc7519 JSON Web Token" acceptable skew should be no // more than a few minutes. if err = claims.ValidateWithLeeway(jose.Expected{ Issuer: p.Name, Time: time.Now().UTC(), }, time.Minute); err != nil { return nil, errs.Wrapf(http.StatusUnauthorized, err, "jwk.authorizeToken; invalid jwk claims") } // validate audiences with the defaults if !matchesAudience(claims.Audience, audiences) { return nil, errs.Unauthorized("jwk.authorizeToken; invalid jwk token audience claim (aud); want %s, but got %s", audiences, claims.Audience) } if claims.Subject == "" { return nil, errs.Unauthorized("jwk.authorizeToken; jwk token subject cannot be empty") } return &claims, nil } // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *JWK) AuthorizeRevoke(_ context.Context, token string) error { _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } // NOTE: This is for backwards compatibility with older versions of cli // and certificates. Older versions added the token subject as the only SAN // in a CSR by default. if len(claims.SANs) == 0 { claims.SANs = []string{claims.Subject} } // Certificate templates data := x509util.CreateTemplateData(claims.Subject, claims.SANs) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } templateOptions, err := TemplateOptions(p.Options, data) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } // Wrap provisioner if the token is an RA token. var self Interface = p if claims.Step != nil && claims.Step.RA != nil { self = &raProvisioner{ Interface: p, raInfo: claims.Step.RA, } } // Check the fingerprint of the certificate request if given. var fingerprint string if claims.Confirmation != nil { fingerprint = claims.Confirmation.Fingerprint } return []SignOption{ self, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID).WithControllerOptions(p.ctl), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators csrFingerprintValidator(fingerprint), commonNameSliceValidator(append([]string{claims.Subject}, claims.SANs...)), defaultPublicKeyValidator{}, newDefaultSANsValidator(ctx, claims.SANs), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), p.ctl.newWebhookController(data, linkedca.Webhook_X509), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. // NOTE: This method does not actually validate the certificate or check it's // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew) return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *JWK) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") } if claims.Step == nil || claims.Step.SSH == nil { return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; jwk token must be an SSH provisioning token") } opts := claims.Step.SSH signOptions := []SignOption{ // validates user's SignSSHOptions with the ones in the token sshCertOptionsValidator(*opts), // validate users's KeyID is the token subject. sshCertOptionsValidator(SignSSHOptions{KeyID: claims.Subject}), } // Default template attributes. certType := sshutil.UserCert keyID := claims.Subject principals := []string{claims.Subject} // Use options in the token. if opts.CertType != "" { if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil { return nil, errs.BadRequestErr(err, "%s", err.Error()) } } if opts.KeyID != "" { keyID = opts.KeyID } if len(opts.Principals) > 0 { principals = opts.Principals } // Certificate templates. data := sshutil.CreateTemplateData(certType, keyID, principals) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } templateOptions, err := TemplateSSHOptions(p.Options, data) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } signOptions = append(signOptions, templateOptions) // Add modifiers from custom claims t := now() if !opts.ValidAfter.IsZero() { signOptions = append(signOptions, sshCertValidAfterModifier(cast.Uint64(opts.ValidAfter.RelativeTime(t).Unix()))) } if !opts.ValidBefore.IsZero() { signOptions = append(signOptions, sshCertValidBeforeModifier(cast.Uint64(opts.ValidBefore.RelativeTime(t).Unix()))) } return append(signOptions, p, // Set the validity bounds if not set. &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, // Ensure that all principal names are allowed newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), // Call webhooks p.ctl.newWebhookController(data, linkedca.Webhook_SSH), ), nil } // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. func (p *JWK) AuthorizeSSHRevoke(_ context.Context, token string) error { _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) // TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") } ================================================ FILE: authority/provisioner/jwk_test.go ================================================ package provisioner import ( "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/x509" "errors" "fmt" "net/http" "strings" "testing" "time" "go.step.sm/crypto/fingerprint" "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" ) func TestJWK_Getters(t *testing.T) { p, err := generateJWK() assert.FatalError(t, err) if got := p.GetID(); got != p.Name+":"+p.Key.KeyID { t.Errorf("JWK.GetID() = %v, want %v:%v", got, p.Name, p.Key.KeyID) } if got := p.GetName(); got != p.Name { t.Errorf("JWK.GetName() = %v, want %v", got, p.Name) } if got := p.GetType(); got != TypeJWK { t.Errorf("JWK.GetType() = %v, want %v", got, TypeJWK) } kid, key, ok := p.GetEncryptedKey() if kid != p.Key.KeyID || key != p.EncryptedKey || ok == false { t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, p.Key.KeyID, p.EncryptedKey, true) } p.EncryptedKey = "" kid, key, ok = p.GetEncryptedKey() if kid != p.Key.KeyID || key != "" || ok == true { t.Errorf("JWK.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, p.Key.KeyID, "", false) } } func TestJWK_Init(t *testing.T) { type ProvisionerValidateTest struct { p *JWK err error } tests := map[string]func(*testing.T) ProvisionerValidateTest{ "fail-empty": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &JWK{}, err: errors.New("provisioner type cannot be empty"), } }, "fail-empty-name": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &JWK{ Type: "JWK", }, err: errors.New("provisioner name cannot be empty"), } }, "fail-empty-type": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &JWK{Name: "foo"}, err: errors.New("provisioner type cannot be empty"), } }, "fail-empty-key": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &JWK{Name: "foo", Type: "bar"}, err: errors.New("provisioner key cannot be empty"), } }, "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, "ok": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, } }, } config := Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, } for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) err := tc.p.Init(config) if err != nil { if assert.NotNil(t, tc.err) { assert.Equals(t, tc.err.Error(), err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func TestJWK_authorizeToken(t *testing.T) { p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateJWK() assert.FatalError(t, err) key1, err := decryptJSONWebKey(p1.EncryptedKey) assert.FatalError(t, err) key2, err := decryptJSONWebKey(p2.EncryptedKey) assert.FatalError(t, err) t1, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], key1) assert.FatalError(t, err) t2, err := generateSimpleToken(p2.Name, testAudiences.Sign[1], key2) assert.FatalError(t, err) t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences.Sign[0], "", []string{}, time.Now(), key1) assert.FatalError(t, err) // Invalid tokens parts := strings.Split(t1, ".") key3, err := generateJSONWebKey() assert.FatalError(t, err) // missing key failKey, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], key3) assert.FatalError(t, err) // invalid token failTok := "foo." + parts[1] + "." + parts[2] // invalid claims failClaims := parts[0] + ".foo." + parts[1] // invalid issuer failIss, err := generateSimpleToken("foobar", testAudiences.Sign[0], key1) assert.FatalError(t, err) // invalid audience failAud, err := generateSimpleToken(p1.Name, "foobar", key1) assert.FatalError(t, err) // invalid signature failSig := t1[0 : len(t1)-2] // no subject failSub, err := generateToken("", p1.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), key1) assert.FatalError(t, err) // expired failExp, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1) assert.FatalError(t, err) // not before failNbf, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1) assert.FatalError(t, err) // Remove encrypted key for p2 p2.EncryptedKey = "" type args struct { token string } tests := []struct { name string prov *JWK args args code int err error }{ {"fail-token", p1, args{failTok}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk token")}, {"fail-key", p1, args{failKey}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims")}, {"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims")}, {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; error parsing jwk claims: go-jose/go-jose: error in cryptographic primitive")}, {"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: go-jose/go-jose/jwt: validation failed, invalid issuer claim (iss)")}, {"fail-expired", p1, args{failExp}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: go-jose/go-jose/jwt: validation failed, token is expired (exp)")}, {"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk claims: go-jose/go-jose/jwt: validation failed, token not valid yet (nbf)")}, {"fail-audience", p1, args{failAud}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; invalid jwk token audience claim (aud)")}, {"fail-subject", p1, args{failSub}, http.StatusUnauthorized, errors.New("jwk.authorizeToken; jwk token subject cannot be empty")}, {"ok", p1, args{t1}, http.StatusOK, nil}, {"ok-no-encrypted-key", p2, args{t2}, http.StatusOK, nil}, {"ok-no-sans", p1, args{t3}, http.StatusOK, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tt.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { assert.Nil(t, tt.err) assert.NotNil(t, got) } }) } } func TestJWK_AuthorizeRevoke(t *testing.T) { p1, err := generateJWK() assert.FatalError(t, err) key1, err := decryptJSONWebKey(p1.EncryptedKey) assert.FatalError(t, err) t1, err := generateSimpleToken(p1.Name, testAudiences.Revoke[0], key1) assert.FatalError(t, err) // invalid signature failSig := t1[0 : len(t1)-2] type args struct { token string } tests := []struct { name string prov *JWK args args code int err error }{ {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, errors.New("jwk.AuthorizeRevoke: jwk.authorizeToken; error parsing jwk claims: go-jose/go-jose: error in cryptographic primitive")}, {"ok", p1, args{t1}, http.StatusOK, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil { if assert.NotNil(t, tt.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } } }) } } func TestJWK_AuthorizeSign(t *testing.T) { p1, err := generateJWK() assert.FatalError(t, err) key1, err := decryptJSONWebKey(p1.EncryptedKey) assert.FatalError(t, err) t1, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "name@smallstep.com", []string{"127.0.0.1", "max@smallstep.com", "foo"}, time.Now(), key1) assert.FatalError(t, err) t2, err := generateToken("subject", p1.Name, testAudiences.Sign[0], "name@smallstep.com", []string{}, time.Now(), key1) assert.FatalError(t, err) t3, err := generateCustomToken("subject", p1.Name, testAudiences.Sign[0], key1, nil, map[string]any{"cnf": map[string]any{"x5rt#S256": "fingerprint"}}) assert.FatalError(t, err) // invalid signature failSig := t1[0 : len(t1)-2] type args struct { token string } tests := []struct { name string prov *JWK args args code int err error sans []string fingerprint string }{ { name: "fail-signature", prov: p1, args: args{failSig}, code: http.StatusUnauthorized, err: errors.New("jwk.AuthorizeSign: jwk.authorizeToken; error parsing jwk claims: go-jose/go-jose: error in cryptographic primitive"), }, { name: "ok-sans", prov: p1, args: args{t1}, code: http.StatusOK, err: nil, sans: []string{"127.0.0.1", "max@smallstep.com", "foo"}, }, { name: "ok-no-sans", prov: p1, args: args{t2}, code: http.StatusOK, err: nil, sans: []string{"subject"}, }, { name: "ok-cnf", prov: p1, args: args{t3}, code: http.StatusOK, err: nil, sans: []string{"subject"}, fingerprint: "fingerprint", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil { if assert.NotNil(t, tt.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { if assert.NotNil(t, got) { assert.Equals(t, 11, len(got)) for _, o := range got { switch v := o.(type) { case *JWK: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeJWK) assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration()) case commonNameSliceValidator: assert.Equals(t, []string(v), append([]string{"subject"}, tt.sans...)) case defaultPublicKeyValidator: case *validityValidator: assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case *defaultSANsValidator: assert.Equals(t, v.sans, tt.sans) assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case *x509NamePolicyValidator: assert.Equals(t, nil, v.policyEngine) case *WebhookController: case csrFingerprintValidator: assert.Equals(t, tt.fingerprint, string(v)) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } } }) } } func TestJWK_AuthorizeRenew(t *testing.T) { now := time.Now().Truncate(time.Second) p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateJWK() assert.FatalError(t, err) // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { cert *x509.Certificate } tests := []struct { name string prov *JWK args args code int wantErr bool }{ {"ok", p1, args{&x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, http.StatusOK, false}, {"fail/renew-disabled", p2, args{&x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) } } func TestJWK_AuthorizeSSHSign(t *testing.T) { tm, fn := mockNow() defer fn() p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateJWK() assert.FatalError(t, err) // disable sshCA disable := false p2.Claims = &Claims{EnableSSHCA: &disable} p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) jwk, err := decryptJSONWebKey(p1.EncryptedKey) assert.FatalError(t, err) key, err := generateJSONWebKey() assert.FatalError(t, err) signer, err := generateJSONWebKey() assert.FatalError(t, err) pub := key.Public().Key rsa2048, err := rsa.GenerateKey(rand.Reader, 2048) assert.FatalError(t, err) //nolint:gosec // tests minimum size of the key rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) // Calculate fingerprint sshPub, err := ssh.NewPublicKey(pub) assert.FatalError(t, err) fp, err := fingerprint.New(sshPub.Marshal(), crypto.SHA256, fingerprint.Base64RawURLFingerprint) assert.FatalError(t, err) iss, aud := p1.Name, testAudiences.SSHSign[0] t1, err := generateSimpleSSHUserToken(iss, aud, jwk) assert.FatalError(t, err) t2, err := generateSimpleSSHHostToken(iss, aud, jwk) assert.FatalError(t, err) t3, err := generateCustomToken("sub", iss, aud, jwk, nil, map[string]any{ "step": map[string]any{ "ssh": map[string]any{"certType": "host", "principals": []string{"smallstep.com"}}, }, "cnf": map[string]any{"kid": fp}, }) assert.FatalError(t, err) t4, err := generateCustomToken("sub", iss, aud, jwk, nil, map[string]any{ "step": map[string]any{ "ssh": map[string]any{"certType": "host", "principals": []string{"smallstep.com"}}, }, "cnf": map[string]any{"kid": "bad-fingerprint"}, }) assert.FatalError(t, err) // invalid signature failSig := t1[0 : len(t1)-2] userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), } expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } type args struct { token string sshOpts SignSSHOptions key interface{} } tests := []struct { name string prov *JWK args args expected *SignSSHOptions code int wantErr bool wantSignErr bool }{ {"user", p1, args{t1, SignSSHOptions{}, pub}, expectedUserOptions, http.StatusOK, false, false}, {"user-rsa2048", p1, args{t1, SignSSHOptions{}, rsa2048.Public()}, expectedUserOptions, http.StatusOK, false, false}, {"user-type", p1, args{t1, SignSSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false}, {"user-principals", p1, args{t1, SignSSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, http.StatusOK, false, false}, {"user-options", p1, args{t1, SignSSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, http.StatusOK, false, false}, {"host", p1, args{t2, SignSSHOptions{}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-type", p1, args{t2, SignSSHOptions{CertType: "host"}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-principals", p1, args{t2, SignSSHOptions{Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-options", p1, args{t2, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"host-cnf", p1, args{t3, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ignore-bad-cnf", p1, args{t4, SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"fail-sshCA-disabled", p2, args{"foo", SignSSHOptions{}, pub}, expectedUserOptions, http.StatusUnauthorized, true, false}, {"fail-signature", p1, args{failSig, SignSSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false}, {"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.prov.AuthorizeSSHSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) if (err != nil) != tt.wantSignErr { t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) } else { if tt.wantSignErr { assert.Nil(t, cert) } else { assert.NoError(t, validateSSHCertificate(cert, tt.expected)) } } } }) } } func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) { tm, fn := mockNow() defer fn() p1, err := generateJWK() assert.FatalError(t, err) jwk, err := decryptJSONWebKey(p1.EncryptedKey) assert.FatalError(t, err) sub, iss, aud, iat := "subject@smallstep.com", p1.Name, testAudiences.SSHSign[0], time.Now() key, err := generateJSONWebKey() assert.FatalError(t, err) signer, err := generateJSONWebKey() assert.FatalError(t, err) userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), } expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } type args struct { sub, iss, aud string iat time.Time tokSSHOpts *SignSSHOptions userSSHOpts *SignSSHOptions jwk *jose.JSONWebKey } tests := []struct { name string prov *JWK args args expected *SignSSHOptions wantErr bool wantSignErr bool }{ {"ok-user", p1, args{sub, iss, aud, iat, &SignSSHOptions{CertType: "user", Principals: []string{"name"}}, &SignSSHOptions{}, jwk}, expectedUserOptions, false, false}, {"ok-host", p1, args{sub, iss, aud, iat, &SignSSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, &SignSSHOptions{}, jwk}, expectedHostOptions, false, false}, {"ok-user-validAfter", p1, args{sub, iss, aud, iat, &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, }, &SignSSHOptions{ ValidAfter: NewTimeDuration(tm.Add(-time.Hour)), }, jwk}, &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm.Add(-time.Hour)), ValidBefore: NewTimeDuration(tm.Add(userDuration - time.Hour)), }, false, false}, {"ok-user-validBefore", p1, args{sub, iss, aud, iat, &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, }, &SignSSHOptions{ ValidBefore: NewTimeDuration(tm.Add(time.Hour)), }, jwk}, &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(time.Hour)), }, false, false}, {"ok-user-validAfter-validBefore", p1, args{sub, iss, aud, iat, &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, }, &SignSSHOptions{ ValidAfter: NewTimeDuration(tm.Add(10 * time.Minute)), ValidBefore: NewTimeDuration(tm.Add(time.Hour)), }, jwk}, &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm.Add(10 * time.Minute)), ValidBefore: NewTimeDuration(tm.Add(time.Hour)), }, false, false}, {"ok-user-match", p1, args{sub, iss, aud, iat, &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(1 * time.Hour)), }, &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(1 * time.Hour)), }, jwk}, &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(time.Hour)), }, false, false}, {"fail-certType", p1, args{sub, iss, aud, iat, &SignSSHOptions{CertType: "user", Principals: []string{"name"}}, &SignSSHOptions{CertType: "host"}, jwk}, nil, false, true}, {"fail-principals", p1, args{sub, iss, aud, iat, &SignSSHOptions{CertType: "user", Principals: []string{"name"}}, &SignSSHOptions{Principals: []string{"root"}}, jwk}, nil, false, true}, {"fail-validAfter", p1, args{sub, iss, aud, iat, &SignSSHOptions{CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm)}, &SignSSHOptions{ValidAfter: NewTimeDuration(tm.Add(time.Hour))}, jwk}, nil, false, true}, {"fail-validBefore", p1, args{sub, iss, aud, iat, &SignSSHOptions{CertType: "user", Principals: []string{"name"}, ValidBefore: NewTimeDuration(tm.Add(time.Hour))}, &SignSSHOptions{ValidBefore: NewTimeDuration(tm.Add(10 * time.Hour))}, jwk}, nil, false, true}, {"fail-subject", p1, args{"", iss, aud, iat, &SignSSHOptions{CertType: "user", Principals: []string{"name"}}, &SignSSHOptions{}, jwk}, nil, true, false}, {"fail-issuer", p1, args{sub, "invalid", aud, iat, &SignSSHOptions{CertType: "user", Principals: []string{"name"}}, &SignSSHOptions{}, jwk}, nil, true, false}, {"fail-audience", p1, args{sub, iss, "invalid", iat, &SignSSHOptions{CertType: "user", Principals: []string{"name"}}, &SignSSHOptions{}, jwk}, nil, true, false}, {"fail-expired", p1, args{sub, iss, aud, iat.Add(-6 * time.Minute), &SignSSHOptions{CertType: "user", Principals: []string{"name"}}, &SignSSHOptions{}, jwk}, nil, true, false}, {"fail-notBefore", p1, args{sub, iss, aud, iat.Add(5 * time.Minute), &SignSSHOptions{CertType: "user", Principals: []string{"name"}}, &SignSSHOptions{}, jwk}, nil, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { token, err := generateSSHToken(tt.args.sub, tt.args.iss, tt.args.aud, tt.args.iat, tt.args.tokSSHOpts, tt.args.jwk) assert.FatalError(t, err) if got, err := tt.prov.AuthorizeSSHSign(context.Background(), token); (err != nil) != tt.wantErr { t.Errorf("JWK.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) } else if !tt.wantErr && assert.NotNil(t, got) { var opts SignSSHOptions if tt.args.userSSHOpts != nil { opts = *tt.args.userSSHOpts } cert, err := signSSHCertificate(key.Public().Key, opts, got, signer.Key.(crypto.Signer)) if (err != nil) != tt.wantSignErr { t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) } else { if tt.wantSignErr { assert.Nil(t, cert) } else { assert.NoError(t, validateSSHCertificate(cert, tt.expected)) } } } }) } } func TestJWK_AuthorizeSSHRevoke(t *testing.T) { type test struct { p *JWK token string code int err error } tests := map[string]func(*testing.T) test{ "fail/invalid-token": func(t *testing.T) test { p, err := generateJWK() assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("jwk.AuthorizeSSHRevoke: jwk.authorizeToken; error parsing jwk token"), } }, "ok": func(t *testing.T) test { p, err := generateJWK() assert.FatalError(t, err) jwk, err := decryptJSONWebKey(p.EncryptedKey) assert.FatalError(t, err) tok, err := generateToken("subject", p.Name, testAudiences.SSHRevoke[0], "name@smallstep.com", []string{"127.0.0.1", "max@smallstep.com", "foo"}, time.Now(), jwk) assert.FatalError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } ================================================ FILE: authority/provisioner/k8sSA.go ================================================ package provisioner import ( "context" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/x509" "encoding/pem" "net/http" "github.com/pkg/errors" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/errs" ) // NOTE: There can be at most one kubernetes service account provisioner configured // per instance of step-ca. This is due to a lack of distinguishing information // contained in kubernetes service account tokens. const ( // K8sSAName is the default name used for kubernetes service account provisioners. K8sSAName = "k8sSA-default" // K8sSAID is the default ID for kubernetes service account provisioners. K8sSAID = "k8ssa/" + K8sSAName k8sSAIssuer = "kubernetes/serviceaccount" ) // jwtPayload extends jwt.Claims with step attributes. type k8sSAPayload struct { jose.Claims Namespace string `json:"kubernetes.io/serviceaccount/namespace,omitempty"` SecretName string `json:"kubernetes.io/serviceaccount/secret.name,omitempty"` ServiceAccountName string `json:"kubernetes.io/serviceaccount/service-account.name,omitempty"` ServiceAccountUID string `json:"kubernetes.io/serviceaccount/service-account.uid,omitempty"` } // K8sSA represents a Kubernetes ServiceAccount provisioner; an // entity trusted to make signature requests. type K8sSA struct { *base ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` PubKeys []byte `json:"publicKeys,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` //kauthn kauthn.AuthenticationV1Interface pubKeys []interface{} ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id // should uniquely identify any K8sSA provisioner. func (p *K8sSA) GetID() string { if p.ID != "" { return p.ID } return p.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (p *K8sSA) GetIDForToken() string { return K8sSAID } // GetTokenID returns an unimplemented error and does not use the input ott. func (p *K8sSA) GetTokenID(string) (string, error) { return "", ErrNotImplemented } // GetName returns the name of the provisioner. func (p *K8sSA) GetName() string { return p.Name } // GetType returns the type of provisioner. func (p *K8sSA) GetType() Type { return TypeK8sSA } // GetEncryptedKey returns false, because the kubernetes provisioner does not // have access to the private key. func (p *K8sSA) GetEncryptedKey() (string, string, bool) { return "", "", false } // Init initializes and validates the fields of a K8sSA type. func (p *K8sSA) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") case p.Name == "": return errors.New("provisioner name cannot be empty") } if p.PubKeys != nil { var ( block *pem.Block rest = p.PubKeys ) for rest != nil { block, rest = pem.Decode(rest) if block == nil { break } key, err := pemutil.ParseKey(pem.EncodeToMemory(block)) if err != nil { return errors.Wrapf(err, "error parsing public key in provisioner '%s'", p.GetName()) } switch q := key.(type) { case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: default: return errors.Errorf("Unexpected public key type %T in provisioner '%s'", q, p.GetName()) } p.pubKeys = append(p.pubKeys, key) } } else { // TODO: Use the TokenReview API if no pub keys provided. This will need to // be configured with additional attributes in the K8sSA struct for // connecting to the kubernetes API server. return errors.New("K8s Service Account provisioner cannot be initialized without pub keys") } /* // NOTE: Not sure if we should be doing this initialization here ... // If you have a k8sSA provisioner defined in your config, but you're not // in a kubernetes pod then your CA will fail to startup. Maybe we just postpone // creating the authn until token validation time? if err := checkAccess(k8s.AuthorizationV1()); err != nil { return errors.Wrapf(err, "error verifying access to kubernetes authz service for provisioner %s", p.GetID()) } p.kauthn = k8s.AuthenticationV1() */ p.ctl, err = NewController(p, p.Claims, config, p.Options) return } // authorizeToken performs common jwt authorization actions and returns the // claims for case specific downstream parsing. // e.g. a Sign request will auth/validate different fields than a Revoke request. func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, error) { _ = audiences // unused input jwt, err := jose.ParseSigned(token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "k8ssa.authorizeToken; error parsing k8sSA token") } var ( valid bool claims k8sSAPayload ) if p.pubKeys == nil { return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented") /* NOTE: We plan to support the TokenReview API in a future release. Below is some code that should be useful when we prioritize this integration. tr := kauthnApi.TokenReview{Spec: kauthnApi.TokenReviewSpec{Token: string(token)}} rvw, err := p.kauthn.TokenReviews().Create(&tr) if err != nil { return nil, errors.Wrap(err, "error using kubernetes TokenReview API") } if rvw.Status.Error != "" { return nil, errors.Errorf("error from kubernetes TokenReviewAPI: %s", rvw.Status.Error) } if !rvw.Status.Authenticated { return nil, errors.New("error from kubernetes TokenReviewAPI: token could not be authenticated") } if err = jwt.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, errors.Wrap(err, "error parsing claims") } */ } for _, pk := range p.pubKeys { if err = jwt.Claims(pk, &claims); err == nil { valid = true break } } if !valid { return nil, errs.Unauthorized("k8ssa.authorizeToken; error validating k8sSA token and extracting claims") } // According to "rfc7519 JSON Web Token" acceptable skew should be no // more than a few minutes. if err = claims.Validate(jose.Expected{ Issuer: k8sSAIssuer, }); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "k8ssa.authorizeToken; invalid k8sSA token claims") } if claims.Subject == "" { return nil, errs.Unauthorized("k8ssa.authorizeToken; k8sSA token subject cannot be empty") } return &claims, nil } // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *K8sSA) AuthorizeRevoke(_ context.Context, token string) error { _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *K8sSA) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") } // Add some values to use in custom templates. data := x509util.NewTemplateData() data.SetCommonName(claims.ServiceAccountName) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } // Certificate templates: on K8sSA the default template is the certificate // request. templateOptions, err := CustomTemplateOptions(p.Options, data, x509util.DefaultAdminLeafTemplate) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") } return []SignOption{ p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeK8sSA, p.Name, "").WithControllerOptions(p.ctl), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), p.ctl.newWebhookController(data, linkedca.Webhook_X509), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign validates an request for an SSH certificate. func (p *K8sSA) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") } // Certificate templates. // Set some default variables to be used in the templates. data := sshutil.CreateTemplateData(sshutil.HostCert, claims.ServiceAccountName, []string{claims.ServiceAccountName}) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } templateOptions, err := CustomSSHTemplateOptions(p.Options, data, sshutil.CertificateRequestTemplate) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") } signOptions := []SignOption{templateOptions} return append(signOptions, p, // Require type, key-id and principals in the SignSSHOptions. &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}, // Set the validity bounds if not set. &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, // Ensure that all principal names are allowed newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), // Call webhooks p.ctl.newWebhookController(data, linkedca.Webhook_SSH), ), nil } /* func checkAccess(authz kauthz.AuthorizationV1Interface) error { r := &kauthzApi.SelfSubjectAccessReview{ Spec: kauthzApi.SelfSubjectAccessReviewSpec{ ResourceAttributes: &kauthzApi.ResourceAttributes{ Group: "authentication.k8s.io", Version: "v1", Resource: "tokenreviews", Verb: "create", }, }, } rvw, err := authz.SelfSubjectAccessReviews().Create(r) if err != nil { return err } if !rvw.Status.Allowed { return fmt.Errorf("Unable to create kubernetes token reviews: %s", rvw.Status.Reason) } return nil } */ ================================================ FILE: authority/provisioner/k8sSA_test.go ================================================ package provisioner import ( "context" "crypto/x509" "errors" "fmt" "net/http" "testing" "time" "go.step.sm/crypto/jose" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" ) func TestK8sSA_Getters(t *testing.T) { p, err := generateK8sSA(nil) assert.FatalError(t, err) id := "k8ssa/" + p.Name if got := p.GetID(); got != id { t.Errorf("K8sSA.GetID() = %v, want %v", got, id) } if got := p.GetName(); got != p.Name { t.Errorf("K8sSA.GetName() = %v, want %v", got, p.Name) } if got := p.GetType(); got != TypeK8sSA { t.Errorf("K8sSA.GetType() = %v, want %v", got, TypeK8sSA) } kid, key, ok := p.GetEncryptedKey() if kid != "" || key != "" || ok == true { t.Errorf("K8sSA.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } } func TestK8sSA_authorizeToken(t *testing.T) { type test struct { p *K8sSA token string err error code int } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("k8ssa.authorizeToken; error parsing k8sSA token"), } }, "fail/not-implemented": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) p, err := generateK8sSA(nil) assert.FatalError(t, err) tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk) p.pubKeys = nil assert.FatalError(t, err) return test{ p: p, token: tok, err: errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented"), code: http.StatusUnauthorized, } }, "fail/error-validating-token": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) p, err := generateK8sSA(nil) assert.FatalError(t, err) tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk) assert.FatalError(t, err) return test{ p: p, token: tok, err: errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims"), code: http.StatusUnauthorized, } }, "fail/invalid-issuer": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) p, err := generateK8sSA(jwk.Public().Key) assert.FatalError(t, err) claims := getK8sSAPayload() claims.Claims.Issuer = "invalid" tok, err := generateK8sSAToken(jwk, claims) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("k8ssa.authorizeToken; invalid k8sSA token claims: go-jose/go-jose/jwt: validation failed, invalid issuer claim (iss)"), } }, "ok": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) p, err := generateK8sSA(jwk.Public().Key) assert.FatalError(t, err) tok, err := generateK8sSAToken(jwk, nil) assert.FatalError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.NotNil(t, claims) } } }) } } func TestK8sSA_AuthorizeRevoke(t *testing.T) { type test struct { p *K8sSA token string err error code int } tests := map[string]func(*testing.T) test{ "fail/invalid-token": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("k8ssa.AuthorizeRevoke: k8ssa.authorizeToken; error parsing k8sSA token"), } }, "ok": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) p, err := generateK8sSA(jwk.Public().Key) assert.FatalError(t, err) tok, err := generateK8sSAToken(jwk, nil) assert.FatalError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func TestK8sSA_AuthorizeRenew(t *testing.T) { now := time.Now().Truncate(time.Second) type test struct { p *K8sSA cert *x509.Certificate err error code int } tests := map[string]func(*testing.T) test{ "fail/renew-disabled": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, cert: &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }, code: http.StatusUnauthorized, err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) return test{ p: p, cert: &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func TestK8sSA_AuthorizeSign(t *testing.T) { type test struct { p *K8sSA token string code int err error } tests := map[string]func(*testing.T) test{ "fail/invalid-token": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("k8ssa.AuthorizeSign: k8ssa.authorizeToken; error parsing k8sSA token"), } }, "ok": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) p, err := generateK8sSA(jwk.Public().Key) assert.FatalError(t, err) tok, err := generateK8sSAToken(jwk, nil) assert.FatalError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { for _, o := range opts { switch v := o.(type) { case *K8sSA: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeK8sSA) assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) case *x509NamePolicyValidator: assert.Equals(t, nil, v.policyEngine) case *WebhookController: assert.Len(t, 0, v.webhooks) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } assert.Equals(t, 8, len(opts)) } } } }) } } func TestK8sSA_AuthorizeSSHSign(t *testing.T) { type test struct { p *K8sSA token string code int err error } tests := map[string]func(*testing.T) test{ "fail/sshCA-disabled": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) // disable sshCA disable := false p.Claims = &Claims{EnableSSHCA: &disable} p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: fmt.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()), } }, "fail/invalid-token": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("k8ssa.AuthorizeSSHSign: k8ssa.authorizeToken; error parsing k8sSA token"), } }, "ok": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) p, err := generateK8sSA(jwk.Public().Key) assert.FatalError(t, err) tok, err := generateK8sSAToken(jwk, nil) assert.FatalError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { assert.Len(t, 9, opts) for _, o := range opts { switch v := o.(type) { case Interface: case sshCertificateOptionsFunc: case *sshCertOptionsRequireValidator: assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}) case *sshCertValidityValidator: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshDefaultDuration: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshNamePolicyValidator: assert.Equals(t, nil, v.userPolicyEngine) assert.Equals(t, nil, v.hostPolicyEngine) case *WebhookController: assert.Len(t, 0, v.webhooks) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } } } }) } } ================================================ FILE: authority/provisioner/keystore.go ================================================ package provisioner import ( "encoding/json" "math/rand" "regexp" "strconv" "sync" "time" "github.com/pkg/errors" "go.step.sm/crypto/jose" ) const ( defaultCacheAge = 12 * time.Hour defaultCacheJitter = 1 * time.Hour ) var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`) type keyStore struct { sync.RWMutex client HTTPClient uri string keySet jose.JSONWebKeySet expiry time.Time jitter time.Duration } func newKeyStore(client HTTPClient, uri string) (*keyStore, error) { keys, age, err := getKeysFromJWKsURI(client, uri) if err != nil { return nil, err } jitter := getCacheJitter(age) return &keyStore{ client: client, uri: uri, keySet: keys, expiry: getExpirationTime(age, jitter), jitter: jitter, }, nil } func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) { ks.RLock() // Force reload if expiration has passed if time.Now().After(ks.expiry) { ks.RUnlock() ks.reload() ks.RLock() } keys = ks.keySet.Key(kid) ks.RUnlock() return } func (ks *keyStore) reload() { if keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri); err == nil { ks.Lock() ks.keySet = keys ks.jitter = getCacheJitter(age) ks.expiry = getExpirationTime(age, ks.jitter) ks.Unlock() } } func getKeysFromJWKsURI(client HTTPClient, uri string) (jose.JSONWebKeySet, time.Duration, error) { var keys jose.JSONWebKeySet resp, err := client.Get(uri) if err != nil { return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri) } defer resp.Body.Close() if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { return keys, 0, errors.Wrapf(err, "error reading %s", uri) } return keys, getCacheAge(resp.Header.Get("cache-control")), nil } func getCacheAge(cacheControl string) time.Duration { age := defaultCacheAge if cacheControl != "" { match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1) if len(match) > 0 { if len(match[0]) == 2 { maxAge := match[0][1] maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64) if err != nil { return defaultCacheAge } age = time.Duration(maxAgeInt) * time.Second } } } return age } func getCacheJitter(age time.Duration) time.Duration { switch { case age > time.Hour: return defaultCacheJitter case age == 0: // Avoids a 0 jitter. The duration is not important as it will rotate // automatically on each Get request. return defaultCacheJitter default: return age / 3 } } func getExpirationTime(age, jitter time.Duration) time.Time { if age > 0 { n := rand.Int63n(int64(jitter)) //nolint:gosec // not used for cryptographic security age -= time.Duration(n) } return time.Now().Truncate(time.Second).Add(abs(age)) } // abs returns the absolute value of n. func abs(n time.Duration) time.Duration { if n < 0 { return -n } return n } ================================================ FILE: authority/provisioner/keystore_test.go ================================================ package provisioner import ( "encoding/json" "fmt" "net/http" "net/http/httptest" "reflect" "testing" "time" "github.com/smallstep/assert" "go.step.sm/crypto/jose" ) func Test_newKeyStore(t *testing.T) { srv := generateTLSJWKServer(2) srv.Close() srv = httptest.NewTLSServer(srv.Config.Handler) defer srv.Close() ks, err := newKeyStore(srv.Client(), srv.URL) assert.FatalError(t, err) type args struct { client *http.Client uri string } tests := []struct { name string args args want jose.JSONWebKeySet wantErr bool }{ {"ok", args{srv.Client(), srv.URL}, ks.keySet, false}, {"fail", args{srv.Client(), srv.URL + "/error"}, jose.JSONWebKeySet{}, true}, {"fail client", args{http.DefaultClient, srv.URL}, jose.JSONWebKeySet{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := newKeyStore(tt.args.client, tt.args.uri) if (err != nil) != tt.wantErr { t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr) return } if err == nil { if !reflect.DeepEqual(got.keySet, tt.want) { t.Errorf("newKeyStore() = %v, want %v", got, tt.want) } } }) } } func Test_keyStore(t *testing.T) { srv := generateJWKServer(2) defer srv.Close() ks, err := newKeyStore(srv.Client(), srv.URL+"/random") assert.FatalError(t, err) ks.RLock() keySet1 := ks.keySet ks.RUnlock() // Check contents assert.Len(t, 2, keySet1.Keys) assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID)) assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID)) assert.Len(t, 0, ks.Get("foobar")) // Wait for rotation time.Sleep(5 * time.Second) assert.Len(t, 0, ks.Get("foobar")) // force refresh ks.RLock() keySet2 := ks.keySet ks.RUnlock() if reflect.DeepEqual(keySet1, keySet2) { t.Error("keyStore did not rotated") } // Check contents assert.Len(t, 2, keySet2.Keys) assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID)) assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID)) assert.Len(t, 0, ks.Get("foobar")) // Check hits resp, err := srv.Client().Get(srv.URL + "/hits") assert.FatalError(t, err) hits := struct { Hits int `json:"hits"` }{} defer resp.Body.Close() err = json.NewDecoder(resp.Body).Decode(&hits) assert.FatalError(t, err) assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits)) } func Test_keyStore_noCache(t *testing.T) { srv := generateJWKServer(2) defer srv.Close() ks, err := newKeyStore(srv.Client(), srv.URL+"/no-cache") assert.FatalError(t, err) ks.RLock() keySet1 := ks.keySet ks.RUnlock() // The keys will rotate on Get. // So we won't be able to find the cached ones assert.Len(t, 2, keySet1.Keys) assert.Len(t, 0, ks.Get(keySet1.Keys[0].KeyID)) assert.Len(t, 0, ks.Get(keySet1.Keys[1].KeyID)) assert.Len(t, 0, ks.Get("foobar")) // Check hits resp, err := srv.Client().Get(srv.URL + "/hits") assert.FatalError(t, err) hits := struct { Hits int `json:"hits"` }{} defer resp.Body.Close() err = json.NewDecoder(resp.Body).Decode(&hits) assert.FatalError(t, err) assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits)) } func Test_keyStore_Get(t *testing.T) { srv := generateJWKServer(2) defer srv.Close() ks, err := newKeyStore(srv.Client(), srv.URL) assert.FatalError(t, err) type args struct { kid string } tests := []struct { name string ks *keyStore args args wantKeys []jose.JSONWebKey }{ {"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}}, {"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}}, {"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) { t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys) } }) } } func Test_abs(t *testing.T) { maxInt64 := time.Duration(1<<63 - 1) minInt64 := time.Duration(-1 << 63) type args struct { n time.Duration } tests := []struct { name string args args want time.Duration }{ {"ok", args{0}, 0}, {"ok", args{-time.Hour}, time.Hour}, {"ok", args{time.Hour}, time.Hour}, {"ok maxInt64", args{maxInt64}, maxInt64}, {"ok minInt64 + 1", args{minInt64 + 1}, maxInt64}, {"overflow on minInt64", args{minInt64}, minInt64}, {"overflow on minInt64", args{minInt64}, -minInt64}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := abs(tt.args.n); got != tt.want { t.Errorf("abs() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/provisioner/method.go ================================================ package provisioner import ( "context" ) // Method indicates the action to action that we will perform, it's used as part // of the context in the call to authorize. It defaults to Sing. type Method int // The key to save the Method in the context. type methodKey struct{} const ( // SignMethod is the method used to sign X.509 certificates. SignMethod Method = iota // SignIdentityMethod is the method used to sign X.509 identity certificates. SignIdentityMethod // RevokeMethod is the method used to revoke X.509 certificates. RevokeMethod // RenewMethod is the method used to renew X.509 certificates. RenewMethod // SSHSignMethod is the method used to sign SSH certificates. SSHSignMethod // SSHRenewMethod is the method used to renew SSH certificates. SSHRenewMethod // SSHRevokeMethod is the method used to revoke SSH certificates. SSHRevokeMethod // SSHRekeyMethod is the method used to rekey SSH certificates. SSHRekeyMethod ) // String returns a string representation of the context method. func (m Method) String() string { switch m { case SignMethod: return "sign-method" case SignIdentityMethod: return "sign-identity-method" case RevokeMethod: return "revoke-method" case RenewMethod: return "renew-method" case SSHSignMethod: return "ssh-sign-method" case SSHRenewMethod: return "ssh-renew-method" case SSHRevokeMethod: return "ssh-revoke-method" case SSHRekeyMethod: return "ssh-rekey-method" default: return "unknown" } } // NewContextWithMethod creates a new context from ctx and attaches method to // it. func NewContextWithMethod(ctx context.Context, method Method) context.Context { return context.WithValue(ctx, methodKey{}, method) } // MethodFromContext returns the Method saved in ctx. func MethodFromContext(ctx context.Context) Method { m, _ := ctx.Value(methodKey{}).(Method) return m } type tokenKey struct{} // NewContextWithToken creates a new context with the given token. func NewContextWithToken(ctx context.Context, token string) context.Context { return context.WithValue(ctx, tokenKey{}, token) } // TokenFromContext returns the token stored in the given context. func TokenFromContext(ctx context.Context) (string, bool) { token, ok := ctx.Value(tokenKey{}).(string) return token, ok } // The key to save the certTypeKey in the context. type certTypeKey struct{} // NewContextWithCertType creates a new context with the given CertType. func NewContextWithCertType(ctx context.Context, certType string) context.Context { return context.WithValue(ctx, certTypeKey{}, certType) } // CertTypeFromContext returns the certType stored in the given context. func CertTypeFromContext(ctx context.Context) (string, bool) { certType, ok := ctx.Value(certTypeKey{}).(string) return certType, ok } ================================================ FILE: authority/provisioner/nebula.go ================================================ package provisioner import ( "context" "crypto/ecdh" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/x509" "encoding/base64" "encoding/pem" "math/big" "net" "net/netip" "time" "github.com/pkg/errors" nebula "github.com/slackhq/nebula/cert" "golang.org/x/crypto/ssh" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x25519" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" ) const ( // NebulaCertHeader is the token header that contains a Nebula certificate. NebulaCertHeader jose.HeaderKey = "nebula" ) // Nebula is a provisioner that verifies tokens signed using Nebula private // keys. The tokens contain a Nebula certificate in the header, which can be // used to verify the token signature. The certificates are themselves verified // using the Nebula CA certificates encoded in Roots. The verification process // is similar to the process for X5C tokens. // // Because Nebula "leaf" certificates use X25519 keys, the tokens are signed // using XEd25519 defined at // https://signal.org/docs/specifications/xeddsa/#xeddsa and implemented by // go.step.sm/crypto/x25519. type Nebula struct { ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Roots []byte `json:"roots"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` caPool *nebula.CAPool ctl *Controller } // Init verifies and initializes the Nebula provisioner. func (p *Nebula) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") case p.Name == "": return errors.New("provisioner name cannot be empty") case len(p.Roots) == 0: return errors.New("provisioner root(s) cannot be empty") } p.caPool, err = nebula.NewCAPoolFromPEM(p.Roots) if err != nil { return errs.InternalServer("failed to create CA pool: %v", err) } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.ctl, err = NewController(p, p.Claims, config, p.Options) return } // GetID returns the provisioner id. func (p *Nebula) GetID() string { if p.ID != "" { return p.ID } return p.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (p *Nebula) GetIDForToken() string { return "nebula/" + p.Name } // GetTokenID returns the identifier of the token. func (p *Nebula) GetTokenID(token string) (string, error) { // Validate payload t, err := jose.ParseSigned(token) if err != nil { return "", errors.Wrap(err, "error parsing token") } // Get claims w/out verification. We need to look up the provisioner // key in order to verify the claims and we need the issuer from the claims // before we can look up the provisioner. var claims jose.Claims if err = t.UnsafeClaimsWithoutVerification(&claims); err != nil { return "", errors.Wrap(err, "error verifying claims") } return claims.ID, nil } // GetName returns the name of the provisioner. func (p *Nebula) GetName() string { return p.Name } // GetType returns the type of provisioner. func (p *Nebula) GetType() Type { return TypeNebula } // GetEncryptedKey returns the base provisioner encrypted key if it's defined. func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } // AuthorizeSign returns the list of SignOption for a Sign request. func (p *Nebula) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, err } sans := claims.SANs if len(sans) == 0 { networks := crt.Networks() sans = make([]string, len(networks)+1) sans[0] = crt.Name() for i, network := range networks { sans[i+1] = network.Addr().String() } } data := x509util.CreateTemplateData(claims.Subject, sans) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } // The Nebula certificate will be available using the template variable // AuthorizationCrt. For example {{ .AuthorizationCrt.Details.Groups }} can // be used to get all the groups. data.SetAuthorizationCertificate(crt) templateOptions, err := TemplateOptions(p.Options, data) if err != nil { return nil, err } return []SignOption{ p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeNebula, p.Name, "").WithControllerOptions(p.ctl), profileLimitDuration{ def: p.ctl.Claimer.DefaultTLSCertDuration(), notBefore: crt.NotBefore(), notAfter: crt.NotAfter(), }, // validators commonNameValidator(claims.Subject), nebulaSANsValidator{ Name: crt.Name(), Networks: crt.Networks(), }, defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), p.ctl.newWebhookController(data, linkedca.Webhook_X509), }, nil } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // Currently the Nebula provisioner only grants host SSH certificates. func (p *Nebula) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) } crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, err } // Default template attributes. keyID := claims.Subject networks := crt.Networks() principals := make([]string, len(networks)+1) principals[0] = crt.Name() for i, network := range networks { principals[i+1] = network.Addr().String() } var signOptions []SignOption // If step ssh options are given, validate them and set key id, principals // and validity. if claims.Step != nil && claims.Step.SSH != nil { opts := claims.Step.SSH // Check that the token only contains valid principals. v := nebulaPrincipalsValidator{ Name: crt.Name(), Networks: crt.Networks(), } if err := v.Valid(*opts); err != nil { return nil, err } // Check that the cert type is a valid one. if opts.CertType != "" && opts.CertType != SSHHostCert { return nil, errs.Forbidden("ssh certificate type does not match - got %v, want %v", opts.CertType, SSHHostCert) } signOptions = []SignOption{ // validate is a host certificate and users's KeyID is the subject. sshCertOptionsValidator(SignSSHOptions{ CertType: SSHHostCert, KeyID: claims.Subject, }), // validates user's SSHOptions with the ones in the token sshCertOptionsValidator(*opts), } // Use options in the token. if opts.KeyID != "" { keyID = opts.KeyID } if len(opts.Principals) > 0 { principals = opts.Principals } // Add modifiers from custom claims t := now() if !opts.ValidAfter.IsZero() { signOptions = append(signOptions, sshCertValidAfterModifier(cast.Uint64(opts.ValidAfter.RelativeTime(t).Unix()))) } if !opts.ValidBefore.IsZero() { signOptions = append(signOptions, sshCertValidBeforeModifier(cast.Uint64(opts.ValidBefore.RelativeTime(t).Unix()))) } } // Certificate templates. data := sshutil.CreateTemplateData(sshutil.HostCert, keyID, principals) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } // The Nebula certificate will be available using the template variable Crt. // For example {{ .AuthorizationCrt.Details.Groups }} can be used to get all the groups. data.SetAuthorizationCertificate(crt) templateOptions, err := TemplateSSHOptions(p.Options, data) if err != nil { return nil, err } return append(signOptions, p, templateOptions, // Checks the validity bounds, and set the validity if has not been set. &sshLimitDuration{p.ctl.Claimer, crt.NotAfter()}, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, // Ensure that all principal names are allowed newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), // Call webhooks p.ctl.newWebhookController(data, linkedca.Webhook_SSH), ), nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error { return p.ctl.AuthorizeRenew(ctx, crt) } // AuthorizeRevoke returns an unauthorized error. func (p *Nebula) AuthorizeRevoke(context.Context, string) error { return errs.Unauthorized("nebula provisioner does not support revoke") } // AuthorizeSSHRevoke returns an unauthorized error. func (p *Nebula) AuthorizeSSHRevoke(context.Context, string) error { return errs.Unauthorized("nebula provisioner does not support SSH revoke") } // AuthorizeSSHRenew returns an unauthorized error. func (p *Nebula) AuthorizeSSHRenew(context.Context, string) (*ssh.Certificate, error) { return nil, errs.Unauthorized("nebula provisioner does not support SSH renew") } // AuthorizeSSHRekey returns an unauthorized error. func (p *Nebula) AuthorizeSSHRekey(context.Context, string) (*ssh.Certificate, []SignOption, error) { return nil, nil, errs.Unauthorized("nebula provisioner does not support SSH rekey") } func (p *Nebula) authorizeToken(token string, audiences []string) (nebula.Certificate, *jwtPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, nil, errs.UnauthorizedErr(err, errs.WithMessage("failed to parse token")) } // Extract Nebula certificate h, ok := jwt.Headers[0].ExtraHeaders[NebulaCertHeader] if !ok { return nil, nil, errs.Unauthorized("failed to parse token: nebula header is missing") } s, ok := h.(string) if !ok { return nil, nil, errs.Unauthorized("failed to parse token: nebula header is not valid") } b, err := base64.StdEncoding.DecodeString(s) if err != nil { return nil, nil, errs.UnauthorizedErr(err, errs.WithMessage("failed to parse token: nebula header is not valid")) } // Wrap raw certificate bytes in PEM for unmarshaling. Try v1 banner // first, then fall back to v2 if that fails. pemData := pem.EncodeToMemory(&pem.Block{Type: nebula.CertificateBanner, Bytes: b}) c, _, err := nebula.UnmarshalCertificateFromPEM(pemData) if err != nil { pemData = pem.EncodeToMemory(&pem.Block{Type: nebula.CertificateV2Banner, Bytes: b}) c, _, err = nebula.UnmarshalCertificateFromPEM(pemData) } if err != nil { return nil, nil, errs.UnauthorizedErr(err, errs.WithMessage("failed to parse nebula certificate: nebula header is not valid")) } // Validate nebula certificate against CAs if _, err := p.caPool.VerifyCertificate(now(), c); err != nil { return nil, nil, errs.UnauthorizedErr(err, errs.WithMessage("token is not valid: failed to verify certificate against configured CA")) } var pub any switch { case c.Curve() == nebula.Curve_P256: // When Nebula is used with ECDSA P-256 keys, both CAs and clients use the same type. ecdhPub, err := ecdh.P256().NewPublicKey(c.PublicKey()) if err != nil { return nil, nil, errs.UnauthorizedErr(err, errs.WithMessage("failed to parse nebula public key")) } publicKeyBytes := ecdhPub.Bytes() pub = &ecdsa.PublicKey{ // convert back to *ecdsa.PublicKey, because our jose package nor go-jose supports *ecdh.PublicKey Curve: elliptic.P256(), X: big.NewInt(0).SetBytes(publicKeyBytes[1:33]), Y: big.NewInt(0).SetBytes(publicKeyBytes[33:]), } case c.IsCA(): pub = ed25519.PublicKey(c.PublicKey()) default: pub = x25519.PublicKey(c.PublicKey()) } // Validate token with public key var claims jwtPayload if err := jose.Verify(jwt, pub, &claims); err != nil { return nil, nil, errs.UnauthorizedErr(err, errs.WithMessage("token is not valid: signature does not match")) } // According to "rfc7519 JSON Web Token" acceptable skew should be no // more than a few minutes. if err = claims.ValidateWithLeeway(jose.Expected{ Issuer: p.Name, Time: now(), }, time.Minute); err != nil { return nil, nil, errs.UnauthorizedErr(err, errs.WithMessage("token is not valid: invalid claims")) } // Validate token and subject too. if !matchesAudience(claims.Audience, audiences) { return nil, nil, errs.Unauthorized("token is not valid: invalid claims") } if claims.Subject == "" { return nil, nil, errs.Unauthorized("token is not valid: subject cannot be empty") } return c, &claims, nil } type nebulaSANsValidator struct { Name string Networks []netip.Prefix } // Valid verifies that the SANs stored in the validator are contained with those // requested in the x509 certificate request. func (v nebulaSANsValidator) Valid(req *x509.CertificateRequest) error { dnsNames, ips, emails, uris := x509util.SplitSANs([]string{v.Name}) if len(req.DNSNames) > 0 { if err := dnsNamesValidator(dnsNames).Valid(req); err != nil { return err } } if len(req.EmailAddresses) > 0 { if err := emailAddressesValidator(emails).Valid(req); err != nil { return err } } if len(req.URIs) > 0 { if err := newURIsValidator(context.Background(), uris).Valid(req); err != nil { return err } } if len(req.IPAddresses) > 0 { for _, ip := range req.IPAddresses { var valid bool // Check ip in name for _, ipInName := range ips { if ip.Equal(ipInName) { valid = true break } } // Check ip network if !valid { for _, network := range v.Networks { if ip.Equal(net.IP(network.Addr().AsSlice())) { valid = true break } } } if !valid { for _, network := range v.Networks { ips = append(ips, net.IP(network.Addr().AsSlice())) } return errs.Forbidden("certificate request contains invalid IP addresses - got %v, want %v", req.IPAddresses, ips) } } } return nil } type nebulaPrincipalsValidator struct { Name string Networks []netip.Prefix } // Valid checks that the SignSSHOptions principals contains only names in the // Nebula certificate. func (v nebulaPrincipalsValidator) Valid(got SignSSHOptions) error { for _, p := range got.Principals { var valid bool if p == v.Name { valid = true } if !valid { if ip := net.ParseIP(p); ip != nil { for _, network := range v.Networks { if ip.Equal(net.IP(network.Addr().AsSlice())) { valid = true break } } } } if !valid { ips := make([]net.IP, len(v.Networks)) for i, network := range v.Networks { ips[i] = net.IP(network.Addr().AsSlice()) } return errs.Forbidden( "ssh certificate principals contains invalid name or IP addresses - got %v, want %s or %v", got.Principals, v.Name, ips, ) } } return nil } ================================================ FILE: authority/provisioner/nebula_test.go ================================================ package provisioner import ( "context" "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/x509" "net" "net/netip" "net/url" "strings" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" "go.step.sm/crypto/x25519" "go.step.sm/crypto/x509util" ) func mustNebulaPrefix(t *testing.T, s string) netip.Prefix { t.Helper() p, err := netip.ParsePrefix(s) require.NoError(t, err) return p } func mustNebulaCA(t *testing.T) (cert.Certificate, ed25519.PrivateKey) { t.Helper() pub, priv, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) now := time.Now() tbs := &cert.TBSCertificate{ Version: cert.Version1, Curve: cert.Curve_CURVE25519, Name: "TestCA", Groups: []string{"test"}, Networks: []netip.Prefix{netip.MustParsePrefix("10.1.0.0/16")}, NotBefore: time.Unix(now.Unix(), 0), NotAfter: time.Unix(now.Add(10*time.Minute).Unix(), 0), PublicKey: pub, IsCA: true, } nc, err := tbs.Sign(nil, cert.Curve_CURVE25519, priv) require.NoError(t, err) return nc, priv } func mustExpiredNebulaCA(t *testing.T) (cert.Certificate, ed25519.PrivateKey) { t.Helper() pub, priv, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) now := time.Now() tbs := &cert.TBSCertificate{ Version: cert.Version1, Curve: cert.Curve_CURVE25519, Name: "ExpiredTestCA", Groups: []string{"expired"}, Networks: []netip.Prefix{netip.MustParsePrefix("10.2.0.0/16")}, NotBefore: time.Unix(now.Add(-2*time.Hour).Unix(), 0), NotAfter: time.Unix(now.Add(-1*time.Hour).Unix(), 0), PublicKey: pub, IsCA: true, } nc, err := tbs.Sign(nil, cert.Curve_CURVE25519, priv) require.NoError(t, err) return nc, priv } func mustNebulaP256CA(t *testing.T) (cert.Certificate, *ecdsa.PrivateKey) { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) ecdhPriv, err := key.ECDH() require.NoError(t, err) now := time.Now() tbs := &cert.TBSCertificate{ Version: cert.Version1, Curve: cert.Curve_P256, Name: "TestCA", Groups: []string{"test"}, Networks: []netip.Prefix{netip.MustParsePrefix("10.1.0.0/16")}, NotBefore: time.Unix(now.Unix(), 0), NotAfter: time.Unix(now.Add(10*time.Minute).Unix(), 0), PublicKey: ecdhPriv.PublicKey().Bytes(), IsCA: true, } // For P256 CAs, Sign expects the raw 32-byte scalar as the key. nc, err := tbs.Sign(nil, cert.Curve_P256, key.D.FillBytes(make([]byte, 32))) require.NoError(t, err) return nc, key } func mustNebulaCert(t *testing.T, name string, network netip.Prefix, groups []string, ca cert.Certificate, signer ed25519.PrivateKey) (cert.Certificate, crypto.Signer) { t.Helper() pub, priv, err := x25519.GenerateKey(rand.Reader) require.NoError(t, err) t1 := time.Now().Truncate(time.Second) tbs := &cert.TBSCertificate{ Version: cert.Version1, Curve: cert.Curve_CURVE25519, Name: name, Networks: []netip.Prefix{network}, Groups: groups, NotBefore: t1, NotAfter: t1.Add(5 * time.Minute), PublicKey: pub, IsCA: false, } nc, err := tbs.Sign(ca, cert.Curve_CURVE25519, signer) require.NoError(t, err) return nc, priv } func mustNebulaP256Cert(t *testing.T, name string, network netip.Prefix, groups []string, ca cert.Certificate, signer *ecdsa.PrivateKey) (cert.Certificate, crypto.Signer) { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) ecdhPriv, err := key.ECDH() require.NoError(t, err) t1 := time.Now().Truncate(time.Second) tbs := &cert.TBSCertificate{ Version: cert.Version1, Curve: cert.Curve_P256, Name: name, Networks: []netip.Prefix{network}, Groups: groups, NotBefore: t1, NotAfter: t1.Add(5 * time.Minute), PublicKey: ecdhPriv.PublicKey().Bytes(), IsCA: false, } ecdhSigner, err := signer.ECDH() require.NoError(t, err) nc, err := tbs.Sign(ca, cert.Curve_P256, ecdhSigner.Bytes()) require.NoError(t, err) return nc, key } func mustNebulaProvisioner(t *testing.T) (*Nebula, cert.Certificate, ed25519.PrivateKey) { t.Helper() nc, signer := mustNebulaCA(t) ncPem, err := nc.MarshalPEM() require.NoError(t, err) bTrue := true p := &Nebula{ Type: TypeNebula.String(), Name: "nebulous", Roots: ncPem, Claims: &Claims{ EnableSSHCA: &bTrue, }, } err = p.Init(Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, }) require.NoError(t, err) return p, nc, signer } func mustNebulaP256Provisioner(t *testing.T) (*Nebula, cert.Certificate, *ecdsa.PrivateKey) { t.Helper() nc, signer := mustNebulaP256CA(t) ncPem, err := nc.MarshalPEM() require.NoError(t, err) bTrue := true p := &Nebula{ Type: TypeNebula.String(), Name: "nebulous", Roots: ncPem, Claims: &Claims{ EnableSSHCA: &bTrue, }, } err = p.Init(Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, }) require.NoError(t, err) return p, nc, signer } func mustNebulaToken(t *testing.T, sub, iss, aud string, iat time.Time, sans []string, nc cert.Certificate, key crypto.Signer, algorithm jose.SignatureAlgorithm) string { t.Helper() ncDer, err := nc.Marshal() require.NoError(t, err) so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader(NebulaCertHeader, ncDer) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: algorithm, Key: key}, so) require.NoError(t, err) id, err := randutil.ASCII(64) require.NoError(t, err) claims := struct { jose.Claims SANS []string `json:"sans"` }{ Claims: jose.Claims{ ID: id, Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, }, SANS: sans, } tok, err := jose.Signed(sig).Claims(claims).CompactSerialize() require.NoError(t, err) return tok } func mustNebulaSSHToken(t *testing.T, sub, iss, aud string, iat time.Time, opts *SignSSHOptions, nc cert.Certificate, key crypto.Signer, algorithm jose.SignatureAlgorithm) string { t.Helper() ncDer, err := nc.Marshal() require.NoError(t, err) so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader(NebulaCertHeader, ncDer) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: algorithm, Key: key}, so) require.NoError(t, err) id, err := randutil.ASCII(64) require.NoError(t, err) claims := struct { jose.Claims Step *stepPayload `json:"step,omitempty"` }{ Claims: jose.Claims{ ID: id, Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, }, } if opts != nil { claims.Step = &stepPayload{ SSH: opts, } } tok, err := jose.Signed(sig).Claims(claims).CompactSerialize() require.NoError(t, err) return tok } func TestNebula_Init(t *testing.T) { nc, _ := mustNebulaCA(t) ncPem, err := nc.MarshalPEM() require.NoError(t, err) expiredNC, _ := mustExpiredNebulaCA(t) expiredPEM, err := expiredNC.MarshalPEM() require.NoError(t, err) expiredPEM = append(expiredPEM, ncPem...) // needed so that regular error isn't triggered cfg := Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, } type fields struct { Type string Name string Roots []byte Claims *Claims Options *Options } type args struct { config Config } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{"Nebula", "Nebulous", ncPem, nil, nil}, args{cfg}, false}, {"ok with claims", fields{"Nebula", "Nebulous", ncPem, &Claims{DefaultTLSDur: &Duration{Duration: time.Hour}}, nil}, args{cfg}, false}, {"ok with options", fields{"Nebula", "Nebulous", ncPem, nil, &Options{X509: &X509Options{Template: x509util.DefaultLeafTemplate}}}, args{cfg}, false}, {"fail type", fields{"", "Nebulous", ncPem, nil, nil}, args{cfg}, true}, {"fail name", fields{"Nebula", "", ncPem, nil, nil}, args{cfg}, true}, {"fail root", fields{"Nebula", "Nebulous", nil, nil, nil}, args{cfg}, true}, {"fail expired root", fields{"Nebula", "Nebulous", expiredPEM, nil, nil}, args{cfg}, true}, {"fail bad root", fields{"Nebula", "Nebulous", ncPem[:16], nil, nil}, args{cfg}, true}, {"fail bad claims", fields{"Nebula", "Nebulous", ncPem, &Claims{ MinTLSDur: &Duration{Duration: 0}, }, nil}, args{cfg}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Nebula{ Type: tt.fields.Type, Name: tt.fields.Name, Roots: tt.fields.Roots, Claims: tt.fields.Claims, Options: tt.fields.Options, } if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { t.Errorf("Nebula.Init() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestNebula_GetID(t *testing.T) { type fields struct { ID string Name string } tests := []struct { name string fields fields want string }{ {"ok with id", fields{"1234", "nebulous"}, "1234"}, {"ok with name", fields{"", "nebulous"}, "nebula/nebulous"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Nebula{ ID: tt.fields.ID, Name: tt.fields.Name, } if got := p.GetID(); got != tt.want { t.Errorf("Nebula.GetID() = %v, want %v", got, tt.want) } }) } } func TestNebula_GetIDForToken(t *testing.T) { type fields struct { Name string } tests := []struct { name string fields fields want string }{ {"ok", fields{"nebulous"}, "nebula/nebulous"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Nebula{ Name: tt.fields.Name, } if got := p.GetIDForToken(); got != tt.want { t.Errorf("Nebula.GetIDForToken() = %v, want %v", got, tt.want) } }) } } func TestNebula_GetTokenID(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) c1, priv := mustNebulaCert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, ca, signer) t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv, jose.XEdDSA) _, claims, err := parseToken(t1) require.NoError(t, err) type args struct { token string } tests := []struct { name string p *Nebula args args want string wantErr bool }{ {"ok", p, args{t1}, claims.ID, false}, {"fail parse", p, args{"token"}, "", true}, {"fail claims", p, args{func() string { parts := strings.Split(t1, ".") return parts[0] + ".eyIifQ." + parts[1] }()}, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.p.GetTokenID(tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Nebula.GetTokenID() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("Nebula.GetTokenID() = %v, want %v", got, tt.want) } }) } } func TestNebula_GetName(t *testing.T) { type fields struct { Name string } tests := []struct { name string fields fields want string }{ {"ok", fields{"nebulous"}, "nebulous"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Nebula{ Name: tt.fields.Name, } if got := p.GetName(); got != tt.want { t.Errorf("Nebula.GetName() = %v, want %v", got, tt.want) } }) } } func TestNebula_GetType(t *testing.T) { type fields struct { Type string } tests := []struct { name string fields fields want Type }{ {"ok", fields{"Nebula"}, TypeNebula}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Nebula{ Type: tt.fields.Type, } if got := p.GetType(); got != tt.want { t.Errorf("Nebula.GetType() = %v, want %v", got, tt.want) } }) } } func TestNebula_GetEncryptedKey(t *testing.T) { tests := []struct { name string p *Nebula wantKid string wantKey string wantOk bool }{ {"ok", &Nebula{}, "", "", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotKid, gotKey, gotOk := tt.p.GetEncryptedKey() if gotKid != tt.wantKid { t.Errorf("Nebula.GetEncryptedKey() gotKid = %v, want %v", gotKid, tt.wantKid) } if gotKey != tt.wantKey { t.Errorf("Nebula.GetEncryptedKey() gotKey = %v, want %v", gotKey, tt.wantKey) } if gotOk != tt.wantOk { t.Errorf("Nebula.GetEncryptedKey() gotOk = %v, want %v", gotOk, tt.wantOk) } }) } } func TestNebula_AuthorizeSign(t *testing.T) { ctx := context.TODO() p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, ca, signer) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv, jose.XEdDSA) okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv, jose.XEdDSA) pBadOptions, _, _ := mustNebulaProvisioner(t) pBadOptions.caPool = p.caPool pBadOptions.Options = &Options{ X509: &X509Options{ TemplateData: []byte(`{""}`), }, } type args struct { ctx context.Context token string } tests := []struct { name string p *Nebula args args wantErr bool }{ {"ok", p, args{ctx, ok}, false}, {"ok no sans", p, args{ctx, okNoSANs}, false}, {"fail token", p, args{ctx, "token"}, true}, {"fail template", pBadOptions, args{ctx, ok}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := tt.p.AuthorizeSign(tt.args.ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Nebula.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return } }) } } func TestNebula_AuthorizeSSHSign(t *testing.T) { ctx := context.TODO() // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, ca, signer) ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan", "10.1.0.1"}, }, crt, priv, jose.XEdDSA) okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv, jose.XEdDSA) okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)), ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)), }, crt, priv, jose.XEdDSA) failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "user", }, crt, priv, jose.XEdDSA) failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan", "10.1.0.1", "foo.bar"}, }, crt, priv, jose.XEdDSA) // Provisioner with SSH disabled var bFalse bool pDisabled, _, _ := mustNebulaProvisioner(t) pDisabled.caPool = p.caPool pDisabled.Claims.EnableSSHCA = &bFalse // Provisioner with bad templates pBadOptions, _, _ := mustNebulaProvisioner(t) pBadOptions.caPool = p.caPool pBadOptions.Options = &Options{ SSH: &SSHOptions{ TemplateData: []byte(`{""}`), }, } type args struct { ctx context.Context token string } tests := []struct { name string p *Nebula args args wantErr bool }{ {"ok", p, args{ctx, ok}, false}, {"ok no options", p, args{ctx, okNoOptions}, false}, {"ok with validity", p, args{ctx, okWithValidity}, false}, {"fail token", p, args{ctx, "token"}, true}, {"fail user", p, args{ctx, failUserCert}, true}, {"fail principals", p, args{ctx, failPrincipals}, true}, {"fail disabled", pDisabled, args{ctx, ok}, true}, {"fail template", pBadOptions, args{ctx, ok}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := tt.p.AuthorizeSSHSign(tt.args.ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Nebula.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } }) } } func TestNebula_AuthorizeRenew(t *testing.T) { ctx := context.TODO() now := time.Now().Truncate(time.Second) // Ok provisioner p, _, _ := mustNebulaProvisioner(t) // Provisioner with renewal disabled bTrue := true pDisabled, _, _ := mustNebulaProvisioner(t) pDisabled.Claims.DisableRenewal = &bTrue type args struct { ctx context.Context crt *x509.Certificate } tests := []struct { name string p *Nebula args args wantErr bool }{ {"ok", p, args{ctx, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, false}, {"fail disabled", pDisabled, args{ctx, &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.p.AuthorizeRenew(tt.args.ctx, tt.args.crt); (err != nil) != tt.wantErr { t.Errorf("Nebula.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestNebula_AuthorizeRevoke(t *testing.T) { ctx := context.TODO() // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, ca, signer) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv, jose.XEdDSA) // Fail different CA nc, signer := mustNebulaCA(t) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, nc, signer) failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv, jose.XEdDSA) type args struct { ctx context.Context token string } tests := []struct { name string p *Nebula args args wantErr bool }{ {"fail unauthorized", p, args{ctx, ok}, true}, {"fail token", p, args{ctx, failToken}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.p.AuthorizeRevoke(tt.args.ctx, tt.args.token); (err != nil) != tt.wantErr { t.Errorf("Nebula.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestNebula_AuthorizeSSHRevoke(t *testing.T) { ctx := context.TODO() // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, ca, signer) ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv, jose.XEdDSA) // Fail different CA nc, signer := mustNebulaCA(t) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, nc, signer) failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv, jose.XEdDSA) // Provisioner with SSH disabled var bFalse bool pDisabled, _, _ := mustNebulaProvisioner(t) pDisabled.caPool = p.caPool pDisabled.Claims.EnableSSHCA = &bFalse type args struct { ctx context.Context token string } tests := []struct { name string p *Nebula args args wantErr bool }{ {"fail unauthorized", p, args{ctx, ok}, true}, {"fail token", p, args{ctx, failToken}, true}, {"fail disabled", pDisabled, args{ctx, ok}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.p.AuthorizeSSHRevoke(tt.args.ctx, tt.args.token); (err != nil) != tt.wantErr { t.Errorf("Nebula.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestNebula_AuthorizeSSHRenew(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, ca, signer) t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv, jose.XEdDSA) type args struct { ctx context.Context token string } tests := []struct { name string p *Nebula args args want *ssh.Certificate wantErr bool }{ {"fail", p, args{context.TODO(), t1}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.p.AuthorizeSSHRenew(tt.args.ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Nebula.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) return } assert.Equal(t, tt.want, got) }) } } func TestNebula_AuthorizeSSHRekey(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, ca, signer) t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv, jose.XEdDSA) type args struct { ctx context.Context token string } tests := []struct { name string p *Nebula args args want *ssh.Certificate want1 []SignOption wantErr bool }{ {"fail", p, args{context.TODO(), t1}, nil, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1, err := tt.p.AuthorizeSSHRekey(tt.args.ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Nebula.AuthorizeSSHRekey() error = %v, wantErr %v", err, tt.wantErr) return } assert.Equal(t, tt.want, got) assert.Equal(t, tt.want1, got1) }) } } func TestNebula_authorizeToken(t *testing.T) { t1 := now() p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, ca, signer) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv, jose.XEdDSA) okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv, jose.XEdDSA) okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan"}, }, crt, priv, jose.XEdDSA) okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv, jose.XEdDSA) // Token with errors failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv, jose.XEdDSA) failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv, jose.XEdDSA) failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv, jose.XEdDSA) failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv, jose.XEdDSA) // Not a nebula token jwk, err := generateJSONWebKey() require.NoError(t, err) simpleToken, err := generateSimpleToken("iss", "aud", jwk) require.NoError(t, err) // Provisioner with a different CA p2, _, _ := mustNebulaProvisioner(t) x509Claims := jose.Claims{ ID: "[REPLACEME]", Subject: "test.lan", Issuer: p.Name, IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), Audience: []string{p.ctl.Audiences.Sign[0]}, } sshClaims := jose.Claims{ ID: "[REPLACEME]", Subject: "test.lan", Issuer: p.Name, IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), Audience: []string{p.ctl.Audiences.SSHSign[0]}, } type args struct { token string audiences []string } tests := []struct { name string p *Nebula args args wantClaims *jwtPayload wantErr bool }{ {"ok x509", p, args{ok, p.ctl.Audiences.Sign}, &jwtPayload{ Claims: x509Claims, SANs: []string{"10.1.0.1"}, }, false}, {"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, &jwtPayload{ Claims: x509Claims, }, false}, {"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, &jwtPayload{ Claims: sshClaims, Step: &stepPayload{ SSH: &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan"}, }, }, }, false}, {"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, &jwtPayload{ Claims: sshClaims, }, false}, {"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, true}, {"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, true}, {"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, true}, {"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, true}, {"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, true}, {"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, true}, {"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences) if tt.wantErr { assert.Error(t, err) assert.Nil(t, got) assert.Nil(t, got1) return } if got1 != nil && tt.wantClaims != nil { tt.wantClaims.ID = got1.ID } assert.NoError(t, err) assert.NotNil(t, got) assert.Equal(t, tt.wantClaims, got1) }) } } func TestNebula_authorizeToken_P256(t *testing.T) { t1 := now() p, ca, signer := mustNebulaP256Provisioner(t) crt, priv := mustNebulaP256Cert(t, "test.lan", mustNebulaPrefix(t, "10.1.0.1/16"), []string{"test"}, ca, signer) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv, jose.ES256) okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv, jose.ES256) okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan"}, }, crt, priv, jose.ES256) okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv, jose.ES256) // Token with errors failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv, jose.ES256) failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv, jose.ES256) failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv, jose.ES256) failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv, jose.ES256) // Not a nebula token jwk, err := generateJSONWebKey() require.NoError(t, err) simpleToken, err := generateSimpleToken("iss", "aud", jwk) require.NoError(t, err) // Provisioner with a different CA p2, _, _ := mustNebulaP256Provisioner(t) x509Claims := jose.Claims{ ID: "[REPLACEME]", Subject: "test.lan", Issuer: p.Name, IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), Audience: []string{p.ctl.Audiences.Sign[0]}, } sshClaims := jose.Claims{ ID: "[REPLACEME]", Subject: "test.lan", Issuer: p.Name, IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), Audience: []string{p.ctl.Audiences.SSHSign[0]}, } type args struct { token string audiences []string } tests := []struct { name string p *Nebula args args wantClaims *jwtPayload wantErr bool }{ {"ok x509", p, args{ok, p.ctl.Audiences.Sign}, &jwtPayload{ Claims: x509Claims, SANs: []string{"10.1.0.1"}, }, false}, {"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, &jwtPayload{ Claims: x509Claims, }, false}, {"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, &jwtPayload{ Claims: sshClaims, Step: &stepPayload{ SSH: &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan"}, }, }, }, false}, {"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, &jwtPayload{ Claims: sshClaims, }, false}, {"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, true}, {"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, true}, {"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, true}, {"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, true}, {"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, true}, {"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, true}, {"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences) if tt.wantErr { assert.Error(t, err) assert.Nil(t, got) assert.Nil(t, got1) return } if got1 != nil && tt.wantClaims != nil { tt.wantClaims.ID = got1.ID } assert.NoError(t, err) assert.NotNil(t, got) assert.Equal(t, tt.wantClaims, got1) }) } } func Test_nebulaSANsValidator_Valid(t *testing.T) { prefix := mustNebulaPrefix(t, "10.1.2.3/16") type fields struct { Name string Networks []netip.Prefix } type args struct { req *x509.CertificateRequest } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{"dns.name", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ DNSNames: []string{"dns.name"}, IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, }}, false}, {"ok name only", fields{"dns.name", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ DNSNames: []string{"dns.name"}, }}, false}, {"ok ip only", fields{"dns.name", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, }}, false}, {"ok email name", fields{"jane@doe.org", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ EmailAddresses: []string{"jane@doe.org"}, IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, }}, false}, {"ok uri name", fields{"urn:foobar", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ URIs: []*url.URL{{Scheme: "urn", Opaque: "foobar"}}, IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, }}, false}, {"ok ip name", fields{"127.0.0.1", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv4(10, 1, 2, 3)}, }}, false}, {"ok multiple ips", fields{"dns.name", []netip.Prefix{prefix, mustNebulaPrefix(t, "10.2.2.3/8")}}, args{&x509.CertificateRequest{ DNSNames: []string{"dns.name"}, IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3), net.IPv4(10, 2, 2, 3)}, }}, false}, {"fail dns", fields{"fail.name", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ DNSNames: []string{"dns.name"}, IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, }}, true}, {"fail email", fields{"fail@doe.org", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ EmailAddresses: []string{"jane@doe.org"}, IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, }}, true}, {"fail uri", fields{"urn:barfoo", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ URIs: []*url.URL{{Scheme: "urn", Opaque: "foobar"}}, IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, }}, true}, {"fail ip", fields{"127.0.0.1", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ IPAddresses: []net.IP{net.IPv4(10, 1, 2, 1), net.IPv4(10, 1, 2, 3)}, }}, true}, {"fail nebula ip", fields{"dns.name", []netip.Prefix{prefix}}, args{&x509.CertificateRequest{ DNSNames: []string{"dns.name"}, IPAddresses: []net.IP{net.IPv4(10, 2, 2, 3)}, }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { v := nebulaSANsValidator{ Name: tt.fields.Name, Networks: tt.fields.Networks, } if err := v.Valid(tt.args.req); (err != nil) != tt.wantErr { t.Errorf("nebulaSANsValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_nebulaPrincipalsValidator_Valid(t *testing.T) { prefix := mustNebulaPrefix(t, "10.1.2.3/16") type fields struct { Name string Networks []netip.Prefix } type args struct { got SignSSHOptions } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{"dns.name", []netip.Prefix{prefix}}, args{SignSSHOptions{ Principals: []string{"dns.name", "10.1.2.3"}, }}, false}, {"ok name", fields{"dns.name", []netip.Prefix{prefix}}, args{SignSSHOptions{ Principals: []string{"dns.name"}, }}, false}, {"ok ip", fields{"dns.name", []netip.Prefix{prefix}}, args{SignSSHOptions{ Principals: []string{"10.1.2.3"}, }}, false}, {"fail name", fields{"dns.name", []netip.Prefix{prefix}}, args{SignSSHOptions{ Principals: []string{"foo.name", "10.1.2.3"}, }}, true}, {"fail ip", fields{"dns.name", []netip.Prefix{prefix}}, args{SignSSHOptions{ Principals: []string{"dns.name", "10.2.2.3"}, }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { v := nebulaPrincipalsValidator{ Name: tt.fields.Name, Networks: tt.fields.Networks, } if err := v.Valid(tt.args.got); (err != nil) != tt.wantErr { t.Errorf("nebulaPrincipalsValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) } }) } } ================================================ FILE: authority/provisioner/noop.go ================================================ package provisioner import ( "context" "crypto/x509" "golang.org/x/crypto/ssh" ) // noop provisioners is a provisioner that accepts anything. type noop struct{} func (p *noop) GetID() string { return "noop" } func (p *noop) GetIDForToken() string { return "noop" } func (p *noop) GetTokenID(string) (string, error) { return "", nil } func (p *noop) GetName() string { return "noop" } func (p *noop) GetType() Type { return noopType } func (p *noop) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } func (p *noop) Init(Config) error { return nil } func (p *noop) AuthorizeSign(context.Context, string) ([]SignOption, error) { return []SignOption{p}, nil } func (p *noop) AuthorizeRenew(context.Context, *x509.Certificate) error { return nil } func (p *noop) AuthorizeRevoke(context.Context, string) error { return nil } func (p *noop) AuthorizeSSHSign(context.Context, string) ([]SignOption, error) { return []SignOption{p}, nil } func (p *noop) AuthorizeSSHRenew(context.Context, string) (*ssh.Certificate, error) { //nolint:nilnil // fine for noop return nil, nil } func (p *noop) AuthorizeSSHRevoke(context.Context, string) error { return nil } func (p *noop) AuthorizeSSHRekey(context.Context, string) (*ssh.Certificate, []SignOption, error) { return nil, []SignOption{}, nil } ================================================ FILE: authority/provisioner/noop_test.go ================================================ package provisioner import ( "context" "crypto/x509" "testing" "github.com/smallstep/assert" ) func Test_noop(t *testing.T) { p := noop{} assert.Equals(t, "noop", p.GetID()) assert.Equals(t, "noop", p.GetName()) assert.Equals(t, noopType, p.GetType()) assert.Equals(t, nil, p.Init(Config{})) assert.Equals(t, nil, p.AuthorizeRenew(context.Background(), &x509.Certificate{})) assert.Equals(t, nil, p.AuthorizeRevoke(context.Background(), "foo")) kid, key, ok := p.GetEncryptedKey() assert.Equals(t, "", kid) assert.Equals(t, "", key) assert.Equals(t, false, ok) ctx := NewContextWithMethod(context.Background(), SignMethod) sigOptions, err := p.AuthorizeSign(ctx, "foo") assert.Equals(t, []SignOption{&p}, sigOptions) assert.Equals(t, nil, err) } ================================================ FILE: authority/provisioner/oidc.go ================================================ package provisioner import ( "context" "crypto/x509" "encoding/json" "net" "net/http" "net/url" "path" "strings" "time" "github.com/pkg/errors" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/errs" ) // openIDConfiguration contains the necessary properties in the // `/.well-known/openid-configuration` document. type openIDConfiguration struct { Issuer string `json:"issuer"` JWKSetURI string `json:"jwks_uri"` } // Validate validates the values in a well-known OpenID configuration endpoint. func (c openIDConfiguration) Validate() error { switch { case c.Issuer == "": return errors.New("issuer cannot be empty") case c.JWKSetURI == "": return errors.New("jwks_uri cannot be empty") default: return nil } } // openIDPayload represents the fields on the id_token JWT payload. type openIDPayload struct { jose.Claims AtHash string `json:"at_hash"` AuthorizedParty string `json:"azp"` Email string `json:"email"` EmailVerified bool `json:"email_verified"` Hd string `json:"hd"` Nonce string `json:"nonce"` Groups []string `json:"groups"` } func (o *openIDPayload) IsAdmin(admins []string) bool { if o.Email != "" { email := sanitizeEmail(o.Email) for _, e := range admins { if email == sanitizeEmail(e) { return true } } } // The groups and emails can be in the same array for now, but consider // making a specialized option later. for _, name := range o.Groups { for _, admin := range admins { if name == admin { return true } } } return false } // OIDC represents an OAuth 2.0 OpenID Connect provider. // // ClientSecret is mandatory, but it can be an empty string. type OIDC struct { *base ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` ClientID string `json:"clientID"` ClientSecret string `json:"clientSecret"` ConfigurationEndpoint string `json:"configurationEndpoint"` TenantID string `json:"tenantID,omitempty"` Admins []string `json:"admins,omitempty"` Domains []string `json:"domains,omitempty"` Groups []string `json:"groups,omitempty"` ListenAddress string `json:"listenAddress,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` Scopes []string `json:"scopes,omitempty"` AuthParams []string `json:"authParams,omitempty"` configuration openIDConfiguration keyStore *keyStore ctl *Controller } func sanitizeEmail(email string) string { if i := strings.LastIndex(email, "@"); i >= 0 { email = email[:i] + strings.ToLower(email[i:]) } return email } // GetID returns the provisioner unique identifier, the OIDC provisioner the // uses the clientID for this. func (o *OIDC) GetID() string { if o.ID != "" { return o.ID } return o.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (o *OIDC) GetIDForToken() string { return o.ClientID } // GetTokenID returns the provisioner unique identifier, the OIDC provisioner the // uses the clientID for this. func (o *OIDC) GetTokenID(ott string) (string, error) { // Validate payload token, err := jose.ParseSigned(ott) if err != nil { return "", errors.Wrap(err, "error parsing token") } // Get claims w/out verification. We need to look up the provisioner // key in order to verify the claims and we need the issuer from the claims // before we can look up the provisioner. var claims openIDPayload if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { return "", errors.Wrap(err, "error verifying claims") } return claims.Nonce, nil } // GetName returns the name of the provisioner. func (o *OIDC) GetName() string { return o.Name } // GetType returns the type of provisioner. func (o *OIDC) GetType() Type { return TypeOIDC } // GetEncryptedKey is not available in an OIDC provisioner. func (o *OIDC) GetEncryptedKey() (kid, key string, ok bool) { return "", "", false } // Init validates and initializes the OIDC provider. func (o *OIDC) Init(config Config) (err error) { switch { case o.Type == "": return errors.New("type cannot be empty") case o.Name == "": return errors.New("name cannot be empty") case o.ClientID == "": return errors.New("clientID cannot be empty") case o.ConfigurationEndpoint == "": return errors.New("configurationEndpoint cannot be empty") } // Validate listenAddress if given if o.ListenAddress != "" { if _, _, err := net.SplitHostPort(o.ListenAddress); err != nil { return errors.Wrap(err, "error parsing listenAddress") } } // Decode and validate openid-configuration endpoint u, err := url.Parse(o.ConfigurationEndpoint) if err != nil { return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint) } if !strings.Contains(u.Path, "/.well-known/openid-configuration") { u.Path = path.Join(u.Path, "/.well-known/openid-configuration") } // Initialize the common provisioner controller o.ctl, err = NewController(o, o.Claims, config, o.Options) if err != nil { return err } // Decode and validate openid-configuration httpClient := o.ctl.GetHTTPClient() if err := getAndDecode(httpClient, u.String(), &o.configuration); err != nil { return err } if err := o.configuration.Validate(); err != nil { return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint) } // Replace {tenantid} with the configured one if o.TenantID != "" { o.configuration.Issuer = strings.ReplaceAll(o.configuration.Issuer, "{tenantid}", o.TenantID) } // Get JWK key set o.keyStore, err = newKeyStore(httpClient, o.configuration.JWKSetURI) return } // ValidatePayload validates the given token payload. func (o *OIDC) ValidatePayload(p openIDPayload) error { // According to "rfc7519 JSON Web Token" acceptable skew should be no more // than a few minutes. if err := p.ValidateWithLeeway(jose.Expected{ Issuer: o.configuration.Issuer, Audience: jose.Audience{o.ClientID}, Time: time.Now().UTC(), }, time.Minute); err != nil { return errs.Wrap(http.StatusUnauthorized, err, "validatePayload: failed to validate oidc token payload") } // Validate azp if present if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID { return errs.Unauthorized("validatePayload: failed to validate oidc token payload: invalid azp") } // Validate domains (case-insensitive) if p.Email != "" && len(o.Domains) > 0 && !p.IsAdmin(o.Admins) { email := sanitizeEmail(p.Email) var found bool for _, d := range o.Domains { if strings.HasSuffix(email, "@"+strings.ToLower(d)) { found = true break } } if !found { return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email %q is not allowed", p.Email) } } // Filter by oidc group claim if len(o.Groups) > 0 { var found bool for _, group := range o.Groups { for _, g := range p.Groups { if g == group { found = true break } } } if !found { return errs.Unauthorized("validatePayload: oidc token payload validation failed: invalid group") } } return nil } // authorizeToken applies the most common provisioner authorization claims, // leaving the rest to context specific methods. func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "oidc.AuthorizeToken; error parsing oidc token") } // Parse claims to get the kid var claims openIDPayload if err := jwt.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "oidc.AuthorizeToken; error parsing oidc token claims") } found := false kid := jwt.Headers[0].KeyID keys := o.keyStore.Get(kid) for _, key := range keys { if err := jwt.Claims(key, &claims); err == nil { found = true break } } if !found { return nil, errs.Unauthorized("oidc.AuthorizeToken; cannot validate oidc token") } if err := o.ValidatePayload(claims); err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeToken") } return &claims, nil } // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. // Only tokens generated by an admin have the right to revoke a certificate. func (o *OIDC) AuthorizeRevoke(_ context.Context, token string) error { claims, err := o.authorizeToken(token) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeRevoke") } // Only admins can revoke certificates. if claims.IsAdmin(o.Admins) { return nil } return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token") } // AuthorizeSign validates the given token. func (o *OIDC) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { claims, err := o.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSign") } // Certificate templates sans := []string{} if claims.Email != "" { sans = append(sans, claims.Email) } // Add uri SAN with iss#sub if issuer is a URL with schema. // // According to https://openid.net/specs/openid-connect-core-1_0.html the // iss value is a case sensitive URL using the https scheme that contains // scheme, host, and optionally, port number and path components and no // query or fragment components. if iss, err := url.Parse(claims.Issuer); err == nil && iss.Scheme != "" { iss.Fragment = claims.Subject sans = append(sans, iss.String()) } data := x509util.CreateTemplateData(claims.Subject, sans) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } // Use the default template unless no-templates are configured and email is // an admin, in that case we will use the CR template. defaultTemplate := x509util.DefaultLeafTemplate if !o.Options.GetX509Options().HasTemplate() && claims.IsAdmin(o.Admins) { defaultTemplate = x509util.DefaultAdminLeafTemplate } templateOptions, err := CustomTemplateOptions(o.Options, data, defaultTemplate) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSign") } return []SignOption{ o, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID).WithControllerOptions(o.ctl), profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(o.ctl.getPolicy().getX509()), // webhooks o.ctl.newWebhookController(data, linkedca.Webhook_X509), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. // NOTE: This method does not actually validate the certificate or check it's // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { return o.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { if !o.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName()) } claims, err := o.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") } if claims.Subject == "" { return nil, errs.Unauthorized("oidc.AuthorizeSSHSign: failed to validate oidc token payload: subject not found") } var data sshutil.TemplateData if claims.Email == "" { // If email is empty, use the Subject claim instead to create minimal // data for the template to use. data = sshutil.CreateTemplateData(sshutil.UserCert, claims.Subject, nil) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } } else { // Get the identity using either the default identityFunc or one injected // externally. Note that the PreferredUsername might be empty. // TBD: Would preferred_username present a safety issue here? iden, err := o.ctl.GetIdentity(ctx, claims.Email) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") } // Certificate templates. data = sshutil.CreateTemplateData(sshutil.UserCert, claims.Email, iden.Usernames) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } // Add custom extensions added in the identity function. for k, v := range iden.Permissions.Extensions { data.AddExtension(k, v) } // Add custom critical options added in the identity function. for k, v := range iden.Permissions.CriticalOptions { data.AddCriticalOption(k, v) } } // Use the default template unless no-templates are configured and email is // an admin, in that case we will use the parameters in the request. isAdmin := claims.IsAdmin(o.Admins) defaultTemplate := sshutil.DefaultTemplate if isAdmin && !o.Options.GetSSHOptions().HasTemplate() { defaultTemplate = sshutil.DefaultAdminTemplate } templateOptions, err := CustomSSHTemplateOptions(o.Options, data, defaultTemplate) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } signOptions := []SignOption{templateOptions} // Admin users can use any principal, and can sign user and host certificates. // Non-admin users can only use principals returned by the identityFunc, and // can only sign user certificates. if isAdmin { signOptions = append(signOptions, &sshCertOptionsRequireValidator{ CertType: true, KeyID: true, Principals: true, }) } else { signOptions = append(signOptions, sshCertOptionsValidator(SignSSHOptions{ CertType: SSHUserCert, })) } return append(signOptions, o, // Set the validity bounds if not set. &sshDefaultDuration{o.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. &sshCertValidityValidator{o.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, // Ensure that all principal names are allowed newSSHNamePolicyValidator(o.ctl.getPolicy().getSSHHost(), o.ctl.getPolicy().getSSHUser()), // Call webhooks o.ctl.newWebhookController(data, linkedca.Webhook_SSH), ), nil } // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. func (o *OIDC) AuthorizeSSHRevoke(_ context.Context, token string) error { claims, err := o.authorizeToken(token) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHRevoke") } // Only admins can revoke certificates. if claims.IsAdmin(o.Admins) { return nil } return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token") } func getAndDecode(client HTTPClient, uri string, v interface{}) error { resp, err := client.Get(uri) if err != nil { return errors.Wrapf(err, "failed to connect to %s", uri) } defer resp.Body.Close() if err := json.NewDecoder(resp.Body).Decode(v); err != nil { return errors.Wrapf(err, "error reading %s", uri) } return nil } ================================================ FILE: authority/provisioner/oidc_test.go ================================================ package provisioner import ( "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/x509" "errors" "fmt" "net/http" "net/url" "strings" "testing" "time" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" ) func Test_openIDConfiguration_Validate(t *testing.T) { type fields struct { Issuer string JWKSetURI string } tests := []struct { name string fields fields wantErr bool }{ {"ok", fields{"the-issuer", "the-jwks-uri"}, false}, {"no-issuer", fields{"", "the-jwks-uri"}, true}, {"no-jwks-uri", fields{"the-issuer", ""}, true}, {"empty", fields{"", ""}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := openIDConfiguration{ Issuer: tt.fields.Issuer, JWKSetURI: tt.fields.JWKSetURI, } if err := c.Validate(); (err != nil) != tt.wantErr { t.Errorf("openIDConfiguration.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestOIDC_Getters(t *testing.T) { p, err := generateOIDC() assert.FatalError(t, err) if got := p.GetID(); got != p.ClientID { t.Errorf("OIDC.GetID() = %v, want %v", got, p.ClientID) } if got := p.GetName(); got != p.Name { t.Errorf("OIDC.GetName() = %v, want %v", got, p.Name) } if got := p.GetType(); got != TypeOIDC { t.Errorf("OIDC.GetType() = %v, want %v", got, TypeOIDC) } kid, key, ok := p.GetEncryptedKey() if kid != "" || key != "" || ok == true { t.Errorf("OIDC.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } } func TestOIDC_Init(t *testing.T) { srv := generateJWKServer(2) defer srv.Close() tlsSrv := generateTLSJWKServer(2) defer tlsSrv.Close() config := Config{ Claims: globalProvisionerClaims, HTTPClient: tlsSrv.Client(), } badHTTPClientConfig := Config{ Claims: globalProvisionerClaims, HTTPClient: http.DefaultClient, } badClaims := &Claims{ DefaultTLSDur: &Duration{0}, } type fields struct { Type string Name string ClientID string ClientSecret string ConfigurationEndpoint string Claims *Claims Admins []string Domains []string ListenAddress string } type args struct { config Config } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, false}, {"ok tls", fields{"oidc", "name", "client-id", "client-secret", tlsSrv.URL, nil, nil, nil, ""}, args{config}, false}, {"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/.well-known/openid-configuration", nil, []string{"foo@smallstep.com"}, nil, ""}, args{config}, false}, {"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, []string{"smallstep.com"}, ""}, args{config}, false}, {"ok-listen-port", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ":10000"}, args{config}, false}, {"ok-listen-host-port", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, "127.0.0.1:10000"}, args{config}, false}, {"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL, nil, nil, nil, ""}, args{config}, false}, {"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, true}, {"no-type", fields{"", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, true}, {"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, true}, {"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil, ""}, args{config}, true}, {"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/random", nil, nil, nil, ""}, args{config}, true}, {"bad-claims", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/.well-known/openid-configuration", badClaims, nil, nil, ""}, args{config}, true}, {"bad-parse-url", fields{"oidc", "name", "client-id", "client-secret", ":", nil, nil, nil, ""}, args{config}, true}, {"bad-get-url", fields{"oidc", "name", "client-id", "client-secret", "https://", nil, nil, nil, ""}, args{config}, true}, {"bad-listen-address", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, "127.0.0.1"}, args{config}, true}, {"bad-http-client", fields{"oidc", "name", "client-id", "client-secret", tlsSrv.URL, nil, nil, nil, ""}, args{badHTTPClientConfig}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &OIDC{ Type: tt.fields.Type, Name: tt.fields.Name, ClientID: tt.fields.ClientID, ConfigurationEndpoint: tt.fields.ConfigurationEndpoint, Claims: tt.fields.Claims, Admins: tt.fields.Admins, Domains: tt.fields.Domains, ListenAddress: tt.fields.ListenAddress, } if err := p.Init(tt.args.config); (err != nil) != tt.wantErr { t.Errorf("OIDC.Init() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr == false { assert.Len(t, 2, p.keyStore.keySet.Keys) u, err := url.Parse(tt.fields.ConfigurationEndpoint) require.NoError(t, err) assert.Equals(t, openIDConfiguration{ Issuer: "the-issuer", JWKSetURI: u.ResolveReference(&url.URL{Path: "/jwks_uri"}).String(), }, p.configuration) } }) } } func TestOIDC_authorizeToken(t *testing.T) { srv := generateJWKServer(3) defer srv.Close() var keys jose.JSONWebKeySet assert.FatalError(t, getAndDecode(srv.Client(), srv.URL+"/private", &keys)) issuer := "the-issuer" tenantID := "ab800f7d-2c87-45fb-b1d0-f90d0bc5ec25" tenantIssuer := "https://login.microsoftonline.com/" + tenantID + "/v2.0" // Create test provisioners p1, err := generateOIDC() assert.FatalError(t, err) p2, err := generateOIDC() assert.FatalError(t, err) p3, err := generateOIDC() assert.FatalError(t, err) // TenantID p2.TenantID = tenantID // Admin + Domains p3.Admins = []string{"name@smallstep.com", "root@example.com"} p3.Domains = []string{"smallstep.com"} // Update configuration endpoints and initialize config := Config{Claims: globalProvisionerClaims} p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p2.ConfigurationEndpoint = srv.URL + "/common/.well-known/openid-configuration" p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" assert.FatalError(t, p1.Init(config)) assert.FatalError(t, p2.Init(config)) assert.FatalError(t, p3.Init(config)) t1, err := generateSimpleToken(issuer, p1.ClientID, &keys.Keys[0]) assert.FatalError(t, err) t2, err := generateSimpleToken(tenantIssuer, p2.ClientID, &keys.Keys[1]) assert.FatalError(t, err) t3, err := generateToken("subject", issuer, p3.ClientID, "name@smallstep.com", []string{}, time.Now(), &keys.Keys[2]) assert.FatalError(t, err) t4, err := generateToken("subject", issuer, p3.ClientID, "foo@smallstep.com", []string{}, time.Now(), &keys.Keys[2]) assert.FatalError(t, err) t5, err := generateToken("subject", issuer, p3.ClientID, "", []string{}, time.Now(), &keys.Keys[2]) assert.FatalError(t, err) // Invalid email failDomain, err := generateToken("subject", issuer, p3.ClientID, "name@example.com", []string{}, time.Now(), &keys.Keys[2]) assert.FatalError(t, err) // Invalid tokens parts := strings.Split(t1, ".") key, err := generateJSONWebKey() assert.FatalError(t, err) // missing key failKey, err := generateSimpleToken(issuer, p1.ClientID, key) assert.FatalError(t, err) // invalid token failTok := "foo." + parts[1] + "." + parts[2] // invalid claims failClaims := parts[0] + ".foo." + parts[1] // invalid issuer failIss, err := generateSimpleToken("bad-issuer", p1.ClientID, &keys.Keys[0]) assert.FatalError(t, err) // invalid audience failAud, err := generateSimpleToken(issuer, "foobar", &keys.Keys[0]) assert.FatalError(t, err) // invalid signature failSig := t1[0 : len(t1)-2] // expired failExp, err := generateToken("subject", issuer, p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(-360*time.Second), &keys.Keys[0]) assert.FatalError(t, err) // not before failNbf, err := generateToken("subject", issuer, p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(360*time.Second), &keys.Keys[0]) assert.FatalError(t, err) type args struct { token string } tests := []struct { name string prov *OIDC args args code int wantIssuer string expErr error }{ {"ok1", p1, args{t1}, http.StatusOK, issuer, nil}, {"ok tenantid", p2, args{t2}, http.StatusOK, tenantIssuer, nil}, {"ok admin", p3, args{t3}, http.StatusOK, issuer, nil}, {"ok domain", p3, args{t4}, http.StatusOK, issuer, nil}, {"ok no email", p3, args{t5}, http.StatusOK, issuer, nil}, {"fail-domain", p3, args{failDomain}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: email "name@example.com" is not allowed`)}, {"fail-key", p1, args{failKey}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; cannot validate oidc token`)}, {"fail-token", p1, args{failTok}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; error parsing oidc token: invalid character '~' looking for beginning of value`)}, {"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; error parsing oidc token claims: invalid character '~' looking for beginning of value`)}, {"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: go-jose/go-jose/jwt: validation failed, invalid issuer claim (iss)`)}, {"fail-audience", p1, args{failAud}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: go-jose/go-jose/jwt: validation failed, invalid audience claim (aud)`)}, {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; cannot validate oidc token`)}, {"fail-expired", p1, args{failExp}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: go-jose/go-jose/jwt: validation failed, token is expired (exp)`)}, {"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: go-jose/go-jose/jwt: validation failed, token not valid yet (nbf)`)}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.prov.authorizeToken(tt.args.token) if tt.expErr != nil { require.Error(t, err) require.EqualError(t, err, tt.expErr.Error()) var sc render.StatusCodedError require.ErrorAs(t, err, &sc, "error does not implement StatusCodedError interface") require.Equal(t, tt.code, sc.StatusCode()) require.Nil(t, got) } else { require.NotNil(t, got) require.Equal(t, tt.wantIssuer, got.Issuer) } }) } } func TestOIDC_AuthorizeSign(t *testing.T) { srv := generateJWKServer(2) defer srv.Close() var keys jose.JSONWebKeySet assert.FatalError(t, getAndDecode(srv.Client(), srv.URL+"/private", &keys)) // Create test provisioners p1, err := generateOIDC() assert.FatalError(t, err) p2, err := generateOIDC() assert.FatalError(t, err) p3, err := generateOIDC() assert.FatalError(t, err) // Admin + Domains p3.Admins = []string{"name@smallstep.com", "root@example.com"} p3.Domains = []string{"smallstep.com"} // Update configuration endpoints and initialize config := Config{Claims: globalProvisionerClaims} p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" assert.FatalError(t, p1.Init(config)) assert.FatalError(t, p2.Init(config)) assert.FatalError(t, p3.Init(config)) t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) assert.FatalError(t, err) // Admin email not in domains okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) // No email noEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) type args struct { token string } tests := []struct { name string prov *OIDC args args code int wantErr bool }{ {"ok1", p1, args{t1}, http.StatusOK, false}, {"admin", p3, args{okAdmin}, http.StatusOK, false}, {"no-email", p3, args{noEmail}, http.StatusOK, false}, {"bad-token", p3, args{"foobar"}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.prov.AuthorizeSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { assert.Equals(t, 8, len(got)) for _, o := range got { switch v := o.(type) { case *OIDC: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equals(t, v.Type, TypeOIDC) assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.CredentialID, tt.prov.ClientID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case *x509NamePolicyValidator: assert.Equals(t, nil, v.policyEngine) case *WebhookController: assert.Len(t, 0, v.webhooks) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } }) } } func TestOIDC_AuthorizeRevoke(t *testing.T) { srv := generateJWKServer(2) defer srv.Close() var keys jose.JSONWebKeySet assert.FatalError(t, getAndDecode(srv.Client(), srv.URL+"/private", &keys)) // Create test provisioners p1, err := generateOIDC() assert.FatalError(t, err) p3, err := generateOIDC() assert.FatalError(t, err) // Admin + Domains p3.Admins = []string{"name@smallstep.com", "root@example.com"} p3.Domains = []string{"smallstep.com"} // Update configuration endpoints and initialize config := Config{Claims: globalProvisionerClaims} p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" assert.FatalError(t, p1.Init(config)) assert.FatalError(t, p3.Init(config)) t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) assert.FatalError(t, err) // Admin email not in domains okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) // Invalid email failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) type args struct { token string } tests := []struct { name string prov *OIDC args args code int wantErr bool }{ {"ok1", p1, args{t1}, http.StatusUnauthorized, true}, {"admin", p3, args{okAdmin}, http.StatusOK, false}, {"fail-email", p3, args{failEmail}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { fmt.Println(tt) t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) return } else if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) } } func TestOIDC_AuthorizeRenew(t *testing.T) { now := time.Now().Truncate(time.Second) p1, err := generateOIDC() assert.FatalError(t, err) p2, err := generateOIDC() assert.FatalError(t, err) // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { cert *x509.Certificate } tests := []struct { name string prov *OIDC args args code int wantErr bool }{ {"ok", p1, args{&x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, http.StatusOK, false}, {"fail/renew-disabled", p2, args{&x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert) if (err != nil) != tt.wantErr { t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) } } func TestOIDC_AuthorizeSSHSign(t *testing.T) { tm, fn := mockNow() defer fn() srv := generateJWKServer(2) defer srv.Close() var keys jose.JSONWebKeySet assert.FatalError(t, getAndDecode(srv.Client(), srv.URL+"/private", &keys)) // Create test provisioners p1, err := generateOIDC() assert.FatalError(t, err) p2, err := generateOIDC() assert.FatalError(t, err) p3, err := generateOIDC() assert.FatalError(t, err) p4, err := generateOIDC() assert.FatalError(t, err) p5, err := generateOIDC() assert.FatalError(t, err) p6, err := generateOIDC() assert.FatalError(t, err) // Admin + Domains p3.Admins = []string{"name@smallstep.com", "root@example.com"} p3.Domains = []string{"smallstep.com"} // disable sshCA disable := false p6.Claims = &Claims{EnableSSHCA: &disable} p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) assert.FatalError(t, err) // Update configuration endpoints and initialize config := Config{Claims: globalProvisionerClaims} p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p4.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p5.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" assert.FatalError(t, p1.Init(config)) assert.FatalError(t, p2.Init(config)) assert.FatalError(t, p3.Init(config)) assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p5.Init(config)) p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return &Identity{Usernames: []string{"max", "mariano"}}, nil } p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return nil, errors.New("force") } // Additional test needed for empty usernames and duplicate email and usernames t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) assert.FatalError(t, err) okGetIdentityToken, err := generateSimpleToken("the-issuer", p4.ClientID, &keys.Keys[0]) assert.FatalError(t, err) failGetIdentityToken, err := generateSimpleToken("the-issuer", p5.ClientID, &keys.Keys[0]) assert.FatalError(t, err) // Admin email not in domains okAdmin, err := generateOIDCToken("subject", "the-issuer", p3.ClientID, "root@example.com", "", time.Now(), &keys.Keys[0]) assert.FatalError(t, err) // Empty email emptyEmail, err := generateToken("subject", "the-issuer", p1.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) expectemptyEmailOptions := &SignSSHOptions{ CertType: "user", Principals: []string{}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(p1.ctl.Claimer.DefaultUserSSHCertDuration())), } assert.FatalError(t, err) key, err := generateJSONWebKey() assert.FatalError(t, err) signer, err := generateJSONWebKey() assert.FatalError(t, err) pub := key.Public().Key rsa2048, err := rsa.GenerateKey(rand.Reader, 2048) assert.FatalError(t, err) //nolint:gosec // tests minimum size of the key rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name", "name@smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), } expectedAdminOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"root", "root@example.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), } expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), } type args struct { token string sshOpts SignSSHOptions key interface{} } tests := []struct { name string prov *OIDC args args expected *SignSSHOptions code int wantErr bool wantSignErr bool }{ {"ok", p1, args{t1, SignSSHOptions{}, pub}, expectedUserOptions, http.StatusOK, false, false}, {"ok-rsa2048", p1, args{t1, SignSSHOptions{}, rsa2048.Public()}, expectedUserOptions, http.StatusOK, false, false}, {"ok-user", p1, args{t1, SignSSHOptions{CertType: "user"}, pub}, expectedUserOptions, http.StatusOK, false, false}, {"ok-empty-email", p1, args{emptyEmail, SignSSHOptions{CertType: "user"}, pub}, expectemptyEmailOptions, http.StatusOK, false, false}, {"ok-principals", p1, args{t1, SignSSHOptions{Principals: []string{"name"}}, pub}, &SignSSHOptions{CertType: "user", Principals: []string{"name", "name@smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, {"ok-principals-ignore-passed", p1, args{t1, SignSSHOptions{Principals: []string{"root"}}, pub}, &SignSSHOptions{CertType: "user", Principals: []string{"name", "name@smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, {"ok-principals-getIdentity", p4, args{okGetIdentityToken, SignSSHOptions{Principals: []string{"mariano"}}, pub}, &SignSSHOptions{CertType: "user", Principals: []string{"max", "mariano"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, {"ok-emptyPrincipals-getIdentity", p4, args{okGetIdentityToken, SignSSHOptions{}, pub}, &SignSSHOptions{CertType: "user", Principals: []string{"max", "mariano"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, {"ok-options", p1, args{t1, SignSSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, &SignSSHOptions{CertType: "user", Principals: []string{"name", "name@smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, {"ok-admin-user", p3, args{okAdmin, SignSSHOptions{CertType: "user", KeyID: "root@example.com", Principals: []string{"root", "root@example.com"}}, pub}, expectedAdminOptions, http.StatusOK, false, false}, {"ok-admin-host", p3, args{okAdmin, SignSSHOptions{CertType: "host", KeyID: "smallstep.com", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, http.StatusOK, false, false}, {"ok-admin-options", p3, args{okAdmin, SignSSHOptions{CertType: "user", KeyID: "name", Principals: []string{"name"}}, pub}, &SignSSHOptions{CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, {"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true}, {"fail-user-host", p1, args{t1, SignSSHOptions{CertType: "host"}, pub}, nil, http.StatusOK, false, true}, {"fail-getIdentity", p5, args{failGetIdentityToken, SignSSHOptions{}, pub}, nil, http.StatusInternalServerError, true, false}, {"fail-sshCA-disabled", p6, args{"foo", SignSSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false}, // Missing parametrs {"fail-admin-type", p3, args{okAdmin, SignSSHOptions{KeyID: "root@example.com", Principals: []string{"root@example.com"}}, pub}, nil, http.StatusUnauthorized, false, true}, {"fail-admin-key-id", p3, args{okAdmin, SignSSHOptions{CertType: "user", Principals: []string{"root@example.com"}}, pub}, nil, http.StatusUnauthorized, false, true}, {"fail-admin-principals", p3, args{okAdmin, SignSSHOptions{CertType: "user", KeyID: "root@example.com"}, pub}, nil, http.StatusUnauthorized, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.prov.AuthorizeSSHSign(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("OIDC.AuthorizeSSHSign() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) if (err != nil) != tt.wantSignErr { t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) } else { if tt.wantSignErr { assert.Nil(t, cert) } else { assert.NoError(t, validateSSHCertificate(cert, tt.expected)) } } } }) } } func TestOIDC_AuthorizeSSHRevoke(t *testing.T) { p1, err := generateOIDC() assert.FatalError(t, err) p2, err := generateOIDC() assert.FatalError(t, err) p2.Admins = []string{"root@example.com"} srv := generateJWKServer(2) defer srv.Close() var keys jose.JSONWebKeySet assert.FatalError(t, getAndDecode(srv.Client(), srv.URL+"/private", &keys)) config := Config{Claims: globalProvisionerClaims} p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" assert.FatalError(t, p1.Init(config)) assert.FatalError(t, p2.Init(config)) // Invalid email failEmail, err := generateToken("subject", "the-issuer", p1.ClientID, "", []string{}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) // Admin email not in domains noAdmin, err := generateToken("subject", "the-issuer", p1.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) // Admin email in domains okAdmin, err := generateToken("subject", "the-issuer", p2.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) type args struct { token string } tests := []struct { name string prov *OIDC args args code int wantErr bool }{ {"ok", p2, args{okAdmin}, http.StatusOK, false}, {"fail/invalid-token", p1, args{failEmail}, http.StatusUnauthorized, true}, {"fail/not-admin", p1, args{noAdmin}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.prov.AuthorizeSSHRevoke(context.Background(), tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) } } func Test_sanitizeEmail(t *testing.T) { tests := []struct { name string email string want string }{ {"equal", "name@smallstep.com", "name@smallstep.com"}, {"domain-insensitive", "name@SMALLSTEP.COM", "name@smallstep.com"}, {"local-sensitive", "NaMe@smallSTEP.CoM", "NaMe@smallstep.com"}, {"multiple-@", "NaMe@NaMe@smallSTEP.CoM", "NaMe@NaMe@smallstep.com"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := sanitizeEmail(tt.email); got != tt.want { t.Errorf("sanitizeEmail() = %v, want %v", got, tt.want) } }) } } func Test_openIDPayload_IsAdmin(t *testing.T) { type fields struct { Email string Groups []string } type args struct { admins []string } tests := []struct { name string fields fields args args want bool }{ {"ok email", fields{"admin@smallstep.com", nil}, args{[]string{"admin@smallstep.com"}}, true}, {"ok email multiple", fields{"admin@smallstep.com", []string{"admin", "eng"}}, args{[]string{"eng@smallstep.com", "admin@smallstep.com"}}, true}, {"ok email sanitized", fields{"admin@Smallstep.com", nil}, args{[]string{"admin@smallStep.com"}}, true}, {"ok group", fields{"", []string{"admin"}}, args{[]string{"admin"}}, true}, {"ok group multiple", fields{"admin@smallstep.com", []string{"engineering", "admin"}}, args{[]string{"admin"}}, true}, {"fail missing", fields{"eng@smallstep.com", []string{"admin"}}, args{[]string{"admin@smallstep.com"}}, false}, {"fail email letter case", fields{"Admin@smallstep.com", []string{}}, args{[]string{"admin@smallstep.com"}}, false}, {"fail group letter case", fields{"", []string{"Admin"}}, args{[]string{"admin"}}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &openIDPayload{ Email: tt.fields.Email, Groups: tt.fields.Groups, } if got := o.IsAdmin(tt.args.admins); got != tt.want { t.Errorf("openIDPayload.IsAdmin() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/provisioner/options.go ================================================ package provisioner import ( "encoding/json" "strings" "github.com/pkg/errors" "github.com/smallstep/cli-utils/step" "go.step.sm/crypto/jose" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner/wire" ) // CertificateOptions is an interface that returns a list of options passed when // creating a new certificate. type CertificateOptions interface { Options(SignOptions) []x509util.Option } type certificateOptionsFunc func(SignOptions) []x509util.Option func (fn certificateOptionsFunc) Options(so SignOptions) []x509util.Option { return fn(so) } // Options are a collection of custom options that can be added to // each provisioner. type Options struct { X509 *X509Options `json:"x509,omitempty"` SSH *SSHOptions `json:"ssh,omitempty"` // Webhooks is a list of webhooks that can augment template data Webhooks []*Webhook `json:"webhooks,omitempty"` // Wire holds the options used for the ACME Wire integration Wire *wire.Options `json:"wire,omitempty"` } // GetX509Options returns the X.509 options. func (o *Options) GetX509Options() *X509Options { if o == nil { return nil } return o.X509 } // GetSSHOptions returns the SSH options. func (o *Options) GetSSHOptions() *SSHOptions { if o == nil { return nil } return o.SSH } // GetWireOptions returns the Wire options if available. It // returns an error if they're not available. func (o *Options) GetWireOptions() (*wire.Options, error) { if o == nil { return nil, errors.New("no options available") } if o.Wire == nil { return nil, errors.New("no Wire options available") } return o.Wire, nil } // GetWebhooks returns the webhooks options. func (o *Options) GetWebhooks() []*Webhook { if o == nil { return nil } return o.Webhooks } // X509Options contains specific options for X.509 certificates. type X509Options struct { // Template contains a X.509 certificate template. It can be a JSON template // escaped in a string or it can be also encoded in base64. Template string `json:"template,omitempty"` // TemplateFile points to a file containing a X.509 certificate template. TemplateFile string `json:"templateFile,omitempty"` // TemplateData is a JSON object with variables that can be used in custom // templates. TemplateData json.RawMessage `json:"templateData,omitempty"` // AllowedNames contains the SANs the provisioner is authorized to sign AllowedNames *policy.X509NameOptions `json:"-"` // DeniedNames contains the SANs the provisioner is not authorized to sign DeniedNames *policy.X509NameOptions `json:"-"` // AllowWildcardNames indicates if literal wildcard names // like *.example.com are allowed. Defaults to false. AllowWildcardNames bool `json:"-"` } // HasTemplate returns true if a template is defined in the provisioner options. func (o *X509Options) HasTemplate() bool { return o != nil && (o.Template != "" || o.TemplateFile != "") } // GetAllowedNameOptions returns the AllowedNames, which models the // SANs that a provisioner is authorized to sign x509 certificates for. func (o *X509Options) GetAllowedNameOptions() *policy.X509NameOptions { if o == nil { return nil } return o.AllowedNames } // GetDeniedNameOptions returns the DeniedNames, which models the // SANs that a provisioner is NOT authorized to sign x509 certificates for. func (o *X509Options) GetDeniedNameOptions() *policy.X509NameOptions { if o == nil { return nil } return o.DeniedNames } func (o *X509Options) AreWildcardNamesAllowed() bool { if o == nil { return true } return o.AllowWildcardNames } // TemplateOptions generates a CertificateOptions with the template and data // defined in the ProvisionerOptions, the provisioner generated data, and the // user data provided in the request. If no template has been provided, // x509util.DefaultLeafTemplate will be used. func TemplateOptions(o *Options, data x509util.TemplateData) (CertificateOptions, error) { return CustomTemplateOptions(o, data, x509util.DefaultLeafTemplate) } // CustomTemplateOptions generates a CertificateOptions with the template, data // defined in the ProvisionerOptions, the provisioner generated data and the // user data provided in the request. If no template has been provided in the // ProvisionerOptions, the given template will be used. func CustomTemplateOptions(o *Options, data x509util.TemplateData, defaultTemplate string) (CertificateOptions, error) { opts := o.GetX509Options() if data == nil { data = x509util.NewTemplateData() } if opts != nil { // Add template data if any. if len(opts.TemplateData) > 0 && string(opts.TemplateData) != "null" { if err := json.Unmarshal(opts.TemplateData, &data); err != nil { return nil, errors.Wrap(err, "error unmarshaling template data") } } } return certificateOptionsFunc(func(so SignOptions) []x509util.Option { // We're not provided user data without custom templates. if !opts.HasTemplate() { return []x509util.Option{ x509util.WithTemplate(defaultTemplate, data), } } // Add user provided data. if len(so.TemplateData) > 0 { userObject := make(map[string]interface{}) if err := json.Unmarshal(so.TemplateData, &userObject); err != nil { data.SetUserData(map[string]interface{}{}) } else { data.SetUserData(userObject) } } // Load a template from a file if Template is not defined. if opts.Template == "" && opts.TemplateFile != "" { return []x509util.Option{ x509util.WithTemplateFile(step.Abs(opts.TemplateFile), data), } } // Load a template from the Template fields // 1. As a JSON in a string. template := strings.TrimSpace(opts.Template) if strings.HasPrefix(template, "{") { return []x509util.Option{ x509util.WithTemplate(template, data), } } // 2. As a base64 encoded JSON. return []x509util.Option{ x509util.WithTemplateBase64(template, data), } }), nil } // unsafeParseSigned parses the given token and returns all the claims without // verifying the signature of the token. func unsafeParseSigned(s string) (map[string]interface{}, error) { token, err := jose.ParseSigned(s) if err != nil { return nil, err } claims := make(map[string]interface{}) if err := token.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, err } return claims, nil } ================================================ FILE: authority/provisioner/options_test.go ================================================ package provisioner import ( "bytes" "crypto/x509" "encoding/json" "reflect" "testing" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" ) func parseCertificateRequest(t *testing.T, filename string) *x509.CertificateRequest { t.Helper() v, err := pemutil.Read(filename) if err != nil { t.Fatal(err) } csr, ok := v.(*x509.CertificateRequest) if !ok { t.Fatalf("%s is not a certificate request", filename) } return csr } func TestOptions_GetX509Options(t *testing.T) { type fields struct { o *Options } tests := []struct { name string fields fields want *X509Options }{ {"ok", fields{&Options{X509: &X509Options{Template: "foo"}}}, &X509Options{Template: "foo"}}, {"nil", fields{&Options{}}, nil}, {"nilOptions", fields{nil}, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.fields.o.GetX509Options(); !reflect.DeepEqual(got, tt.want) { t.Errorf("Options.GetX509Options() = %v, want %v", got, tt.want) } }) } } func TestOptions_GetSSHOptions(t *testing.T) { type fields struct { o *Options } tests := []struct { name string fields fields want *SSHOptions }{ {"ok", fields{&Options{SSH: &SSHOptions{Template: "foo"}}}, &SSHOptions{Template: "foo"}}, {"nil", fields{&Options{}}, nil}, {"nilOptions", fields{nil}, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.fields.o.GetSSHOptions(); !reflect.DeepEqual(got, tt.want) { t.Errorf("Options.GetSSHOptions() = %v, want %v", got, tt.want) } }) } } func TestOptions_GetWebhooks(t *testing.T) { type fields struct { o *Options } tests := []struct { name string fields fields want []*Webhook }{ {"ok", fields{&Options{Webhooks: []*Webhook{ {Name: "foo"}, {Name: "bar"}, }}}, []*Webhook{ {Name: "foo"}, {Name: "bar"}, }, }, {"nil", fields{&Options{}}, nil}, {"nilOptions", fields{nil}, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.fields.o.GetWebhooks(); !reflect.DeepEqual(got, tt.want) { t.Errorf("Options.GetWebhooks() = %v, want %v", got, tt.want) } }) } } func TestProvisionerX509Options_HasTemplate(t *testing.T) { type fields struct { Template string TemplateFile string TemplateData json.RawMessage } tests := []struct { name string fields fields want bool }{ {"template", fields{Template: "the template"}, true}, {"templateFile", fields{TemplateFile: "the template file"}, true}, {"false", fields{}, false}, {"falseWithTemplateData", fields{TemplateData: []byte(`{"foo":"bar"}`)}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &X509Options{ Template: tt.fields.Template, TemplateFile: tt.fields.TemplateFile, TemplateData: tt.fields.TemplateData, } if got := o.HasTemplate(); got != tt.want { t.Errorf("ProvisionerOptions.HasTemplate() = %v, want %v", got, tt.want) } }) } } func TestTemplateOptions(t *testing.T) { csr := parseCertificateRequest(t, "testdata/certs/ecdsa.csr") data := x509util.TemplateData{ x509util.SubjectKey: x509util.Subject{ CommonName: "foobar", }, x509util.SANsKey: []x509util.SubjectAlternativeName{ {Type: "dns", Value: "foo.com"}, }, } type args struct { o *Options data x509util.TemplateData } tests := []struct { name string args args want x509util.Options wantErr bool }{ {"ok", args{nil, data}, x509util.Options{ CertBuffer: bytes.NewBufferString(`{ "subject": {"commonName":"foobar"}, "sans": [{"type":"dns","value":"foo.com"}], "keyUsage": ["digitalSignature"], "extKeyUsage": ["serverAuth", "clientAuth"] }`)}, false}, {"okCustomTemplate", args{&Options{X509: &X509Options{Template: x509util.DefaultIIDLeafTemplate}}, data}, x509util.Options{ CertBuffer: bytes.NewBufferString(`{ "subject": {"commonName": "foo"}, "sans": [{"type":"dns","value":"foo.com"}], "keyUsage": ["digitalSignature"], "extKeyUsage": ["serverAuth", "clientAuth"] }`)}, false}, {"fail", args{&Options{X509: &X509Options{TemplateData: []byte(`{"badJSON`)}}, data}, x509util.Options{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cof, err := TemplateOptions(tt.args.o, tt.args.data) if (err != nil) != tt.wantErr { t.Errorf("TemplateOptions() error = %v, wantErr %v", err, tt.wantErr) return } var opts x509util.Options if cof != nil { for _, fn := range cof.Options(SignOptions{}) { if err := fn(csr, &opts); err != nil { t.Errorf("x509util.Options() error = %v", err) return } } } if !reflect.DeepEqual(opts, tt.want) { t.Errorf("x509util.Option = %v, want %v", opts, tt.want) } }) } } func TestCustomTemplateOptions(t *testing.T) { csr := parseCertificateRequest(t, "testdata/certs/ecdsa.csr") csrCertificate := `{"version":0,"subject":{"commonName":"foo"},"rawSubject":"MA4xDDAKBgNVBAMTA2Zvbw==","dnsNames":["foo"],"emailAddresses":null,"ipAddresses":null,"uris":null,"sans":null,"extensions":[{"id":"2.5.29.17","critical":false,"value":"MAWCA2Zvbw=="}],"keyUsage":null,"extKeyUsage":[],"unknownExtKeyUsage":null,"basicConstraints":null,"signatureAlgorithm":""}` data := x509util.TemplateData{ x509util.SubjectKey: x509util.Subject{ CommonName: "foobar", }, x509util.SANsKey: []x509util.SubjectAlternativeName{ {Type: "dns", Value: "foo.com"}, }, } type args struct { o *Options data x509util.TemplateData defaultTemplate string userOptions SignOptions } tests := []struct { name string args args want x509util.Options wantErr bool }{ {"ok", args{nil, data, x509util.DefaultLeafTemplate, SignOptions{}}, x509util.Options{ CertBuffer: bytes.NewBufferString(`{ "subject": {"commonName":"foobar"}, "sans": [{"type":"dns","value":"foo.com"}], "keyUsage": ["digitalSignature"], "extKeyUsage": ["serverAuth", "clientAuth"] }`)}, false}, {"okIID", args{nil, data, x509util.DefaultIIDLeafTemplate, SignOptions{}}, x509util.Options{ CertBuffer: bytes.NewBufferString(`{ "subject": {"commonName": "foo"}, "sans": [{"type":"dns","value":"foo.com"}], "keyUsage": ["digitalSignature"], "extKeyUsage": ["serverAuth", "clientAuth"] }`)}, false}, {"okNoData", args{&Options{}, nil, x509util.DefaultLeafTemplate, SignOptions{}}, x509util.Options{ CertBuffer: bytes.NewBufferString(`{ "subject": null, "sans": null, "keyUsage": ["digitalSignature"], "extKeyUsage": ["serverAuth", "clientAuth"] }`)}, false}, {"okTemplateData", args{&Options{X509: &X509Options{TemplateData: []byte(`{"foo":"bar"}`)}}, data, x509util.DefaultLeafTemplate, SignOptions{}}, x509util.Options{ CertBuffer: bytes.NewBufferString(`{ "subject": {"commonName":"foobar"}, "sans": [{"type":"dns","value":"foo.com"}], "keyUsage": ["digitalSignature"], "extKeyUsage": ["serverAuth", "clientAuth"] }`)}, false}, {"okTemplate", args{&Options{X509: &X509Options{Template: "{{ toJson .Insecure.CR }}"}}, data, x509util.DefaultLeafTemplate, SignOptions{}}, x509util.Options{ CertBuffer: bytes.NewBufferString(csrCertificate)}, false}, {"okFile", args{&Options{X509: &X509Options{TemplateFile: "./testdata/templates/cr.tpl"}}, data, x509util.DefaultLeafTemplate, SignOptions{}}, x509util.Options{ CertBuffer: bytes.NewBufferString(csrCertificate)}, false}, {"okBase64", args{&Options{X509: &X509Options{Template: "e3sgdG9Kc29uIC5JbnNlY3VyZS5DUiB9fQ=="}}, data, x509util.DefaultLeafTemplate, SignOptions{}}, x509util.Options{ CertBuffer: bytes.NewBufferString(csrCertificate)}, false}, {"okUserOptions", args{&Options{X509: &X509Options{Template: `{"foo": "{{.Insecure.User.foo}}"}`}}, data, x509util.DefaultLeafTemplate, SignOptions{TemplateData: []byte(`{"foo":"bar"}`)}}, x509util.Options{ CertBuffer: bytes.NewBufferString(`{"foo": "bar"}`), }, false}, {"okBadUserOptions", args{&Options{X509: &X509Options{Template: `{"foo": "{{.Insecure.User.foo}}"}`}}, data, x509util.DefaultLeafTemplate, SignOptions{TemplateData: []byte(`{"badJSON"}`)}}, x509util.Options{ CertBuffer: bytes.NewBufferString(`{"foo": ""}`), }, false}, {"okNullTemplateData", args{&Options{X509: &X509Options{TemplateData: []byte(`null`)}}, data, x509util.DefaultLeafTemplate, SignOptions{}}, x509util.Options{ CertBuffer: bytes.NewBufferString(`{ "subject": {"commonName":"foobar"}, "sans": [{"type":"dns","value":"foo.com"}], "keyUsage": ["digitalSignature"], "extKeyUsage": ["serverAuth", "clientAuth"] }`)}, false}, {"fail", args{&Options{X509: &X509Options{TemplateData: []byte(`{"badJSON`)}}, data, x509util.DefaultLeafTemplate, SignOptions{}}, x509util.Options{}, true}, {"failTemplateData", args{&Options{X509: &X509Options{TemplateData: []byte(`{"badJSON}`)}}, data, x509util.DefaultLeafTemplate, SignOptions{}}, x509util.Options{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cof, err := CustomTemplateOptions(tt.args.o, tt.args.data, tt.args.defaultTemplate) if (err != nil) != tt.wantErr { t.Errorf("CustomTemplateOptions() error = %v, wantErr %v", err, tt.wantErr) return } var opts x509util.Options if cof != nil { for _, fn := range cof.Options(tt.args.userOptions) { if err := fn(csr, &opts); err != nil { t.Errorf("x509util.Options() error = %v", err) return } } } if !reflect.DeepEqual(opts, tt.want) { t.Errorf("x509util.Option = %v, want %v", opts, tt.want) } }) } } func Test_unsafeParseSigned(t *testing.T) { //nolint:gosec // no credentials here okToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJqYW5lQGRvZS5jb20iLCJpc3MiOiJodHRwczovL2RvZS5jb20iLCJqdGkiOiI4ZmYzMjQ4MS1mZDVmLTRlMmUtOTZkZi05MDhjMTI3Yzg1ZjciLCJpYXQiOjE1OTUzNjAwMjgsImV4cCI6MTU5NTM2MzYyOH0.aid8UuhFucJOFHXaob9zpNtVvhul9ulTGsA52mU6XIw" type args struct { s string } tests := []struct { name string args args want map[string]interface{} wantErr bool }{ {"ok", args{okToken}, map[string]interface{}{ "sub": "jane@doe.com", "iss": "https://doe.com", "jti": "8ff32481-fd5f-4e2e-96df-908c127c85f7", "iat": float64(1595360028), "exp": float64(1595363628), }, false}, {"failToken", args{"foobar"}, nil, true}, {"failPayload", args{"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.ew.aid8UuhFucJOFHXaob9zpNtVvhul9ulTGsA52mU6XIw"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := unsafeParseSigned(tt.args.s) if (err != nil) != tt.wantErr { t.Errorf("unsafeParseSigned() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("unsafeParseSigned() = \n%v, want \n%v", got, tt.want) } }) } } func TestX509Options_IsWildcardLiteralAllowed(t *testing.T) { tests := []struct { name string options *X509Options want bool }{ { name: "nil-options", options: nil, want: true, }, { name: "set-true", options: &X509Options{ AllowWildcardNames: true, }, want: true, }, { name: "set-false", options: &X509Options{ AllowWildcardNames: false, }, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.options.AreWildcardNamesAllowed(); got != tt.want { t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/provisioner/policy.go ================================================ package provisioner import "github.com/smallstep/certificates/authority/policy" type policyEngine struct { x509Policy policy.X509Policy sshHostPolicy policy.HostPolicy sshUserPolicy policy.UserPolicy } func newPolicyEngine(options *Options) (*policyEngine, error) { if options == nil { //nolint:nilnil // legacy return nil, nil } var ( x509Policy policy.X509Policy sshHostPolicy policy.HostPolicy sshUserPolicy policy.UserPolicy err error ) // Initialize the x509 allow/deny policy engine if x509Policy, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil { return nil, err } // Initialize the SSH allow/deny policy engine for host certificates if sshHostPolicy, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil { return nil, err } // Initialize the SSH allow/deny policy engine for user certificates if sshUserPolicy, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil { return nil, err } return &policyEngine{ x509Policy: x509Policy, sshHostPolicy: sshHostPolicy, sshUserPolicy: sshUserPolicy, }, nil } func (p *policyEngine) getX509() policy.X509Policy { if p == nil { return nil } return p.x509Policy } func (p *policyEngine) getSSHHost() policy.HostPolicy { if p == nil { return nil } return p.sshHostPolicy } func (p *policyEngine) getSSHUser() policy.UserPolicy { if p == nil { return nil } return p.sshUserPolicy } ================================================ FILE: authority/provisioner/provisioner.go ================================================ package provisioner import ( "context" "crypto/x509" "encoding/json" stderrors "errors" "net/http" "net/url" "strings" "github.com/pkg/errors" kmsapi "go.step.sm/crypto/kms/apiv1" "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/errs" ) // Interface is the interface that all provisioner types must implement. type Interface interface { GetID() string GetIDForToken() string GetTokenID(token string) (string, error) GetName() string GetType() Type GetEncryptedKey() (kid string, key string, ok bool) Init(config Config) error AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) AuthorizeRevoke(ctx context.Context, token string) error AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) AuthorizeSSHRevoke(ctx context.Context, token string) error AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) } // HTTPClient is the interface implemented by the HTTP clients used by the // provisioners. type HTTPClient interface { Get(string) (*http.Response, error) Do(*http.Request) (*http.Response, error) } // Uninitialized represents a disabled provisioner. Uninitialized provisioners // are created when the Init methods fails. type Uninitialized struct { Interface Reason error } // MarshalJSON returns the JSON encoding of the provisioner with the disabled // reason. func (p Uninitialized) MarshalJSON() ([]byte, error) { provisionerJSON, err := json.Marshal(p.Interface) if err != nil { return nil, err } reasonJSON, err := json.Marshal(struct { State string `json:"state"` StateReason string `json:"stateReason"` }{"Uninitialized", p.Reason.Error()}) if err != nil { return nil, err } reasonJSON[0] = ',' return append(provisionerJSON[:len(provisionerJSON)-1], reasonJSON...), nil } // ErrAllowTokenReuse is an error that is returned by provisioners that allows // the reuse of tokens. // // This is, for example, returned by the Azure provisioner when // DisableTrustOnFirstUse is set to true. Azure caches tokens for up to 24hr and // has no mechanism for getting a different token - this can be an issue when // rebooting a VM. In contrast, AWS and GCP have facilities for requesting a new // token. Therefore, for the Azure provisioner we are enabling token reuse, with // the understanding that we are not following security best practices var ErrAllowTokenReuse = stderrors.New("allow token reuse") // ErrTokenFlowNotSupported is an error that is returned by provisioners on // GetTokenID when the use of tokens is not supported. var ErrTokenFlowNotSupported = stderrors.New("token flow is not supported") // ErrNotImplemented is an error returned when one method is not implemented. var ErrNotImplemented = stderrors.New("not implemented") // Audiences stores all supported audiences by request type. type Audiences struct { Sign []string Renew []string Revoke []string SSHSign []string SSHRevoke []string SSHRenew []string SSHRekey []string } // All returns all supported audiences across all request types in one list. func (a Audiences) All() (auds []string) { auds = a.Sign auds = append(auds, a.Renew...) auds = append(auds, a.Revoke...) auds = append(auds, a.SSHSign...) auds = append(auds, a.SSHRevoke...) auds = append(auds, a.SSHRenew...) auds = append(auds, a.SSHRekey...) return } // WithFragment returns a copy of audiences where the url audiences contains the // given fragment. func (a Audiences) WithFragment(fragment string) Audiences { ret := Audiences{ Sign: make([]string, len(a.Sign)), Renew: make([]string, len(a.Renew)), Revoke: make([]string, len(a.Revoke)), SSHSign: make([]string, len(a.SSHSign)), SSHRevoke: make([]string, len(a.SSHRevoke)), SSHRenew: make([]string, len(a.SSHRenew)), SSHRekey: make([]string, len(a.SSHRekey)), } for i, s := range a.Sign { if u, err := url.Parse(s); err == nil { ret.Sign[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() } else { ret.Sign[i] = s } } for i, s := range a.Renew { if u, err := url.Parse(s); err == nil { ret.Renew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() } else { ret.Renew[i] = s } } for i, s := range a.Revoke { if u, err := url.Parse(s); err == nil { ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() } else { ret.Revoke[i] = s } } for i, s := range a.SSHSign { if u, err := url.Parse(s); err == nil { ret.SSHSign[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() } else { ret.SSHSign[i] = s } } for i, s := range a.SSHRevoke { if u, err := url.Parse(s); err == nil { ret.SSHRevoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() } else { ret.SSHRevoke[i] = s } } for i, s := range a.SSHRenew { if u, err := url.Parse(s); err == nil { ret.SSHRenew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() } else { ret.SSHRenew[i] = s } } for i, s := range a.SSHRekey { if u, err := url.Parse(s); err == nil { ret.SSHRekey[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() } else { ret.SSHRekey[i] = s } } return ret } // generateSignAudience generates a sign audience with the format // https:///1.0/sign#provisionerID func generateSignAudience(caURL, provisionerID string) (string, error) { u, err := url.Parse(caURL) if err != nil { return "", errors.Wrapf(err, "error parsing %s", caURL) } return u.ResolveReference(&url.URL{Path: "/1.0/sign", Fragment: provisionerID}).String(), nil } // Type indicates the provisioner Type. type Type int const ( noopType Type = 0 // TypeJWK is used to indicate the JWK provisioners. TypeJWK Type = 1 // TypeOIDC is used to indicate the OIDC provisioners. TypeOIDC Type = 2 // TypeGCP is used to indicate the GCP provisioners. TypeGCP Type = 3 // TypeAWS is used to indicate the AWS provisioners. TypeAWS Type = 4 // TypeAzure is used to indicate the Azure provisioners. TypeAzure Type = 5 // TypeACME is used to indicate the ACME provisioners. TypeACME Type = 6 // TypeX5C is used to indicate the X5C provisioners. TypeX5C Type = 7 // TypeK8sSA is used to indicate the X5C provisioners. TypeK8sSA Type = 8 // TypeSSHPOP is used to indicate the SSHPOP provisioners. TypeSSHPOP Type = 9 // TypeSCEP is used to indicate the SCEP provisioners TypeSCEP Type = 10 // TypeNebula is used to indicate the Nebula provisioners TypeNebula Type = 11 ) // String returns the string representation of the type. func (t Type) String() string { switch t { case TypeJWK: return "JWK" case TypeOIDC: return "OIDC" case TypeGCP: return "GCP" case TypeAWS: return "AWS" case TypeAzure: return "Azure" case TypeACME: return "ACME" case TypeX5C: return "X5C" case TypeK8sSA: return "K8sSA" case TypeSSHPOP: return "SSHPOP" case TypeSCEP: return "SCEP" case TypeNebula: return "Nebula" default: return "" } } // SSHKeys represents the SSH User and Host public keys. type SSHKeys struct { UserKeys []ssh.PublicKey HostKeys []ssh.PublicKey } // SCEPKeyManager is a KMS interface that combines a KeyManager with a // Decrypter. type SCEPKeyManager interface { kmsapi.KeyManager kmsapi.Decrypter } // Config defines the default parameters used in the initialization of // provisioners. type Config struct { // Claims are the default claims. Claims Claims // Audiences are the audiences used in the default provisioner, (JWK). Audiences Audiences // SSHKeys are the root SSH public keys. SSHKeys *SSHKeys // GetIdentityFunc is a function that returns an identity that will be // used by the provisioner to populate certificate attributes. GetIdentityFunc GetIdentityFunc // AuthorizeRenewFunc is a function that returns nil if a given X.509 // certificate can be renewed. AuthorizeRenewFunc AuthorizeRenewFunc // AuthorizeSSHRenewFunc is a function that returns nil if a given SSH // certificate can be renewed. AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc // WebhookClient is an HTTP client used when performing webhook requests. WebhookClient HTTPClient // SCEPKeyManager, if defined, is the interface used by SCEP provisioners. SCEPKeyManager SCEPKeyManager // HTTPClient is an HTTP client that trusts the system cert pool and the CA // roots. HTTPClient HTTPClient // WrapTransport references the function that should wrap any [http.Transport] initialized // down the Config's chain. WrapTransport TransportWrapper } type provisioner struct { Type string `json:"type"` } // List represents a list of provisioners. type List []Interface // UnmarshalJSON implements json.Unmarshaler and allows to unmarshal a list of a // interfaces into the right type. func (l *List) UnmarshalJSON(data []byte) error { ps := []json.RawMessage{} if err := json.Unmarshal(data, &ps); err != nil { return errors.Wrap(err, "error unmarshaling provisioner list") } *l = List{} for _, data := range ps { var typ provisioner if err := json.Unmarshal(data, &typ); err != nil { return errors.Errorf("error unmarshaling provisioner") } var p Interface switch strings.ToLower(typ.Type) { case "jwk": p = &JWK{} case "oidc": p = &OIDC{} case "gcp": p = &GCP{} case "aws": p = &AWS{} case "azure": p = &Azure{} case "acme": p = &ACME{} case "x5c": p = &X5C{} case "k8ssa": p = &K8sSA{} case "sshpop": p = &SSHPOP{} case "scep": p = &SCEP{} case "nebula": p = &Nebula{} default: // Skip unsupported provisioners. A client using this method may be // compiled with a version of smallstep/certificates that does not // support a specific provisioner type. If we don't skip unknown // provisioners, a client encountering an unknown provisioner will // break. Rather than break the client, we skip the provisioner. // TODO: accept a pluggable logger (depending on client) that can // warn the user that an unknown provisioner was found and suggest // that the user update their client's dependency on // step/certificates and recompile. continue } if err := json.Unmarshal(data, p); err != nil { return errors.Wrap(err, "error unmarshaling provisioner") } *l = append(*l, p) } return nil } type base struct{} // AuthorizeSign returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for signing x509 Certificates. func (b *base) AuthorizeSign(context.Context, string) ([]SignOption, error) { return nil, errs.Unauthorized("provisioner.AuthorizeSign not implemented") } // AuthorizeRevoke returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for revoking x509 Certificates. func (b *base) AuthorizeRevoke(context.Context, string) error { return errs.Unauthorized("provisioner.AuthorizeRevoke not implemented") } // AuthorizeRenew returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for renewing x509 Certificates. func (b *base) AuthorizeRenew(context.Context, *x509.Certificate) error { return errs.Unauthorized("provisioner.AuthorizeRenew not implemented") } // AuthorizeSSHSign returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for signing SSH Certificates. func (b *base) AuthorizeSSHSign(context.Context, string) ([]SignOption, error) { return nil, errs.Unauthorized("provisioner.AuthorizeSSHSign not implemented") } // AuthorizeSSHRevoke returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for revoking SSH Certificates. func (b *base) AuthorizeSSHRevoke(context.Context, string) error { return errs.Unauthorized("provisioner.AuthorizeSSHRevoke not implemented") } // AuthorizeSSHRenew returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for renewing SSH Certificates. func (b *base) AuthorizeSSHRenew(context.Context, string) (*ssh.Certificate, error) { return nil, errs.Unauthorized("provisioner.AuthorizeSSHRenew not implemented") } // AuthorizeSSHRekey returns an unimplemented error. Provisioners should overwrite // this method if they will support authorizing tokens for rekeying SSH Certificates. func (b *base) AuthorizeSSHRekey(context.Context, string) (*ssh.Certificate, []SignOption, error) { return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented") } // Permissions defines extra extensions and critical options to grant to an SSH certificate. type Permissions struct { Extensions map[string]string `json:"extensions"` CriticalOptions map[string]string `json:"criticalOptions"` } // RAInfo is the information about a provisioner present in RA tokens generated // by StepCAS. type RAInfo struct { AuthorityID string `json:"authorityId,omitempty"` EndpointID string `json:"endpointId,omitempty"` ProvisionerID string `json:"provisionerId,omitempty"` ProvisionerType string `json:"provisionerType,omitempty"` ProvisionerName string `json:"provisionerName,omitempty"` } // raProvisioner wraps a provisioner with RA data. type raProvisioner struct { Interface raInfo *RAInfo } // RAInfo returns the RAInfo in the wrapped provisioner. func (p *raProvisioner) RAInfo() *RAInfo { return p.raInfo } // MockProvisioner for testing type MockProvisioner struct { Mret1, Mret2, Mret3 interface{} Merr error MgetID func() string MgetIDForToken func() string MgetTokenID func(string) (string, error) MgetName func() string MgetType func() Type MgetEncryptedKey func() (string, string, bool) Minit func(Config) error MauthorizeSign func(ctx context.Context, ott string) ([]SignOption, error) MauthorizeRenew func(ctx context.Context, cert *x509.Certificate) error MauthorizeRevoke func(ctx context.Context, ott string) error MauthorizeSSHSign func(ctx context.Context, ott string) ([]SignOption, error) MauthorizeSSHRenew func(ctx context.Context, ott string) (*ssh.Certificate, error) MauthorizeSSHRekey func(ctx context.Context, ott string) (*ssh.Certificate, []SignOption, error) MauthorizeSSHRevoke func(ctx context.Context, ott string) error } // GetID mock func (m *MockProvisioner) GetID() string { if m.MgetID != nil { return m.MgetID() } return m.Mret1.(string) } // GetIDForToken mock func (m *MockProvisioner) GetIDForToken() string { if m.MgetIDForToken != nil { return m.MgetIDForToken() } return m.Mret1.(string) } // GetTokenID mock func (m *MockProvisioner) GetTokenID(token string) (string, error) { if m.MgetTokenID != nil { return m.MgetTokenID(token) } if m.Mret1 == nil { return "", m.Merr } return m.Mret1.(string), m.Merr } // GetName mock func (m *MockProvisioner) GetName() string { if m.MgetName != nil { return m.MgetName() } return m.Mret1.(string) } // GetType mock func (m *MockProvisioner) GetType() Type { if m.MgetType != nil { return m.MgetType() } return m.Mret1.(Type) } // GetEncryptedKey mock func (m *MockProvisioner) GetEncryptedKey() (string, string, bool) { if m.MgetEncryptedKey != nil { return m.MgetEncryptedKey() } return m.Mret1.(string), m.Mret2.(string), m.Mret3.(bool) } // Init mock func (m *MockProvisioner) Init(c Config) error { if m.Minit != nil { return m.Minit(c) } return m.Merr } // AuthorizeSign mock func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]SignOption, error) { if m.MauthorizeSign != nil { return m.MauthorizeSign(ctx, ott) } return m.Mret1.([]SignOption), m.Merr } // AuthorizeRevoke mock func (m *MockProvisioner) AuthorizeRevoke(ctx context.Context, ott string) error { if m.MauthorizeRevoke != nil { return m.MauthorizeRevoke(ctx, ott) } return m.Merr } // AuthorizeRenew mock func (m *MockProvisioner) AuthorizeRenew(ctx context.Context, c *x509.Certificate) error { if m.MauthorizeRenew != nil { return m.MauthorizeRenew(ctx, c) } return m.Merr } // AuthorizeSSHSign mock func (m *MockProvisioner) AuthorizeSSHSign(ctx context.Context, ott string) ([]SignOption, error) { if m.MauthorizeSign != nil { return m.MauthorizeSign(ctx, ott) } return m.Mret1.([]SignOption), m.Merr } // AuthorizeSSHRenew mock func (m *MockProvisioner) AuthorizeSSHRenew(ctx context.Context, ott string) (*ssh.Certificate, error) { if m.MauthorizeRenew != nil { return m.MauthorizeSSHRenew(ctx, ott) } return m.Mret1.(*ssh.Certificate), m.Merr } // AuthorizeSSHRekey mock func (m *MockProvisioner) AuthorizeSSHRekey(ctx context.Context, ott string) (*ssh.Certificate, []SignOption, error) { if m.MauthorizeSSHRekey != nil { return m.MauthorizeSSHRekey(ctx, ott) } return m.Mret1.(*ssh.Certificate), m.Mret2.([]SignOption), m.Merr } // AuthorizeSSHRevoke mock func (m *MockProvisioner) AuthorizeSSHRevoke(ctx context.Context, ott string) error { if m.MauthorizeSSHRevoke != nil { return m.MauthorizeSSHRevoke(ctx, ott) } return m.Merr } ================================================ FILE: authority/provisioner/provisioner_test.go ================================================ package provisioner import ( "context" "errors" "net/http" "testing" "github.com/go-jose/go-jose/v3" "github.com/smallstep/certificates/api/render" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ssh" ) func TestType_String(t *testing.T) { tests := []struct { name string t Type want string }{ {"JWK", TypeJWK, "JWK"}, {"OIDC", TypeOIDC, "OIDC"}, {"AWS", TypeAWS, "AWS"}, {"Azure", TypeAzure, "Azure"}, {"GCP", TypeGCP, "GCP"}, {"noop", noopType, ""}, {"notFound", 1000, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.t.String(); got != tt.want { t.Errorf("Type.String() = %v, want %v", got, tt.want) } }) } } func TestSanitizeSSHUserPrincipal(t *testing.T) { type args struct { email string } tests := []struct { name string args args want string }{ {"simple", args{"foobar"}, "foobar"}, {"camelcase", args{"FooBar"}, "foobar"}, {"email", args{"foo@example.com"}, "foo"}, {"email with dots", args{"foo.bar.zar@example.com"}, "foobarzar"}, {"email with dashes", args{"foo-bar-zar@example.com"}, "foo-bar-zar"}, {"email with underscores", args{"foo_bar_zar@example.com"}, "foo_bar_zar"}, {"email with symbols", args{"Foo.Bar0123456789!#$%&'*+-/=?^_`{|}~;@example.com"}, "foobar0123456789________-___________"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := SanitizeSSHUserPrincipal(tt.args.email); got != tt.want { t.Errorf("SanitizeSSHUserPrincipal() = %v, want %v", got, tt.want) } }) } } func TestDefaultIdentityFunc(t *testing.T) { type test struct { p Interface email string usernames []string err error identity *Identity } tests := map[string]func(*testing.T) test{ "fail/unsupported-provisioner": func(t *testing.T) test { return test{ p: &X5C{}, err: errors.New("provisioner type '*provisioner.X5C' not supported by identity function"), } }, "ok": func(t *testing.T) test { return test{ p: &OIDC{}, email: "max.furman@smallstep.com", identity: &Identity{Usernames: []string{"maxfurman", "max.furman", "max.furman@smallstep.com"}}, } }, "ok letter case": func(t *testing.T) test { return test{ p: &OIDC{}, email: "Max.Furman@smallstep.com", identity: &Identity{Usernames: []string{"maxfurman", "Max.Furman", "Max.Furman@smallstep.com"}}, } }, "ok simple": func(t *testing.T) test { return test{ p: &OIDC{}, email: "john@smallstep.com", identity: &Identity{Usernames: []string{"john", "john@smallstep.com"}}, } }, "ok simple letter case": func(t *testing.T) test { return test{ p: &OIDC{}, email: "John@smallstep.com", identity: &Identity{Usernames: []string{"john", "John", "John@smallstep.com"}}, } }, "ok symbol": func(t *testing.T) test { return test{ p: &OIDC{}, email: "John+Doe@smallstep.com", identity: &Identity{Usernames: []string{"john_doe", "John+Doe", "John+Doe@smallstep.com"}}, } }, "ok username": func(t *testing.T) test { return test{ p: &OIDC{}, email: "john@smallstep.com", usernames: []string{"johnny"}, identity: &Identity{Usernames: []string{"john", "john@smallstep.com"}}, } }, "ok usernames": func(t *testing.T) test { return test{ p: &OIDC{}, email: "john@smallstep.com", usernames: []string{"johnny", "js", "", "johnny", ""}, identity: &Identity{Usernames: []string{"john", "john@smallstep.com"}}, } }, "ok empty username": func(t *testing.T) test { return test{ p: &OIDC{}, email: "john@smallstep.com", usernames: []string{""}, identity: &Identity{Usernames: []string{"john", "john@smallstep.com"}}, } }, "ok/badname": func(t *testing.T) test { return test{ p: &OIDC{}, email: "$%^#_>@smallstep.com", identity: &Identity{Usernames: []string{"______", "$%^#_>", "$%^#_>@smallstep.com"}}, } }, } for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) identity, err := DefaultIdentityFunc(context.Background(), tc.p, tc.email) if err != nil { if assert.NotNil(t, tc.err) { assert.Equal(t, tc.err.Error(), err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equal(t, identity.Usernames, tc.identity.Usernames) } } }) } } func TestUnimplementedMethods(t *testing.T) { tests := []struct { name string p Interface method Method }{ {"jwk/sshRekey", &JWK{}, SSHRekeyMethod}, {"jwk/sshRenew", &JWK{}, SSHRenewMethod}, {"aws/revoke", &AWS{}, RevokeMethod}, {"aws/sshRenew", &AWS{}, SSHRenewMethod}, {"aws/rekey", &AWS{}, SSHRekeyMethod}, {"aws/sshRevoke", &AWS{}, SSHRevokeMethod}, {"azure/revoke", &Azure{}, RevokeMethod}, {"azure/sshRenew", &Azure{}, SSHRenewMethod}, {"azure/sshRekey", &Azure{}, SSHRekeyMethod}, {"azure/sshRevoke", &Azure{}, SSHRevokeMethod}, {"gcp/revoke", &GCP{}, RevokeMethod}, {"gcp/sshRenew", &GCP{}, SSHRenewMethod}, {"gcp/sshRekey", &GCP{}, SSHRekeyMethod}, {"gcp/sshRevoke", &GCP{}, SSHRevokeMethod}, {"oidc/sshRenew", &OIDC{}, SSHRenewMethod}, {"oidc/sshRekey", &OIDC{}, SSHRekeyMethod}, {"x5c/sshRenew", &X5C{}, SSHRenewMethod}, {"x5c/sshRekey", &X5C{}, SSHRekeyMethod}, {"x5c/sshRevoke", &X5C{}, SSHRekeyMethod}, {"acme/sshSign", &ACME{}, SSHSignMethod}, {"acme/sshRekey", &ACME{}, SSHRekeyMethod}, {"acme/sshRenew", &ACME{}, SSHRenewMethod}, {"acme/sshRevoke", &ACME{}, SSHRevokeMethod}, {"sshpop/sign", &SSHPOP{}, SignMethod}, {"sshpop/renew", &SSHPOP{}, RenewMethod}, {"sshpop/revoke", &SSHPOP{}, RevokeMethod}, {"sshpop/sshSign", &SSHPOP{}, SSHSignMethod}, {"k8ssa/sshRekey", &K8sSA{}, SSHRekeyMethod}, {"k8ssa/sshRenew", &K8sSA{}, SSHRenewMethod}, {"k8ssa/sshRevoke", &K8sSA{}, SSHRevokeMethod}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var ( err error msg string ) switch tt.method { case SignMethod: var signOpts []SignOption signOpts, err = tt.p.AuthorizeSign(context.Background(), "") assert.Nil(t, signOpts) msg = "provisioner.AuthorizeSign not implemented" case RenewMethod: err = tt.p.AuthorizeRenew(context.Background(), nil) msg = "provisioner.AuthorizeRenew not implemented" case RevokeMethod: err = tt.p.AuthorizeRevoke(context.Background(), "") msg = "provisioner.AuthorizeRevoke not implemented" case SSHSignMethod: var signOpts []SignOption signOpts, err = tt.p.AuthorizeSSHSign(context.Background(), "") assert.Nil(t, signOpts) msg = "provisioner.AuthorizeSSHSign not implemented" case SSHRenewMethod: var cert *ssh.Certificate cert, err = tt.p.AuthorizeSSHRenew(context.Background(), "") assert.Nil(t, cert) msg = "provisioner.AuthorizeSSHRenew not implemented" case SSHRekeyMethod: var ( cert *ssh.Certificate signOpts []SignOption ) cert, signOpts, err = tt.p.AuthorizeSSHRekey(context.Background(), "") assert.Nil(t, cert) assert.Nil(t, signOpts) msg = "provisioner.AuthorizeSSHRekey not implemented" case SSHRevokeMethod: err = tt.p.AuthorizeSSHRevoke(context.Background(), "") msg = "provisioner.AuthorizeSSHRevoke not implemented" default: t.Errorf("unexpected method %s", tt.method) } var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equal(t, http.StatusUnauthorized, sc.StatusCode()) } assert.Equal(t, msg, err.Error()) }) } } func TestUninitialized_MarshalJSON(t *testing.T) { p := &JWK{ Name: "bad-provisioner", Type: "JWK", Key: &jose.JSONWebKey{ Key: []byte("foo"), }, } type fields struct { Interface Interface Reason error } tests := []struct { name string fields fields want []byte assertion assert.ErrorAssertionFunc }{ {"ok", fields{p, errors.New("bad key")}, []byte(`{"type":"JWK","name":"bad-provisioner","key":{"kty":"oct","k":"Zm9v"},"state":"Uninitialized","stateReason":"bad key"}`), assert.NoError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := Uninitialized{ Interface: tt.fields.Interface, Reason: tt.fields.Reason, } got, err := p.MarshalJSON() tt.assertion(t, err) assert.Equal(t, tt.want, got) }) } } ================================================ FILE: authority/provisioner/scep.go ================================================ package provisioner import ( "context" "crypto" "crypto/rsa" "crypto/subtle" "crypto/x509" "encoding/pem" "fmt" "time" "github.com/pkg/errors" "github.com/smallstep/linkedca" "go.step.sm/crypto/kms" kmsapi "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/webhook" ) // SCEP is the SCEP provisioner type, an entity that can authorize the // SCEP provisioning flow type SCEP struct { *base ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` ForceCN bool `json:"forceCN,omitempty"` ChallengePassword string `json:"challenge,omitempty"` Capabilities []string `json:"capabilities,omitempty"` // IncludeRoot makes the provisioner return the CA root in addition to the // intermediate in the GetCACerts response IncludeRoot bool `json:"includeRoot,omitempty"` // ExcludeIntermediate makes the provisioner skip the intermediate CA in the // GetCACerts response ExcludeIntermediate bool `json:"excludeIntermediate,omitempty"` // MinimumPublicKeyLength is the minimum length for public keys in CSRs MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` // TODO(hs): also support a separate signer configuration? DecrypterCertificate []byte `json:"decrypterCertificate,omitempty"` DecrypterKeyPEM []byte `json:"decrypterKeyPEM,omitempty"` DecrypterKeyURI string `json:"decrypterKey,omitempty"` DecrypterKeyPassword string `json:"decrypterKeyPassword,omitempty"` // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // Defaults to 0, being DES-CBC EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` Options *Options `json:"options,omitempty"` Claims *Claims `json:"claims,omitempty"` ctl *Controller encryptionAlgorithm int challengeValidationController *challengeValidationController notificationController *notificationController keyManager SCEPKeyManager decrypter crypto.Decrypter decrypterCertificate *x509.Certificate signer crypto.Signer signerCertificate *x509.Certificate } // GetID returns the provisioner unique identifier. func (s *SCEP) GetID() string { if s.ID != "" { return s.ID } return s.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (s *SCEP) GetIDForToken() string { return "scep/" + s.Name } // GetName returns the name of the provisioner. func (s *SCEP) GetName() string { return s.Name } // GetType returns the type of provisioner. func (s *SCEP) GetType() Type { return TypeSCEP } // GetEncryptedKey returns the base provisioner encrypted key if it's defined. func (s *SCEP) GetEncryptedKey() (string, string, bool) { return "", "", false } // GetTokenID returns the identifier of the token. This provisioner will always // return [ErrTokenFlowNotSupported]. func (s *SCEP) GetTokenID(string) (string, error) { return "", ErrTokenFlowNotSupported } // GetOptions returns the configured provisioner options. func (s *SCEP) GetOptions() *Options { return s.Options } // DefaultTLSCertDuration returns the default TLS cert duration enforced by // the provisioner. func (s *SCEP) DefaultTLSCertDuration() time.Duration { return s.ctl.Claimer.DefaultTLSCertDuration() } type challengeValidationController struct { client HTTPClient wrapTransport httptransport.Wrapper webhooks []*Webhook } // newChallengeValidationController creates a new challengeValidationController // that performs challenge validation through webhooks. func newChallengeValidationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController { scepHooks := []*Webhook{} for _, wh := range webhooks { if wh.Kind != linkedca.Webhook_SCEPCHALLENGE.String() { continue } if !isCertTypeOK(wh) { continue } scepHooks = append(scepHooks, wh) } return &challengeValidationController{ client: client, wrapTransport: tw, webhooks: scepHooks, } } var ( ErrSCEPChallengeInvalid = errors.New("webhook server did not allow request") ErrSCEPNotificationFailed = errors.New("scep notification failed") ) // Validate executes zero or more configured webhooks to // validate the SCEP challenge. If at least one of them indicates // the challenge value is accepted, validation succeeds. In // that case, the other webhooks will be skipped. If none of // the webhooks indicates the value of the challenge was accepted, // an error is returned. func (c *challengeValidationController) Validate(ctx context.Context, csr *x509.CertificateRequest, provisionerName, challenge, transactionID string) ([]SignCSROption, error) { var opts []SignCSROption for _, wh := range c.webhooks { req, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) if err != nil { return nil, fmt.Errorf("failed creating new webhook request: %w", err) } req.ProvisionerName = provisionerName req.SCEPChallenge = challenge req.SCEPTransactionID = transactionID resp, err := wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil) // TODO(hs): support templated URL? Requires some refactoring if err != nil { return nil, fmt.Errorf("failed executing webhook request: %w", err) } if resp.Allow { opts = append(opts, TemplateDataModifierFunc(func(data x509util.TemplateData) { data.SetWebhook(wh.Name, resp.Data) })) } } if len(opts) == 0 { return nil, ErrSCEPChallengeInvalid } return opts, nil } type notificationController struct { client HTTPClient wrapTransport httptransport.Wrapper webhooks []*Webhook } // newNotificationController creates a new notificationController // that performs SCEP notifications through webhooks. func newNotificationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController { scepHooks := []*Webhook{} for _, wh := range webhooks { if wh.Kind != linkedca.Webhook_NOTIFYING.String() { continue } if !isCertTypeOK(wh) { continue } scepHooks = append(scepHooks, wh) } return ¬ificationController{ client: client, wrapTransport: tw, webhooks: scepHooks, } } func (c *notificationController) Success(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, transactionID string) error { for _, wh := range c.webhooks { req, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr), webhook.WithX509Certificate(nil, cert)) // TODO(hs): pass in the x509util.Certifiate too? if err != nil { return fmt.Errorf("failed creating new webhook request: %w", err) } req.X509Certificate.Raw = cert.Raw // adding the full certificate DER bytes req.SCEPTransactionID = transactionID if _, err = wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil); err != nil { return fmt.Errorf("failed executing webhook request: %w: %w", ErrSCEPNotificationFailed, err) } } return nil } func (c *notificationController) Failure(ctx context.Context, csr *x509.CertificateRequest, transactionID string, errorCode int, errorDescription string) error { for _, wh := range c.webhooks { req, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) if err != nil { return fmt.Errorf("failed creating new webhook request: %w", err) } req.SCEPTransactionID = transactionID req.SCEPErrorCode = errorCode req.SCEPErrorDescription = errorDescription if _, err = wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil); err != nil { return fmt.Errorf("failed executing webhook request: %w: %w", ErrSCEPNotificationFailed, err) } } return nil } // isCertTypeOK returns whether or not the webhook can be used // with the SCEP challenge validation webhook controller. func isCertTypeOK(wh *Webhook) bool { if wh.CertType == linkedca.Webhook_ALL.String() || wh.CertType == "" { return true } return linkedca.Webhook_X509.String() == wh.CertType } // Init initializes and validates the fields of a SCEP type. func (s *SCEP) Init(config Config) (err error) { switch { case s.Type == "": return errors.New("provisioner type cannot be empty") case s.Name == "": return errors.New("provisioner name cannot be empty") } // Default to 2048 bits minimum public key length (for CSRs) if not set if s.MinimumPublicKeyLength == 0 { s.MinimumPublicKeyLength = 2048 } if s.MinimumPublicKeyLength%8 != 0 { return errors.Errorf("%d bits is not exactly divisible by 8", s.MinimumPublicKeyLength) } // Set the encryption algorithm to use s.encryptionAlgorithm = s.EncryptionAlgorithmIdentifier // TODO(hs): we might want to upgrade the default security to AES-CBC? if s.encryptionAlgorithm < 0 || s.encryptionAlgorithm > 4 { return errors.New("only encryption algorithm identifiers from 0 to 4 are valid") } // Prepare the SCEP challenge validator s.challengeValidationController = newChallengeValidationController( config.WebhookClient, config.WrapTransport, s.GetOptions().GetWebhooks(), ) // Prepare the SCEP notification controller s.notificationController = newNotificationController( config.WebhookClient, config.WrapTransport, s.GetOptions().GetWebhooks(), ) // parse the decrypter key PEM contents if available if len(s.DecrypterKeyPEM) > 0 { // try reading the PEM for validation block, rest := pem.Decode(s.DecrypterKeyPEM) if len(rest) > 0 { return errors.New("failed parsing decrypter key: trailing data") } if block == nil { return errors.New("failed parsing decrypter key: no PEM block found") } opts := kms.Options{ Type: kmsapi.SoftKMS, } km, err := kms.New(context.Background(), opts) if err != nil { return fmt.Errorf("failed initializing kms: %w", err) } scepKeyManager, ok := km.(SCEPKeyManager) if !ok { return fmt.Errorf("%q is not a kmsapi.Decrypter", opts.Type) } s.keyManager = scepKeyManager if s.decrypter, err = s.keyManager.CreateDecrypter(&kmsapi.CreateDecrypterRequest{ DecryptionKeyPEM: s.DecrypterKeyPEM, Password: []byte(s.DecrypterKeyPassword), PasswordPrompter: kmsapi.NonInteractivePasswordPrompter, }); err != nil { return fmt.Errorf("failed creating decrypter: %w", err) } if s.signer, err = s.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ SigningKeyPEM: s.DecrypterKeyPEM, // TODO(hs): support distinct signer key in the future? Password: []byte(s.DecrypterKeyPassword), PasswordPrompter: kmsapi.NonInteractivePasswordPrompter, }); err != nil { return fmt.Errorf("failed creating signer: %w", err) } } if s.DecrypterKeyURI != "" { kmsType, err := kmsapi.TypeOf(s.DecrypterKeyURI) if err != nil { return fmt.Errorf("failed parsing decrypter key: %w", err) } if config.SCEPKeyManager != nil { s.keyManager = config.SCEPKeyManager } else { if kmsType == kmsapi.DefaultKMS { kmsType = kmsapi.SoftKMS } opts := kms.Options{ Type: kmsType, URI: s.DecrypterKeyURI, } km, err := kms.New(context.Background(), opts) if err != nil { return fmt.Errorf("failed initializing kms: %w", err) } scepKeyManager, ok := km.(SCEPKeyManager) if !ok { return fmt.Errorf("%q is not a kmsapi.Decrypter", opts.Type) } s.keyManager = scepKeyManager } // Create decrypter and signer with the same key: // TODO(hs): support distinct signer key in the future? if s.decrypter, err = s.keyManager.CreateDecrypter(&kmsapi.CreateDecrypterRequest{ DecryptionKey: s.DecrypterKeyURI, Password: []byte(s.DecrypterKeyPassword), PasswordPrompter: kmsapi.NonInteractivePasswordPrompter, }); err != nil { return fmt.Errorf("failed creating decrypter: %w", err) } if s.signer, err = s.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ SigningKey: s.DecrypterKeyURI, Password: []byte(s.DecrypterKeyPassword), PasswordPrompter: kmsapi.NonInteractivePasswordPrompter, }); err != nil { return fmt.Errorf("failed creating signer: %w", err) } } // parse the decrypter certificate contents if available if len(s.DecrypterCertificate) > 0 { block, rest := pem.Decode(s.DecrypterCertificate) if len(rest) > 0 { return errors.New("failed parsing decrypter certificate: trailing data") } if block == nil { return errors.New("failed parsing decrypter certificate: no PEM block found") } if s.decrypterCertificate, err = x509.ParseCertificate(block.Bytes); err != nil { return fmt.Errorf("failed parsing decrypter certificate: %w", err) } // the decrypter certificate is also the signer certificate s.signerCertificate = s.decrypterCertificate } // TODO(hs): alternatively, check if the KMS keyManager is a CertificateManager // and load the certificate corresponding to the decryption key? // Final validation for the decrypter. if s.decrypter != nil { decrypterPublicKey, ok := s.decrypter.Public().(*rsa.PublicKey) if !ok { return fmt.Errorf("only RSA keys are supported") } if s.decrypterCertificate == nil { return fmt.Errorf("provisioner %q does not have a decrypter certificate set", s.Name) } if !decrypterPublicKey.Equal(s.decrypterCertificate.PublicKey) { return errors.New("mismatch between decrypter certificate and decrypter public keys") } } // TODO: add other, SCEP specific, options? s.ctl, err = NewController(s, s.Claims, config, s.Options) return } // AuthorizeSign does not do any verification, because all verification is handled // in the SCEP protocol. This method returns a list of modifiers / constraints // on the resulting certificate. func (s *SCEP) AuthorizeSign(context.Context, string) ([]SignOption, error) { return []SignOption{ s, // modifiers / withOptions newProvisionerExtensionOption(TypeSCEP, s.Name, "").WithControllerOptions(s.ctl), newForceCNOption(s.ForceCN), profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()), // validators newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(s.ctl.getPolicy().getX509()), s.ctl.newWebhookController(nil, linkedca.Webhook_X509), }, nil } // GetCapabilities returns the CA capabilities func (s *SCEP) GetCapabilities() []string { return s.Capabilities } // ShouldIncludeRootInChain indicates if the CA should // return its intermediate, which is currently used for // both signing and decryption, as well as the root in // its chain. func (s *SCEP) ShouldIncludeRootInChain() bool { return s.IncludeRoot } // ShouldIncludeIntermediateInChain indicates if the // CA should include the intermediate CA certificate in the // GetCACerts response. This is true by default, but can be // overridden through configuration in case SCEP clients // don't pick the right recipient. func (s *SCEP) ShouldIncludeIntermediateInChain() bool { return !s.ExcludeIntermediate } // GetContentEncryptionAlgorithm returns the numeric identifier // for the pkcs7 package encryption algorithm to use. func (s *SCEP) GetContentEncryptionAlgorithm() int { return s.encryptionAlgorithm } // ValidateChallenge validates the provided challenge. It starts by // selecting the validation method to use, then performs validation // according to that method. func (s *SCEP) ValidateChallenge(ctx context.Context, csr *x509.CertificateRequest, challenge, transactionID string) ([]SignCSROption, error) { if s.challengeValidationController == nil { return nil, fmt.Errorf("provisioner %q wasn't initialized", s.Name) } switch s.selectValidationMethod() { case validationMethodWebhook: return s.challengeValidationController.Validate(ctx, csr, s.Name, challenge, transactionID) default: if subtle.ConstantTimeCompare([]byte(s.ChallengePassword), []byte(challenge)) == 0 { return nil, errors.New("invalid challenge password provided") } return []SignCSROption{}, nil } } func (s *SCEP) NotifySuccess(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, transactionID string) error { if s.notificationController == nil { return fmt.Errorf("provisioner %q wasn't initialized", s.Name) } return s.notificationController.Success(ctx, csr, cert, transactionID) } func (s *SCEP) NotifyFailure(ctx context.Context, csr *x509.CertificateRequest, transactionID string, errorCode int, errorDescription string) error { if s.notificationController == nil { return fmt.Errorf("provisioner %q wasn't initialized", s.Name) } return s.notificationController.Failure(ctx, csr, transactionID, errorCode, errorDescription) } type validationMethod string const ( validationMethodNone validationMethod = "none" validationMethodStatic validationMethod = "static" validationMethodWebhook validationMethod = "webhook" ) // selectValidationMethod returns the method to validate SCEP // challenges. If a webhook is configured with kind `SCEPCHALLENGE`, // the webhook method will be used. If a challenge password is set, // the static method is used. It will default to the `none` method. func (s *SCEP) selectValidationMethod() validationMethod { if len(s.challengeValidationController.webhooks) > 0 { return validationMethodWebhook } if s.ChallengePassword != "" { return validationMethodStatic } return validationMethodNone } // GetDecrypter returns the provisioner specific decrypter, // used to decrypt SCEP request messages sent by a SCEP client. // The decrypter consists of a crypto.Decrypter (a private key) // and a certificate for the public key corresponding to the // private key. func (s *SCEP) GetDecrypter() (*x509.Certificate, crypto.Decrypter) { return s.decrypterCertificate, s.decrypter } // GetSigner returns the provisioner specific signer, used to // sign SCEP response messages for the client. The signer consists // of a crypto.Signer and a certificate for the public key // corresponding to the private key. func (s *SCEP) GetSigner() (*x509.Certificate, crypto.Signer) { return s.signerCertificate, s.signer } ================================================ FILE: authority/provisioner/scep_test.go ================================================ package provisioner import ( "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/json" "encoding/pem" "errors" "net/http" "net/http/httptest" "os" "path/filepath" "testing" "github.com/smallstep/certificates/webhook" "github.com/smallstep/linkedca" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/kms/softkms" "go.step.sm/crypto/minica" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" ) func generateSCEP(t *testing.T) *SCEP { t.Helper() ca, err := minica.New() require.NoError(t, err) key, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) cert, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "SCEP decrypter"}, PublicKey: key.Public(), }) require.NoError(t, err) certPEM := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: cert.Raw, }) block, err := pemutil.Serialize(key, pemutil.WithPassword([]byte("password"))) require.NoError(t, err) keyPEM := pem.EncodeToMemory(block) p := &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, } require.NoError(t, p.Init(Config{Claims: globalProvisionerClaims})) return p } func Test_challengeValidationController_Validate(t *testing.T) { dummyCSR := &x509.CertificateRequest{ Raw: []byte{1}, } type request struct { ProvisionerName string `json:"provisionerName,omitempty"` Request *webhook.X509CertificateRequest `json:"x509CertificateRequest,omitempty"` Challenge string `json:"scepChallenge"` TransactionID string `json:"scepTransactionID"` } type response struct { Allow bool `json:"allow"` Data any `json:"data"` } nokServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req := &request{} err := json.NewDecoder(r.Body).Decode(req) require.NoError(t, err) assert.Equal(t, "my-scep-provisioner", req.ProvisionerName) assert.Equal(t, "not-allowed", req.Challenge) assert.Equal(t, "transaction-1", req.TransactionID) b, err := json.Marshal(response{Allow: false}) require.NoError(t, err) w.WriteHeader(200) w.Write(b) })) okServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req := &request{} err := json.NewDecoder(r.Body).Decode(req) require.NoError(t, err) assert.Equal(t, "my-scep-provisioner", req.ProvisionerName) assert.Equal(t, "challenge", req.Challenge) assert.Equal(t, "transaction-1", req.TransactionID) if assert.NotNil(t, req.Request) { assert.Equal(t, []byte{1}, req.Request.Raw) } resp := response{Allow: true} if r.Header.Get("X-Smallstep-Webhook-Id") == "webhook-id-2" { resp.Data = map[string]any{ "ID": "2adcbfec-5e4a-4b93-8913-640e24faf101", "Email": "admin@example.com", } } b, err := json.Marshal(resp) require.NoError(t, err) w.WriteHeader(200) w.Write(b) })) t.Cleanup(func() { nokServer.Close() okServer.Close() }) type fields struct { client *http.Client webhooks []*Webhook } type args struct { provisionerName string challenge string transactionID string } tests := []struct { name string fields fields args args want x509util.TemplateData expErr error }{ { name: "fail/no-webhook", fields: fields{http.DefaultClient, nil}, args: args{"my-scep-provisioner", "no-webhook", "transaction-1"}, expErr: errors.New("webhook server did not allow request"), }, { name: "fail/wrong-cert-type", fields: fields{http.DefaultClient, []*Webhook{ { Kind: linkedca.Webhook_SCEPCHALLENGE.String(), CertType: linkedca.Webhook_SSH.String(), }, }}, args: args{"my-scep-provisioner", "wrong-cert-type", "transaction-1"}, expErr: errors.New("webhook server did not allow request"), }, { name: "fail/wrong-secret-value", fields: fields{http.DefaultClient, []*Webhook{ { ID: "webhook-id-1", Name: "webhook-name-1", Secret: "{{}}", Kind: linkedca.Webhook_SCEPCHALLENGE.String(), CertType: linkedca.Webhook_X509.String(), URL: okServer.URL, }, }}, args: args{ provisionerName: "my-scep-provisioner", challenge: "wrong-secret-value", transactionID: "transaction-1", }, expErr: errors.New("failed executing webhook request: illegal base64 data at input byte 0"), }, { name: "fail/not-allowed", fields: fields{http.DefaultClient, []*Webhook{ { ID: "webhook-id-1", Name: "webhook-name-1", Secret: "MTIzNAo=", Kind: linkedca.Webhook_SCEPCHALLENGE.String(), CertType: linkedca.Webhook_X509.String(), URL: nokServer.URL, }, }}, args: args{ provisionerName: "my-scep-provisioner", challenge: "not-allowed", transactionID: "transaction-1", }, expErr: errors.New("webhook server did not allow request"), }, { name: "ok", fields: fields{http.DefaultClient, []*Webhook{ { ID: "webhook-id-1", Name: "webhook-name-1", Secret: "MTIzNAo=", Kind: linkedca.Webhook_SCEPCHALLENGE.String(), CertType: linkedca.Webhook_X509.String(), URL: okServer.URL, }, }}, args: args{ provisionerName: "my-scep-provisioner", challenge: "challenge", transactionID: "transaction-1", }, want: x509util.TemplateData{ x509util.WebhooksKey: map[string]any{ "webhook-name-1": nil, }, }, }, { name: "ok with data", fields: fields{http.DefaultClient, []*Webhook{ { ID: "webhook-id-2", Name: "webhook-name-2", Secret: "MTIzNAo=", Kind: linkedca.Webhook_SCEPCHALLENGE.String(), CertType: linkedca.Webhook_X509.String(), URL: okServer.URL, }, }}, args: args{ provisionerName: "my-scep-provisioner", challenge: "challenge", transactionID: "transaction-1", }, want: x509util.TemplateData{ x509util.WebhooksKey: map[string]any{ "webhook-name-2": map[string]any{ "ID": "2adcbfec-5e4a-4b93-8913-640e24faf101", "Email": "admin@example.com", }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := newChallengeValidationController(tt.fields.client, nil, tt.fields.webhooks) ctx := context.Background() got, err := c.Validate(ctx, dummyCSR, tt.args.provisionerName, tt.args.challenge, tt.args.transactionID) if tt.expErr != nil { assert.EqualError(t, err, tt.expErr.Error()) return } assert.NoError(t, err) data := x509util.TemplateData{} for _, o := range got { if m, ok := o.(TemplateDataModifier); ok { m.Modify(data) } else { t.Errorf("Validate() got = %T, want TemplateDataModifier", o) } } assert.Equal(t, tt.want, data) }) } } func TestController_isCertTypeOK(t *testing.T) { assert.True(t, isCertTypeOK(&Webhook{CertType: linkedca.Webhook_X509.String()})) assert.True(t, isCertTypeOK(&Webhook{CertType: linkedca.Webhook_ALL.String()})) assert.True(t, isCertTypeOK(&Webhook{CertType: ""})) assert.False(t, isCertTypeOK(&Webhook{CertType: linkedca.Webhook_SSH.String()})) } func Test_selectValidationMethod(t *testing.T) { tests := []struct { name string p *SCEP want validationMethod }{ {"webhooks", &SCEP{ Name: "SCEP", Type: "SCEP", Options: &Options{ Webhooks: []*Webhook{ { Name: "challenge", URL: "https://scep.challenge", Kind: linkedca.Webhook_SCEPCHALLENGE.String(), }, }, }, }, "webhook"}, {"challenge", &SCEP{ Name: "SCEP", Type: "SCEP", ChallengePassword: "pass", }, "static"}, {"challenge-with-different-webhook", &SCEP{ Name: "SCEP", Type: "SCEP", Options: &Options{ Webhooks: []*Webhook{ { Name: "authorizing", URL: "https://scep.authorizing", Kind: linkedca.Webhook_AUTHORIZING.String(), }, }, }, ChallengePassword: "pass", }, "static"}, {"none", &SCEP{ Name: "SCEP", Type: "SCEP", }, "none"}, {"none-with-different-webhook", &SCEP{ Name: "SCEP", Type: "SCEP", Options: &Options{ Webhooks: []*Webhook{ { Name: "authorizing", URL: "https://scep.authorizing", Kind: linkedca.Webhook_AUTHORIZING.String(), }, }, }, }, "none"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.p.Init(Config{Claims: globalProvisionerClaims}) require.NoError(t, err) got := tt.p.selectValidationMethod() assert.Equal(t, tt.want, got) }) } } func TestSCEP_ValidateChallenge(t *testing.T) { dummyCSR := &x509.CertificateRequest{ Raw: []byte{1}, } type request struct { ProvisionerName string `json:"provisionerName,omitempty"` Request *webhook.X509CertificateRequest `json:"x509CertificateRequest,omitempty"` Challenge string `json:"scepChallenge"` TransactionID string `json:"scepTransactionID"` } type response struct { Allow bool `json:"allow"` Data any `json:"data"` } okServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req := &request{} err := json.NewDecoder(r.Body).Decode(req) require.NoError(t, err) assert.Equal(t, "SCEP", req.ProvisionerName) assert.Equal(t, "webhook-challenge", req.Challenge) assert.Equal(t, "webhook-transaction-1", req.TransactionID) if assert.NotNil(t, req.Request) { assert.Equal(t, []byte{1}, req.Request.Raw) } resp := response{Allow: true} if r.Header.Get("X-Smallstep-Webhook-Id") == "webhook-id-2" { resp.Data = map[string]any{ "ID": "2adcbfec-5e4a-4b93-8913-640e24faf101", "Email": "admin@example.com", } } b, err := json.Marshal(resp) require.NoError(t, err) w.WriteHeader(200) w.Write(b) })) httpclient := okServer.Client() t.Cleanup(okServer.Close) type args struct { challenge string transactionID string } tests := []struct { name string p *SCEP server *httptest.Server args args want x509util.TemplateData expErr error }{ {"ok/webhooks", &SCEP{ Name: "SCEP", Type: "SCEP", Options: &Options{ Webhooks: []*Webhook{ { ID: "webhook-id-1", Name: "webhook-name-1", Secret: "MTIzNAo=", Kind: linkedca.Webhook_SCEPCHALLENGE.String(), CertType: linkedca.Webhook_X509.String(), URL: okServer.URL, }, }, }, }, okServer, args{"webhook-challenge", "webhook-transaction-1"}, x509util.TemplateData{ x509util.WebhooksKey: map[string]any{ "webhook-name-1": nil, }, }, nil}, {"ok/with-data", &SCEP{ Name: "SCEP", Type: "SCEP", Options: &Options{ Webhooks: []*Webhook{ { ID: "webhook-id-1", Name: "webhook-name-1", Secret: "MTIzNAo=", Kind: linkedca.Webhook_SCEPCHALLENGE.String(), CertType: linkedca.Webhook_X509.String(), URL: okServer.URL, }, { ID: "webhook-id-2", Name: "webhook-name-2", Secret: "MTIzNAo=", Kind: linkedca.Webhook_SCEPCHALLENGE.String(), CertType: linkedca.Webhook_X509.String(), URL: okServer.URL, }, }, }, }, okServer, args{"webhook-challenge", "webhook-transaction-1"}, x509util.TemplateData{ x509util.WebhooksKey: map[string]any{ "webhook-name-1": nil, "webhook-name-2": map[string]any{ "ID": "2adcbfec-5e4a-4b93-8913-640e24faf101", "Email": "admin@example.com", }, }, }, nil}, {"fail/webhooks-secret-configuration", &SCEP{ Name: "SCEP", Type: "SCEP", Options: &Options{ Webhooks: []*Webhook{ { ID: "webhook-id-1", Name: "webhook-name-1", Secret: "{{}}", Kind: linkedca.Webhook_SCEPCHALLENGE.String(), CertType: linkedca.Webhook_X509.String(), URL: okServer.URL, }, }, }, }, nil, args{"webhook-challenge", "webhook-transaction-1"}, nil, errors.New("failed executing webhook request: illegal base64 data at input byte 0")}, {"ok/static-challenge", &SCEP{ Name: "SCEP", Type: "SCEP", Options: &Options{}, ChallengePassword: "secret-static-challenge", }, nil, args{"secret-static-challenge", "static-transaction-1"}, x509util.TemplateData{}, nil}, {"fail/wrong-static-challenge", &SCEP{ Name: "SCEP", Type: "SCEP", Options: &Options{}, ChallengePassword: "secret-static-challenge", }, nil, args{"the-wrong-challenge-secret", "static-transaction-1"}, nil, errors.New("invalid challenge password provided")}, {"ok/no-challenge", &SCEP{ Name: "SCEP", Type: "SCEP", Options: &Options{}, ChallengePassword: "", }, nil, args{"", "static-transaction-1"}, x509util.TemplateData{}, nil}, {"fail/no-challenge-but-provided", &SCEP{ Name: "SCEP", Type: "SCEP", Options: &Options{}, ChallengePassword: "", }, nil, args{"a-challenge-value", "static-transaction-1"}, nil, errors.New("invalid challenge password provided")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.p.Init(Config{Claims: globalProvisionerClaims, WebhookClient: httpclient}) require.NoError(t, err) ctx := context.Background() got, err := tt.p.ValidateChallenge(ctx, dummyCSR, tt.args.challenge, tt.args.transactionID) if tt.expErr != nil { assert.EqualError(t, err, tt.expErr.Error()) return } assert.NoError(t, err) data := x509util.TemplateData{} for _, o := range got { if m, ok := o.(TemplateDataModifier); ok { m.Modify(data) } else { t.Errorf("Validate() got = %T, want TemplateDataModifier", o) } } assert.Equal(t, tt.want, data) }) } } func TestSCEP_Init(t *testing.T) { serialize := func(key crypto.PrivateKey, password string) []byte { var opts []pemutil.Options if password == "" { opts = append(opts, pemutil.WithPasswordPrompt("no password", func(s string) ([]byte, error) { return nil, nil })) } else { opts = append(opts, pemutil.WithPassword([]byte("password"))) } block, err := pemutil.Serialize(key, opts...) require.NoError(t, err) return pem.EncodeToMemory(block) } ca, err := minica.New() require.NoError(t, err) key, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) badKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) cert, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "SCEP decryptor"}, PublicKey: key.Public(), }) require.NoError(t, err) certPEM := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: cert.Raw, }) certPEMWithIntermediate := append(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: cert.Raw, }), pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: ca.Intermediate.Raw, })...) keyPEM := serialize(key, "password") keyPEMNoPassword := serialize(key, "") badKeyPEM := serialize(badKey, "password") tmp := t.TempDir() path := filepath.Join(tmp, "rsa.priv") pathNoPassword := filepath.Join(tmp, "rsa.key") require.NoError(t, os.WriteFile(path, keyPEM, 0600)) require.NoError(t, os.WriteFile(pathNoPassword, keyPEMNoPassword, 0600)) type args struct { config Config } tests := []struct { name string s *SCEP args args wantErr bool }{ {"ok", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, false}, {"ok no password", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyPEM: keyPEMNoPassword, DecrypterKeyPassword: "", EncryptionAlgorithmIdentifier: 1, }, args{Config{Claims: globalProvisionerClaims}}, false}, {"ok with uri", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 1024, DecrypterCertificate: certPEM, DecrypterKeyURI: "softkms:path=" + path, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 2, }, args{Config{Claims: globalProvisionerClaims}}, false}, {"ok with uri no password", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 2048, DecrypterCertificate: certPEM, DecrypterKeyURI: "softkms:path=" + pathNoPassword, DecrypterKeyPassword: "", EncryptionAlgorithmIdentifier: 3, }, args{Config{Claims: globalProvisionerClaims}}, false}, {"ok with SCEPKeyManager", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 2048, DecrypterCertificate: certPEM, DecrypterKeyURI: "softkms:path=" + pathNoPassword, DecrypterKeyPassword: "", EncryptionAlgorithmIdentifier: 4, }, args{Config{Claims: globalProvisionerClaims, SCEPKeyManager: &softkms.SoftKMS{}}}, false}, {"ok intermediate", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: nil, DecrypterKeyPEM: nil, DecrypterKeyPassword: "", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, false}, {"fail type", &SCEP{ Type: "", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail name", &SCEP{ Type: "SCEP", Name: "", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail minimumPublicKeyLength", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 2001, DecrypterCertificate: certPEM, DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail encryptionAlgorithmIdentifier", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 5, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail negative encryptionAlgorithmIdentifier", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: -1, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail key decode", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyPEM: []byte("not a pem"), DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail certificate decode", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: []byte("not a pem"), DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail certificate with intermediate", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEMWithIntermediate, DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "password", }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail decrypter password", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "badpassword", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail uri", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyURI: "softkms:path=missing.key", DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail uri password", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyURI: "softkms:path=" + path, DecrypterKeyPassword: "badpassword", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail uri type", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyURI: "foo:path=" + path, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail missing certificate", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: nil, DecrypterKeyPEM: keyPEM, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, {"fail key match", &SCEP{ Type: "SCEP", Name: "scep", ChallengePassword: "password123", MinimumPublicKeyLength: 0, DecrypterCertificate: certPEM, DecrypterKeyPEM: badKeyPEM, DecrypterKeyPassword: "password", EncryptionAlgorithmIdentifier: 0, }, args{Config{Claims: globalProvisionerClaims}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.s.Init(tt.args.config); (err != nil) != tt.wantErr { t.Errorf("SCEP.Init() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestSCEP_Getters(t *testing.T) { p := generateSCEP(t) assert.Equal(t, "scep/scep", p.GetID()) assert.Equal(t, "scep", p.GetName()) assert.Equal(t, TypeSCEP, p.GetType()) kid, key, ok := p.GetEncryptedKey() if kid != "" || key != "" || ok == true { t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } tokenID, err := p.GetTokenID("token") assert.Empty(t, tokenID) assert.Equal(t, ErrTokenFlowNotSupported, err) } ================================================ FILE: authority/provisioner/sign_options.go ================================================ package provisioner import ( "context" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/sha256" "crypto/subtle" "crypto/x509" "encoding/base64" "encoding/json" "net" "net/http" "net/url" "reflect" "time" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" ) // DefaultCertValidity is the default validity for a certificate if none is specified. const DefaultCertValidity = 24 * time.Hour // SignOptions contains the options that can be passed to the Sign method. Backdate // is automatically filled and can only be configured in the CA. type SignOptions struct { NotAfter TimeDuration `json:"notAfter"` NotBefore TimeDuration `json:"notBefore"` TemplateData json.RawMessage `json:"templateData"` Backdate time.Duration `json:"-"` } // SignOption is the interface used to collect all extra options used in the // Sign method. type SignOption interface{} // CertificateValidator is an interface used to validate a given X.509 certificate. type CertificateValidator interface { Valid(cert *x509.Certificate, opts SignOptions) error } // CertificateRequestValidator is an interface used to validate a given X.509 certificate request. type CertificateRequestValidator interface { Valid(cr *x509.CertificateRequest) error } // CertificateModifier is an interface used to modify a given X.509 certificate. // Types implementing this interface will be validated with a // CertificateValidator. type CertificateModifier interface { Modify(cert *x509.Certificate, opts SignOptions) error } // CertificateEnforcer is an interface used to modify a given X.509 certificate. // Types implemented this interface will NOT be validated with a // CertificateValidator. type CertificateEnforcer interface { Enforce(cert *x509.Certificate) error } // CertificateModifierFunc allows to create simple certificate modifiers just // with a function. type CertificateModifierFunc func(cert *x509.Certificate, opts SignOptions) error // Modify implements CertificateModifier and just calls the defined function. func (fn CertificateModifierFunc) Modify(cert *x509.Certificate, opts SignOptions) error { return fn(cert, opts) } // CertificateEnforcerFunc allows to create simple certificate enforcer just // with a function. type CertificateEnforcerFunc func(cert *x509.Certificate) error // Enforce implements CertificateEnforcer and just calls the defined function. func (fn CertificateEnforcerFunc) Enforce(cert *x509.Certificate) error { return fn(cert) } // AttestationData is a SignOption used to pass attestation information to the // sign methods. type AttestationData struct { PermanentIdentifier string } // defaultPublicKeyValidator validates the public key of a certificate request. type defaultPublicKeyValidator struct{} // Valid checks that certificate request common name matches the one configured. func (v defaultPublicKeyValidator) Valid(req *x509.CertificateRequest) error { switch k := req.PublicKey.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { return errs.Forbidden("certificate request RSA key must be at least %d bits (%d bytes)", 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes) } case *ecdsa.PublicKey, ed25519.PublicKey: default: return errs.BadRequest("certificate request key of type '%T' is not supported", k) } return nil } // publicKeyMinimumLengthValidator validates the length (in bits) of the public key // of a certificate request is at least a certain length type publicKeyMinimumLengthValidator struct { length int } // newPublicKeyMinimumLengthValidator creates a new publicKeyMinimumLengthValidator // with the given length as its minimum value // TODO: change the defaultPublicKeyValidator to have a configurable length instead? func newPublicKeyMinimumLengthValidator(length int) publicKeyMinimumLengthValidator { return publicKeyMinimumLengthValidator{ length: length, } } // Valid checks that certificate request common name matches the one configured. func (v publicKeyMinimumLengthValidator) Valid(req *x509.CertificateRequest) error { switch k := req.PublicKey.(type) { case *rsa.PublicKey: minimumLengthInBytes := v.length / 8 if k.Size() < minimumLengthInBytes { return errs.Forbidden("certificate request RSA key must be at least %d bits (%d bytes)", v.length, minimumLengthInBytes) } case *ecdsa.PublicKey, ed25519.PublicKey: default: return errs.BadRequest("certificate request key of type '%T' is not supported", k) } return nil } // commonNameValidator validates the common name of a certificate request. type commonNameValidator string // Valid checks that certificate request common name matches the one configured. // An empty common name is considered valid. func (v commonNameValidator) Valid(req *x509.CertificateRequest) error { if req.Subject.CommonName == "" { return nil } if req.Subject.CommonName != string(v) { return errs.Forbidden("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v) } return nil } // commonNameSliceValidator validates thats the common name of a certificate // request is present in the slice. An empty common name is considered valid. type commonNameSliceValidator []string func (v commonNameSliceValidator) Valid(req *x509.CertificateRequest) error { if req.Subject.CommonName == "" { return nil } for _, cn := range v { if req.Subject.CommonName == cn { return nil } } return errs.Forbidden("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v) } // dnsNamesValidator validates the DNS names SAN of a certificate request. type dnsNamesValidator []string // Valid checks that certificate request DNS Names match those configured in // the bootstrap (token) flow. func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error { if len(req.DNSNames) == 0 { return nil } want := make(map[string]bool) for _, s := range v { want[s] = true } got := make(map[string]bool) for _, s := range req.DNSNames { got[s] = true } if !reflect.DeepEqual(want, got) { return errs.Forbidden("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v) } return nil } // dnsNamesSubsetValidator validates the DNS name SANs of a certificate request. type dnsNamesSubsetValidator []string // Valid checks that all DNS name SANs in the certificate request are present in // the allowed list of DNS names. func (v dnsNamesSubsetValidator) Valid(req *x509.CertificateRequest) error { if len(req.DNSNames) == 0 { return nil } allowed := make(map[string]struct{}, len(v)) for _, s := range v { allowed[s] = struct{}{} } for _, s := range req.DNSNames { if _, ok := allowed[s]; !ok { return errs.Forbidden("certificate request contains unauthorized DNS names - got %v, allowed %v", req.DNSNames, v) } } return nil } // ipAddressesValidator validates the IP addresses SAN of a certificate request. type ipAddressesValidator []net.IP // Valid checks that certificate request IP Addresses match those configured in // the bootstrap (token) flow. func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error { if len(req.IPAddresses) == 0 { return nil } want := make(map[string]bool) for _, ip := range v { want[ip.String()] = true } got := make(map[string]bool) for _, ip := range req.IPAddresses { got[ip.String()] = true } if !reflect.DeepEqual(want, got) { return errs.Forbidden("certificate request does not contain the valid IP addresses - got %v, want %v", req.IPAddresses, v) } return nil } // emailAddressesValidator validates the email address SANs of a certificate request. type emailAddressesValidator []string // Valid checks that certificate request IP Addresses match those configured in // the bootstrap (token) flow. func (v emailAddressesValidator) Valid(req *x509.CertificateRequest) error { if len(req.EmailAddresses) == 0 { return nil } want := make(map[string]bool) for _, s := range v { want[s] = true } got := make(map[string]bool) for _, s := range req.EmailAddresses { got[s] = true } if !reflect.DeepEqual(want, got) { return errs.Forbidden("certificate request does not contain the valid email addresses - got %v, want %v", req.EmailAddresses, v) } return nil } // urisValidator validates the URI SANs of a certificate request. type urisValidator struct { ctx context.Context uris []*url.URL } func newURIsValidator(ctx context.Context, uris []*url.URL) *urisValidator { return &urisValidator{ctx, uris} } // Valid checks that certificate request IP Addresses match those configured in // the bootstrap (token) flow. func (v urisValidator) Valid(req *x509.CertificateRequest) error { // SignIdentityMethod does not need to validate URIs. if MethodFromContext(v.ctx) == SignIdentityMethod { return nil } if len(req.URIs) == 0 { return nil } want := make(map[string]bool) for _, u := range v.uris { want[u.String()] = true } got := make(map[string]bool) for _, u := range req.URIs { got[u.String()] = true } if !reflect.DeepEqual(want, got) { return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v.uris) } return nil } // defaultsSANsValidator stores a set of SANs to eventually validate 1:1 against // the SANs in an x509 certificate request. type defaultSANsValidator struct { ctx context.Context sans []string } func newDefaultSANsValidator(ctx context.Context, sans []string) *defaultSANsValidator { return &defaultSANsValidator{ctx, sans} } // Valid verifies that the SANs stored in the validator match 1:1 with those // requested in the x509 certificate request. func (v defaultSANsValidator) Valid(req *x509.CertificateRequest) (err error) { dnsNames, ips, emails, uris := x509util.SplitSANs(v.sans) if err = dnsNamesValidator(dnsNames).Valid(req); err != nil { return } else if err = emailAddressesValidator(emails).Valid(req); err != nil { return } else if err = ipAddressesValidator(ips).Valid(req); err != nil { return } else if err = newURIsValidator(v.ctx, uris).Valid(req); err != nil { return } return } // profileDefaultDuration is a modifier that sets the certificate // duration. type profileDefaultDuration time.Duration // Modify sets the certificate NotBefore and NotAfter using the following order: // - From the SignOptions that we get from flags. // - From x509.Certificate that we get from the template. // - NotBefore from the current time with a backdate. // - NotAfter from NotBefore plus the duration in v. func (v profileDefaultDuration) Modify(cert *x509.Certificate, so SignOptions) error { var backdate time.Duration notBefore := timeOr(so.NotBefore.Time(), cert.NotBefore) if notBefore.IsZero() { notBefore = now() backdate = -1 * so.Backdate } notAfter := timeOr(so.NotAfter.RelativeTime(notBefore), cert.NotAfter) if notAfter.IsZero() { if v != 0 { notAfter = notBefore.Add(time.Duration(v)) } else { notAfter = notBefore.Add(DefaultCertValidity) } } cert.NotBefore = notBefore.Add(backdate) cert.NotAfter = notAfter return nil } // profileLimitDuration is an x509 profile option that modifies an x509 validity // period according to an imposed expiration time. type profileLimitDuration struct { def time.Duration notBefore, notAfter time.Time } // Modify sets the certificate NotBefore and NotAfter but limits the validity // period to the certificate to one that is superficially imposed. // // The expected NotBefore and NotAfter are set using the following order: // - From the SignOptions that we get from flags. // - From x509.Certificate that we get from the template. // - NotBefore from the current time with a backdate. // - NotAfter from NotBefore plus the duration v or the notAfter in v if lower. func (v profileLimitDuration) Modify(cert *x509.Certificate, so SignOptions) error { var backdate time.Duration notBefore := timeOr(so.NotBefore.Time(), cert.NotBefore) if notBefore.IsZero() { notBefore = now() backdate = -1 * so.Backdate } if notBefore.Before(v.notBefore) { return errs.Forbidden( "requested certificate notBefore (%s) is before the active validity window of the provisioning credential (%s)", notBefore, v.notBefore) } notAfter := timeOr(so.NotAfter.RelativeTime(notBefore), cert.NotAfter) if notAfter.After(v.notAfter) { return errs.Forbidden( "requested certificate notAfter (%s) is after the expiration of the provisioning credential (%s)", notAfter, v.notAfter) } if notAfter.IsZero() { t := notBefore.Add(v.def) if t.After(v.notAfter) { notAfter = v.notAfter } else { notAfter = t } } cert.NotBefore = notBefore.Add(backdate) cert.NotAfter = notAfter return nil } // validityValidator validates the certificate validity settings. type validityValidator struct { min time.Duration max time.Duration } // newValidityValidator return a new validity validator. func newValidityValidator(minDur, maxDur time.Duration) *validityValidator { return &validityValidator{min: minDur, max: maxDur} } // Valid validates the certificate validity settings (notBefore/notAfter) and // total duration. func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { var ( na = cert.NotAfter.Truncate(time.Second) nb = cert.NotBefore.Truncate(time.Second) now = time.Now().Truncate(time.Second) ) d := na.Sub(nb) if na.Before(now) { return errs.BadRequest("notAfter cannot be in the past; na=%v", na) } if na.Before(nb) { return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) } if d < v.min { return errs.Forbidden("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) } // NOTE: this check is not "technically correct". We're allowing the max // duration of a cert to be "max + backdate" and not all certificates will // be backdated (e.g. if a user passes the NotBefore value then we do not // apply a backdate). This is good enough. if d > v.max+o.Backdate { return errs.Forbidden("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate) } return nil } // x509NamePolicyValidator validates that the certificate (to be signed) // contains only allowed SANs. type x509NamePolicyValidator struct { policyEngine policy.X509Policy } // newX509NamePolicyValidator return a new SANs allow/deny validator. func newX509NamePolicyValidator(engine policy.X509Policy) *x509NamePolicyValidator { return &x509NamePolicyValidator{ policyEngine: engine, } } // Valid validates that the certificate (to be signed) contains only allowed SANs. func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) error { if v.policyEngine == nil { return nil } return v.policyEngine.IsX509CertificateAllowed(cert) } type forceCNOption struct { ForceCN bool } func newForceCNOption(forceCN bool) *forceCNOption { return &forceCNOption{forceCN} } func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error { if !o.ForceCN { return nil } // Force the common name to be the first DNS if not provided. if cert.Subject.CommonName == "" { if len(cert.DNSNames) == 0 { return errs.BadRequest("cannot force common name, DNS names is empty") } cert.Subject.CommonName = cert.DNSNames[0] } return nil } type provisionerExtensionOption struct { Extension Disabled bool } func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption { return &provisionerExtensionOption{ Extension: Extension{ Type: typ, Name: name, CredentialID: credentialID, KeyValuePairs: keyValuePairs, }, } } // WithControllerOptions updates the provisionerExtensionOption with options // from the controller. Currently only the DisableSmallstepExtensions // provisioner claim is used. func (o *provisionerExtensionOption) WithControllerOptions(c *Controller) *provisionerExtensionOption { o.Disabled = c.Claimer.IsDisableSmallstepExtensions() return o } func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error { if o.Disabled { return nil } ext, err := o.ToExtension() if err != nil { return errs.NewError(http.StatusInternalServerError, err, "error creating certificate") } // Replace or append the provisioner extension to avoid the inclusions of // malicious stepOIDProvisioner using templates. for i, e := range cert.ExtraExtensions { if e.Id.Equal(StepOIDProvisioner) { cert.ExtraExtensions[i] = ext return nil } } cert.ExtraExtensions = append(cert.ExtraExtensions, ext) return nil } // csrFingerprintValidator is a CertificateRequestValidator that checks the // fingerprint of the certificate request with the provided one. type csrFingerprintValidator string func (s csrFingerprintValidator) Valid(cr *x509.CertificateRequest) error { if s != "" { expected, err := base64.RawURLEncoding.DecodeString(string(s)) if err != nil { return errs.ForbiddenErr(err, "error decoding fingerprint") } sum := sha256.Sum256(cr.Raw) if subtle.ConstantTimeCompare(expected, sum[:]) != 1 { return errs.Forbidden("certificate request fingerprint does not match %q", s) } } return nil } // SignCSROption is the interface used to collect extra options in the SignCSR // method of the SCEP authority. type SignCSROption any // TemplateDataModifier is an interface that allows to modify template data. type TemplateDataModifier interface { Modify(data x509util.TemplateData) } type templateDataModifier struct { fn func(x509util.TemplateData) } func (t *templateDataModifier) Modify(data x509util.TemplateData) { t.fn(data) } // TemplateDataModifierFunc returns a TemplateDataModifier with the given // function. func TemplateDataModifierFunc(fn func(data x509util.TemplateData)) TemplateDataModifier { return &templateDataModifier{ fn: fn, } } ================================================ FILE: authority/provisioner/sign_options_test.go ================================================ package provisioner import ( "context" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "fmt" "net" "net/url" "strings" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "go.step.sm/crypto/pemutil" ) func Test_defaultPublicKeyValidator_Valid(t *testing.T) { _shortRSA, err := pemutil.Read("./testdata/certs/short-rsa.csr") assert.FatalError(t, err) shortRSA, ok := _shortRSA.(*x509.CertificateRequest) assert.Fatal(t, ok) _rsa, err := pemutil.Read("./testdata/certs/rsa.csr") assert.FatalError(t, err) rsaCSR, ok := _rsa.(*x509.CertificateRequest) assert.Fatal(t, ok) _ecdsa, err := pemutil.Read("./testdata/certs/ecdsa.csr") assert.FatalError(t, err) ecdsaCSR, ok := _ecdsa.(*x509.CertificateRequest) assert.Fatal(t, ok) _ed25519, err := pemutil.Read("./testdata/certs/ed25519.csr") assert.FatalError(t, err) ed25519CSR, ok := _ed25519.(*x509.CertificateRequest) assert.Fatal(t, ok) v := defaultPublicKeyValidator{} tests := []struct { name string csr *x509.CertificateRequest err error }{ { "fail/unrecognized-key-type", &x509.CertificateRequest{PublicKey: "foo"}, errors.New("certificate request key of type 'string' is not supported"), }, { "fail/rsa/too-short", shortRSA, errors.New("certificate request RSA key must be at least 2048 bits (256 bytes)"), }, { "ok/rsa", rsaCSR, nil, }, { "ok/ecdsa", ecdsaCSR, nil, }, { "ok/ed25519", ed25519CSR, nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := v.Valid(tt.csr); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { assert.Nil(t, tt.err) } }) } } func Test_commonNameValidator_Valid(t *testing.T) { type args struct { req *x509.CertificateRequest } tests := []struct { name string v commonNameValidator args args wantErr bool }{ {"ok", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false}, {"empty", "", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: ""}}}, false}, {"wrong", "foo.bar.zar", args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "example.com"}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr { t.Errorf("commonNameValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_commonNameSliceValidator_Valid(t *testing.T) { type args struct { req *x509.CertificateRequest } tests := []struct { name string v commonNameSliceValidator args args wantErr bool }{ {"ok", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false}, {"ok", []string{"example.com", "foo.bar.zar"}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "foo.bar.zar"}}}, false}, {"empty", []string{""}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: ""}}}, false}, {"wrong", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{Subject: pkix.Name{CommonName: "example.com"}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr { t.Errorf("commonNameSliceValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_emailAddressesValidator_Valid(t *testing.T) { type args struct { req *x509.CertificateRequest } tests := []struct { name string v emailAddressesValidator args args wantErr bool }{ {"ok0", []string{}, args{&x509.CertificateRequest{EmailAddresses: []string{}}}, false}, {"ok1", []string{"max@smallstep.com"}, args{&x509.CertificateRequest{EmailAddresses: []string{"max@smallstep.com"}}}, false}, {"ok2", []string{"max@step.com", "mike@step.com"}, args{&x509.CertificateRequest{EmailAddresses: []string{"max@step.com", "mike@step.com"}}}, false}, {"ok3", []string{"max@step.com", "mike@step.com"}, args{&x509.CertificateRequest{EmailAddresses: []string{"mike@step.com", "max@step.com"}}}, false}, {"ok3", []string{"max@step.com", "mike@step.com"}, args{&x509.CertificateRequest{}}, false}, {"fail1", []string{"max@step.com"}, args{&x509.CertificateRequest{EmailAddresses: []string{"mike@step.com"}}}, true}, {"fail2", []string{"mike@step.com"}, args{&x509.CertificateRequest{EmailAddresses: []string{"max@step.com", "mike@step.com"}}}, true}, {"fail3", []string{"mike@step.com", "max@step.com"}, args{&x509.CertificateRequest{EmailAddresses: []string{"mike@step.com", "mex@step.com"}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr { t.Errorf("emailAddressesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_dnsNamesValidator_Valid(t *testing.T) { type args struct { req *x509.CertificateRequest } tests := []struct { name string v dnsNamesValidator args args wantErr bool }{ {"ok0", []string{}, args{&x509.CertificateRequest{DNSNames: []string{}}}, false}, {"ok1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, false}, {"ok2", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "bar.zar"}}}, false}, {"ok3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, false}, {"ok4", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{}}, false}, {"fail1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar"}}}, true}, {"fail2", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, true}, {"fail3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "zar.bar"}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr { t.Errorf("dnsNamesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_dnsNamesSubsetValidator_Valid(t *testing.T) { type args struct { req *x509.CertificateRequest } tests := []struct { name string v dnsNamesSubsetValidator args args wantErr bool }{ {"ok0", []string{}, args{&x509.CertificateRequest{DNSNames: []string{}}}, false}, {"ok1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, false}, {"ok2", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "bar.zar"}}}, false}, {"ok3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, false}, {"ok4", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{}}, false}, {"ok5", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar"}}}, false}, {"ok6", []string{"foo", "bar", "baz", "zar", "zap"}, args{&x509.CertificateRequest{DNSNames: []string{"zap", "baz", "foo"}}}, false}, {"fail1", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar"}}}, true}, {"fail2", []string{"foo.bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"bar.zar", "foo.bar.zar"}}}, true}, {"fail3", []string{"foo.bar.zar", "bar.zar"}, args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar", "zar.bar"}}}, true}, {"fail4", []string{"foo", "bar", "baz", "zar", "zap"}, args{&x509.CertificateRequest{DNSNames: []string{"zap", "baz", "foO"}}}, true}, {"fail5", []string{"foo", "bar", "baz", "zar", "zap"}, args{&x509.CertificateRequest{DNSNames: []string{"zap", "baz", "fax", "foo"}}}, true}, {"fail6", []string{}, args{&x509.CertificateRequest{DNSNames: []string{"zap", "baz", "fax", "foo"}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr { t.Errorf("dnsNamesSubsetValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_ipAddressesValidator_Valid(t *testing.T) { ip1 := net.IPv4(10, 3, 2, 1) ip2 := net.IPv4(10, 3, 2, 2) ip3 := net.IPv4(10, 3, 2, 3) type args struct { req *x509.CertificateRequest } tests := []struct { name string v ipAddressesValidator args args wantErr bool }{ {"ok0", []net.IP{}, args{&x509.CertificateRequest{IPAddresses: []net.IP{}}}, false}, {"ok1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1}}}, false}, {"ok2", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip2}}}, false}, {"ok3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, false}, {"ok4", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{}}, false}, {"fail1", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2}}}, true}, {"fail2", []net.IP{ip1}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip2, ip1}}}, true}, {"fail3", []net.IP{ip1, ip2}, args{&x509.CertificateRequest{IPAddresses: []net.IP{ip1, ip3}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr { t.Errorf("ipAddressesValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_urisValidator_Valid(t *testing.T) { u1, err := url.Parse("https://ca.smallstep.com") assert.FatalError(t, err) u2, err := url.Parse("https://google.com/index.html") assert.FatalError(t, err) u3, err := url.Parse("urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959") assert.FatalError(t, err) fu, err := url.Parse("https://unexpected.com") assert.FatalError(t, err) signContext := NewContextWithMethod(context.Background(), SignMethod) signIdentityContext := NewContextWithMethod(context.Background(), SignIdentityMethod) type args struct { req *x509.CertificateRequest } tests := []struct { name string v *urisValidator args args wantErr bool }{ {"ok0", newURIsValidator(signContext, []*url.URL{}), args{&x509.CertificateRequest{URIs: []*url.URL{}}}, false}, {"ok1", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u1}}}, false}, {"ok2", newURIsValidator(signContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, false}, {"ok3", newURIsValidator(signContext, []*url.URL{u2, u1, u3}), args{&x509.CertificateRequest{URIs: []*url.URL{u3, u2, u1}}}, false}, {"ok4", newURIsValidator(signIdentityContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, false}, {"fail1", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u2}}}, true}, {"fail2", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, true}, {"fail3", newURIsValidator(signContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.v.Valid(tt.args.req); (err != nil) != tt.wantErr { t.Errorf("urisValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) } }) } } func Test_defaultSANsValidator_Valid(t *testing.T) { type test struct { csr *x509.CertificateRequest ctx context.Context expectedSANs []string err error } signContext := NewContextWithMethod(context.Background(), SignMethod) signIdentityContext := NewContextWithMethod(context.Background(), SignIdentityMethod) tests := map[string]func() test{ "fail/dnsNamesValidator": func() test { return test{ csr: &x509.CertificateRequest{DNSNames: []string{"foo", "bar"}}, ctx: signContext, expectedSANs: []string{"foo"}, err: errors.New("certificate request does not contain the valid DNS names"), } }, "fail/emailAddressesValidator": func() test { return test{ csr: &x509.CertificateRequest{EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}}, ctx: signContext, expectedSANs: []string{"dcow@fx.com"}, err: errors.New("certificate request does not contain the valid email addresses"), } }, "fail/ipAddressesValidator": func() test { return test{ csr: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("1.1.1.1"), net.ParseIP("127.0.0.1")}}, ctx: signContext, expectedSANs: []string{"127.0.0.1"}, err: errors.New("certificate request does not contain the valid IP addresses"), } }, "fail/urisValidator": func() test { u1, err := url.Parse("https://google.com") assert.FatalError(t, err) u2, err := url.Parse("urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959") assert.FatalError(t, err) return test{ csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}}, ctx: signContext, expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"}, err: errors.New("certificate request does not contain the valid URIs"), } }, "ok/urisBadValidator-SignIdentity": func() test { u1, err := url.Parse("https://google.com") assert.FatalError(t, err) u2, err := url.Parse("urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959") assert.FatalError(t, err) return test{ csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}}, ctx: signIdentityContext, expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"}, } }, "ok": func() test { u1, err := url.Parse("https://google.com") assert.FatalError(t, err) u2, err := url.Parse("urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959") assert.FatalError(t, err) return test{ ctx: signContext, csr: &x509.CertificateRequest{ DNSNames: []string{"foo", "bar"}, EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}, IPAddresses: []net.IP{net.ParseIP("1.1.1.1"), net.ParseIP("127.0.0.1")}, URIs: []*url.URL{u1, u2}, }, expectedSANs: []string{"foo", "127.0.0.1", "max@fx.com", "mariano@fx.com", "https://google.com", "1.1.1.1", "bar", "urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"}, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() if err := newDefaultSANsValidator(tt.ctx, tt.expectedSANs).Valid(tt.csr); err != nil { if assert.NotNil(t, tt.err, fmt.Sprintf("expected no error, but got err = %s", err.Error())) { assert.True(t, strings.Contains(err.Error(), tt.err.Error()), fmt.Sprintf("want err = %s, but got err = %s", tt.err.Error(), err.Error())) } } else { assert.Nil(t, tt.err, fmt.Sprintf("expected err = %s, but not ", tt.err)) } }) } } func Test_validityValidator_Valid(t *testing.T) { type test struct { cert *x509.Certificate opts SignOptions vv *validityValidator err error } tests := map[string]func() test{ "fail/notAfter-past": func() test { return test{ vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, cert: &x509.Certificate{NotAfter: time.Now().Add(-5 * time.Minute)}, opts: SignOptions{}, err: errors.New("notAfter cannot be in the past"), } }, "fail/notBefore-after-notAfter": func() test { return test{ vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, cert: &x509.Certificate{NotBefore: time.Now().Add(10 * time.Minute), NotAfter: time.Now().Add(5 * time.Minute)}, opts: SignOptions{}, err: errors.New("notAfter cannot be before notBefore"), } }, "fail/duration-too-short": func() test { n := now() return test{ vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, cert: &x509.Certificate{NotBefore: n, NotAfter: n.Add(3 * time.Minute)}, opts: SignOptions{}, err: errors.New("is less than the authorized minimum certificate duration of "), } }, "ok/duration-exactly-min": func() test { n := now() return test{ vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, cert: &x509.Certificate{NotBefore: n, NotAfter: n.Add(5 * time.Minute)}, opts: SignOptions{}, } }, "fail/duration-too-great": func() test { n := now() return test{ vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, cert: &x509.Certificate{NotBefore: n, NotAfter: n.Add(24*time.Hour + time.Second)}, err: errors.New("is more than the authorized maximum certificate duration of "), } }, "ok/duration-exactly-max": func() test { n := time.Now() return test{ vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, cert: &x509.Certificate{NotBefore: n, NotAfter: n.Add(24 * time.Hour)}, } }, "ok/duration-exact-min-with-backdate": func() test { now := time.Now() cert := &x509.Certificate{NotBefore: now, NotAfter: now.Add(5 * time.Minute)} time.Sleep(time.Second) return test{ vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, cert: cert, opts: SignOptions{Backdate: time.Second}, } }, "ok/duration-exact-max-with-backdate": func() test { backdate := time.Second now := time.Now() cert := &x509.Certificate{NotBefore: now, NotAfter: now.Add(24*time.Hour + backdate)} time.Sleep(backdate) return test{ vv: &validityValidator{5 * time.Minute, 24 * time.Hour}, cert: cert, opts: SignOptions{Backdate: backdate}, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() if err := tt.vv.Valid(tt.cert, tt.opts); err != nil { if assert.NotNil(t, tt.err, fmt.Sprintf("expected no error, but got err = %s", err.Error())) { assert.True(t, strings.Contains(err.Error(), tt.err.Error()), fmt.Sprintf("want err = %s, but got err = %s", tt.err.Error(), err.Error())) } } else { assert.Nil(t, tt.err, fmt.Sprintf("expected err = %s, but not ", tt.err)) } }) } } func Test_forceCN_Option(t *testing.T) { type test struct { so SignOptions fcn forceCNOption cert *x509.Certificate valid func(*x509.Certificate) err error } tests := map[string]func() test{ "ok/CN-not-forced": func() test { return test{ fcn: forceCNOption{false}, so: SignOptions{}, cert: &x509.Certificate{ Subject: pkix.Name{}, DNSNames: []string{"acme.example.com", "step.example.com"}, }, valid: func(cert *x509.Certificate) { assert.Equals(t, cert.Subject.CommonName, "") }, } }, "ok/CN-forced-and-set": func() test { return test{ fcn: forceCNOption{true}, so: SignOptions{}, cert: &x509.Certificate{ Subject: pkix.Name{ CommonName: "Some Common Name", }, DNSNames: []string{"acme.example.com", "step.example.com"}, }, valid: func(cert *x509.Certificate) { assert.Equals(t, cert.Subject.CommonName, "Some Common Name") }, } }, "ok/CN-forced-and-not-set": func() test { return test{ fcn: forceCNOption{true}, so: SignOptions{}, cert: &x509.Certificate{ Subject: pkix.Name{}, DNSNames: []string{"acme.example.com", "step.example.com"}, }, valid: func(cert *x509.Certificate) { assert.Equals(t, cert.Subject.CommonName, "acme.example.com") }, } }, "fail/CN-forced-and-empty-DNSNames": func() test { return test{ fcn: forceCNOption{true}, so: SignOptions{}, cert: &x509.Certificate{ Subject: pkix.Name{}, DNSNames: []string{}, }, err: errors.New("cannot force common name, DNS names is empty"), } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() if err := tt.fcn.Modify(tt.cert, tt.so); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { if assert.Nil(t, tt.err) { tt.valid(tt.cert) } } }) } } func Test_profileDefaultDuration_Option(t *testing.T) { type test struct { so SignOptions pdd profileDefaultDuration cert *x509.Certificate valid func(*x509.Certificate) } tests := map[string]func() test{ "ok/notBefore-notAfter-duration-empty": func() test { return test{ pdd: profileDefaultDuration(0), so: SignOptions{}, cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { n := now() assert.True(t, n.After(cert.NotBefore.Add(-time.Second))) assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore)) assert.True(t, n.Add(24*time.Hour).After(cert.NotAfter.Add(-time.Second))) assert.True(t, n.Add(24*time.Hour).Add(-1*time.Minute).Before(cert.NotAfter)) }, } }, "ok/notBefore-set": func() test { nb := time.Now().Add(5 * time.Minute).UTC() return test{ pdd: profileDefaultDuration(0), so: SignOptions{NotBefore: NewTimeDuration(nb)}, cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { assert.Equals(t, cert.NotBefore, nb) assert.Equals(t, cert.NotAfter, nb.Add(24*time.Hour)) }, } }, "ok/duration-set": func() test { d := 4 * time.Hour return test{ pdd: profileDefaultDuration(d), so: SignOptions{Backdate: time.Second}, cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { n := now() assert.True(t, n.After(cert.NotBefore), fmt.Sprintf("expected now = %s to be after cert.NotBefore = %s", n, cert.NotBefore)) assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore)) assert.True(t, n.Add(d).After(cert.NotAfter)) assert.True(t, n.Add(d).Add(-1*time.Minute).Before(cert.NotAfter)) }, } }, "ok/notAfter-set": func() test { na := now().Add(10 * time.Minute).UTC() return test{ pdd: profileDefaultDuration(0), so: SignOptions{NotAfter: NewTimeDuration(na)}, cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { n := now() assert.True(t, n.Add(3*time.Second).After(cert.NotBefore), fmt.Sprintf("expected now = %s to be after cert.NotBefore = %s", n.Add(3*time.Second), cert.NotBefore)) assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore)) assert.Equals(t, cert.NotAfter, na) }, } }, "ok/notBefore-and-notAfter-set": func() test { nb := time.Now().Add(5 * time.Minute).UTC() na := time.Now().Add(10 * time.Minute).UTC() d := 4 * time.Hour return test{ pdd: profileDefaultDuration(d), so: SignOptions{NotBefore: NewTimeDuration(nb), NotAfter: NewTimeDuration(na)}, cert: &x509.Certificate{ NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour), }, valid: func(cert *x509.Certificate) { assert.Equals(t, nb, cert.NotBefore) assert.Equals(t, na, cert.NotAfter) }, } }, "ok/cert-with-validity": func() test { nb := time.Now().Add(5 * time.Minute).UTC() na := time.Now().Add(10 * time.Minute).UTC() d := 4 * time.Hour return test{ pdd: profileDefaultDuration(d), so: SignOptions{}, cert: &x509.Certificate{ NotBefore: nb, NotAfter: na, }, valid: func(cert *x509.Certificate) { assert.Equals(t, nb, cert.NotBefore) assert.Equals(t, na, cert.NotAfter) }, } }, "ok/cert-notBefore-option-notafter": func() test { nb := time.Now().Add(5 * time.Minute).UTC() na := time.Now().Add(10 * time.Minute).UTC() d := 4 * time.Hour return test{ pdd: profileDefaultDuration(d), so: SignOptions{NotAfter: NewTimeDuration(na)}, cert: &x509.Certificate{ NotBefore: nb, }, valid: func(cert *x509.Certificate) { assert.Equals(t, nb, cert.NotBefore) assert.Equals(t, na, cert.NotAfter) }, } }, "ok/cert-notAfter-option-notBefore": func() test { nb := time.Now().Add(5 * time.Minute).UTC() na := time.Now().Add(10 * time.Minute).UTC() d := 4 * time.Hour return test{ pdd: profileDefaultDuration(d), so: SignOptions{NotBefore: NewTimeDuration(nb)}, cert: &x509.Certificate{ NotAfter: na, }, valid: func(cert *x509.Certificate) { assert.Equals(t, cert.NotBefore, nb) assert.Equals(t, cert.NotAfter, na) }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() assert.FatalError(t, tt.pdd.Modify(tt.cert, tt.so), "unexpected error") time.Sleep(100 * time.Millisecond) tt.valid(tt.cert) }) } } func Test_newProvisionerExtension_Option(t *testing.T) { expectedValue, err := asn1.Marshal(extensionASN1{ Type: int(TypeJWK), Name: []byte("name"), CredentialID: []byte("credentialId"), KeyValuePairs: []string{"key", "value"}, }) if err != nil { t.Fatal(err) } // Claims with smallstep extensions disabled. claimer, err := NewClaimer(&Claims{ DisableSmallstepExtensions: &trueValue, }, globalProvisionerClaims) if err != nil { t.Fatal(err) } type test struct { modifier *provisionerExtensionOption cert *x509.Certificate valid func(*x509.Certificate) } tests := map[string]func() test{ "ok/one-element": func() test { return test{ modifier: newProvisionerExtensionOption(TypeJWK, "name", "credentialId", "key", "value"), cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { if assert.Len(t, 1, cert.ExtraExtensions) { ext := cert.ExtraExtensions[0] assert.Equals(t, StepOIDProvisioner, ext.Id) assert.Equals(t, expectedValue, ext.Value) assert.False(t, ext.Critical) } }, } }, "ok/replace": func() test { return test{ modifier: newProvisionerExtensionOption(TypeJWK, "name", "credentialId", "key", "value"), cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, valid: func(cert *x509.Certificate) { if assert.Len(t, 2, cert.ExtraExtensions) { ext := cert.ExtraExtensions[0] assert.Equals(t, StepOIDProvisioner, ext.Id) assert.Equals(t, expectedValue, ext.Value) assert.False(t, ext.Critical) } }, } }, "ok/disabled": func() test { return test{ modifier: newProvisionerExtensionOption(TypeJWK, "name", "credentialId", "key", "value").WithControllerOptions(&Controller{ Claimer: claimer, }), cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { assert.Len(t, 0, cert.ExtraExtensions) }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() assert.FatalError(t, tt.modifier.Modify(tt.cert, SignOptions{})) tt.valid(tt.cert) }) } } func Test_profileLimitDuration_Option(t *testing.T) { n, fn := mockNow() defer fn() type test struct { pld profileLimitDuration so SignOptions cert *x509.Certificate valid func(*x509.Certificate) err error } tests := map[string]func() test{ "fail/notBefore-before-active-window": func() test { d, err := ParseTimeDuration("6h") assert.FatalError(t, err) return test{ pld: profileLimitDuration{def: 4 * time.Hour, notBefore: n.Add(8 * time.Hour)}, so: SignOptions{NotBefore: d}, cert: new(x509.Certificate), err: errors.New("requested certificate notBefore ("), } }, "fail/requested-notAfter-after-limit": func() test { d, err := ParseTimeDuration("4h") assert.FatalError(t, err) return test{ pld: profileLimitDuration{def: 4 * time.Hour, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{NotBefore: NewTimeDuration(n.Add(3 * time.Hour)), NotAfter: d}, cert: new(x509.Certificate), err: errors.New("requested certificate notAfter ("), } }, "fail/cert-validity-notBefore": func() test { return test{ pld: profileLimitDuration{def: 4 * time.Hour, notBefore: n, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{}, cert: &x509.Certificate{ NotBefore: n.Add(-time.Second), NotAfter: n.Add(5 * time.Hour), }, err: errors.New("requested certificate notBefore ("), } }, "fail/cert-validity-notAfter": func() test { return test{ pld: profileLimitDuration{def: 4 * time.Hour, notBefore: n, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{}, cert: &x509.Certificate{ NotBefore: n, NotAfter: n.Add(6*time.Hour + time.Second), }, err: errors.New("requested certificate notAfter ("), } }, "ok/valid-notAfter-requested": func() test { d, err := ParseTimeDuration("2h") assert.FatalError(t, err) return test{ pld: profileLimitDuration{def: 4 * time.Hour, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{NotBefore: NewTimeDuration(n.Add(3 * time.Hour)), NotAfter: d, Backdate: 1 * time.Minute}, cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { assert.Equals(t, cert.NotBefore, n.Add(3*time.Hour)) assert.Equals(t, cert.NotAfter, n.Add(5*time.Hour)) }, } }, "ok/valid-notAfter-nil-limit-over-default": func() test { return test{ pld: profileLimitDuration{def: 1 * time.Hour, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{NotBefore: NewTimeDuration(n.Add(3 * time.Hour)), Backdate: 1 * time.Minute}, cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { assert.Equals(t, cert.NotBefore, n.Add(3*time.Hour)) assert.Equals(t, cert.NotAfter, n.Add(4*time.Hour)) }, } }, "ok/valid-notAfter-nil-limit-under-default": func() test { return test{ pld: profileLimitDuration{def: 4 * time.Hour, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{NotBefore: NewTimeDuration(n.Add(3 * time.Hour)), Backdate: 1 * time.Minute}, cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { assert.Equals(t, cert.NotBefore, n.Add(3*time.Hour)) assert.Equals(t, cert.NotAfter, n.Add(6*time.Hour)) }, } }, "ok/over-limit-with-backdate": func() test { return test{ pld: profileLimitDuration{def: 24 * time.Hour, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{Backdate: 1 * time.Minute}, cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { assert.Equals(t, cert.NotBefore, n.Add(-time.Minute)) assert.Equals(t, cert.NotAfter, n.Add(6*time.Hour)) }, } }, "ok/under-limit-with-backdate": func() test { return test{ pld: profileLimitDuration{def: 24 * time.Hour, notAfter: n.Add(30 * time.Hour)}, so: SignOptions{Backdate: 1 * time.Minute}, cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { assert.Equals(t, cert.NotBefore, n.Add(-time.Minute)) assert.Equals(t, cert.NotAfter, n.Add(24*time.Hour)) }, } }, "ok/cert-validity": func() test { return test{ pld: profileLimitDuration{def: 4 * time.Hour, notBefore: n, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{}, cert: &x509.Certificate{ NotBefore: n, NotAfter: n.Add(5 * time.Hour), }, valid: func(cert *x509.Certificate) { assert.Equals(t, n, cert.NotBefore) assert.Equals(t, n.Add(5*time.Hour), cert.NotAfter) }, } }, "ok/cert-notBefore-default": func() test { return test{ pld: profileLimitDuration{def: 4 * time.Hour, notBefore: n, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{}, cert: &x509.Certificate{ NotBefore: n, }, valid: func(cert *x509.Certificate) { assert.Equals(t, n, cert.NotBefore) assert.Equals(t, n.Add(4*time.Hour), cert.NotAfter) }, } }, "ok/cert-notAfter-default": func() test { return test{ pld: profileLimitDuration{def: 4 * time.Hour, notBefore: n, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{}, cert: &x509.Certificate{ NotAfter: n.Add(5 * time.Hour), }, valid: func(cert *x509.Certificate) { assert.Equals(t, n, cert.NotBefore) assert.Equals(t, n.Add(5*time.Hour), cert.NotAfter) }, } }, "ok/cert-notBefore-option": func() test { return test{ pld: profileLimitDuration{def: 4 * time.Hour, notBefore: n, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{NotAfter: NewTimeDuration(n.Add(5 * time.Hour))}, cert: &x509.Certificate{ NotBefore: n, }, valid: func(cert *x509.Certificate) { assert.Equals(t, n, cert.NotBefore) assert.Equals(t, n.Add(5*time.Hour), cert.NotAfter) }, } }, "ok/cert-notAfter-option": func() test { return test{ pld: profileLimitDuration{def: 4 * time.Hour, notBefore: n, notAfter: n.Add(6 * time.Hour)}, so: SignOptions{NotBefore: NewTimeDuration(n.Add(4 * time.Hour))}, cert: &x509.Certificate{ NotAfter: n.Add(5 * time.Hour), }, valid: func(cert *x509.Certificate) { assert.Equals(t, n.Add(4*time.Hour), cert.NotBefore) assert.Equals(t, n.Add(5*time.Hour), cert.NotAfter) }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() if err := tt.pld.Modify(tt.cert, tt.so); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { if assert.Nil(t, tt.err) { tt.valid(tt.cert) } } }) } } ================================================ FILE: authority/provisioner/sign_ssh_options.go ================================================ package provisioner import ( "crypto/rsa" "encoding/binary" "encoding/json" "fmt" "math/big" "strings" "time" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "go.step.sm/crypto/keyutil" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" ) const ( // SSHUserCert is the string used to represent ssh.UserCert. SSHUserCert = "user" // SSHHostCert is the string used to represent ssh.HostCert. SSHHostCert = "host" ) // SSHCertModifier is the interface used to change properties in an SSH // certificate. type SSHCertModifier interface { SignOption Modify(cert *ssh.Certificate, opts SignSSHOptions) error } // SSHCertValidator is the interface used to validate an SSH certificate. type SSHCertValidator interface { SignOption Valid(cert *ssh.Certificate, opts SignSSHOptions) error } // SSHCertOptionsValidator is the interface used to validate the custom // options used to modify the SSH certificate. type SSHCertOptionsValidator interface { SignOption Valid(got SignSSHOptions) error } // SSHPublicKeyValidator is the interface used to validate the public key of an // SSH certificate. type SSHPublicKeyValidator interface { SignOption Valid(got ssh.PublicKey) error } // SignSSHOptions contains the options that can be passed to the SignSSH method. type SignSSHOptions struct { CertType string `json:"certType"` KeyID string `json:"keyID"` Principals []string `json:"principals"` ValidAfter TimeDuration `json:"validAfter,omitempty"` ValidBefore TimeDuration `json:"validBefore,omitempty"` TemplateData json.RawMessage `json:"templateData,omitempty"` Backdate time.Duration `json:"-"` } // Validate validates the given SignSSHOptions. func (o SignSSHOptions) Validate() error { if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert { return errs.BadRequest("certType '%s' is not valid", o.CertType) } for _, p := range o.Principals { if p == "" { return errs.BadRequest("principals cannot contain empty values") } } return nil } // Type returns the uint32 representation of the CertType. func (o SignSSHOptions) Type() uint32 { return sshCertTypeUInt32(o.CertType) } // Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate. func (o SignSSHOptions) Modify(cert *ssh.Certificate, _ SignSSHOptions) error { switch o.CertType { case "": // ignore case SSHUserCert: cert.CertType = ssh.UserCert case SSHHostCert: cert.CertType = ssh.HostCert default: return errs.BadRequest("ssh certificate has an unknown type '%s'", o.CertType) } cert.KeyId = o.KeyID cert.ValidPrincipals = o.Principals return o.ModifyValidity(cert) } // ModifyValidity modifies only the ValidAfter and ValidBefore on the given // ssh.Certificate. func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error { t := now() if !o.ValidAfter.IsZero() { cert.ValidAfter = cast.Uint64(o.ValidAfter.RelativeTime(t).Unix()) } if !o.ValidBefore.IsZero() { cert.ValidBefore = cast.Uint64(o.ValidBefore.RelativeTime(t).Unix()) } if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore { return errs.BadRequest("ssh certificate validAfter cannot be greater than validBefore") } return nil } // match compares two SSHOptions and return an error if they don't match. It // ignores zero values. func (o SignSSHOptions) match(got SignSSHOptions) error { if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType { return errs.Forbidden("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType) } if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) { return errs.Forbidden("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals) } if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) { return errs.Forbidden("ssh certificate validAfter does not match - got %v, want %v", got.ValidAfter, o.ValidAfter) } if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) { return errs.Forbidden("ssh certificate validBefore does not match - got %v, want %v", got.ValidBefore, o.ValidBefore) } return nil } // sshCertValidAfterModifier is an SSHCertModifier that sets the // ValidAfter in the SSH certificate. type sshCertValidAfterModifier uint64 func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error { cert.ValidAfter = uint64(m) return nil } // sshCertValidBeforeModifier is an SSHCertModifier that sets the // ValidBefore in the SSH certificate. type sshCertValidBeforeModifier uint64 func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error { cert.ValidBefore = uint64(m) return nil } // sshDefaultDuration is an SSHCertModifier that sets the certificate // ValidAfter and ValidBefore if they have not been set. It will fail if a // CertType has not been set or is not valid. type sshDefaultDuration struct { *Claimer } // Modify implements SSHCertModifier and sets the validity if it has not been // set, but it always applies the backdate. func (m *sshDefaultDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error { d, err := m.DefaultSSHCertDuration(cert.CertType) if err != nil { return err } var backdate uint64 if cert.ValidAfter == 0 { backdate = cast.Uint64(o.Backdate / time.Second) cert.ValidAfter = cast.Uint64(now().Truncate(time.Second).Unix()) } if cert.ValidBefore == 0 { cert.ValidBefore = cert.ValidAfter + cast.Uint64(d/time.Second) } // Apply backdate safely if cert.ValidAfter > backdate { cert.ValidAfter -= backdate } return nil } // sshLimitDuration adjusts the duration to min(default, remaining provisioning // credential duration). E.g. if the default is 12hrs but the remaining validity // of the provisioning credential is only 4hrs, this option will set the value // to 4hrs (the min of the two values). It will fail if a CertType has not been // set or is not valid. type sshLimitDuration struct { *Claimer NotAfter time.Time } // Modify implements SSHCertModifier and modifies the validity of the // certificate to expire before the configured limit. func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error { if m.NotAfter.IsZero() { defaultDuration := &sshDefaultDuration{m.Claimer} return defaultDuration.Modify(cert, o) } // Make sure the duration is within the limits. d, err := m.DefaultSSHCertDuration(cert.CertType) if err != nil { return err } var backdate uint64 if cert.ValidAfter == 0 { backdate = cast.Uint64(o.Backdate / time.Second) cert.ValidAfter = cast.Uint64(now().Truncate(time.Second).Unix()) } certValidAfter := time.Unix(cast.Int64(cert.ValidAfter), 0) if certValidAfter.After(m.NotAfter) { return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validAfter (%s)", m.NotAfter, certValidAfter) } if cert.ValidBefore == 0 { certValidBefore := certValidAfter.Add(d) if m.NotAfter.Before(certValidBefore) { certValidBefore = m.NotAfter } cert.ValidBefore = cast.Uint64(certValidBefore.Unix()) } else { certValidBefore := time.Unix(cast.Int64(cert.ValidBefore), 0) if m.NotAfter.Before(certValidBefore) { return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validBefore (%s)", m.NotAfter, certValidBefore) } } // Apply backdate safely if cert.ValidAfter > backdate { cert.ValidAfter -= backdate } return nil } // sshCertOptionsValidator validates the user SSHOptions with the ones // usually present in the token. type sshCertOptionsValidator SignSSHOptions // Valid implements SSHCertOptionsValidator and returns nil if both // SSHOptions match. func (v sshCertOptionsValidator) Valid(got SignSSHOptions) error { want := SignSSHOptions(v) return want.match(got) } // sshCertOptionsRequireValidator defines which elements in the SignSSHOptions are required. type sshCertOptionsRequireValidator struct { CertType bool KeyID bool Principals bool } func (v *sshCertOptionsRequireValidator) Valid(got SignSSHOptions) error { switch { case v.CertType && got.CertType == "": return errs.BadRequest("ssh certificate certType cannot be empty") case v.KeyID && got.KeyID == "": return errs.BadRequest("ssh certificate keyID cannot be empty") case v.Principals && len(got.Principals) == 0: return errs.BadRequest("ssh certificate principals cannot be empty") default: return nil } } type sshCertValidityValidator struct { *Claimer } func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error { switch { case cert.ValidAfter == 0: return errs.BadRequest("ssh certificate validAfter cannot be 0") case cert.ValidBefore < cast.Uint64(now().Unix()): return errs.BadRequest("ssh certificate validBefore cannot be in the past") case cert.ValidBefore < cert.ValidAfter: return errs.BadRequest("ssh certificate validBefore cannot be before validAfter") } var minDur, maxDur time.Duration switch cert.CertType { case ssh.UserCert: minDur = v.MinUserSSHCertDuration() maxDur = v.MaxUserSSHCertDuration() case ssh.HostCert: minDur = v.MinHostSSHCertDuration() maxDur = v.MaxHostSSHCertDuration() case 0: return errs.BadRequest("ssh certificate type has not been set") default: return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType) } // To not take into account the backdate, time.Now() will be used to // calculate the duration if ValidAfter is in the past. dur := time.Duration(cast.Int64(cert.ValidBefore-cert.ValidAfter)) * time.Second switch { case dur < minDur: return errs.Forbidden("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, minDur) case dur > maxDur+opts.Backdate: return errs.Forbidden("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, maxDur+opts.Backdate) default: return nil } } // sshCertDefaultValidator implements a simple validator for all the // fields in the SSH certificate. type sshCertDefaultValidator struct{} // Valid returns an error if the given certificate does not contain the // necessary fields. We skip ValidPrincipals and Extensions as with custom // templates you can set them empty. func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error { switch { case len(cert.Nonce) == 0: return errs.Forbidden("ssh certificate nonce cannot be empty") case cert.Key == nil: return errs.Forbidden("ssh certificate key cannot be nil") case cert.Serial == 0: return errs.Forbidden("ssh certificate serial cannot be 0") case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert: return errs.Forbidden("ssh certificate has an unknown type '%d'", cert.CertType) case cert.KeyId == "": return errs.Forbidden("ssh certificate key id cannot be empty") case cert.ValidAfter == 0: return errs.Forbidden("ssh certificate validAfter cannot be 0") case cert.ValidBefore < cast.Uint64(now().Unix()): return errs.Forbidden("ssh certificate validBefore cannot be in the past") case cert.ValidBefore < cert.ValidAfter: return errs.Forbidden("ssh certificate validBefore cannot be before validAfter") case cert.SignatureKey == nil: return errs.Forbidden("ssh certificate signature key cannot be nil") case cert.Signature == nil: return errs.Forbidden("ssh certificate signature cannot be nil") default: return nil } } // sshDefaultPublicKeyValidator implements a validator for the certificate key. type sshDefaultPublicKeyValidator struct{} // Valid checks that certificate request common name matches the one configured. // // TODO: this is the only validator that checks the key type. We should execute // this before the signing. We should add a new validations interface or extend // SSHCertOptionsValidator with the key. func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error { if cert.Key == nil { return errs.BadRequest("ssh certificate key cannot be nil") } switch cert.Key.Type() { case ssh.KeyAlgoRSA: _, in, ok := sshParseString(cert.Key.Marshal()) if !ok { return errs.BadRequest("ssh certificate key is invalid") } key, err := sshParseRSAPublicKey(in) if err != nil { return errs.BadRequestErr(err, "error parsing public key") } if key.Size() < keyutil.MinRSAKeyBytes { return errs.Forbidden("ssh certificate key must be at least %d bits (%d bytes)", 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes) } return nil case ssh.InsecureKeyAlgoDSA: //nolint:staticcheck // only using the constant for lookup; no dependent logic return errs.BadRequest("ssh certificate key algorithm (DSA) is not supported") default: return nil } } // sshNamePolicyValidator validates that the certificate (to be signed) // contains only allowed principals. type sshNamePolicyValidator struct { hostPolicyEngine policy.HostPolicy userPolicyEngine policy.UserPolicy } // newSSHNamePolicyValidator return a new SSH allow/deny validator. func newSSHNamePolicyValidator(host policy.HostPolicy, user policy.UserPolicy) *sshNamePolicyValidator { return &sshNamePolicyValidator{ hostPolicyEngine: host, userPolicyEngine: user, } } // Valid validates that the certificate (to be signed) contains only allowed principals. func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error { if v.hostPolicyEngine == nil && v.userPolicyEngine == nil { // no policy configured at all; allow anything return nil } // Check the policy type to execute based on type of the certificate. // We don't allow user certs if only a host policy engine is configured and // the same for host certs: if only a user policy engine is configured, host // certs are denied. When both policy engines are configured, the type of // cert determines which policy engine is used. switch cert.CertType { case ssh.HostCert: // when no host policy engine is configured, but a user policy engine is // configured, the host certificate is denied. if v.hostPolicyEngine == nil && v.userPolicyEngine != nil { return errors.New("SSH host certificate not authorized") } return v.hostPolicyEngine.IsSSHCertificateAllowed(cert) case ssh.UserCert: // when no user policy engine is configured, but a host policy engine is // configured, the user certificate is denied. if v.userPolicyEngine == nil && v.hostPolicyEngine != nil { return errors.New("SSH user certificate not authorized") } return v.userPolicyEngine.IsSSHCertificateAllowed(cert) default: return fmt.Errorf("unexpected SSH certificate type %d", cert.CertType) // satisfy return; shouldn't happen } } // sshCertTypeUInt32 func sshCertTypeUInt32(ct string) uint32 { switch ct { case SSHUserCert: return ssh.UserCert case SSHHostCert: return ssh.HostCert default: return 0 } } // containsAllMembers reports whether all members of subgroup are within group. func containsAllMembers(group, subgroup []string) bool { lg, lsg := len(group), len(subgroup) if lsg > lg || (lg > 0 && lsg == 0) { return false } visit := make(map[string]struct{}, lg) for i := 0; i < lg; i++ { visit[strings.ToLower(group[i])] = struct{}{} } for i := 0; i < lsg; i++ { if _, ok := visit[strings.ToLower(subgroup[i])]; !ok { return false } } return true } func sshParseString(in []byte) (out, rest []byte, ok bool) { if len(in) < 4 { return } length := binary.BigEndian.Uint32(in) in = in[4:] if cast.Uint32(len(in)) < length { return } out = in[:length] rest = in[length:] ok = true return } func sshParseRSAPublicKey(in []byte) (*rsa.PublicKey, error) { var w struct { E *big.Int N *big.Int Rest []byte `ssh:"rest"` } if err := ssh.Unmarshal(in, &w); err != nil { return nil, errors.Wrap(err, "error unmarshalling public key") } if w.E.BitLen() > 24 { return nil, errors.New("invalid public key: exponent too large") } e := w.E.Int64() if e < 3 || e&1 == 0 { return nil, errors.New("invalid public key: incorrect exponent") } var key rsa.PublicKey key.E = int(e) key.N = w.N return &key, nil } ================================================ FILE: authority/provisioner/sign_ssh_options_test.go ================================================ package provisioner import ( "reflect" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" ) func TestSSHOptions_Type(t *testing.T) { type fields struct { CertType string } tests := []struct { name string fields fields want uint32 }{ {"user", fields{"user"}, 1}, {"host", fields{"host"}, 2}, {"empty", fields{""}, 0}, {"invalid", fields{"invalid"}, 0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := SignSSHOptions{ CertType: tt.fields.CertType, } if got := o.Type(); got != tt.want { t.Errorf("SSHOptions.Type() = %v, want %v", got, tt.want) } }) } } func TestSSHOptions_Modify(t *testing.T) { type test struct { so SignSSHOptions cert *ssh.Certificate valid func(*ssh.Certificate) err error } tests := map[string]func() test{ "fail/unexpected-cert-type": func() test { return test{ so: SignSSHOptions{CertType: "foo"}, cert: new(ssh.Certificate), err: errors.Errorf("ssh certificate has an unknown type 'foo'"), } }, "fail/validAfter-greater-validBefore": func() test { return test{ so: SignSSHOptions{CertType: "user"}, cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)}, err: errors.Errorf("ssh certificate validAfter cannot be greater than validBefore"), } }, "ok/user-cert": func() test { return test{ so: SignSSHOptions{CertType: "user"}, cert: new(ssh.Certificate), valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.CertType, uint32(ssh.UserCert)) }, } }, "ok/host-cert": func() test { return test{ so: SignSSHOptions{CertType: "host"}, cert: new(ssh.Certificate), valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.CertType, uint32(ssh.HostCert)) }, } }, "ok": func() test { va := time.Now().Add(5 * time.Minute) vb := time.Now().Add(1 * time.Hour) so := SignSSHOptions{CertType: "host", KeyID: "foo", Principals: []string{"foo", "bar"}, ValidAfter: NewTimeDuration(va), ValidBefore: NewTimeDuration(vb)} return test{ so: so, cert: new(ssh.Certificate), valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.CertType, uint32(ssh.HostCert)) assert.Equals(t, cert.KeyId, so.KeyID) assert.Equals(t, cert.ValidPrincipals, so.Principals) assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix())) assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix())) }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run() if err := tc.so.Modify(tc.cert, tc.so); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { tc.valid(tc.cert) } } }) } } func TestSSHOptions_Match(t *testing.T) { type test struct { so SignSSHOptions cmp SignSSHOptions err error } tests := map[string]func() test{ "fail/cert-type": func() test { return test{ so: SignSSHOptions{CertType: "foo"}, cmp: SignSSHOptions{CertType: "bar"}, err: errors.Errorf("ssh certificate type does not match - got bar, want foo"), } }, "fail/pricipals": func() test { return test{ so: SignSSHOptions{Principals: []string{"foo"}}, cmp: SignSSHOptions{Principals: []string{"bar"}}, err: errors.Errorf("ssh certificate principals does not match - got [bar], want [foo]"), } }, "fail/validAfter": func() test { return test{ so: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))}, cmp: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))}, err: errors.Errorf("ssh certificate validAfter does not match"), } }, "fail/validBefore": func() test { return test{ so: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))}, cmp: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))}, err: errors.Errorf("ssh certificate validBefore does not match"), } }, "ok/original-empty": func() test { return test{ so: SignSSHOptions{}, cmp: SignSSHOptions{ CertType: "foo", Principals: []string{"foo"}, ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)), ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)), }, } }, "ok/cmp-empty": func() test { return test{ cmp: SignSSHOptions{}, so: SignSSHOptions{ CertType: "foo", Principals: []string{"foo"}, ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute)), ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)), }, } }, "ok/equal": func() test { n := time.Now() va := NewTimeDuration(n.Add(1 * time.Minute)) vb := NewTimeDuration(n.Add(5 * time.Minute)) return test{ cmp: SignSSHOptions{ CertType: "foo", Principals: []string{"foo"}, ValidAfter: va, ValidBefore: vb, }, so: SignSSHOptions{ CertType: "foo", Principals: []string{"foo"}, ValidAfter: va, ValidBefore: vb, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run() if err := tc.so.match(tc.cmp); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func Test_sshCertValidAfterModifier_Modify(t *testing.T) { type test struct { modifier sshCertValidAfterModifier cert *ssh.Certificate expected uint64 } tests := map[string]func() test{ "ok": func() test { return test{ modifier: sshCertValidAfterModifier(15), cert: new(ssh.Certificate), expected: 15, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run() if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) { assert.Equals(t, tc.cert.ValidAfter, tc.expected) } }) } } func Test_sshCertDefaultValidator_Valid(t *testing.T) { pub, _, err := keyutil.GenerateDefaultKeyPair() assert.FatalError(t, err) sshPub, err := ssh.NewPublicKey(pub) assert.FatalError(t, err) v := sshCertDefaultValidator{} tests := []struct { name string cert *ssh.Certificate err error }{ { "fail/zero-nonce", &ssh.Certificate{}, errors.New("ssh certificate nonce cannot be empty"), }, { "fail/nil-key", &ssh.Certificate{Nonce: []byte("foo")}, errors.New("ssh certificate key cannot be nil"), }, { "fail/zero-serial", &ssh.Certificate{Nonce: []byte("foo"), Key: sshPub}, errors.New("ssh certificate serial cannot be 0"), }, { "fail/unexpected-cert-type", // UserCert = 1, HostCert = 2 &ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, CertType: 3, Serial: 1}, errors.New("ssh certificate has an unknown type '3'"), }, { "fail/empty-cert-key-id", &ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1}, errors.New("ssh certificate key id cannot be empty"), }, { "fail/zero-validAfter", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: 0, }, errors.New("ssh certificate validAfter cannot be 0"), }, { "fail/validBefore-past", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Add(-10 * time.Minute).Unix()), ValidBefore: uint64(time.Now().Add(-5 * time.Minute).Unix()), }, errors.New("ssh certificate validBefore cannot be in the past"), }, { "fail/validAfter-after-validBefore", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Add(15 * time.Minute).Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), }, errors.New("ssh certificate validBefore cannot be before validAfter"), }, { "fail/nil-signature-key", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), Permissions: ssh.Permissions{ Extensions: map[string]string{"foo": "bar"}, }, }, errors.New("ssh certificate signature key cannot be nil"), }, { "fail/nil-signature", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), Permissions: ssh.Permissions{ Extensions: map[string]string{"foo": "bar"}, }, SignatureKey: sshPub, }, errors.New("ssh certificate signature cannot be nil"), }, { "ok/userCert", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), Permissions: ssh.Permissions{ Extensions: map[string]string{"foo": "bar"}, }, SignatureKey: sshPub, Signature: &ssh.Signature{}, }, nil, }, { "ok/hostCert", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 2, KeyId: "foo", ValidPrincipals: []string{"foo"}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), SignatureKey: sshPub, Signature: &ssh.Signature{}, }, nil, }, { "ok/emptyPrincipals", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), SignatureKey: sshPub, Signature: &ssh.Signature{}, }, nil, }, { "ok/empty-extensions", &ssh.Certificate{ Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1, KeyId: "foo", ValidPrincipals: []string{}, ValidAfter: uint64(time.Now().Unix()), ValidBefore: uint64(time.Now().Add(10 * time.Minute).Unix()), SignatureKey: sshPub, Signature: &ssh.Signature{}, }, nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := v.Valid(tt.cert, SignSSHOptions{}); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { assert.Nil(t, tt.err) } }) } } func Test_sshCertValidityValidator(t *testing.T) { p, err := generateX5C(nil) assert.FatalError(t, err) v := sshCertValidityValidator{p.ctl.Claimer} n := now() tests := []struct { name string cert *ssh.Certificate opts SignSSHOptions err error }{ { "fail/validAfter-0", &ssh.Certificate{CertType: ssh.UserCert}, SignSSHOptions{}, errors.New("ssh certificate validAfter cannot be 0"), }, { "fail/validBefore-in-past", &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(-time.Minute).Unix())}, SignSSHOptions{}, errors.New("ssh certificate validBefore cannot be in the past"), }, { "fail/validBefore-before-validAfter", &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Add(5 * time.Minute).Unix()), ValidBefore: uint64(now().Add(3 * time.Minute).Unix())}, SignSSHOptions{}, errors.New("ssh certificate validBefore cannot be before validAfter"), }, { "fail/cert-type-not-set", &ssh.Certificate{ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix())}, SignSSHOptions{}, errors.New("ssh certificate type has not been set"), }, { "fail/unexpected-cert-type", &ssh.Certificate{ CertType: 3, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix()), }, SignSSHOptions{}, errors.New("ssh certificate has an unknown type '3'"), }, { "fail/durationmax", &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(48 * time.Hour).Unix()), }, SignSSHOptions{Backdate: time.Second}, errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m1s"), }, { "ok/duration-exactly-max", &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(24*time.Hour + time.Second).Unix()), }, SignSSHOptions{Backdate: time.Second}, nil, }, { "ok", &ssh.Certificate{ CertType: 1, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(8 * time.Hour).Unix()), }, SignSSHOptions{Backdate: time.Second}, nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := v.Valid(tt.cert, tt.opts); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { assert.Nil(t, tt.err) } }) } } func Test_sshValidityModifier(t *testing.T) { n, fn := mockNow() defer fn() p, err := generateX5C(nil) assert.FatalError(t, err) type test struct { svm *sshLimitDuration cert *ssh.Certificate valid func(*ssh.Certificate) err error } tests := map[string]func() test{ "fail/type-not-set": func() test { return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), }, err: errors.New("ssh certificate type has not been set"), } }, "fail/type-not-recognized": func() test { return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ CertType: 4, ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), }, err: errors.New("ssh certificate has an unknown type: 4"), } }, "fail/requested-validAfter-after-limit": func() test { return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Add(2 * time.Hour).Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), }, err: errors.Errorf("provisioning credential expiration ("), } }, "fail/requested-validBefore-after-limit": func() test { return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(2 * time.Hour).Unix()), }, err: errors.New("provisioning credential expiration ("), } }, "ok/no-limit": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.ValidAfter, va) assert.Equals(t, cert.ValidBefore, vb) }, } }, "ok/defaults": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.ValidAfter, va) assert.Equals(t, cert.ValidBefore, vb) }, } }, "ok/valid-requested-validBefore": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix()) return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, ValidBefore: vb, }, valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.ValidAfter, va) assert.Equals(t, cert.ValidBefore, vb) }, } }, "ok/empty-requested-validBefore-limit-after-default": func() test { va := uint64(n.Unix()) return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, }, valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.ValidAfter, va) assert.Equals(t, cert.ValidBefore, uint64(n.Add(16*time.Hour).Unix())) }, } }, "ok/empty-requested-validBefore-limit-before-default": func() test { va := uint64(n.Unix()) return test{ svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, }, valid: func(cert *ssh.Certificate) { assert.Equals(t, cert.ValidAfter, va) assert.Equals(t, cert.ValidBefore, uint64(n.Add(3*time.Hour).Unix())) }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() if err := tt.svm.Modify(tt.cert, SignSSHOptions{}); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { if assert.Nil(t, tt.err) { tt.valid(tt.cert) } } }) } } func Test_sshDefaultDuration_Option(t *testing.T) { tm, fn := mockNow() defer fn() newClaimer := func(claims *Claims) *Claimer { c, err := NewClaimer(claims, globalProvisionerClaims) if err != nil { t.Fatal(err) } return c } unix := func(d time.Duration) uint64 { return uint64(tm.Add(d).Unix()) } type fields struct { Claimer *Claimer } type args struct { o SignSSHOptions cert *ssh.Certificate } tests := []struct { name string fields fields args args want *ssh.Certificate wantErr bool }{ {"user", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.UserCert}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(0), ValidBefore: unix(16 * time.Hour)}, false}, {"host", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.HostCert}}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(0), ValidBefore: unix(30 * 24 * time.Hour)}, false}, {"user claim", fields{newClaimer(&Claims{DefaultUserSSHDur: &Duration{1 * time.Hour}})}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.UserCert}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(0), ValidBefore: unix(1 * time.Hour)}, false}, {"host claim", fields{newClaimer(&Claims{DefaultHostSSHDur: &Duration{1 * time.Hour}})}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.HostCert}}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(0), ValidBefore: unix(1 * time.Hour)}, false}, {"user backdate", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(16 * time.Hour)}, false}, {"host backdate", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert}}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(30 * 24 * time.Hour)}, false}, {"user validAfter", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(1 * time.Hour)}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Hour), ValidBefore: unix(17 * time.Hour)}, false}, {"user validBefore", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidBefore: unix(1 * time.Hour)}}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(time.Hour)}, false}, {"host validAfter validBefore", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}, false}, {"fail zero", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{}}, &ssh.Certificate{}, true}, {"fail type", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{CertType: 3}}, &ssh.Certificate{CertType: 3}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m := &sshDefaultDuration{ Claimer: tt.fields.Claimer, } if err := m.Modify(tt.args.cert, tt.args.o); (err != nil) != tt.wantErr { t.Errorf("sshDefaultDuration.Option() error = %v, wantErr %v", err, tt.wantErr) } if !reflect.DeepEqual(tt.args.cert, tt.want) { t.Errorf("sshDefaultDuration.Option() = %v, want %v", tt.args.cert, tt.want) } }) } } ================================================ FILE: authority/provisioner/ssh_options.go ================================================ package provisioner import ( "encoding/json" "strings" "github.com/pkg/errors" "github.com/smallstep/cli-utils/step" "go.step.sm/crypto/sshutil" "github.com/smallstep/certificates/authority/policy" ) // SSHCertificateOptions is an interface that returns a list of options passed when // creating a new certificate. type SSHCertificateOptions interface { Options(SignSSHOptions) []sshutil.Option } type sshCertificateOptionsFunc func(SignSSHOptions) []sshutil.Option func (fn sshCertificateOptionsFunc) Options(so SignSSHOptions) []sshutil.Option { return fn(so) } // SSHOptions are a collection of custom options that can be added to each // provisioner. type SSHOptions struct { // Template contains an SSH certificate template. It can be a JSON template // escaped in a string or it can be also encoded in base64. Template string `json:"template,omitempty"` // TemplateFile points to a file containing a SSH certificate template. TemplateFile string `json:"templateFile,omitempty"` // TemplateData is a JSON object with variables that can be used in custom // templates. TemplateData json.RawMessage `json:"templateData,omitempty"` // User contains SSH user certificate options. User *policy.SSHUserCertificateOptions `json:"-"` // Host contains SSH host certificate options. Host *policy.SSHHostCertificateOptions `json:"-"` } // GetAllowedUserNameOptions returns the SSHNameOptions that are // allowed when SSH User certificates are requested. func (o *SSHOptions) GetAllowedUserNameOptions() *policy.SSHNameOptions { if o == nil { return nil } if o.User == nil { return nil } return o.User.AllowedNames } // GetDeniedUserNameOptions returns the SSHNameOptions that are // denied when SSH user certificates are requested. func (o *SSHOptions) GetDeniedUserNameOptions() *policy.SSHNameOptions { if o == nil { return nil } if o.User == nil { return nil } return o.User.DeniedNames } // GetAllowedHostNameOptions returns the SSHNameOptions that are // allowed when SSH host certificates are requested. func (o *SSHOptions) GetAllowedHostNameOptions() *policy.SSHNameOptions { if o == nil { return nil } if o.Host == nil { return nil } return o.Host.AllowedNames } // GetDeniedHostNameOptions returns the SSHNameOptions that are // denied when SSH host certificates are requested. func (o *SSHOptions) GetDeniedHostNameOptions() *policy.SSHNameOptions { if o == nil { return nil } if o.Host == nil { return nil } return o.Host.DeniedNames } // HasTemplate returns true if a template is defined in the provisioner options. func (o *SSHOptions) HasTemplate() bool { return o != nil && (o.Template != "" || o.TemplateFile != "") } // TemplateSSHOptions generates a SSHCertificateOptions with the template and // data defined in the ProvisionerOptions, the provisioner generated data, and // the user data provided in the request. If no template has been provided, // x509util.DefaultLeafTemplate will be used. func TemplateSSHOptions(o *Options, data sshutil.TemplateData) (SSHCertificateOptions, error) { return CustomSSHTemplateOptions(o, data, sshutil.DefaultTemplate) } // CustomSSHTemplateOptions generates a CertificateOptions with the template, data // defined in the ProvisionerOptions, the provisioner generated data and the // user data provided in the request. If no template has been provided in the // ProvisionerOptions, the given template will be used. func CustomSSHTemplateOptions(o *Options, data sshutil.TemplateData, defaultTemplate string) (SSHCertificateOptions, error) { opts := o.GetSSHOptions() if data == nil { data = sshutil.NewTemplateData() } if opts != nil { // Add template data if any. if len(opts.TemplateData) > 0 && string(opts.TemplateData) != "null" { if err := json.Unmarshal(opts.TemplateData, &data); err != nil { return nil, errors.Wrap(err, "error unmarshaling template data") } } } return sshCertificateOptionsFunc(func(so SignSSHOptions) []sshutil.Option { // We're not provided user data without custom templates. if !opts.HasTemplate() { return []sshutil.Option{ sshutil.WithTemplate(defaultTemplate, data), } } // Add user provided data. if len(so.TemplateData) > 0 { userObject := make(map[string]interface{}) if err := json.Unmarshal(so.TemplateData, &userObject); err != nil { data.SetUserData(map[string]interface{}{}) } else { data.SetUserData(userObject) } } // Load a template from a file if Template is not defined. if opts.Template == "" && opts.TemplateFile != "" { return []sshutil.Option{ sshutil.WithTemplateFile(step.Abs(opts.TemplateFile), data), } } // Load a template from the Template fields // 1. As a JSON in a string. template := strings.TrimSpace(opts.Template) if strings.HasPrefix(template, "{") { return []sshutil.Option{ sshutil.WithTemplate(template, data), } } // 2. As a base64 encoded JSON. return []sshutil.Option{ sshutil.WithTemplateBase64(template, data), } }), nil } ================================================ FILE: authority/provisioner/ssh_options_test.go ================================================ package provisioner import ( "bytes" "reflect" "testing" "go.step.sm/crypto/sshutil" ) func TestCustomSSHTemplateOptions(t *testing.T) { cr := sshutil.CertificateRequest{ Type: "user", KeyID: "foo@smallstep.com", Principals: []string{"foo"}, } crCertificate := `{"Key":null,"Type":"user","KeyID":"foo@smallstep.com","Principals":["foo"]}` data := sshutil.CreateTemplateData(sshutil.HostCert, "smallstep.com", []string{"smallstep.com"}) type args struct { o *Options data sshutil.TemplateData defaultTemplate string userOptions SignSSHOptions } tests := []struct { name string args args want sshutil.Options wantErr bool }{ {"ok", args{nil, data, sshutil.DefaultTemplate, SignSSHOptions{}}, sshutil.Options{ CertBuffer: bytes.NewBufferString(`{ "type": "host", "keyId": "smallstep.com", "principals": ["smallstep.com"], "extensions": null, "criticalOptions": null }`), }, false}, {"okNoData", args{nil, nil, sshutil.DefaultTemplate, SignSSHOptions{}}, sshutil.Options{ CertBuffer: bytes.NewBufferString(`{ "type": null, "keyId": null, "principals": null, "extensions": null, "criticalOptions": null }`), }, false}, {"okTemplateData", args{&Options{SSH: &SSHOptions{TemplateData: []byte(`{"foo":"bar"}`)}}, data, sshutil.DefaultTemplate, SignSSHOptions{}}, sshutil.Options{ CertBuffer: bytes.NewBufferString(`{ "type": "host", "keyId": "smallstep.com", "principals": ["smallstep.com"], "extensions": null, "criticalOptions": null }`), }, false}, {"okNullTemplateData", args{&Options{SSH: &SSHOptions{TemplateData: []byte(`null`)}}, data, sshutil.DefaultTemplate, SignSSHOptions{}}, sshutil.Options{ CertBuffer: bytes.NewBufferString(`{ "type": "host", "keyId": "smallstep.com", "principals": ["smallstep.com"], "extensions": null, "criticalOptions": null }`), }, false}, // Note: `{{ toJson .Insecure.CR }}` is not a valid ssh template {"okTemplate", args{&Options{SSH: &SSHOptions{Template: "{{ toJson .Insecure.CR }}"}}, data, sshutil.DefaultTemplate, SignSSHOptions{}}, sshutil.Options{ CertBuffer: bytes.NewBufferString(crCertificate)}, false}, {"okFile", args{&Options{SSH: &SSHOptions{TemplateFile: "./testdata/templates/cr.tpl"}}, data, sshutil.DefaultTemplate, SignSSHOptions{}}, sshutil.Options{ CertBuffer: bytes.NewBufferString(crCertificate)}, false}, {"okBase64", args{&Options{SSH: &SSHOptions{Template: "e3sgdG9Kc29uIC5JbnNlY3VyZS5DUiB9fQ=="}}, data, sshutil.DefaultTemplate, SignSSHOptions{}}, sshutil.Options{ CertBuffer: bytes.NewBufferString(crCertificate)}, false}, {"okUserOptions", args{&Options{SSH: &SSHOptions{Template: `{"foo": "{{.Insecure.User.foo}}"}`}}, data, sshutil.DefaultTemplate, SignSSHOptions{TemplateData: []byte(`{"foo":"bar"}`)}}, sshutil.Options{ CertBuffer: bytes.NewBufferString(`{"foo": "bar"}`), }, false}, {"okNulUserOptions", args{&Options{SSH: &SSHOptions{Template: `{"foo": "{{.Insecure.User.foo}}"}`}}, data, sshutil.DefaultTemplate, SignSSHOptions{TemplateData: []byte(`null`)}}, sshutil.Options{ CertBuffer: bytes.NewBufferString(`{"foo": ""}`), }, false}, {"okBadUserOptions", args{&Options{SSH: &SSHOptions{Template: `{"foo": "{{.Insecure.User.foo}}"}`}}, data, sshutil.DefaultTemplate, SignSSHOptions{TemplateData: []byte(`{"badJSON"}`)}}, sshutil.Options{ CertBuffer: bytes.NewBufferString(`{"foo": ""}`), }, false}, {"fail", args{&Options{SSH: &SSHOptions{TemplateData: []byte(`{"badJSON`)}}, data, sshutil.DefaultTemplate, SignSSHOptions{}}, sshutil.Options{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cof, err := CustomSSHTemplateOptions(tt.args.o, tt.args.data, tt.args.defaultTemplate) if (err != nil) != tt.wantErr { t.Errorf("CustomSSHTemplateOptions() error = %v, wantErr %v", err, tt.wantErr) return } var opts sshutil.Options if cof != nil { for _, fn := range cof.Options(tt.args.userOptions) { if err := fn(cr, &opts); err != nil { t.Errorf("x509util.Options() error = %v", err) return } } } if !reflect.DeepEqual(opts, tt.want) { t.Errorf("CustomSSHTemplateOptions() = %v, want %v", opts, tt.want) } }) } } ================================================ FILE: authority/provisioner/ssh_test.go ================================================ package provisioner import ( "crypto" "errors" "fmt" "net/http" "reflect" "time" "golang.org/x/crypto/ssh" "go.step.sm/crypto/sshutil" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" ) func validateSSHCertificate(cert *ssh.Certificate, opts *SignSSHOptions) error { switch { case cert == nil: return fmt.Errorf("certificate is nil") case cert.Signature == nil: return fmt.Errorf("certificate signature is nil") case cert.SignatureKey == nil: return fmt.Errorf("certificate signature is nil") case !reflect.DeepEqual(cert.ValidPrincipals, opts.Principals) && (len(opts.Principals) > 0 || len(cert.ValidPrincipals) > 0): return fmt.Errorf("certificate principals are not equal, want %v, got %v", opts.Principals, cert.ValidPrincipals) case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert: return fmt.Errorf("certificate type %v is not valid", cert.CertType) case opts.CertType == "user" && cert.CertType != ssh.UserCert: return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.UserCert, cert.CertType) case opts.CertType == "host" && cert.CertType != ssh.HostCert: return fmt.Errorf("certificate type is not valid, want %v, got %v", ssh.HostCert, cert.CertType) case cert.ValidAfter != uint64(opts.ValidAfter.Unix()): return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(cast.Int64(cert.ValidAfter), 0)) case cert.ValidBefore != uint64(opts.ValidBefore.Unix()): return fmt.Errorf("certificate valid after is not valid, want %v, got %v", opts.ValidAfter.Unix(), time.Unix(cast.Int64(cert.ValidAfter), 0)) case opts.CertType == "user" && len(cert.Extensions) != 5: return fmt.Errorf("certificate extensions number is invalid, want 5, got %d", len(cert.Extensions)) case opts.CertType == "host" && len(cert.Extensions) != 0: return fmt.Errorf("certificate extensions number is invalid, want 0, got %d", len(cert.Extensions)) default: return nil } } func signSSHCertificate(key crypto.PublicKey, opts SignSSHOptions, signOpts []SignOption, signKey crypto.Signer) (*ssh.Certificate, error) { pub, err := ssh.NewPublicKey(key) if err != nil { return nil, err } var mods []SSHCertModifier var certOptions []sshutil.Option var validators []SSHCertValidator for _, op := range signOpts { switch o := op.(type) { case Interface: // add options to NewCertificate case SSHCertificateOptions: certOptions = append(certOptions, o.Options(opts)...) // modify the ssh.Certificate case SSHCertModifier: mods = append(mods, o) // validate the ssh.Certificate case SSHCertValidator: validators = append(validators, o) // validate the given SSHOptions case SSHCertOptionsValidator: if err := o.Valid(opts); err != nil { return nil, err } // call webhooks case *WebhookController: default: return nil, fmt.Errorf("signSSH: invalid extra option type %T", o) } } // Simulated certificate request with request options. cr := sshutil.CertificateRequest{ Type: opts.CertType, KeyID: opts.KeyID, Principals: opts.Principals, Key: pub, } // Create certificate from template. certificate, err := sshutil.NewCertificate(cr, certOptions...) if err != nil { var templErr *sshutil.TemplateError if errors.As(err, &templErr) { return nil, errs.NewErr(http.StatusBadRequest, templErr, errs.WithMessage("%s", templErr.Error()), errs.WithKeyVal("signOptions", signOpts), ) } return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH") } // Get actual *ssh.Certificate and continue with provisioner modifiers. cert := certificate.GetCertificate() // Use SignSSHOptions to modify the certificate validity. It will be later // checked or set if not defined. if err := opts.ModifyValidity(cert); err != nil { return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") } // Use provisioner modifiers. for _, m := range mods { if err := m.Modify(cert, opts); err != nil { return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH") } } // Get signer from authority keys var signer ssh.Signer switch cert.CertType { case ssh.UserCert: signer, err = ssh.NewSignerFromSigner(signKey) case ssh.HostCert: signer, err = ssh.NewSignerFromSigner(signKey) default: return nil, fmt.Errorf("unexpected ssh certificate type: %d", cert.CertType) } if err != nil { return nil, err } // Sign certificate. cert, err = sshutil.CreateCertificate(cert, signer) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error signing certificate") } // User provisioners validators. for _, v := range validators { if err := v.Valid(cert, opts); err != nil { return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH") } } return cert, nil } ================================================ FILE: authority/provisioner/sshpop.go ================================================ package provisioner import ( "context" "encoding/base64" "net/http" "strconv" "time" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "go.step.sm/crypto/jose" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" ) // sshPOPPayload extends jwt.Claims with step attributes. type sshPOPPayload struct { jose.Claims SANs []string `json:"sans,omitempty"` Step *stepPayload `json:"step,omitempty"` sshCert *ssh.Certificate } // SSHPOP is the default provisioner, an entity that can sign tokens necessary for // signature requests. type SSHPOP struct { *base ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Claims *Claims `json:"claims,omitempty"` ctl *Controller sshPubKeys *SSHKeys } // GetID returns the provisioner unique identifier. The name and credential id // should uniquely identify any SSH-POP provisioner. func (p *SSHPOP) GetID() string { if p.ID != "" { return p.ID } return p.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (p *SSHPOP) GetIDForToken() string { return "sshpop/" + p.Name } // GetTokenID returns the identifier of the token. func (p *SSHPOP) GetTokenID(ott string) (string, error) { // Validate payload token, err := jose.ParseSigned(ott) if err != nil { return "", errors.Wrap(err, "error parsing token") } // Get claims w/out verification. We need to look up the provisioner // key in order to verify the claims and we need the issuer from the claims // before we can look up the provisioner. var claims jose.Claims if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { return "", errors.Wrap(err, "error verifying claims") } return claims.ID, nil } // GetName returns the name of the provisioner. func (p *SSHPOP) GetName() string { return p.Name } // GetType returns the type of provisioner. func (p *SSHPOP) GetType() Type { return TypeSSHPOP } // GetEncryptedKey returns the base provisioner encrypted key if it's defined. func (p *SSHPOP) GetEncryptedKey() (string, string, bool) { return "", "", false } // Init initializes and validates the fields of a SSHPOP type. func (p *SSHPOP) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") case p.Name == "": return errors.New("provisioner name cannot be empty") case config.SSHKeys == nil: return errors.New("provisioner public SSH validation keys cannot be empty") } p.sshPubKeys = config.SSHKeys config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.ctl, err = NewController(p, p.Claims, config, nil) return } // authorizeToken performs common jwt authorization actions and returns the // claims for case specific downstream parsing. // e.g. a Sign request will auth/validate different fields than a Revoke request. // // Checking for certificate revocation has been moved to the authority package. func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) { sshCert, jwt, err := ExtractSSHPOPCert(token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; error extracting sshpop header from token") } // Check validity period of the certificate. // // Controller.AuthorizeSSHRenew will validate this on the renewal flow. if checkValidity { unixNow := time.Now().Unix() if after := cast.Int64(sshCert.ValidAfter); after < 0 || unixNow < cast.Int64(sshCert.ValidAfter) { return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") } if before := cast.Int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) { return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") } } sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) if !ok { return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey") } pubKey := sshCryptoPubKey.CryptoPublicKey() var ( found bool data = bytesForSigning(sshCert) keys []ssh.PublicKey ) if sshCert.CertType == ssh.UserCert { keys = p.sshPubKeys.UserKeys } else { keys = p.sshPubKeys.HostKeys } for _, k := range keys { if err = (&ssh.Certificate{Key: k}).Verify(data, sshCert.Signature); err == nil { found = true break } } if !found { return nil, errs.Unauthorized("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate") } // Using the ssh certificates key to validate the claims accomplishes two // things: // 1. Asserts that the private key used to sign the token corresponds // to the public certificate in the `sshpop` header of the token. // 2. Asserts that the claims are valid - have not been tampered with. var claims sshPOPPayload if err = jwt.Claims(pubKey, &claims); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; error parsing sshpop token claims") } // According to "rfc7519 JSON Web Token" acceptable skew should be no // more than a few minutes. if err = claims.ValidateWithLeeway(jose.Expected{ Issuer: p.Name, Time: time.Now().UTC(), }, time.Minute); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "sshpop.authorizeToken; invalid sshpop token") } // validate audiences with the defaults if !matchesAudience(claims.Audience, audiences) { return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token has invalid audience "+ "claim (aud): expected %s, but got %s", audiences, claims.Audience) } if claims.Subject == "" { return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop token subject cannot be empty") } claims.sshCert = sshCert return &claims, nil } // AuthorizeSSHRevoke validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRevoke(_ context.Context, token string) error { claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } if serial := strconv.FormatUint(claims.sshCert.Serial, 10); claims.Subject != serial { return errs.Forbidden( "token subject %q and sshpop certificate serial number %q do not match", claims.Subject, serial, ) } return nil } // AuthorizeSSHRenew validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } if claims.sshCert.CertType != ssh.HostCert { return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert) } // AuthorizeSSHRekey validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRekey(_ context.Context, token string) (*ssh.Certificate, []SignOption, error) { claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true) if err != nil { return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } if claims.sshCert.CertType != ssh.HostCert { return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } return claims.sshCert, []SignOption{ p, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, }, nil } // ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate // in the sshpop header. If the header is missing, an error is returned. func ExtractSSHPOPCert(token string) (*ssh.Certificate, *jose.JSONWebToken, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, nil, errors.Wrapf(err, "extractSSHPOPCert; error parsing token") } encodedSSHCert, ok := jwt.Headers[0].ExtraHeaders["sshpop"] if !ok { return nil, nil, errors.New("extractSSHPOPCert; token missing sshpop header") } encodedSSHCertStr, ok := encodedSSHCert.(string) if !ok { return nil, nil, errors.Errorf("extractSSHPOPCert; error unexpected type for sshpop header: "+ "want 'string', but got '%T'", encodedSSHCert) } sshCertBytes, err := base64.StdEncoding.DecodeString(encodedSSHCertStr) if err != nil { return nil, nil, errors.Wrap(err, "extractSSHPOPCert; error base64 decoding sshpop header") } sshPub, err := ssh.ParsePublicKey(sshCertBytes) if err != nil { return nil, nil, errors.Wrap(err, "extractSSHPOPCert; error parsing ssh public key") } sshCert, ok := sshPub.(*ssh.Certificate) if !ok { return nil, nil, errors.New("extractSSHPOPCert; error converting ssh public key to ssh certificate") } return sshCert, jwt, nil } func bytesForSigning(cert *ssh.Certificate) []byte { c2 := *cert c2.Signature = nil out := c2.Marshal() // Drop trailing signature length. return out[:len(out)-4] } ================================================ FILE: authority/provisioner/sshpop_test.go ================================================ package provisioner import ( "context" "crypto" "crypto/rand" "encoding/base64" "errors" "fmt" "net/http" "testing" "time" "golang.org/x/crypto/ssh" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" ) func TestSSHPOP_Getters(t *testing.T) { p, err := generateSSHPOP() assert.FatalError(t, err) id := "sshpop/" + p.Name if got := p.GetID(); got != id { t.Errorf("SSHPOP.GetID() = %v, want %v", got, id) } if got := p.GetName(); got != p.Name { t.Errorf("SSHPOP.GetName() = %v, want %v", got, p.Name) } if got := p.GetType(); got != TypeSSHPOP { t.Errorf("SSHPOP.GetType() = %v, want %v", got, TypeSSHPOP) } kid, key, ok := p.GetEncryptedKey() if kid != "" || key != "" || ok == true { t.Errorf("SSHPOP.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } } func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { now := time.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) if err != nil { return nil, nil, err } cert.Key, err = ssh.NewPublicKey(jwk.Public().Key) if err != nil { return nil, nil, err } if cert.ValidAfter == 0 { cert.ValidAfter = uint64(now.Unix()) } if cert.ValidBefore == 0 { cert.ValidBefore = uint64(now.Add(time.Hour).Unix()) } if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } return cert, jwk, nil } func generateSSHPOPToken(p Interface, cert *ssh.Certificate, jwk *jose.JSONWebKey) (string, error) { return generateToken("foo", p.GetName(), testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) } func TestSSHPOP_authorizeToken(t *testing.T) { key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key") assert.FatalError(t, err) signer, ok := key.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") sshSigner, err := ssh.NewSignerFromSigner(signer) assert.FatalError(t, err) type test struct { p *SSHPOP token string err error code int } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "), } }, "fail/cert-not-yet-valid": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{ CertType: ssh.UserCert, ValidAfter: uint64(time.Now().Add(time.Minute).Unix()), }, sshSigner) assert.FatalError(t, err) tok, err := generateSSHPOPToken(p, cert, jwk) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("sshpop.authorizeToken; sshpop certificate validAfter is in the future"), } }, "fail/cert-past-validity": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{ CertType: ssh.UserCert, ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()), }, sshSigner) assert.FatalError(t, err) tok, err := generateSSHPOPToken(p, cert, jwk) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("sshpop.authorizeToken; sshpop certificate validBefore is in the past"), } }, "fail/no-signer-found": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) assert.FatalError(t, err) tok, err := generateSSHPOPToken(p, cert, jwk) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("sshpop.authorizeToken; could not find valid ca signer to verify sshpop certificate"), } }, "fail/error-parsing-claims-bad-sig": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) tok, err := generateSSHPOPToken(p, cert, otherJWK) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("sshpop.authorizeToken; error parsing sshpop token claims"), } }, "fail/invalid-claims-issuer": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("sshpop.authorizeToken; invalid sshpop token"), } }, "fail/invalid-audience": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), "invalid-aud", "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("sshpop.authorizeToken; sshpop token has invalid audience claim (aud)"), } }, "fail/empty-subject": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("sshpop.authorizeToken; sshpop token subject cannot be empty"), } }, "ok": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateSSHPOPToken(p, cert, jwk) assert.FatalError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equals(t, sc.StatusCode(), tc.code) } if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else if assert.Nil(t, tc.err) { assert.NotNil(t, claims) } }) } } func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key") assert.FatalError(t, err) signer, ok := key.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") sshSigner, err := ssh.NewSignerFromSigner(signer) assert.FatalError(t, err) type test struct { p *SSHPOP token string err error code int } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("sshpop.AuthorizeSSHRevoke: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "), } }, "fail/subject-not-equal-serial": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusForbidden, err: errors.New(`token subject "foo" and sshpop certificate serial number "0" do not match`), } }, "ok": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner) assert.FatalError(t, err) tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equals(t, sc.StatusCode(), tc.code) } if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key") assert.FatalError(t, err) userSigner, ok := key.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh user signing key to crypto signer") sshUserSigner, err := ssh.NewSignerFromSigner(userSigner) assert.FatalError(t, err) hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") assert.FatalError(t, err) hostSigner, ok := hostKey.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer") sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner) assert.FatalError(t, err) type test struct { p *SSHPOP token string cert *ssh.Certificate err error code int } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("sshpop.AuthorizeSSHRenew: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "), } }, "fail/not-host-cert": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusBadRequest, err: errors.New("sshpop certificate must be a host ssh certificate"), } }, "ok": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner) assert.FatalError(t, err) tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return test{ p: p, token: tok, cert: cert, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equals(t, sc.StatusCode(), tc.code) } assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.cert.Nonce, cert.Nonce) } } }) } } func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { key, err := pemutil.Read("./testdata/secrets/ssh_user_ca_key") assert.FatalError(t, err) userSigner, ok := key.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh user signing key to crypto signer") sshUserSigner, err := ssh.NewSignerFromSigner(userSigner) assert.FatalError(t, err) hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") assert.FatalError(t, err) hostSigner, ok := hostKey.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer") sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner) assert.FatalError(t, err) type test struct { p *SSHPOP token string cert *ssh.Certificate err error code int } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("sshpop.AuthorizeSSHRekey: sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "), } }, "fail/not-host-cert": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner) assert.FatalError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return test{ p: p, token: tok, code: http.StatusBadRequest, err: errors.New("sshpop certificate must be a host ssh certificate"), } }, "ok": func(t *testing.T) test { p, err := generateSSHPOP() assert.FatalError(t, err) cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner) assert.FatalError(t, err) tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return test{ p: p, token: tok, cert: cert, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equals(t, sc.StatusCode(), tc.code) } assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Len(t, 4, opts) for _, o := range opts { switch v := o.(type) { case Interface: case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshCertValidityValidator: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } assert.Equals(t, tc.cert.Nonce, cert.Nonce) } } }) } } func TestSSHPOP_ExtractSSHPOPCert(t *testing.T) { hostKey, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") assert.FatalError(t, err) hostSigner, ok := hostKey.(crypto.Signer) assert.Fatal(t, ok, "could not cast ssh host signing key to crypto signer") sshHostSigner, err := ssh.NewSignerFromSigner(hostSigner) assert.FatalError(t, err) type test struct { token string cert *ssh.Certificate jwk *jose.JSONWebKey err error } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { return test{ token: "foo", err: errors.New("extractSSHPOPCert; error parsing token"), } }, "fail/sshpop-missing": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) tok, err := generateToken("sub", "sshpop-provisioner", testAudiences.SSHRekey[0], "", []string{"test.smallstep.com"}, time.Now(), jwk) assert.FatalError(t, err) return test{ token: tok, err: errors.New("extractSSHPOPCert; token missing sshpop header"), } }, "fail/wrong-sshpop-type": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error { so.WithHeader("sshpop", 12345) return nil }) assert.FatalError(t, err) return test{ token: tok, err: errors.New("extractSSHPOPCert; error unexpected type for sshpop header: "), } }, "fail/base64decode-error": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error { so.WithHeader("sshpop", "!@#$%^&*") return nil }) assert.FatalError(t, err) return test{ token: tok, err: errors.New("extractSSHPOPCert; error base64 decoding sshpop header: illegal base64"), } }, "fail/parsing-sshpop-pubkey": func(t *testing.T) test { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, func(so *jose.SignerOptions) error { so.WithHeader("sshpop", base64.StdEncoding.EncodeToString([]byte("foo"))) return nil }) assert.FatalError(t, err) return test{ token: tok, err: errors.New("extractSSHPOPCert; error parsing ssh public key"), } }, "ok": func(t *testing.T) test { cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner) assert.FatalError(t, err) tok, err := generateToken("123455", "sshpop-provisioner", testAudiences.SSHRekey[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withSSHPOPFile(cert)) assert.FatalError(t, err) return test{ token: tok, jwk: jwk, cert: cert, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if cert, jwt, err := ExtractSSHPOPCert(tc.token); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.cert.Nonce, cert.Nonce) assert.Equals(t, tc.jwk.KeyID, jwt.Headers[0].KeyID) } } }) } } ================================================ FILE: authority/provisioner/testdata/certs/apple-att-ca.crt ================================================ -----BEGIN CERTIFICATE----- MIICJDCCAamgAwIBAgIUQsDCuyxyfFxeq/bxpm8frF15hzcwCgYIKoZIzj0EAwMw UTEtMCsGA1UEAwwkQXBwbGUgRW50ZXJwcmlzZSBBdHRlc3RhdGlvbiBSb290IENB MRMwEQYDVQQKDApBcHBsZSBJbmMuMQswCQYDVQQGEwJVUzAeFw0yMjAyMTYxOTAx MjRaFw00NzAyMjAwMDAwMDBaMFExLTArBgNVBAMMJEFwcGxlIEVudGVycHJpc2Ug QXR0ZXN0YXRpb24gUm9vdCBDQTETMBEGA1UECgwKQXBwbGUgSW5jLjELMAkGA1UE BhMCVVMwdjAQBgcqhkjOPQIBBgUrgQQAIgNiAAT6Jigq+Ps9Q4CoT8t8q+UnOe2p oT9nRaUfGhBTbgvqSGXPjVkbYlIWYO+1zPk2Sz9hQ5ozzmLrPmTBgEWRcHjA2/y7 7GEicps9wn2tj+G89l3INNDKETdxSPPIZpPj8VmjQjBAMA8GA1UdEwEB/wQFMAMB Af8wHQYDVR0OBBYEFPNqTQGd8muBpV5du+UIbVbi+d66MA4GA1UdDwEB/wQEAwIB BjAKBggqhkjOPQQDAwNpADBmAjEA1xpWmTLSpr1VH4f8Ypk8f3jMUKYz4QPG8mL5 8m9sX/b2+eXpTv2pH4RZgJjucnbcAjEA4ZSB6S45FlPuS/u4pTnzoz632rA+xW/T ZwFEh9bhKjJ+5VQ9/Do1os0u3LEkgN/r -----END CERTIFICATE----- ================================================ FILE: authority/provisioner/testdata/certs/aws-test.crt ================================================ -----BEGIN CERTIFICATE----- MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw FgYDVQQKEw9BbWF6b24uY29tIEluYy4xGjAYBgNVBAMTEWVjMi5hbWF6b25hd3Mu Y29tMB4XDTE0MDYwNTE0MjgwMloXDTI0MDYwNTE0MjgwMlowajELMAkGA1UEBhMC VVMxEzARBgNVBAgTCldhc2hpbmd0b24xEDAOBgNVBAcTB1NlYXR0bGUxGDAWBgNV BAoTD0FtYXpvbi5jb20gSW5jLjEaMBgGA1UEAxMRZWMyLmFtYXpvbmF3cy5jb20w gZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAIe9GN//SRK2knbjySG0ho3yqQM3 e2TDhWO8D2e8+XZqck754gFSo99AbT2RmXClambI7xsYHZFapbELC4H91ycihvrD jbST1ZjkLQgga0NE1q43eS68ZeTDccScXQSNivSlzJZS8HJZjgqzBlXjZftjtdJL XeE4hwvo0sD4f3j9AgMBAAGjgc8wgcwwHQYDVR0OBBYEFCXWzAgVyrbwnFncFFIs 77VBdlE4MIGcBgNVHSMEgZQwgZGAFCXWzAgVyrbwnFncFFIs77VBdlE4oW6kbDBq MQswCQYDVQQGEwJVUzETMBEGA1UECBMKV2FzaGluZ3RvbjEQMA4GA1UEBxMHU2Vh dHRsZTEYMBYGA1UEChMPQW1hem9uLmNvbSBJbmMuMRowGAYDVQQDExFlYzIuYW1h em9uYXdzLmNvbYIJAKnL4UEDMN/FMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF BQADgYEAFYcz1OgEhQBXIwIdsgCOS8vEtiJYF+j9uO6jz7VOmJqO+pRlAbRlvY8T C1haGgSI/A1uZUKs/Zfnph0oEI0/hu1IIJ/SKBDtN5lvmZ/IzbOPIJWirlsllQIQ 7zvWbGd9c9+Rm3p04oTvhup99la7kZqevJK0QRdD/6NpCKsqP/0= -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIICFTCCAX6gAwIBAgIRAKmbVVYAl/1XEqRfF3eJ97MwDQYJKoZIhvcNAQELBQAw GDEWMBQGA1UEAxMNQVdTIFRlc3QgQ2VydDAeFw0xOTA0MjQyMjU3MzlaFw0yOTA0 MjEyMjU3MzlaMBgxFjAUBgNVBAMTDUFXUyBUZXN0IENlcnQwgZ8wDQYJKoZIhvcN AQEBBQADgY0AMIGJAoGBAOHMmMXwbXN90SoRl/xXAcJs5TacaVYJ5iNAVWM5KYyF +JwqYuJp/umLztFUi0oX0luu3EzD4KurVeUJSzZjTFTX1d/NX6hA45+bvdSUOcgV UghO+2uhBZ4SNFxFRZ7SKvoWIN195l5bVX6/60Eo6+kUCKCkyxW4V/ksWzdXjHnf AgMBAAGjXzBdMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0G A1UdDgQWBBRHfLOjEddK/CWCIHNg8Oc/oJa1IzAYBgNVHREEETAPgg1BV1MgVGVz dCBDZXJ0MA0GCSqGSIb3DQEBCwUAA4GBAKNCiVM9eGb9dW2xNyHaHAmmy7ERB2OJ 7oXHfLjooOavk9lU/Gs2jfX/JSBa84+DzWg9ShmCNLti8CxU/dhzXW7jE/5CcdTa DCA6B3Yl5TmfG9+D9dtFqRB2CiMgNcsJJE5Dc6pDwBIiSj/MkE0AaGVQmSwn6Cb6 vX1TAxqeWJHq -----END CERTIFICATE----- ================================================ FILE: authority/provisioner/testdata/certs/aws.crt ================================================ -----BEGIN CERTIFICATE----- MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw FgYDVQQKEw9BbWF6b24uY29tIEluYy4xGjAYBgNVBAMTEWVjMi5hbWF6b25hd3Mu Y29tMB4XDTE0MDYwNTE0MjgwMloXDTI0MDYwNTE0MjgwMlowajELMAkGA1UEBhMC VVMxEzARBgNVBAgTCldhc2hpbmd0b24xEDAOBgNVBAcTB1NlYXR0bGUxGDAWBgNV BAoTD0FtYXpvbi5jb20gSW5jLjEaMBgGA1UEAxMRZWMyLmFtYXpvbmF3cy5jb20w gZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAIe9GN//SRK2knbjySG0ho3yqQM3 e2TDhWO8D2e8+XZqck754gFSo99AbT2RmXClambI7xsYHZFapbELC4H91ycihvrD jbST1ZjkLQgga0NE1q43eS68ZeTDccScXQSNivSlzJZS8HJZjgqzBlXjZftjtdJL XeE4hwvo0sD4f3j9AgMBAAGjgc8wgcwwHQYDVR0OBBYEFCXWzAgVyrbwnFncFFIs 77VBdlE4MIGcBgNVHSMEgZQwgZGAFCXWzAgVyrbwnFncFFIs77VBdlE4oW6kbDBq MQswCQYDVQQGEwJVUzETMBEGA1UECBMKV2FzaGluZ3RvbjEQMA4GA1UEBxMHU2Vh dHRsZTEYMBYGA1UEChMPQW1hem9uLmNvbSBJbmMuMRowGAYDVQQDExFlYzIuYW1h em9uYXdzLmNvbYIJAKnL4UEDMN/FMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF BQADgYEAFYcz1OgEhQBXIwIdsgCOS8vEtiJYF+j9uO6jz7VOmJqO+pRlAbRlvY8T C1haGgSI/A1uZUKs/Zfnph0oEI0/hu1IIJ/SKBDtN5lvmZ/IzbOPIJWirlsllQIQ 7zvWbGd9c9+Rm3p04oTvhup99la7kZqevJK0QRdD/6NpCKsqP/0= -----END CERTIFICATE----- ================================================ FILE: authority/provisioner/testdata/certs/bad-extension.crt ================================================ -----BEGIN CERTIFICATE----- MIIDeTCCAx+gAwIBAgIRAOTItW2pYuSU+PkmLW090iUwCgYIKoZIzj0EAwIwJDEi MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjUy MjBaFw0yMjAzMTIyMjUzMjBaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0 ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI KoZIzj0DAQcDQgAE/9vvOZ1Zzysnf3VeGyotMJEMZdAborB36Ah5QL/3yQNMRWIc pv9Dwx19pHw7SquVE8jIaPPJSjaeWnfMPDYDxaOCAbcwggGzMA4GA1UdDwEB/wQE AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw ADAdBgNVHQ4EFgQUkJUg6AsqWlqTZt6BHidRMwh1vKYwHwYDVR0jBBgwFoAUDpTg d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2 LXNlcnZlci1nMy5jcmwwFwYMKwYBBAGCpGTGKEABBAdmb29vYmFyMAoGCCqGSM49 BAMCA0gAMEUCIQCWYqOuk4bLkVVeHvo3P8TlJJ3fw6ijDDLstvdrQqAl5wIgEjSY wVcR649Oc8PJGh/43Kpx0+4OTYPQrD/JqphVF7g= -----END CERTIFICATE----- ================================================ FILE: authority/provisioner/testdata/certs/bar.pub ================================================ -----BEGIN PUBLIC KEY----- MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEGQIXbsr73X28pzwC1wa+ccY2H3s8 PplbkapCrwxyYVvM78y/GmeSA7fv2MKQ8iKpCw461MlOQGX+VlWT+ChRFw== -----END PUBLIC KEY----- ================================================ FILE: authority/provisioner/testdata/certs/ecdsa.csr ================================================ -----BEGIN CERTIFICATE REQUEST----- MIHqMIGRAgEAMA4xDDAKBgNVBAMTA2ZvbzBZMBMGByqGSM49AgEGCCqGSM49AwEH A0IABKdDjTb7XIYCWC4QUq1xn5hgf3J4WpfWbd3C5frKrA4/VdQ+XfpHQIxDoHqh jcWke0SEETc9i6HDDtWv8bXSETegITAfBgkqhkiG9w0BCQ4xEjAQMA4GA1UdEQQH MAWCA2ZvbzAKBggqhkjOPQQDAgNIADBFAiEA1pFLT8p/YogG0o6NEEmdxzwbOzJA A+C+DvoT91c1OcQCIGUjP3s+k6Xwdf/VukUZXTfG1lobmkZhO3vYxAjPkwA7 -----END CERTIFICATE REQUEST----- ================================================ FILE: authority/provisioner/testdata/certs/ed25519.csr ================================================ -----BEGIN CERTIFICATE REQUEST----- MIGuMGICAQAwDjEMMAoGA1UEAxMDZm9vMCowBQYDK2VwAyEA3yF/Igqb5UTp6XOq yj+cZL9nIfjDKrUT0fMzDAHtIqqgITAfBgkqhkiG9w0BCQ4xEjAQMA4GA1UdEQQH MAWCA2ZvbzAFBgMrZXADQQAIAx7N6ezi4NL8n0oJU8v3AmVSi0XvTuIHXUtcLGoU OZtlO3zjWI+DgcT/ADeEKn+T8OazDxcCbTBbHiM2hIsA -----END CERTIFICATE REQUEST----- ================================================ FILE: authority/provisioner/testdata/certs/foo.crt ================================================ -----BEGIN CERTIFICATE----- MIICIDCCAcagAwIBAgIQTL7pKDl8mFzRziotXbgjEjAKBggqhkjOPQQDAjAnMSUw IwYDVQQDExxFeGFtcGxlIEluYy4gSW50ZXJtZWRpYXRlIENBMB4XDTE5MDMyMjIy MjkyOVoXDTE5MDMyMzIyMjkyOVowHDEaMBgGA1UEAxMRZm9vLnNtYWxsc3RlcC5j b20wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQbptfDonFaeUPiTr52wl9r3dcz greolwDRmsgyFgnr1EuKH56WRcgH1gjfL0pybFlO3PdgBukR4u+sveq343OAo4He MIHbMA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUH AwIwHQYDVR0OBBYEFP9pHiVlsx5mr4L2QirOb1G9Mo4jMB8GA1UdIwQYMBaAFKEe 9IdMyaHdURMjoJce7FN9HC9wMBwGA1UdEQQVMBOCEWZvby5zbWFsbHN0ZXAuY29t MEwGDCsGAQQBgqRkxihAAQQ8MDoCAQEECHN0ZXAtY2xpBCs0VUVMSng4ZTBhUzlt MENIM2ZaMEVCN0Q1YVVQSUNiNzU5ekFMSEZlanZjMAoGCCqGSM49BAMCA0gAMEUC IDxtNo1BX/4Sbf/+k1n+v//kh8ETr3clPvhjcyfvBIGTAiEAiT0kvbkPdCCnmHIw lhpgBwT5YReZzBwIYXyKyJXc07M= -----END CERTIFICATE----- ================================================ FILE: authority/provisioner/testdata/certs/foo.pub ================================================ -----BEGIN PUBLIC KEY----- MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEzriaeV2e1aEz33x62kyqVC6ootU7 rl41L8cyeOJ4SjTu4FV+o5i4NsS6DCE07JJSHlqc9PsrzjSs4LZD4gWVLQ== -----END PUBLIC KEY----- ================================================ FILE: authority/provisioner/testdata/certs/good-extension.crt ================================================ -----BEGIN CERTIFICATE----- MIIDujCCA2GgAwIBAgIRAM5celDKTTqAGycljO7FZdEwCgYIKoZIzj0EAwIwJDEi MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjQx MDRaFw0yMjAzMTIyMjQyMDRaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0 ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI KoZIzj0DAQcDQgAEkXffZYlSJRMxJrZHmUpEMC4jQYCkF86mLJY0iLZ8k00N/xF0 4rAGwzTU/l9tfRpNl+z/XfMMWPXS0Q8NU/o4S6OCAfkwggH1MA4GA1UdDwEB/wQE AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw ADAdBgNVHQ4EFgQUL3sSlYW8Tf2l2P+gFTdn5wsUjfgwHwYDVR0jBBgwFoAUDpTg d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2 LXNlcnZlci1nMy5jcmwwWQYMKwYBBAGCpGTGKEABBEkwRwIBAQQVbWFyaWFub0Bz bWFsbHN0ZXAuY29tBCtudmduUjh3U3pwVWxydF90QzNtdnJod2hCeDlZN1QxV0xf SmpjRlZXWUJRMAoGCCqGSM49BAMCA0cAMEQCIE6umrhSbeQWWVK5cWBvXj5c0cGB bUF0rNw/dsaCaWcwAiAKSkmjhsC63DVPXPCNUki90YgVovO69foO1ZaB43lx5w== -----END CERTIFICATE----- ================================================ FILE: authority/provisioner/testdata/certs/root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET /A8LXNH4M06A7vE= -----END CERTIFICATE----- ================================================ FILE: authority/provisioner/testdata/certs/rsa.csr ================================================ -----BEGIN CERTIFICATE REQUEST----- MIICdDCCAVwCAQAwDjEMMAoGA1UEAxMDZm9vMIIBIjANBgkqhkiG9w0BAQEFAAOC AQ8AMIIBCgKCAQEA86h3t/KJylE0/aPxvF9JqPaOwSsGexuDWqDVJSOWBJi/ZqUA Ea2Gy05ZIJkQ5GOy0bUs2JCNCVXVkfPrUkX6IvIlXpTjutjMDYyYGdgQjzpKPnOA v3mO2a7mLMzJunws7pvrUPP7z5KDCKSAPf6VAcu/na8rGDWn1TUYR8hINK1rLQQf OcyNWrr7yLkR84jSsrw/Qgc8NS//F4ccca1NfZecPEtxgcHjKdDQZ3SYRAfb6Dc0 jRuvoByAd3q9okOOr70gpMXgpoFVArDynaHMPK9xJ1w2p3s2/NhOYgY9f9rtcWTo afoAcHK1jy5iQCogFUKt1bUCz5IsaYkRt+D+HQIDAQABoCEwHwYJKoZIhvcNAQkO MRIwEDAOBgNVHREEBzAFggNmb28wDQYJKoZIhvcNAQELBQADggEBAOsv1UKwEbcY 8Fj2Pl55BjkqQG4PqSQdWJZfK0ol/GRty5XFaTgOUZyTeXOag84OGw0qM0E7kkUa O5QwDOpnmIgg01Ywr4QM166l1iED+eOUscXJMonBAsS3JNYF1JxcDyKzIl/dt9+w JXQ64uquuD57amOs8++ROfKW988HzXm0OnoHj8LZ1Mq2yUmxvnnfVnmMpZWo43sA 8NQs4v9dT5wLByFvBjcaWiGVZwZiwT4Q/Msskv9L0o1On0fgCJ6PjLYdblTwMHDZ syH+X8SsUqeEmyvtiRc1XUeFbxS2hnPXJCXeyfljqwsBNGaVhBXcsV2Lg7IaloBF /RyWqQZ44eE= -----END CERTIFICATE REQUEST----- ================================================ FILE: authority/provisioner/testdata/certs/short-rsa.csr ================================================ -----BEGIN CERTIFICATE REQUEST----- MIIBdDCB2wIBADAOMQwwCgYDVQQDEwNmb28wgaIwDQYJKoZIhvcNAQEBBQADgZAA MIGMAoGEAK8dks7oV6kcIFEaWna7CDGYPAE8IL7rNi+ruQ1dIYz+JtxT7OPjbCn/ t5iqni96+35iS/8CvMtEuquOMTMSWOWwlurrbTbLqCazuz/g233o8udxSxhny3cY wHogp4cXCX6cFll6DeUnoCEuTTSIu8IBHbK48VfNw4V4gGz6cp/H93HrAgMBAAGg ITAfBgkqhkiG9w0BCQ4xEjAQMA4GA1UdEQQHMAWCA2ZvbzANBgkqhkiG9w0BAQsF AAOBhABCZsYM+Kgje68Z9Fjl2+cBwtQHvZDarh+cz6W1SchinZ1T0aNQvSj/otOe ttnEF4Rq8zqzr4fbv+AF451Mx36AkfgZr9XWGzxidrH+fBCNWXWNR+ymhrL6UFTG 2FbarLt9jN2aJLAYQPwtSeGTAZ74tLOPRPnTP6aMfFNg4XCR0uveHA== -----END CERTIFICATE REQUEST----- ================================================ FILE: authority/provisioner/testdata/certs/ssh_host_ca_key.pub ================================================ ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJj80EJXJR9vxefhdqOLSdzRzBw24t9YKPxb+eCYLf7BU50pJQnB/jK2ZM3qLFbieLaYjngZ86T4DzHxlPAnlAY= ================================================ FILE: authority/provisioner/testdata/certs/ssh_user_ca_key.pub ================================================ ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ8einS88ZaWpcTZG27D5N9JDKfGv0rzjDByLGsZzMsLYl3XcsN9IWKXB6b+5GJ3UaoZf/pFxzRzIdDIh7Ypw3Y= ================================================ FILE: authority/provisioner/testdata/certs/x5c-leaf.crt ================================================ -----BEGIN CERTIFICATE----- MIIBuDCCAV+gAwIBAgIQFdu723gqgGaTaqjf6ny88zAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFpbnRlcm1lZGlhdGUtdGVzdDAgFw0xOTEwMDIwMzE4NTNaGA8yMTE5 MDkwODAzMTg1MVowFDESMBAGA1UEAxMJbGVhZi10ZXN0MFkwEwYHKoZIzj0CAQYI KoZIzj0DAQcDQgAEaV6807GhWEtMxA39zjuMVHAiN2/Ri5B1R1s+Y/8mlrKIvuvr VpgSPXYruNRFduPWX564Abz/TDmb276JbKGeQqOBiDCBhTAOBgNVHQ8BAf8EBAMC BaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBReMkPW f4MNWdg7KN4xI4ZLJd0IJDAfBgNVHSMEGDAWgBSckDGJlzLaJsdy698XH32gPDMp czAUBgNVHREEDTALgglsZWFmLXRlc3QwCgYIKoZIzj0EAwIDRwAwRAIgKYLKXpTN wtvZZaIvDzq1p8MO/SZ8yI42Ot69dNk/QtkCIBSvg5PozYcfbvwkgX5SwsjfYu0Z AvUgkUQ2G25NBRmX -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIBtjCCAVygAwIBAgIQNr+f4IkABY2n4wx4sLOMrTAKBggqhkjOPQQDAjAUMRIw EAYDVQQDEwlyb290LXRlc3QwIBcNMTkxMDAyMDI0MDM0WhgPMjExOTA5MDgwMjQw MzJaMBwxGjAYBgNVBAMTEWludGVybWVkaWF0ZS10ZXN0MFkwEwYHKoZIzj0CAQYI KoZIzj0DAQcDQgAEflfRhPjgJXv4zsPWahXjM2UU61aRFErN0iw88ZPyxea22fxl qN9ezntTXxzsS+mZiWapl8B40ACJgvP+WLQBHKOBhTCBgjAOBgNVHQ8BAf8EBAMC AQYwEgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUnJAxiZcy2ibHcuvfFx99 oDwzKXMwHwYDVR0jBBgwFoAUpHS7FfaQ5bCrTxUeu6R2ZC3VGOowHAYDVR0RBBUw E4IRaW50ZXJtZWRpYXRlLXRlc3QwCgYIKoZIzj0EAwIDSAAwRQIgII8XpQ8ezDO1 2xdq3hShf155C5X/5jO8qr0VyEJgzlkCIQCTqph1Gwu/dmuf6dYLCfQqJyb371LC lgsqsR63is+0YQ== -----END CERTIFICATE----- ================================================ FILE: authority/provisioner/testdata/certs/yubico-piv-ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIDFzCCAf+gAwIBAgIDBAZHMA0GCSqGSIb3DQEBCwUAMCsxKTAnBgNVBAMMIFl1 YmljbyBQSVYgUm9vdCBDQSBTZXJpYWwgMjYzNzUxMCAXDTE2MDMxNDAwMDAwMFoY DzIwNTIwNDE3MDAwMDAwWjArMSkwJwYDVQQDDCBZdWJpY28gUElWIFJvb3QgQ0Eg U2VyaWFsIDI2Mzc1MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMN2 cMTNR6YCdcTFRxuPy31PabRn5m6pJ+nSE0HRWpoaM8fc8wHC+Tmb98jmNvhWNE2E ilU85uYKfEFP9d6Q2GmytqBnxZsAa3KqZiCCx2LwQ4iYEOb1llgotVr/whEpdVOq joU0P5e1j1y7OfwOvky/+AXIN/9Xp0VFlYRk2tQ9GcdYKDmqU+db9iKwpAzid4oH BVLIhmD3pvkWaRA2H3DA9t7H/HNq5v3OiO1jyLZeKqZoMbPObrxqDg+9fOdShzgf wCqgT3XVmTeiwvBSTctyi9mHQfYd2DwkaqxRnLbNVyK9zl+DzjSGp9IhVPiVtGet X02dxhQnGS7K6BO0Qe8CAwEAAaNCMEAwHQYDVR0OBBYEFMpfyvLEojGc6SJf8ez0 1d8Cv4O/MA8GA1UdEwQIMAYBAf8CAQEwDgYDVR0PAQH/BAQDAgEGMA0GCSqGSIb3 DQEBCwUAA4IBAQBc7Ih8Bc1fkC+FyN1fhjWioBCMr3vjneh7MLbA6kSoyWF70N3s XhbXvT4eRh0hvxqvMZNjPU/VlRn6gLVtoEikDLrYFXN6Hh6Wmyy1GTnspnOvMvz2 lLKuym9KYdYLDgnj3BeAvzIhVzzYSeU77/Cupofj093OuAswW0jYvXsGTyix6B3d bW5yWvyS9zNXaqGaUmP3U9/b6DlHdDogMLu3VLpBB9bm5bjaKWWJYgWltCVgUbFq Fqyi4+JE014cSgR57Jcu3dZiehB6UtAPgad9L5cNvua/IWRmm+ANy3O2LH++Pyl8 SREzU8onbBsjMg9QDiSf5oJLKvd/Ren+zGY7 -----END CERTIFICATE----- ================================================ FILE: authority/provisioner/testdata/secrets/bar.priv ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIM8wGIzCKjAOGdBFmYHtS791Ly2I9FtmknEsR2sa63s7oAoGCCqGSM49 AwEHoUQDQgAEGQIXbsr73X28pzwC1wa+ccY2H3s8PplbkapCrwxyYVvM78y/GmeS A7fv2MKQ8iKpCw461MlOQGX+VlWT+ChRFw== -----END EC PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/secrets/bar_host_ssh_key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIHzAUYu3h8e1gL5ONGZo+lghJJa9rl1TvP2UlqDXazxvoAoGCCqGSM49 AwEHoUQDQgAEOLScS+1Yzmqdyots9lSC0tzTSXUXEgyOD9wYrQ0BqnVZtBXlQw1p m3fnF/7Ehl6bD1YZWjrF1t+IBZQMq1uBBw== -----END EC PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/secrets/ecdsa.key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,54abd40e525b255542ee6161ec438721 fJvmEc5n0IG4t4FKF+ekKhpog4ods2nZjBR5KLkGH5oSGAOEADSXIRBK76Jnm/nz Kv8ZwGqxNnoJUQyeTMlyg5OnOUAQPyNBPvoItOlD2DP32WJXgQ+NSHB2h9pcBGYG yLWrCtzl9/P9REWskanPO4RujP27Ht62omcMO7SxxNI= -----END EC PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/secrets/ed25519.key ================================================ -----BEGIN ENCRYPTED PRIVATE KEY----- MIGkMGAGCSqGSIb3DQEFDTBTMDIGCSqGSIb3DQEFDDAlBBDJ0vCXdpPyUiLlbge5 1g0jAgMBhqAwDAYIKoZIhvcNAgkAADAdBglghkgBZQMEASoEENtOknzU2eS2mlxl 73Yo/IoEQEyJS2EEx3+oYaKlFIB90e1Zkmi8da7d3r2iUlfc7faRAiKChcEvtEas vYF2l9LEZ9DXv1Rm1uyNuSpXuddHScE= -----END ENCRYPTED PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/secrets/foo.key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIJmnxm3N/ahRA2PWeZhRGJUKPU1lI44WcE4P1bynIim6oAoGCCqGSM49 AwEHoUQDQgAEG6bXw6JxWnlD4k6+dsJfa93XM4K3qJcA0ZrIMhYJ69RLih+elkXI B9YI3y9KcmxZTtz3YAbpEeLvrL3qt+NzgA== -----END EC PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/secrets/foo.priv ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIB8ovUg0Atvz+b+XiF8QV722OivOm1geGtI3sP0F48N1oAoGCCqGSM49 AwEHoUQDQgAEzriaeV2e1aEz33x62kyqVC6ootU7rl41L8cyeOJ4SjTu4FV+o5i4 NsS6DCE07JJSHlqc9PsrzjSs4LZD4gWVLQ== -----END EC PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/secrets/foo_user_ssh_key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEINWGD2xneE43YeytQzORItISxv6d/oH+9TXvDKHo6TyXoAoGCCqGSM49 AwEHoUQDQgAEVK/EtXgVV7+7ppnQSjCtI5qb/gIGnQUF4i//F/JKKho7kRNyMDSn BP3kndiv8Yfxg4PsyIRY5ZofbEo5eJE6bg== -----END EC PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/secrets/rsa.key ================================================ -----BEGIN RSA PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,e77ed7e2d2572b5a246a1c4b994190bb 29o7UA/L7OF7inTPKkwBrtd8CJdXVQs9R3oJPitmFk8SZbrLHEiEhF1C0uB3xK3s GWQ9O7bjERM3uAvQkd7MSkUUBpyPXS9GFacd85e85d1Ubl6miTAwkFrQnT9yn6n/ Fak5JtkmdB6ObfVTioOwT1jdtGTifKg1bIhYISgwqCWhgV2fUFk6HQAAIMXTTRc+ ZK1WbunT7LimYrnN3gQ4ylm/4C8nQl3JCGpvWZaRoH91q1LLD6IwWmX0D09F37dU X3KoKv/GvuDlV3H1dUBwxhU+GI5/lItPp9OcdLZnnr67Gs+X+do3MFT1h675TM4N c9QEIJB6RYatLBKHCS7j8W7EbuJAFZ+MCCapP92ERmVVVsWPY+V7CDVQxM2v4X/w 7C7JYx8b4xuQbdvu9KVU5irsXg8hBx7kDb/mtWjT4+8+sseLKA4oOmI6XwVMdbow MciGilAIaNtWwQHe0EK9E9tiQfc9OyzxdrfplRckAAehHuPGU7+iMCsigCLT3aiV CDHmnLdTXKIvGe8faTQoJphrb9F8bobGo5D4ZqX5f6gKuPIJsfd/r0GD8VNSF/Q7 SJQMhkVyaixFB0gQbmea7sTyScdW+Qne7nLpam3ISgo+G4CAH8W88wLnuHMLmvoC ZE3HvArSeQZ0WPHgB86AfoNRIxd6Emgb+dFyA6wPJC29nZkB8PFSrAHp0zp7KilF fe9K2dVAUBZFhQthQIYAjJmYLCukLhxUALiqdSmQZrt6DSE33K8s5ed2KJu/60G6 lZwIzQHPXesRhwmwbkfPB8CyWM+L6osdWv8QyMdM8Wb+66zkhKWBNbm+ccMfP6Zf 1ynF/a/DRX8bf81w+nvLsCGTdxVuEVEpuzS1NclKTmYQu58Ol0RgQe2JSxL89n+A JAHUu9g9LcTg2jNPjxeA/vusSXMZRrPqrUCYhHhcgR4mE13uyyFI/9frk0gPpKXp /FislMydWov2JRp1ixzypMBqlFR/zF6j6m3P1g7gchwScWzrZQHD58xdRin4Udiv OR4huswh5v2i/0KozBoUAwbvPGERnMlTaGoBMPJ5Xe/jkBJw3uC3Dhi74uyUCjqU hMQW4RJKmuiZVfAIX0RdgeUWXPs+8pf2pXrpIiVHCAHDrxXNMC7X6/9EcBN15B88 W5/KIRngDeB2oVYrn1GfO7iLu1Rd8VFXyaVItOXq7WrL2pwm8ANhWcFDdnXf6jHW BcKss1j8rZxOchksf+ZPXhn3QkdooD9iVONky1zLIsV5GPwMe8+yXwXznzJSbHH7 dOfhK93fZqUwx4gFULwCuWIwLTfNmQ3VzdKioGt39RFDVQb+pbR7p9jv899VjsVO TBBpRa00fvbK1H2CMVHnwwIf82M4XypSNGR/tSD3AImZPb5RfZnznoXXCMEfYVsd 8/Ry4GHusA+zxCjCxHFtXVkb9sklewJtnUmN5mUzo/81szuigLB5IADR21IOyVBq A4kz96Ta885Z5owhonfZp1HD53pDEbxCuuIy+fgYfjfDSAj3L/QT3ZKrdcIdYQap PhrNRW3j38koAatTLd3+E9KqBO5BiY+T5h+Q3XesWnaXInfu5WKiiEm5hHiejA0C -----END RSA PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/secrets/ssh_host_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIKZCgb5pTSSCbr/xcHCOkl9O6tQtZmNahr3Ap3/c2nBLoAoGCCqGSM49 AwEHoUQDQgAEmPzQQlclH2/F5+F2o4tJ3NHMHDbi31go/Fv54Jgt/sFTnSklCcH+ MrZkzeosVuJ4tpiOeBnzpPgPMfGU8CeUBg== -----END EC PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/secrets/ssh_user_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIDuzykyPM6rLnSoyF4jnOpPAlyKZERqtaB8PTh179DMgoAoGCCqGSM49 AwEHoUQDQgAEnx6KdLzxlpalxNkbbsPk30kMp8a/SvOMMHIsaxnMywtiXddyw30h YpcHpv7kYndRqhl/+kXHNHMh0MiHtinDdg== -----END EC PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/secrets/x5c-leaf.key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIALytC4LyTTAagMLMv+rzq2vtfhFkhuyBz4kqsnRs6zioAoGCCqGSM49 AwEHoUQDQgAEaV6807GhWEtMxA39zjuMVHAiN2/Ri5B1R1s+Y/8mlrKIvuvrVpgS PXYruNRFduPWX564Abz/TDmb276JbKGeQg== -----END EC PRIVATE KEY----- ================================================ FILE: authority/provisioner/testdata/templates/cr.tpl ================================================ {{ toJson .Insecure.CR }} ================================================ FILE: authority/provisioner/timeduration.go ================================================ package provisioner import ( "encoding/json" "time" "github.com/pkg/errors" ) var now = func() time.Time { return time.Now().UTC() } // timeOr returns the first of its arguments that is not equal to the zero time. // This method can be replaced with cmp.Or when step-ca requires Go 1.22. func timeOr(ts ...time.Time) time.Time { for _, t := range ts { if !t.IsZero() { return t } } return time.Time{} } // TimeDuration is a type that represents a time but the JSON unmarshaling can // use a time using the RFC 3339 format or a time.Duration string. If a duration // is used, the time will be set on the first call to TimeDuration.Time. type TimeDuration struct { t time.Time d time.Duration } // NewTimeDuration returns a TimeDuration with the defined time. func NewTimeDuration(t time.Time) TimeDuration { return TimeDuration{t: t} } // ParseTimeDuration returns a new TimeDuration parsing the RFC 3339 time or // time.Duration string. func ParseTimeDuration(s string) (TimeDuration, error) { if s == "" { return TimeDuration{}, nil } // Try to use the unquoted RFC 3339 format var t time.Time if err := t.UnmarshalText([]byte(s)); err == nil { return TimeDuration{t: t.UTC()}, nil } // Try to use the time.Duration string format if d, err := time.ParseDuration(s); err == nil { return TimeDuration{d: d}, nil } return TimeDuration{}, errors.Errorf("failed to parse %s", s) } // SetDuration initializes the TimeDuration with the given duration string. If // the time was set it will re-set to zero. func (t *TimeDuration) SetDuration(d time.Duration) { t.t, t.d = time.Time{}, d } // SetTime initializes the TimeDuration with the given time. If the duration is // set it will be re-set to zero. func (t *TimeDuration) SetTime(tt time.Time) { t.t, t.d = tt, 0 } // IsZero returns true the TimeDuration represents the zero value, false // otherwise. func (t *TimeDuration) IsZero() bool { return t.t.IsZero() && t.d == 0 } // Equal returns if t and other are equal. func (t *TimeDuration) Equal(other *TimeDuration) bool { return t.t.Equal(other.t) && t.d == other.d } // MarshalJSON implements the json.Marshaler interface. If the time is set it // will return the time in RFC 3339 format if not it will return the duration // string. func (t TimeDuration) MarshalJSON() ([]byte, error) { switch { case t.t.IsZero(): if t.d == 0 { return []byte(`""`), nil } return json.Marshal(t.d.String()) default: return t.t.MarshalJSON() } } // UnmarshalJSON implements the json.Unmarshaler interface. The time is expected // to be a quoted string in RFC 3339 format or a quoted time.Duration string. func (t *TimeDuration) UnmarshalJSON(data []byte) error { var s string if err := json.Unmarshal(data, &s); err != nil { return errors.Wrapf(err, "error unmarshaling %s", data) } // Empty TimeDuration if s == "" { *t = TimeDuration{} return nil } // Try to use the unquoted RFC 3339 format var tt time.Time if err := tt.UnmarshalText([]byte(s)); err == nil { *t = TimeDuration{t: tt} return nil } // Try to use the time.Duration string format if d, err := time.ParseDuration(s); err == nil { *t = TimeDuration{d: d} return nil } return errors.Errorf("failed to parse %s", data) } // Time calculates the time if needed and returns it. func (t *TimeDuration) Time() time.Time { return t.RelativeTime(now()) } // Unix calculates the time if needed it and returns the Unix time in seconds. func (t *TimeDuration) Unix() int64 { return t.RelativeTime(now()).Unix() } // RelativeTime returns the embedded time.Time or the base time plus the // duration if this is not zero. func (t *TimeDuration) RelativeTime(base time.Time) time.Time { switch { case t == nil: return time.Time{} case t.t.IsZero(): if t.d == 0 { return time.Time{} } t.t = base.Add(t.d) return t.t.UTC() default: return t.t.UTC() } } // String implements the fmt.Stringer interface. func (t *TimeDuration) String() string { return t.Time().String() } ================================================ FILE: authority/provisioner/timeduration_test.go ================================================ package provisioner import ( "reflect" "testing" "time" ) func mockNow() (time.Time, func()) { tm := time.Unix(1584198566, 535897000).UTC() nowFn := now now = func() time.Time { return tm } return tm, func() { now = nowFn } } func TestNewTimeDuration(t *testing.T) { tm := time.Unix(1584198566, 535897000).UTC() type args struct { t time.Time } tests := []struct { name string args args want TimeDuration }{ {"ok", args{tm}, TimeDuration{t: tm}}, {"zero", args{time.Time{}}, TimeDuration{}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := NewTimeDuration(tt.args.t); !reflect.DeepEqual(got, tt.want) { t.Errorf("NewTimeDuration() = %v, want %v", got, tt.want) } }) } } func TestParseTimeDuration(t *testing.T) { type args struct { s string } tests := []struct { name string args args want TimeDuration wantErr bool }{ {"timestamp", args{"2020-03-14T15:09:26.535897Z"}, TimeDuration{t: time.Unix(1584198566, 535897000).UTC()}, false}, {"timestamp", args{"2020-03-14T15:09:26Z"}, TimeDuration{t: time.Unix(1584198566, 0).UTC()}, false}, {"timestamp", args{"2020-03-14T15:09:26.535897-07:00"}, TimeDuration{t: time.Unix(1584223766, 535897000).UTC()}, false}, {"timestamp", args{"2020-03-14T15:09:26-07:00"}, TimeDuration{t: time.Unix(1584223766, 0).UTC()}, false}, {"timestamp", args{"2020-03-14T15:09:26.535897+07:00"}, TimeDuration{t: time.Unix(1584173366, 535897000).UTC()}, false}, {"timestamp", args{"2020-03-14T15:09:26+07:00"}, TimeDuration{t: time.Unix(1584173366, 0).UTC()}, false}, {"1h", args{"1h"}, TimeDuration{d: 1 * time.Hour}, false}, {"-24h60m60s", args{"-24h60m60s"}, TimeDuration{d: -24*time.Hour - 60*time.Minute - 60*time.Second}, false}, {"0", args{"0"}, TimeDuration{}, false}, {"empty", args{""}, TimeDuration{}, false}, {"fail", args{"2020-03-14T15:09:26Z07:00"}, TimeDuration{}, true}, {"fail", args{"1d"}, TimeDuration{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := ParseTimeDuration(tt.args.s) if (err != nil) != tt.wantErr { t.Errorf("ParseTimeDuration() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("ParseTimeDuration() = %v, want %v", got, tt.want) } }) } } func TestTimeDuration_SetDuration(t *testing.T) { type fields struct { t time.Time d time.Duration } type args struct { d time.Duration } tests := []struct { name string fields fields args args want *TimeDuration }{ {"new", fields{}, args{2 * time.Hour}, &TimeDuration{d: 2 * time.Hour}}, {"old", fields{time.Now(), 1 * time.Hour}, args{2 * time.Hour}, &TimeDuration{d: 2 * time.Hour}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { td := &TimeDuration{ t: tt.fields.t, d: tt.fields.d, } td.SetDuration(tt.args.d) if !reflect.DeepEqual(td, tt.want) { t.Errorf("SetDuration() = %v, want %v", td, tt.want) } }) } } func TestTimeDuration_SetTime(t *testing.T) { tm := time.Unix(1584198566, 535897000).UTC() type fields struct { t time.Time d time.Duration } type args struct { tt time.Time } tests := []struct { name string fields fields args args want *TimeDuration }{ {"new", fields{}, args{tm}, &TimeDuration{t: tm}}, {"old", fields{time.Now(), 1 * time.Hour}, args{tm}, &TimeDuration{t: tm}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { td := &TimeDuration{ t: tt.fields.t, d: tt.fields.d, } td.SetTime(tt.args.tt) if !reflect.DeepEqual(td, tt.want) { t.Errorf("SetTime() = %v, want %v", td, tt.want) } }) } } func TestTimeDuration_MarshalJSON(t *testing.T) { tm := time.Unix(1584198566, 535897000).UTC() tests := []struct { name string timeDuration TimeDuration want []byte wantErr bool }{ {"empty", TimeDuration{}, []byte(`""`), false}, {"timestamp", TimeDuration{t: tm}, []byte(`"2020-03-14T15:09:26.535897Z"`), false}, {"duration", TimeDuration{d: 1 * time.Hour}, []byte(`"1h0m0s"`), false}, {"fail", TimeDuration{t: time.Date(-1, 0, 0, 0, 0, 0, 0, time.UTC)}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.timeDuration.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("TimeDuration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("TimeDuration.MarshalJSON() = %s, want %s", got, tt.want) } }) } } func TestTimeDuration_UnmarshalJSON(t *testing.T) { type args struct { data []byte } tests := []struct { name string args args want *TimeDuration wantErr bool }{ {"empty", args{[]byte(`""`)}, &TimeDuration{}, false}, {"timestamp", args{[]byte(`"2020-03-14T15:09:26.535897Z"`)}, &TimeDuration{t: time.Unix(1584198566, 535897000).UTC()}, false}, {"duration", args{[]byte(`"1h"`)}, &TimeDuration{d: time.Hour}, false}, {"fail", args{[]byte("123")}, &TimeDuration{}, true}, {"fail", args{[]byte(`"2020-03-14T15:09:26.535897Z07:00"`)}, &TimeDuration{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { td := &TimeDuration{} if err := td.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { t.Errorf("TimeDuration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } if !reflect.DeepEqual(td, tt.want) { t.Errorf("TimeDuration.UnmarshalJSON() = %s, want %s", td, tt.want) } }) } } func TestTimeDuration_Time(t *testing.T) { tm, fn := mockNow() defer fn() tests := []struct { name string timeDuration *TimeDuration want time.Time }{ {"zero", nil, time.Time{}}, {"zero", &TimeDuration{}, time.Time{}}, {"timestamp", &TimeDuration{t: tm}, tm}, {"local", &TimeDuration{t: tm.Local()}, tm}, {"duration", &TimeDuration{d: 1 * time.Hour}, tm.Add(1 * time.Hour)}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.timeDuration.Time() if !reflect.DeepEqual(got, tt.want) { t.Errorf("TimeDuration.Time() = %v, want %v", got, tt.want) } }) } } func TestTimeDuration_Unix(t *testing.T) { tm, fn := mockNow() defer fn() tests := []struct { name string timeDuration *TimeDuration want int64 }{ {"zero", nil, -62135596800}, {"zero", &TimeDuration{}, -62135596800}, {"timestamp", &TimeDuration{t: tm}, 1584198566}, {"local", &TimeDuration{t: tm.Local()}, 1584198566}, {"duration", &TimeDuration{d: 1 * time.Hour}, 1584202166}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.timeDuration.Unix() if !reflect.DeepEqual(got, tt.want) { t.Errorf("TimeDuration.Unix() = %v, want %v", got, tt.want) } }) } } func TestTimeDuration_String(t *testing.T) { tm, fn := mockNow() defer fn() tests := []struct { name string timeDuration *TimeDuration want string }{ {"zero", nil, "0001-01-01 00:00:00 +0000 UTC"}, {"zero", &TimeDuration{}, "0001-01-01 00:00:00 +0000 UTC"}, {"timestamp", &TimeDuration{t: tm}, "2020-03-14 15:09:26.535897 +0000 UTC"}, {"duration", &TimeDuration{d: 1 * time.Hour}, "2020-03-14 16:09:26.535897 +0000 UTC"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.timeDuration.String(); got != tt.want { t.Errorf("TimeDuration.String() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/provisioner/utils_test.go ================================================ package provisioner import ( "crypto" "crypto/rand" "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/hex" "encoding/json" "encoding/pem" "fmt" "net/http" "net/http/httptest" "os" "strings" "time" "github.com/pkg/errors" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/authority/provisioner/gcp" ) var ( defaultDisableRenewal = false defaultAllowRenewalAfterExpiry = false defaultEnableSSHCA = true defaultDisableSmallstepExtensions = false globalProvisionerClaims = Claims{ MinTLSDur: &Duration{5 * time.Minute}, MaxTLSDur: &Duration{24 * time.Hour}, DefaultTLSDur: &Duration{24 * time.Hour}, MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, EnableSSHCA: &defaultEnableSSHCA, DisableRenewal: &defaultDisableRenewal, AllowRenewalAfterExpiry: &defaultAllowRenewalAfterExpiry, DisableSmallstepExtensions: &defaultDisableSmallstepExtensions, } testAudiences = Audiences{ Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"}, Revoke: []string{"https://ca.smallstep.com/1.0/revoke", "https://ca.smallstep.com/revoke"}, SSHSign: []string{"https://ca.smallstep.com/1.0/ssh/sign"}, SSHRevoke: []string{"https://ca.smallstep.com/1.0/ssh/revoke"}, SSHRenew: []string{"https://ca.smallstep.com/1.0/ssh/renew"}, SSHRekey: []string{"https://ca.smallstep.com/1.0/ssh/rekey"}, } ) const awsTestCertificate = `-----BEGIN CERTIFICATE----- MIICFTCCAX6gAwIBAgIRAKmbVVYAl/1XEqRfF3eJ97MwDQYJKoZIhvcNAQELBQAw GDEWMBQGA1UEAxMNQVdTIFRlc3QgQ2VydDAeFw0xOTA0MjQyMjU3MzlaFw0yOTA0 MjEyMjU3MzlaMBgxFjAUBgNVBAMTDUFXUyBUZXN0IENlcnQwgZ8wDQYJKoZIhvcN AQEBBQADgY0AMIGJAoGBAOHMmMXwbXN90SoRl/xXAcJs5TacaVYJ5iNAVWM5KYyF +JwqYuJp/umLztFUi0oX0luu3EzD4KurVeUJSzZjTFTX1d/NX6hA45+bvdSUOcgV UghO+2uhBZ4SNFxFRZ7SKvoWIN195l5bVX6/60Eo6+kUCKCkyxW4V/ksWzdXjHnf AgMBAAGjXzBdMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0G A1UdDgQWBBRHfLOjEddK/CWCIHNg8Oc/oJa1IzAYBgNVHREEETAPgg1BV1MgVGVz dCBDZXJ0MA0GCSqGSIb3DQEBCwUAA4GBAKNCiVM9eGb9dW2xNyHaHAmmy7ERB2OJ 7oXHfLjooOavk9lU/Gs2jfX/JSBa84+DzWg9ShmCNLti8CxU/dhzXW7jE/5CcdTa DCA6B3Yl5TmfG9+D9dtFqRB2CiMgNcsJJE5Dc6pDwBIiSj/MkE0AaGVQmSwn6Cb6 vX1TAxqeWJHq -----END CERTIFICATE-----` const awsTestKey = `-----BEGIN RSA PRIVATE KEY----- MIICXAIBAAKBgQDhzJjF8G1zfdEqEZf8VwHCbOU2nGlWCeYjQFVjOSmMhficKmLi af7pi87RVItKF9JbrtxMw+Crq1XlCUs2Y0xU19XfzV+oQOOfm73UlDnIFVIITvtr oQWeEjRcRUWe0ir6FiDdfeZeW1V+v+tBKOvpFAigpMsVuFf5LFs3V4x53wIDAQAB AoGADZQFF9oWatyFCHeYYSdGRs/PlNIhD3h262XB/L6CPh4MTi/KVH01RAwROstP uPvnvXWtb7xTtV8PQj+l0zZzb4W/DLCSBdoRwpuNXyffUCtbI22jPupTsVu+ENWR 3x7HHzoZYjU45ADSTMxEtwD7/zyNgpRKjIA2HYpkt+fI27ECQQD5/AOr9/yQD73x cquF+FWahWgDL25YeMwdfe1HfpUxUxd9kJJKieB8E2BtBAv9XNguxIBpf7VlAKsF NFhdfWFHAkEA5zuX8vqDecSzyNNEQd3tugxt1pGOXNesHzuPbdlw3ppN9Rbd93an uU2TaAvTjr/3EkxulYNRmHs+RSVK54+uqQJAKWurhBQMAibJlzcj2ofiTz8pk9WJ GBmz4HMcHMuJlumoq8KHqtgbnRNs18Ni5TE8FMu0Z0ak3L52l98rgRokQwJBAJS8 9KTLF79AFBVeME3eH4jJbe3TeyulX4ZHnZ8fe0b1IqhAqU8A+CpuCB+pW9A7Ewam O4vZCKd4vzljH6eL+OECQHHxhYoTW7lFpKGnUDG9fPZ3eYzWpgka6w1vvBk10BAu 6fbwppM9pQ7DPMg7V6YGEjjT0gX9B9TttfHxGhvtZNQ= -----END RSA PRIVATE KEY-----` func must(args ...interface{}) []interface{} { if l := len(args); l > 0 && args[l-1] != nil { if err, ok := args[l-1].(error); ok { panic(err) } } return args } func generateJSONWebKey() (*jose.JSONWebKey, error) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) if err != nil { return nil, err } fp, err := jwk.Thumbprint(crypto.SHA256) if err != nil { return nil, err } jwk.KeyID = hex.EncodeToString(fp) return jwk, nil } func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) { var keySet jose.JSONWebKeySet for i := 0; i < n; i++ { key, err := generateJSONWebKey() if err != nil { return jose.JSONWebKeySet{}, err } keySet.Keys = append(keySet.Keys, *key) } return keySet, nil } func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) { b, err := json.Marshal(jwk) if err != nil { return nil, err } salt, err := randutil.Salt(jose.PBKDF2SaltSize) if err != nil { return nil, err } opts := new(jose.EncrypterOptions) opts.WithContentType(jose.ContentType("jwk+json")) recipient := jose.Recipient{ Algorithm: jose.PBES2_HS256_A128KW, Key: []byte("password"), PBES2Count: jose.PBKDF2Iterations, PBES2Salt: salt, } encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts) if err != nil { return nil, err } return encrypter.Encrypt(b) } func decryptJSONWebKey(key string) (*jose.JSONWebKey, error) { enc, err := jose.ParseEncrypted(key) if err != nil { return nil, err } b, err := enc.Decrypt([]byte("password")) if err != nil { return nil, err } jwk := new(jose.JSONWebKey) if err := json.Unmarshal(b, jwk); err != nil { return nil, err } return jwk, nil } func generateJWK() (*JWK, error) { name, err := randutil.Alphanumeric(10) if err != nil { return nil, err } jwk, err := generateJSONWebKey() if err != nil { return nil, err } jwe, err := encryptJSONWebKey(jwk) if err != nil { return nil, err } public := jwk.Public() encrypted, err := jwe.CompactSerialize() if err != nil { return nil, err } p := &JWK{ Name: name, Type: "JWK", Key: &public, EncryptedKey: encrypted, Claims: &globalProvisionerClaims, } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, }, nil) return p, err } func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { fooPubB, err := os.ReadFile("./testdata/certs/foo.pub") if err != nil { return nil, err } fooPub, err := pemutil.ParseKey(fooPubB) if err != nil { return nil, err } barPubB, err := os.ReadFile("./testdata/certs/bar.pub") if err != nil { return nil, err } barPub, err := pemutil.ParseKey(barPubB) if err != nil { return nil, err } pubKeys := []interface{}{fooPub, barPub} if inputPubKey != nil { pubKeys = append(pubKeys, inputPubKey) } p := &K8sSA{ Name: K8sSAName, Type: "K8sSA", Claims: &globalProvisionerClaims, pubKeys: pubKeys, } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, }, nil) return p, err } func generateSSHPOP() (*SSHPOP, error) { name, err := randutil.Alphanumeric(10) if err != nil { return nil, err } userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub") if err != nil { return nil, err } userKey, _, _, _, err := ssh.ParseAuthorizedKey(userB) if err != nil { return nil, err } hostB, err := os.ReadFile("./testdata/certs/ssh_host_ca_key.pub") if err != nil { return nil, err } hostKey, _, _, _, err := ssh.ParseAuthorizedKey(hostB) if err != nil { return nil, err } p := &SSHPOP{ Name: name, Type: "SSHPOP", Claims: &globalProvisionerClaims, sshPubKeys: &SSHKeys{ UserKeys: []ssh.PublicKey{userKey}, HostKeys: []ssh.PublicKey{hostKey}, }, } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, }, nil) return p, err } func generateX5C(root []byte) (*X5C, error) { if root == nil { root = []byte(`-----BEGIN CERTIFICATE----- MIIBhTCCASqgAwIBAgIRAMalM7pKi0GCdKjO6u88OyowCgYIKoZIzj0EAwIwFDES MBAGA1UEAxMJcm9vdC10ZXN0MCAXDTE5MTAwMjAyMzk0OFoYDzIxMTkwOTA4MDIz OTQ4WjAUMRIwEAYDVQQDEwlyb290LXRlc3QwWTATBgcqhkjOPQIBBggqhkjOPQMB BwNCAAS29QTCXUu7cx9sa9wZPpRSFq/zXaw8Ai3EIygayrBsKnX42U2atBUjcBZO BWL6A+PpLzU9ja867U5SYNHERS+Oo1swWTAOBgNVHQ8BAf8EBAMCAQYwEgYDVR0T AQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQUpHS7FfaQ5bCrTxUeu6R2ZC3VGOowFAYD VR0RBA0wC4IJcm9vdC10ZXN0MAoGCCqGSM49BAMCA0kAMEYCIQC2vgqwla0u8LHH 1MHob14qvS5o76HautbIBW7fcHzz5gIhAIx5A2+wkJYX4026kqaZCk/1sAwTxSGY M46l92gdOozT -----END CERTIFICATE-----`) } name, err := randutil.Alphanumeric(10) if err != nil { return nil, err } rootPool := x509.NewCertPool() var ( block *pem.Block rest = root ) for rest != nil { block, rest = pem.Decode(rest) if block == nil { break } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, errors.Wrap(err, "error parsing x509 certificate from PEM block") } rootPool.AddCert(cert) } p := &X5C{ Name: name, Type: "X5C", Roots: root, Claims: &globalProvisionerClaims, rootPool: rootPool, } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, }, nil) return p, err } func generateOIDC() (*OIDC, error) { name, err := randutil.Alphanumeric(10) if err != nil { return nil, err } clientID, err := randutil.Alphanumeric(10) if err != nil { return nil, err } issuer, err := randutil.Alphanumeric(10) if err != nil { return nil, err } jwk, err := generateJSONWebKey() if err != nil { return nil, err } p := &OIDC{ Name: name, Type: "OIDC", ClientID: clientID, ConfigurationEndpoint: "https://example.com/.well-known/openid-configuration", Claims: &globalProvisionerClaims, configuration: openIDConfiguration{ Issuer: issuer, JWKSetURI: "https://example.com/.well-known/jwks", }, keyStore: &keyStore{ keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, }, nil) return p, err } func generateGCP() (*GCP, error) { name, err := randutil.Alphanumeric(10) if err != nil { return nil, err } serviceAccount, err := randutil.Alphanumeric(10) if err != nil { return nil, err } jwk, err := generateJSONWebKey() if err != nil { return nil, err } p := &GCP{ Type: "GCP", Name: name, ServiceAccounts: []string{serviceAccount}, Claims: &globalProvisionerClaims, DisableSSHCAHost: &DefaultDisableSSHCAHost, DisableSSHCAUser: &DefaultDisableSSHCAUser, config: newGCPConfig(), keyStore: &keyStore{ keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, projectValidator: &gcp.ProjectValidator{}, } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences.WithFragment("gcp/" + name), }, nil) return p, err } func generateAWS() (*AWS, error) { name, err := randutil.Alphanumeric(10) if err != nil { return nil, err } accountID, err := randutil.Alphanumeric(10) if err != nil { return nil, err } block, _ := pem.Decode([]byte(awsTestCertificate)) if block == nil || block.Type != "CERTIFICATE" { return nil, errors.New("error decoding AWS certificate") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, errors.Wrap(err, "error parsing AWS certificate") } p := &AWS{ Type: "AWS", Name: name, Accounts: []string{accountID}, Claims: &globalProvisionerClaims, IMDSVersions: []string{"v2", "v1"}, config: &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, tokenURL: awsAPITokenURL, tokenTTL: awsAPITokenTTL, certificates: []*x509.Certificate{cert}, signatureAlgorithm: awsSignatureAlgorithm, }, } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences.WithFragment("aws/" + name), }, nil) return p, err } func generateAWSWithServer() (*AWS, *httptest.Server, error) { aws, err := generateAWS() if err != nil { return nil, nil, err } block, _ := pem.Decode([]byte(awsTestKey)) if block == nil || block.Type != "RSA PRIVATE KEY" { return nil, nil, errors.New("error decoding AWS key") } key, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, nil, errors.Wrap(err, "error parsing AWS private key") } doc, err := json.MarshalIndent(awsInstanceIdentityDocument{ AccountID: aws.Accounts[0], Architecture: "x86_64", AvailabilityZone: "us-west-2b", ImageID: "image-id", InstanceID: "instance-id", InstanceType: "t2.micro", PendingTime: time.Now(), PrivateIP: "127.0.0.1", Region: "us-west-1", Version: "2017-09-30", }, "", " ") if err != nil { return nil, nil, err } sum := sha256.Sum256(doc) signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256) if err != nil { return nil, nil, errors.Wrap(err, "error signing document") } //nolint:gosec // tests minimum size of the key token := "AQAEAEEO9-7Z88ewKFpboZuDlFYWz9A3AN-wMOVzjEhfAyXW31BvVw==" srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/latest/dynamic/instance-identity/document": // check for API token if r.Header.Get("X-aws-ec2-metadata-token") != token { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("401 Unauthorized")) } w.Write(doc) case "/latest/dynamic/instance-identity/signature": // check for API token if r.Header.Get("X-aws-ec2-metadata-token") != token { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("401 Unauthorized")) } w.Write([]byte(base64.StdEncoding.EncodeToString(signature))) case "/latest/api/token": w.Write([]byte(token)) case "/bad-document": w.Write([]byte("{}")) case "/bad-signature": w.Write([]byte("YmFkLXNpZ25hdHVyZQo=")) case "/bad-json": w.Write([]byte("{")) default: http.NotFound(w, r) } })) aws.config.identityURL = srv.URL + "/latest/dynamic/instance-identity/document" aws.config.signatureURL = srv.URL + "/latest/dynamic/instance-identity/signature" aws.config.tokenURL = srv.URL + "/latest/api/token" return aws, srv, nil } func generateAWSV1Only() (*AWS, error) { name, err := randutil.Alphanumeric(10) if err != nil { return nil, err } accountID, err := randutil.Alphanumeric(10) if err != nil { return nil, err } block, _ := pem.Decode([]byte(awsTestCertificate)) if block == nil || block.Type != "CERTIFICATE" { return nil, errors.New("error decoding AWS certificate") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, errors.Wrap(err, "error parsing AWS certificate") } p := &AWS{ Type: "AWS", Name: name, Accounts: []string{accountID}, Claims: &globalProvisionerClaims, IMDSVersions: []string{"v1"}, config: &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, tokenURL: awsAPITokenURL, tokenTTL: awsAPITokenTTL, certificates: []*x509.Certificate{cert}, signatureAlgorithm: awsSignatureAlgorithm, }, } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences.WithFragment("aws/" + name), }, nil) return p, err } func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) { aws, err := generateAWSV1Only() if err != nil { return nil, nil, err } block, _ := pem.Decode([]byte(awsTestKey)) if block == nil || block.Type != "RSA PRIVATE KEY" { return nil, nil, errors.New("error decoding AWS key") } key, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, nil, errors.Wrap(err, "error parsing AWS private key") } doc, err := json.MarshalIndent(awsInstanceIdentityDocument{ AccountID: aws.Accounts[0], Architecture: "x86_64", AvailabilityZone: "us-west-2b", ImageID: "image-id", InstanceID: "instance-id", InstanceType: "t2.micro", PendingTime: time.Now(), PrivateIP: "127.0.0.1", Region: "us-west-1", Version: "2017-09-30", }, "", " ") if err != nil { return nil, nil, err } sum := sha256.Sum256(doc) signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256) if err != nil { return nil, nil, errors.Wrap(err, "error signing document") } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/latest/dynamic/instance-identity/document": w.Write(doc) case "/latest/dynamic/instance-identity/signature": w.Write([]byte(base64.StdEncoding.EncodeToString(signature))) case "/bad-document": w.Write([]byte("{}")) case "/bad-signature": w.Write([]byte("YmFkLXNpZ25hdHVyZQo=")) case "/bad-json": w.Write([]byte("{")) default: http.NotFound(w, r) } })) aws.config.identityURL = srv.URL + "/latest/dynamic/instance-identity/document" aws.config.signatureURL = srv.URL + "/latest/dynamic/instance-identity/signature" return aws, srv, nil } func generateAzure() (*Azure, error) { name, err := randutil.Alphanumeric(10) if err != nil { return nil, err } tenantID, err := randutil.Alphanumeric(10) if err != nil { return nil, err } jwk, err := generateJSONWebKey() if err != nil { return nil, err } p := &Azure{ Type: "Azure", Name: name, TenantID: tenantID, Audience: azureDefaultAudience, Claims: &globalProvisionerClaims, config: newAzureConfig(tenantID), oidcConfig: openIDConfiguration{ Issuer: "https://sts.windows.net/" + tenantID + "/", JWKSetURI: "https://login.microsoftonline.com/common/discovery/keys", }, keyStore: &keyStore{ keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, } p.ctl, err = NewController(p, p.Claims, Config{ Audiences: testAudiences, }, nil) return p, err } func generateAzureWithServer() (*Azure, *httptest.Server, error) { az, err := generateAzure() if err != nil { return nil, nil, err } writeJSON := func(w http.ResponseWriter, v interface{}) { b, err := json.Marshal(v) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.Header().Add("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(b) } getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet { var ret jose.JSONWebKeySet for _, k := range ks.Keys { ret.Keys = append(ret.Keys, k.Public()) } return ret } issuer := "https://sts.windows.net/" + az.TenantID + "/" srv := httptest.NewUnstartedServer(nil) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/error": http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) case "/" + az.TenantID + "/.well-known/openid-configuration": writeJSON(w, openIDConfiguration{Issuer: issuer, JWKSetURI: srv.URL + "/jwks_uri"}) case "/openid-configuration-no-issuer": writeJSON(w, openIDConfiguration{Issuer: "", JWKSetURI: srv.URL + "/jwks_uri"}) case "/openid-configuration-fail-jwk": writeJSON(w, openIDConfiguration{Issuer: issuer, JWKSetURI: srv.URL + "/error"}) case "/random": keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) w.Header().Add("Cache-Control", "max-age=5") writeJSON(w, getPublic(keySet)) case "/private": writeJSON(w, az.keyStore.keySet) case "/jwks_uri": w.Header().Add("Cache-Control", "max-age=5") writeJSON(w, getPublic(az.keyStore.keySet)) case "/metadata/identity/oauth2/token": tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &az.keyStore.keySet.Keys[0]) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } else { writeJSON(w, azureIdentityToken{ AccessToken: tok, }) } case "/metadata/instance/compute/azEnvironment": w.Header().Add("Content-Type", "text/plain") w.Write([]byte("AzurePublicCloud")) default: http.NotFound(w, r) } }) srv.Start() az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration" az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token" az.config.instanceComputeURL = srv.URL + "/metadata/instance/compute/azEnvironment" return az, srv, nil } func generateCollection(nJWK, nOIDC int) (*Collection, error) { col := NewCollection(testAudiences) for i := 0; i < nJWK; i++ { p, err := generateJWK() if err != nil { return nil, err } col.Store(p) } for i := 0; i < nOIDC; i++ { p, err := generateOIDC() if err != nil { return nil, err } col.Store(p) } return col, nil } func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { return generateToken("subject", iss, aud, "name@smallstep.com", []string{"test.smallstep.com"}, time.Now(), jwk) } type tokOption func(*jose.SignerOptions) error func withX5CHdr(certs []*x509.Certificate) tokOption { return func(so *jose.SignerOptions) error { strs := make([]string, len(certs)) for i, cert := range certs { strs[i] = base64.StdEncoding.EncodeToString(cert.Raw) } so.WithHeader("x5c", strs) return nil } } func withSSHPOPFile(cert *ssh.Certificate) tokOption { return func(so *jose.SignerOptions) error { so.WithHeader("sshpop", base64.StdEncoding.EncodeToString(cert.Marshal())) return nil } } func generateToken(sub, iss, aud, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", jwk.KeyID) for _, o := range tokOpts { if err := o(so); err != nil { return "", err } } sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) if err != nil { return "", err } id, err := randutil.ASCII(64) if err != nil { return "", err } claims := struct { jose.Claims Email string `json:"email"` SANS []string `json:"sans"` }{ Claims: jose.Claims{ ID: id, Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, }, Email: email, SANS: sans, } return jose.Signed(sig).Claims(claims).CompactSerialize() } func generateCustomToken(sub, iss, aud string, jwk *jose.JSONWebKey, extraHeaders, extraClaims map[string]any) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", jwk.KeyID) for k, v := range extraHeaders { so.WithHeader(jose.HeaderKey(k), v) } sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) if err != nil { return "", err } id, err := randutil.ASCII(64) if err != nil { return "", err } iat := time.Now() claims := jose.Claims{ ID: id, Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, } return jose.Signed(sig).Claims(claims).Claims(extraClaims).CompactSerialize() } func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", jwk.KeyID) for _, o := range tokOpts { if err := o(so); err != nil { return "", err } } sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) if err != nil { return "", err } id, err := randutil.ASCII(64) if err != nil { return "", err } claims := struct { jose.Claims Email string `json:"email"` PreferredUsername string `json:"preferred_username,omitempty"` }{ Claims: jose.Claims{ ID: id, Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, }, Email: email, PreferredUsername: preferredUsername, } return jose.Signed(sig).Claims(claims).CompactSerialize() } func generateX5CSSHToken(jwk *jose.JSONWebKey, claims *x5cPayload, tokOpts ...tokOption) (string, error) { so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", jwk.KeyID) for _, o := range tokOpts { if err := o(so); err != nil { return "", err } } sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) if err != nil { return "", err } return jose.Signed(sig).Claims(claims).CompactSerialize() } func getK8sSAPayload() *k8sSAPayload { return &k8sSAPayload{ Claims: jose.Claims{ Issuer: k8sSAIssuer, Subject: "foo", }, Namespace: "ns-foo", SecretName: "sn-foo", ServiceAccountName: "san-foo", ServiceAccountUID: "sauid-foo", } } func generateK8sSAToken(jwk *jose.JSONWebKey, claims *k8sSAPayload, tokOpts ...tokOption) (string, error) { so := new(jose.SignerOptions) so.WithHeader("kid", jwk.KeyID) for _, o := range tokOpts { if err := o(so); err != nil { return "", err } } sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, so) if err != nil { return "", err } if claims == nil { claims = getK8sSAPayload() } return jose.Signed(sig).Claims(*claims).CompactSerialize() } func generateSimpleSSHUserToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { return generateSSHToken("subject@localhost", iss, aud, time.Now(), &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, }, jwk) } func generateSimpleSSHHostToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { return generateSSHToken("subject@localhost", iss, aud, time.Now(), &SignSSHOptions{ CertType: "host", Principals: []string{"smallstep.com"}, }, jwk) } func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *SignSSHOptions, jwk *jose.JSONWebKey) (string, error) { sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), ) if err != nil { return "", err } id, err := randutil.ASCII(64) if err != nil { return "", err } claims := struct { jose.Claims Step *stepPayload `json:"step,omitempty"` }{ Claims: jose.Claims{ ID: id, Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, }, Step: &stepPayload{ SSH: sshOpts, }, } return jose.Signed(sig).Claims(claims).CompactSerialize() } func generateGCPToken(sub, iss, aud, instanceID, instanceName, projectID, zone string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), ) if err != nil { return "", err } aud, err = generateSignAudience("https://ca.smallstep.com", aud) if err != nil { return "", err } claims := gcpPayload{ Claims: jose.Claims{ Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, }, AuthorizedParty: sub, Email: "foo@developer.gserviceaccount.com", EmailVerified: true, Google: gcpGooglePayload{ ComputeEngine: gcpComputeEnginePayload{ InstanceID: instanceID, InstanceName: instanceName, InstanceCreationTimestamp: jose.NewNumericDate(iat), ProjectID: projectID, ProjectNumber: 1234567890, Zone: zone, }, }, } return jose.Signed(sig).Claims(claims).CompactSerialize() } func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, region string, iat time.Time, key crypto.Signer) (string, error) { doc, err := json.MarshalIndent(awsInstanceIdentityDocument{ AccountID: accountID, Architecture: "x86_64", AvailabilityZone: "us-west-2b", ImageID: "ami-123123", InstanceID: instanceID, InstanceType: "t2.micro", PendingTime: iat, PrivateIP: privateIP, Region: region, Version: "2017-09-30", }, "", " ") if err != nil { return "", err } sum := sha256.Sum256(doc) signature, err := key.Sign(rand.Reader, sum[:], crypto.SHA256) if err != nil { return "", errors.Wrap(err, "error signing document") } sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.HS256, Key: signature}, new(jose.SignerOptions).WithType("JWT"), ) if err != nil { return "", err } aud, err = generateSignAudience("https://ca.smallstep.com", aud) if err != nil { return "", err } unique := fmt.Sprintf("%s.%s", p.GetID(), instanceID) sum = sha256.Sum256([]byte(unique)) claims := awsPayload{ Claims: jose.Claims{ ID: strings.ToLower(hex.EncodeToString(sum[:])), Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, }, Amazon: awsAmazonPayload{ Document: doc, Signature: signature, }, } return jose.Signed(sig).Claims(claims).CompactSerialize() } func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), ) if err != nil { return "", err } var xmsMirID string switch resourceType { case "vm": xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, resourceName) case "uai": xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscriptionID, resourceGroup, resourceName) } claims := azurePayload{ Claims: jose.Claims{ Subject: sub, Issuer: iss, IssuedAt: jose.NewNumericDate(iat), NotBefore: jose.NewNumericDate(iat), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Audience: []string{aud}, ID: "the-jti", }, AppID: "the-appid", AppIDAcr: "the-appidacr", IdentityProvider: "the-idp", ObjectID: "the-oid", TenantID: tenantID, Version: "the-version", XMSMirID: xmsMirID, } return jose.Signed(sig).Claims(claims).CompactSerialize() } func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) { tok, err := jose.ParseSigned(token) if err != nil { return nil, nil, err } claims := new(jose.Claims) if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil { return nil, nil, err } return tok, claims, nil } func parseAWSToken(token string) (*jose.JSONWebToken, *awsPayload, error) { tok, err := jose.ParseSigned(token) if err != nil { return nil, nil, err } claims := new(awsPayload) if err := tok.UnsafeClaimsWithoutVerification(claims); err != nil { return nil, nil, err } var doc awsInstanceIdentityDocument if err := json.Unmarshal(claims.Amazon.Document, &doc); err != nil { return nil, nil, errors.Wrap(err, "error unmarshaling identity document") } claims.document = doc return tok, claims, nil } func generateJWKServerHandler(n int, srv *httptest.Server) http.Handler { hits := struct { Hits int `json:"hits"` }{} writeJSON := func(w http.ResponseWriter, v interface{}) { b, err := json.Marshal(v) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.Header().Add("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(b) } getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet { var ret jose.JSONWebKeySet for _, k := range ks.Keys { ret.Keys = append(ret.Keys, k.Public()) } return ret } defaultKeySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hits.Hits++ switch r.RequestURI { case "/error": http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) case "/hits": writeJSON(w, hits) case "/.well-known/openid-configuration": writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"}) case "/common/.well-known/openid-configuration": writeJSON(w, openIDConfiguration{Issuer: "https://login.microsoftonline.com/{tenantid}/v2.0", JWKSetURI: srv.URL + "/jwks_uri"}) case "/random": keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet) w.Header().Add("Cache-Control", "max-age=5") writeJSON(w, getPublic(keySet)) case "/no-cache": keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet) w.Header().Add("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") writeJSON(w, getPublic(keySet)) case "/private": writeJSON(w, defaultKeySet) default: w.Header().Add("Cache-Control", "max-age=5") writeJSON(w, getPublic(defaultKeySet)) } }) } func generateJWKServer(n int) *httptest.Server { srv := httptest.NewUnstartedServer(nil) srv.Config.Handler = generateJWKServerHandler(n, srv) srv.Start() return srv } func generateTLSJWKServer(n int) *httptest.Server { srv := httptest.NewUnstartedServer(nil) srv.Config.Handler = generateJWKServerHandler(n, srv) srv.StartTLS() return srv } func generateACME() (*ACME, error) { // Initialize provisioners p := &ACME{ Type: "ACME", Name: "test@acme-provisioner.com", } if err := p.Init(Config{Claims: globalProvisionerClaims}); err != nil { return nil, err } return p, nil } func parseCerts(b []byte) ([]*x509.Certificate, error) { var ( block *pem.Block rest = b certs = []*x509.Certificate{} ) for rest != nil { block, rest = pem.Decode(rest) if block == nil { break } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, errors.Wrap(err, "error parsing x509 certificate from PEM block") } certs = append(certs, cert) } return certs, nil } ================================================ FILE: authority/provisioner/webhook.go ================================================ package provisioner import ( "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "log" "net/http" "net/url" "strings" "text/template" "time" "github.com/pkg/errors" "github.com/smallstep/linkedca" "github.com/smallstep/certificates/authority/poolhttp" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/middleware/requestid" "github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/webhook" ) var ErrWebhookDenied = errors.New("webhook server did not allow request") type WebhookSetter interface { SetWebhook(string, any) } type WebhookController struct { client HTTPClient wrapTransport httptransport.Wrapper webhooks []*Webhook certType linkedca.Webhook_CertType options []webhook.RequestBodyOption TemplateData WebhookSetter } // Enrich fetches data from remote servers and adds returned data to the // templateData func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error { if wc == nil { return nil } // Apply extra options in the webhook controller for _, fn := range wc.options { if err := fn(req); err != nil { return err } } for _, wh := range wc.webhooks { if wh.Kind != linkedca.Webhook_ENRICHING.String() { continue } if !wc.isCertTypeOK(wh) { continue } whCtx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() //nolint:gocritic // every request canceled with its own timeout resp, err := wh.DoWithContext(whCtx, wc.client, wc.wrapTransport, req, wc.TemplateData) if err != nil { return err } if !resp.Allow { if resp.Error != nil { return resp.Error } return ErrWebhookDenied } wc.TemplateData.SetWebhook(wh.Name, resp.Data) } return nil } // Authorize checks that all remote servers allow the request func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error { if wc == nil { return nil } // Apply extra options in the webhook controller for _, fn := range wc.options { if err := fn(req); err != nil { return err } } for _, wh := range wc.webhooks { if wh.Kind != linkedca.Webhook_AUTHORIZING.String() { continue } if !wc.isCertTypeOK(wh) { continue } whCtx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() //nolint:gocritic // every request canceled with its own timeout resp, err := wh.DoWithContext(whCtx, wc.client, wc.wrapTransport, req, wc.TemplateData) if err != nil { return err } if !resp.Allow { if resp.Error != nil { return resp.Error } return ErrWebhookDenied } } return nil } func (wc *WebhookController) isCertTypeOK(wh *Webhook) bool { if wc.certType == linkedca.Webhook_ALL { return true } if wh.CertType == linkedca.Webhook_ALL.String() || wh.CertType == "" { return true } return wc.certType.String() == wh.CertType } type Webhook struct { ID string `json:"id"` Name string `json:"name"` URL string `json:"url"` Kind string `json:"kind"` DisableTLSClientAuth bool `json:"disableTLSClientAuth,omitempty"` CertType string `json:"certType"` Secret string `json:"-"` BearerToken string `json:"-"` BasicAuth struct { Username string Password string } `json:"-"` } // Validate validates a webhook, only name, url and kind are required. func (w *Webhook) Validate() error { if w == nil { return nil } // name if w.Name == "" { return errors.New("webhook name is required") } // url parsedURL, err := url.Parse(w.URL) if err != nil { return errors.New("webhook url is invalid") } if parsedURL.Host == "" { return errors.New("webhook url is invalid") } if parsedURL.Scheme != "https" { return errors.New("webhook url must use https") } if parsedURL.User != nil { return errors.New("webhook url may not contain username or password") } // kind if w.Kind == "" { return errors.New("webhook kind is required") } w.Kind = strings.ToUpper(w.Kind) kind, ok := linkedca.Webhook_Kind_value[w.Kind] if !ok || kind == 0 { return errors.New("webhook kind is invalid") } return nil } // TransportWrapper wraps the set of functions mapping [http.Transport] references to // [http.RoundTripper]. type TransportWrapper = httptransport.Wrapper func (w *Webhook) DoWithContext(ctx context.Context, client HTTPClient, tw TransportWrapper, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) if err != nil { return nil, err } buf := &bytes.Buffer{} if err := tmpl.Execute(buf, data); err != nil { return nil, err } webhookURL := buf.String() /* Sending the token to the webhook server is a security risk. A K8sSA token can be reused multiple times. The webhook can misuse it to get fake certificates. A webhook can misuse any other token to get its own certificate before responding. switch tmpl := data.(type) { case x509util.TemplateData: reqBody.Token = tmpl[x509util.TokenKey] case sshutil.TemplateData: reqBody.Token = tmpl[sshutil.TokenKey] } */ reqBody.Timestamp = time.Now() reqBytes, err := json.Marshal(reqBody) if err != nil { return nil, err } retries := 1 retry: req, err := http.NewRequestWithContext(ctx, "POST", webhookURL, bytes.NewReader(reqBytes)) if err != nil { return nil, err } if requestID, ok := requestid.FromContext(ctx); ok { req.Header.Set("X-Request-Id", requestID) } secret, err := base64.StdEncoding.DecodeString(w.Secret) if err != nil { return nil, err } h := hmac.New(sha256.New, secret) h.Write(reqBytes) sig := h.Sum(nil) req.Header.Set("X-Smallstep-Signature", hex.EncodeToString(sig)) req.Header.Set("X-Smallstep-Webhook-ID", w.ID) if w.BearerToken != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", w.BearerToken)) } else if w.BasicAuth.Username != "" || w.BasicAuth.Password != "" { req.SetBasicAuth(w.BasicAuth.Username, w.BasicAuth.Password) } if w.DisableTLSClientAuth { var transport *http.Transport if ct, ok := client.(poolhttp.Transporter); ok { transport = ct.Transport() } else { transport = httptransport.New() } if transport.TLSClientConfig != nil { transport.TLSClientConfig.GetClientCertificate = nil transport.TLSClientConfig.Certificates = nil } client = &http.Client{ Transport: tw(transport), } } resp, err := client.Do(req) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return nil, err } else if retries > 0 { retries-- time.Sleep(time.Second) goto retry } return nil, err } defer func() { if err := resp.Body.Close(); err != nil { log.Printf("Failed to close body of response from %s", w.URL) } }() if resp.StatusCode >= 500 && retries > 0 { retries-- time.Sleep(time.Second) goto retry } if resp.StatusCode >= 400 { return nil, fmt.Errorf("Webhook server responded with %d", resp.StatusCode) } respBody := &webhook.ResponseBody{} if err := json.NewDecoder(resp.Body).Decode(respBody); err != nil { return nil, err } return respBody, nil } ================================================ FILE: authority/provisioner/webhook_test.go ================================================ package provisioner import ( "context" "crypto/hmac" "crypto/sha256" "crypto/tls" "crypto/x509" _ "embed" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/middleware/requestid" "github.com/smallstep/certificates/webhook" "github.com/smallstep/linkedca" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" ) func TestWebhookController_isCertTypeOK(t *testing.T) { type test struct { wc *WebhookController wh *Webhook want bool } tests := map[string]test{ "all/all": { wc: &WebhookController{certType: linkedca.Webhook_ALL}, wh: &Webhook{CertType: linkedca.Webhook_ALL.String()}, want: true, }, "all/x509": { wc: &WebhookController{certType: linkedca.Webhook_ALL}, wh: &Webhook{CertType: linkedca.Webhook_X509.String()}, want: true, }, "all/ssh": { wc: &WebhookController{certType: linkedca.Webhook_ALL}, wh: &Webhook{CertType: linkedca.Webhook_SSH.String()}, want: true, }, `all/""`: { wc: &WebhookController{certType: linkedca.Webhook_ALL}, wh: &Webhook{}, want: true, }, "x509/all": { wc: &WebhookController{certType: linkedca.Webhook_X509}, wh: &Webhook{CertType: linkedca.Webhook_ALL.String()}, want: true, }, "x509/x509": { wc: &WebhookController{certType: linkedca.Webhook_X509}, wh: &Webhook{CertType: linkedca.Webhook_X509.String()}, want: true, }, "x509/ssh": { wc: &WebhookController{certType: linkedca.Webhook_X509}, wh: &Webhook{CertType: linkedca.Webhook_SSH.String()}, want: false, }, `x509/""`: { wc: &WebhookController{certType: linkedca.Webhook_X509}, wh: &Webhook{}, want: true, }, "ssh/all": { wc: &WebhookController{certType: linkedca.Webhook_SSH}, wh: &Webhook{CertType: linkedca.Webhook_ALL.String()}, want: true, }, "ssh/x509": { wc: &WebhookController{certType: linkedca.Webhook_SSH}, wh: &Webhook{CertType: linkedca.Webhook_X509.String()}, want: false, }, "ssh/ssh": { wc: &WebhookController{certType: linkedca.Webhook_SSH}, wh: &Webhook{CertType: linkedca.Webhook_SSH.String()}, want: true, }, `ssh/""`: { wc: &WebhookController{certType: linkedca.Webhook_SSH}, wh: &Webhook{}, want: true, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { assert.Equal(t, test.want, test.wc.isCertTypeOK(test.wh)) }) } } // withRequestID is a helper that calls into [requestid.NewContext] and returns // a new context with the requestID added. func withRequestID(t *testing.T, ctx context.Context, requestID string) context.Context { t.Helper() return requestid.NewContext(ctx, requestID) } func TestWebhookController_Enrich(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) require.NoError(t, err) type test struct { ctl *WebhookController ctx context.Context req *webhook.RequestBody responses []*webhook.ResponseBody expectErr bool expectTemplateData any assertRequest func(t *testing.T, req *webhook.RequestBody) assertError func(t *testing.T, err error) } tests := map[string]test{ "ok/no enriching webhooks": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, TemplateData: nil, }, req: &webhook.RequestBody{}, responses: nil, expectErr: false, expectTemplateData: nil, }, "ok/one webhook": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}}, }, "ok/two webhooks": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{ {Name: "people", Kind: "ENRICHING"}, {Name: "devices", Kind: "ENRICHING"}, }, TemplateData: x509util.TemplateData{}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, {Allow: true, Data: map[string]any{"serial": "123"}}, }, expectErr: false, expectTemplateData: x509util.TemplateData{ "Webhooks": map[string]any{ "devices": map[string]any{"serial": "123"}, "people": map[string]any{"role": "bar"}, }, }, }, "ok/x509 only": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{ {Name: "people", Kind: "ENRICHING", CertType: linkedca.Webhook_SSH.String()}, {Name: "devices", Kind: "ENRICHING"}, }, TemplateData: x509util.TemplateData{}, certType: linkedca.Webhook_X509, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, {Allow: true, Data: map[string]any{"serial": "123"}}, }, expectErr: false, expectTemplateData: x509util.TemplateData{ "Webhooks": map[string]any{ "devices": map[string]any{"serial": "123"}, }, }, }, "ok/with options": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}}, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) require.NoError(t, err) assert.Equal(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), NotBefore: cert.NotBefore, NotAfter: cert.NotAfter, }, req.X5CCertificate) }, }, "deny": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, expectTemplateData: x509util.TemplateData{}, assertError: func(t *testing.T, err error) { assert.Equal(t, ErrWebhookDenied, err) }, }, "deny/with error": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false, Error: &webhook.Error{ Code: "theCode", Message: "Some message", }}}, expectErr: true, expectTemplateData: x509util.TemplateData{}, assertError: func(t *testing.T, err error) { assert.Equal(t, &webhook.Error{ Code: "theCode", Message: "Some message", }, err) }, }, "fail/with options": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(&x509.Certificate{ PublicKey: []byte("bad"), })}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, expectTemplateData: x509util.TemplateData{}, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { for i, wh := range test.ctl.webhooks { var j = i ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "reqID", r.Header.Get("X-Request-ID")) err := json.NewEncoder(w).Encode(test.responses[j]) require.NoError(t, err) })) // nolint: gocritic // defer in loop isn't a memory leak defer ts.Close() wh.URL = ts.URL } err := test.ctl.Enrich(test.ctx, test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } assert.Equal(t, test.expectTemplateData, test.ctl.TemplateData) if test.assertRequest != nil { test.assertRequest(t, test.req) } if test.assertError != nil { test.assertError(t, err) } }) } } func TestWebhookController_Authorize(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) require.NoError(t, err) type test struct { ctl *WebhookController ctx context.Context req *webhook.RequestBody responses []*webhook.ResponseBody expectErr bool assertRequest func(t *testing.T, req *webhook.RequestBody) assertError func(t *testing.T, err error) } tests := map[string]test{ "ok/no enriching webhooks": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, }, req: &webhook.RequestBody{}, responses: nil, expectErr: false, }, "ok": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, }, "ok/ssh only": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, certType: linkedca.Webhook_SSH, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: false, }, "ok/with options": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) require.NoError(t, err) assert.Equal(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), NotBefore: cert.NotBefore, NotAfter: cert.NotAfter, }, req.X5CCertificate) }, }, "deny": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, assertError: func(t *testing.T, err error) { assert.Equal(t, ErrWebhookDenied, err) }, }, "deny/withError": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false, Error: &webhook.Error{ Code: "theCode", Message: "Some message", }}}, expectErr: true, assertError: func(t *testing.T, err error) { assert.Equal(t, &webhook.Error{ Code: "theCode", Message: "Some message", }, err) }, }, "fail/with options": { ctl: &WebhookController{ client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(&x509.Certificate{ PublicKey: []byte("bad"), })}, }, ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { for i, wh := range test.ctl.webhooks { var j = i ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "reqID", r.Header.Get("X-Request-ID")) err := json.NewEncoder(w).Encode(test.responses[j]) require.NoError(t, err) })) // nolint: gocritic // defer in loop isn't a memory leak defer ts.Close() wh.URL = ts.URL } err := test.ctl.Authorize(test.ctx, test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } if test.assertRequest != nil { test.assertRequest(t, test.req) } if test.assertError != nil { test.assertError(t, err) } }) } } func TestWebhook_Do(t *testing.T) { csr := parseCertificateRequest(t, "testdata/certs/ecdsa.csr") type test struct { webhook Webhook dataArg any requestID string webhookResponse webhook.ResponseBody expectPath string errStatusCode int serverErrMsg string expectErr error // expectToken any } tests := map[string]test{ "ok": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", }, requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, }, "ok/no-request-id": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", }, webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, }, "ok/bearer": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", BearerToken: "mytoken", }, requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, }, "ok/basic": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", BasicAuth: struct { Username string Password string }{ Username: "myuser", Password: "mypass", }, }, requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, }, "ok/templated-url": { webhook: Webhook{ ID: "abc123", // scheme, host, port will come from test server URL: "/users/{{ .username }}?region={{ .region }}", Secret: "c2VjcmV0Cg==", }, requestID: "reqID", dataArg: map[string]interface{}{"username": "areed", "region": "central"}, webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, expectPath: "/users/areed?region=central", }, /* "ok/token from ssh template": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", }, webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, dataArg: sshutil.TemplateData{sshutil.TokenKey: "token"}, expectToken: "token", }, "ok/token from x509 template": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", }, webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, dataArg: x509util.TemplateData{sshutil.TokenKey: "token"}, expectToken: "token", }, */ "ok/allow": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", }, requestID: "reqID", webhookResponse: webhook.ResponseBody{ Allow: true, }, }, "fail/404": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", }, webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, requestID: "reqID", errStatusCode: 404, serverErrMsg: "item not found", expectErr: errors.New("Webhook server responded with 404"), }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if tc.requestID != "" { assert.Equal(t, tc.requestID, r.Header.Get("X-Request-ID")) } assert.Equal(t, tc.webhook.ID, r.Header.Get("X-Smallstep-Webhook-ID")) sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature")) assert.NoError(t, err) body, err := io.ReadAll(r.Body) assert.NoError(t, err) secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret) assert.NoError(t, err) h := hmac.New(sha256.New, secret) h.Write(body) mac := h.Sum(nil) assert.True(t, hmac.Equal(sig, mac)) switch { case tc.webhook.BearerToken != "": ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken) assert.Equal(t, ah, r.Header.Get("Authorization")) case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "": whReq, err := http.NewRequest("", "", http.NoBody) require.NoError(t, err) whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password) ah := whReq.Header.Get("Authorization") assert.Equal(t, ah, whReq.Header.Get("Authorization")) default: assert.Equal(t, "", r.Header.Get("Authorization")) } if tc.expectPath != "" { assert.Equal(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) } if tc.errStatusCode != 0 { http.Error(w, tc.serverErrMsg, tc.errStatusCode) return } reqBody := new(webhook.RequestBody) err = json.Unmarshal(body, reqBody) require.NoError(t, err) err = json.NewEncoder(w).Encode(tc.webhookResponse) require.NoError(t, err) })) defer ts.Close() tc.webhook.URL = ts.URL + tc.webhook.URL reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) require.NoError(t, err) ctx := context.Background() if tc.requestID != "" { ctx = withRequestID(t, ctx, tc.requestID) } ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, httptransport.NoopWrapper(), reqBody, tc.dataArg) if tc.expectErr != nil { assert.Equal(t, tc.expectErr.Error(), err.Error()) return } assert.NoError(t, err) assert.Equal(t, &tc.webhookResponse, got) }) } t.Run("disableTLSClientAuth", func(t *testing.T) { ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("{}")) })) ts.TLS.ClientAuth = tls.RequireAnyClientCert wh := Webhook{ URL: ts.URL, } cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key") require.NoError(t, err) transport := httptransport.New() transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, Certificates: []tls.Certificate{cert}, } client := &http.Client{ Transport: transport, } reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() _, err = wh.DoWithContext(ctx, client, httptransport.NoopWrapper(), reqBody, nil) require.NoError(t, err) ctx, cancel = context.WithTimeout(context.Background(), time.Second*10) defer cancel() wh.DisableTLSClientAuth = true _, err = wh.DoWithContext(ctx, client, httptransport.NoopWrapper(), reqBody, nil) require.Error(t, err) }) } func TestWebhook_Validate(t *testing.T) { tests := []struct { name string webhook *Webhook assertion assert.ErrorAssertionFunc }{ {"ok enriching", &Webhook{Name: "devices", URL: "https://localhost:3000", Kind: "ENRICHING"}, assert.NoError}, {"ok authorizing", &Webhook{Name: "devices", URL: "https://localhost:3000/devices", Kind: "AUTHORIZING"}, assert.NoError}, {"fail name", &Webhook{Name: "", URL: "https://localhost:3000", Kind: "ENRICHING"}, assert.Error}, {"fail url", &Webhook{Name: "devices", URL: "", Kind: "ENRICHING"}, assert.Error}, {"fail bad url", &Webhook{Name: "devices", URL: "https://{{.Templated.Host}}", Kind: "ENRICHING"}, assert.Error}, {"fail host", &Webhook{Name: "devices", URL: "https:opaque", Kind: "ENRICHING"}, assert.Error}, {"fail scheme", &Webhook{Name: "devices", URL: "http://localhost", Kind: "ENRICHING"}, assert.Error}, {"fail user", &Webhook{Name: "devices", URL: "https://user:pass@localhost", Kind: "ENRICHING"}, assert.Error}, {"fail kind", &Webhook{Name: "devices", URL: "https://localhost:3000", Kind: ""}, assert.Error}, {"fail bad kind", &Webhook{Name: "devices", URL: "https://localhost:3000", Kind: "SOMETHING"}, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.assertion(t, tt.webhook.Validate()) }) } } ================================================ FILE: authority/provisioner/wire/dpop_options.go ================================================ package wire import ( "bytes" "crypto" "errors" "fmt" "text/template" "go.step.sm/crypto/pemutil" ) type DPOPOptions struct { // Public part of the signing key for DPoP access token in PEM format SigningKey []byte `json:"key"` // URI template for the URI the ACME client must call to fetch the DPoP challenge proof (an access token from wire-server) Target string `json:"target"` signingKey crypto.PublicKey target *template.Template } func (o *DPOPOptions) GetSigningKey() crypto.PublicKey { return o.signingKey } func (o *DPOPOptions) EvaluateTarget(deviceID string) (string, error) { if deviceID == "" { return "", errors.New("deviceID must not be empty") } buf := new(bytes.Buffer) if err := o.target.Execute(buf, struct{ DeviceID string }{DeviceID: deviceID}); err != nil { return "", fmt.Errorf("failed executing DPoP template: %w", err) } return buf.String(), nil } func (o *DPOPOptions) validateAndInitialize() (err error) { o.signingKey, err = pemutil.Parse(o.SigningKey) if err != nil { return fmt.Errorf("failed parsing key: %w", err) } o.target, err = template.New("DeviceID").Parse(o.Target) if err != nil { return fmt.Errorf("failed parsing DPoP template: %w", err) } return nil } ================================================ FILE: authority/provisioner/wire/dpop_options_test.go ================================================ package wire import ( "errors" "testing" "text/template" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDPOPOptions_EvaluateTarget(t *testing.T) { tu := "http://wire.com:15958/clients/{{.DeviceID}}/access-token" target, err := template.New("DeviceID").Parse(tu) require.NoError(t, err) fail := "https:/wire.com:15958/clients/{{.DeviceId}}/access-token" failTarget, err := template.New("DeviceID").Parse(fail) require.NoError(t, err) type fields struct { target *template.Template } type args struct { deviceID string } tests := []struct { name string fields fields args args want string expectedErr error }{ { name: "ok", fields: fields{target: target}, args: args{deviceID: "deviceID"}, want: "http://wire.com:15958/clients/deviceID/access-token", }, { name: "fail/empty", fields: fields{target: target}, args: args{deviceID: ""}, expectedErr: errors.New("deviceID must not be empty"), }, { name: "fail/template", fields: fields{target: failTarget}, args: args{deviceID: "bla"}, expectedErr: errors.New(`failed executing DPoP template: template: DeviceID:1:32: executing "DeviceID" at <.DeviceId>: can't evaluate field DeviceId in type struct { DeviceID string }`), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &DPOPOptions{ target: tt.fields.target, } got, err := o.EvaluateTarget(tt.args.deviceID) if tt.expectedErr != nil { assert.EqualError(t, err, tt.expectedErr.Error()) assert.Empty(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.want, got) }) } } ================================================ FILE: authority/provisioner/wire/oidc_options.go ================================================ package wire import ( "bytes" "context" "encoding/json" "errors" "fmt" "net/url" "text/template" "time" "github.com/coreos/go-oidc/v3/oidc" "go.step.sm/crypto/x509util" ) type Provider struct { DiscoveryBaseURL string `json:"discoveryBaseUrl,omitempty"` IssuerURL string `json:"issuerUrl,omitempty"` AuthURL string `json:"authorizationUrl,omitempty"` TokenURL string `json:"tokenUrl,omitempty"` JWKSURL string `json:"jwksUrl,omitempty"` UserInfoURL string `json:"userInfoUrl,omitempty"` Algorithms []string `json:"signatureAlgorithms,omitempty"` } type Config struct { ClientID string `json:"clientId,omitempty"` SignatureAlgorithms []string `json:"signatureAlgorithms,omitempty"` // the properties below are only used for testing SkipClientIDCheck bool `json:"-"` SkipExpiryCheck bool `json:"-"` SkipIssuerCheck bool `json:"-"` InsecureSkipSignatureCheck bool `json:"-"` Now func() time.Time `json:"-"` } type OIDCOptions struct { Provider *Provider `json:"provider,omitempty"` Config *Config `json:"config,omitempty"` TransformTemplate string `json:"transform,omitempty"` target *template.Template transform *template.Template oidcProviderConfig *oidc.ProviderConfig provider *oidc.Provider verifier *oidc.IDTokenVerifier } func (o *OIDCOptions) GetVerifier(ctx context.Context) (*oidc.IDTokenVerifier, error) { if o.verifier == nil { switch { case o.Provider.DiscoveryBaseURL != "": // creates a new OIDC provider using automatic discovery and the default HTTP client provider, err := oidc.NewProvider(ctx, o.Provider.DiscoveryBaseURL) if err != nil { return nil, fmt.Errorf("failed creating new OIDC provider using discovery: %w", err) } o.provider = provider default: o.provider = o.oidcProviderConfig.NewProvider(ctx) } if o.provider == nil { return nil, errors.New("no OIDC provider available") } o.verifier = o.provider.Verifier(o.getConfig()) } return o.verifier, nil } func (o *OIDCOptions) getConfig() *oidc.Config { if o == nil || o.Config == nil { return &oidc.Config{} } return &oidc.Config{ ClientID: o.Config.ClientID, SupportedSigningAlgs: o.Config.SignatureAlgorithms, SkipClientIDCheck: o.Config.SkipClientIDCheck, SkipExpiryCheck: o.Config.SkipExpiryCheck, SkipIssuerCheck: o.Config.SkipIssuerCheck, Now: o.Config.Now, InsecureSkipSignatureCheck: o.Config.InsecureSkipSignatureCheck, } } const defaultTemplate = `{"name": "{{ .name }}", "preferred_username": "{{ .preferred_username }}"}` func (o *OIDCOptions) validateAndInitialize() (err error) { if o.Provider == nil { return errors.New("provider not set") } if o.Provider.IssuerURL == "" && o.Provider.DiscoveryBaseURL == "" { return errors.New("either OIDC discovery or issuer URL must be set") } if o.Provider.DiscoveryBaseURL == "" { o.oidcProviderConfig, err = toOIDCProviderConfig(o.Provider) if err != nil { return fmt.Errorf("failed creationg OIDC provider config: %w", err) } } o.target, err = template.New("DeviceID").Parse(o.Provider.IssuerURL) if err != nil { return fmt.Errorf("failed parsing OIDC template: %w", err) } o.transform, err = parseTransform(o.TransformTemplate) if err != nil { return fmt.Errorf("failed parsing OIDC transformation template: %w", err) } return nil } func parseTransform(transformTemplate string) (*template.Template, error) { if transformTemplate == "" { transformTemplate = defaultTemplate } return template.New("transform").Funcs(x509util.GetFuncMap()).Parse(transformTemplate) } func (o *OIDCOptions) EvaluateTarget(deviceID string) (string, error) { buf := new(bytes.Buffer) if err := o.target.Execute(buf, struct{ DeviceID string }{DeviceID: deviceID}); err != nil { return "", fmt.Errorf("failed executing OIDC template: %w", err) } return buf.String(), nil } func (o *OIDCOptions) Transform(v map[string]any) (map[string]any, error) { if o.transform == nil || v == nil { return v, nil } // TODO(hs): add support for extracting error message from template "fail" function? buf := new(bytes.Buffer) if err := o.transform.Execute(buf, v); err != nil { return nil, fmt.Errorf("failed executing OIDC transformation: %w", err) } var r map[string]any if err := json.Unmarshal(buf.Bytes(), &r); err != nil { return nil, fmt.Errorf("failed unmarshaling transformed OIDC token: %w", err) } // add original claims if not yet in the transformed result for key, value := range v { if _, ok := r[key]; !ok { r[key] = value } } return r, nil } func toOIDCProviderConfig(in *Provider) (*oidc.ProviderConfig, error) { issuerURL, err := url.Parse(in.IssuerURL) if err != nil { return nil, fmt.Errorf("failed parsing issuer URL: %w", err) } // Removes query params from the URL because we use it as a way to notify client about the actual OAuth ClientId // for this provisioner. // This URL is going to look like: "https://idp:5556/dex?clientid=foo" // If we don't trim the query params here i.e. 'clientid' then the idToken verification is going to fail because // the 'iss' claim of the idToken will be "https://idp:5556/dex" issuerURL.RawQuery = "" issuerURL.Fragment = "" return &oidc.ProviderConfig{ IssuerURL: issuerURL.String(), AuthURL: in.AuthURL, TokenURL: in.TokenURL, UserInfoURL: in.UserInfoURL, JWKSURL: in.JWKSURL, Algorithms: in.Algorithms, }, nil } ================================================ FILE: authority/provisioner/wire/oidc_options_test.go ================================================ package wire import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "testing" "text/template" "github.com/coreos/go-oidc/v3/oidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" ) func TestOIDCOptions_Transform(t *testing.T) { defaultTransform, err := parseTransform(``) require.NoError(t, err) swapTransform, err := parseTransform(`{"name": "{{ .preferred_username }}", "preferred_username": "{{ .name }}"}`) require.NoError(t, err) funcTransform, err := parseTransform(`{"name": "{{ .name }}", "preferred_username": "{{ first .usernames }}"}`) require.NoError(t, err) type fields struct { transform *template.Template } type args struct { v map[string]any } tests := []struct { name string fields fields args args want map[string]any expectedErr error }{ { name: "ok/no-transform", fields: fields{ transform: nil, }, args: args{ v: map[string]any{ "name": "Example", "preferred_username": "Preferred", }, }, want: map[string]any{ "name": "Example", "preferred_username": "Preferred", }, }, { name: "ok/empty-data", fields: fields{ transform: nil, }, args: args{ v: map[string]any{}, }, want: map[string]any{}, }, { name: "ok/default-transform", fields: fields{ transform: defaultTransform, }, args: args{ v: map[string]any{ "name": "Example", "preferred_username": "Preferred", }, }, want: map[string]any{ "name": "Example", "preferred_username": "Preferred", }, }, { name: "ok/swap-transform", fields: fields{ transform: swapTransform, }, args: args{ v: map[string]any{ "name": "Example", "preferred_username": "Preferred", }, }, want: map[string]any{ "name": "Preferred", "preferred_username": "Example", }, }, { name: "ok/transform-with-functions", fields: fields{ transform: funcTransform, }, args: args{ v: map[string]any{ "name": "Example", "usernames": []string{"name-1", "name-2", "name-3"}, }, }, want: map[string]any{ "name": "Example", "preferred_username": "name-1", "usernames": []string{"name-1", "name-2", "name-3"}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &OIDCOptions{ transform: tt.fields.transform, } got, err := o.Transform(tt.args.v) if tt.expectedErr != nil { assert.Error(t, err) return } assert.Equal(t, tt.want, got) }) } } func TestOIDCOptions_EvaluateTarget(t *testing.T) { tu := "http://target.example.com/{{.DeviceID}}" target, err := template.New("DeviceID").Parse(tu) require.NoError(t, err) empty := "http://target.example.com" emptyTarget, err := template.New("DeviceID").Parse(empty) require.NoError(t, err) fail := "https:/wire.com:15958/clients/{{.DeviceId}}/access-token" failTarget, err := template.New("DeviceID").Parse(fail) require.NoError(t, err) type fields struct { target *template.Template } type args struct { deviceID string } tests := []struct { name string fields fields args args want string expectedErr error }{ { name: "ok", fields: fields{target: target}, args: args{deviceID: "deviceID"}, want: "http://target.example.com/deviceID", }, { name: "ok/empty", fields: fields{target: emptyTarget}, args: args{deviceID: ""}, want: "http://target.example.com", }, { name: "fail/template", fields: fields{target: failTarget}, args: args{deviceID: "bla"}, expectedErr: errors.New(`failed executing OIDC template: template: DeviceID:1:32: executing "DeviceID" at <.DeviceId>: can't evaluate field DeviceId in type struct { DeviceID string }`), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &OIDCOptions{ target: tt.fields.target, } got, err := o.EvaluateTarget(tt.args.deviceID) if tt.expectedErr != nil { assert.EqualError(t, err, tt.expectedErr.Error()) assert.Empty(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.want, got) }) } } func TestOIDCOptions_GetVerifier(t *testing.T) { signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) require.NoError(t, err) srv := mustDiscoveryServer(t, signerJWK.Public()) defer srv.Close() type fields struct { Provider *Provider Config *Config TransformTemplate string } tests := []struct { name string fields fields ctx context.Context want *oidc.IDTokenVerifier wantErr bool }{ { name: "fail/invalid-discovery-url", fields: fields{ Provider: &Provider{ DiscoveryBaseURL: "http://invalid.example.com", }, Config: &Config{ ClientID: "client-id", }, TransformTemplate: "http://target.example.com/{{.DeviceID}}", }, ctx: context.Background(), wantErr: true, }, { name: "ok/auto", fields: fields{ Provider: &Provider{ DiscoveryBaseURL: srv.URL, }, Config: &Config{ ClientID: "client-id", }, TransformTemplate: "http://target.example.com/{{.DeviceID}}", }, ctx: context.Background(), }, { name: "ok/fixed", fields: fields{ Provider: &Provider{ IssuerURL: "http://issuer.example.com", }, Config: &Config{ ClientID: "client-id", }, TransformTemplate: "http://target.example.com/{{.DeviceID}}", }, ctx: context.Background(), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &OIDCOptions{ Provider: tt.fields.Provider, Config: tt.fields.Config, TransformTemplate: tt.fields.TransformTemplate, } err := o.validateAndInitialize() require.NoError(t, err) verifier, err := o.GetVerifier(tt.ctx) if tt.wantErr { assert.Error(t, err) assert.Nil(t, verifier) return } assert.NoError(t, err) assert.NotNil(t, verifier) if assert.NotNil(t, o.provider) { assert.NotNil(t, o.provider.Endpoint()) } }) } } func mustDiscoveryServer(t *testing.T, pub jose.JSONWebKey) *httptest.Server { t.Helper() mux := http.NewServeMux() server := httptest.NewServer(mux) b, err := json.Marshal(struct { Keys []jose.JSONWebKey `json:"keys,omitempty"` }{ Keys: []jose.JSONWebKey{pub}, }) require.NoError(t, err) jwks := string(b) wellKnown := fmt.Sprintf(`{ "issuer": "%[1]s", "authorization_endpoint": "%[1]s/auth", "token_endpoint": "%[1]s/token", "jwks_uri": "%[1]s/keys", "userinfo_endpoint": "%[1]s/userinfo", "id_token_signing_alg_values_supported": ["ES256"] }`, server.URL) mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, req *http.Request) { _, err := io.WriteString(w, wellKnown) if err != nil { w.WriteHeader(500) } }) mux.HandleFunc("/keys", func(w http.ResponseWriter, req *http.Request) { _, err := io.WriteString(w, jwks) if err != nil { w.WriteHeader(500) } }) t.Cleanup(server.Close) return server } ================================================ FILE: authority/provisioner/wire/wire_options.go ================================================ package wire import ( "errors" "fmt" ) // Options holds the Wire ACME extension options type Options struct { OIDC *OIDCOptions `json:"oidc,omitempty"` DPOP *DPOPOptions `json:"dpop,omitempty"` } // GetOIDCOptions returns the OIDC options. func (o *Options) GetOIDCOptions() *OIDCOptions { if o == nil { return nil } return o.OIDC } // GetDPOPOptions returns the DPoP options. func (o *Options) GetDPOPOptions() *DPOPOptions { if o == nil { return nil } return o.DPOP } // Validate validates and initializes the Wire OIDC and DPoP options. // // TODO(hs): find a good way to perform this only once. func (o *Options) Validate() error { if oidc := o.GetOIDCOptions(); oidc != nil { if err := oidc.validateAndInitialize(); err != nil { return fmt.Errorf("failed initializing OIDC options: %w", err) } } else { return errors.New("no OIDC options available") } if dpop := o.GetDPOPOptions(); dpop != nil { if err := dpop.validateAndInitialize(); err != nil { return fmt.Errorf("failed initializing DPoP options: %w", err) } } else { return errors.New("no DPoP options available") } return nil } ================================================ FILE: authority/provisioner/wire/wire_options_test.go ================================================ package wire import ( "errors" "testing" "github.com/stretchr/testify/assert" ) func TestOptions_Validate(t *testing.T) { key := []byte(`-----BEGIN PUBLIC KEY----- MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= -----END PUBLIC KEY-----`) type fields struct { OIDC *OIDCOptions DPOP *DPOPOptions } tests := []struct { name string fields fields expectedErr error }{ { name: "ok", fields: fields{ OIDC: &OIDCOptions{ Provider: &Provider{ IssuerURL: "https://example.com", }, Config: &Config{}, }, DPOP: &DPOPOptions{ SigningKey: key, }, }, expectedErr: nil, }, { name: "fail/no-oidc-options", fields: fields{ OIDC: nil, DPOP: &DPOPOptions{}, }, expectedErr: errors.New("no OIDC options available"), }, { name: "fail/empty-issuer-url", fields: fields{ OIDC: &OIDCOptions{ Provider: &Provider{ IssuerURL: "", }, Config: &Config{}, }, DPOP: &DPOPOptions{}, }, expectedErr: errors.New("failed initializing OIDC options: either OIDC discovery or issuer URL must be set"), }, { name: "fail/invalid-issuer-url", fields: fields{ OIDC: &OIDCOptions{ Provider: &Provider{ IssuerURL: "\x00", }, Config: &Config{}, }, DPOP: &DPOPOptions{}, }, expectedErr: errors.New(`failed initializing OIDC options: failed creationg OIDC provider config: failed parsing issuer URL: parse "\x00": net/url: invalid control character in URL`), }, { name: "fail/issuer-url-template", fields: fields{ OIDC: &OIDCOptions{ Provider: &Provider{ IssuerURL: "https://issuer.example.com/{{}", }, Config: &Config{}, }, DPOP: &DPOPOptions{}, }, expectedErr: errors.New(`failed initializing OIDC options: failed parsing OIDC template: template: DeviceID:1: unexpected "}" in command`), }, { name: "fail/invalid-transform-template", fields: fields{ OIDC: &OIDCOptions{ Provider: &Provider{ IssuerURL: "https://example.com", }, Config: &Config{}, TransformTemplate: "{{}", }, DPOP: &DPOPOptions{ SigningKey: key, }, }, expectedErr: errors.New(`failed initializing OIDC options: failed parsing OIDC transformation template: template: transform:1: unexpected "}" in command`), }, { name: "fail/no-dpop-options", fields: fields{ OIDC: &OIDCOptions{ Provider: &Provider{ IssuerURL: "https://example.com", }, Config: &Config{}, }, DPOP: nil, }, expectedErr: errors.New("no DPoP options available"), }, { name: "fail/invalid-key", fields: fields{ OIDC: &OIDCOptions{ Provider: &Provider{ IssuerURL: "https://example.com", }, Config: &Config{}, }, DPOP: &DPOPOptions{ SigningKey: []byte{0x00}, Target: "", }, }, expectedErr: errors.New(`failed initializing DPoP options: failed parsing key: error decoding PEM: not a valid PEM encoded block`), }, { name: "fail/target-template", fields: fields{ OIDC: &OIDCOptions{ Provider: &Provider{ IssuerURL: "https://example.com", }, Config: &Config{}, }, DPOP: &DPOPOptions{ SigningKey: key, Target: "{{}", }, }, expectedErr: errors.New(`failed initializing DPoP options: failed parsing DPoP template: template: DeviceID:1: unexpected "}" in command`), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &Options{ OIDC: tt.fields.OIDC, DPOP: tt.fields.DPOP, } err := o.Validate() if tt.expectedErr != nil { assert.EqualError(t, err, tt.expectedErr.Error()) return } assert.NoError(t, err) }) } } ================================================ FILE: authority/provisioner/x5c.go ================================================ package provisioner import ( "context" "crypto/x509" "encoding/pem" "net/http" "time" "github.com/pkg/errors" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" "github.com/smallstep/certificates/webhook" ) // x5cPayload extends jwt.Claims with step attributes. type x5cPayload struct { jose.Claims SANs []string `json:"sans,omitempty"` Step *stepPayload `json:"step,omitempty"` Confirmation *cnfPayload `json:"cnf,omitempty"` chains [][]*x509.Certificate } // X5C is the default provisioner, an entity that can sign tokens necessary for // signature requests. type X5C struct { *base ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Roots []byte `json:"roots"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` ctl *Controller rootPool *x509.CertPool } // GetID returns the provisioner unique identifier. The name and credential id // should uniquely identify any X5C provisioner. func (p *X5C) GetID() string { if p.ID != "" { return p.ID } return p.GetIDForToken() } // GetIDForToken returns an identifier that will be used to load the provisioner // from a token. func (p *X5C) GetIDForToken() string { return "x5c/" + p.Name } // GetTokenID returns the identifier of the token. func (p *X5C) GetTokenID(ott string) (string, error) { // Validate payload token, err := jose.ParseSigned(ott) if err != nil { return "", errors.Wrap(err, "error parsing token") } // Get claims w/out verification. We need to look up the provisioner // key in order to verify the claims and we need the issuer from the claims // before we can look up the provisioner. var claims jose.Claims if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { return "", errors.Wrap(err, "error verifying claims") } return claims.ID, nil } // GetName returns the name of the provisioner. func (p *X5C) GetName() string { return p.Name } // GetType returns the type of provisioner. func (p *X5C) GetType() Type { return TypeX5C } // GetEncryptedKey returns the base provisioner encrypted key if it's defined. func (p *X5C) GetEncryptedKey() (string, string, bool) { return "", "", false } // Init initializes and validates the fields of a X5C type. func (p *X5C) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") case p.Name == "": return errors.New("provisioner name cannot be empty") case len(p.Roots) == 0: return errors.New("provisioner root(s) cannot be empty") } p.rootPool = x509.NewCertPool() var ( block *pem.Block rest = p.Roots count int ) for rest != nil { block, rest = pem.Decode(rest) if block == nil { break } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return errors.Wrap(err, "error parsing x509 certificate from PEM block") } count++ p.rootPool.AddCert(cert) } // Verify that at least one root was found. if count == 0 { return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName()) } config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.ctl, err = NewController(p, p.Claims, config, p.Options) return } // authorizeToken performs common jwt authorization actions and returns the // claims for case specific downstream parsing. // e.g. a Sign request will auth/validate different fields than a Revoke request. func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error parsing x5c token") } verifiedChains, err := jwt.Headers[0].Certificates(x509.VerifyOptions{ Roots: p.rootPool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, }) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error verifying x5c certificate chain in token") } leaf := verifiedChains[0][0] if leaf.KeyUsage&x509.KeyUsageDigitalSignature == 0 { return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature") } // Using the leaf certificate's key to validate the claims accomplishes two // things: // 1. Asserts that the private key used to sign the token corresponds // to the public certificate in the `x5c` header of the token. // 2. Asserts that the claims are valid - have not been tampered with. var claims x5cPayload if err = jwt.Claims(leaf.PublicKey, &claims); err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, "x5c.authorizeToken; error parsing x5c claims") } // According to "rfc7519 JSON Web Token" acceptable skew should be no // more than a few minutes. if err = claims.ValidateWithLeeway(jose.Expected{ Issuer: p.Name, Time: time.Now().UTC(), }, time.Minute); err != nil { return nil, errs.Wrapf(http.StatusUnauthorized, err, "x5c.authorizeToken; invalid x5c claims") } // validate audiences with the defaults if !matchesAudience(claims.Audience, audiences) { return nil, errs.Unauthorized("x5c.authorizeToken; x5c token has invalid audience "+ "claim (aud); expected %s, but got %s", audiences, claims.Audience) } if claims.Subject == "" { return nil, errs.Unauthorized("x5c.authorizeToken; x5c token subject cannot be empty") } // Save the verified chains on the x5c payload object. claims.chains = verifiedChains return &claims, nil } // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *X5C) AuthorizeRevoke(_ context.Context, token string) error { _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") } // NOTE: This is for backwards compatibility with older versions of cli // and certificates. Older versions added the token subject as the only SAN // in a CSR by default. if len(claims.SANs) == 0 { claims.SANs = []string{claims.Subject} } // Certificate templates data := x509util.CreateTemplateData(claims.Subject, claims.SANs) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } // The X509 certificate will be available using the template variable // AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be // used to get all the domains. x5cLeaf := claims.chains[0][0] data.SetAuthorizationCertificate(x5cLeaf) templateOptions, err := TemplateOptions(p.Options, data) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } // Wrap provisioner if the token is an RA token. var self Interface = p if claims.Step != nil && claims.Step.RA != nil { self = &raProvisioner{ Interface: p, raInfo: claims.Step.RA, } } // Check the fingerprint of the certificate request if given. var fingerprint string if claims.Confirmation != nil { fingerprint = claims.Confirmation.Fingerprint } return []SignOption{ self, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeX5C, p.Name, "").WithControllerOptions(p.ctl), profileLimitDuration{ p.ctl.Claimer.DefaultTLSCertDuration(), x5cLeaf.NotBefore, x5cLeaf.NotAfter, }, // validators csrFingerprintValidator(fingerprint), commonNameValidator(claims.Subject), newDefaultSANsValidator(ctx, claims.SANs), defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), p.ctl.newWebhookController( data, linkedca.Webhook_X509, webhook.WithX5CCertificate(x5cLeaf), webhook.WithAuthorizationPrincipal(x5cLeaf.Subject.CommonName), ), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *X5C) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") } if claims.Step == nil || claims.Step.SSH == nil { return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token") } opts := claims.Step.SSH signOptions := []SignOption{ // validates user's SSHOptions with the ones in the token sshCertOptionsValidator(*opts), // validate users's KeyID is the token subject. sshCertOptionsValidator(SignSSHOptions{KeyID: claims.Subject}), } // Default template attributes. certType := sshutil.UserCert keyID := claims.Subject principals := []string{claims.Subject} // Use options in the token. if opts.CertType != "" { if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil { return nil, errs.BadRequestErr(err, "%s", err.Error()) } } if opts.KeyID != "" { keyID = opts.KeyID } if len(opts.Principals) > 0 { principals = opts.Principals } // Certificate templates. data := sshutil.CreateTemplateData(certType, keyID, principals) if v, err := unsafeParseSigned(token); err == nil { data.SetToken(v) } // The X509 certificate will be available using the template variable // AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be // used to get all the domains. x5cLeaf := claims.chains[0][0] data.SetAuthorizationCertificate(x5cLeaf) templateOptions, err := TemplateSSHOptions(p.Options, data) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") } signOptions = append(signOptions, templateOptions) // Add modifiers from custom claims t := now() if !opts.ValidAfter.IsZero() { signOptions = append(signOptions, sshCertValidAfterModifier(cast.Uint64(opts.ValidAfter.RelativeTime(t).Unix()))) } if !opts.ValidBefore.IsZero() { signOptions = append(signOptions, sshCertValidBeforeModifier(cast.Uint64(opts.ValidBefore.RelativeTime(t).Unix()))) } return append(signOptions, p, // Checks the validity bounds, and set the validity if has not been set. &sshLimitDuration{p.ctl.Claimer, x5cLeaf.NotAfter}, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, // Ensure that all principal names are allowed newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), // Call webhooks p.ctl.newWebhookController( data, linkedca.Webhook_SSH, webhook.WithX5CCertificate(x5cLeaf), webhook.WithAuthorizationPrincipal(x5cLeaf.Subject.CommonName), ), ), nil } ================================================ FILE: authority/provisioner/x5c_test.go ================================================ package provisioner import ( "context" "crypto/x509" "encoding/base64" "errors" "fmt" "net/http" "strings" "testing" "time" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" "github.com/smallstep/certificates/api/render" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func assertHasPrefix(t *testing.T, s, p string) bool { t.Helper() return assert.True(t, strings.HasPrefix(s, p), "%q is not a prefix of %q", p, s) } func TestX5C_Getters(t *testing.T) { p, err := generateX5C(nil) require.NoError(t, err) id := "x5c/" + p.Name if got := p.GetID(); got != id { t.Errorf("X5C.GetID() = %v, want %v:%v", got, p.Name, id) } if got := p.GetName(); got != p.Name { t.Errorf("X5C.GetName() = %v, want %v", got, p.Name) } if got := p.GetType(); got != TypeX5C { t.Errorf("X5C.GetType() = %v, want %v", got, TypeX5C) } kid, key, ok := p.GetEncryptedKey() if kid != "" || key != "" || ok == true { t.Errorf("X5C.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } } func TestX5C_Init(t *testing.T) { type ProvisionerValidateTest struct { p *X5C err error extraValid func(*X5C) error } tests := map[string]func(*testing.T) ProvisionerValidateTest{ "fail/empty": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &X5C{}, err: errors.New("provisioner type cannot be empty"), } }, "fail/empty-name": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &X5C{ Type: "X5C", }, err: errors.New("provisioner name cannot be empty"), } }, "fail/empty-type": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &X5C{Name: "foo"}, err: errors.New("provisioner type cannot be empty"), } }, "fail/empty-key": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &X5C{Name: "foo", Type: "bar"}, err: errors.New("provisioner root(s) cannot be empty"), } }, "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")}, err: errors.New("no x509 certificates found in roots attribute for provisioner 'foo'"), } }, "fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest { p, err := generateX5C(nil) require.NoError(t, err) p.Claims = &Claims{DefaultTLSDur: &Duration{0}} return ProvisionerValidateTest{ p: p, err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, "ok": func(t *testing.T) ProvisionerValidateTest { p, err := generateX5C(nil) require.NoError(t, err) return ProvisionerValidateTest{ p: p, } }, "ok/root-chain": func(t *testing.T) ProvisionerValidateTest { p, err := generateX5C([]byte(`-----BEGIN CERTIFICATE----- MIIBtjCCAVygAwIBAgIQNr+f4IkABY2n4wx4sLOMrTAKBggqhkjOPQQDAjAUMRIw EAYDVQQDEwlyb290LXRlc3QwIBcNMTkxMDAyMDI0MDM0WhgPMjExOTA5MDgwMjQw MzJaMBwxGjAYBgNVBAMTEWludGVybWVkaWF0ZS10ZXN0MFkwEwYHKoZIzj0CAQYI KoZIzj0DAQcDQgAEflfRhPjgJXv4zsPWahXjM2UU61aRFErN0iw88ZPyxea22fxl qN9ezntTXxzsS+mZiWapl8B40ACJgvP+WLQBHKOBhTCBgjAOBgNVHQ8BAf8EBAMC AQYwEgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUnJAxiZcy2ibHcuvfFx99 oDwzKXMwHwYDVR0jBBgwFoAUpHS7FfaQ5bCrTxUeu6R2ZC3VGOowHAYDVR0RBBUw E4IRaW50ZXJtZWRpYXRlLXRlc3QwCgYIKoZIzj0EAwIDSAAwRQIgII8XpQ8ezDO1 2xdq3hShf155C5X/5jO8qr0VyEJgzlkCIQCTqph1Gwu/dmuf6dYLCfQqJyb371LC lgsqsR63is+0YQ== -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIBhTCCASqgAwIBAgIRAMalM7pKi0GCdKjO6u88OyowCgYIKoZIzj0EAwIwFDES MBAGA1UEAxMJcm9vdC10ZXN0MCAXDTE5MTAwMjAyMzk0OFoYDzIxMTkwOTA4MDIz OTQ4WjAUMRIwEAYDVQQDEwlyb290LXRlc3QwWTATBgcqhkjOPQIBBggqhkjOPQMB BwNCAAS29QTCXUu7cx9sa9wZPpRSFq/zXaw8Ai3EIygayrBsKnX42U2atBUjcBZO BWL6A+PpLzU9ja867U5SYNHERS+Oo1swWTAOBgNVHQ8BAf8EBAMCAQYwEgYDVR0T AQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQUpHS7FfaQ5bCrTxUeu6R2ZC3VGOowFAYD VR0RBA0wC4IJcm9vdC10ZXN0MAoGCCqGSM49BAMCA0kAMEYCIQC2vgqwla0u8LHH 1MHob14qvS5o76HautbIBW7fcHzz5gIhAIx5A2+wkJYX4026kqaZCk/1sAwTxSGY M46l92gdOozT -----END CERTIFICATE-----`)) require.NoError(t, err) return ProvisionerValidateTest{ p: p, extraValid: func(p *X5C) error { //nolint:staticcheck // We don't have a different way to // check the number of certificates in the pool. numCerts := len(p.rootPool.Subjects()) if numCerts != 2 { return fmt.Errorf("unexpected number of certs: want 2, but got %d", numCerts) } return nil }, } }, } config := Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, } for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) err := tc.p.Init(config) if err != nil { if assert.NotNil(t, tc.err) { assert.EqualError(t, tc.err, err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equal(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID())) if tc.extraValid != nil { assert.Nil(t, tc.extraValid(tc.p)) } } } }) } } func TestX5C_authorizeToken(t *testing.T) { x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") require.NoError(t, err) x5cJWK, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") require.NoError(t, err) type test struct { p *X5C token string code int err error } tests := map[string]func(*testing.T) test{ "fail/bad-token": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("x5c.authorizeToken; error parsing x5c token"), } }, "fail/invalid-cert-chain": func(t *testing.T) test { certs, err := parseCerts([]byte(`-----BEGIN CERTIFICATE----- MIIBpTCCAUugAwIBAgIRAOn2LHXjYyTXQ7PNjDTSKiIwCgYIKoZIzj0EAwIwHDEa MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwOTE0MDk1NTM2WhcNMjkw OTExMDk1NTM2WjAkMSIwIAYDVQQDExlTbWFsbHN0ZXAgSW50ZXJtZWRpYXRlIENB MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE2Cs0TY0dLM4b2s+z8+cc3JJp/W5H zQRvICX/1aJ4MuObNLcvoSguJwJEkYpGB5fhb0KvoL+ebHfEOywGNwrWkaNmMGQw DgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNLJ 4ZXoX9cI6YkGPxgs2US3ssVzMB8GA1UdIwQYMBaAFGIwpqz85wL29aF47Vj9XSVM P9K7MAoGCCqGSM49BAMCA0gAMEUCIQC5c1ldDcesDb31GlO5cEJvOcRrIrNtkk8m a5wpg+9s6QIgHIW6L60F8klQX+EO3o0SBqLeNcaskA4oSZsKjEdpSGo= -----END CERTIFICATE-----`)) require.NoError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) require.NoError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("x5c.authorizeToken; error verifying x5c certificate chain in token"), } }, "fail/doubled-up-self-signed-cert": func(t *testing.T) test { certs, err := parseCerts([]byte(`-----BEGIN CERTIFICATE----- MIIBgjCCASigAwIBAgIQIZiE9wpmSj6SMMDfHD17qjAKBggqhkjOPQQDAjAQMQ4w DAYDVQQDEwVsZWFmMjAgFw0xOTEwMDIwMzEzNTlaGA8yMTE5MDkwODAzMTM1OVow EDEOMAwGA1UEAxMFbGVhZjIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAATuajJI 3YgDaj+jorioJzGJc2+V1hUM7XzN9tIHoUeItgny9GW08TrTc23h1cCZteNZvayG M0wGpGeXOnE4IlH9o2IwYDAOBgNVHQ8BAf8EBAMCBSAwHQYDVR0lBBYwFAYIKwYB BQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBT99+JChTh3LWOHaqlSwNiwND18/zAQ BgNVHREECTAHggVsZWFmMjAKBggqhkjOPQQDAgNIADBFAiB7gMRy3t81HpcnoRAS ELZmDFaEnoLCsVfbmanFykazQQIhAI0sZjoE9t6gvzQp7XQp6CoxzCc3Jv3FwZ8G EXAHTA9L -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIBgjCCASigAwIBAgIQIZiE9wpmSj6SMMDfHD17qjAKBggqhkjOPQQDAjAQMQ4w DAYDVQQDEwVsZWFmMjAgFw0xOTEwMDIwMzEzNTlaGA8yMTE5MDkwODAzMTM1OVow EDEOMAwGA1UEAxMFbGVhZjIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAATuajJI 3YgDaj+jorioJzGJc2+V1hUM7XzN9tIHoUeItgny9GW08TrTc23h1cCZteNZvayG M0wGpGeXOnE4IlH9o2IwYDAOBgNVHQ8BAf8EBAMCBSAwHQYDVR0lBBYwFAYIKwYB BQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBT99+JChTh3LWOHaqlSwNiwND18/zAQ BgNVHREECTAHggVsZWFmMjAKBggqhkjOPQQDAgNIADBFAiB7gMRy3t81HpcnoRAS ELZmDFaEnoLCsVfbmanFykazQQIhAI0sZjoE9t6gvzQp7XQp6CoxzCc3Jv3FwZ8G EXAHTA9L -----END CERTIFICATE-----`)) require.NoError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) require.NoError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("x5c.authorizeToken; error verifying x5c certificate chain in token"), } }, "fail/digital-signature-ext-required": func(t *testing.T) test { certs, err := parseCerts([]byte(`-----BEGIN CERTIFICATE----- MIIBuTCCAV+gAwIBAgIQeRJLdDMIdn/T2ORKxYABezAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFpbnRlcm1lZGlhdGUtdGVzdDAgFw0xOTEwMDIwMjQxMTRaGA8yMTE5 MDkwODAyNDExMlowFDESMBAGA1UEAxMJbGVhZi10ZXN0MFkwEwYHKoZIzj0CAQYI KoZIzj0DAQcDQgAEDA1nGTOujobkcBWklyvymhWE5gQlvNLarVzhhhvPDw+MK2LX yqkXrYZM10GrwQZuQ7ykHnjz00U/KXpPRQ7+0qOBiDCBhTAOBgNVHQ8BAf8EBAMC BSAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBQYv0AK 3GUOvC+m8ZTfyhn7tKQOazAfBgNVHSMEGDAWgBSckDGJlzLaJsdy698XH32gPDMp czAUBgNVHREEDTALgglsZWFmLXRlc3QwCgYIKoZIzj0EAwIDSAAwRQIhAPmertx0 lchRU3kAu647exvlhEr1xosPOu6P8kVYbtTEAiAA51w9EYIT/Zb26M3eQV817T2g Dnhl0ElPQsA92pkqbA== -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIBtjCCAVygAwIBAgIQNr+f4IkABY2n4wx4sLOMrTAKBggqhkjOPQQDAjAUMRIw EAYDVQQDEwlyb290LXRlc3QwIBcNMTkxMDAyMDI0MDM0WhgPMjExOTA5MDgwMjQw MzJaMBwxGjAYBgNVBAMTEWludGVybWVkaWF0ZS10ZXN0MFkwEwYHKoZIzj0CAQYI KoZIzj0DAQcDQgAEflfRhPjgJXv4zsPWahXjM2UU61aRFErN0iw88ZPyxea22fxl qN9ezntTXxzsS+mZiWapl8B40ACJgvP+WLQBHKOBhTCBgjAOBgNVHQ8BAf8EBAMC AQYwEgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUnJAxiZcy2ibHcuvfFx99 oDwzKXMwHwYDVR0jBBgwFoAUpHS7FfaQ5bCrTxUeu6R2ZC3VGOowHAYDVR0RBBUw E4IRaW50ZXJtZWRpYXRlLXRlc3QwCgYIKoZIzj0EAwIDSAAwRQIgII8XpQ8ezDO1 2xdq3hShf155C5X/5jO8qr0VyEJgzlkCIQCTqph1Gwu/dmuf6dYLCfQqJyb371LC lgsqsR63is+0YQ== -----END CERTIFICATE-----`)) require.NoError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("", p.Name, testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) require.NoError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature"), } }, "fail/signature-does-not-match-x5c-pub-key": func(t *testing.T) test { certs, err := parseCerts([]byte(`-----BEGIN CERTIFICATE----- MIIBuDCCAV+gAwIBAgIQFdu723gqgGaTaqjf6ny88zAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFpbnRlcm1lZGlhdGUtdGVzdDAgFw0xOTEwMDIwMzE4NTNaGA8yMTE5 MDkwODAzMTg1MVowFDESMBAGA1UEAxMJbGVhZi10ZXN0MFkwEwYHKoZIzj0CAQYI KoZIzj0DAQcDQgAEaV6807GhWEtMxA39zjuMVHAiN2/Ri5B1R1s+Y/8mlrKIvuvr VpgSPXYruNRFduPWX564Abz/TDmb276JbKGeQqOBiDCBhTAOBgNVHQ8BAf8EBAMC BaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBReMkPW f4MNWdg7KN4xI4ZLJd0IJDAfBgNVHSMEGDAWgBSckDGJlzLaJsdy698XH32gPDMp czAUBgNVHREEDTALgglsZWFmLXRlc3QwCgYIKoZIzj0EAwIDRwAwRAIgKYLKXpTN wtvZZaIvDzq1p8MO/SZ8yI42Ot69dNk/QtkCIBSvg5PozYcfbvwkgX5SwsjfYu0Z AvUgkUQ2G25NBRmX -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIBtjCCAVygAwIBAgIQNr+f4IkABY2n4wx4sLOMrTAKBggqhkjOPQQDAjAUMRIw EAYDVQQDEwlyb290LXRlc3QwIBcNMTkxMDAyMDI0MDM0WhgPMjExOTA5MDgwMjQw MzJaMBwxGjAYBgNVBAMTEWludGVybWVkaWF0ZS10ZXN0MFkwEwYHKoZIzj0CAQYI KoZIzj0DAQcDQgAEflfRhPjgJXv4zsPWahXjM2UU61aRFErN0iw88ZPyxea22fxl qN9ezntTXxzsS+mZiWapl8B40ACJgvP+WLQBHKOBhTCBgjAOBgNVHQ8BAf8EBAMC AQYwEgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUnJAxiZcy2ibHcuvfFx99 oDwzKXMwHwYDVR0jBBgwFoAUpHS7FfaQ5bCrTxUeu6R2ZC3VGOowHAYDVR0RBBUw E4IRaW50ZXJtZWRpYXRlLXRlc3QwCgYIKoZIzj0EAwIDSAAwRQIgII8XpQ8ezDO1 2xdq3hShf155C5X/5jO8qr0VyEJgzlkCIQCTqph1Gwu/dmuf6dYLCfQqJyb371LC lgsqsR63is+0YQ== -----END CERTIFICATE-----`)) require.NoError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) require.NoError(t, err) p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("", "foobar", testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) require.NoError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("x5c.authorizeToken; error parsing x5c claims"), } }, "fail/invalid-issuer": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("", "foobar", testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), x5cJWK, withX5CHdr(x5cCerts)) require.NoError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("x5c.authorizeToken; invalid x5c claims"), } }, "fail/invalid-audience": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("", p.GetName(), "foobar", "", []string{"test.smallstep.com"}, time.Now(), x5cJWK, withX5CHdr(x5cCerts)) require.NoError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("x5c.authorizeToken; x5c token has invalid audience claim (aud)"), } }, "fail/empty-subject": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), x5cJWK, withX5CHdr(x5cCerts)) require.NoError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("x5c.authorizeToken; x5c token subject cannot be empty"), } }, "ok": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", []string{"test.smallstep.com"}, time.Now(), x5cJWK, withX5CHdr(x5cCerts)) require.NoError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equal(t, tc.code, sc.StatusCode()) } assertHasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.NoError(t, tc.err) { assert.NotNil(t, claims) assert.NotNil(t, claims.chains) } } }) } } func TestX5C_AuthorizeSign(t *testing.T) { certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") require.NoError(t, err) jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") require.NoError(t, err) type test struct { p *X5C token string code int err error sans []string fingerprint string } tests := map[string]func(*testing.T) test{ "fail/invalid-token": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("x5c.AuthorizeSign: x5c.authorizeToken; error parsing x5c token"), } }, "ok/empty-sans": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", []string{}, time.Now(), jwk, withX5CHdr(certs)) require.NoError(t, err) return test{ p: p, token: tok, sans: []string{"foo"}, } }, "ok/multi-sans": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.Sign[0], "", []string{"127.0.0.1", "foo", "max@smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) require.NoError(t, err) return test{ p: p, token: tok, sans: []string{"127.0.0.1", "foo", "max@smallstep.com"}, } }, "ok/cnf": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) x5c := make([]string, len(certs)) for i, cert := range certs { x5c[i] = base64.StdEncoding.EncodeToString(cert.Raw) } extraHeaders := map[string]any{"x5c": x5c} extraClaims := map[string]any{ "sans": []string{"127.0.0.1", "foo", "max@smallstep.com"}, "cnf": map[string]any{"x5rt#S256": "fingerprint"}, } tok, err := generateCustomToken("foo", p.GetName(), testAudiences.Sign[0], jwk, extraHeaders, extraClaims) require.NoError(t, err) return test{ p: p, token: tok, sans: []string{"127.0.0.1", "foo", "max@smallstep.com"}, fingerprint: "fingerprint", } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) ctx := NewContextWithMethod(context.Background(), SignIdentityMethod) if opts, err := tc.p.AuthorizeSign(ctx, tc.token); err != nil { if assert.NotNil(t, tc.err, err.Error()) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equal(t, tc.code, sc.StatusCode()) } assertHasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { assert.Len(t, opts, 11) for _, o := range opts { switch v := o.(type) { case *X5C: case certificateOptionsFunc: case *provisionerExtensionOption: assert.Equal(t, TypeX5C, v.Type) assert.Equal(t, tc.p.GetName(), v.Name) assert.Equal(t, "", v.CredentialID) assert.Len(t, v.KeyValuePairs, 0) case profileLimitDuration: assert.Equal(t, tc.p.ctl.Claimer.DefaultTLSCertDuration(), v.def) claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign) require.NoError(t, err) assert.Equal(t, claims.chains[0][0].NotAfter, v.notAfter) case commonNameValidator: assert.Equal(t, "foo", string(v)) case defaultPublicKeyValidator: case *defaultSANsValidator: assert.Equal(t, tc.sans, v.sans) assert.Equal(t, SignIdentityMethod, MethodFromContext(v.ctx)) case *validityValidator: assert.Equal(t, tc.p.ctl.Claimer.MinTLSCertDuration(), v.min) assert.Equal(t, tc.p.ctl.Claimer.MaxTLSCertDuration(), v.max) case *x509NamePolicyValidator: assert.Equal(t, nil, v.policyEngine) case *WebhookController: assert.Len(t, v.webhooks, 0) assert.Equal(t, linkedca.Webhook_X509, v.certType) assert.Len(t, v.options, 2) case csrFingerprintValidator: assert.Equal(t, tc.fingerprint, string(v)) default: require.NoError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } } } }) } } func TestX5C_AuthorizeRevoke(t *testing.T) { type test struct { p *X5C token string code int err error } tests := map[string]func(*testing.T) test{ "fail/invalid-token": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("x5c.AuthorizeRevoke: x5c.authorizeToken; error parsing x5c token"), } }, "ok": func(t *testing.T) test { certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") require.NoError(t, err) serialNumber := certs[0].SerialNumber.String() jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") require.NoError(t, err) p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken(serialNumber, p.GetName(), testAudiences.Revoke[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) require.NoError(t, err) return test{ p: p, token: tok, } }, "ok/different-serial-number": func(t *testing.T) test { certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") require.NoError(t, err) jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") require.NoError(t, err) p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("123456789", p.GetName(), testAudiences.Revoke[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) require.NoError(t, err) return test{ p: p, token: tok, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equal(t, tc.code, sc.StatusCode()) } assertHasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func TestX5C_AuthorizeRenew(t *testing.T) { now := time.Now().Truncate(time.Second) type test struct { p *X5C code int err error } tests := map[string]func(*testing.T) test{ "fail/renew-disabled": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) require.NoError(t, err) return test{ p: p, code: http.StatusUnauthorized, err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) return test{ p: p, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{ NotBefore: now, NotAfter: now.Add(time.Hour), }); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equal(t, tc.code, sc.StatusCode()) } assertHasPrefix(t, err.Error(), tc.err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func TestX5C_AuthorizeSSHSign(t *testing.T) { x5cCerts, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") require.NoError(t, err) x5cJWK, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") require.NoError(t, err) _, fn := mockNow() defer fn() type test struct { p *X5C token string claims *x5cPayload fingerprint string count int code int err error } tests := map[string]func(*testing.T) test{ "fail/sshCA-disabled": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) // disable sshCA enable := false p.Claims = &Claims{EnableSSHCA: &enable} p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) require.NoError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: fmt.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()), } }, "fail/invalid-token": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, err: errors.New("x5c.AuthorizeSSHSign: x5c.authorizeToken; error parsing x5c token"), } }, "fail/no-Step-claim": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) tok, err := generateToken("foo", p.GetName(), testAudiences.SSHSign[0], "", []string{"test.smallstep.com"}, time.Now(), x5cJWK, withX5CHdr(x5cCerts)) require.NoError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"), } }, "fail/no-SSH-subattribute-in-claims": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) id, err := randutil.ASCII(64) require.NoError(t, err) now := time.Now() claims := &x5cPayload{ Claims: jose.Claims{ ID: id, Subject: "foo", Issuer: p.GetName(), IssuedAt: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Audience: []string{testAudiences.SSHSign[0]}, }, Step: &stepPayload{}, } tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) require.NoError(t, err) return test{ p: p, token: tok, code: http.StatusUnauthorized, err: errors.New("x5c.AuthorizeSSHSign; x5c token must be an SSH provisioning token"), } }, "ok/with-claims": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) id, err := randutil.ASCII(64) require.NoError(t, err) now := time.Now() claims := &x5cPayload{ Claims: jose.Claims{ ID: id, Subject: "foo", Issuer: p.GetName(), IssuedAt: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Audience: []string{testAudiences.SSHSign[0]}, }, Step: &stepPayload{SSH: &SignSSHOptions{ CertType: SSHUserCert, KeyID: "foo", Principals: []string{"max", "mariano", "alan"}, ValidAfter: TimeDuration{d: 5 * time.Minute}, ValidBefore: TimeDuration{d: 10 * time.Minute}, }}, } tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) require.NoError(t, err) return test{ p: p, claims: claims, token: tok, count: 12, } }, "ok/without-claims": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) id, err := randutil.ASCII(64) require.NoError(t, err) now := time.Now() claims := &x5cPayload{ Claims: jose.Claims{ ID: id, Subject: "foo", Issuer: p.GetName(), IssuedAt: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Audience: []string{testAudiences.SSHSign[0]}, }, Step: &stepPayload{SSH: &SignSSHOptions{}}, } tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) require.NoError(t, err) return test{ p: p, claims: claims, token: tok, count: 10, } }, "ok/cnf": func(t *testing.T) test { p, err := generateX5C(nil) require.NoError(t, err) id, err := randutil.ASCII(64) require.NoError(t, err) now := time.Now() claims := &x5cPayload{ Claims: jose.Claims{ ID: id, Subject: "foo", Issuer: p.GetName(), IssuedAt: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Audience: []string{testAudiences.SSHSign[0]}, }, Step: &stepPayload{SSH: &SignSSHOptions{ CertType: SSHHostCert, Principals: []string{"host.smallstep.com"}, }}, Confirmation: &cnfPayload{ Fingerprint: "fingerprint", }, } tok, err := generateX5CSSHToken(x5cJWK, claims, withX5CHdr(x5cCerts)) require.NoError(t, err) return test{ p: p, claims: claims, token: tok, fingerprint: "fingerprint", count: 10, } }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equal(t, tc.code, sc.StatusCode()) } assertHasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { tot := 0 firstValidator := true nw := now() for _, o := range opts { switch v := o.(type) { case Interface: case sshCertOptionsValidator: tc.claims.Step.SSH.ValidAfter.t = time.Time{} tc.claims.Step.SSH.ValidBefore.t = time.Time{} if firstValidator { assert.Equal(t, *tc.claims.Step.SSH, SignSSHOptions(v)) } else { assert.Equal(t, SignSSHOptions{KeyID: tc.claims.Subject}, SignSSHOptions(v)) } firstValidator = false case sshCertValidAfterModifier: assert.Equal(t, tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix(), int64(v)) case sshCertValidBeforeModifier: assert.Equal(t, tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix(), int64(v)) case *sshLimitDuration: assert.Equal(t, tc.p.ctl.Claimer, v.Claimer) assert.Equal(t, x5cCerts[0].NotAfter, v.NotAfter) case *sshCertValidityValidator: assert.Equal(t, tc.p.ctl.Claimer, v.Claimer) case *sshNamePolicyValidator: assert.Nil(t, v.userPolicyEngine) assert.Nil(t, v.hostPolicyEngine) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: case *WebhookController: assert.Len(t, v.webhooks, 0) assert.Equal(t, linkedca.Webhook_SSH, v.certType) assert.Len(t, v.options, 2) default: require.NoError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } assert.Equal(t, tc.count, tot) } } } }) } } ================================================ FILE: authority/provisioners.go ================================================ package authority import ( "bytes" "context" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "os" "github.com/pkg/errors" "github.com/smallstep/cli-utils/step" "github.com/smallstep/cli-utils/ui" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" ) type raProvisioner interface { RAInfo() *provisioner.RAInfo } type attProvisioner interface { AttestationData() *provisioner.AttestationData } // wrapProvisioner wraps the given provisioner with RA information and // attestation data. func wrapProvisioner(p provisioner.Interface, attData *provisioner.AttestationData) *wrappedProvisioner { var raInfo *provisioner.RAInfo if rap, ok := p.(raProvisioner); ok { raInfo = rap.RAInfo() } return &wrappedProvisioner{ Interface: p, attestationData: attData, raInfo: raInfo, } } // wrapRAProvisioner wraps the given provisioner with RA information. func wrapRAProvisioner(p provisioner.Interface, raInfo *provisioner.RAInfo) *wrappedProvisioner { return &wrappedProvisioner{ Interface: p, raInfo: raInfo, } } // isRAProvisioner returns if the given provisioner is an RA provisioner. func isRAProvisioner(p provisioner.Interface) bool { if rap, ok := p.(raProvisioner); ok { return rap.RAInfo() != nil } return false } // wrappedProvisioner implements raProvisioner and attProvisioner. type wrappedProvisioner struct { provisioner.Interface attestationData *provisioner.AttestationData raInfo *provisioner.RAInfo } func (p *wrappedProvisioner) AttestationData() *provisioner.AttestationData { return p.attestationData } func (p *wrappedProvisioner) RAInfo() *provisioner.RAInfo { return p.raInfo } // GetEncryptedKey returns the JWE key corresponding to the given kid argument. func (a *Authority) GetEncryptedKey(kid string) (string, error) { a.adminMutex.RLock() defer a.adminMutex.RUnlock() key, ok := a.provisioners.LoadEncryptedKey(kid) if !ok { return "", errs.NotFound("encrypted key with kid %s was not found", kid) } return key, nil } // GetProvisioners returns a map listing each provisioner and the JWK Key Set // with their public keys. func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List, string, error) { a.adminMutex.RLock() defer a.adminMutex.RUnlock() provisioners, nextCursor := a.provisioners.Find(cursor, limit) return provisioners, nextCursor, nil } // LoadProvisionerByCertificate returns an interface to the provisioner that // provisioned the certificate. func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) { a.adminMutex.RLock() defer a.adminMutex.RUnlock() if p, err := a.unsafeLoadProvisionerFromDatabase(crt); err == nil { return p, nil } return a.unsafeLoadProvisionerFromExtension(crt) } func (a *Authority) unsafeLoadProvisionerFromExtension(crt *x509.Certificate) (provisioner.Interface, error) { p, ok := a.provisioners.LoadByCertificate(crt) if !ok || p.GetType() == 0 { return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate") } return p, nil } func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (provisioner.Interface, error) { // certificateDataGetter is an interface that can be used to retrieve the // provisioner from a db or a linked ca. type certificateDataGetter interface { GetCertificateData(string) (*db.CertificateData, error) } var err error var data *db.CertificateData if cdg, ok := a.adminDB.(certificateDataGetter); ok { data, err = cdg.GetCertificateData(crt.SerialNumber.String()) } else if cdg, ok := a.db.(certificateDataGetter); ok { data, err = cdg.GetCertificateData(crt.SerialNumber.String()) } if err == nil && data != nil && data.Provisioner != nil { if p, ok := a.provisioners.Load(data.Provisioner.ID); ok { if data.RaInfo != nil { return wrapRAProvisioner(p, data.RaInfo), nil } return p, nil } } return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate") } // LoadProvisionerByToken returns an interface to the provisioner that // provisioned the token. func (a *Authority) LoadProvisionerByToken(token *jose.JSONWebToken, claims *jose.Claims) (provisioner.Interface, error) { a.adminMutex.RLock() defer a.adminMutex.RUnlock() p, ok := a.provisioners.LoadByToken(token, claims) if !ok { return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from token") } return p, nil } // LoadProvisionerByID returns an interface to the provisioner with the given ID. func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) { a.adminMutex.RLock() defer a.adminMutex.RUnlock() p, ok := a.provisioners.Load(id) if !ok { return nil, admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id) } return p, nil } // LoadProvisionerByName returns an interface to the provisioner with the given Name. func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, error) { a.adminMutex.RLock() defer a.adminMutex.RUnlock() p, ok := a.provisioners.LoadByName(name) if !ok { return nil, admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", name) } return p, nil } func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.Config, error) { // Merge global and configuration claims claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, config.GlobalProvisionerClaims) if err != nil { return provisioner.Config{}, err } // TODO: should we also be combining the ssh federated roots here? // If we rotate ssh roots keys, sshpop provisioner will lose ability to // validate old SSH certificates, unless they are added as federated certs. sshKeys, err := a.GetSSHRoots(ctx) if err != nil { return provisioner.Config{}, err } return provisioner.Config{ Claims: claimer.Claims(), Audiences: a.config.GetAudiences(), SSHKeys: &provisioner.SSHKeys{ UserKeys: sshKeys.UserKeys, HostKeys: sshKeys.HostKeys, }, GetIdentityFunc: a.getIdentityFunc, AuthorizeRenewFunc: a.authorizeRenewFunc, AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc, WebhookClient: a.webhookClient, HTTPClient: a.httpClient, WrapTransport: a.wrapTransport, SCEPKeyManager: a.scepKeyManager, }, nil } // StoreProvisioner stores a provisioner to the authority. func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { a.adminMutex.Lock() defer a.adminMutex.Unlock() certProv, err := ProvisionerToCertificates(prov) if err != nil { return admin.WrapErrorISE(err, "error converting to certificates provisioner from linkedca provisioner") } if _, ok := a.provisioners.LoadByName(prov.GetName()); ok { return admin.NewError(admin.ErrorBadRequestType, "provisioner with name %s already exists", prov.GetName()) } if _, ok := a.provisioners.LoadByTokenID(certProv.GetIDForToken()); ok { return admin.NewError(admin.ErrorBadRequestType, "provisioner with token ID %s already exists", certProv.GetIDForToken()) } provisionerConfig, err := a.generateProvisionerConfig(ctx) if err != nil { return admin.WrapErrorISE(err, "error generating provisioner config") } if err := a.checkProvisionerPolicy(ctx, prov.Name, prov.Policy); err != nil { return err } if err := certProv.Init(provisionerConfig); err != nil { return admin.WrapError(admin.ErrorBadRequestType, err, "error validating configuration for provisioner %q", prov.Name) } // Store to database -- this will set the ID. if err := a.adminDB.CreateProvisioner(ctx, prov); err != nil { return admin.WrapErrorISE(err, "error creating provisioner") } // We need a new conversion that has the newly set ID. certProv, err = ProvisionerToCertificates(prov) if err != nil { return admin.WrapErrorISE(err, "error converting to certificates provisioner from linkedca provisioner") } if err := certProv.Init(provisionerConfig); err != nil { return admin.WrapErrorISE(err, "error initializing provisioner %s", prov.Name) } if err := a.provisioners.Store(certProv); err != nil { if err := a.ReloadAdminResources(ctx); err != nil { return admin.WrapErrorISE(err, "error reloading admin resources on failed provisioner store") } return admin.WrapErrorISE(err, "error storing provisioner in authority cache") } return nil } // UpdateProvisioner stores an provisioner.Interface to the authority. func (a *Authority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error { a.adminMutex.Lock() defer a.adminMutex.Unlock() certProv, err := ProvisionerToCertificates(nu) if err != nil { return admin.WrapErrorISE(err, "error converting to certificates provisioner from linkedca provisioner") } provisionerConfig, err := a.generateProvisionerConfig(ctx) if err != nil { return admin.WrapErrorISE(err, "error generating provisioner config") } if err := a.checkProvisionerPolicy(ctx, nu.Name, nu.Policy); err != nil { return err } if err := certProv.Init(provisionerConfig); err != nil { return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name) } if err := a.provisioners.Update(certProv); err != nil { return admin.WrapErrorISE(err, "error updating provisioner '%s' in authority cache", nu.Name) } if err := a.adminDB.UpdateProvisioner(ctx, nu); err != nil { if err := a.ReloadAdminResources(ctx); err != nil { return admin.WrapErrorISE(err, "error reloading admin resources on failed provisioner update") } return admin.WrapErrorISE(err, "error updating provisioner '%s'", nu.Name) } return nil } // RemoveProvisioner removes an provisioner.Interface from the authority. func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error { a.adminMutex.Lock() defer a.adminMutex.Unlock() p, ok := a.provisioners.Load(id) if !ok { return admin.NewError(admin.ErrorBadRequestType, "provisioner %s not found", id) } provName, provID := p.GetName(), p.GetID() if a.IsAdminAPIEnabled() { // Validate // - Check that there will be SUPER_ADMINs that remain after we // remove this provisioner. if a.IsAdminAPIEnabled() && a.admins.SuperCount() == a.admins.SuperCountByProvisioner(provName) { return admin.NewError(admin.ErrorBadRequestType, "cannot remove provisioner %s because no super admins will remain", provName) } // Delete all admins associated with the provisioner. admins, ok := a.admins.LoadByProvisioner(provName) if ok { for _, adm := range admins { if err := a.removeAdmin(ctx, adm.Id); err != nil { return admin.WrapErrorISE(err, "error deleting admin %s, as part of provisioner %s deletion", adm.Subject, provName) } } } } // Remove provisioner from authority caches. if err := a.provisioners.Remove(provID); err != nil { return admin.WrapErrorISE(err, "error removing provisioner from authority cache") } // Remove provisioner from database. if err := a.adminDB.DeleteProvisioner(ctx, provID); err != nil { if err := a.ReloadAdminResources(ctx); err != nil { return admin.WrapErrorISE(err, "error reloading admin resources on failed provisioner remove") } return admin.WrapErrorISE(err, "error deleting provisioner %s", provName) } return nil } // CreateFirstProvisioner creates and stores the first provisioner when using // admin database provisioner storage. func CreateFirstProvisioner(ctx context.Context, adminDB admin.DB, password string) (*linkedca.Provisioner, error) { if password == "" { pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one") if err != nil { return nil, err } password = string(pass) } jwk, jwe, err := jose.GenerateDefaultKeyPair([]byte(password)) if err != nil { return nil, admin.WrapErrorISE(err, "error generating JWK key pair") } jwkPubBytes, err := jwk.MarshalJSON() if err != nil { return nil, admin.WrapErrorISE(err, "error marshaling JWK") } jwePrivStr, err := jwe.CompactSerialize() if err != nil { return nil, admin.WrapErrorISE(err, "error serializing JWE") } p := &linkedca.Provisioner{ Name: "Admin JWK", Type: linkedca.Provisioner_JWK, Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_JWK{ JWK: &linkedca.JWKProvisioner{ PublicKey: jwkPubBytes, EncryptedPrivateKey: []byte(jwePrivStr), }, }, }, Claims: &linkedca.Claims{ X509: &linkedca.X509Claims{ Enabled: true, Durations: &linkedca.Durations{ Default: "5m", }, }, }, } if err := adminDB.CreateProvisioner(ctx, p); err != nil { return nil, admin.WrapErrorISE(err, "error creating provisioner") } return p, nil } // ValidateClaims validates the Claims type. func ValidateClaims(c *linkedca.Claims) error { if c == nil { return nil } if c.X509 != nil { if c.X509.Durations != nil { if err := ValidateDurations(c.X509.Durations); err != nil { return err } } } if c.Ssh != nil { if c.Ssh.UserDurations != nil { if err := ValidateDurations(c.Ssh.UserDurations); err != nil { return err } } if c.Ssh.HostDurations != nil { if err := ValidateDurations(c.Ssh.HostDurations); err != nil { return err } } } return nil } // ValidateDurations validates the Durations type. func ValidateDurations(d *linkedca.Durations) error { var ( err error minDur, maxDur, def *provisioner.Duration ) if d.Min != "" { minDur, err = provisioner.NewDuration(d.Min) if err != nil { return admin.WrapError(admin.ErrorBadRequestType, err, "min duration '%s' is invalid", d.Min) } if minDur.Value() < 0 { return admin.WrapError(admin.ErrorBadRequestType, err, "min duration '%s' cannot be less than 0", d.Min) } } if d.Max != "" { maxDur, err = provisioner.NewDuration(d.Max) if err != nil { return admin.WrapError(admin.ErrorBadRequestType, err, "max duration '%s' is invalid", d.Max) } if maxDur.Value() < 0 { return admin.WrapError(admin.ErrorBadRequestType, err, "max duration '%s' cannot be less than 0", d.Max) } } if d.Default != "" { def, err = provisioner.NewDuration(d.Default) if err != nil { return admin.WrapError(admin.ErrorBadRequestType, err, "default duration '%s' is invalid", d.Default) } if def.Value() < 0 { return admin.WrapError(admin.ErrorBadRequestType, err, "default duration '%s' cannot be less than 0", d.Default) } } if d.Min != "" && d.Max != "" && minDur.Value() > maxDur.Value() { return admin.NewError(admin.ErrorBadRequestType, "min duration '%s' cannot be greater than max duration '%s'", d.Min, d.Max) } if d.Min != "" && d.Default != "" && minDur.Value() > def.Value() { return admin.NewError(admin.ErrorBadRequestType, "min duration '%s' cannot be greater than default duration '%s'", d.Min, d.Default) } if d.Default != "" && d.Max != "" && minDur.Value() > def.Value() { return admin.NewError(admin.ErrorBadRequestType, "default duration '%s' cannot be greater than max duration '%s'", d.Default, d.Max) } return nil } func provisionerListToCertificates(l []*linkedca.Provisioner) (provisioner.List, error) { var nu provisioner.List for _, p := range l { certProv, err := ProvisionerToCertificates(p) if err != nil { return nil, err } nu = append(nu, certProv) } return nu, nil } func optionsToCertificates(p *linkedca.Provisioner) *provisioner.Options { ops := &provisioner.Options{ X509: &provisioner.X509Options{}, SSH: &provisioner.SSHOptions{}, } if p.X509Template != nil { ops.X509.Template = string(p.X509Template.Template) ops.X509.TemplateData = p.X509Template.Data } if p.SshTemplate != nil { ops.SSH.Template = string(p.SshTemplate.Template) ops.SSH.TemplateData = p.SshTemplate.Data } if pol := p.GetPolicy(); pol != nil { if x := pol.GetX509(); x != nil { if allow := x.GetAllow(); allow != nil { ops.X509.AllowedNames = &policy.X509NameOptions{ DNSDomains: allow.Dns, IPRanges: allow.Ips, EmailAddresses: allow.Emails, URIDomains: allow.Uris, } } if deny := x.GetDeny(); deny != nil { ops.X509.DeniedNames = &policy.X509NameOptions{ DNSDomains: deny.Dns, IPRanges: deny.Ips, EmailAddresses: deny.Emails, URIDomains: deny.Uris, } } } if ssh := pol.GetSsh(); ssh != nil { if host := ssh.GetHost(); host != nil { ops.SSH.Host = &policy.SSHHostCertificateOptions{} if allow := host.GetAllow(); allow != nil { ops.SSH.Host.AllowedNames = &policy.SSHNameOptions{ DNSDomains: allow.Dns, IPRanges: allow.Ips, Principals: allow.Principals, } } if deny := host.GetDeny(); deny != nil { ops.SSH.Host.DeniedNames = &policy.SSHNameOptions{ DNSDomains: deny.Dns, IPRanges: deny.Ips, Principals: deny.Principals, } } } if user := ssh.GetUser(); user != nil { ops.SSH.User = &policy.SSHUserCertificateOptions{} if allow := user.GetAllow(); allow != nil { ops.SSH.User.AllowedNames = &policy.SSHNameOptions{ EmailAddresses: allow.Emails, Principals: allow.Principals, } } if deny := user.GetDeny(); deny != nil { ops.SSH.User.DeniedNames = &policy.SSHNameOptions{ EmailAddresses: deny.Emails, Principals: deny.Principals, } } } } } for _, wh := range p.Webhooks { whCert := webhookToCertificates(wh) ops.Webhooks = append(ops.Webhooks, whCert) } return ops } func webhookToCertificates(wh *linkedca.Webhook) *provisioner.Webhook { pwh := &provisioner.Webhook{ ID: wh.Id, Name: wh.Name, URL: wh.Url, Kind: wh.Kind.String(), Secret: wh.Secret, DisableTLSClientAuth: wh.DisableTlsClientAuth, CertType: wh.CertType.String(), } switch a := wh.GetAuth().(type) { case *linkedca.Webhook_BearerToken: pwh.BearerToken = a.BearerToken.BearerToken case *linkedca.Webhook_BasicAuth: pwh.BasicAuth.Username = a.BasicAuth.Username pwh.BasicAuth.Password = a.BasicAuth.Password } return pwh } func provisionerWebhookToLinkedca(pwh *provisioner.Webhook) *linkedca.Webhook { lwh := &linkedca.Webhook{ Id: pwh.ID, Name: pwh.Name, Url: pwh.URL, Kind: linkedca.Webhook_Kind(linkedca.Webhook_Kind_value[pwh.Kind]), Secret: pwh.Secret, DisableTlsClientAuth: pwh.DisableTLSClientAuth, CertType: linkedca.Webhook_CertType(linkedca.Webhook_CertType_value[pwh.CertType]), } if pwh.BearerToken != "" { lwh.Auth = &linkedca.Webhook_BearerToken{ BearerToken: &linkedca.BearerToken{ BearerToken: pwh.BearerToken, }, } } else if pwh.BasicAuth.Username != "" || pwh.BasicAuth.Password != "" { lwh.Auth = &linkedca.Webhook_BasicAuth{ BasicAuth: &linkedca.BasicAuth{ Username: pwh.BasicAuth.Username, Password: pwh.BasicAuth.Password, }, } } return lwh } func durationsToCertificates(d *linkedca.Durations) (minDur, maxDur, def *provisioner.Duration, err error) { if d.Min != "" { minDur, err = provisioner.NewDuration(d.Min) if err != nil { return nil, nil, nil, admin.WrapErrorISE(err, "error parsing minimum duration '%s'", d.Min) } } if d.Max != "" { maxDur, err = provisioner.NewDuration(d.Max) if err != nil { return nil, nil, nil, admin.WrapErrorISE(err, "error parsing maximum duration '%s'", d.Max) } } if d.Default != "" { def, err = provisioner.NewDuration(d.Default) if err != nil { return nil, nil, nil, admin.WrapErrorISE(err, "error parsing default duration '%s'", d.Default) } } return } func durationsToLinkedca(d *provisioner.Duration) string { if d == nil { return "" } return d.Duration.String() } // claimsToCertificates converts the linkedca provisioner claims type to the // certifictes claims type. func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) { if c == nil { //nolint:nilnil // nil claims do not pose an issue. return nil, nil } pc := &provisioner.Claims{ DisableRenewal: &c.DisableRenewal, AllowRenewalAfterExpiry: &c.AllowRenewalAfterExpiry, DisableSmallstepExtensions: &c.DisableSmallstepExtensions, } var err error if xc := c.X509; xc != nil { if d := xc.Durations; d != nil { pc.MinTLSDur, pc.MaxTLSDur, pc.DefaultTLSDur, err = durationsToCertificates(d) if err != nil { return nil, err } } } if sc := c.Ssh; sc != nil { pc.EnableSSHCA = &sc.Enabled if d := sc.UserDurations; d != nil { pc.MinUserSSHDur, pc.MaxUserSSHDur, pc.DefaultUserSSHDur, err = durationsToCertificates(d) if err != nil { return nil, err } } if d := sc.HostDurations; d != nil { pc.MinHostSSHDur, pc.MaxHostSSHDur, pc.DefaultHostSSHDur, err = durationsToCertificates(d) if err != nil { return nil, err } } } return pc, nil } func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims { if c == nil { return nil } disableRenewal := config.DefaultDisableRenewal allowRenewalAfterExpiry := config.DefaultAllowRenewalAfterExpiry disableSmallstepExtensions := config.DefaultDisableSmallstepExtensions if c.DisableRenewal != nil { disableRenewal = *c.DisableRenewal } if c.AllowRenewalAfterExpiry != nil { allowRenewalAfterExpiry = *c.AllowRenewalAfterExpiry } if c.DisableSmallstepExtensions != nil { disableSmallstepExtensions = *c.DisableSmallstepExtensions } lc := &linkedca.Claims{ DisableRenewal: disableRenewal, AllowRenewalAfterExpiry: allowRenewalAfterExpiry, DisableSmallstepExtensions: disableSmallstepExtensions, } if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil { lc.X509 = &linkedca.X509Claims{ Enabled: true, Durations: &linkedca.Durations{ Default: durationsToLinkedca(c.DefaultTLSDur), Min: durationsToLinkedca(c.MinTLSDur), Max: durationsToLinkedca(c.MaxTLSDur), }, } } if c.EnableSSHCA != nil && *c.EnableSSHCA { lc.Ssh = &linkedca.SSHClaims{ Enabled: true, } if c.DefaultUserSSHDur != nil || c.MinUserSSHDur != nil || c.MaxUserSSHDur != nil { lc.Ssh.UserDurations = &linkedca.Durations{ Default: durationsToLinkedca(c.DefaultUserSSHDur), Min: durationsToLinkedca(c.MinUserSSHDur), Max: durationsToLinkedca(c.MaxUserSSHDur), } } if c.DefaultHostSSHDur != nil || c.MinHostSSHDur != nil || c.MaxHostSSHDur != nil { lc.Ssh.HostDurations = &linkedca.Durations{ Default: durationsToLinkedca(c.DefaultHostSSHDur), Min: durationsToLinkedca(c.MinHostSSHDur), Max: durationsToLinkedca(c.MaxHostSSHDur), } } } return lc } func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *linkedca.Template, []*linkedca.Webhook, error) { var err error var x509Template, sshTemplate *linkedca.Template if p == nil { return nil, nil, nil, nil } if p.X509 != nil && p.X509.HasTemplate() { x509Template = &linkedca.Template{ Template: nil, Data: nil, } if p.X509.Template != "" { x509Template.Template = []byte(p.X509.Template) } else if p.X509.TemplateFile != "" { filename := step.Abs(p.X509.TemplateFile) if x509Template.Template, err = os.ReadFile(filename); err != nil { return nil, nil, nil, errors.Wrap(err, "error reading x509 template") } } if p.X509.TemplateData != nil { x509Template.Data = p.X509.TemplateData } } if p.SSH != nil && p.SSH.HasTemplate() { sshTemplate = &linkedca.Template{ Template: nil, Data: nil, } if p.SSH.Template != "" { sshTemplate.Template = []byte(p.SSH.Template) } else if p.SSH.TemplateFile != "" { filename := step.Abs(p.SSH.TemplateFile) if sshTemplate.Template, err = os.ReadFile(filename); err != nil { return nil, nil, nil, errors.Wrap(err, "error reading ssh template") } } if p.SSH.TemplateData != nil { sshTemplate.Data = p.SSH.TemplateData } } var webhooks []*linkedca.Webhook for _, pwh := range p.Webhooks { webhooks = append(webhooks, provisionerWebhookToLinkedca(pwh)) } return x509Template, sshTemplate, webhooks, nil } func provisionerPEMToLinkedca(b []byte) [][]byte { var roots [][]byte var block *pem.Block for { if block, b = pem.Decode(b); block == nil { break } roots = append(roots, pem.EncodeToMemory(block)) } return roots } func provisionerPEMToCertificates(bs [][]byte) []byte { var roots []byte for i, root := range bs { if i > 0 && !bytes.HasSuffix(root, []byte{'\n'}) { roots = append(roots, '\n') } roots = append(roots, root...) } return roots } // ProvisionerToCertificates converts the linkedca provisioner type to the certificates provisioner // interface. func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, error) { claims, err := claimsToCertificates(p.Claims) if err != nil { return nil, err } details := p.Details.GetData() if details == nil { return nil, errors.New("provisioner does not have any details") } options := optionsToCertificates(p) switch d := details.(type) { case *linkedca.ProvisionerDetails_JWK: jwk := new(jose.JSONWebKey) if err := json.Unmarshal(d.JWK.PublicKey, &jwk); err != nil { return nil, errors.Wrap(err, "error unmarshaling public key") } return &provisioner.JWK{ ID: p.Id, Type: p.Type.String(), Name: p.Name, Key: jwk, EncryptedKey: string(d.JWK.EncryptedPrivateKey), Claims: claims, Options: options, }, nil case *linkedca.ProvisionerDetails_X5C: var roots []byte for i, root := range d.X5C.GetRoots() { if i > 0 { roots = append(roots, '\n') } roots = append(roots, root...) } return &provisioner.X5C{ ID: p.Id, Type: p.Type.String(), Name: p.Name, Roots: roots, Claims: claims, Options: options, }, nil case *linkedca.ProvisionerDetails_K8SSA: var publicKeys []byte for i, k := range d.K8SSA.GetPublicKeys() { if i > 0 { publicKeys = append(publicKeys, '\n') } publicKeys = append(publicKeys, k...) } return &provisioner.K8sSA{ ID: p.Id, Type: p.Type.String(), Name: p.Name, PubKeys: publicKeys, Claims: claims, Options: options, }, nil case *linkedca.ProvisionerDetails_SSHPOP: return &provisioner.SSHPOP{ ID: p.Id, Type: p.Type.String(), Name: p.Name, Claims: claims, }, nil case *linkedca.ProvisionerDetails_ACME: cfg := d.ACME return &provisioner.ACME{ ID: p.Id, Type: p.Type.String(), Name: p.Name, ForceCN: cfg.ForceCn, TermsOfService: cfg.TermsOfService, Website: cfg.Website, CaaIdentities: cfg.CaaIdentities, RequireEAB: cfg.RequireEab, Challenges: challengesToCertificates(cfg.Challenges), AttestationFormats: attestationFormatsToCertificates(cfg.AttestationFormats), AttestationRoots: provisionerPEMToCertificates(cfg.AttestationRoots), Claims: claims, Options: options, }, nil case *linkedca.ProvisionerDetails_OIDC: cfg := d.OIDC return &provisioner.OIDC{ ID: p.Id, Type: p.Type.String(), Name: p.Name, TenantID: cfg.TenantId, ClientID: cfg.ClientId, ClientSecret: cfg.ClientSecret, ConfigurationEndpoint: cfg.ConfigurationEndpoint, Admins: cfg.Admins, Domains: cfg.Domains, Groups: cfg.Groups, ListenAddress: cfg.ListenAddress, Scopes: cfg.Scopes, AuthParams: cfg.AuthParams, Claims: claims, Options: options, }, nil case *linkedca.ProvisionerDetails_AWS: cfg := d.AWS instanceAge, err := parseInstanceAge(cfg.InstanceAge) if err != nil { return nil, err } return &provisioner.AWS{ ID: p.Id, Type: p.Type.String(), Name: p.Name, Accounts: cfg.Accounts, DisableCustomSANs: cfg.DisableCustomSans, DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse, InstanceAge: instanceAge, Claims: claims, Options: options, }, nil case *linkedca.ProvisionerDetails_GCP: cfg := d.GCP instanceAge, err := parseInstanceAge(cfg.InstanceAge) if err != nil { return nil, err } return &provisioner.GCP{ ID: p.Id, Type: p.Type.String(), Name: p.Name, ServiceAccounts: cfg.ServiceAccounts, ProjectIDs: cfg.ProjectIds, OrganizationID: cfg.OrganizationId, DisableCustomSANs: cfg.DisableCustomSans, DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse, DisableSSHCAUser: cfg.DisableSshCaUser, DisableSSHCAHost: cfg.DisableSshCaHost, InstanceAge: instanceAge, Claims: claims, Options: options, }, nil case *linkedca.ProvisionerDetails_Azure: cfg := d.Azure return &provisioner.Azure{ ID: p.Id, Type: p.Type.String(), Name: p.Name, TenantID: cfg.TenantId, ResourceGroups: cfg.ResourceGroups, SubscriptionIDs: cfg.SubscriptionIds, ObjectIDs: cfg.ObjectIds, Audience: cfg.Audience, DisableCustomSANs: cfg.DisableCustomSans, DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse, Claims: claims, Options: options, }, nil case *linkedca.ProvisionerDetails_SCEP: cfg := d.SCEP s := &provisioner.SCEP{ ID: p.Id, Type: p.Type.String(), Name: p.Name, ForceCN: cfg.ForceCn, ChallengePassword: cfg.Challenge, Capabilities: cfg.Capabilities, IncludeRoot: cfg.IncludeRoot, ExcludeIntermediate: cfg.ExcludeIntermediate, MinimumPublicKeyLength: int(cfg.MinimumPublicKeyLength), EncryptionAlgorithmIdentifier: int(cfg.EncryptionAlgorithmIdentifier), Claims: claims, Options: options, } if decrypter := cfg.GetDecrypter(); decrypter != nil { s.DecrypterCertificate = decrypter.Certificate s.DecrypterKeyPEM = decrypter.Key s.DecrypterKeyURI = decrypter.KeyUri s.DecrypterKeyPassword = string(decrypter.KeyPassword) } return s, nil case *linkedca.ProvisionerDetails_Nebula: var roots []byte for i, root := range d.Nebula.GetRoots() { if i > 0 && !bytes.HasSuffix(root, []byte{'\n'}) { roots = append(roots, '\n') } roots = append(roots, root...) } return &provisioner.Nebula{ ID: p.Id, Type: p.Type.String(), Name: p.Name, Roots: roots, Claims: claims, Options: options, }, nil default: return nil, fmt.Errorf("provisioner %s not implemented", p.Type) } } // ProvisionerToLinkedca converts a provisioner.Interface to a // linkedca.Provisioner type. func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, error) { switch p := p.(type) { case *provisioner.JWK: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { return nil, err } publicKey, err := json.Marshal(p.Key) if err != nil { return nil, errors.Wrap(err, "error marshaling key") } return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_JWK, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_JWK{ JWK: &linkedca.JWKProvisioner{ PublicKey: publicKey, EncryptedPrivateKey: []byte(p.EncryptedKey), }, }, }, Claims: claimsToLinkedca(p.Claims), X509Template: x509Template, SshTemplate: sshTemplate, Webhooks: webhooks, }, nil case *provisioner.OIDC: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { return nil, err } return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_OIDC, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_OIDC{ OIDC: &linkedca.OIDCProvisioner{ ClientId: p.ClientID, ClientSecret: p.ClientSecret, ConfigurationEndpoint: p.ConfigurationEndpoint, Admins: p.Admins, Domains: p.Domains, Groups: p.Groups, ListenAddress: p.ListenAddress, TenantId: p.TenantID, Scopes: p.Scopes, AuthParams: p.AuthParams, }, }, }, Claims: claimsToLinkedca(p.Claims), X509Template: x509Template, SshTemplate: sshTemplate, Webhooks: webhooks, }, nil case *provisioner.GCP: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { return nil, err } return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_GCP, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_GCP{ GCP: &linkedca.GCPProvisioner{ ServiceAccounts: p.ServiceAccounts, ProjectIds: p.ProjectIDs, OrganizationId: p.OrganizationID, DisableCustomSans: p.DisableCustomSANs, DisableTrustOnFirstUse: p.DisableTrustOnFirstUse, DisableSshCaUser: p.DisableSSHCAUser, DisableSshCaHost: p.DisableSSHCAHost, InstanceAge: p.InstanceAge.String(), }, }, }, Claims: claimsToLinkedca(p.Claims), X509Template: x509Template, SshTemplate: sshTemplate, Webhooks: webhooks, }, nil case *provisioner.AWS: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { return nil, err } return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_AWS, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_AWS{ AWS: &linkedca.AWSProvisioner{ Accounts: p.Accounts, DisableCustomSans: p.DisableCustomSANs, DisableTrustOnFirstUse: p.DisableTrustOnFirstUse, InstanceAge: p.InstanceAge.String(), }, }, }, Claims: claimsToLinkedca(p.Claims), X509Template: x509Template, SshTemplate: sshTemplate, Webhooks: webhooks, }, nil case *provisioner.Azure: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { return nil, err } return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_AZURE, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_Azure{ Azure: &linkedca.AzureProvisioner{ TenantId: p.TenantID, ResourceGroups: p.ResourceGroups, SubscriptionIds: p.SubscriptionIDs, ObjectIds: p.ObjectIDs, Audience: p.Audience, DisableCustomSans: p.DisableCustomSANs, DisableTrustOnFirstUse: p.DisableTrustOnFirstUse, }, }, }, Claims: claimsToLinkedca(p.Claims), X509Template: x509Template, SshTemplate: sshTemplate, Webhooks: webhooks, }, nil case *provisioner.ACME: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { return nil, err } return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_ACME, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{ ForceCn: p.ForceCN, TermsOfService: p.TermsOfService, Website: p.Website, CaaIdentities: p.CaaIdentities, RequireEab: p.RequireEAB, Challenges: challengesToLinkedca(p.Challenges), AttestationFormats: attestationFormatsToLinkedca(p.AttestationFormats), AttestationRoots: provisionerPEMToLinkedca(p.AttestationRoots), }, }, }, Claims: claimsToLinkedca(p.Claims), X509Template: x509Template, SshTemplate: sshTemplate, Webhooks: webhooks, }, nil case *provisioner.X5C: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { return nil, err } return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_X5C, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_X5C{ X5C: &linkedca.X5CProvisioner{ Roots: provisionerPEMToLinkedca(p.Roots), }, }, }, Claims: claimsToLinkedca(p.Claims), X509Template: x509Template, SshTemplate: sshTemplate, Webhooks: webhooks, }, nil case *provisioner.K8sSA: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { return nil, err } return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_K8SSA, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_K8SSA{ K8SSA: &linkedca.K8SSAProvisioner{ PublicKeys: provisionerPEMToLinkedca(p.PubKeys), }, }, }, Claims: claimsToLinkedca(p.Claims), X509Template: x509Template, SshTemplate: sshTemplate, Webhooks: webhooks, }, nil case *provisioner.SSHPOP: return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_SSHPOP, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_SSHPOP{ SSHPOP: &linkedca.SSHPOPProvisioner{}, }, }, Claims: claimsToLinkedca(p.Claims), }, nil case *provisioner.SCEP: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { return nil, err } return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_SCEP, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_SCEP{ SCEP: &linkedca.SCEPProvisioner{ ForceCn: p.ForceCN, Challenge: p.ChallengePassword, Capabilities: p.Capabilities, MinimumPublicKeyLength: cast.Int32(p.MinimumPublicKeyLength), IncludeRoot: p.IncludeRoot, ExcludeIntermediate: p.ExcludeIntermediate, EncryptionAlgorithmIdentifier: cast.Int32(p.EncryptionAlgorithmIdentifier), Decrypter: &linkedca.SCEPDecrypter{ Certificate: p.DecrypterCertificate, Key: p.DecrypterKeyPEM, KeyUri: p.DecrypterKeyURI, KeyPassword: []byte(p.DecrypterKeyPassword), }, }, }, }, Claims: claimsToLinkedca(p.Claims), X509Template: x509Template, SshTemplate: sshTemplate, Webhooks: webhooks, }, nil case *provisioner.Nebula: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { return nil, err } return &linkedca.Provisioner{ Id: p.ID, Type: linkedca.Provisioner_NEBULA, Name: p.GetName(), Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_Nebula{ Nebula: &linkedca.NebulaProvisioner{ Roots: provisionerPEMToLinkedca(p.Roots), }, }, }, Claims: claimsToLinkedca(p.Claims), X509Template: x509Template, SshTemplate: sshTemplate, Webhooks: webhooks, }, nil default: return nil, fmt.Errorf("provisioner %s not implemented", p.GetType()) } } func parseInstanceAge(age string) (provisioner.Duration, error) { var instanceAge provisioner.Duration if age != "" { iap, err := provisioner.NewDuration(age) if err != nil { return instanceAge, err } instanceAge = *iap } return instanceAge, nil } // challengesToCertificates converts linkedca challenges to provisioner ones // skipping the unknown ones. func challengesToCertificates(challenges []linkedca.ACMEProvisioner_ChallengeType) []provisioner.ACMEChallenge { ret := make([]provisioner.ACMEChallenge, 0, len(challenges)) for _, ch := range challenges { switch ch { case linkedca.ACMEProvisioner_HTTP_01: ret = append(ret, provisioner.HTTP_01) case linkedca.ACMEProvisioner_DNS_01: ret = append(ret, provisioner.DNS_01) case linkedca.ACMEProvisioner_TLS_ALPN_01: ret = append(ret, provisioner.TLS_ALPN_01) case linkedca.ACMEProvisioner_DEVICE_ATTEST_01: ret = append(ret, provisioner.DEVICE_ATTEST_01) } } return ret } // challengesToLinkedca converts provisioner challenges to linkedca ones // skipping the unknown ones. func challengesToLinkedca(challenges []provisioner.ACMEChallenge) []linkedca.ACMEProvisioner_ChallengeType { ret := make([]linkedca.ACMEProvisioner_ChallengeType, 0, len(challenges)) for _, ch := range challenges { switch provisioner.ACMEChallenge(ch.String()) { case provisioner.HTTP_01: ret = append(ret, linkedca.ACMEProvisioner_HTTP_01) case provisioner.DNS_01: ret = append(ret, linkedca.ACMEProvisioner_DNS_01) case provisioner.TLS_ALPN_01: ret = append(ret, linkedca.ACMEProvisioner_TLS_ALPN_01) case provisioner.DEVICE_ATTEST_01: ret = append(ret, linkedca.ACMEProvisioner_DEVICE_ATTEST_01) } } return ret } // attestationFormatsToCertificates converts linkedca attestation formats to // provisioner ones skipping the unknown ones. func attestationFormatsToCertificates(formats []linkedca.ACMEProvisioner_AttestationFormatType) []provisioner.ACMEAttestationFormat { ret := make([]provisioner.ACMEAttestationFormat, 0, len(formats)) for _, f := range formats { switch f { case linkedca.ACMEProvisioner_APPLE: ret = append(ret, provisioner.APPLE) case linkedca.ACMEProvisioner_STEP: ret = append(ret, provisioner.STEP) case linkedca.ACMEProvisioner_TPM: ret = append(ret, provisioner.TPM) } } return ret } // attestationFormatsToLinkedca converts provisioner attestation formats to // linkedca ones skipping the unknown ones. func attestationFormatsToLinkedca(formats []provisioner.ACMEAttestationFormat) []linkedca.ACMEProvisioner_AttestationFormatType { ret := make([]linkedca.ACMEProvisioner_AttestationFormatType, 0, len(formats)) for _, f := range formats { switch provisioner.ACMEAttestationFormat(f.String()) { case provisioner.APPLE: ret = append(ret, linkedca.ACMEProvisioner_APPLE) case provisioner.STEP: ret = append(ret, linkedca.ACMEProvisioner_STEP) case provisioner.TPM: ret = append(ret, linkedca.ACMEProvisioner_TPM) } } return ret } ================================================ FILE: authority/provisioners_test.go ================================================ package authority import ( "context" "crypto/x509" "errors" "net/http" "reflect" "testing" "time" "github.com/stretchr/testify/require" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" ) func TestGetEncryptedKey(t *testing.T) { type ek struct { a *Authority kid string err error code int } tests := map[string]func(t *testing.T) *ek{ "ok": func(t *testing.T) *ek { c, err := LoadConfiguration("../ca/testdata/ca.json") require.NoError(t, err) a, err := New(c) require.NoError(t, err) return &ek{ a: a, kid: c.AuthorityConfig.Provisioners[1].(*provisioner.JWK).Key.KeyID, } }, "fail-not-found": func(t *testing.T) *ek { c, err := LoadConfiguration("../ca/testdata/ca.json") require.NoError(t, err) a, err := New(c) require.NoError(t, err) return &ek{ a: a, kid: "foo", err: errors.New("encrypted key with kid foo was not found"), code: http.StatusNotFound, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) ek, err := tc.a.GetEncryptedKey(tc.kid) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equals(t, sc.StatusCode(), tc.code) } assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { val, ok := tc.a.provisioners.Load("mike:" + tc.kid) assert.Fatal(t, ok) p, ok := val.(*provisioner.JWK) assert.Fatal(t, ok) assert.Equals(t, p.EncryptedKey, ek) } } }) } } type mockAdminDB struct { admin.MockDB MGetCertificateData func(string) (*db.CertificateData, error) } func (c *mockAdminDB) GetCertificateData(sn string) (*db.CertificateData, error) { return c.MGetCertificateData(sn) } func TestGetProvisioners(t *testing.T) { type gp struct { a *Authority err error code int } tests := map[string]func(t *testing.T) *gp{ "ok": func(t *testing.T) *gp { c, err := LoadConfiguration("../ca/testdata/ca.json") require.NoError(t, err) a, err := New(c) require.NoError(t, err) return &gp{a: a} }, "ok/rsa": func(t *testing.T) *gp { c, err := LoadConfiguration("../ca/testdata/rsaca.json") require.NoError(t, err) a, err := New(c) require.NoError(t, err) return &gp{a: a} }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) ps, next, err := tc.a.GetProvisioners("", 0) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equals(t, tc.code, sc.StatusCode()) } assert.HasPrefix(t, tc.err.Error(), err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.a.config.AuthorityConfig.Provisioners, ps) assert.Equals(t, "", next) } } }) } } func TestAuthority_LoadProvisionerByCertificate(t *testing.T) { _, priv, err := keyutil.GenerateDefaultKeyPair() require.NoError(t, err) csr := getCSR(t, priv) sign := func(a *Authority, extraOpts ...provisioner.SignOption) *x509.Certificate { key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) require.NoError(t, err) token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key) require.NoError(t, err) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) opts, err := a.Authorize(ctx, token) require.NoError(t, err) opts = append(opts, extraOpts...) certs, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...) require.NoError(t, err) return certs[0] } getProvisioner := func(a *Authority, name string) provisioner.Interface { p, ok := a.provisioners.LoadByName(name) if !ok { t.Fatalf("provisioner %s does not exists", name) } return p } removeExtension := provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { for i, ext := range cert.ExtraExtensions { if ext.Id.Equal(provisioner.StepOIDProvisioner) { cert.ExtraExtensions = append(cert.ExtraExtensions[:i], cert.ExtraExtensions[i+1:]...) break } } return nil }) a0 := testAuthority(t) a1 := testAuthority(t) a1.db = &db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { p, err := a1.LoadProvisionerByName("dev") require.NoError(t, err) return &db.CertificateData{ Provisioner: &db.ProvisionerData{ ID: p.GetID(), Name: p.GetName(), Type: p.GetType().String(), }, }, nil }, } a2 := testAuthority(t) a2.adminDB = &mockAdminDB{ MGetCertificateData: (func(s string) (*db.CertificateData, error) { p, err := a2.LoadProvisionerByName("dev") require.NoError(t, err) return &db.CertificateData{ Provisioner: &db.ProvisionerData{ ID: p.GetID(), Name: p.GetName(), Type: p.GetType().String(), }, }, nil }), } a3 := testAuthority(t) a3.db = &db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { return &db.CertificateData{ Provisioner: &db.ProvisionerData{ ID: "foo", Name: "foo", Type: "foo", }, }, nil }, } a4 := testAuthority(t) a4.adminDB = &mockAdminDB{ MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { return &db.CertificateData{ Provisioner: &db.ProvisionerData{ ID: "foo", Name: "foo", Type: "foo", }, }, nil }, } type args struct { crt *x509.Certificate } tests := []struct { name string authority *Authority args args want provisioner.Interface wantErr bool }{ {"ok from certificate", a0, args{sign(a0)}, getProvisioner(a0, "step-cli"), false}, {"ok from db", a1, args{sign(a1)}, getProvisioner(a1, "dev"), false}, {"ok from admindb", a2, args{sign(a2)}, getProvisioner(a2, "dev"), false}, {"fail from certificate", a0, args{sign(a0, removeExtension)}, nil, true}, {"fail from db", a3, args{sign(a3, removeExtension)}, nil, true}, {"fail from admindb", a4, args{sign(a4, removeExtension)}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.authority.LoadProvisionerByCertificate(tt.args.crt) if (err != nil) != tt.wantErr { t.Errorf("Authority.LoadProvisionerByCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.LoadProvisionerByCertificate() = %v, want %v", got, tt.want) } }) } } func TestProvisionerWebhookToLinkedca(t *testing.T) { type test struct { lwh *linkedca.Webhook pwh *provisioner.Webhook } tests := map[string]test{ "empty": test{ lwh: &linkedca.Webhook{}, pwh: &provisioner.Webhook{Kind: "NO_KIND", CertType: "ALL"}, }, "enriching ssh basic auth": test{ lwh: &linkedca.Webhook{ Id: "abc123", Name: "people", Url: "https://localhost", Kind: linkedca.Webhook_ENRICHING, Secret: "secret", Auth: &linkedca.Webhook_BasicAuth{ BasicAuth: &linkedca.BasicAuth{ Username: "user", Password: "pass", }, }, DisableTlsClientAuth: true, CertType: linkedca.Webhook_SSH, }, pwh: &provisioner.Webhook{ ID: "abc123", Name: "people", URL: "https://localhost", Kind: "ENRICHING", Secret: "secret", BasicAuth: struct { Username string Password string }{ Username: "user", Password: "pass", }, DisableTLSClientAuth: true, CertType: "SSH", }, }, "authorizing x509 bearer auth": test{ lwh: &linkedca.Webhook{ Id: "abc123", Name: "people", Url: "https://localhost", Kind: linkedca.Webhook_AUTHORIZING, Secret: "secret", Auth: &linkedca.Webhook_BearerToken{ BearerToken: &linkedca.BearerToken{ BearerToken: "tkn", }, }, CertType: linkedca.Webhook_X509, }, pwh: &provisioner.Webhook{ ID: "abc123", Name: "people", URL: "https://localhost", Kind: "AUTHORIZING", Secret: "secret", BearerToken: "tkn", CertType: "X509", }, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { gotLWH := provisionerWebhookToLinkedca(test.pwh) assert.Equals(t, test.lwh, gotLWH) gotPWH := webhookToCertificates(test.lwh) assert.Equals(t, test.pwh, gotPWH) }) } } func Test_wrapRAProvisioner(t *testing.T) { type args struct { p provisioner.Interface raInfo *provisioner.RAInfo } tests := []struct { name string args args want *wrappedProvisioner }{ {"ok", args{&provisioner.JWK{Name: "jwt"}, &provisioner.RAInfo{ProvisionerName: "ra"}}, &wrappedProvisioner{ Interface: &provisioner.JWK{Name: "jwt"}, raInfo: &provisioner.RAInfo{ProvisionerName: "ra"}, }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := wrapRAProvisioner(tt.args.p, tt.args.raInfo); !reflect.DeepEqual(got, tt.want) { t.Errorf("wrapRAProvisioner() = %v, want %v", got, tt.want) } }) } } func Test_isRAProvisioner(t *testing.T) { type args struct { p provisioner.Interface } tests := []struct { name string args args want bool }{ {"true", args{&wrappedProvisioner{ Interface: &provisioner.JWK{Name: "jwt"}, raInfo: &provisioner.RAInfo{ProvisionerName: "ra"}, }}, true}, {"nil ra", args{&wrappedProvisioner{ Interface: &provisioner.JWK{Name: "jwt"}, }}, false}, {"not ra", args{&provisioner.JWK{Name: "jwt"}}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := isRAProvisioner(tt.args.p); got != tt.want { t.Errorf("isRAProvisioner() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: authority/root.go ================================================ package authority import ( "crypto/x509" "github.com/smallstep/certificates/errs" ) // Root returns the certificate corresponding to the given SHA sum argument. func (a *Authority) Root(sum string) (*x509.Certificate, error) { val, ok := a.certificates.Load(sum) if !ok { return nil, errs.NotFound("certificate with fingerprint %s was not found", sum) } crt, ok := val.(*x509.Certificate) if !ok { return nil, errs.InternalServer("stored value is not a *x509.Certificate") } return crt, nil } // GetRootCertificate returns the server root certificate. func (a *Authority) GetRootCertificate() *x509.Certificate { return a.rootX509Certs[0] } // GetRootCertificates returns the server root certificates. // // In the Authority interface we also have a similar method, GetRoots, at the // moment the functionality of these two methods are almost identical, but this // method is intended to be used internally by CA HTTP server to load the roots // that will be set in the tls.Config while GetRoots will be used by the // Authority interface and might have extra checks in the future. func (a *Authority) GetRootCertificates() []*x509.Certificate { return a.rootX509Certs } // GetRoots returns all the root certificates for this CA. // This method implements the Authority interface. func (a *Authority) GetRoots() ([]*x509.Certificate, error) { return a.rootX509Certs, nil } // GetFederation returns all the root certificates in the federation. // This method implements the Authority interface. func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) { a.certificates.Range(func(_, v interface{}) bool { crt, ok := v.(*x509.Certificate) if !ok { federation = nil err = errs.InternalServer("stored value is not a *x509.Certificate") return false } federation = append(federation, crt) return true }) return } // GetIntermediateCertificate return the intermediate certificate that issues // the leaf certificates in the CA. // // This method can return nil if the CA is configured with a Certificate // Authority Service (CAS) that does not implement the // CertificateAuthorityGetter interface. func (a *Authority) GetIntermediateCertificate() *x509.Certificate { if len(a.intermediateX509Certs) > 0 { return a.intermediateX509Certs[0] } return nil } // GetIntermediateCertificates returns a list of all intermediate certificates // configured. The first certificate in the list will be the issuer certificate. // // This method can return an empty list or nil if the CA is configured with a // Certificate Authority Service (CAS) that does not implement the // CertificateAuthorityGetter interface. func (a *Authority) GetIntermediateCertificates() []*x509.Certificate { return a.intermediateX509Certs } ================================================ FILE: authority/root_test.go ================================================ package authority import ( "crypto/x509" "crypto/x509/pkix" "errors" "net/http" "reflect" "testing" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" "github.com/stretchr/testify/require" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/minica" "go.step.sm/crypto/pemutil" ) func TestRoot(t *testing.T) { a := testAuthority(t) a.certificates.Store("invaliddata", "a string") // invalid cert for testing tests := map[string]struct { sum string err error code int }{ "not-found": {"foo", errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound}, "invalid-stored-certificate": {"invaliddata", errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError}, "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil, http.StatusOK}, } for name, tc := range tests { t.Run(name, func(t *testing.T) { crt, err := a.Root(tc.sum) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, crt, a.rootX509Certs[0]) } } }) } } func TestAuthority_GetRootCertificate(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/root_ca.crt") if err != nil { t.Fatal(err) } tests := []struct { name string want *x509.Certificate }{ {"ok", cert}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) if got := a.GetRootCertificate(); !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetRootCertificate() = %v, want %v", got, tt.want) } }) } } func TestAuthority_GetRootCertificates(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/root_ca.crt") if err != nil { t.Fatal(err) } tests := []struct { name string want []*x509.Certificate }{ {"ok", []*x509.Certificate{cert}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) if got := a.GetRootCertificates(); !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetRootCertificates() = %v, want %v", got, tt.want) } }) } } func TestAuthority_GetRoots(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/root_ca.crt") if err != nil { t.Fatal(err) } tests := []struct { name string want []*x509.Certificate wantErr bool }{ {"ok", []*x509.Certificate{cert}, false}, } for _, tt := range tests { a := testAuthority(t) t.Run(tt.name, func(t *testing.T) { got, err := a.GetRoots() if (err != nil) != tt.wantErr { t.Errorf("Authority.GetRoots() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetRoots() = %v, want %v", got, tt.want) } }) } } func TestAuthority_GetFederation(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/root_ca.crt") if err != nil { t.Fatal(err) } tests := []struct { name string wantFederation []*x509.Certificate wantErr bool fn func(a *Authority) }{ {"ok", []*x509.Certificate{cert}, false, nil}, {"fail", nil, true, func(a *Authority) { a.certificates.Store("foo", "bar") }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) if tt.fn != nil { tt.fn(a) } gotFederation, err := a.GetFederation() if (err != nil) != tt.wantErr { t.Errorf("Authority.GetFederation() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(gotFederation, tt.wantFederation) { t.Errorf("Authority.GetFederation() = %v, want %v", gotFederation, tt.wantFederation) } }) } } func TestAuthority_GetIntermediateCertificate(t *testing.T) { ca, err := minica.New(minica.WithRootTemplate(`{ "subject": {{ toJson .Subject }}, "issuer": {{ toJson .Subject }}, "keyUsage": ["certSign", "crlSign"], "basicConstraints": { "isCA": true, "maxPathLen": -1 } }`), minica.WithIntermediateTemplate(`{ "subject": {{ toJson .Subject }}, "keyUsage": ["certSign", "crlSign"], "basicConstraints": { "isCA": true, "maxPathLen": 1 } }`)) require.NoError(t, err) signer, err := keyutil.GenerateDefaultSigner() require.NoError(t, err) cert, err := ca.Sign(&x509.Certificate{ Subject: pkix.Name{CommonName: "MiniCA Intermediate CA 0"}, PublicKey: signer.Public(), BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, }) require.NoError(t, err) type fields struct { intermediateX509Certs []*x509.Certificate } tests := []struct { name string fields fields want *x509.Certificate wantSlice []*x509.Certificate }{ {"ok one", fields{[]*x509.Certificate{ca.Intermediate}}, ca.Intermediate, []*x509.Certificate{ca.Intermediate}}, {"ok multiple", fields{[]*x509.Certificate{cert, ca.Intermediate}}, cert, []*x509.Certificate{cert, ca.Intermediate}}, {"ok empty", fields{[]*x509.Certificate{}}, nil, []*x509.Certificate{}}, {"ok nil", fields{nil}, nil, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ intermediateX509Certs: tt.fields.intermediateX509Certs, } if got := a.GetIntermediateCertificate(); !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetIntermediateCertificate() = %v, want %v", got, tt.want) } if got := a.GetIntermediateCertificates(); !reflect.DeepEqual(got, tt.wantSlice) { t.Errorf("Authority.GetIntermediateCertificates() = %v, want %v", got, tt.wantSlice) } }) } } ================================================ FILE: authority/ssh.go ================================================ package authority import ( "context" "crypto/rand" "crypto/x509" "encoding/binary" "errors" "net/http" "strings" "time" "golang.org/x/crypto/ssh" "go.step.sm/crypto/randutil" "go.step.sm/crypto/sshutil" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/internal/cast" "github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/webhook" ) const ( // SSHAddUserPrincipal is the principal that will run the add user command. // Defaults to "provisioner" but it can be changed in the configuration. SSHAddUserPrincipal = "provisioner" // SSHAddUserCommand is the default command to run to add a new user. // Defaults to "sudo useradd -m ; nc -q0 localhost 22" but it can be changed in the // configuration. The string "" will be replace by the new // principal to add. SSHAddUserCommand = "sudo useradd -m ; nc -q0 localhost 22" ) // GetSSHRoots returns the SSH User and Host public keys. func (a *Authority) GetSSHRoots(context.Context) (*config.SSHKeys, error) { return &config.SSHKeys{ HostKeys: a.sshCAHostCerts, UserKeys: a.sshCAUserCerts, }, nil } // GetSSHFederation returns the public keys for federated SSH signers. func (a *Authority) GetSSHFederation(context.Context) (*config.SSHKeys, error) { return &config.SSHKeys{ HostKeys: a.sshCAHostFederatedCerts, UserKeys: a.sshCAUserFederatedCerts, }, nil } // GetSSHConfig returns rendered templates for clients (user) or servers (host). func (a *Authority) GetSSHConfig(_ context.Context, typ string, data map[string]string) ([]templates.Output, error) { if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil { return nil, errs.NotFound("getSSHConfig: ssh is not configured") } if a.templates == nil { return nil, errs.NotFound("getSSHConfig: ssh templates are not configured") } var ts []templates.Template switch typ { case provisioner.SSHUserCert: if a.templates != nil && a.templates.SSH != nil { ts = a.templates.SSH.User } case provisioner.SSHHostCert: if a.templates != nil && a.templates.SSH != nil { ts = a.templates.SSH.Host } default: return nil, errs.BadRequest("invalid certificate type '%s'", typ) } // Merge user and default data var mergedData map[string]interface{} if len(data) == 0 { mergedData = a.templates.Data } else { mergedData = make(map[string]interface{}, len(a.templates.Data)+1) mergedData["User"] = data for k, v := range a.templates.Data { mergedData[k] = v } } // Render templates output := []templates.Output{} for _, t := range ts { if err := t.Load(); err != nil { return nil, err } // Check for required variables. if err := t.ValidateRequiredData(data); err != nil { return nil, errs.BadRequestErr(err, "%v, please use `--set ` flag", err) } o, err := t.Output(mergedData) if err != nil { return nil, err } // Backwards compatibility for version of the cli older than v0.18.0. // Before v0.18.0 we were not passing any value for SSHTemplateVersionKey // from the cli. if o.Name == "step_includes.tpl" && data[templates.SSHTemplateVersionKey] == "" { o.Type = templates.File o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/") } output = append(output, o) } return output, nil } // GetSSHBastion returns the bastion configuration, for the given pair user, // hostname. func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (*config.Bastion, error) { if a.sshBastionFunc != nil { bs, err := a.sshBastionFunc(ctx, user, hostname) return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion") } if a.config.SSH != nil { if a.config.SSH.Bastion != nil && a.config.SSH.Bastion.Hostname != "" { // Do not return a bastion for a bastion host. // // This condition might fail if a different name or IP is used. // Trying to resolve hostnames to IPs and compare them won't be a // complete solution because it depends on the network // configuration, of the CA and clients and can also return false // positives. Although not perfect, this simple solution will work // in most cases. if !strings.EqualFold(hostname, a.config.SSH.Bastion.Hostname) { return a.config.SSH.Bastion, nil } } //nolint:nilnil // legacy return nil, nil } return nil, errs.NotFound("authority.GetSSHBastion; ssh is not configured") } // SignSSH creates a signed SSH certificate with the given public key and options. func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { cert, prov, err := a.signSSH(ctx, key, opts, signOpts...) a.meter.SSHSigned(cert, prov, err) return cert, err } func (a *Authority) signSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) { var ( certOptions []sshutil.Option mods []provisioner.SSHCertModifier validators []provisioner.SSHCertValidator keyValidators []provisioner.SSHPublicKeyValidator ) // Validate given key and options if key == nil { return nil, nil, errs.BadRequest("ssh public key cannot be nil") } if err := opts.Validate(); err != nil { return nil, nil, err } // Set backdate with the configured value opts.Backdate = a.config.AuthorityConfig.Backdate.Duration var prov provisioner.Interface var webhookCtl webhookController for _, op := range signOpts { switch o := op.(type) { // Capture current provisioner case provisioner.Interface: prov = o // add options to NewCertificate case provisioner.SSHCertificateOptions: certOptions = append(certOptions, o.Options(opts)...) // modify the ssh.Certificate case provisioner.SSHCertModifier: mods = append(mods, o) // validate the ssh public key case provisioner.SSHPublicKeyValidator: keyValidators = append(keyValidators, o) // validate the ssh.Certificate case provisioner.SSHCertValidator: validators = append(validators, o) // validate the given SSHOptions case provisioner.SSHCertOptionsValidator: if err := o.Valid(opts); err != nil { return nil, prov, errs.BadRequestErr(err, "error validating ssh certificate options") } // call webhooks case webhookController: webhookCtl = o default: return nil, prov, errs.InternalServer("authority.SignSSH: invalid extra option type %T", o) } } // Validate public key for _, v := range keyValidators { if err := v.Valid(key); err != nil { return nil, nil, errs.ApplyOptions( errs.ForbiddenErr(err, "%s", err.Error()), errs.WithKeyVal("signOptions", signOpts), ) } } // Simulated certificate request with request options. cr := sshutil.CertificateRequest{ Type: opts.CertType, KeyID: opts.KeyID, Principals: opts.Principals, Key: key, } // Call enriching webhooks if err := a.callEnrichingWebhooksSSH(ctx, prov, webhookCtl, cr); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "%s", err.Error()), errs.WithKeyVal("signOptions", signOpts), ) } // Create certificate from template. certificate, err := sshutil.NewCertificate(cr, certOptions...) if err != nil { var te *sshutil.TemplateError switch { case errors.As(err, &te): return nil, prov, errs.ApplyOptions( errs.BadRequestErr(err, "%s", err.Error()), errs.WithKeyVal("signOptions", signOpts), ) case strings.HasPrefix(err.Error(), "error unmarshaling certificate"): // explicitly check for unmarshaling errors, which are most probably caused by JSON template syntax errors return nil, prov, errs.InternalServerErr(templatingError(err), errs.WithKeyVal("signOptions", signOpts), errs.WithMessage("error applying certificate template"), ) default: return nil, prov, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH") } } // Get actual *ssh.Certificate and continue with provisioner modifiers. certTpl := certificate.GetCertificate() // Use SignSSHOptions to modify the certificate validity. It will be later // checked or set if not defined. if err := opts.ModifyValidity(certTpl); err != nil { return nil, prov, errs.BadRequestErr(err, "%s", err.Error()) } // Use provisioner modifiers. for _, m := range mods { if err := m.Modify(certTpl, opts); err != nil { return nil, prov, errs.ForbiddenErr(err, "error creating ssh certificate") } } // Get signer from authority keys var signer ssh.Signer switch certTpl.CertType { case ssh.UserCert: if a.sshCAUserCertSignKey == nil { return nil, prov, errs.NotImplemented("authority.SignSSH: user certificate signing is not enabled") } signer = a.sshCAUserCertSignKey case ssh.HostCert: if a.sshCAHostCertSignKey == nil { return nil, prov, errs.NotImplemented("authority.SignSSH: host certificate signing is not enabled") } signer = a.sshCAHostCertSignKey default: return nil, prov, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType) } // Check if authority is allowed to sign the certificate if err := a.isAllowedToSignSSHCertificate(certTpl); err != nil { var ee *errs.Error if errors.As(err, &ee) { return nil, prov, ee } return nil, prov, errs.InternalServerErr(err, errs.WithMessage("authority.SignSSH: error creating ssh certificate"), ) } // Send certificate to webhooks for authorization if err := a.callAuthorizingWebhooksSSH(ctx, prov, webhookCtl, certificate, certTpl); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"), ) } // Sign certificate. cert, err := sshutil.CreateCertificate(certTpl, signer) if err != nil { return nil, prov, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error signing certificate") } // User provisioners validators. for _, v := range validators { if err := v.Valid(cert, opts); err != nil { return nil, prov, errs.ForbiddenErr(err, "error validating ssh certificate") } } if err := a.storeSSHCertificate(prov, cert); err != nil && !errors.Is(err, db.ErrNotImplemented) { return nil, prov, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db") } return cert, prov, nil } // isAllowedToSignSSHCertificate checks if the Authority is allowed to sign the SSH certificate. func (a *Authority) isAllowedToSignSSHCertificate(cert *ssh.Certificate) error { return a.policyEngine.IsSSHCertificateAllowed(cert) } // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { cert, prov, err := a.renewSSH(ctx, oldCert) a.meter.SSHRenewed(cert, prov, err) return cert, err } func (a *Authority) renewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, provisioner.Interface, error) { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { return nil, nil, errs.BadRequest("cannot renew a certificate without validity period") } if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { return nil, nil, err } // Attempt to extract the provisioner from the token. var prov provisioner.Interface if token, ok := provisioner.TokenFromContext(ctx); ok { prov, _, _ = a.getProvisionerFromToken(token) } backdate := a.config.AuthorityConfig.Backdate.Duration duration := time.Duration(cast.Int64(oldCert.ValidBefore-oldCert.ValidAfter)) * time.Second now := time.Now() va := now.Add(-1 * backdate) vb := now.Add(duration - backdate) // Build base certificate with the old key. // Nonce and serial will be automatically generated on signing. certTpl := &ssh.Certificate{ Key: oldCert.Key, CertType: oldCert.CertType, KeyId: oldCert.KeyId, ValidPrincipals: oldCert.ValidPrincipals, Permissions: oldCert.Permissions, Reserved: oldCert.Reserved, ValidAfter: cast.Uint64(va.Unix()), ValidBefore: cast.Uint64(vb.Unix()), } // Get signer from authority keys var signer ssh.Signer switch certTpl.CertType { case ssh.UserCert: if a.sshCAUserCertSignKey == nil { return nil, prov, errs.NotImplemented("renewSSH: user certificate signing is not enabled") } signer = a.sshCAUserCertSignKey case ssh.HostCert: if a.sshCAHostCertSignKey == nil { return nil, prov, errs.NotImplemented("renewSSH: host certificate signing is not enabled") } signer = a.sshCAHostCertSignKey default: return nil, prov, errs.InternalServer("renewSSH: unexpected ssh certificate type: %d", certTpl.CertType) } // Sign certificate. cert, err := sshutil.CreateCertificate(certTpl, signer) if err != nil { return nil, prov, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") } if err := a.storeRenewedSSHCertificate(prov, oldCert, cert); err != nil && !errors.Is(err, db.ErrNotImplemented) { return nil, prov, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db") } return cert, prov, nil } // RekeySSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { cert, prov, err := a.rekeySSH(ctx, oldCert, pub, signOpts...) a.meter.SSHRekeyed(cert, prov, err) return cert, err } func (a *Authority) rekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, provisioner.Interface, error) { var prov provisioner.Interface var validators []provisioner.SSHCertValidator for _, op := range signOpts { switch o := op.(type) { // Capture current provisioner case provisioner.Interface: prov = o // validate the ssh.Certificate case provisioner.SSHCertValidator: validators = append(validators, o) default: return nil, prov, errs.InternalServer("rekeySSH; invalid extra option type %T", o) } } if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { return nil, prov, errs.BadRequest("cannot rekey a certificate without validity period") } if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { return nil, prov, err } backdate := a.config.AuthorityConfig.Backdate.Duration duration := time.Duration(cast.Int64(oldCert.ValidBefore-oldCert.ValidAfter)) * time.Second now := time.Now() va := now.Add(-1 * backdate) vb := now.Add(duration - backdate) // Build base certificate with the new key. // Nonce and serial will be automatically generated on signing. cert := &ssh.Certificate{ Key: pub, CertType: oldCert.CertType, KeyId: oldCert.KeyId, ValidPrincipals: oldCert.ValidPrincipals, Permissions: oldCert.Permissions, Reserved: oldCert.Reserved, ValidAfter: cast.Uint64(va.Unix()), ValidBefore: cast.Uint64(vb.Unix()), } // Get signer from authority keys var signer ssh.Signer switch cert.CertType { case ssh.UserCert: if a.sshCAUserCertSignKey == nil { return nil, prov, errs.NotImplemented("rekeySSH; user certificate signing is not enabled") } signer = a.sshCAUserCertSignKey case ssh.HostCert: if a.sshCAHostCertSignKey == nil { return nil, prov, errs.NotImplemented("rekeySSH; host certificate signing is not enabled") } signer = a.sshCAHostCertSignKey default: return nil, prov, errs.BadRequest("unexpected certificate type '%d'", cert.CertType) } var err error // Sign certificate. cert, err = sshutil.CreateCertificate(cert, signer) if err != nil { return nil, prov, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") } // Apply validators from provisioner. for _, v := range validators { if err := v.Valid(cert, provisioner.SignSSHOptions{Backdate: backdate}); err != nil { return nil, prov, errs.ForbiddenErr(err, "error validating ssh certificate") } } if err := a.storeRenewedSSHCertificate(prov, oldCert, cert); err != nil && !errors.Is(err, db.ErrNotImplemented) { return nil, prov, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db") } return cert, prov, nil } func (a *Authority) storeSSHCertificate(prov provisioner.Interface, cert *ssh.Certificate) error { type sshCertificateStorer interface { StoreSSHCertificate(provisioner.Interface, *ssh.Certificate) error } // Store certificate in admindb or linkedca switch s := a.adminDB.(type) { case sshCertificateStorer: return s.StoreSSHCertificate(prov, cert) case db.CertificateStorer: return s.StoreSSHCertificate(cert) } // Store certificate in localdb switch s := a.db.(type) { case sshCertificateStorer: return s.StoreSSHCertificate(prov, cert) case db.CertificateStorer: return s.StoreSSHCertificate(cert) default: return nil } } func (a *Authority) storeRenewedSSHCertificate(prov provisioner.Interface, parent, cert *ssh.Certificate) error { type sshRenewerCertificateStorer interface { StoreRenewedSSHCertificate(p provisioner.Interface, parent, cert *ssh.Certificate) error } // Store certificate in admindb or linkedca switch s := a.adminDB.(type) { case sshRenewerCertificateStorer: return s.StoreRenewedSSHCertificate(prov, parent, cert) case db.CertificateStorer: return s.StoreSSHCertificate(cert) } // Store certificate in localdb switch s := a.db.(type) { case sshRenewerCertificateStorer: return s.StoreRenewedSSHCertificate(prov, parent, cert) case db.CertificateStorer: return s.StoreSSHCertificate(cert) default: return nil } } // IsValidForAddUser checks if a user provisioner certificate can be issued to // the given certificate. func IsValidForAddUser(cert *ssh.Certificate) error { if cert.CertType != ssh.UserCert { return errs.Forbidden("certificate is not a user certificate") } switch len(cert.ValidPrincipals) { case 0: return errs.Forbidden("certificate does not have any principals") case 1: return nil case 2: // OIDC provisioners adds a second principal with the email address. // @ cannot be the first character. if strings.Index(cert.ValidPrincipals[1], "@") > 0 { return nil } return errs.Forbidden("certificate does not have only one principal") default: return errs.Forbidden("certificate does not have only one principal") } } // SignSSHAddUser signs a certificate that provisions a new user in a server. func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) { if a.sshCAUserCertSignKey == nil { return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled") } if err := IsValidForAddUser(subject); err != nil { return nil, err } nonce, err := randutil.ASCII(32) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser") } var serial uint64 if err := binary.Read(rand.Reader, binary.BigEndian, &serial); err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error reading random number") } // Attempt to extract the provisioner from the token. var prov provisioner.Interface if token, ok := provisioner.TokenFromContext(ctx); ok { prov, _, _ = a.getProvisionerFromToken(token) } signer := a.sshCAUserCertSignKey principal := subject.ValidPrincipals[0] addUserPrincipal := a.getAddUserPrincipal() cert := &ssh.Certificate{ Nonce: []byte(nonce), Key: key, Serial: serial, CertType: ssh.UserCert, KeyId: principal + "-" + addUserPrincipal, ValidPrincipals: []string{addUserPrincipal}, ValidAfter: subject.ValidAfter, ValidBefore: subject.ValidBefore, Permissions: ssh.Permissions{ CriticalOptions: map[string]string{ "force-command": a.getAddUserCommand(principal), }, }, SignatureKey: signer.PublicKey(), } // Get bytes for signing trailing the signature length. data := cert.Marshal() data = data[:len(data)-4] // Sign the certificate sig, err := signer.Sign(rand.Reader, data) if err != nil { return nil, err } cert.Signature = sig if err = a.storeRenewedSSHCertificate(prov, subject, cert); err != nil && !errors.Is(err, db.ErrNotImplemented) { return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db") } return cert, nil } // CheckSSHHost checks the given principal has been registered before. func (a *Authority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { if a.sshCheckHostFunc != nil { exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates()) if err != nil { return false, errs.Wrap(http.StatusInternalServerError, err, "checkSSHHost: error from injected checkSSHHost func") } return exists, nil } exists, err := a.db.IsSSHHost(principal) if err != nil { if errors.Is(err, db.ErrNotImplemented) { return false, errs.Wrap(http.StatusNotImplemented, err, "checkSSHHost: isSSHHost is not implemented") } return false, errs.Wrap(http.StatusInternalServerError, err, "checkSSHHost: error checking if hosts exists") } return exists, nil } // GetSSHHosts returns a list of valid host principals. func (a *Authority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) { if a.GetConfig().AuthorityConfig.DisableGetSSHHosts { return nil, errs.New(http.StatusNotFound, "ssh hosts list api disabled") } if a.sshGetHostsFunc != nil { hosts, err := a.sshGetHostsFunc(ctx, cert) return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts") } hostnames, err := a.db.GetSSHHostPrincipals() if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts") } hosts := make([]config.Host, len(hostnames)) for i, hn := range hostnames { hosts[i] = config.Host{Hostname: hn} } return hosts, nil } func (a *Authority) getAddUserPrincipal() (cmd string) { if a.config.SSH.AddUserPrincipal == "" { return SSHAddUserPrincipal } return a.config.SSH.AddUserPrincipal } func (a *Authority) getAddUserCommand(principal string) string { var cmd string if a.config.SSH.AddUserCommand == "" { cmd = SSHAddUserCommand } else { cmd = a.config.SSH.AddUserCommand } return strings.ReplaceAll(cmd, "", principal) } func (a *Authority) callEnrichingWebhooksSSH(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cr sshutil.CertificateRequest) (err error) { if webhookCtl == nil { return } defer func() { a.meter.SSHWebhookEnriched(prov, err) }() var whEnrichReq *webhook.RequestBody if whEnrichReq, err = webhook.NewRequestBody( webhook.WithSSHCertificateRequest(cr), ); err == nil { err = webhookCtl.Enrich(ctx, whEnrichReq) } return } func (a *Authority) callAuthorizingWebhooksSSH(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) (err error) { if webhookCtl == nil { return } defer func() { a.meter.SSHWebhookAuthorized(prov, err) }() var whAuthBody *webhook.RequestBody if whAuthBody, err = webhook.NewRequestBody( webhook.WithSSHCertificate(cert, certTpl), ); err == nil { err = webhookCtl.Authorize(ctx, whAuthBody) } return } ================================================ FILE: authority/ssh_test.go ================================================ package authority import ( "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "encoding/base64" "errors" "fmt" "net/http" "reflect" "testing" "time" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "golang.org/x/crypto/ssh" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/templates" ) type sshTestModifier ssh.Certificate func (m sshTestModifier) Modify(cert *ssh.Certificate, _ provisioner.SignSSHOptions) error { if m.CertType != 0 { cert.CertType = m.CertType } if m.KeyId != "" { cert.KeyId = m.KeyId } if m.ValidAfter != 0 { cert.ValidAfter = m.ValidAfter } if m.ValidBefore != 0 { cert.ValidBefore = m.ValidBefore } if len(m.ValidPrincipals) != 0 { cert.ValidPrincipals = m.ValidPrincipals } if m.Permissions.CriticalOptions != nil { cert.Permissions.CriticalOptions = m.Permissions.CriticalOptions } if m.Permissions.Extensions != nil { cert.Permissions.Extensions = m.Permissions.Extensions } return nil } type sshTestCertModifier string func (m sshTestCertModifier) Modify(*ssh.Certificate, provisioner.SignSSHOptions) error { if m == "" { return nil } return errors.New(string(m)) } type sshTestCertValidator string func (v sshTestCertValidator) Valid(*ssh.Certificate, provisioner.SignSSHOptions) error { if v == "" { return nil } return errors.New(string(v)) } type sshTestOptionsValidator string func (v sshTestOptionsValidator) Valid(provisioner.SignSSHOptions) error { if v == "" { return nil } return errors.New(string(v)) } type sshTestOptionsModifier string func (m sshTestOptionsModifier) Modify(*ssh.Certificate, provisioner.SignSSHOptions) error { if m == "" { return nil } return errors.New(string(m)) } func TestAuthority_initHostOnly(t *testing.T) { auth := testAuthority(t, func(a *Authority) error { a.config.SSH.UserKey = "" return nil }) // Check keys keys, err := auth.GetSSHRoots(context.Background()) assert.NoError(t, err) assert.Len(t, 1, keys.HostKeys) assert.Len(t, 0, keys.UserKeys) // Check templates, user templates should work fine. _, err = auth.GetSSHConfig(context.Background(), "user", nil) assert.NoError(t, err) _, err = auth.GetSSHConfig(context.Background(), "host", map[string]string{ "Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key", }) assert.Error(t, err) } func TestAuthority_initUserOnly(t *testing.T) { auth := testAuthority(t, func(a *Authority) error { a.config.SSH.HostKey = "" return nil }) // Check keys keys, err := auth.GetSSHRoots(context.Background()) assert.NoError(t, err) assert.Len(t, 0, keys.HostKeys) assert.Len(t, 1, keys.UserKeys) // Check templates, host templates should work fine. _, err = auth.GetSSHConfig(context.Background(), "host", map[string]string{ "Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key", }) assert.NoError(t, err) _, err = auth.GetSSHConfig(context.Background(), "user", nil) assert.Error(t, err) } func TestAuthority_SignSSH(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) pub, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) signer, err := ssh.NewSignerFromKey(signKey) assert.FatalError(t, err) userOptions := sshTestModifier{ CertType: ssh.UserCert, } hostOptions := sshTestModifier{ CertType: ssh.HostCert, } userTemplate, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", nil)) assert.FatalError(t, err) hostTemplate, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", nil)) assert.FatalError(t, err) userTemplateWithUser, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"user"})) assert.FatalError(t, err) hostTemplateWithHosts, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"foo.test.com", "bar.test.com"})) assert.FatalError(t, err) userTemplateWithRoot, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"root"})) assert.FatalError(t, err) hostTemplateWithExampleDotCom, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"example.com"})) assert.FatalError(t, err) badUserTemplate, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"127.0.0.1"})) assert.FatalError(t, err) badHostTemplate, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"host...local"})) assert.FatalError(t, err) userCustomTemplate, err := provisioner.TemplateSSHOptions(&provisioner.Options{ SSH: &provisioner.SSHOptions{Template: `{ "type": "{{ .Type }}", "keyId": "{{ .KeyID }}", "principals": {{ append .Principals "admin" | toJson }}, "extensions": {{ set .Extensions "login@github.com" .Insecure.User.username | toJson }}, "criticalOptions": {{ toJson .CriticalOptions }} }`}, }, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"user"})) assert.FatalError(t, err) enrichTemplateData := sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"user"}) enrichTemplate, err := provisioner.TemplateSSHOptions(&provisioner.Options{ SSH: &provisioner.SSHOptions{Template: `{ "type": "{{ .Type }}", "keyId": "{{ .KeyID }}", "principals": {{ toJson .Webhooks.people.role }}, "extensions": {{ set .Extensions "login@github.com" .Insecure.User.username | toJson }}, "criticalOptions": {{ toJson .CriticalOptions }} }`}, }, enrichTemplateData) assert.FatalError(t, err) userFailTemplate, err := provisioner.TemplateSSHOptions(&provisioner.Options{ SSH: &provisioner.SSHOptions{Template: `{{ fail "an error"}}`}, }, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"user"})) assert.FatalError(t, err) userJSONSyntaxErrorTemplateFile, err := provisioner.TemplateSSHOptions(&provisioner.Options{ SSH: &provisioner.SSHOptions{TemplateFile: "./testdata/templates/badjsonsyntax.tpl"}, }, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"user"})) assert.FatalError(t, err) userJSONValueErrorTemplateFile, err := provisioner.TemplateSSHOptions(&provisioner.Options{ SSH: &provisioner.SSHOptions{TemplateFile: "./testdata/templates/badjsonvalue.tpl"}, }, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"user"})) assert.FatalError(t, err) userPolicyOptions := &policy.Options{ SSH: &policy.SSHPolicyOptions{ User: &policy.SSHUserCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ Principals: []string{"user"}, }, }, }, } userPolicy, err := policy.New(userPolicyOptions) assert.FatalError(t, err) hostPolicyOptions := &policy.Options{ SSH: &policy.SSHPolicyOptions{ Host: &policy.SSHHostCertificateOptions{ AllowedNames: &policy.SSHNameOptions{ DNSDomains: []string{"*.test.com"}, }, }, }, } hostPolicy, err := policy.New(hostPolicyOptions) assert.FatalError(t, err) now := time.Now() type fields struct { sshCAUserCertSignKey ssh.Signer sshCAHostCertSignKey ssh.Signer policyEngine *policy.Engine } type args struct { key ssh.PublicKey opts provisioner.SignSSHOptions signOpts []provisioner.SignOption } type want struct { CertType uint32 Principals []string ValidAfter uint64 ValidBefore uint64 } tests := []struct { name string fields fields args args want want wantErr bool }{ {"ok-user", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, {"ok-host", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, {"ok-user-only", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, {"ok-host-only", fields{nil, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, {"ok-opts-type-user", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false}, {"ok-opts-type-host", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false}, {"ok-opts-principals", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, {"ok-opts-principals", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false}, {"ok-opts-valid-after", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false}, {"ok-opts-valid-before", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false}, {"ok-cert-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false}, {"ok-cert-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false}, {"ok-opts-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false}, {"ok-opts-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false}, {"ok-custom-template", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userCustomTemplate, userOptions}}, want{CertType: ssh.UserCert, Principals: []string{"user", "admin"}}, false}, {"ok-enrich-template", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{enrichTemplate, userOptions, &mockWebhookController{templateData: enrichTemplateData, respData: map[string]any{"people": map[string]any{"role": []string{"user", "eng"}}}}}}, want{CertType: ssh.UserCert, Principals: []string{"user", "eng"}}, false}, {"ok-user-policy", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, {"ok-host-policy", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false}, {"fail-opts-type", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "foo"}, []provisioner.SignOption{userTemplate}}, want{}, true}, {"fail-cert-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("an error")}}, want{}, true}, {"fail-cert-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("an error")}}, want{}, true}, {"fail-opts-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("an error")}}, want{}, true}, {"fail-opts-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("an error")}}, want{}, true}, {"fail-bad-sign-options", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, "wrong type"}}, want{}, true}, {"fail-no-user-key", fields{nil, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{}, true}, {"fail-no-host-key", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{}, true}, {"fail-bad-type", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, sshTestModifier{CertType: 100}}}, want{}, true}, {"fail-custom-template", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userFailTemplate, userOptions}}, want{}, true}, {"fail-custom-template-syntax-error-file", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONSyntaxErrorTemplateFile, userOptions}}, want{}, true}, {"fail-custom-template-syntax-value-file", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONValueErrorTemplateFile, userOptions}}, want{}, true}, {"fail-user-policy", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"root"}}, []provisioner.SignOption{userTemplateWithRoot}}, want{}, true}, {"fail-user-policy-with-host-cert", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com"}}, []provisioner.SignOption{hostTemplateWithExampleDotCom}}, want{}, true}, {"fail-user-policy-with-bad-user", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{badUserTemplate}}, want{}, true}, {"fail-host-policy", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"example.com"}}, []provisioner.SignOption{hostTemplateWithExampleDotCom}}, want{}, true}, {"fail-host-policy-with-user-cert", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{}, true}, {"fail-host-policy-with-bad-host", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"example.com"}}, []provisioner.SignOption{badHostTemplate}}, want{}, true}, {"fail-enriching-webhooks", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, &mockWebhookController{enrichErr: provisioner.ErrWebhookDenied}}}, want{}, true}, {"fail-authorizing-webhooks", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, &mockWebhookController{authorizeErr: provisioner.ErrWebhookDenied}}}, want{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey a.policyEngine = tt.fields.policyEngine got, err := a.SignSSH(context.Background(), tt.args.key, tt.args.opts, tt.args.signOpts...) if (err != nil) != tt.wantErr { t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr) return } if err == nil && assert.NotNil(t, got) { assert.Equals(t, tt.want.CertType, got.CertType) assert.Equals(t, tt.want.Principals, got.ValidPrincipals) assert.Equals(t, tt.want.ValidAfter, got.ValidAfter) assert.Equals(t, tt.want.ValidBefore, got.ValidBefore) assert.NotNil(t, got.Key) assert.NotNil(t, got.Nonce) assert.NotEquals(t, 0, got.Serial) assert.NotNil(t, got.Signature) assert.NotNil(t, got.SignatureKey) } }) } } func TestAuthority_SignSSHAddUser(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) pub, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) signer, err := ssh.NewSignerFromKey(signKey) assert.FatalError(t, err) type fields struct { sshCAUserCertSignKey ssh.Signer sshCAHostCertSignKey ssh.Signer addUserPrincipal string addUserCommand string } type args struct { key ssh.PublicKey subject *ssh.Certificate } type want struct { CertType uint32 Principals []string ValidAfter uint64 ValidBefore uint64 ForceCommand string } now := time.Now() validCert := &ssh.Certificate{ CertType: ssh.UserCert, ValidPrincipals: []string{"user"}, ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), } validWant := want{ CertType: ssh.UserCert, Principals: []string{"provisioner"}, ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), ForceCommand: "sudo useradd -m user; nc -q0 localhost 22", } tests := []struct { name string fields fields args args want want wantErr bool }{ {"ok", fields{signer, signer, "", ""}, args{pub, validCert}, validWant, false}, {"ok-no-host-key", fields{signer, nil, "", ""}, args{pub, validCert}, validWant, false}, {"ok-custom-principal", fields{signer, signer, "my-principal", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "sudo useradd -m user; nc -q0 localhost 22"}, false}, {"ok-custom-command", fields{signer, signer, "", "foo "}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"provisioner"}, ForceCommand: "foo user user"}, false}, {"ok-custom-principal-and-command", fields{signer, signer, "my-principal", "foo "}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"user"}}}, want{CertType: ssh.UserCert, Principals: []string{"my-principal"}, ForceCommand: "foo user user"}, false}, {"fail-no-user-key", fields{nil, signer, "", ""}, args{pub, validCert}, want{}, true}, {"fail-no-user-cert", fields{signer, signer, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.HostCert, ValidPrincipals: []string{"foo"}}}, want{}, true}, {"fail-no-principals", fields{signer, signer, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{}}}, want{}, true}, {"fail-many-principals", fields{signer, signer, "", ""}, args{pub, &ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}}}, want{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey a.config.SSH = &SSHConfig{ AddUserPrincipal: tt.fields.addUserPrincipal, AddUserCommand: tt.fields.addUserCommand, } got, err := a.SignSSHAddUser(context.Background(), tt.args.key, tt.args.subject) if (err != nil) != tt.wantErr { t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr) return } if err == nil && assert.NotNil(t, got) { assert.Equals(t, tt.want.CertType, got.CertType) assert.Equals(t, tt.want.Principals, got.ValidPrincipals) assert.Equals(t, tt.args.subject.ValidPrincipals[0]+"-"+tt.want.Principals[0], got.KeyId) assert.Equals(t, tt.want.ValidAfter, got.ValidAfter) assert.Equals(t, tt.want.ValidBefore, got.ValidBefore) assert.Equals(t, map[string]string{"force-command": tt.want.ForceCommand}, got.CriticalOptions) assert.Equals(t, nil, got.Extensions) assert.NotNil(t, got.Key) assert.NotNil(t, got.Nonce) assert.NotEquals(t, 0, got.Serial) assert.NotNil(t, got.Signature) assert.NotNil(t, got.SignatureKey) } }) } } func TestAuthority_GetSSHRoots(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) user, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) host, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) type fields struct { sshCAUserCerts []ssh.PublicKey sshCAHostCerts []ssh.PublicKey } tests := []struct { name string fields fields want *SSHKeys wantErr bool }{ {"ok", fields{[]ssh.PublicKey{user}, []ssh.PublicKey{host}}, &SSHKeys{UserKeys: []ssh.PublicKey{user}, HostKeys: []ssh.PublicKey{host}}, false}, {"nil", fields{}, &SSHKeys{}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.sshCAUserCerts = tt.fields.sshCAUserCerts a.sshCAHostCerts = tt.fields.sshCAHostCerts got, err := a.GetSSHRoots(context.Background()) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHRoots() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHRoots() = %v, want %v", got, tt.want) } }) } } func TestAuthority_GetSSHFederation(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) user, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) host, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) type fields struct { sshCAUserFederatedCerts []ssh.PublicKey sshCAHostFederatedCerts []ssh.PublicKey } tests := []struct { name string fields fields want *SSHKeys wantErr bool }{ {"ok", fields{[]ssh.PublicKey{user}, []ssh.PublicKey{host}}, &SSHKeys{UserKeys: []ssh.PublicKey{user}, HostKeys: []ssh.PublicKey{host}}, false}, {"nil", fields{}, &SSHKeys{}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.sshCAUserFederatedCerts = tt.fields.sshCAUserFederatedCerts a.sshCAHostFederatedCerts = tt.fields.sshCAHostFederatedCerts got, err := a.GetSSHFederation(context.Background()) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHFederation() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHFederation() = %v, want %v", got, tt.want) } }) } } func TestAuthority_GetSSHConfig(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) user, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) userSigner, err := ssh.NewSignerFromSigner(key) assert.FatalError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) host, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) hostSigner, err := ssh.NewSignerFromSigner(key) assert.FatalError(t, err) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) tmplConfig := &templates.Templates{ SSH: &templates.SSHTemplates{ User: []templates.Template{ {Name: "known_host.tpl", Type: templates.File, TemplatePath: "./testdata/templates/known_hosts.tpl", Path: "ssh/known_host", Comment: "#"}, }, Host: []templates.Template{ {Name: "ca.tpl", Type: templates.File, TemplatePath: "./testdata/templates/ca.tpl", Path: "/etc/ssh/ca.pub", Comment: "#"}, }, }, Data: map[string]interface{}{ "Step": &templates.Step{ SSH: templates.StepSSH{ UserKey: user, HostKey: host, }, }, }, } userOutput := []templates.Output{ {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte(fmt.Sprintf("@cert-authority * %s %s", host.Type(), hostB64))}, } hostOutput := []templates.Output{ {Name: "ca.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/ca.pub", Content: []byte(user.Type() + " " + userB64)}, } tmplConfigWithUserData := &templates.Templates{ SSH: &templates.SSHTemplates{ User: []templates.Template{ {Name: "include.tpl", Type: templates.File, TemplatePath: "./testdata/templates/include.tpl", Path: "ssh/include", Comment: "#"}, {Name: "config.tpl", Type: templates.File, TemplatePath: "./testdata/templates/config.tpl", Path: "ssh/config", Comment: "#"}, }, Host: []templates.Template{ { Name: "sshd_config.tpl", Type: templates.File, TemplatePath: "./testdata/templates/sshd_config.tpl", Path: "/etc/ssh/sshd_config", Comment: "#", RequiredData: []string{"Certificate", "Key"}, }, }, }, Data: map[string]interface{}{ "Step": &templates.Step{ SSH: templates.StepSSH{ UserKey: user, HostKey: host, }, }, }, } userOutputWithUserData := []templates.Output{ {Name: "include.tpl", Type: templates.File, Comment: "#", Path: "ssh/include", Content: []byte("Host *\n\tInclude /home/user/.step/ssh/config")}, {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("Match exec \"step ssh check-host %h\"\n\tUserKnownHostsFile /home/user/.step/ssh/known_hosts\n\tProxyCommand step ssh proxycommand %r %h %p\n")}, } hostOutputWithUserData := []templates.Output{ {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")}, } tmplConfigUserIncludes := &templates.Templates{ SSH: &templates.SSHTemplates{ User: []templates.Template{ {Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"}, }, }, Data: map[string]interface{}{ "Step": &templates.Step{ SSH: templates.StepSSH{ UserKey: user, HostKey: host, }, }, }, } userOutputEmptyData := []templates.Output{ {Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/ssh/config\"\n")}, } userOutputWithoutTemplateVersion := []templates.Output{ {Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")}, } userOutputWithTemplateVersion := []templates.Output{ {Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")}, } tmplConfigErr := &templates.Templates{ SSH: &templates.SSHTemplates{ User: []templates.Template{ {Name: "error.tpl", Type: templates.File, TemplatePath: "./testdata/templates/error.tpl", Path: "ssh/error", Comment: "#"}, }, Host: []templates.Template{ {Name: "error.tpl", Type: templates.File, TemplatePath: "./testdata/templates/error.tpl", Path: "ssh/error", Comment: "#"}, }, }, } tmplConfigFail := &templates.Templates{ SSH: &templates.SSHTemplates{ User: []templates.Template{ {Name: "fail.tpl", Type: templates.File, TemplatePath: "./testdata/templates/fail.tpl", Path: "ssh/fail", Comment: "#"}, }, }, } type fields struct { templates *templates.Templates userSigner ssh.Signer hostSigner ssh.Signer } type args struct { typ string data map[string]string } tests := []struct { name string fields fields args args want []templates.Output wantErr bool }{ {"user", fields{tmplConfig, userSigner, hostSigner}, args{"user", nil}, userOutput, false}, {"user", fields{tmplConfig, userSigner, nil}, args{"user", nil}, userOutput, false}, {"host", fields{tmplConfig, userSigner, hostSigner}, args{"host", nil}, hostOutput, false}, {"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false}, {"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false}, {"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false}, {"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false}, {"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false}, {"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false}, {"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true}, {"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true}, {"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true}, {"hostError", fields{tmplConfigErr, userSigner, hostSigner}, args{"host", map[string]string{"Function": "foo"}}, nil, true}, {"noTemplates", fields{nil, userSigner, hostSigner}, args{"user", nil}, nil, true}, {"missingData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub"}}, nil, true}, {"failError", fields{tmplConfigFail, userSigner, hostSigner}, args{"user", nil}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.templates = tt.fields.templates a.sshCAUserCertSignKey = tt.fields.userSigner a.sshCAHostCertSignKey = tt.fields.hostSigner got, err := a.GetSSHConfig(context.Background(), tt.args.typ, tt.args.data) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHConfig() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHConfig() = %v, want %v", got, tt.want) } }) } } func TestAuthority_CheckSSHHost(t *testing.T) { type fields struct { exists bool err error } type args struct { ctx context.Context principal string token string } tests := []struct { name string fields fields args args want bool wantErr bool }{ {"true", fields{true, nil}, args{context.Background(), "foo.internal.com", ""}, true, false}, {"false", fields{false, nil}, args{context.Background(), "foo.internal.com", ""}, false, false}, {"notImplemented", fields{false, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true}, {"notImplemented", fields{true, db.ErrNotImplemented}, args{context.Background(), "foo.internal.com", ""}, false, true}, {"internal", fields{false, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true}, {"internal", fields{true, fmt.Errorf("an error")}, args{context.Background(), "foo.internal.com", ""}, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := testAuthority(t) a.db = &db.MockAuthDB{ MIsSSHHost: func(_ string) (bool, error) { return tt.fields.exists, tt.fields.err }, } got, err := a.CheckSSHHost(tt.args.ctx, tt.args.principal, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Authority.CheckSSHHost() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("Authority.CheckSSHHost() = %v, want %v", got, tt.want) } }) } } func TestSSHConfig_Validate(t *testing.T) { key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0) assert.FatalError(t, err) tests := []struct { name string sshConfig *SSHConfig wantErr bool }{ {"nil", nil, false}, {"ok", &SSHConfig{Keys: []*SSHPublicKey{{Type: "user", Key: key.Public()}}}, false}, {"ok", &SSHConfig{Keys: []*SSHPublicKey{{Type: "host", Key: key.Public()}}}, false}, {"badType", &SSHConfig{Keys: []*SSHPublicKey{{Type: "bad", Key: key.Public()}}}, true}, {"badKey", &SSHConfig{Keys: []*SSHPublicKey{{Type: "user", Key: *key}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.sshConfig.Validate(); (err != nil) != tt.wantErr { t.Errorf("SSHConfig.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestAuthority_GetSSHBastion(t *testing.T) { bastion := &Bastion{ Hostname: "bastion.local", Port: "2222", } type fields struct { config *Config sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error) } type args struct { user string hostname string } tests := []struct { name string fields fields args args want *Bastion wantErr bool }{ {"config", fields{&Config{SSH: &SSHConfig{Bastion: bastion}}, nil}, args{"user", "host.local"}, bastion, false}, {"bastion", fields{&Config{SSH: &SSHConfig{Bastion: bastion}}, nil}, args{"user", "bastion.local"}, nil, false}, {"nil", fields{&Config{SSH: &SSHConfig{Bastion: nil}}, nil}, args{"user", "host.local"}, nil, false}, {"empty", fields{&Config{SSH: &SSHConfig{Bastion: &Bastion{}}}, nil}, args{"user", "host.local"}, nil, false}, {"func", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false}, {"func err", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true}, {"error", fields{&Config{SSH: nil}, nil}, args{"user", "host.local"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &Authority{ config: tt.fields.config, sshBastionFunc: tt.fields.sshBastionFunc, } got, err := a.GetSSHBastion(context.Background(), tt.args.user, tt.args.hostname) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) return } else if err != nil { var sc render.StatusCodedError assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want) } }) } } func TestAuthority_GetSSHHosts(t *testing.T) { a := testAuthority(t) type test struct { getHostsFunc func(context.Context, *x509.Certificate) ([]Host, error) auth *Authority cert *x509.Certificate cmp func(got []Host) err error code int } tests := map[string]func(t *testing.T) *test{ "fail/getHostsFunc-fail": func(t *testing.T) *test { return &test{ getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]Host, error) { return nil, errors.New("force") }, cert: &x509.Certificate{}, err: errors.New("getSSHHosts: force"), code: http.StatusInternalServerError, } }, "ok/getHostsFunc-defined": func(t *testing.T) *test { hosts := []Host{ {HostID: "1", Hostname: "foo"}, {HostID: "2", Hostname: "bar"}, } return &test{ getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]Host, error) { return hosts, nil }, cert: &x509.Certificate{}, cmp: func(got []Host) { assert.Equals(t, got, hosts) }, } }, "fail/db-get-fail": func(t *testing.T) *test { return &test{ auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ MGetSSHHostPrincipals: func() ([]string, error) { return nil, errors.New("force") }, })), cert: &x509.Certificate{}, err: errors.New("getSSHHosts: force"), code: http.StatusInternalServerError, } }, "ok": func(t *testing.T) *test { return &test{ auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ MGetSSHHostPrincipals: func() ([]string, error) { return []string{"foo", "bar"}, nil }, })), cert: &x509.Certificate{}, cmp: func(got []Host) { assert.Equals(t, got, []Host{ {Hostname: "foo"}, {Hostname: "bar"}, }) }, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) auth := tc.auth if auth == nil { auth = a } auth.sshGetHostsFunc = tc.getHostsFunc hosts, err := auth.GetSSHHosts(context.Background(), tc.cert) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equals(t, sc.StatusCode(), tc.code) } assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { tc.cmp(hosts) } } }) } } func TestAuthority_RekeySSH(t *testing.T) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) pub, err := ssh.NewPublicKey(key.Public()) assert.FatalError(t, err) signKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.FatalError(t, err) signer, err := ssh.NewSignerFromKey(signKey) assert.FatalError(t, err) userOptions := sshTestModifier{ CertType: ssh.UserCert, } now := time.Now().UTC() a := testAuthority(t) a.db = &db.MockAuthDB{ MIsSSHRevoked: func(sn string) (bool, error) { return false, nil }, } type test struct { auth *Authority userSigner ssh.Signer hostSigner ssh.Signer cert *ssh.Certificate key ssh.PublicKey signOpts []provisioner.SignOption cmpResult func(old, n *ssh.Certificate) err error code int } tests := map[string]func(t *testing.T) *test{ "fail/is-revoked": func(t *testing.T) *test { auth := testAuthority(t) auth.db = &db.MockAuthDB{ MIsSSHRevoked: func(sn string) (bool, error) { return true, nil }, } return &test{ auth: auth, userSigner: signer, hostSigner: signer, cert: &ssh.Certificate{ Serial: 1234567890, ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}, KeyId: "foo", }, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("authority.authorizeSSHCertificate: certificate has been revoked"), code: http.StatusUnauthorized, } }, "fail/is-revoked-error": func(t *testing.T) *test { auth := testAuthority(t) auth.db = &db.MockAuthDB{ MIsSSHRevoked: func(sn string) (bool, error) { return false, errors.New("an error") }, } return &test{ auth: auth, userSigner: signer, hostSigner: signer, cert: &ssh.Certificate{ Serial: 1234567890, ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()), CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}, KeyId: "foo", }, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("authority.authorizeSSHCertificate: an error"), code: http.StatusInternalServerError, } }, "fail/opts-type": func(t *testing.T) *test { return &test{ userSigner: signer, hostSigner: signer, key: pub, signOpts: []provisioner.SignOption{userOptions}, err: errors.New("rekeySSH; invalid extra option type"), code: http.StatusInternalServerError, } }, "fail/old-cert-validAfter": func(t *testing.T) *test { return &test{ userSigner: signer, hostSigner: signer, cert: &ssh.Certificate{}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("cannot rekey a certificate without validity period"), code: http.StatusBadRequest, } }, "fail/old-cert-validBefore": func(t *testing.T) *test { return &test{ userSigner: signer, hostSigner: signer, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("cannot rekey a certificate without validity period"), code: http.StatusBadRequest, } }, "fail/old-cert-no-user-key": func(t *testing.T) *test { return &test{ userSigner: nil, hostSigner: signer, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("rekeySSH; user certificate signing is not enabled"), code: http.StatusNotImplemented, } }, "fail/old-cert-no-host-key": func(t *testing.T) *test { return &test{ userSigner: signer, hostSigner: nil, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.HostCert}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("rekeySSH; host certificate signing is not enabled"), code: http.StatusNotImplemented, } }, "fail/unexpected-old-cert-type": func(t *testing.T) *test { return &test{ userSigner: signer, hostSigner: signer, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("unexpected certificate type '0'"), code: http.StatusBadRequest, } }, "fail/db-store": func(t *testing.T) *test { return &test{ auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ MIsSSHRevoked: func(sn string) (bool, error) { return false, nil }, MStoreSSHCertificate: func(cert *ssh.Certificate) error { return errors.New("force") }, })), userSigner: signer, hostSigner: nil, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: ssh.UserCert}, key: pub, signOpts: []provisioner.SignOption{}, err: errors.New("rekeySSH; error storing certificate in db: force"), code: http.StatusInternalServerError, } }, "ok": func(t *testing.T) *test { va1 := now.Add(-24 * time.Hour) vb1 := now.Add(-23 * time.Hour) return &test{ userSigner: signer, hostSigner: nil, cert: &ssh.Certificate{ ValidAfter: uint64(va1.Unix()), ValidBefore: uint64(vb1.Unix()), CertType: ssh.UserCert, ValidPrincipals: []string{"foo", "bar"}, KeyId: "foo", }, key: pub, signOpts: []provisioner.SignOption{}, cmpResult: func(old, n *ssh.Certificate) { assert.Equals(t, n.CertType, old.CertType) assert.Equals(t, n.ValidPrincipals, old.ValidPrincipals) assert.Equals(t, n.KeyId, old.KeyId) assert.True(t, n.ValidAfter > uint64(now.Add(-5*time.Minute).Unix())) assert.True(t, n.ValidAfter < uint64(now.Add(5*time.Minute).Unix())) l8r := now.Add(1 * time.Hour) assert.True(t, n.ValidBefore > uint64(l8r.Add(-5*time.Minute).Unix())) assert.True(t, n.ValidBefore < uint64(l8r.Add(5*time.Minute).Unix())) }, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) auth := tc.auth if auth == nil { auth = a } a.sshCAUserCertSignKey = tc.userSigner a.sshCAHostCertSignKey = tc.hostSigner cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...) if err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { assert.Equals(t, sc.StatusCode(), tc.code) } assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { tc.cmpResult(tc.cert, cert) } } }) } } func TestIsValidForAddUser(t *testing.T) { type args struct { cert *ssh.Certificate } tests := []struct { name string args args wantErr bool }{ {"ok", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"john"}}}, false}, {"ok oidc", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"jane", "jane@smallstep.com"}}}, false}, {"fail at", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"jane", "@smallstep.com"}}}, true}, {"fail host", args{&ssh.Certificate{CertType: ssh.HostCert, ValidPrincipals: []string{"john"}}}, true}, {"fail principals", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"john", "jane"}}}, true}, {"fail no principals", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{}}}, true}, {"fail extra principals", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"john", "jane", "doe"}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := IsValidForAddUser(tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("IsValidForAddUser() error = %v, wantErr %v", err, tt.wantErr) } }) } } ================================================ FILE: authority/testdata/certs/badsig.csr ================================================ -----BEGIN CERTIFICATE REQUEST----- MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5 OI+cWOIc/IGwqZul/zEF5dani5ihOL7UwA== -----END CERTIFICATE REQUEST----- ================================================ FILE: authority/testdata/certs/foo.crt ================================================ -----BEGIN CERTIFICATE----- MIICIDCCAcagAwIBAgIQTL7pKDl8mFzRziotXbgjEjAKBggqhkjOPQQDAjAnMSUw IwYDVQQDExxFeGFtcGxlIEluYy4gSW50ZXJtZWRpYXRlIENBMB4XDTE5MDMyMjIy MjkyOVoXDTE5MDMyMzIyMjkyOVowHDEaMBgGA1UEAxMRZm9vLnNtYWxsc3RlcC5j b20wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQbptfDonFaeUPiTr52wl9r3dcz greolwDRmsgyFgnr1EuKH56WRcgH1gjfL0pybFlO3PdgBukR4u+sveq343OAo4He MIHbMA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUH AwIwHQYDVR0OBBYEFP9pHiVlsx5mr4L2QirOb1G9Mo4jMB8GA1UdIwQYMBaAFKEe 9IdMyaHdURMjoJce7FN9HC9wMBwGA1UdEQQVMBOCEWZvby5zbWFsbHN0ZXAuY29t MEwGDCsGAQQBgqRkxihAAQQ8MDoCAQEECHN0ZXAtY2xpBCs0VUVMSng4ZTBhUzlt MENIM2ZaMEVCN0Q1YVVQSUNiNzU5ekFMSEZlanZjMAoGCCqGSM49BAMCA0gAMEUC IDxtNo1BX/4Sbf/+k1n+v//kh8ETr3clPvhjcyfvBIGTAiEAiT0kvbkPdCCnmHIw lhpgBwT5YReZzBwIYXyKyJXc07M= -----END CERTIFICATE----- ================================================ FILE: authority/testdata/certs/foo.csr ================================================ -----BEGIN CERTIFICATE REQUEST----- MIIBBTCBqwIBADAbMRkwFwYDVQQDExBleGFtcGxlLmFjbWUuY29tMFkwEwYHKoZI zj0CAQYIKoZIzj0DAQcDQgAEk67TNST5NIdTgAutRDPfO0wa8CGAFjO7D1IoUlJI cOA48D4pSkar8v/l4dmKvxdiCNEaU8G0S16zI6dZoBGYAaAuMCwGCSqGSIb3DQEJ DjEfMB0wGwYDVR0RBBQwEoIQZXhhbXBsZS5hY21lLmNvbTAKBggqhkjOPQQDAgNJ ADBGAiEAiuk3HO986dhTjxNBBUsw7sorDWSX2+6sWvYsYkDfJrQCIQDS32JVK0P5 OI+cWOIc/IGwqZul/zEF5dani5ihOR7UwA== -----END CERTIFICATE REQUEST----- ================================================ FILE: authority/testdata/certs/intermediate_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBxTCCAWugAwIBAgIQfkaUVV4yh8gQZa/EsIECpTAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFzbWFsbHN0ZXAgUm9vdCBDQTAeFw0xODA4MTgxOTAxNDZaFw0yODA4 MTUxOTAxNDZaMCQxIjAgBgNVBAMTGXNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAATfuJeqP7FHMaVq1uMU9avTZ9JW+VzL NS7rJrkhs41j38Oru9UpZWCqXr5uNNioqElRLB6xRfTPd1mCNctQoTUpo4GGMIGD MA4GA1UdDwEB/wQEAwIBpjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIw EgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU1rz/ojOuK6vKFH4Qi8mwpXtv OzkwHwYDVR0jBBgwFoAUjoa24fWu22FipFrMI2rjBkzVDhEwCgYIKoZIzj0EAwID SAAwRQIgWDEWlEaleq5ubnm21k4Zc+agdh1pwOQ41uS4GxXEY5ACIQDkY+MvTLLe uBjherwnoVagcftox+GmRwgFpLJC/gRLzw== -----END CERTIFICATE----- ================================================ FILE: authority/testdata/certs/provisioner-not-found.crt ================================================ -----BEGIN CERTIFICATE----- MIICTDCCAfGgAwIBAgIQH1JRmbStwdCkiuqf7SM8dzAKBggqhkjOPQQDAjAnMSUw IwYDVQQDExxFeGFtcGxlIEluYy4gSW50ZXJtZWRpYXRlIENBMB4XDTE5MDMyMjIz MDI0OVoXDTE5MDMyMzIzMDI0OVowLjEsMCoGA1UEAxMjcHJvdmlzaW9uZXItbm90 LWZvdW5kLnNtYWxsc3RlcC5jb20wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARw DOZEqgkXXY0PqnEvl5ADX4xXMDNgX4lraK8SP48Ljo3vUn5FqARjKaBgPLfowFkQ gnjsAbBPwzt4SUWZW0ybo4H3MIH0MA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAU BggrBgEFBQcDAQYIKwYBBQUHAwIwHQYDVR0OBBYEFDLOyjWD26FV5lfIwPqegYIt PdmSMB8GA1UdIwQYMBaAFKEe9IdMyaHdURMjoJce7FN9HC9wMC4GA1UdEQQnMCWC I3Byb3Zpc2lvbmVyLW5vdC1mb3VuZC5zbWFsbHN0ZXAuY29tMFMGDCsGAQQBgqRk xihAAQRDMEECAQEED2dpZkBleGFtcGxlLmNvbQQrRVdDQThsdFJCdEwxN2VFQS1I dW4zQWtCN0sxTERhUXItNkdvdXc3RXBoVTAKBggqhkjOPQQDAgNJADBGAiEAkaHR dE706JI8eLio/AqPbH8A/qK1INlbKbrkZ03K5wECIQCqTGY4TYopJqLYt3HkQeTy cJfHpuPfIzvpT8X0h3zlwQ== -----END CERTIFICATE----- ================================================ FILE: authority/testdata/certs/renew-disabled.crt ================================================ -----BEGIN CERTIFICATE----- MIICJjCCAcygAwIBAgIQWhtLLuWC1foM7eq1jefkGDAKBggqhkjOPQQDAjAnMSUw IwYDVQQDExxFeGFtcGxlIEluYy4gSW50ZXJtZWRpYXRlIENBMB4XDTE5MDMyNzIz Mzk0M1oXDTE5MDMyODIzMzk0M1owHDEaMBgGA1UEAxMRYmF6LnNtYWxsc3RlcC5j b20wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAATxC77uJiCHgxIoctoHZbEauQwV 1FStMSKnEQwNkm88GD0HVUcz3g9OEHJbdMuY7VJjefD2NfdMil2N1jOw8VzMo4Hk MIHhMA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUH AwIwHQYDVR0OBBYEFCEoFgFtPV3v3YsJt7uYoz7GgChEMB8GA1UdIwQYMBaAFKEe 9IdMyaHdURMjoJce7FN9HC9wMBwGA1UdEQQVMBOCEWJhei5zbWFsbHN0ZXAuY29t MFIGDCsGAQQBgqRkxihAAQRCMEACAQEEDnJlbmV3X2Rpc2FibGVkBCtJTWk5NFdC Tkk2Z1A1Y05IWGxaWU5VenZNakdkSHlCUm1Gb28tbENFYXFrMAoGCCqGSM49BAMC A0gAMEUCIQD1uGcIQYdEEtVtOFWZGhDk+QJTznH5C182k74Kj/Ns3QIgeNtqYeto Ur1bgN1pwEwjTyr4aNz+pUWHZhyodduVaCE= -----END CERTIFICATE----- ================================================ FILE: authority/testdata/certs/root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBezCCASGgAwIBAgIQO4IwgRBrTxUIHlMdV9j5NDAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFzbWFsbHN0ZXAgUm9vdCBDQTAeFw0xODA4MTgxOTAxNDZaFw0yODA4 MTUxOTAxNDZaMBwxGjAYBgNVBAMTEXNtYWxsc3RlcCBSb290IENBMFkwEwYHKoZI zj0CAQYIKoZIzj0DAQcDQgAEsA5O9AoNi/LslXQ2LRXrcWsTH3Urlyrw4RNLs4nK Fep6C/kRk83eD4eGr0Nfh0EYvUc4J6kYIQl62/bD2RjqCqNFMEMwDgYDVR0PAQH/ BAQDAgGmMBIGA1UdEwEB/wQIMAYBAf8CAQEwHQYDVR0OBBYEFI6GtuH1rtthYqRa zCNq4wZM1Q4RMAoGCCqGSM49BAMCA0gAMEUCIQCiC+3oVXGMmUp1xeQ/vOwRWTat I96I5ms2tY8LA6z9RQIgdhiWiYwvvgIMlm57sGpol7evVuAibYH6CE3Mqn4jIE4= -----END CERTIFICATE----- ================================================ FILE: authority/testdata/certs/ssh_host_ca_key.pub ================================================ ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJj80EJXJR9vxefhdqOLSdzRzBw24t9YKPxb+eCYLf7BU50pJQnB/jK2ZM3qLFbieLaYjngZ86T4DzHxlPAnlAY= ================================================ FILE: authority/testdata/certs/ssh_user_ca_key.pub ================================================ ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ8einS88ZaWpcTZG27D5N9JDKfGv0rzjDByLGsZzMsLYl3XcsN9IWKXB6b+5GJ3UaoZf/pFxzRzIdDIh7Ypw3Y= ================================================ FILE: authority/testdata/scep/intermediate.crt ================================================ -----BEGIN CERTIFICATE----- MIICZTCCAgugAwIBAgIQDPpOQXW7OLMFNR/+iOUdQjAKBggqhkjOPQQDAjAXMRUw EwYDVQQDEwxzY2VwdGVzdHJvb3QwHhcNMjEwNTA3MTUyMjU2WhcNMzEwNTA1MTUy MjU2WjAfMR0wGwYDVQQDExRzY2VwdGVzdGludGVybWVkaWF0ZTCCASIwDQYJKoZI hvcNAQEBBQADggEPADCCAQoCggEBAJTw49z9/MeZ/YeRO89ylMV3HnYpw52/Vs2G NsgYZRKiPz2RjixUp1iWRPoDONdlEOIAo0TALNOqz4EqJHB+FpBPBA1ZfwG/PlP/ eWFubNXLXIhZPSQOiHmL4dIw0FS/VFGZm1eqc9JPG/V2G6UaKvOa8+W9/nhi4eeL +/9nTwG4cTav9ltaVxQ55kcoJtMcvouYQ4oPSZ6yNuVYbFAoaqZnJqNQhxDvKsFH lHmvl28FAVM+otmEQNTm91uPwXuVusxEGn9N/d7M4iojCiMGg0S3luBS8IrGRI1Y bSKZvGsFnqUjHh2cLL1lqqo5+QvhvP9ut6+g8QGoq8NTc2yCRy8CAwEAAaNmMGQw DgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFGfO jTNTKTAyra+rAd/NL2ydarSFMB8GA1UdIwQYMBaAFKJr1p5QRfkHzewG3YEhPAtv FQNrMAoGCCqGSM49BAMCA0gAMEUCIEYK76FN9a/hWkMZcQ+NXyzGtfW+bnwsX3oN wT6jfyO0AiEAojTeSwf/H2l/E1lvsWJfNr8nOokWz+ZsbmMm5PU0Y+g= -----END CERTIFICATE----- ================================================ FILE: authority/testdata/scep/intermediate.key ================================================ -----BEGIN RSA PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,a54ae9388ce050f0a479a258d105fbb7 VkJp9kKZQ7O9Gy9orvXaO+klt4Lrqp9oSABSBy8yFcc3neniLixqcyZZ4+CC/OG2 TGTm4TiB9RBucrUyPwoxBraWbtTLHvS4nfPwr2feSTKoHDhSIr4Z1VMDF8PWiOSg vD3iYs5F1lz78hcB/SNdSZ2jm0ze84DFC2E49agWeiFLwezcLhXKQ2HHRJ6PmJv7 IYB7+aLw8cUis/eJquWv7vrmlnshXBXLOrDekNq/mGhdpUmguDNEGX/3yT+8QYRv yeCqLVWcfkQ7KkXAeet0tVPNGQQF0+yS80Hv2/LBcskhL467qa79Xm+QPbBbhsEB aa4rettMLEdxk3IB1dgXdWhdJ4zBD+RFjczJbQlZRfmPb8sR20V/xp3x9i+SLqKp seVoNF+LhLhEwJdMF23t2KpuiOShzC60ApjALN6/O2/XGCl0KQ+NzucX+wpirS6z d2XfEYpsUaUFEFraOwfGXxLmluRtS6Q3+0+NPgwVQuH7EE7KuoTDUoSrUG4OFjaq CeUeZv1IVf0sYqZQVRiMxxdoFBKUSgcaR1gzzLZgHeoZCGP0PewmZDfJMQ5rWe0D zYYIKXUg8+oytHsz+5pQ277psXsl7iApZu56s6w3rD45w/zBeEyBhyL5JMBP8Y6y 7ReaUGsoFu3WEvrMcOsN+0Vag/SdQsvEH0PGA/ltlrlhaHKq+4t/ZwP6WxUmnaVV JNtTWB8IqxtO0zbwK1owxjrO7t42K2isSryg/y2sQb4wgokoOzg1PqEaM8PIUvjl qkGhwrOz4lNNQ9b6Hgy81DpnXnJkRNY7B5yKi62TCc6K/DHrFs0fHKb9Qxac5KKf paasGWuEC5IP0lUyn81BmAVlfByBvnGmYiDmmGXLmfsyqtGFL9fpOl1Txq3/URfT f705lzeUt9r2BT5FJtV5lkTntRzjpi5QeRiJsvfXA7nCPZj2hoLWgIm/D/HRgfVR PIX1M7nxefRgES+T6UJNsBbGjSTgEVIPqVnyWs0JUyg4+KQ5VMU8g8SGA0dtnJyF 9JrZHy2OA/AYt/c96vJj4WdFvqw3kodIKOipBbKjBBGokaOTsLADFEYgOr51BfvO QmxGZoXsRpD4sBOAwW039Ka5uCfuBETa+XQPtlHailaRZLlK9cZaDlzQr/K9jAgM qOmZIKr3L8YPK3mQV+mWVYchPXTf+UyTFiWIt30z1JlyrTw1H+h62pV9f1QXDB6P FIlfWHUK2mohWqzBnv4zFRBTVUnUDC9ONT+cVLh0cvlbRt2yy2ZgR4+d6IGH6mRH VLgWAFpS3KS1/4NfwWRBaMvIBfqfXCzXSqVJsq7RlBSW/EBwe9TDXhcTzOLHjx4E vdp+hqyXT62cTd7oWe78BBw3xOgpQwQ8bUdhye0kXMLNpU9j70pA7CjLVoVsdzH6 n1EG7Mz/5NmXLy7LP8RuVU90mNQzNu8PFWtfjZ/jr3/OxoOc0Wx6mFykXkZbxKXI xOlaOnUHKnEmsCLnZUkIxEqwKo+RYWBRtKxYsS8x8TLXyFGEfHidI75ulZM7eAS8 jWtVNKbPIyal+nQMpqa/lKW6fiGGUVp0u2x3Pnd8luRCs2htBmXSB7W7mJ2SMCui -----END RSA PRIVATE KEY----- ================================================ FILE: authority/testdata/scep/root.crt ================================================ -----BEGIN CERTIFICATE----- MIIBczCCARigAwIBAgIRAImbSwfqrrI6p72t0b9f6l4wCgYIKoZIzj0EAwIwFzEV MBMGA1UEAxMMc2NlcHRlc3Ryb290MB4XDTIxMDUwNzE1MjEzMFoXDTMxMDUwNTE1 MjEzMFowFzEVMBMGA1UEAxMMc2NlcHRlc3Ryb290MFkwEwYHKoZIzj0CAQYIKoZI zj0DAQcDQgAE3fyAgJsDICrnXhhoxHKmXMHLoW0EM9bYiBmx1xRyol0Qa3SZMW43 rtTykqVP3HUA3rIrLdX106s9IFcA3eIYiaNFMEMwDgYDVR0PAQH/BAQDAgEGMBIG A1UdEwEB/wQIMAYBAf8CAQEwHQYDVR0OBBYEFKJr1p5QRfkHzewG3YEhPAtvFQNr MAoGCCqGSM49BAMCA0kAMEYCIQDlXU695zKmSSfVPaPbM2cx7OlKr2n6NSyifatH 9zDITwIhAJUbbHzRJVgscxx+VSMqC2TkFvug6ryNu6kQIKNRwolr -----END CERTIFICATE----- ================================================ FILE: authority/testdata/scep/root.key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,0ea78864d21de199d3a737e4337589c2 ZD3ggzw3eDYJp8NovTWgTxk6MagLutgU2UfwbYliAl7wKvVyzwkPytwRkyAXPBM6 jMfiAdq6wY2wEpc8OSfrvAXrGuYqlCakDhdMaFDPcS3K29VLl4BaO2X2Rfk55nBd ASBNREKVb+hg2HV22DO7r6t+EYXTSD6iO7EB90bvKdE= -----END EC PRIVATE KEY----- ================================================ FILE: authority/testdata/secrets/foo.key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIJmnxm3N/ahRA2PWeZhRGJUKPU1lI44WcE4P1bynIim6oAoGCCqGSM49 AwEHoUQDQgAEG6bXw6JxWnlD4k6+dsJfa93XM4K3qJcA0ZrIMhYJ69RLih+elkXI B9YI3y9KcmxZTtz3YAbpEeLvrL3qt+NzgA== -----END EC PRIVATE KEY----- ================================================ FILE: authority/testdata/secrets/intermediate_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-128-CBC,856c18a6a0d6654d0e3aed6e3211a285 J1j4qQjtBsh6+ETLy/wlG4eSmQSkmxNQkyzt5zkpqFozS8yssAmTdkIFM6JGnQcc e0jGRXCy+Sx/vYQCY1uKR5FKlVpcT9I02r1nwgNHfd6zVmbQcXuYKvZQjJKLP27p gqluC9+nPA+NLJM/oP0GjNtQGasCc7oX6jYP4f1XFpw= -----END EC PRIVATE KEY----- ================================================ FILE: authority/testdata/secrets/max_priv.jwk ================================================ { "protected": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IkpsNkZLWUp4V1UwdGRIbG9UanA1aGcifQ", "encrypted_key": "Qy0EP6u5-t0ggOweoc3Z1DCzR5BllsQi", "iv": "KUkviZ_TJKY4c0Mi", "ciphertext": "h7QZqgh_Fl2MZpmVy4h375yC0DORjB1dQULbNqc6MuUCW2iweWVRysFImUXiXMUKRarJC5adwWy1GhyAqUj6Xj1iOZDGLjYnqMETGWcI0rKDBwcSU7y7Y-2VYBRDSM2b7aWtTBfz3_kvEaw_vc3b5CEPJ86UlZc-jhKFRr_IcGWU-vXX5-bppoH15IPreyzi55YdjCll338lYpDecB_Paym3XBXotyd2iGXXUwoA1npEFwuyRMMEhl9zLp7rVcMW6A_32EzB8cZANEnA0C4FXGHQalY6u_2UeqxcC8_FuXPay6VIYODyRqcABvvkft3nwOcrI0pYDGBdk2w2Euk", "tag": "kOAFq3Tg6s4vBGS_plMpSw" } ================================================ FILE: authority/testdata/secrets/max_pub.jwk ================================================ { "use": "sig", "kty": "EC", "kid": "IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", "crv": "P-256", "alg": "ES256", "x": "XmaY0c9Cc_kjfn9uhimiDiKnKn00gmFzzsvElg4KxoE", "y": "ZhYcFQBqtErdC_pA7sOXrO7AboCEPIKP9Ik4CHJqANk" } ================================================ FILE: authority/testdata/secrets/provisioner-not-found.key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEILWLnE+pkh9QQ0CcM89sCBAWMEK7EtoJOmHvvFpugj2joAoGCCqGSM49 AwEHoUQDQgAEcAzmRKoJF12ND6pxL5eQA1+MVzAzYF+Ja2ivEj+PC46N71J+RagE YymgYDy36MBZEIJ47AGwT8M7eElFmVtMmw== -----END EC PRIVATE KEY----- ================================================ FILE: authority/testdata/secrets/renew-disabled.key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIKmDvbNqeIZA9zssZxixJzAQBEUEBSyVnjCKvTWGMAd2oAoGCCqGSM49 AwEHoUQDQgAE8Qu+7iYgh4MSKHLaB2WxGrkMFdRUrTEipxEMDZJvPBg9B1VHM94P ThByW3TLmO1SY3nw9jX3TIpdjdYzsPFczA== -----END EC PRIVATE KEY----- ================================================ FILE: authority/testdata/secrets/ssh_host_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIKZCgb5pTSSCbr/xcHCOkl9O6tQtZmNahr3Ap3/c2nBLoAoGCCqGSM49 AwEHoUQDQgAEmPzQQlclH2/F5+F2o4tJ3NHMHDbi31go/Fv54Jgt/sFTnSklCcH+ MrZkzeosVuJ4tpiOeBnzpPgPMfGU8CeUBg== -----END EC PRIVATE KEY----- ================================================ FILE: authority/testdata/secrets/ssh_user_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIDuzykyPM6rLnSoyF4jnOpPAlyKZERqtaB8PTh179DMgoAoGCCqGSM49 AwEHoUQDQgAEnx6KdLzxlpalxNkbbsPk30kMp8a/SvOMMHIsaxnMywtiXddyw30h YpcHpv7kYndRqhl/+kXHNHMh0MiHtinDdg== -----END EC PRIVATE KEY----- ================================================ FILE: authority/testdata/secrets/step_cli_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-128-CBC,e2c9c7cdad45b5032f1990b929cf83fd k3Yd307VgDrdllCBGN7PP8dOMQvEAUkq1lYtyxAWa7u/DuxeDP7SYlDB+xEk/UL8 bgoYYCProydEElYFzGg8Z98WYAzbNoP2p6PPPpAhOZsxJjc5OfTHf/OQleR8PjD5 ryN4woGuq7Tiq5xritlyhluPc91ODqMsm4P98X1sPYA= -----END EC PRIVATE KEY----- ================================================ FILE: authority/testdata/secrets/step_cli_key.public ================================================ -----BEGIN PUBLIC KEY----- MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE7ZdAAMZCFU4XwgblI5RfZouBi8lY mF6DlZusNNnsbm+xCvYl3PAPZ+DKvKYERdazEPEU2OOo3riostJst0tn1g== -----END PUBLIC KEY----- ================================================ FILE: authority/testdata/secrets/step_cli_key_priv.jwk ================================================ { "protected": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ", "encrypted_key": "XaN9zcPQeWt49zchUDm34FECUTHfQTn_", "iv": "tmNHPQDqR3ebsWfd", "ciphertext": "9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw", "tag": "thPcx3t1AUcWuEygXIY3Fg" } ================================================ FILE: authority/testdata/secrets/step_cli_key_pub.jwk ================================================ { "use": "sig", "kty": "EC", "kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc", "crv": "P-256", "alg": "ES256", "x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8", "y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y" } ================================================ FILE: authority/testdata/templates/badjsonsyntax.tpl ================================================ { "subject": "badjson.localhost, } ================================================ FILE: authority/testdata/templates/badjsonvalue.tpl ================================================ { "subject": 1, "sans": {{ toJson .SANs }}, {{- if typeIs "*rsa.PublicKey" .Insecure.CR.PublicKey }} "keyUsage": ["keyEncipherment", "digitalSignature"], {{- else }} "keyUsage": ["digitalSignature"], {{- end }} "extKeyUsage": ["serverAuth", "clientAuth"] } ================================================ FILE: authority/testdata/templates/ca.tpl ================================================ {{.Step.SSH.UserKey.Type}} {{.Step.SSH.UserKey.Marshal | toString | b64enc}} {{- range .Step.SSH.UserFederatedKeys}} {{.Type}} {{.Marshal | toString | b64enc}} {{- end}} ================================================ FILE: authority/testdata/templates/config.tpl ================================================ Match exec "step ssh check-host %h" {{- if .User.User }} User {{.User.User}} {{- end }} {{- if or .User.GOOS "none" | eq "windows" }} UserKnownHostsFile {{.User.StepPath}}\ssh\known_hosts ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand %r %h %p {{- else }} UserKnownHostsFile {{.User.StepPath}}/ssh/known_hosts ProxyCommand step ssh proxycommand %r %h %p {{- end }} ================================================ FILE: authority/testdata/templates/error.tpl ================================================ Missing function {{Function}} ================================================ FILE: authority/testdata/templates/fail.tpl ================================================ {{ fail "This template will fail" }} ================================================ FILE: authority/testdata/templates/include.tpl ================================================ Host * {{- if or .User.GOOS "linux" | eq "windows" }} Include {{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config {{- else }} Include {{.User.StepPath}}/ssh/config {{- end }} ================================================ FILE: authority/testdata/templates/known_hosts.tpl ================================================ @cert-authority * {{.Step.SSH.HostKey.Type}} {{.Step.SSH.HostKey.Marshal | toString | b64enc}} {{- range .Step.SSH.HostFederatedKeys}} @cert-authority * {{.Type}} {{.Marshal | toString | b64enc}} {{- end}} ================================================ FILE: authority/testdata/templates/sshd_config.tpl ================================================ Match all TrustedUserCAKeys /etc/ssh/ca.pub HostCertificate /etc/ssh/{{.User.Certificate}} HostKey /etc/ssh/{{.User.Key}} ================================================ FILE: authority/testdata/templates/step_includes.tpl ================================================ {{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }} ================================================ FILE: authority/tls.go ================================================ package authority import ( "context" "crypto" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "math/big" "net" "net/http" "strings" "time" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/webhook" "github.com/smallstep/nosql/database" ) type tokenKey struct{} // NewTokenContext adds the given token to the context. func NewTokenContext(ctx context.Context, token string) context.Context { return context.WithValue(ctx, tokenKey{}, token) } // TokenFromContext returns the token from the given context. func TokenFromContext(ctx context.Context) (token string, ok bool) { token, ok = ctx.Value(tokenKey{}).(string) return } // GetTLSOptions returns the tls options configured. func (a *Authority) GetTLSOptions() *config.TLSOptions { return a.config.TLS } var ( oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35} oidSubjectKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 14} oidExtensionIssuingDistributionPoint = asn1.ObjectIdentifier{2, 5, 29, 28} ) func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc { return func(crt *x509.Certificate, _ provisioner.SignOptions) error { if def == nil { return errors.New("default ASN1DN template cannot be nil") } if len(crt.Subject.Country) == 0 && def.Country != "" { crt.Subject.Country = append(crt.Subject.Country, def.Country) } if len(crt.Subject.Organization) == 0 && def.Organization != "" { crt.Subject.Organization = append(crt.Subject.Organization, def.Organization) } if len(crt.Subject.OrganizationalUnit) == 0 && def.OrganizationalUnit != "" { crt.Subject.OrganizationalUnit = append(crt.Subject.OrganizationalUnit, def.OrganizationalUnit) } if len(crt.Subject.Locality) == 0 && def.Locality != "" { crt.Subject.Locality = append(crt.Subject.Locality, def.Locality) } if len(crt.Subject.Province) == 0 && def.Province != "" { crt.Subject.Province = append(crt.Subject.Province, def.Province) } if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" { crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress) } if crt.Subject.SerialNumber == "" && def.SerialNumber != "" { crt.Subject.SerialNumber = def.SerialNumber } if crt.Subject.CommonName == "" && def.CommonName != "" { crt.Subject.CommonName = def.CommonName } return nil } } // GetX509Signer returns a [crypto.Signer] implementation using the intermediate // key. // // This method can return a [NotImplementedError] if the CA is configured with a // Certificate Authority Service (CAS) that does not implement the // CertificateAuthoritySigner interface. // // [NotImplementedError]: https://pkg.go.dev/github.com/smallstep/certificates/cas/apiv1#NotImplementedError func (a *Authority) GetX509Signer() (crypto.Signer, error) { if s, ok := a.x509CAService.(casapi.CertificateAuthoritySigner); ok { return s.GetSigner() } return nil, casapi.NotImplementedError{} } // Sign creates a signed certificate from a certificate signing request. It // creates a new context.Context, and calls into SignWithContext. // // Deprecated: Use authority.SignWithContext with an actual context.Context. func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...) } // SignWithContext creates a signed certificate from a certificate signing // request, taking the provided context.Context. func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { chain, prov, err := a.signX509(ctx, csr, signOpts, extraOpts...) a.meter.X509Signed(chain, prov, err) return chain, err } func (a *Authority) signX509(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, provisioner.Interface, error) { var ( certOptions []x509util.Option certValidators []provisioner.CertificateValidator certModifiers []provisioner.CertificateModifier certEnforcers []provisioner.CertificateEnforcer ) opts := []any{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)} if err := csr.CheckSignature(); err != nil { return nil, nil, errs.ApplyOptions( errs.BadRequestErr(err, "invalid certificate request"), opts..., ) } // Set backdate with the configured value signOpts.Backdate = a.config.AuthorityConfig.Backdate.Duration var ( prov provisioner.Interface pInfo *casapi.ProvisionerInfo attData *provisioner.AttestationData webhookCtl webhookController ) for _, op := range extraOpts { switch k := op.(type) { // Capture current provisioner case provisioner.Interface: prov = k pInfo = &casapi.ProvisionerInfo{ ID: prov.GetID(), Type: prov.GetType().String(), Name: prov.GetName(), } // Adds new options to NewCertificate case provisioner.CertificateOptions: certOptions = append(certOptions, k.Options(signOpts)...) // Validate the given certificate request. case provisioner.CertificateRequestValidator: if err := k.Valid(csr); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "error validating certificate request"), opts..., ) } // Validates the unsigned certificate template. case provisioner.CertificateValidator: certValidators = append(certValidators, k) // Modifies a certificate before validating it. case provisioner.CertificateModifier: certModifiers = append(certModifiers, k) // Modifies a certificate after validating it. case provisioner.CertificateEnforcer: certEnforcers = append(certEnforcers, k) // Extra information from ACME attestations. case provisioner.AttestationData: attData = &k // Capture the provisioner's webhook controller case webhookController: webhookCtl = k default: return nil, prov, errs.InternalServer("authority.Sign; invalid extra option type %T", append([]any{k}, opts...)...) } } if err := a.callEnrichingWebhooksX509(ctx, prov, webhookCtl, attData, csr); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "%s", err.Error()), errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts), ) } crt, err := x509util.NewCertificate(csr, certOptions...) if err != nil { var te *x509util.TemplateError switch { case errors.As(err, &te): return nil, prov, errs.ApplyOptions( errs.BadRequestErr(err, "%s", err.Error()), errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts), ) case strings.HasPrefix(err.Error(), "error unmarshaling certificate"): // explicitly check for unmarshaling errors, which are most probably caused by JSON template (syntax) errors return nil, prov, errs.InternalServerErr(templatingError(err), errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts), errs.WithMessage("error applying certificate template"), ) default: return nil, prov, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign", opts...) } } // Certificate modifiers before validation leaf := crt.GetCertificate() // Set default subject if err := withDefaultASN1DN(a.config.AuthorityConfig.Template).Modify(leaf, signOpts); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "error creating certificate"), opts..., ) } for _, m := range certModifiers { if err := m.Modify(leaf, signOpts); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "error creating certificate"), opts..., ) } } // Certificate validation. for _, v := range certValidators { if err := v.Valid(leaf, signOpts); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "error validating certificate"), opts..., ) } } // Certificate modifiers after validation for _, m := range certEnforcers { if err = m.Enforce(leaf); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "error creating certificate"), opts..., ) } } // Process injected modifiers after validation for _, m := range a.x509Enforcers { if err = m.Enforce(leaf); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "error creating certificate"), opts..., ) } } // Check if authority is allowed to sign the certificate if err = a.isAllowedToSignX509Certificate(leaf); err != nil { var ee *errs.Error if errors.As(err, &ee) { return nil, prov, errs.ApplyOptions(ee, opts...) } return nil, prov, errs.InternalServerErr(err, errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts), errs.WithMessage("error creating certificate"), ) } // Send certificate to webhooks for authorization if err := a.callAuthorizingWebhooksX509(ctx, prov, webhookCtl, crt, leaf, attData); err != nil { return nil, prov, errs.ApplyOptions( errs.ForbiddenErr(err, "error creating certificate"), opts..., ) } // Sign certificate lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate)) resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{ Template: leaf, CSR: csr, Lifetime: lifetime, Backdate: signOpts.Backdate, Provisioner: pInfo, }) if err != nil { return nil, prov, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error creating certificate", opts...) } chain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) // Wrap provisioner with extra information, if not nil if prov != nil { prov = wrapProvisioner(prov, attData) } // Store certificate in the db. if err := a.storeCertificate(prov, chain); err != nil && !errors.Is(err, db.ErrNotImplemented) { return nil, prov, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error storing certificate in db", opts...) } return chain, prov, nil } // isAllowedToSignX509Certificate checks if the Authority is allowed // to sign the X.509 certificate. func (a *Authority) isAllowedToSignX509Certificate(cert *x509.Certificate) error { if err := a.constraintsEngine.ValidateCertificate(cert); err != nil { return err } return a.policyEngine.IsX509CertificateAllowed(cert) } // AreSANsAllowed evaluates the provided sans against the // authority X.509 policy. func (a *Authority) AreSANsAllowed(_ context.Context, sans []string) error { return a.policyEngine.AreSANsAllowed(sans) } // Renew creates a new Certificate identical to the old certificate, except with // a validity window that begins 'now'. func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { return a.RenewContext(context.Background(), oldCert, nil) } // Rekey is used for rekeying and renewing based on the public key. If the // public key is 'nil' then it's assumed that the cert should be renewed using // the existing public key. If the public key is not 'nil' then it's assumed // that the cert should be rekeyed. // // For both Rekey and Renew all other attributes of the new certificate should // match the old certificate. The exceptions are 'AuthorityKeyId' (which may // have changed), 'SubjectKeyId' (different in case of rekey), and // 'NotBefore/NotAfter' (the validity duration of the new certificate should be // equal to the old one, but starting 'now'). func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { return a.RenewContext(context.Background(), oldCert, pk) } // RenewContext creates a new certificate identical to the old one, but it can // optionally replace the public key with the given one. When running on RA // mode, it can only renew a certificate using a renew token instead. // // For both rekey and renew operations, all other attributes of the new // certificate should match the old certificate. The exceptions are // 'AuthorityKeyId' (which may have changed), 'SubjectKeyId' (different in case // of rekey), and 'NotBefore/NotAfter' (the validity duration of the new // certificate should be equal to the old one, but starting 'now'). func (a *Authority) RenewContext(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { chain, prov, err := a.renewContext(ctx, oldCert, pk) if pk == nil { a.meter.X509Renewed(chain, prov, err) } else { a.meter.X509Rekeyed(chain, prov, err) } return chain, err } func (a *Authority) renewContext(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, provisioner.Interface, error) { isRekey := (pk != nil) opts := []errs.Option{ errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()), } // Check step provisioner extensions prov, err := a.authorizeRenew(ctx, oldCert) if err != nil { return nil, prov, errs.StatusCodeError(http.StatusInternalServerError, err, opts...) } // Durations backdate := a.config.AuthorityConfig.Backdate.Duration duration := oldCert.NotAfter.Sub(oldCert.NotBefore) lifetime := duration - backdate // Create new certificate from previous values. // Issuer, NotBefore, NotAfter and SubjectKeyId will be set by the CAS. newCert := &x509.Certificate{ RawSubject: oldCert.RawSubject, KeyUsage: oldCert.KeyUsage, UnhandledCriticalExtensions: oldCert.UnhandledCriticalExtensions, ExtKeyUsage: oldCert.ExtKeyUsage, UnknownExtKeyUsage: oldCert.UnknownExtKeyUsage, BasicConstraintsValid: oldCert.BasicConstraintsValid, IsCA: oldCert.IsCA, MaxPathLen: oldCert.MaxPathLen, MaxPathLenZero: oldCert.MaxPathLenZero, OCSPServer: oldCert.OCSPServer, IssuingCertificateURL: oldCert.IssuingCertificateURL, PermittedDNSDomainsCritical: oldCert.PermittedDNSDomainsCritical, PermittedEmailAddresses: oldCert.PermittedEmailAddresses, DNSNames: oldCert.DNSNames, EmailAddresses: oldCert.EmailAddresses, IPAddresses: oldCert.IPAddresses, URIs: oldCert.URIs, PermittedDNSDomains: oldCert.PermittedDNSDomains, ExcludedDNSDomains: oldCert.ExcludedDNSDomains, PermittedIPRanges: oldCert.PermittedIPRanges, ExcludedIPRanges: oldCert.ExcludedIPRanges, ExcludedEmailAddresses: oldCert.ExcludedEmailAddresses, PermittedURIDomains: oldCert.PermittedURIDomains, ExcludedURIDomains: oldCert.ExcludedURIDomains, CRLDistributionPoints: oldCert.CRLDistributionPoints, PolicyIdentifiers: oldCert.PolicyIdentifiers, } if isRekey { newCert.PublicKey = pk } else { newCert.PublicKey = oldCert.PublicKey } // Copy all extensions except: // // 1. Authority Key Identifier - This one might be different if we rotate // the intermediate certificate and it will cause a TLS bad certificate // error. // // 2. Subject Key Identifier, if rekey - For rekey, SubjectKeyIdentifier // extension will be calculated for the new public key by // x509util.CreateCertificate() for _, ext := range oldCert.Extensions { if ext.Id.Equal(oidAuthorityKeyIdentifier) { continue } if ext.Id.Equal(oidSubjectKeyIdentifier) && isRekey { newCert.SubjectKeyId = nil continue } newCert.ExtraExtensions = append(newCert.ExtraExtensions, ext) } // Check if the certificate is allowed to be renewed, name constraints might // change over time. // // TODO(hslatman,maraino): consider adding policies too and consider if // RenewSSH should check policies. if err = a.constraintsEngine.ValidateCertificate(newCert); err != nil { var ee *errs.Error switch { case errors.As(err, &ee): return nil, prov, errs.StatusCodeError(ee.StatusCode(), err, opts...) default: return nil, prov, errs.InternalServerErr(err, errs.WithKeyVal("serialNumber", oldCert.SerialNumber.String()), errs.WithMessage("error renewing certificate"), ) } } // The token can optionally be in the context. If the CA is running in RA // mode, this can be used to renew a certificate. token, _ := TokenFromContext(ctx) resp, err := a.x509CAService.RenewCertificate(&casapi.RenewCertificateRequest{ Template: newCert, Lifetime: lifetime, Backdate: backdate, Token: token, }) if err != nil { return nil, prov, errs.StatusCodeError(http.StatusInternalServerError, err, opts...) } chain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) if err = a.storeRenewedCertificate(oldCert, chain); err != nil && !errors.Is(err, db.ErrNotImplemented) { return nil, prov, errs.StatusCodeError(http.StatusInternalServerError, err, opts...) } return chain, prov, nil } // storeCertificate allows to use an extension of the db.AuthDB interface that // can log the full chain of certificates. // // TODO: at some point we should replace the db.AuthDB interface to implement // `StoreCertificate(...*x509.Certificate) error` instead of just // `StoreCertificate(*x509.Certificate) error`. func (a *Authority) storeCertificate(prov provisioner.Interface, fullchain []*x509.Certificate) error { type certificateChainStorer interface { StoreCertificateChain(provisioner.Interface, ...*x509.Certificate) error } type certificateChainSimpleStorer interface { StoreCertificateChain(...*x509.Certificate) error } // Store certificate in linkedca switch s := a.adminDB.(type) { case certificateChainStorer: return s.StoreCertificateChain(prov, fullchain...) case certificateChainSimpleStorer: return s.StoreCertificateChain(fullchain...) } // Store certificate in local db switch s := a.db.(type) { case certificateChainStorer: return s.StoreCertificateChain(prov, fullchain...) case certificateChainSimpleStorer: return s.StoreCertificateChain(fullchain...) case db.CertificateStorer: return s.StoreCertificate(fullchain[0]) default: return nil } } // storeRenewedCertificate allows to use an extension of the db.AuthDB interface // that can log if a certificate has been renewed or rekeyed. // // TODO: at some point we should implement this in the standard implementation. func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain []*x509.Certificate) error { type renewedCertificateChainStorer interface { StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error } // Store certificate in linkedca if s, ok := a.adminDB.(renewedCertificateChainStorer); ok { return s.StoreRenewedCertificate(oldCert, fullchain...) } // Store certificate in local db switch s := a.db.(type) { case renewedCertificateChainStorer: return s.StoreRenewedCertificate(oldCert, fullchain...) case db.CertificateStorer: return s.StoreCertificate(fullchain[0]) default: return nil } } // RevokeOptions are the options for the Revoke API. type RevokeOptions struct { Serial string Reason string ReasonCode int PassiveOnly bool MTLS bool ACME bool Crt *x509.Certificate OTT string } // Revoke revokes a certificate. // // NOTE: Only supports passive revocation - prevent existing certificates from // being renewed. // // TODO: Add OCSP and CRL support. func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error { opts := []interface{}{ errs.WithKeyVal("serialNumber", revokeOpts.Serial), errs.WithKeyVal("reasonCode", revokeOpts.ReasonCode), errs.WithKeyVal("reason", revokeOpts.Reason), errs.WithKeyVal("passiveOnly", revokeOpts.PassiveOnly), errs.WithKeyVal("MTLS", revokeOpts.MTLS), errs.WithKeyVal("ACME", revokeOpts.ACME), errs.WithKeyVal("context", provisioner.MethodFromContext(ctx).String()), } if revokeOpts.MTLS || revokeOpts.ACME { opts = append(opts, errs.WithKeyVal("certificate", base64.StdEncoding.EncodeToString(revokeOpts.Crt.Raw))) } else { opts = append(opts, errs.WithKeyVal("token", revokeOpts.OTT)) } rci := &db.RevokedCertificateInfo{ Serial: revokeOpts.Serial, ReasonCode: revokeOpts.ReasonCode, Reason: revokeOpts.Reason, MTLS: revokeOpts.MTLS, ACME: revokeOpts.ACME, RevokedAt: time.Now().UTC(), } // For X509 CRLs attempt to get the expiration date of the certificate. if provisioner.MethodFromContext(ctx) == provisioner.RevokeMethod { if revokeOpts.Crt == nil { cert, err := a.db.GetCertificate(revokeOpts.Serial) if err == nil { rci.ExpiresAt = cert.NotAfter } } else { rci.ExpiresAt = revokeOpts.Crt.NotAfter } } // If not mTLS nor ACME, then get the TokenID of the token. if !revokeOpts.MTLS && !revokeOpts.ACME { token, err := jose.ParseSigned(revokeOpts.OTT) if err != nil { return errs.Wrap(http.StatusUnauthorized, err, "authority.Revoke; error parsing token", opts...) } // Get claims w/out verification. var claims Claims if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { return errs.Wrap(http.StatusUnauthorized, err, "authority.Revoke", opts...) } // Verify that the serial in the token matches the serial from the request. if revokeOpts.Serial != claims.Subject { return errs.ApplyOptions( errs.Forbidden( "request serial number %q and token subject %q do not match", revokeOpts.Serial, claims.Subject, ), opts..., ) } // This method will also validate the audiences for JWK provisioners. p, err := a.LoadProvisionerByToken(token, &claims.Claims) if err != nil { return err } rci.ProvisionerID = p.GetID() rci.TokenID, err = p.GetTokenID(revokeOpts.OTT) if err != nil && !errors.Is(err, provisioner.ErrAllowTokenReuse) { return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke; could not get ID for token") } opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID), errs.WithKeyVal("tokenID", rci.TokenID), ) } else if p, err := a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil { // Load the Certificate provisioner if one exists. rci.ProvisionerID = p.GetID() opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID)) } failRevoke := func(err error) error { switch { case errors.Is(err, db.ErrNotImplemented): return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...) case errors.Is(err, db.ErrAlreadyExists): return errs.ApplyOptions( errs.BadRequest("certificate with serial number '%s' is already revoked", rci.Serial), opts..., ) default: return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) } } if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod { if err := a.revokeSSH(nil, rci); err != nil { return failRevoke(err) } } else { // Revoke an X.509 certificate using CAS. If the certificate is not // provided we will try to read it from the db. If the read fails we // won't throw an error as it will be responsibility of the CAS // implementation to require a certificate. var revokedCert *x509.Certificate if revokeOpts.Crt != nil { revokedCert = revokeOpts.Crt } else if rci.Serial != "" { revokedCert, _ = a.db.GetCertificate(rci.Serial) } // CAS operation, note that SoftCAS (default) is a noop. // The revoke happens when this is stored in the db. _, err := a.x509CAService.RevokeCertificate(&casapi.RevokeCertificateRequest{ Certificate: revokedCert, SerialNumber: rci.Serial, Reason: rci.Reason, ReasonCode: rci.ReasonCode, PassiveOnly: revokeOpts.PassiveOnly, }) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) } // Save as revoked in the Db. if err := a.revoke(revokedCert, rci); err != nil { return failRevoke(err) } // Generate a new CRL so CRL requesters will always get an up-to-date // CRL whenever they request it. if a.config.CRL.IsEnabled() && a.config.CRL.GenerateOnRevoke { if err := a.GenerateCertificateRevocationList(); err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) } } } return nil } func (a *Authority) revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error { if lca, ok := a.adminDB.(interface { Revoke(*x509.Certificate, *db.RevokedCertificateInfo) error }); ok { return lca.Revoke(crt, rci) } return a.db.Revoke(rci) } func (a *Authority) revokeSSH(crt *ssh.Certificate, rci *db.RevokedCertificateInfo) error { if lca, ok := a.adminDB.(interface { RevokeSSH(*ssh.Certificate, *db.RevokedCertificateInfo) error }); ok { return lca.RevokeSSH(crt, rci) } return a.db.RevokeSSH(rci) } // CertificateRevocationListInfo contains a CRL in DER format and associated metadata. type CertificateRevocationListInfo struct { Number int64 ExpiresAt time.Time Duration time.Duration Data []byte } // GetCertificateRevocationList will return the currently generated CRL from the DB, or a not implemented // error if the underlying AuthDB does not support CRLs func (a *Authority) GetCertificateRevocationList() (*CertificateRevocationListInfo, error) { if !a.config.CRL.IsEnabled() { return nil, errs.Wrap(http.StatusNotFound, errors.Errorf("Certificate Revocation Lists are not enabled"), "authority.GetCertificateRevocationList") } crlDB, ok := a.db.(db.CertificateRevocationListDB) if !ok { return nil, errs.Wrap(http.StatusNotImplemented, errors.Errorf("Database does not support Certificate Revocation Lists"), "authority.GetCertificateRevocationList") } crlInfo, err := crlDB.GetCRL() if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetCertificateRevocationList") } return &CertificateRevocationListInfo{ Number: crlInfo.Number, ExpiresAt: crlInfo.ExpiresAt, Duration: crlInfo.Duration, Data: crlInfo.DER, }, nil } // GenerateCertificateRevocationList generates a DER representation of a signed CRL and stores it in the // database. Returns nil if CRL generation has been disabled in the config func (a *Authority) GenerateCertificateRevocationList() error { if !a.config.CRL.IsEnabled() { return nil } crlDB, ok := a.db.(db.CertificateRevocationListDB) if !ok { return errors.Errorf("Database does not support CRL generation") } // some CAS may not implement the CRLGenerator interface, so check before we proceed caCRLGenerator, ok := a.x509CAService.(casapi.CertificateAuthorityCRLGenerator) if !ok { return errors.Errorf("CA does not support CRL Generation") } // use a mutex to ensure only one CRL is generated at a time to avoid // concurrency issues a.crlMutex.Lock() defer a.crlMutex.Unlock() crlInfo, err := crlDB.GetCRL() if err != nil && !database.IsErrNotFound(err) { return errors.Wrap(err, "could not retrieve CRL from database") } now := time.Now().Truncate(time.Second).UTC() revokedList, err := crlDB.GetRevokedCertificates() if err != nil { return errors.Wrap(err, "could not retrieve revoked certificates list from database") } // Number is a monotonically increasing integer (essentially the CRL version // number) that we need to keep track of and increase every time we generate // a new CRL var bn big.Int if crlInfo != nil { bn.SetInt64(crlInfo.Number + 1) } // Convert our database db.RevokedCertificateInfo types into the x509 // representation ready for the CAS to sign it var revokedCertificateEntries []x509.RevocationListEntry skipExpiredTime := now.Add(-config.DefaultCRLExpiredDuration) for _, revokedCert := range *revokedList { // skip expired certificates if !revokedCert.ExpiresAt.IsZero() && revokedCert.ExpiresAt.Before(skipExpiredTime) { continue } var sn big.Int sn.SetString(revokedCert.Serial, 10) revokedCertificateEntries = append(revokedCertificateEntries, x509.RevocationListEntry{ SerialNumber: &sn, RevocationTime: revokedCert.RevokedAt, ReasonCode: revokedCert.ReasonCode, }) } var updateDuration time.Duration if a.config.CRL.CacheDuration != nil { updateDuration = a.config.CRL.CacheDuration.Duration } else if crlInfo != nil { updateDuration = crlInfo.Duration } // Create a RevocationList representation ready for the CAS to sign // TODO: allow SignatureAlgorithm to be specified? revocationList := x509.RevocationList{ SignatureAlgorithm: 0, RevokedCertificateEntries: revokedCertificateEntries, Number: &bn, ThisUpdate: now, NextUpdate: now.Add(updateDuration), } // Set CRL IDP to config item, otherwise, leave as default var fullName string if a.config.CRL.IDPurl != "" { fullName = a.config.CRL.IDPurl } else { fullName = a.config.Audience("/1.0/crl")[0] } // Add distribution point. // // Note that this is currently using the port 443 by default. if b, err := marshalDistributionPoint(fullName); err == nil { revocationList.ExtraExtensions = []pkix.Extension{ {Id: oidExtensionIssuingDistributionPoint, Critical: true, Value: b}, } } certificateRevocationList, err := caCRLGenerator.CreateCRL(&casapi.CreateCRLRequest{RevocationList: &revocationList}) if err != nil { return errors.Wrap(err, "could not create CRL") } // Create a new db.CertificateRevocationListInfo, which stores the new Number we just generated, the // expiry time, duration, and the DER-encoded CRL newCRLInfo := db.CertificateRevocationListInfo{ Number: bn.Int64(), ExpiresAt: revocationList.NextUpdate, DER: certificateRevocationList.CRL, Duration: updateDuration, } // Store the CRL in the database ready for retrieval by api endpoints err = crlDB.StoreCRL(&newCRLInfo) if err != nil { return errors.Wrap(err, "could not store CRL in database") } return nil } // GetTLSCertificate creates a new leaf certificate to be used by the CA HTTPS server. func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) { fatal := func(err error) (*tls.Certificate, error) { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetTLSCertificate") } // Generate default key. priv, err := keyutil.GenerateDefaultKey() if err != nil { return fatal(err) } signer, ok := priv.(crypto.Signer) if !ok { return fatal(errors.New("private key is not a crypto.Signer")) } // prepare the sans: IPv6 DNS hostname representations are converted to their IP representation sans := make([]string, len(a.config.DNSNames)) for i, san := range a.config.DNSNames { if strings.HasPrefix(san, "[") && strings.HasSuffix(san, "]") { if ip := net.ParseIP(san[1 : len(san)-1]); ip != nil { san = ip.String() } } sans[i] = san } // Create initial certificate request. cr, err := x509util.CreateCertificateRequest(a.config.CommonName, sans, signer) if err != nil { return fatal(err) } // Generate certificate template directly from the certificate request. template, err := x509util.NewCertificate(cr) if err != nil { return fatal(err) } // Get x509 certificate template, set validity and sign it. now := time.Now() certTpl := template.GetCertificate() certTpl.NotBefore = now.Add(-1 * time.Minute) certTpl.NotAfter = now.Add(24 * time.Hour) // Policy and constraints require this fields to be set. At this moment they // are only present in the extra extension. certTpl.DNSNames = cr.DNSNames certTpl.IPAddresses = cr.IPAddresses certTpl.EmailAddresses = cr.EmailAddresses certTpl.URIs = cr.URIs // Fail if name constraints do not allow the server names. if err := a.constraintsEngine.ValidateCertificate(certTpl); err != nil { return fatal(err) } // Set the cert lifetime as follows: // i) If the CA is not a StepCAS RA use 24h, else // ii) if the CA is a StepCAS RA, leave the lifetime empty and // let the provisioner of the CA decide the lifetime of the RA cert. var lifetime time.Duration if casapi.TypeOf(a.x509CAService) != casapi.StepCAS { lifetime = 24 * time.Hour } resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{ Template: certTpl, CSR: cr, Lifetime: lifetime, Backdate: 1 * time.Minute, IsCAServerCert: true, }) if err != nil { return fatal(err) } // Generate PEM blocks to create tls.Certificate pemBlocks := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: resp.Certificate.Raw, }) for _, crt := range resp.CertificateChain { pemBlocks = append(pemBlocks, pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, })...) } keyPEM, err := pemutil.Serialize(priv) if err != nil { return fatal(err) } tlsCrt, err := tls.X509KeyPair(pemBlocks, pem.EncodeToMemory(keyPEM)) if err != nil { return fatal(err) } // Set leaf certificate tlsCrt.Leaf = resp.Certificate return &tlsCrt, nil } // RFC 5280, 5.2.5 type distributionPoint struct { DistributionPoint distributionPointName `asn1:"optional,tag:0"` OnlyContainsUserCerts bool `asn1:"optional,tag:1"` OnlyContainsCACerts bool `asn1:"optional,tag:2"` OnlySomeReasons asn1.BitString `asn1:"optional,tag:3"` IndirectCRL bool `asn1:"optional,tag:4"` OnlyContainsAttributeCerts bool `asn1:"optional,tag:5"` } type distributionPointName struct { FullName []asn1.RawValue `asn1:"optional,tag:0"` RelativeName pkix.RDNSequence `asn1:"optional,tag:1"` } /* marshalDistributionPoint currently marshals only DP, citing spec https://datatracker.ietf.org/doc/html/rfc5280#section-5.2.5: That is, if onlyContainsUserCerts, onlyContainsCACerts, indirectCRL, and onlyContainsAttributeCerts are all FALSE, then either the distributionPoint field or the onlySomeReasons field MUST be present. */ func marshalDistributionPoint(fullName string) ([]byte, error) { return asn1.Marshal(distributionPoint{ DistributionPoint: distributionPointName{ FullName: []asn1.RawValue{ {Class: 2, Tag: 6, Bytes: []byte(fullName)}, }, }, }) } // templatingError tries to extract more information about the cause of // an error related to (most probably) malformed template data and adds // this to the error message. func templatingError(err error) error { cause := errors.Cause(err) var ( syntaxError *json.SyntaxError typeError *json.UnmarshalTypeError ) if errors.As(err, &syntaxError) { // offset is arguably not super clear to the user, but it's the best we can do here cause = fmt.Errorf("%w at offset %d", cause, syntaxError.Offset) } else if errors.As(err, &typeError) { // slightly rewriting the default error message to include the offset cause = fmt.Errorf("cannot unmarshal %s at offset %d into Go value of type %s", typeError.Value, typeError.Offset, typeError.Type) } return errors.Wrap(cause, "error applying certificate template") } func (a *Authority) callEnrichingWebhooksX509(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) (err error) { if webhookCtl == nil { return } defer func() { a.meter.X509WebhookEnriched(prov, err) }() var attested *webhook.AttestationData if attData != nil { attested = &webhook.AttestationData{ PermanentIdentifier: attData.PermanentIdentifier, } } var whEnrichReq *webhook.RequestBody if whEnrichReq, err = webhook.NewRequestBody( webhook.WithX509CertificateRequest(csr), webhook.WithAttestationData(attested), ); err == nil { err = webhookCtl.Enrich(ctx, whEnrichReq) } return } func (a *Authority) callAuthorizingWebhooksX509(ctx context.Context, prov provisioner.Interface, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) (err error) { if webhookCtl == nil { return } defer func() { a.meter.X509WebhookAuthorized(prov, err) }() var attested *webhook.AttestationData if attData != nil { attested = &webhook.AttestationData{ PermanentIdentifier: attData.PermanentIdentifier, } } var whAuthBody *webhook.RequestBody if whAuthBody, err = webhook.NewRequestBody( webhook.WithX509Certificate(cert, leaf), webhook.WithAttestationData(attested), ); err == nil { err = webhookCtl.Authorize(ctx, whAuthBody) } return } ================================================ FILE: authority/tls_test.go ================================================ package authority import ( "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/sha1" //nolint:gosec // used to create the Subject Key Identifier by RFC 5280 "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/pem" "errors" "fmt" "net/http" "reflect" "strings" "testing" "time" "go.step.sm/crypto/fingerprint" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/minica" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/cas/softcas" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/errs" "github.com/smallstep/nosql/database" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var ( stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) ) const provisionerTypeJWK = 1 type stepProvisionerASN1 struct { Type int Name []byte CredentialID []byte } type certificateDurationEnforcer struct { NotBefore time.Time NotAfter time.Time } func (m *certificateDurationEnforcer) Enforce(cert *x509.Certificate) error { cert.NotBefore = m.NotBefore cert.NotAfter = m.NotAfter return nil } type certificateChainDB struct { db.MockAuthDB MStoreCertificateChain func(provisioner.Interface, ...*x509.Certificate) error } func (d *certificateChainDB) StoreCertificateChain(p provisioner.Interface, certs ...*x509.Certificate) error { return d.MStoreCertificateChain(p, certs...) } func getDefaultIssuer(a *Authority) *x509.Certificate { return a.x509CAService.(*softcas.SoftCAS).CertificateChain[len(a.x509CAService.(*softcas.SoftCAS).CertificateChain)-1] } func getDefaultSigner(a *Authority) crypto.Signer { return a.x509CAService.(*softcas.SoftCAS).Signer } func generateCertificate(t *testing.T, commonName string, sans []string, opts ...interface{}) *x509.Certificate { t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) cr, err := x509util.CreateCertificateRequest(commonName, sans, priv) require.NoError(t, err) template, err := x509util.NewCertificate(cr) require.NoError(t, err) cert := template.GetCertificate() for _, m := range opts { switch m := m.(type) { case provisioner.CertificateModifierFunc: err = m.Modify(cert, provisioner.SignOptions{}) require.NoError(t, err) case signerFunc: cert, err = m(cert, priv.Public()) require.NoError(t, err) default: require.Fail(t, "", "unknown type %T", m) } } return cert } func generateRootCertificate(t *testing.T) (*x509.Certificate, crypto.Signer) { t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) cr, err := x509util.CreateCertificateRequest("TestRootCA", nil, priv) require.NoError(t, err) data := x509util.CreateTemplateData("TestRootCA", nil) template, err := x509util.NewCertificate(cr, x509util.WithTemplate(x509util.DefaultRootTemplate, data)) require.NoError(t, err) cert := template.GetCertificate() cert, err = x509util.CreateCertificate(cert, cert, priv.Public(), priv) require.NoError(t, err) return cert, priv } func generateIntermidiateCertificate(t *testing.T, issuer *x509.Certificate, signer crypto.Signer) (*x509.Certificate, crypto.Signer) { t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) cr, err := x509util.CreateCertificateRequest("TestIntermediateCA", nil, priv) require.NoError(t, err) data := x509util.CreateTemplateData("TestIntermediateCA", nil) template, err := x509util.NewCertificate(cr, x509util.WithTemplate(x509util.DefaultRootTemplate, data)) require.NoError(t, err) cert := template.GetCertificate() cert, err = x509util.CreateCertificate(cert, issuer, priv.Public(), signer) require.NoError(t, err) return cert, priv } func withSubject(sub pkix.Name) provisioner.CertificateModifierFunc { return func(crt *x509.Certificate, _ provisioner.SignOptions) error { crt.Subject = sub return nil } } func withProvisionerOID(name, kid string) provisioner.CertificateModifierFunc { return func(crt *x509.Certificate, _ provisioner.SignOptions) error { b, err := asn1.Marshal(stepProvisionerASN1{ Type: provisionerTypeJWK, Name: []byte(name), CredentialID: []byte(kid), }) if err != nil { return err } crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{ Id: stepOIDProvisioner, Critical: false, Value: b, }) return nil } } func withNotBeforeNotAfter(notBefore, notAfter time.Time) provisioner.CertificateModifierFunc { return func(crt *x509.Certificate, _ provisioner.SignOptions) error { crt.NotBefore = notBefore crt.NotAfter = notAfter return nil } } type signerFunc func(crt *x509.Certificate, pub crypto.PublicKey) (*x509.Certificate, error) func withSigner(issuer *x509.Certificate, signer crypto.Signer) signerFunc { return func(crt *x509.Certificate, pub crypto.PublicKey) (*x509.Certificate, error) { return x509util.CreateCertificate(crt, issuer, pub, signer) } } func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateRequest)) *x509.CertificateRequest { _csr := &x509.CertificateRequest{ Subject: pkix.Name{CommonName: "smallstep test"}, DNSNames: []string{"test.smallstep.com"}, } for _, opt := range opts { opt(_csr) } csrBytes, err := x509.CreateCertificateRequest(rand.Reader, _csr, priv) require.NoError(t, err) csr, err := x509.ParseCertificateRequest(csrBytes) require.NoError(t, err) return csr } func setExtraExtsCSR(exts []pkix.Extension) func(*x509.CertificateRequest) { return func(csr *x509.CertificateRequest) { csr.ExtraExtensions = exts } } func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { b, err := x509.MarshalPKIXPublicKey(pub) if err != nil { return nil, fmt.Errorf("error marshaling public key: %w", err) } info := struct { Algorithm pkix.AlgorithmIdentifier SubjectPublicKey asn1.BitString }{} if _, err = asn1.Unmarshal(b, &info); err != nil { return nil, fmt.Errorf("error unmarshaling public key: %w", err) } //nolint:gosec // used to create the Subject Key Identifier by RFC 5280 hash := sha1.Sum(info.SubjectPublicKey.Bytes) return hash[:], nil } type basicConstraints struct { IsCA bool `asn1:"optional"` MaxPathLen int `asn1:"optional,default:-1"` } type testEnforcer struct { enforcer func(*x509.Certificate) error } func (e *testEnforcer) Enforce(cert *x509.Certificate) error { if e.enforcer != nil { return e.enforcer(cert) } return nil } func assertHasPrefix(t *testing.T, s, p string) bool { t.Helper() return assert.True(t, strings.HasPrefix(s, p), "%q is not a prefix of %q", p, s) } func TestAuthority_SignWithContext(t *testing.T) { pub, priv, err := keyutil.GenerateDefaultKeyPair() require.NoError(t, err) a := testAuthority(t) require.NoError(t, err) a.config.AuthorityConfig.Template = &ASN1DN{ Country: "Tazmania", Organization: "Acme Co", Locality: "Landscapes", Province: "Sudden Cliffs", StreetAddress: "TNT", CommonName: "test.smallstep.com", } nb := time.Now() signOpts := provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb), NotAfter: provisioner.NewTimeDuration(nb.Add(time.Minute * 5)), Backdate: 1 * time.Minute, } // Create a token to get test extra opts. p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK) key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) require.NoError(t, err) token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key) require.NoError(t, err) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) extraOpts, err := a.Authorize(ctx, token) require.NoError(t, err) type signTest struct { auth *Authority csr *x509.CertificateRequest signOpts provisioner.SignOptions extraOpts []provisioner.SignOption notBefore time.Time notAfter time.Time extensionsCount int err error code int } tests := map[string]func(*testing.T) *signTest{ "fail invalid signature": func(t *testing.T) *signTest { csr := getCSR(t, priv) csr.Signature = []byte("foo") return &signTest{ auth: a, csr: csr, extraOpts: extraOpts, signOpts: signOpts, err: errors.New("invalid certificate request"), code: http.StatusBadRequest, } }, "fail invalid extra option": func(t *testing.T) *signTest { csr := getCSR(t, priv) csr.Raw = []byte("foo") return &signTest{ auth: a, csr: csr, extraOpts: append(extraOpts, "42"), signOpts: signOpts, err: errors.New("authority.Sign; invalid extra option type string"), code: http.StatusInternalServerError, } }, "fail merge default ASN1DN": func(t *testing.T) *signTest { _a := testAuthority(t) _a.config.AuthorityConfig.Template = nil csr := getCSR(t, priv) return &signTest{ auth: _a, csr: csr, extraOpts: extraOpts, signOpts: signOpts, err: errors.New("default ASN1DN template cannot be nil"), code: http.StatusForbidden, } }, "fail create cert": func(t *testing.T) *signTest { _a := testAuthority(t) _a.x509CAService.(*softcas.SoftCAS).Signer = nil csr := getCSR(t, priv) return &signTest{ auth: _a, csr: csr, extraOpts: extraOpts, signOpts: signOpts, err: errors.New("authority.Sign; error creating certificate"), code: http.StatusInternalServerError, } }, "fail provisioner duration claim": func(t *testing.T) *signTest { csr := getCSR(t, priv) _signOpts := provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb), NotAfter: provisioner.NewTimeDuration(nb.Add(time.Hour * 25)), } return &signTest{ auth: a, csr: csr, extraOpts: extraOpts, signOpts: _signOpts, err: errors.New("requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"), code: http.StatusForbidden, } }, "fail validate sans when adding common name not in claims": func(t *testing.T) *signTest { csr := getCSR(t, priv, func(csr *x509.CertificateRequest) { csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) }) return &signTest{ auth: a, csr: csr, extraOpts: extraOpts, signOpts: signOpts, err: errors.New("certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), code: http.StatusForbidden, } }, "fail rsa key too short": func(t *testing.T) *signTest { shortRSAKeyPEM := `-----BEGIN CERTIFICATE REQUEST----- MIIBhDCB7gIBADAZMRcwFQYDVQQDEw5zbWFsbHN0ZXAgdGVzdDCBnzANBgkqhkiG 9w0BAQEFAAOBjQAwgYkCgYEA5JlgH99HvHHsCD6XTqqYj3bXU2oIlnYGoLVs7IJ4 k205rv5/YWky2gjdpIv0Tnaf3o57IJ891lB7GiyO5iHIEUv5N9dVzrdUboyzk2uZ 7JMMNB43CSLB2oNuwJjLeAM/yBzlhRnvpKjrNSfSV+cH54FXdnbFbcTFMStnjqKG MeECAwEAAaAsMCoGCSqGSIb3DQEJDjEdMBswGQYDVR0RBBIwEIIOc21hbGxzdGVw IHRlc3QwDQYJKoZIhvcNAQELBQADgYEAKwsbr8Zfcq05DgOoJ//cXMFK1SP8ktRU N2++E8Ww0Tet9oyNRArqxxS/UyVio63D3wynzRAB25PFGpYG1cN4b81Gv/foFUT6 W5kR63lNVHBHgQmv5mA8YFsfrJHstaz5k727v2LMHEYIf5/3i16d5zhuxUoaPTYr ZYtQ9Ot36qc= -----END CERTIFICATE REQUEST-----` block, _ := pem.Decode([]byte(shortRSAKeyPEM)) require.NoError(t, err) csr, err := x509.ParseCertificateRequest(block.Bytes) require.NoError(t, err) return &signTest{ auth: a, csr: csr, extraOpts: extraOpts, signOpts: signOpts, err: errors.New("certificate request RSA key must be at least 2048 bits (256 bytes)"), code: http.StatusForbidden, } }, "fail store cert in db": func(t *testing.T) *signTest { csr := getCSR(t, priv) _a := testAuthority(t) _a.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { return errors.New("force") }, } return &signTest{ auth: _a, csr: csr, extraOpts: extraOpts, signOpts: signOpts, err: errors.New("authority.Sign; error storing certificate in db: force"), code: http.StatusInternalServerError, } }, "fail custom template": func(t *testing.T) *signTest { csr := getCSR(t, priv) testAuthority := testAuthority(t) p, ok := testAuthority.provisioners.Load("step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") if !ok { t.Fatal("provisioner not found") } p.(*provisioner.JWK).Options = &provisioner.Options{ X509: &provisioner.X509Options{Template: `{{ fail "fail message" }}`}, } testExtraOpts, err := testAuthority.Authorize(ctx, token) require.NoError(t, err) testAuthority.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, "smallstep test", crt.Subject.CommonName) return nil }, } return &signTest{ auth: testAuthority, csr: csr, extraOpts: testExtraOpts, signOpts: signOpts, err: errors.New("fail message"), code: http.StatusBadRequest, } }, "fail bad JSON syntax template file": func(t *testing.T) *signTest { csr := getCSR(t, priv) testAuthority := testAuthority(t) p, ok := testAuthority.provisioners.Load("step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") if !ok { t.Fatal("provisioner not found") } p.(*provisioner.JWK).Options = &provisioner.Options{ X509: &provisioner.X509Options{ TemplateFile: "./testdata/templates/badjsonsyntax.tpl", }, } testExtraOpts, err := testAuthority.Authorize(ctx, token) require.NoError(t, err) testAuthority.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, "smallstep test", crt.Subject.CommonName) return nil }, } return &signTest{ auth: testAuthority, csr: csr, extraOpts: testExtraOpts, signOpts: signOpts, err: errors.New("error applying certificate template: invalid character"), code: http.StatusInternalServerError, } }, "fail bad JSON value template file": func(t *testing.T) *signTest { csr := getCSR(t, priv) testAuthority := testAuthority(t) p, ok := testAuthority.provisioners.Load("step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") if !ok { t.Fatal("provisioner not found") } p.(*provisioner.JWK).Options = &provisioner.Options{ X509: &provisioner.X509Options{ TemplateFile: "./testdata/templates/badjsonvalue.tpl", }, } testExtraOpts, err := testAuthority.Authorize(ctx, token) require.NoError(t, err) testAuthority.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, "smallstep test", crt.Subject.CommonName) return nil }, } return &signTest{ auth: testAuthority, csr: csr, extraOpts: testExtraOpts, signOpts: signOpts, err: errors.New("error applying certificate template: cannot unmarshal"), code: http.StatusInternalServerError, } }, "fail with provisioner enforcer": func(t *testing.T) *signTest { csr := getCSR(t, priv) aa := testAuthority(t) aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, "smallstep test", crt.Subject.CommonName) return nil }, } return &signTest{ auth: aa, csr: csr, extraOpts: append(extraOpts, &testEnforcer{ enforcer: func(crt *x509.Certificate) error { return fmt.Errorf("an error") }, }), signOpts: signOpts, err: errors.New("error creating certificate"), code: http.StatusForbidden, } }, "fail with custom enforcer": func(t *testing.T) *signTest { csr := getCSR(t, priv) aa := testAuthority(t, WithX509Enforcers(&testEnforcer{ enforcer: func(cert *x509.Certificate) error { return fmt.Errorf("an error") }, })) aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, "smallstep test", crt.Subject.CommonName) return nil }, } return &signTest{ auth: aa, csr: csr, extraOpts: extraOpts, signOpts: signOpts, err: errors.New("error creating certificate"), code: http.StatusForbidden, } }, "fail with policy": func(t *testing.T) *signTest { csr := getCSR(t, priv) aa := testAuthority(t) aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { fmt.Println(crt.Subject) assert.Equal(t, "smallstep test", crt.Subject.CommonName) return nil }, } options := &policy.Options{ X509: &policy.X509PolicyOptions{ DeniedNames: &policy.X509NameOptions{ DNSDomains: []string{"test.smallstep.com"}, }, }, } engine, err := policy.New(options) require.NoError(t, err) aa.policyEngine = engine return &signTest{ auth: aa, csr: csr, extraOpts: extraOpts, signOpts: signOpts, notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second), extensionsCount: 6, err: errors.New("dns name \"test.smallstep.com\" not allowed"), code: http.StatusForbidden, } }, "fail enriching webhooks": func(t *testing.T) *signTest { csr := getCSR(t, priv) csr.Raw = []byte("foo") return &signTest{ auth: a, csr: csr, extensionsCount: 7, extraOpts: append(extraOpts, &mockWebhookController{ enrichErr: provisioner.ErrWebhookDenied, }), signOpts: signOpts, err: provisioner.ErrWebhookDenied, code: http.StatusForbidden, } }, "fail authorizing webhooks": func(t *testing.T) *signTest { csr := getCSR(t, priv) csr.Raw = []byte("foo") return &signTest{ auth: a, csr: csr, extensionsCount: 7, extraOpts: append(extraOpts, &mockWebhookController{ authorizeErr: provisioner.ErrWebhookDenied, }), signOpts: signOpts, err: provisioner.ErrWebhookDenied, code: http.StatusForbidden, } }, "fail with cnf": func(t *testing.T) *signTest { csr := getCSR(t, priv) auth := testAuthority(t) auth.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template auth.db = &db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, crt.Subject.CommonName, "smallstep test") assert.Equal(t, crt.DNSNames, []string{"test.smallstep.com"}) return nil }, } // Create a token with cnf tok, err := generateCustomToken("smallstep test", "step-cli", testAudiences.Sign[0], key, nil, map[string]any{ "sans": []string{"test.smallstep.com"}, "cnf": map[string]any{"x5rt#S256": "bad-fingerprint"}, }) require.NoError(t, err) opts, err := auth.Authorize(ctx, tok) require.NoError(t, err) return &signTest{ auth: auth, csr: csr, extraOpts: opts, signOpts: signOpts, notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second), err: errors.New(`certificate request fingerprint does not match "bad-fingerprint"`), code: http.StatusForbidden, } }, "ok": func(t *testing.T) *signTest { csr := getCSR(t, priv) _a := testAuthority(t) _a.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, "smallstep test", crt.Subject.CommonName) return nil }, } return &signTest{ auth: a, csr: csr, extraOpts: extraOpts, signOpts: signOpts, notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second), extensionsCount: 6, } }, "ok with enforced modifier": func(t *testing.T) *signTest { bcExt := pkix.Extension{} bcExt.Id = asn1.ObjectIdentifier{2, 5, 29, 19} bcExt.Critical = false bcExt.Value, err = asn1.Marshal(basicConstraints{IsCA: true, MaxPathLen: 4}) require.NoError(t, err) csr := getCSR(t, priv, setExtraExtsCSR([]pkix.Extension{ bcExt, {Id: stepOIDProvisioner, Value: []byte("foo")}, {Id: []int{1, 1, 1}, Value: []byte("bar")}})) now := time.Now().UTC() //nolint:gocritic enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{ NotBefore: now, NotAfter: now.Add(365 * 24 * time.Hour), }) _a := testAuthority(t) _a.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, "smallstep test", crt.Subject.CommonName) return nil }, } return &signTest{ auth: a, csr: csr, extraOpts: enforcedExtraOptions, signOpts: signOpts, notBefore: now.Truncate(time.Second), notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second), extensionsCount: 6, } }, "ok with custom template": func(t *testing.T) *signTest { csr := getCSR(t, priv) testAuthority := testAuthority(t) testAuthority.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template p, ok := testAuthority.provisioners.Load("step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") if !ok { t.Fatal("provisioner not found") } p.(*provisioner.JWK).Options = &provisioner.Options{ X509: &provisioner.X509Options{Template: `{ "subject": {{toJson .Subject}}, "dnsNames": {{ toJson .Insecure.CR.DNSNames }}, "keyUsage": ["digitalSignature"], "extKeyUsage": ["serverAuth","clientAuth"] }`}, } testExtraOpts, err := testAuthority.Authorize(ctx, token) require.NoError(t, err) testAuthority.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, "smallstep test", crt.Subject.CommonName) return nil }, } return &signTest{ auth: testAuthority, csr: csr, extraOpts: testExtraOpts, signOpts: signOpts, notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second), extensionsCount: 6, } }, "ok with enriching webhook": func(t *testing.T) *signTest { csr := getCSR(t, priv) testAuthority := testAuthority(t) testAuthority.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template p, ok := testAuthority.provisioners.Load("step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") if !ok { t.Fatal("provisioner not found") } p.(*provisioner.JWK).Options = &provisioner.Options{ X509: &provisioner.X509Options{Template: `{ "subject": {"commonName": {{ toJson .Webhooks.people.role }} }, "dnsNames": {{ toJson .Insecure.CR.DNSNames }}, "keyUsage": ["digitalSignature"], "extKeyUsage": ["serverAuth","clientAuth"] }`}, } testExtraOpts, err := testAuthority.Authorize(ctx, token) require.NoError(t, err) testAuthority.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, "smallstep test", crt.Subject.CommonName) return nil }, } for i, o := range testExtraOpts { if wc, ok := o.(*provisioner.WebhookController); ok { testExtraOpts[i] = &mockWebhookController{ templateData: wc.TemplateData, respData: map[string]any{"people": map[string]any{"role": "smallstep test"}}, } } } return &signTest{ auth: testAuthority, csr: csr, extraOpts: testExtraOpts, signOpts: signOpts, notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second), extensionsCount: 6, } }, "ok/csr with no template critical SAN extension": func(t *testing.T) *signTest { csr := getCSR(t, priv, func(csr *x509.CertificateRequest) { csr.Subject = pkix.Name{} }, func(csr *x509.CertificateRequest) { csr.DNSNames = []string{"foo", "bar"} }) now := time.Now().UTC() enforcedExtraOptions := []provisioner.SignOption{&certificateDurationEnforcer{ NotBefore: now, NotAfter: now.Add(365 * 24 * time.Hour), }} _a := testAuthority(t) _a.config.AuthorityConfig.Template = &ASN1DN{} _a.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, pkix.Name{}, crt.Subject) return nil }, } return &signTest{ auth: _a, csr: csr, extraOpts: enforcedExtraOptions, signOpts: provisioner.SignOptions{}, notBefore: now.Truncate(time.Second), notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second), extensionsCount: 5, } }, "ok with custom enforcer": func(t *testing.T) *signTest { csr := getCSR(t, priv) aa := testAuthority(t, WithX509Enforcers(&testEnforcer{ enforcer: func(cert *x509.Certificate) error { cert.CRLDistributionPoints = []string{"http://ca.example.org/leaf.crl"} return nil }, })) aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, "smallstep test", crt.Subject.CommonName) assert.Equal(t, []string{"http://ca.example.org/leaf.crl"}, crt.CRLDistributionPoints) return nil }, } return &signTest{ auth: aa, csr: csr, extraOpts: extraOpts, signOpts: signOpts, notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second), extensionsCount: 7, } }, "ok with policy": func(t *testing.T) *signTest { csr := getCSR(t, priv) aa := testAuthority(t) aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template aa.db = &db.MockAuthDB{ MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, crt.Subject.CommonName, "smallstep test") return nil }, } options := &policy.Options{ X509: &policy.X509PolicyOptions{ AllowedNames: &policy.X509NameOptions{ CommonNames: []string{"smallstep test"}, DNSDomains: []string{"*.smallstep.com"}, }, }, } engine, err := policy.New(options) require.NoError(t, err) aa.policyEngine = engine return &signTest{ auth: aa, csr: csr, extraOpts: extraOpts, signOpts: signOpts, notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second), extensionsCount: 6, } }, "ok with attestation data": func(t *testing.T) *signTest { csr := getCSR(t, priv) aa := testAuthority(t) aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template aa.db = &certificateChainDB{ MStoreCertificateChain: func(prov provisioner.Interface, certs ...*x509.Certificate) error { p, ok := prov.(attProvisioner) if assert.True(t, ok) { assert.Equal(t, &provisioner.AttestationData{ PermanentIdentifier: "1234567890", }, p.AttestationData()) } if assert.Len(t, certs, 2) { assert.Equal(t, "smallstep test", certs[0].Subject.CommonName) assert.Equal(t, "smallstep Intermediate CA", certs[1].Subject.CommonName) } return nil }, } return &signTest{ auth: aa, csr: csr, extraOpts: append(extraOpts, provisioner.AttestationData{ PermanentIdentifier: "1234567890", }), signOpts: signOpts, notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second), extensionsCount: 6, } }, "ok with cnf": func(t *testing.T) *signTest { csr := getCSR(t, priv) fingerprint, err := fingerprint.New(csr.Raw, crypto.SHA256, fingerprint.Base64RawURLFingerprint) require.NoError(t, err) auth := testAuthority(t) auth.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template auth.db = &db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MStoreCertificate: func(crt *x509.Certificate) error { assert.Equal(t, crt.Subject.CommonName, "smallstep test") assert.Equal(t, crt.DNSNames, []string{"test.smallstep.com"}) return nil }, } // Create a token with cnf tok, err := generateCustomToken("smallstep test", "step-cli", testAudiences.Sign[0], key, nil, map[string]any{ "sans": []string{"test.smallstep.com"}, "cnf": map[string]any{"x5rt#S256": fingerprint}, }) require.NoError(t, err) opts, err := auth.Authorize(ctx, tok) require.NoError(t, err) return &signTest{ auth: auth, csr: csr, extraOpts: opts, signOpts: signOpts, notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second), extensionsCount: 6, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) certChain, err := tc.auth.SignWithContext(context.Background(), tc.csr, tc.signOpts, tc.extraOpts...) if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) var sc render.StatusCodedError require.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equal(t, tc.code, sc.StatusCode()) assertHasPrefix(t, err.Error(), tc.err.Error()) var ctxErr *errs.Error require.True(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") assert.Equal(t, tc.csr, ctxErr.Details["csr"]) assert.Equal(t, tc.signOpts, ctxErr.Details["signOptions"]) } } else { leaf := certChain[0] intermediate := certChain[1] if assert.Nil(t, tc.err) { assert.Equal(t, tc.notBefore, leaf.NotBefore) assert.Equal(t, tc.notAfter, leaf.NotAfter) tmplt := a.config.AuthorityConfig.Template if tc.csr.Subject.CommonName == "" { assert.Equal(t, pkix.Name{}, leaf.Subject) } else { assert.Equal(t, pkix.Name{ Country: []string{tmplt.Country}, Organization: []string{tmplt.Organization}, Locality: []string{tmplt.Locality}, StreetAddress: []string{tmplt.StreetAddress}, Province: []string{tmplt.Province}, CommonName: "smallstep test", }.String(), leaf.Subject.String()) assert.Equal(t, []string{"test.smallstep.com"}, leaf.DNSNames) } assert.Equal(t, intermediate.Subject, leaf.Issuer) assert.Equal(t, x509.ECDSAWithSHA256, leaf.SignatureAlgorithm) assert.Equal(t, x509.ECDSA, leaf.PublicKeyAlgorithm) assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, leaf.ExtKeyUsage) issuer := getDefaultIssuer(a) subjectKeyID, err := generateSubjectKeyID(pub) require.NoError(t, err) assert.Equal(t, subjectKeyID, leaf.SubjectKeyId) assert.Equal(t, issuer.SubjectKeyId, leaf.AuthorityKeyId) // Verify Provisioner OID found := 0 for _, ext := range leaf.Extensions { switch { case ext.Id.Equal(stepOIDProvisioner): found++ val := stepProvisionerASN1{} _, err := asn1.Unmarshal(ext.Value, &val) require.NoError(t, err) assert.Equal(t, provisionerTypeJWK, val.Type) assert.Equal(t, []byte(p.Name), val.Name) assert.Equal(t, []byte(p.Key.KeyID), val.CredentialID) // Basic Constraints case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 19})): val := basicConstraints{} _, err := asn1.Unmarshal(ext.Value, &val) require.NoError(t, err) assert.False(t, val.IsCA, false) assert.Equal(t, val.MaxPathLen, 0) // SAN extension case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 17})): if tc.csr.Subject.CommonName == "" { // Empty CSR subject test does not use any provisioner extensions. // So provisioner ID ext will be missing. found = 1 } } } assert.Equal(t, found, 1) realIntermediate, err := x509.ParseCertificate(issuer.Raw) require.NoError(t, err) assert.Equal(t, realIntermediate, intermediate) assert.Len(t, leaf.Extensions, tc.extensionsCount) } } }) } } func TestAuthority_Renew(t *testing.T) { a := testAuthority(t) a.config.AuthorityConfig.Template = &ASN1DN{ Country: "Tazmania", Organization: "Acme Co", Locality: "Landscapes", Province: "Sudden Cliffs", StreetAddress: "TNT", CommonName: "renew", } now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) na1 := now.Add(time.Hour) so := &provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb1), NotAfter: provisioner.NewTimeDuration(na1), } issuer := getDefaultIssuer(a) signer := getDefaultSigner(a) cert := generateCertificate(t, "renew", []string{"test.smallstep.com", "test"}, withNotBeforeNotAfter(so.NotBefore.Time(), so.NotAfter.Time()), withDefaultASN1DN(a.config.AuthorityConfig.Template), withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID), withSigner(issuer, signer)) certExtraNames := generateCertificate(t, "renew", []string{"test.smallstep.com", "test"}, withSubject(pkix.Name{ CommonName: "renew", ExtraNames: []pkix.AttributeTypeAndValue{ {Type: asn1.ObjectIdentifier{0, 9, 2342, 19200300, 100, 1, 25}, Value: "dc"}, }, }), withNotBeforeNotAfter(so.NotBefore.Time(), so.NotAfter.Time()), withDefaultASN1DN(a.config.AuthorityConfig.Template), withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID), withSigner(issuer, signer)) certNoRenew := generateCertificate(t, "renew", []string{"test.smallstep.com", "test"}, withNotBeforeNotAfter(so.NotBefore.Time(), so.NotAfter.Time()), withDefaultASN1DN(a.config.AuthorityConfig.Template), withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID), withSigner(issuer, signer)) type renewTest struct { auth *Authority cert *x509.Certificate err error code int } tests := map[string]func() (*renewTest, error){ "fail/create-cert": func() (*renewTest, error) { _a := testAuthority(t) _a.x509CAService.(*softcas.SoftCAS).Signer = nil return &renewTest{ auth: _a, cert: cert, err: errors.New("error creating certificate"), code: http.StatusInternalServerError, }, nil }, "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, "fail/WithAuthorizeRenewFunc": func() (*renewTest, error) { aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error { return errs.Unauthorized("not authorized") })) aa.x509CAService = a.x509CAService aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template return &renewTest{ auth: aa, cert: cert, err: errors.New("authority.authorizeRenew: not authorized"), code: http.StatusUnauthorized, }, nil }, "ok": func() (*renewTest, error) { return &renewTest{ auth: a, cert: cert, }, nil }, "ok/WithExtraNames": func() (*renewTest, error) { return &renewTest{ auth: a, cert: certExtraNames, }, nil }, "ok/success-new-intermediate": func() (*renewTest, error) { rootCert, rootSigner := generateRootCertificate(t) intCert, intSigner := generateIntermidiateCertificate(t, rootCert, rootSigner) _a := testAuthority(t) _a.x509CAService.(*softcas.SoftCAS).CertificateChain = []*x509.Certificate{intCert} _a.x509CAService.(*softcas.SoftCAS).Signer = intSigner return &renewTest{ auth: _a, cert: cert, }, nil }, "ok/WithAuthorizeRenewFunc": func() (*renewTest, error) { aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error { return nil })) aa.x509CAService = a.x509CAService aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template return &renewTest{ auth: aa, cert: cert, }, nil }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc, err := genTestCase() require.NoError(t, err) var certChain []*x509.Certificate if tc.auth != nil { certChain, err = tc.auth.Renew(tc.cert) } else { certChain, err = a.Renew(tc.cert) } if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) var sc render.StatusCodedError require.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equal(t, tc.code, sc.StatusCode()) assertHasPrefix(t, err.Error(), tc.err.Error()) var ctxErr *errs.Error require.True(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") assert.Equal(t, tc.cert.SerialNumber.String(), ctxErr.Details["serialNumber"]) } } else { leaf := certChain[0] intermediate := certChain[1] if assert.Nil(t, tc.err) { assert.Equal(t, tc.cert.NotAfter.Sub(cert.NotBefore), leaf.NotAfter.Sub(leaf.NotBefore)) assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute))) assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute))) expiry := now.Add(time.Minute * 7) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template assert.Equal(t, tc.cert.RawSubject, leaf.RawSubject) assert.Equal(t, []string{tmplt.Country}, leaf.Subject.Country) assert.Equal(t, []string{tmplt.Organization}, leaf.Subject.Organization) assert.Equal(t, []string{tmplt.Locality}, leaf.Subject.Locality) assert.Equal(t, []string{tmplt.StreetAddress}, leaf.Subject.StreetAddress) assert.Equal(t, []string{tmplt.Province}, leaf.Subject.Province) assert.Equal(t, tmplt.CommonName, leaf.Subject.CommonName) assert.Equal(t, intermediate.Subject, leaf.Issuer) assert.Equal(t, x509.ECDSAWithSHA256, leaf.SignatureAlgorithm) assert.Equal(t, x509.ECDSA, leaf.PublicKeyAlgorithm) assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, leaf.ExtKeyUsage) assert.Equal(t, []string{"test.smallstep.com", "test"}, leaf.DNSNames) subjectKeyID, err := generateSubjectKeyID(leaf.PublicKey) require.NoError(t, err) assert.Equal(t, subjectKeyID, leaf.SubjectKeyId) // We did not change the intermediate before renewing. authIssuer := getDefaultIssuer(tc.auth) if issuer.SerialNumber == authIssuer.SerialNumber { assert.Equal(t, issuer.SubjectKeyId, leaf.AuthorityKeyId) // Compare extensions: they can be in a different order for _, ext1 := range tc.cert.Extensions { //skip SubjectKeyIdentifier if ext1.Id.Equal(oidSubjectKeyIdentifier) { continue } found := false for _, ext2 := range leaf.Extensions { if reflect.DeepEqual(ext1, ext2) { found = true break } } if !found { t.Errorf("x509 extension %s not found in renewed certificate", ext1.Id.String()) } } } else { // We did change the intermediate before renewing. assert.Equal(t, authIssuer.SubjectKeyId, leaf.AuthorityKeyId) // Compare extensions: they can be in a different order for _, ext1 := range tc.cert.Extensions { //skip SubjectKeyIdentifier if ext1.Id.Equal(oidSubjectKeyIdentifier) { continue } // The authority key id extension should be different b/c the intermediates are different. if ext1.Id.Equal(oidAuthorityKeyIdentifier) { for _, ext2 := range leaf.Extensions { assert.False(t, reflect.DeepEqual(ext1, ext2)) } continue } found := false for _, ext2 := range leaf.Extensions { if reflect.DeepEqual(ext1, ext2) { found = true break } } if !found { t.Errorf("x509 extension %s not found in renewed certificate", ext1.Id.String()) } } } realIntermediate, err := x509.ParseCertificate(authIssuer.Raw) require.NoError(t, err) assert.Equal(t, realIntermediate, intermediate) } } }) } } func TestAuthority_Rekey(t *testing.T) { pub, _, err := keyutil.GenerateDefaultKeyPair() require.NoError(t, err) a := testAuthority(t) a.config.AuthorityConfig.Template = &ASN1DN{ Country: "Tazmania", Organization: "Acme Co", Locality: "Landscapes", Province: "Sudden Cliffs", StreetAddress: "TNT", CommonName: "renew", } now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) na1 := now.Add(time.Hour) so := &provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb1), NotAfter: provisioner.NewTimeDuration(na1), } issuer := getDefaultIssuer(a) signer := getDefaultSigner(a) cert := generateCertificate(t, "renew", []string{"test.smallstep.com", "test"}, withNotBeforeNotAfter(so.NotBefore.Time(), so.NotAfter.Time()), withDefaultASN1DN(a.config.AuthorityConfig.Template), withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID), withSigner(issuer, signer)) certNoRenew := generateCertificate(t, "renew", []string{"test.smallstep.com", "test"}, withNotBeforeNotAfter(so.NotBefore.Time(), so.NotAfter.Time()), withDefaultASN1DN(a.config.AuthorityConfig.Template), withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID), withSigner(issuer, signer)) type renewTest struct { auth *Authority cert *x509.Certificate pk crypto.PublicKey err error code int } tests := map[string]func() (*renewTest, error){ "fail/create-cert": func() (*renewTest, error) { _a := testAuthority(t) _a.x509CAService.(*softcas.SoftCAS).Signer = nil return &renewTest{ auth: _a, cert: cert, err: errors.New("error creating certificate"), code: http.StatusInternalServerError, }, nil }, "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, "ok/renew": func() (*renewTest, error) { return &renewTest{ auth: a, cert: cert, }, nil }, "ok/rekey": func() (*renewTest, error) { return &renewTest{ auth: a, cert: cert, pk: pub, }, nil }, "ok/renew/success-new-intermediate": func() (*renewTest, error) { rootCert, rootSigner := generateRootCertificate(t) intCert, intSigner := generateIntermidiateCertificate(t, rootCert, rootSigner) _a := testAuthority(t) _a.x509CAService.(*softcas.SoftCAS).CertificateChain = []*x509.Certificate{intCert} _a.x509CAService.(*softcas.SoftCAS).Signer = intSigner return &renewTest{ auth: _a, cert: cert, }, nil }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc, err := genTestCase() require.NoError(t, err) var certChain []*x509.Certificate if tc.auth != nil { certChain, err = tc.auth.Rekey(tc.cert, tc.pk) } else { certChain, err = a.Rekey(tc.cert, tc.pk) } if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) var sc render.StatusCodedError require.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equal(t, tc.code, sc.StatusCode()) assertHasPrefix(t, err.Error(), tc.err.Error()) var ctxErr *errs.Error require.True(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") assert.Equal(t, tc.cert.SerialNumber.String(), ctxErr.Details["serialNumber"]) } } else { leaf := certChain[0] intermediate := certChain[1] if assert.Nil(t, tc.err) { assert.Equal(t, tc.cert.NotAfter.Sub(cert.NotBefore), leaf.NotAfter.Sub(leaf.NotBefore)) assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute))) assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute))) expiry := now.Add(time.Minute * 7) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template assert.Equal(t, pkix.Name{ Country: []string{tmplt.Country}, Organization: []string{tmplt.Organization}, Locality: []string{tmplt.Locality}, StreetAddress: []string{tmplt.StreetAddress}, Province: []string{tmplt.Province}, CommonName: tmplt.CommonName, }.String(), leaf.Subject.String()) assert.Equal(t, intermediate.Subject, leaf.Issuer) assert.Equal(t, x509.ECDSAWithSHA256, leaf.SignatureAlgorithm) assert.Equal(t, x509.ECDSA, leaf.PublicKeyAlgorithm) assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, leaf.ExtKeyUsage) assert.Equal(t, []string{"test.smallstep.com", "test"}, leaf.DNSNames) // Test Public Key and SubjectKeyId expectedPK := tc.pk if tc.pk == nil { expectedPK = cert.PublicKey } assert.Equal(t, expectedPK, leaf.PublicKey) subjectKeyID, err := generateSubjectKeyID(expectedPK) require.NoError(t, err) assert.Equal(t, subjectKeyID, leaf.SubjectKeyId) if tc.pk == nil { assert.Equal(t, cert.SubjectKeyId, leaf.SubjectKeyId) } // We did not change the intermediate before renewing. authIssuer := getDefaultIssuer(tc.auth) if issuer.SerialNumber == authIssuer.SerialNumber { assert.Equal(t, issuer.SubjectKeyId, leaf.AuthorityKeyId) // Compare extensions: they can be in a different order for _, ext1 := range tc.cert.Extensions { //skip SubjectKeyIdentifier if ext1.Id.Equal(oidSubjectKeyIdentifier) { continue } found := false for _, ext2 := range leaf.Extensions { if reflect.DeepEqual(ext1, ext2) { found = true break } } if !found { t.Errorf("x509 extension %s not found in renewed certificate", ext1.Id.String()) } } } else { // We did change the intermediate before renewing. assert.Equal(t, authIssuer.SubjectKeyId, leaf.AuthorityKeyId) // Compare extensions: they can be in a different order for _, ext1 := range tc.cert.Extensions { //skip SubjectKeyIdentifier if ext1.Id.Equal(oidSubjectKeyIdentifier) { continue } // The authority key id extension should be different b/c the intermediates are different. if ext1.Id.Equal(oidAuthorityKeyIdentifier) { for _, ext2 := range leaf.Extensions { assert.False(t, reflect.DeepEqual(ext1, ext2)) } continue } found := false for _, ext2 := range leaf.Extensions { if reflect.DeepEqual(ext1, ext2) { found = true break } } if !found { t.Errorf("x509 extension %s not found in renewed certificate", ext1.Id.String()) } } } realIntermediate, err := x509.ParseCertificate(authIssuer.Raw) require.NoError(t, err) assert.Equal(t, realIntermediate, intermediate) } } }) } } func TestAuthority_GetTLSOptions(t *testing.T) { type renewTest struct { auth *Authority opts *TLSOptions } tests := map[string]func() (*renewTest, error){ "default": func() (*renewTest, error) { a := testAuthority(t) return &renewTest{auth: a, opts: &DefaultTLSOptions}, nil }, "non-default": func() (*renewTest, error) { a := testAuthority(t) a.config.TLS = &TLSOptions{ CipherSuites: CipherSuites{ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", }, MinVersion: 1.0, MaxVersion: 1.1, Renegotiation: true, } return &renewTest{auth: a, opts: a.config.TLS}, nil }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc, err := genTestCase() require.NoError(t, err) opts := tc.auth.GetTLSOptions() assert.Equal(t, tc.opts, opts) }) } } func TestAuthority_Revoke(t *testing.T) { reasonCode := 2 reason := "bob was let go" validIssuer := "step-cli" validAudience := testAudiences.Revoke now := time.Now().UTC() jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) require.NoError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) require.NoError(t, err) a := testAuthority(t) tlsRevokeCtx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) type test struct { auth *Authority ctx context.Context opts *RevokeOptions err error code int checkErrDetails func(err *errs.Error) } tests := map[string]func() test{ "fail/token/authorizeRevoke error": func() test { return test{ auth: a, ctx: tlsRevokeCtx, opts: &RevokeOptions{ OTT: "foo", Serial: "sn", ReasonCode: reasonCode, Reason: reason, }, err: errors.New("authority.Revoke; error parsing token"), code: http.StatusUnauthorized, } }, "fail/nil-db": func() test { cl := jose.Claims{ Subject: "sn", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() require.NoError(t, err) return test{ auth: a, ctx: tlsRevokeCtx, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, Reason: reason, OTT: raw, }, err: errors.New("authority.Revoke; no persistence layer configured"), code: http.StatusNotImplemented, checkErrDetails: func(err *errs.Error) { assert.Equal(t, raw, err.Details["token"]) assert.Equal(t, "44", err.Details["tokenID"]) assert.Equal(t, "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc", err.Details["provisionerID"]) }, } }, "fail/db-revoke": func() test { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MGetCertificate: func(sn string) (*x509.Certificate, error) { return nil, errors.New("not found") }, Err: errors.New("force"), })) cl := jose.Claims{ Subject: "sn", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() require.NoError(t, err) return test{ auth: _a, ctx: tlsRevokeCtx, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, Reason: reason, OTT: raw, }, err: errors.New("authority.Revoke: force"), code: http.StatusInternalServerError, checkErrDetails: func(err *errs.Error) { assert.Equal(t, raw, err.Details["token"]) assert.Equal(t, "44", err.Details["tokenID"]) assert.Equal(t, "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc", err.Details["provisionerID"]) }, } }, "fail/already-revoked": func() test { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MGetCertificate: func(sn string) (*x509.Certificate, error) { return nil, errors.New("not found") }, Err: db.ErrAlreadyExists, })) cl := jose.Claims{ Subject: "sn", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() require.NoError(t, err) return test{ auth: _a, ctx: tlsRevokeCtx, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, Reason: reason, OTT: raw, }, err: errors.New("certificate with serial number 'sn' is already revoked"), code: http.StatusBadRequest, checkErrDetails: func(err *errs.Error) { assert.Equal(t, raw, err.Details["token"]) assert.Equal(t, "44", err.Details["tokenID"]) assert.Equal(t, "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc", err.Details["provisionerID"]) }, } }, "fail/serial-number": func() test { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MGetCertificate: func(sn string) (*x509.Certificate, error) { return nil, errors.New("not found") }, })) cl := jose.Claims{ Subject: "token-sn", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() require.NoError(t, err) return test{ auth: _a, ctx: tlsRevokeCtx, opts: &RevokeOptions{ Serial: "request-sn", ReasonCode: reasonCode, Reason: reason, OTT: raw, }, err: errors.New(`request serial number "request-sn" and token subject "token-sn" do not match`), code: http.StatusForbidden, checkErrDetails: func(err *errs.Error) { assert.Equal(t, raw, err.Details["token"]) }, } }, "ok/token": func() test { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MGetCertificate: func(sn string) (*x509.Certificate, error) { return nil, errors.New("not found") }, })) cl := jose.Claims{ Subject: "sn", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() require.NoError(t, err) return test{ auth: _a, ctx: tlsRevokeCtx, opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, Reason: reason, OTT: raw, }, } }, "ok/mTLS": func() test { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") require.NoError(t, err) return test{ auth: _a, ctx: tlsRevokeCtx, opts: &RevokeOptions{ Crt: crt, Serial: "102012593071130646873265215610956555026", ReasonCode: reasonCode, Reason: reason, MTLS: true, }, } }, "ok/mTLS-no-provisioner": func() test { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") require.NoError(t, err) // Filter out provisioner extension. for i, ext := range crt.Extensions { if ext.Id.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}) { crt.Extensions = append(crt.Extensions[:i], crt.Extensions[i+1:]...) break } } return test{ auth: _a, ctx: tlsRevokeCtx, opts: &RevokeOptions{ Crt: crt, Serial: "102012593071130646873265215610956555026", ReasonCode: reasonCode, Reason: reason, MTLS: true, }, } }, "ok/ACME": func() test { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") require.NoError(t, err) return test{ auth: _a, ctx: tlsRevokeCtx, opts: &RevokeOptions{ Crt: crt, Serial: "102012593071130646873265215610956555026", ReasonCode: reasonCode, Reason: reason, ACME: true, }, } }, "ok/ssh": func() test { a := testAuthority(t, WithDatabase(&db.MockAuthDB{ MRevoke: func(rci *db.RevokedCertificateInfo) error { return errors.New("Revoke was called") }, MRevokeSSH: func(rci *db.RevokedCertificateInfo) error { return nil }, })) cl := jose.Claims{ Subject: "sn", Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: "44", } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() require.NoError(t, err) return test{ auth: a, ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), opts: &RevokeOptions{ Serial: "sn", ReasonCode: reasonCode, Reason: reason, OTT: raw, }, } }, } for name, f := range tests { tc := f() t.Run(name, func(t *testing.T) { if err := tc.auth.Revoke(tc.ctx, tc.opts); err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { var sc render.StatusCodedError require.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") assert.Equal(t, tc.code, sc.StatusCode()) assertHasPrefix(t, err.Error(), tc.err.Error()) var ctxErr *errs.Error require.True(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") assert.Equal(t, tc.opts.Serial, ctxErr.Details["serialNumber"]) assert.Equal(t, tc.opts.ReasonCode, ctxErr.Details["reasonCode"]) assert.Equal(t, tc.opts.Reason, ctxErr.Details["reason"]) assert.Equal(t, tc.opts.MTLS, ctxErr.Details["MTLS"]) assert.Equal(t, provisioner.RevokeMethod.String(), ctxErr.Details["context"]) if tc.checkErrDetails != nil { tc.checkErrDetails(ctxErr) } } } else { assert.Nil(t, tc.err) } }) } } func TestAuthority_constraints(t *testing.T) { ca, err := minica.New( minica.WithIntermediateTemplate(`{ "subject": {{ toJson .Subject }}, "keyUsage": ["certSign", "crlSign"], "basicConstraints": { "isCA": true, "maxPathLen": 0 }, "nameConstraints": { "critical": true, "permittedDNSDomains": ["internal.example.org"], "excludedDNSDomains": ["internal.example.com"], "permittedIPRanges": ["192.168.1.0/24", "192.168.2.1/32"], "excludedIPRanges": ["192.168.3.0/24", "192.168.4.0/28"], "permittedEmailAddresses": ["root@example.org", "example.org", ".acme.org"], "excludedEmailAddresses": ["root@example.com", "example.com", ".acme.com"], "permittedURIDomains": ["uuid.example.org", ".acme.org"], "excludedURIDomains": ["uuid.example.com", ".acme.com"] } }`), ) if err != nil { t.Fatal(err) } auth, err := NewEmbedded(WithX509RootCerts(ca.Root), WithX509Signer(ca.Intermediate, ca.Signer)) if err != nil { t.Fatal(err) } signer, err := keyutil.GenerateDefaultSigner() if err != nil { t.Fatal(err) } tests := []struct { name string sans []string wantErr bool }{ {"ok dns", []string{"internal.example.org", "host.internal.example.org"}, false}, {"ok ip", []string{"192.168.1.10", "192.168.2.1"}, false}, {"ok email", []string{"root@example.org", "info@example.org", "info@www.acme.org"}, false}, {"ok uri", []string{"https://uuid.example.org/b908d973-5167-4a62-abe3-6beda358d82a", "https://uuid.acme.org/1724aae1-1bb3-44fb-83c3-9a1a18df67c8"}, false}, {"fail permitted dns", []string{"internal.acme.org"}, true}, {"fail excluded dns", []string{"internal.example.com"}, true}, {"fail permitted ips", []string{"192.168.2.10"}, true}, {"fail excluded ips", []string{"192.168.3.1"}, true}, {"fail permitted emails", []string{"root@acme.org"}, true}, {"fail excluded emails", []string{"root@example.com"}, true}, {"fail permitted uris", []string{"https://acme.org/uuid/7848819c-9d0b-4e12-bbff-cd66079a3444"}, true}, {"fail excluded uris", []string{"https://uuid.example.com/d325eda7-6356-4d60-b8f6-3d64724afeb3"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { csr, err := x509util.CreateCertificateRequest(tt.sans[0], tt.sans, signer) if err != nil { t.Fatal(err) } cert, err := ca.SignCSR(csr) if err != nil { t.Fatal(err) } data := x509util.CreateTemplateData(tt.sans[0], tt.sans) templateOption, err := provisioner.TemplateOptions(nil, data) if err != nil { t.Fatal(err) } _, err = auth.SignWithContext(context.Background(), csr, provisioner.SignOptions{}, templateOption) if (err != nil) != tt.wantErr { t.Errorf("Authority.SignWithContext() error = %v, wantErr %v", err, tt.wantErr) } _, err = auth.Renew(cert) if (err != nil) != tt.wantErr { t.Errorf("Authority.Renew() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestAuthority_CRL(t *testing.T) { reasonCode := 2 reason := "bob was let go" validIssuer := "step-cli" validAudience := testAudiences.Revoke now := time.Now().UTC() jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) require.NoError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) require.NoError(t, err) crlCtx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) var crlStore db.CertificateRevocationListInfo var revokedList []db.RevokedCertificateInfo type test struct { auth *Authority ctx context.Context expected []string expectedReasonCode *int err error } tests := map[string]func() test{ "fail/empty-crl": func() test { a := testAuthority(t, WithDatabase(&db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MGetCertificate: func(sn string) (*x509.Certificate, error) { return nil, errors.New("not found") }, MStoreCRL: func(i *db.CertificateRevocationListInfo) error { crlStore = *i return nil }, MGetCRL: func() (*db.CertificateRevocationListInfo, error) { return nil, database.ErrNotFound }, MGetRevokedCertificates: func() (*[]db.RevokedCertificateInfo, error) { return &revokedList, nil }, MRevoke: func(rci *db.RevokedCertificateInfo) error { revokedList = append(revokedList, *rci) return nil }, })) a.config.CRL = &config.CRLConfig{ Enabled: true, } return test{ auth: a, ctx: crlCtx, expected: nil, err: errors.New("authority.GetCertificateRevocationList: not found"), } }, "ok/crl-full": func() test { a := testAuthority(t, WithDatabase(&db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MGetCertificate: func(sn string) (*x509.Certificate, error) { return nil, errors.New("not found") }, MStoreCRL: func(i *db.CertificateRevocationListInfo) error { crlStore = *i return nil }, MGetCRL: func() (*db.CertificateRevocationListInfo, error) { return &crlStore, nil }, MGetRevokedCertificates: func() (*[]db.RevokedCertificateInfo, error) { return &revokedList, nil }, MRevoke: func(rci *db.RevokedCertificateInfo) error { revokedList = append(revokedList, *rci) return nil }, })) a.config.CRL = &config.CRLConfig{ Enabled: true, GenerateOnRevoke: true, } var ex []string for i := 0; i < 100; i++ { sn := fmt.Sprintf("%v", i) cl := jose.Claims{ Subject: sn, Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: sn, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() require.NoError(t, err) err = a.Revoke(crlCtx, &RevokeOptions{ Serial: sn, ReasonCode: reasonCode, Reason: reason, OTT: raw, }) require.NoError(t, err) ex = append(ex, sn) } return test{ auth: a, ctx: crlCtx, expected: ex, expectedReasonCode: &reasonCode, } }, "ok/crl-no-reason-code": func() test { var localRevokedList []db.RevokedCertificateInfo var localCRLStore db.CertificateRevocationListInfo a := testAuthority(t, WithDatabase(&db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { return true, nil }, MGetCertificate: func(sn string) (*x509.Certificate, error) { return nil, errors.New("not found") }, MStoreCRL: func(i *db.CertificateRevocationListInfo) error { localCRLStore = *i return nil }, MGetCRL: func() (*db.CertificateRevocationListInfo, error) { return &localCRLStore, nil }, MGetRevokedCertificates: func() (*[]db.RevokedCertificateInfo, error) { return &localRevokedList, nil }, MRevoke: func(rci *db.RevokedCertificateInfo) error { localRevokedList = append(localRevokedList, *rci) return nil }, })) a.config.CRL = &config.CRLConfig{ Enabled: true, GenerateOnRevoke: true, } var ex []string zeroReasonCode := 0 for i := 0; i < 5; i++ { sn := fmt.Sprintf("%v", i) cl := jose.Claims{ Subject: sn, Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAudience, ID: sn, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() require.NoError(t, err) err = a.Revoke(crlCtx, &RevokeOptions{ Serial: sn, ReasonCode: zeroReasonCode, Reason: reason, OTT: raw, }) require.NoError(t, err) ex = append(ex, sn) } return test{ auth: a, ctx: crlCtx, expected: ex, expectedReasonCode: &zeroReasonCode, } }, } for name, f := range tests { tc := f() t.Run(name, func(t *testing.T) { crlInfo, err := tc.auth.GetCertificateRevocationList() if tc.err != nil { assert.EqualError(t, err, tc.err.Error()) assert.Nil(t, crlInfo) return } crl, parseErr := x509.ParseRevocationList(crlInfo.Data) require.NoError(t, parseErr) var cmpList []string for _, c := range crl.RevokedCertificateEntries { cmpList = append(cmpList, c.SerialNumber.String()) // ReasonCode 0 causes Go's x509 package to omit the reasonCode // extension entirely. Parsing it back yields 0 as the zero value // of the field, not from an explicit extension. This confirms the // zero-code path produces a well-formed CRL, not that the // extension round-trips. if tc.expectedReasonCode != nil { assert.Equal(t, *tc.expectedReasonCode, c.ReasonCode) } } assert.Equal(t, tc.expected, cmpList) }) } } type notImplementedCAS struct{} func (notImplementedCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) { return nil, apiv1.NotImplementedError{} } func (notImplementedCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { return nil, apiv1.NotImplementedError{} } func (notImplementedCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { return nil, apiv1.NotImplementedError{} } func TestAuthority_GetX509Signer(t *testing.T) { auth := testAuthority(t) require.IsType(t, &softcas.SoftCAS{}, auth.x509CAService) signer := auth.x509CAService.(*softcas.SoftCAS).Signer require.NotNil(t, signer) tests := []struct { name string authority *Authority want crypto.Signer assertion assert.ErrorAssertionFunc }{ {"ok", auth, signer, assert.NoError}, {"fail", testAuthority(t, WithX509CAService(notImplementedCAS{})), nil, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.authority.GetX509Signer() tt.assertion(t, err) assert.Equal(t, tt.want, got) }) } } ================================================ FILE: authority/version.go ================================================ package authority // GlobalVersion stores the version information of the server. var GlobalVersion = Version{ Version: "0.0.0", } // Version defines the type Version struct { Version string RequireClientAuthentication bool } // Version returns the version information of the server. func (a *Authority) Version() Version { return GlobalVersion } ================================================ FILE: authority/webhook.go ================================================ package authority import ( "context" "github.com/smallstep/certificates/webhook" ) type webhookController interface { Enrich(context.Context, *webhook.RequestBody) error Authorize(context.Context, *webhook.RequestBody) error } ================================================ FILE: authority/webhook_test.go ================================================ package authority import ( "context" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/webhook" ) type mockWebhookController struct { enrichErr error authorizeErr error templateData provisioner.WebhookSetter respData map[string]any } var _ webhookController = &mockWebhookController{} func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error { for key, data := range wc.respData { wc.templateData.SetWebhook(key, data) } return wc.enrichErr } func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error { return wc.authorizeErr } ================================================ FILE: autocert/README.md ================================================ # ⚠️ Autocert has moved to https://github.com/smallstep/autocert If you're looking for hello-mTLS examples they're at https://github.com/smallstep/autocert/tree/master/examples/hello-mtls ================================================ FILE: ca/acmeClient.go ================================================ package ca import ( "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "io" "net/http" "strings" "github.com/pkg/errors" "go.step.sm/crypto/jose" "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" ) // ACMEClient implements an HTTP client to an ACME API. type ACMEClient struct { client *http.Client dirLoc string dir *acmeAPI.Directory acc *acme.Account Key *jose.JSONWebKey kid string } // NewACMEClient initializes a new ACMEClient. func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*ACMEClient, error) { // Retrieve transport from options. o := defaultClientOptions() if err := o.apply(opts); err != nil { return nil, err } tr, err := o.getTransport(endpoint) if err != nil { return nil, err } ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: endpoint, } req, err := http.NewRequest("GET", endpoint, http.NoBody) if err != nil { return nil, errors.Wrapf(err, "creating GET request %s failed", endpoint) } req.Header.Set("User-Agent", UserAgent) enforceRequestID(req) resp, err := ac.client.Do(req) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", endpoint) } defer resp.Body.Close() if resp.StatusCode >= 400 { return nil, readACMEError(resp.Body) } var dir acmeAPI.Directory if err := readJSON(resp.Body, &dir); err != nil { return nil, errors.Wrapf(err, "error reading %s", endpoint) } ac.dir = &dir ac.Key, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) if err != nil { return nil, err } nar := &acmeAPI.NewAccountRequest{ Contact: contact, TermsOfServiceAgreed: true, } payload, err := json.Marshal(nar) if err != nil { return nil, errors.Wrap(err, "error marshaling new account request") } resp, err = ac.post(payload, ac.dir.NewAccount, withJWK(ac)) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode >= 400 { return nil, readACMEError(resp.Body) } var acc acme.Account if err := readJSON(resp.Body, &acc); err != nil { return nil, errors.Wrapf(err, "error reading %s", dir.NewAccount) } ac.acc = &acc ac.kid = resp.Header.Get("Location") return ac, nil } // GetDirectory makes a directory request to the ACME api and returns an // ACME directory object. func (c *ACMEClient) GetDirectory() (*acmeAPI.Directory, error) { return c.dir, nil } // GetNonce makes a nonce request to the ACME api and returns an // ACME directory object. func (c *ACMEClient) GetNonce() (string, error) { req, err := http.NewRequest("GET", c.dir.NewNonce, http.NoBody) if err != nil { return "", errors.Wrapf(err, "creating GET request %s failed", c.dir.NewNonce) } req.Header.Set("User-Agent", UserAgent) enforceRequestID(req) resp, err := c.client.Do(req) if err != nil { return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce) } defer resp.Body.Close() if resp.StatusCode >= 400 { return "", readACMEError(resp.Body) } return resp.Header.Get("Replay-Nonce"), nil } type withHeaderOption func(so *jose.SignerOptions) func withJWK(c *ACMEClient) withHeaderOption { return func(so *jose.SignerOptions) { so.WithHeader("jwk", c.Key.Public()) } } func withKid(c *ACMEClient) withHeaderOption { return func(so *jose.SignerOptions) { so.WithHeader("kid", c.kid) } } // serialize serializes a json web signature and doesn't omit empty fields. func serialize(obj *jose.JSONWebSignature) (string, error) { raw, err := obj.CompactSerialize() if err != nil { return "", errors.Wrap(err, "error serializing JWS") } parts := strings.Split(raw, ".") msg := struct { Protected string `json:"protected"` Payload string `json:"payload"` Signature string `json:"signature"` }{Protected: parts[0], Payload: parts[1], Signature: parts[2]} b, err := json.Marshal(msg) if err != nil { return "", errors.Wrap(err, "error marshaling jws message") } return string(b), nil } func (c *ACMEClient) post(payload []byte, url string, headerOps ...withHeaderOption) (*http.Response, error) { if c.Key == nil { return nil, errors.New("acme client not configured with account") } nonce, err := c.GetNonce() if err != nil { return nil, err } so := new(jose.SignerOptions) so.WithHeader("nonce", nonce) so.WithHeader("url", url) for _, hop := range headerOps { hop(so) } signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.SignatureAlgorithm(c.Key.Algorithm), Key: c.Key.Key, }, so) if err != nil { return nil, errors.Wrap(err, "error creating JWS signer") } signed, err := signer.Sign(payload) if err != nil { return nil, errors.Errorf("error signing payload: %s", jose.TrimPrefix(err)) } raw, err := serialize(signed) if err != nil { return nil, err } req, err := http.NewRequest("POST", url, strings.NewReader(raw)) if err != nil { return nil, errors.Wrapf(err, "creating POST request %s failed", url) } req.Header.Set("Content-Type", "application/jose+json") req.Header.Set("User-Agent", UserAgent) enforceRequestID(req) resp, err := c.client.Do(req) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", c.dir.NewOrder) } return resp, nil } // NewOrder creates and returns the information for a new ACME order. func (c *ACMEClient) NewOrder(payload []byte) (*acme.Order, error) { resp, err := c.post(payload, c.dir.NewOrder, withKid(c)) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode >= 400 { return nil, readACMEError(resp.Body) } var o acme.Order if err := readJSON(resp.Body, &o); err != nil { return nil, errors.Wrapf(err, "error reading %s", c.dir.NewOrder) } o.ID = resp.Header.Get("Location") return &o, nil } // GetChallenge returns the Challenge at the given path. // With the validate parameter set to True this method will attempt to validate the // challenge before returning it. func (c *ACMEClient) GetChallenge(url string) (*acme.Challenge, error) { resp, err := c.post(nil, url, withKid(c)) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode >= 400 { return nil, readACMEError(resp.Body) } var ch acme.Challenge if err := readJSON(resp.Body, &ch); err != nil { return nil, errors.Wrapf(err, "error reading %s", url) } return &ch, nil } // ValidateChallenge returns the Challenge at the given path. // With the validate parameter set to True this method will attempt to validate the // challenge before returning it. func (c *ACMEClient) ValidateChallenge(url string) error { resp, err := c.post([]byte("{}"), url, withKid(c)) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode >= 400 { return readACMEError(resp.Body) } return nil } // ValidateWithPayload will attempt to validate the challenge at the given url // with the given attestation payload. func (c *ACMEClient) ValidateWithPayload(url string, payload []byte) error { resp, err := c.post(payload, url, withKid(c)) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode >= 400 { return readACMEError(resp.Body) } return nil } // GetAuthz returns the Authz at the given path. func (c *ACMEClient) GetAuthz(url string) (*acme.Authorization, error) { resp, err := c.post(nil, url, withKid(c)) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode >= 400 { return nil, readACMEError(resp.Body) } var az acme.Authorization if err := readJSON(resp.Body, &az); err != nil { return nil, errors.Wrapf(err, "error reading %s", url) } return &az, nil } // GetOrder returns the Order at the given path. func (c *ACMEClient) GetOrder(url string) (*acme.Order, error) { resp, err := c.post(nil, url, withKid(c)) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode >= 400 { return nil, readACMEError(resp.Body) } var o acme.Order if err := readJSON(resp.Body, &o); err != nil { return nil, errors.Wrapf(err, "error reading %s", url) } return &o, nil } // FinalizeOrder makes a finalize request to the ACME api. func (c *ACMEClient) FinalizeOrder(url string, csr *x509.CertificateRequest) error { payload, err := json.Marshal(acmeAPI.FinalizeRequest{ CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), }) if err != nil { return errors.Wrap(err, "error marshaling finalize request") } resp, err := c.post(payload, url, withKid(c)) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode >= 400 { return readACMEError(resp.Body) } return nil } // GetCertificate retrieves the certificate along with all intermediates. func (c *ACMEClient) GetCertificate(url string) (*x509.Certificate, []*x509.Certificate, error) { resp, err := c.post(nil, url, withKid(c)) if err != nil { return nil, nil, err } defer resp.Body.Close() if resp.StatusCode >= 400 { return nil, nil, readACMEError(resp.Body) } defer resp.Body.Close() bodyBytes, err := io.ReadAll(resp.Body) if err != nil { return nil, nil, errors.Wrap(err, "error reading GET certificate response") } var certs []*x509.Certificate block, rest := pem.Decode(bodyBytes) if block == nil { return nil, nil, errors.New("failed to parse any certificates from response") } for block != nil { cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, nil, errors.Wrap(err, "error parsing certificate pem response") } certs = append(certs, cert) block, rest = pem.Decode(rest) } return certs[0], certs[1:], nil } // GetAccountOrders retrieves the orders belonging to the given account. func (c *ACMEClient) GetAccountOrders() ([]string, error) { if c.acc == nil { return nil, errors.New("acme client not configured with account") } resp, err := c.post(nil, c.acc.OrdersURL, withKid(c)) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode >= 400 { return nil, readACMEError(resp.Body) } var orders []string if err := readJSON(resp.Body, &orders); err != nil { return nil, errors.Wrapf(err, "error reading %s", c.acc.OrdersURL) } return orders, nil } func readACMEError(r io.ReadCloser) error { defer r.Close() b, err := io.ReadAll(r) if err != nil { return errors.Wrap(err, "error reading from body") } ae := new(acme.Error) err = json.Unmarshal(b, &ae) // If we successfully marshaled to an ACMEError then return the ACMEError. if err != nil || ae.Error() == "" { fmt.Printf("b = %s\n", b) // Throw up our hands. return errors.Errorf("%s", b) } return ae } ================================================ FILE: ca/acmeClient_test.go ================================================ package ca import ( "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" "github.com/smallstep/certificates/api/render" ) func TestNewACMEClient(t *testing.T) { type test struct { ops []ClientOption r1, r2 interface{} rc1, rc2 int err error } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", NewAccount: srv.URL + "/bar", NewOrder: srv.URL + "/baz", RevokeCert: srv.URL + "/zip", KeyChange: srv.URL + "/blorp", } acc := acme.Account{ Contact: []string{"max", "mariano"}, Status: "valid", OrdersURL: "orders-url", } tests := map[string]func(t *testing.T) test{ "fail/client-option-error": func(t *testing.T) test { return test{ ops: []ClientOption{ func(o *clientOptions) error { return errors.New("force") }, }, err: errors.New("force"), } }, "fail/get-directory": func(t *testing.T) test { return test{ ops: []ClientOption{WithTransport(http.DefaultTransport)}, r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "fail/bad-directory": func(t *testing.T) test { return test{ ops: []ClientOption{WithTransport(http.DefaultTransport)}, r1: "foo", rc1: 200, err: errors.New("error reading http://127.0.0.1"), } }, "fail/error-post-newAccount": func(t *testing.T) test { return test{ ops: []ClientOption{WithTransport(http.DefaultTransport)}, r1: dir, rc1: 200, r2: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), rc2: 400, err: errors.New("Account does not exist"), } }, "fail/error-bad-account": func(t *testing.T) test { return test{ ops: []ClientOption{WithTransport(http.DefaultTransport)}, r1: dir, rc1: 200, r2: "foo", rc2: 200, err: errors.New("error reading http://127.0.0.1"), } }, "ok": func(t *testing.T) test { return test{ ops: []ClientOption{WithTransport(http.DefaultTransport)}, r1: dir, rc1: 200, r2: acc, rc2: 200, } }, } accLocation := "linkitylink" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) i := 0 srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header switch i { case 0: render.JSONStatus(w, r, tc.r1, tc.rc1) i++ case 1: w.Header().Set("Replay-Nonce", "abc123") render.JSONStatus(w, r, []byte{}, 200) i++ default: w.Header().Set("Location", accLocation) render.JSONStatus(w, r, tc.r2, tc.rc2) } }) if client, err := NewACMEClient(srv.URL, []string{"max", "mariano"}, tc.ops...); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, *client.dir, dir) assert.NotNil(t, client.Key) assert.NotNil(t, client.acc) assert.Equals(t, client.kid, accLocation) } } }) } } func TestACMEClient_GetDirectory(t *testing.T) { c := &ACMEClient{ dir: &acmeAPI.Directory{ NewNonce: "/foo", NewAccount: "/bar", NewOrder: "/baz", RevokeCert: "/zip", KeyChange: "/blorp", }, } dir, err := c.GetDirectory() assert.FatalError(t, err) assert.Equals(t, c.dir, dir) } func TestACMEClient_GetNonce(t *testing.T) { type test struct { r1 interface{} rc1 int err error } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. o := defaultClientOptions() assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) tr, err := o.getTransport(srv.URL) assert.FatalError(t, err) ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: srv.URL, dir: &dir, } tests := map[string]func(t *testing.T) test{ "fail/GET-nonce": func(t *testing.T) test { return test{ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "ok": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, } }, } expectedNonce := "abc123" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) render.JSONStatus(w, r, tc.r1, tc.rc1) }) if nonce, err := ac.GetNonce(); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, expectedNonce, nonce) } } }) } } func TestACMEClient_post(t *testing.T) { type test struct { payload []byte Key *jose.JSONWebKey ops []withHeaderOption r1, r2 interface{} rc1, rc2 int jwkInJWS bool client *ACMEClient err error } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. o := defaultClientOptions() assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) tr, err := o.getTransport(srv.URL) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := acme.Account{ Contact: []string{"max", "mariano"}, Status: "valid", OrdersURL: "orders-url", } ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: srv.URL, dir: &dir, Key: jwk, kid: "foobar", } tests := map[string]func(t *testing.T) test{ "fail/account-not-configured": func(t *testing.T) test { return test{ client: &ACMEClient{}, r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("acme client not configured with account"), } }, "fail/GET-nonce": func(t *testing.T) test { return test{ client: ac, r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "ok/jwk": func(t *testing.T) test { return test{ client: ac, r1: []byte{}, rc1: 200, r2: acc, rc2: 200, ops: []withHeaderOption{withJWK(ac)}, jwkInJWS: true, } }, "ok/kid": func(t *testing.T) test { return test{ client: ac, r1: []byte{}, rc1: 200, r2: acc, rc2: 200, ops: []withHeaderOption{withKid(ac)}, } }, } expectedNonce := "abc123" url := srv.URL + "/foo" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) i := 0 srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) hdr := jws.Signatures[0].Protected assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) assert.Equals(t, jwsURL, url) if tc.jwkInJWS { assert.Equals(t, hdr.JSONWebKey.KeyID, ac.Key.KeyID) } else { assert.Equals(t, hdr.KeyID, ac.kid) } render.JSONStatus(w, r, tc.r2, tc.rc2) }) if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { var res acme.Account assert.FatalError(t, readJSON(resp.Body, &res)) assert.Equals(t, res, acc) } } }) } } func TestACMEClient_NewOrder(t *testing.T) { type test struct { ops []withHeaderOption r1, r2 interface{} rc1, rc2 int err error } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", NewOrder: srv.URL + "/bar", } // Retrieve transport from options. o := defaultClientOptions() assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) tr, err := o.getTransport(srv.URL) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) now := time.Now().UTC().Round(time.Second) nor := acmeAPI.NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "dns", Value: "example.com"}, {Type: "dns", Value: "acme.example.com"}, }, NotBefore: now, NotAfter: now.Add(time.Minute), } norb, err := json.Marshal(nor) assert.FatalError(t, err) ord := acme.Order{ Status: "valid", ExpiresAt: now, // "soon" FinalizeURL: "finalize-url", } ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: srv.URL, dir: &dir, Key: jwk, kid: "foobar", } tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "fail/newOrder-error": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, ops: []withHeaderOption{withKid(ac)}, err: errors.New("The request message was malformed"), } }, "fail/bad-order": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: "foo", rc2: 200, ops: []withHeaderOption{withKid(ac)}, err: errors.New("error reading http://127.0.0.1"), } }, "ok": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: ord, rc2: 200, ops: []withHeaderOption{withKid(ac)}, } }, } expectedNonce := "abc123" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) i := 0 srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) hdr := jws.Signatures[0].Protected assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) assert.Equals(t, jwsURL, dir.NewOrder) assert.Equals(t, hdr.KeyID, ac.kid) payload, err := jws.Verify(ac.Key.Public()) assert.FatalError(t, err) assert.Equals(t, payload, norb) render.JSONStatus(w, r, tc.r2, tc.rc2) }) if res, err := ac.NewOrder(norb); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, *res, ord) } } }) } } func TestACMEClient_GetOrder(t *testing.T) { type test struct { r1, r2 interface{} rc1, rc2 int err error } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. o := defaultClientOptions() assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) tr, err := o.getTransport(srv.URL) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ord := acme.Order{ Status: "valid", ExpiresAt: time.Now().UTC().Round(time.Second), // "soon" FinalizeURL: "finalize-url", } ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: srv.URL, dir: &dir, Key: jwk, kid: "foobar", } tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "fail/getOrder-error": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } }, "fail/bad-order": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: "foo", rc2: 200, err: errors.New("error reading http://127.0.0.1"), } }, "ok": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: ord, rc2: 200, } }, } expectedNonce := "abc123" url := srv.URL + "/hullaballoo" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) i := 0 srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) hdr := jws.Signatures[0].Protected assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) assert.Equals(t, jwsURL, url) assert.Equals(t, hdr.KeyID, ac.kid) payload, err := jws.Verify(ac.Key.Public()) assert.FatalError(t, err) assert.Equals(t, len(payload), 0) render.JSONStatus(w, r, tc.r2, tc.rc2) }) if res, err := ac.GetOrder(url); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, *res, ord) } } }) } } func TestACMEClient_GetAuthz(t *testing.T) { type test struct { r1, r2 interface{} rc1, rc2 int err error } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. o := defaultClientOptions() assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) tr, err := o.getTransport(srv.URL) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) az := acme.Authorization{ Status: "valid", ExpiresAt: time.Now().UTC().Round(time.Second), Identifier: acme.Identifier{Type: "dns", Value: "example.com"}, } ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: srv.URL, dir: &dir, Key: jwk, kid: "foobar", } tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "fail/getChallenge-error": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } }, "fail/bad-challenge": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: "foo", rc2: 200, err: errors.New("error reading http://127.0.0.1"), } }, "ok": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: az, rc2: 200, } }, } expectedNonce := "abc123" url := srv.URL + "/hullaballoo" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) i := 0 srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) hdr := jws.Signatures[0].Protected assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) assert.Equals(t, jwsURL, url) assert.Equals(t, hdr.KeyID, ac.kid) payload, err := jws.Verify(ac.Key.Public()) assert.FatalError(t, err) assert.Equals(t, len(payload), 0) render.JSONStatus(w, r, tc.r2, tc.rc2) }) if res, err := ac.GetAuthz(url); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, *res, az) } } }) } } func TestACMEClient_GetChallenge(t *testing.T) { type test struct { r1, r2 interface{} rc1, rc2 int err error } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. o := defaultClientOptions() assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) tr, err := o.getTransport(srv.URL) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ch := acme.Challenge{ Type: "http-01", Status: "valid", Token: "foo", } ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: srv.URL, dir: &dir, Key: jwk, kid: "foobar", } tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "fail/getChallenge-error": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } }, "fail/bad-challenge": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: "foo", rc2: 200, err: errors.New("error reading http://127.0.0.1"), } }, "ok": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: ch, rc2: 200, } }, } expectedNonce := "abc123" url := srv.URL + "/hullaballoo" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) i := 0 srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) hdr := jws.Signatures[0].Protected assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) assert.Equals(t, jwsURL, url) assert.Equals(t, hdr.KeyID, ac.kid) payload, err := jws.Verify(ac.Key.Public()) assert.FatalError(t, err) assert.Equals(t, len(payload), 0) render.JSONStatus(w, r, tc.r2, tc.rc2) }) if res, err := ac.GetChallenge(url); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, *res, ch) } } }) } } func TestACMEClient_ValidateChallenge(t *testing.T) { type test struct { r1, r2 interface{} rc1, rc2 int err error } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. o := defaultClientOptions() assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) tr, err := o.getTransport(srv.URL) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ch := acme.Challenge{ Type: "http-01", Status: "valid", Token: "foo", } ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: srv.URL, dir: &dir, Key: jwk, kid: "foobar", } tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "fail/getChallenge-error": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } }, "fail/bad-challenge": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: "foo", rc2: 200, err: errors.New("error reading http://127.0.0.1"), } }, "ok": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: ch, rc2: 200, } }, } expectedNonce := "abc123" url := srv.URL + "/hullaballoo" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) i := 0 srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) hdr := jws.Signatures[0].Protected assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) assert.Equals(t, jwsURL, url) assert.Equals(t, hdr.KeyID, ac.kid) payload, err := jws.Verify(ac.Key.Public()) assert.FatalError(t, err) assert.Equals(t, payload, []byte("{}")) render.JSONStatus(w, r, tc.r2, tc.rc2) }) if err := ac.ValidateChallenge(url); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } }) } } func TestACMEClient_ValidateWithPayload(t *testing.T) { key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header t.Log(r.RequestURI) w.Header().Set("Replay-Nonce", "nonce") switch r.RequestURI { case "/nonce": render.JSONStatus(w, r, []byte{}, 200) return case "/fail-nonce": render.JSONStatus(w, r, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400) return } // validate jws request protected headers and body body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) hdr := jws.Signatures[0].Protected assert.Equals(t, hdr.Nonce, "nonce") _, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) assert.Equals(t, hdr.KeyID, "kid") payload, err := jws.Verify(key.Public()) assert.FatalError(t, err) assert.Equals(t, payload, []byte("the-payload")) switch r.RequestURI { case "/ok": render.JSONStatus(w, r, acme.Challenge{ Type: "device-attestation-01", Status: "valid", Token: "foo", }, 200) case "/fail": render.JSONStatus(w, r, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400) } })) defer srv.Close() type fields struct { client *http.Client dirLoc string dir *acmeAPI.Directory acc *acme.Account Key *jose.JSONWebKey kid string } type args struct { url string payload []byte } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{srv.Client(), srv.URL, &acmeAPI.Directory{ NewNonce: srv.URL + "/nonce", }, nil, key, "kid"}, args{srv.URL + "/ok", []byte("the-payload")}, false}, {"fail nonce", fields{srv.Client(), srv.URL, &acmeAPI.Directory{ NewNonce: srv.URL + "/fail-nonce", }, nil, key, "kid"}, args{srv.URL + "/ok", []byte("the-payload")}, true}, {"fail payload", fields{srv.Client(), srv.URL, &acmeAPI.Directory{ NewNonce: srv.URL + "/nonce", }, nil, key, "kid"}, args{srv.URL + "/fail", []byte("the-payload")}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &ACMEClient{ client: tt.fields.client, dirLoc: tt.fields.dirLoc, dir: tt.fields.dir, acc: tt.fields.acc, Key: tt.fields.Key, kid: tt.fields.kid, } if err := c.ValidateWithPayload(tt.args.url, tt.args.payload); (err != nil) != tt.wantErr { t.Errorf("ACMEClient.ValidateWithPayload() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestACMEClient_FinalizeOrder(t *testing.T) { type test struct { r1, r2 interface{} rc1, rc2 int err error } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. o := defaultClientOptions() assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) tr, err := o.getTransport(srv.URL) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ord := acme.Order{ Status: "valid", ExpiresAt: time.Now(), // "soon" FinalizeURL: "finalize-url", CertificateURL: "cert-url", } _csr, err := pemutil.Read("../authority/testdata/certs/foo.csr") assert.FatalError(t, err) csr, ok := _csr.(*x509.CertificateRequest) assert.Fatal(t, ok) fr := acmeAPI.FinalizeRequest{CSR: base64.RawURLEncoding.EncodeToString(csr.Raw)} frb, err := json.Marshal(fr) assert.FatalError(t, err) ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: srv.URL, dir: &dir, Key: jwk, kid: "foobar", } tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "fail/finalizeOrder-error": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } }, "fail/bad-order": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: "foo", rc2: 200, err: errors.New("error reading http://127.0.0.1"), } }, "ok": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: ord, rc2: 200, } }, } expectedNonce := "abc123" url := srv.URL + "/hullaballoo" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) i := 0 srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) hdr := jws.Signatures[0].Protected assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) assert.Equals(t, jwsURL, url) assert.Equals(t, hdr.KeyID, ac.kid) payload, err := jws.Verify(ac.Key.Public()) assert.FatalError(t, err) assert.Equals(t, payload, frb) render.JSONStatus(w, r, tc.r2, tc.rc2) }) if err := ac.FinalizeOrder(url, csr); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } }) } } func TestACMEClient_GetAccountOrders(t *testing.T) { type test struct { r1, r2 interface{} rc1, rc2 int err error client *ACMEClient } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. o := defaultClientOptions() assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) tr, err := o.getTransport(srv.URL) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) orders := []string{"foo", "bar", "baz"} ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: srv.URL, dir: &dir, Key: jwk, kid: "foobar", acc: &acme.Account{ Contact: []string{"max", "mariano"}, Status: "valid", OrdersURL: srv.URL + "/orders-url", }, } tests := map[string]func(t *testing.T) test{ "fail/account-not-configured": func(t *testing.T) test { return test{ client: &ACMEClient{}, err: errors.New("acme client not configured with account"), } }, "fail/client-post": func(t *testing.T) test { return test{ client: ac, r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "fail/getAccountOrders-error": func(t *testing.T) test { return test{ client: ac, r1: []byte{}, rc1: 200, r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } }, "fail/bad-accountOrders": func(t *testing.T) test { return test{ client: ac, r1: []byte{}, rc1: 200, r2: "foo", rc2: 200, err: errors.New("error reading http://127.0.0.1"), } }, "ok": func(t *testing.T) test { return test{ client: ac, r1: []byte{}, rc1: 200, r2: orders, rc2: 200, } }, } expectedNonce := "abc123" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) i := 0 srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) hdr := jws.Signatures[0].Protected assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) assert.Equals(t, jwsURL, ac.acc.OrdersURL) assert.Equals(t, hdr.KeyID, ac.kid) payload, err := jws.Verify(ac.Key.Public()) assert.FatalError(t, err) assert.Equals(t, len(payload), 0) render.JSONStatus(w, r, tc.r2, tc.rc2) }) if res, err := tc.client.GetAccountOrders(); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, res, orders) } } }) } } func TestACMEClient_GetCertificate(t *testing.T) { type test struct { r1, r2 interface{} certBytes []byte rc1, rc2 int err error } srv := httptest.NewServer(nil) defer srv.Close() dir := acmeAPI.Directory{ NewNonce: srv.URL + "/foo", } // Retrieve transport from options. o := defaultClientOptions() assert.FatalError(t, o.apply([]ClientOption{WithTransport(http.DefaultTransport)})) tr, err := o.getTransport(srv.URL) assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) leaf, err := pemutil.ReadCertificate("../authority/testdata/certs/foo.crt") assert.FatalError(t, err) leafb := pem.EncodeToMemory(&pem.Block{ Type: "Certificate", Bytes: leaf.Raw, }) //nolint:gocritic certBytes := append(leafb, leafb...) certBytes = append(certBytes, leafb...) ac := &ACMEClient{ client: &http.Client{ Transport: tr, }, dirLoc: srv.URL, dir: &dir, Key: jwk, kid: "foobar", acc: &acme.Account{ Contact: []string{"max", "mariano"}, Status: "valid", OrdersURL: srv.URL + "/orders-url", }, } tests := map[string]func(t *testing.T) test{ "fail/client-post": func(t *testing.T) test { return test{ r1: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc1: 400, err: errors.New("The request message was malformed"), } }, "fail/getAccountOrders-error": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: acme.NewError(acme.ErrorMalformedType, "malformed request"), rc2: 400, err: errors.New("The request message was malformed"), } }, "fail/bad-certificate": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, r2: "foo", rc2: 200, err: errors.New("failed to parse any certificates from response"), } }, "ok": func(t *testing.T) test { return test{ r1: []byte{}, rc1: 200, certBytes: certBytes, } }, } expectedNonce := "abc123" url := srv.URL + "/cert/foo" for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) i := 0 srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { render.JSONStatus(w, r, tc.r1, tc.rc1) i++ return } // validate jws request protected headers and body body, err := io.ReadAll(r.Body) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(body)) assert.FatalError(t, err) hdr := jws.Signatures[0].Protected assert.Equals(t, hdr.Nonce, expectedNonce) jwsURL, ok := hdr.ExtraHeaders["url"].(string) assert.Fatal(t, ok) assert.Equals(t, jwsURL, url) assert.Equals(t, hdr.KeyID, ac.kid) payload, err := jws.Verify(ac.Key.Public()) assert.FatalError(t, err) assert.Equals(t, len(payload), 0) if tc.certBytes != nil { w.Write(tc.certBytes) } else { render.JSONStatus(w, r, tc.r2, tc.rc2) } }) if crt, chain, err := ac.GetCertificate(url); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, crt, leaf) assert.Equals(t, chain, []*x509.Certificate{leaf, leaf}) } } }) } } ================================================ FILE: ca/adminClient.go ================================================ package ca import ( "bytes" "crypto/x509" "encoding/json" "fmt" "io" "net/http" "net/url" "path" "strconv" "time" "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" "github.com/smallstep/cli-utils/token" "github.com/smallstep/cli-utils/token/provision" "github.com/smallstep/linkedca" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" adminAPI "github.com/smallstep/certificates/authority/admin/api" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) const ( adminURLPrefix = "admin" adminIssuer = "step-admin-client/1.0" ) // AdminClient implements an HTTP client for the CA server. type AdminClient struct { client *uaClient endpoint *url.URL retryFunc RetryFunc opts []ClientOption x5cJWK *jose.JSONWebKey x5cCertFile string x5cCertStrs []string x5cCert *x509.Certificate x5cSubject string } var ErrAdminAPINotImplemented = errors.New("admin API not implemented") var ErrAdminAPINotAuthorized = errors.New("admin API not authorized") // AdminClientError is the client side representation of an // AdminError returned by the CA. type AdminClientError struct { Type string `json:"type"` Detail string `json:"detail"` Message string `json:"message"` } // Error returns the AdminClientError message as the error message func (e *AdminClientError) Error() string { return e.Message } // defaultClientOptions returns a new [clientOptions] with a // default timeout set. func defaultClientOptions() clientOptions { return clientOptions{ timeout: 15 * time.Second, } } // NewAdminClient creates a new AdminClient with the given endpoint and options. func NewAdminClient(endpoint string, opts ...ClientOption) (*AdminClient, error) { u, err := parseEndpoint(endpoint) if err != nil { return nil, err } // Retrieve transport from options. o := defaultClientOptions() if err := o.apply(opts); err != nil { return nil, err } tr, err := o.getTransport(endpoint) if err != nil { return nil, err } return &AdminClient{ client: newClient(tr, o.timeout), endpoint: u, retryFunc: o.retryFunc, opts: opts, x5cJWK: o.x5cJWK, x5cCertFile: o.x5cCertFile, x5cCertStrs: o.x5cCertStrs, x5cCert: o.x5cCert, x5cSubject: o.x5cSubject, }, nil } func (c *AdminClient) generateAdminToken(aud *url.URL) (string, error) { // A random jwt id will be used to identify duplicated tokens jwtID, err := randutil.Hex(64) // 256 bits if err != nil { return "", err } // Drop any query string parameter from the token audience aud = &url.URL{ Scheme: aud.Scheme, Host: aud.Host, Path: aud.Path, } now := time.Now() tokOptions := []token.Options{ token.WithJWTID(jwtID), token.WithKid(c.x5cJWK.KeyID), token.WithIssuer(adminIssuer), token.WithAudience(aud.String()), token.WithValidity(now, now.Add(token.DefaultValidity)), token.WithX5CCerts(c.x5cCertStrs), } tok, err := provision.New(c.x5cSubject, tokOptions...) if err != nil { return "", err } return tok.SignedString(c.x5cJWK.Algorithm, c.x5cJWK.Key) } func (c *AdminClient) retryOnError(r *http.Response) bool { if c.retryFunc != nil { if c.retryFunc(r.StatusCode) { o := defaultClientOptions() if err := o.apply(c.opts); err != nil { return false } tr, err := o.getTransport(c.endpoint.String()) if err != nil { return false } r.Body.Close() c.client.SetTransport(tr) return true } } return false } // IsEnabled checks if the admin API is enabled. func (c *AdminClient) IsEnabled() error { u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins")}) resp, err := c.client.Get(u.String()) if err != nil { return clientError(err) } defer resp.Body.Close() if resp.StatusCode < http.StatusBadRequest { return nil } switch resp.StatusCode { case http.StatusNotFound, http.StatusNotImplemented: return ErrAdminAPINotImplemented case http.StatusUnauthorized: return ErrAdminAPINotAuthorized default: return errors.Errorf("unexpected status code when performing is-enabled check for Admin API: %d", resp.StatusCode) } } // GetAdmin performs the GET /admin/admin/{id} request to the CA. func (c *AdminClient) GetAdmin(id string) (*linkedca.Admin, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins", id)}) retry: resp, err := c.client.Get(u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var adm = new(linkedca.Admin) if err := readProtoJSON(resp.Body, adm); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return adm, nil } // AdminOption is the type of options passed to the Admin method. type AdminOption func(o *adminOptions) error type adminOptions struct { cursor string limit int } func (o *adminOptions) apply(opts []AdminOption) (err error) { for _, fn := range opts { if err = fn(o); err != nil { return } } return } func (o *adminOptions) rawQuery() string { v := url.Values{} if o.cursor != "" { v.Set("cursor", o.cursor) } if o.limit > 0 { v.Set("limit", strconv.Itoa(o.limit)) } return v.Encode() } // WithAdminCursor will request the admins starting with the given cursor. func WithAdminCursor(cursor string) AdminOption { return func(o *adminOptions) error { o.cursor = cursor return nil } } // WithAdminLimit will request the given number of admins. func WithAdminLimit(limit int) AdminOption { return func(o *adminOptions) error { o.limit = limit return nil } } // GetAdminsPaginate returns a page from the the GET /admin/admins request to the CA. func (c *AdminClient) GetAdminsPaginate(opts ...AdminOption) (*adminAPI.GetAdminsResponse, error) { var retried bool o := new(adminOptions) if err := o.apply(opts); err != nil { return nil, err } u := c.endpoint.ResolveReference(&url.URL{ Path: "/admin/admins", RawQuery: o.rawQuery(), }) tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("GET", u.String(), http.NoBody) if err != nil { return nil, errors.Wrapf(err, "create GET %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var body = new(adminAPI.GetAdminsResponse) if err := readJSON(resp.Body, body); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return body, nil } // GetAdmins returns all admins from the GET /admin/admins request to the CA. func (c *AdminClient) GetAdmins(...AdminOption) ([]*linkedca.Admin, error) { var ( cursor = "" admins = []*linkedca.Admin{} ) for { resp, err := c.GetAdminsPaginate(WithAdminCursor(cursor), WithAdminLimit(100)) if err != nil { return nil, err } admins = append(admins, resp.Admins...) if resp.NextCursor == "" { return admins, nil } cursor = resp.NextCursor } } // CreateAdmin performs the POST /admin/admins request to the CA. func (c *AdminClient) CreateAdmin(createAdminRequest *adminAPI.CreateAdminRequest) (*linkedca.Admin, error) { var retried bool body, err := json.Marshal(createAdminRequest) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/admin/admins"}) tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("POST", u.String(), bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "create GET %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var adm = new(linkedca.Admin) if err := readProtoJSON(resp.Body, adm); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return adm, nil } // RemoveAdmin performs the DELETE /admin/admins/{id} request to the CA. func (c *AdminClient) RemoveAdmin(id string) error { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins", id)}) tok, err := c.generateAdminToken(u) if err != nil { return errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("DELETE", u.String(), http.NoBody) if err != nil { return errors.Wrapf(err, "create DELETE %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return readAdminError(resp.Body) } return nil } // UpdateAdmin performs the PUT /admin/admins/{id} request to the CA. func (c *AdminClient) UpdateAdmin(id string, uar *adminAPI.UpdateAdminRequest) (*linkedca.Admin, error) { var retried bool body, err := json.Marshal(uar) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins", id)}) tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("PATCH", u.String(), bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "create PATCH %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var adm = new(linkedca.Admin) if err := readProtoJSON(resp.Body, adm); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return adm, nil } // GetProvisioner performs the GET /admin/provisioners/{name} request to the CA. func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provisioner, error) { var retried bool o := new(ProvisionerOptions) if err := o.Apply(opts); err != nil { return nil, err } var u *url.URL switch { case o.ID != "": u = c.endpoint.ResolveReference(&url.URL{ Path: "/admin/provisioners/id", RawQuery: o.rawQuery(), }) case o.Name != "": u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.Name)}) default: return nil, errors.New("must set either name or id in method options") } tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("GET", u.String(), http.NoBody) if err != nil { return nil, errors.Wrapf(err, "create GET %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var prov = new(linkedca.Provisioner) if err := readProtoJSON(resp.Body, prov); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return prov, nil } // GetProvisionersPaginate performs the GET /admin/provisioners request to the CA. func (c *AdminClient) GetProvisionersPaginate(opts ...ProvisionerOption) (*adminAPI.GetProvisionersResponse, error) { var retried bool o := new(ProvisionerOptions) if err := o.Apply(opts); err != nil { return nil, err } u := c.endpoint.ResolveReference(&url.URL{ Path: "/admin/provisioners", RawQuery: o.rawQuery(), }) tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("GET", u.String(), http.NoBody) if err != nil { return nil, errors.Wrapf(err, "create GET %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var body = new(adminAPI.GetProvisionersResponse) if err := readJSON(resp.Body, body); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return body, nil } // GetProvisioners returns all admins from the GET /admin/admins request to the CA. func (c *AdminClient) GetProvisioners(...AdminOption) (provisioner.List, error) { var ( cursor = "" provs = provisioner.List{} ) for { resp, err := c.GetProvisionersPaginate(WithProvisionerCursor(cursor), WithProvisionerLimit(100)) if err != nil { return nil, err } provs = append(provs, resp.Provisioners...) if resp.NextCursor == "" { return provs, nil } cursor = resp.NextCursor } } // RemoveProvisioner performs the DELETE /admin/provisioners/{name} request to the CA. func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error { var ( u *url.URL retried bool ) o := new(ProvisionerOptions) if err := o.Apply(opts); err != nil { return err } switch { case o.ID != "": u = c.endpoint.ResolveReference(&url.URL{ Path: path.Join(adminURLPrefix, "provisioners/id"), RawQuery: o.rawQuery(), }) case o.Name != "": u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.Name)}) default: return errors.New("must set either name or id in method options") } tok, err := c.generateAdminToken(u) if err != nil { return errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("DELETE", u.String(), http.NoBody) if err != nil { return errors.Wrapf(err, "create DELETE %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return readAdminError(resp.Body) } return nil } // CreateProvisioner performs the POST /admin/provisioners request to the CA. func (c *AdminClient) CreateProvisioner(prov *linkedca.Provisioner) (*linkedca.Provisioner, error) { var retried bool body, err := protojson.Marshal(prov) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners")}) tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("POST", u.String(), bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "create POST %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var nuProv = new(linkedca.Provisioner) if err := readProtoJSON(resp.Body, nuProv); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return nuProv, nil } // UpdateProvisioner performs the PUT /admin/provisioners/{name} request to the CA. func (c *AdminClient) UpdateProvisioner(name string, prov *linkedca.Provisioner) error { var retried bool body, err := protojson.Marshal(prov) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", name)}) tok, err := c.generateAdminToken(u) if err != nil { return errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("PUT", u.String(), bytes.NewReader(body)) if err != nil { return errors.Wrapf(err, "create PUT %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return readAdminError(resp.Body) } return nil } // GetExternalAccountKeysPaginate returns a page from the GET /admin/acme/eab request to the CA. func (c *AdminClient) GetExternalAccountKeysPaginate(provisionerName, reference string, opts ...AdminOption) (*adminAPI.GetExternalAccountKeysResponse, error) { var retried bool o := new(adminOptions) if err := o.apply(opts); err != nil { return nil, err } p := path.Join(adminURLPrefix, "acme/eab", provisionerName) if reference != "" { p = path.Join(p, "/", reference) } u := c.endpoint.ResolveReference(&url.URL{ Path: p, RawQuery: o.rawQuery(), }) tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("GET", u.String(), http.NoBody) if err != nil { return nil, errors.Wrapf(err, "create GET %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var body = new(adminAPI.GetExternalAccountKeysResponse) if err := readJSON(resp.Body, body); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return body, nil } // CreateExternalAccountKey performs the POST /admin/acme/eab request to the CA. func (c *AdminClient) CreateExternalAccountKey(provisionerName string, eakRequest *adminAPI.CreateExternalAccountKeyRequest) (*linkedca.EABKey, error) { var retried bool body, err := json.Marshal(eakRequest) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "acme/eab/", provisionerName)}) tok, err := c.generateAdminToken(u) if err != nil { return nil, errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("POST", u.String(), bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "create POST %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var eabKey = new(linkedca.EABKey) if err := readProtoJSON(resp.Body, eabKey); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return eabKey, nil } // RemoveExternalAccountKey performs the DELETE /admin/acme/eab/{prov}/{key_id} request to the CA. func (c *AdminClient) RemoveExternalAccountKey(provisionerName, keyID string) error { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "acme/eab", provisionerName, "/", keyID)}) tok, err := c.generateAdminToken(u) if err != nil { return errors.Wrapf(err, "error generating admin token") } req, err := http.NewRequest("DELETE", u.String(), http.NoBody) if err != nil { return errors.Wrapf(err, "create DELETE %s request failed", u) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return readAdminError(resp.Body) } return nil } func (c *AdminClient) GetAuthorityPolicy() (*linkedca.Policy, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) if err != nil { return nil, fmt.Errorf("creating GET %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var policy = new(linkedca.Policy) if err := readProtoJSON(resp.Body, policy); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return policy, nil } func (c *AdminClient) CreateAuthorityPolicy(p *linkedca.Policy) (*linkedca.Policy, error) { var retried bool body, err := protojson.Marshal(p) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("creating POST %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var policy = new(linkedca.Policy) if err := readProtoJSON(resp.Body, policy); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return policy, nil } func (c *AdminClient) UpdateAuthorityPolicy(p *linkedca.Policy) (*linkedca.Policy, error) { var retried bool body, err := protojson.Marshal(p) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var policy = new(linkedca.Policy) if err := readProtoJSON(resp.Body, policy); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return policy, nil } func (c *AdminClient) RemoveAuthorityPolicy() error { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")}) tok, err := c.generateAdminToken(u) if err != nil { return fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody) if err != nil { return fmt.Errorf("creating DELETE %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return readAdminError(resp.Body) } return nil } func (c *AdminClient) GetProvisionerPolicy(provisionerName string) (*linkedca.Policy, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) if err != nil { return nil, fmt.Errorf("creating GET %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var policy = new(linkedca.Policy) if err := readProtoJSON(resp.Body, policy); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return policy, nil } func (c *AdminClient) CreateProvisionerPolicy(provisionerName string, p *linkedca.Policy) (*linkedca.Policy, error) { var retried bool body, err := protojson.Marshal(p) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("creating POST %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var policy = new(linkedca.Policy) if err := readProtoJSON(resp.Body, policy); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return policy, nil } func (c *AdminClient) UpdateProvisionerPolicy(provisionerName string, p *linkedca.Policy) (*linkedca.Policy, error) { var retried bool body, err := protojson.Marshal(p) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var policy = new(linkedca.Policy) if err := readProtoJSON(resp.Body, policy); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return policy, nil } func (c *AdminClient) RemoveProvisionerPolicy(provisionerName string) error { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")}) tok, err := c.generateAdminToken(u) if err != nil { return fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody) if err != nil { return fmt.Errorf("creating DELETE %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return readAdminError(resp.Body) } return nil } func (c *AdminClient) GetACMEPolicy(provisionerName, reference, keyID string) (*linkedca.Policy, error) { var retried bool var urlPath string switch { case keyID != "": urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) default: urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) } u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody) if err != nil { return nil, fmt.Errorf("creating GET %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var policy = new(linkedca.Policy) if err := readProtoJSON(resp.Body, policy); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return policy, nil } func (c *AdminClient) CreateACMEPolicy(provisionerName, reference, keyID string, p *linkedca.Policy) (*linkedca.Policy, error) { var retried bool body, err := protojson.Marshal(p) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) } var urlPath string switch { case keyID != "": urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) default: urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) } u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("creating POST %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var policy = new(linkedca.Policy) if err := readProtoJSON(resp.Body, policy); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return policy, nil } func (c *AdminClient) UpdateACMEPolicy(provisionerName, reference, keyID string, p *linkedca.Policy) (*linkedca.Policy, error) { var retried bool body, err := protojson.Marshal(p) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) } var urlPath string switch { case keyID != "": urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) default: urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) } u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var policy = new(linkedca.Policy) if err := readProtoJSON(resp.Body, policy); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return policy, nil } func (c *AdminClient) RemoveACMEPolicy(provisionerName, reference, keyID string) error { var retried bool var urlPath string switch { case keyID != "": urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID) default: urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference) } u := c.endpoint.ResolveReference(&url.URL{Path: urlPath}) tok, err := c.generateAdminToken(u) if err != nil { return fmt.Errorf("error generating admin token: %w", err) } req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody) if err != nil { return fmt.Errorf("creating DELETE %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) retry: resp, err := c.client.Do(req) if err != nil { return clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return readAdminError(resp.Body) } return nil } func (c *AdminClient) CreateProvisionerWebhook(provisionerName string, wh *linkedca.Webhook) (*linkedca.Webhook, error) { var retried bool body, err := protojson.Marshal(wh) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "webhooks")}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } retry: req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("creating POST %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var webhook = new(linkedca.Webhook) if err := readProtoJSON(resp.Body, webhook); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return webhook, nil } func (c *AdminClient) UpdateProvisionerWebhook(provisionerName string, wh *linkedca.Webhook) (*linkedca.Webhook, error) { var retried bool body, err := protojson.Marshal(wh) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "webhooks", wh.Name)}) tok, err := c.generateAdminToken(u) if err != nil { return nil, fmt.Errorf("error generating admin token: %w", err) } retry: req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return nil, readAdminError(resp.Body) } var webhook = new(linkedca.Webhook) if err := readProtoJSON(resp.Body, webhook); err != nil { return nil, fmt.Errorf("error reading %s: %w", u, err) } return webhook, nil } func (c *AdminClient) DeleteProvisionerWebhook(provisionerName, webhookName string) error { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "webhooks", webhookName)}) tok, err := c.generateAdminToken(u) if err != nil { return fmt.Errorf("error generating admin token: %w", err) } retry: req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody) if err != nil { return fmt.Errorf("creating DELETE %s request failed: %w", u, err) } req.Header.Add("Authorization", tok) resp, err := c.client.Do(req) if err != nil { return clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { retried = true goto retry } return readAdminError(resp.Body) } return nil } func readAdminError(r io.ReadCloser) error { // TODO: not all errors can be read (i.e. 404); seems to be a bigger issue defer r.Close() adminErr := new(AdminClientError) if err := json.NewDecoder(r).Decode(adminErr); err != nil { return err } return adminErr } ================================================ FILE: ca/bootstrap.go ================================================ package ca import ( "context" "crypto" "crypto/tls" "net" "net/http" "strings" "github.com/pkg/errors" "github.com/smallstep/certificates/api" "go.step.sm/crypto/jose" ) type tokenClaims struct { SHA string `json:"sha"` jose.Claims } // Bootstrap is a helper function that initializes a client with the // configuration in the bootstrap token. func Bootstrap(token string) (*Client, error) { tok, err := jose.ParseSigned(token) if err != nil { return nil, errors.Wrap(err, "error parsing token") } var claims tokenClaims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, errors.Wrap(err, "error parsing token") } // Validate bootstrap token switch { case claims.SHA == "": return nil, errors.New("invalid bootstrap token: sha claim is not present") case !strings.HasPrefix(strings.ToLower(claims.Audience[0]), "http"): return nil, errors.New("invalid bootstrap token: aud claim is not a url") } return NewClient(claims.Audience[0], WithRootSHA256(claims.SHA)) } // BootstrapClient is a helper function that using the given bootstrap token // return an http.Client configured with a Transport prepared to do TLS // connections using the client certificate returned by the certificate // authority. By default the server will kick off a routine that will renew the // certificate after 2/3rd of the certificate's lifetime has expired. // // Usage: // // // Default example with certificate rotation. // client, err := ca.BootstrapClient(ctx.Background(), token) // // // Example canceling automatic certificate rotation. // ctx, cancel := context.WithCancel(context.Background()) // defer cancel() // client, err := ca.BootstrapClient(ctx, token) // if err != nil { // return err // } // resp, err := client.Get("https://internal.smallstep.com") func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) { b, err := createBootstrap(token) //nolint:contextcheck // deeply nested context; temporary if err != nil { return nil, err } // Make sure the tlsConfig has all supported roots on RootCAs. // // The roots request is only supported if identity certificates are not // required. In all cases the current root is also added after applying all // options too. if !b.RequireClientAuth { options = append(options, AddRootsToRootCAs()) } transport, err := b.Client.Transport(ctx, b.SignResponse, b.PrivateKey, options...) if err != nil { return nil, err } return &http.Client{ Transport: transport, }, nil } // BootstrapServer is a helper function that using the given token returns the // given http.Server configured with a TLS certificate signed by the Certificate // Authority. By default the server will kick off a routine that will renew the // certificate after 2/3rd of the certificate's lifetime has expired. // // Without any extra option the server will be configured for mTLS, it will // require and verify clients certificates, but options can be used to drop this // requirement, the most common will be only verify the certs if given with // ca.VerifyClientCertIfGiven(), or add extra CAs with // ca.AddClientCA(*x509.Certificate). // // Usage: // // // Default example with certificate rotation. // srv, err := ca.BootstrapServer(context.Background(), token, &http.Server{ // Addr: ":443", // Handler: handler, // }) // // // Example canceling automatic certificate rotation. // ctx, cancel := context.WithCancel(context.Background()) // defer cancel() // srv, err := ca.BootstrapServer(ctx, token, &http.Server{ // Addr: ":443", // Handler: handler, // }) // if err != nil { // return err // } // srv.ListenAndServeTLS("", "") func BootstrapServer(ctx context.Context, token string, base *http.Server, options ...TLSOption) (*http.Server, error) { if base.TLSConfig != nil { return nil, errors.New("server TLSConfig is already set") } b, err := createBootstrap(token) //nolint:contextcheck // deeply nested context; temporary if err != nil { return nil, err } // Make sure the tlsConfig has all supported roots on RootCAs. // // The roots request is only supported if identity certificates are not // required. In all cases the current root is also added after applying all // options too. if !b.RequireClientAuth { options = append(options, AddRootsToCAs()) } tlsConfig, err := b.Client.GetServerTLSConfig(ctx, b.SignResponse, b.PrivateKey, options...) if err != nil { return nil, err } base.TLSConfig = tlsConfig return base, nil } // BootstrapListener is a helper function that using the given token returns a // TLS listener which accepts connections from an inner listener and wraps each // connection with Server. // // Without any extra option the server will be configured for mTLS, it will // require and verify clients certificates, but options can be used to drop this // requirement, the most common will be only verify the certs if given with // ca.VerifyClientCertIfGiven(), or add extra CAs with // ca.AddClientCA(*x509.Certificate). // // Usage: // // inner, err := net.Listen("tcp", ":443") // if err != nil { // return nil // } // ctx, cancel := context.WithCancel(context.Background()) // defer cancel() // lis, err := ca.BootstrapListener(ctx, token, inner) // if err != nil { // return err // } // srv := grpc.NewServer() // ... // register services // srv.Serve(lis) func BootstrapListener(ctx context.Context, token string, inner net.Listener, options ...TLSOption) (net.Listener, error) { b, err := createBootstrap(token) //nolint:contextcheck // deeply nested context; temporary if err != nil { return nil, err } // Make sure the tlsConfig has all supported roots on RootCAs. // // The roots request is only supported if identity certificates are not // required. In all cases the current root is also added after applying all // options too. if !b.RequireClientAuth { options = append(options, AddRootsToCAs()) } tlsConfig, err := b.Client.GetServerTLSConfig(ctx, b.SignResponse, b.PrivateKey, options...) if err != nil { return nil, err } return tls.NewListener(inner, tlsConfig), nil } type bootstrap struct { Client *Client RequireClientAuth bool SignResponse *api.SignResponse PrivateKey crypto.PrivateKey } func createBootstrap(token string) (*bootstrap, error) { client, err := Bootstrap(token) if err != nil { return nil, err } version, err := client.Version() if err != nil { return nil, err } req, pk, err := CreateSignRequest(token) if err != nil { return nil, err } sign, err := client.Sign(req) if err != nil { return nil, err } return &bootstrap{ Client: client, RequireClientAuth: version.RequireClientAuthentication, SignResponse: sign, PrivateKey: pk, }, nil } ================================================ FILE: ca/bootstrap_test.go ================================================ package ca import ( "context" "crypto/tls" "io" "net" "net/http" "net/http/httptest" "os" "reflect" "strings" "sync" "testing" "time" "github.com/pkg/errors" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/errs" ) func newLocalListener() net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { panic(errors.Wrap(err, "failed to listen on a port")) } } return l } func setMinCertDuration(time.Duration) func() { tmp := minCertDuration minCertDuration = 1 * time.Second return func() { minCertDuration = tmp } } func startCABootstrapServer() *httptest.Server { config, err := authority.LoadConfiguration("testdata/ca.json") if err != nil { panic(err) } srv := httptest.NewUnstartedServer(nil) config.Address = srv.Listener.Addr().String() ca, err := New(config) if err != nil { panic(err) } baseContext := buildContext(ca.auth, nil, nil, nil) srv.Config.Handler = ca.srv.Handler srv.Config.BaseContext = func(net.Listener) context.Context { return baseContext } srv.TLS = ca.srv.TLSConfig srv.StartTLS() // Force the use of GetCertificate on IPs srv.TLS.Certificates = nil return srv } func startCAServer(configFile string) (*CA, string, error) { config, err := authority.LoadConfiguration(configFile) if err != nil { return nil, "", err } listener := newLocalListener() config.Address = listener.Addr().String() caURL := "https://" + listener.Addr().String() ca, err := New(config) if err != nil { return nil, "", err } go func() { ca.srv.Serve(listener) }() return ca, caURL, nil } func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/version" { render.JSON(w, r, api.VersionResponse{ Version: "test", RequireClientAuthentication: true, }) return } for _, s := range nonAuthenticatedPaths { if strings.HasPrefix(r.URL.Path, s) || strings.HasPrefix(r.URL.Path, "/1.0"+s) { next.ServeHTTP(w, r) return } } isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 if !isMTLS { render.Error(w, r, errs.Unauthorized("missing peer certificate")) } else { next.ServeHTTP(w, r) } }) } func generateBootstrapToken(ca, subject, sha string) string { now := time.Now() jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) if err != nil { panic(err) } opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts) if err != nil { panic(err) } id, err := randutil.ASCII(64) if err != nil { panic(err) } cl := struct { SHA string `json:"sha"` jose.Claims SANS []string `json:"sans"` }{ SHA: sha, Claims: jose.Claims{ ID: id, Subject: subject, Issuer: "mariano", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: []string{ca + "/sign"}, }, SANS: []string{subject}, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() if err != nil { panic(err) } return raw } func TestBootstrap(t *testing.T) { srv := startCABootstrapServer() defer srv.Close() token := generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) if err != nil { t.Fatal(err) } type args struct { token string } tests := []struct { name string args args want *Client wantErr bool }{ {"ok", args{token}, client, false}, {"token err", args{"badtoken"}, nil, true}, {"bad claims", args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.foo.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"}, nil, true}, {"bad sha", args{generateBootstrapToken(srv.URL, "subject", "")}, nil, true}, {"bad aud", args{generateBootstrapToken("", "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := Bootstrap(tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Bootstrap() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr { if !reflect.DeepEqual(got, tt.want) { t.Errorf("Bootstrap() = %v, want %v", got, tt.want) } } else { if got == nil { t.Error("Bootstrap() = nil, want not nil") } else { if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) { t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint) } gotTR := got.client.GetTransport().(*http.Transport) wantTR := tt.want.client.GetTransport().(*http.Transport) if !equalPools(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) { t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) } } } }) } } //nolint:gosec // insecure test servers func TestBootstrapServerWithoutMTLS(t *testing.T) { srv := startCABootstrapServer() defer srv.Close() token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } mtlsServer := startCABootstrapServer() next := mtlsServer.Config.Handler mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") defer mtlsServer.Close() mtlsToken := func() string { return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } type args struct { ctx context.Context token string base *http.Server } tests := []struct { name string args args wantErr bool }{ {"ok", args{context.Background(), token(), &http.Server{}}, false}, {"ok mtls", args{context.Background(), mtlsToken(), &http.Server{}}, false}, {"fail", args{context.Background(), "bad-token", &http.Server{}}, true}, {"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base, VerifyClientCertIfGiven()) if (err != nil) != tt.wantErr { t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr { if got != nil { t.Errorf("BootstrapServer() = %v, want nil", got) } } else { expected := &http.Server{ TLSConfig: got.TLSConfig, } //nolint:govet // not comparing errors if !reflect.DeepEqual(got, expected) { t.Errorf("BootstrapServer() = %v, want %v", got, expected) } if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil { t.Errorf("BootstrapServer() invalid TLSConfig = %#v", got.TLSConfig) } } }) } } //nolint:gosec // insecure test servers func TestBootstrapServerWithMTLS(t *testing.T) { srv := startCABootstrapServer() defer srv.Close() token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } mtlsServer := startCABootstrapServer() next := mtlsServer.Config.Handler mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") defer mtlsServer.Close() mtlsToken := func() string { return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } type args struct { ctx context.Context token string base *http.Server } tests := []struct { name string args args wantErr bool }{ {"ok", args{context.Background(), token(), &http.Server{}}, false}, {"ok mtls", args{context.Background(), mtlsToken(), &http.Server{}}, false}, {"fail", args{context.Background(), "bad-token", &http.Server{}}, true}, {"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base) if (err != nil) != tt.wantErr { t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr { if got != nil { t.Errorf("BootstrapServer() = %v, want nil", got) } } else { expected := &http.Server{ TLSConfig: got.TLSConfig, } //nolint:govet // not comparing errors if !reflect.DeepEqual(got, expected) { t.Errorf("BootstrapServer() = %v, want %v", got, expected) } if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil { t.Errorf("BootstrapServer() invalid TLSConfig = %#v", got.TLSConfig) } } }) } } func TestBootstrapClient(t *testing.T) { srv := startCABootstrapServer() defer srv.Close() token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } mtlsServer := startCABootstrapServer() next := mtlsServer.Config.Handler mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") defer mtlsServer.Close() mtlsToken := func() string { return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } type args struct { ctx context.Context token string } tests := []struct { name string args args wantErr bool }{ {"ok", args{context.Background(), token()}, false}, {"ok mtls", args{context.Background(), mtlsToken()}, false}, {"fail", args{context.Background(), "bad-token"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := BootstrapClient(tt.args.ctx, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("BootstrapClient() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr { if got != nil { t.Errorf("BootstrapClient() = %v, want nil", got) } } else { tlsConfig := got.Transport.(*http.Transport).TLSClientConfig if tlsConfig == nil || tlsConfig.ClientCAs != nil || tlsConfig.GetClientCertificate == nil || tlsConfig.RootCAs == nil || tlsConfig.GetCertificate != nil { t.Errorf("BootstrapClient() invalid Transport = %#v", tlsConfig) } resp, err := got.Post(srv.URL+"/renew", "application/json", http.NoBody) if err != nil { t.Errorf("BootstrapClient() failed renewing certificate") return } var renewal api.SignResponse if err := readJSON(resp.Body, &renewal); err != nil { t.Errorf("BootstrapClient() error reading response: %v", err) return } if renewal.CaPEM.Certificate == nil || renewal.ServerPEM.Certificate == nil || len(renewal.CertChainPEM) == 0 { t.Errorf("BootstrapClient() invalid renewal response: %v", renewal) } } }) } } func TestBootstrapClientServerRotation(t *testing.T) { if os.Getenv("CI") == "true" { t.Skipf("skip until we fix https://github.com/smallstep/certificates/issues/873") } reset := setMinCertDuration(1 * time.Second) defer reset() // Configuration with current root config, err := authority.LoadConfiguration("testdata/rotate-ca-0.json") if err != nil { t.Fatal(err) } // Get local address listener := newLocalListener() config.Address = listener.Addr().String() caURL := "https://" + listener.Addr().String() // Start CA server ca, err := New(config) if err != nil { t.Fatal(err) } go func() { ca.srv.Serve(listener) }() defer ca.Stop() time.Sleep(1 * time.Second) // Create bootstrap server token := generateBootstrapToken(caURL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") //nolint:gosec // insecure test server server, err := BootstrapServer(context.Background(), token, &http.Server{ Addr: ":0", Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) }), }, RequireAndVerifyClientCert()) if err != nil { t.Fatal(err) } listener = newLocalListener() srvURL := "https://" + listener.Addr().String() go func() { server.ServeTLS(listener, "", "") }() defer server.Close() time.Sleep(1 * time.Second) // Create bootstrap client token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") client, err := BootstrapClient(context.Background(), token) if err != nil { t.Errorf("BootstrapClient() error = %v", err) return } // doTest does a request that requires mTLS doTest := func(client *http.Client) error { // test with ca resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) if err != nil { return errors.Wrap(err, "client.Post() failed") } var renew api.SignResponse if err := readJSON(resp.Body, &renew); err != nil { return errors.Wrap(err, "client.Post() error reading response") } if renew.ServerPEM.Certificate == nil || renew.CaPEM.Certificate == nil || len(renew.CertChainPEM) == 0 { return errors.New("client.Post() unexpected response found") } // test with bootstrap server resp, err = client.Get(srvURL) if err != nil { return errors.Wrapf(err, "client.Get(%s) failed", srvURL) } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return errors.Wrap(err, "client.Get() error reading response") } if string(b) != "ok" { return errors.New("client.Get() unexpected response found") } return nil } // Test with default root if err := doTest(client); err != nil { t.Errorf("Test with rotate-ca-0.json failed: %v", err) } // wait for renew time.Sleep(5 * time.Second) // Reload with configuration with current and future root ca.opts.configFile = "testdata/rotate-ca-1.json" if err := doReload(ca); err != nil { t.Errorf("ca.Reload() error = %v", err) return } if err := doTest(client); err != nil { t.Errorf("Test with rotate-ca-1.json failed: %v", err) } // wait for renew time.Sleep(5 * time.Second) // Reload with new and old root ca.opts.configFile = "testdata/rotate-ca-2.json" if err := doReload(ca); err != nil { t.Errorf("ca.Reload() error = %v", err) return } if err := doTest(client); err != nil { t.Errorf("Test with rotate-ca-2.json failed: %v", err) } // wait for renew time.Sleep(5 * time.Second) // Reload with pnly the new root ca.opts.configFile = "testdata/rotate-ca-3.json" if err := doReload(ca); err != nil { t.Errorf("ca.Reload() error = %v", err) return } if err := doTest(client); err != nil { t.Errorf("Test with rotate-ca-3.json failed: %v", err) } } func TestBootstrapClientServerFederation(t *testing.T) { reset := setMinCertDuration(1 * time.Second) defer reset() ca1, caURL1, err := startCAServer("testdata/ca.json") if err != nil { t.Fatal(err) } defer ca1.Stop() ca2, caURL2, err := startCAServer("testdata/federated-ca.json") if err != nil { t.Fatal(err) } defer ca2.Stop() // Create bootstrap server token := generateBootstrapToken(caURL1, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") //nolint:gosec // insecure test server server, err := BootstrapServer(context.Background(), token, &http.Server{ Addr: ":0", Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) }), }, RequireAndVerifyClientCert(), AddFederationToClientCAs()) if err != nil { t.Fatal(err) } listener := newLocalListener() srvURL := "https://" + listener.Addr().String() go func() { server.ServeTLS(listener, "", "") }() defer server.Close() // Create bootstrap client token = generateBootstrapToken(caURL2, "client", "c86f74bb7eb2eabef45c4f7fc6c146359ed3a5bbad416b31da5dce8093bcbffd") client, err := BootstrapClient(context.Background(), token, AddFederationToRootCAs()) if err != nil { t.Errorf("BootstrapClient() error = %v", err) return } // doTest does a request that requires mTLS doTest := func(client *http.Client) error { // test with ca resp, err := client.Post(caURL2+"/renew", "application/json", http.NoBody) if err != nil { return errors.Wrap(err, "client.Post() failed") } var renew api.SignResponse if err := readJSON(resp.Body, &renew); err != nil { return errors.Wrap(err, "client.Post() error reading response") } if renew.ServerPEM.Certificate == nil || renew.CaPEM.Certificate == nil || len(renew.CertChainPEM) == 0 { return errors.New("client.Post() unexpected response found") } // test with bootstrap server resp, err = client.Get(srvURL) if err != nil { return errors.Wrapf(err, "client.Get(%s) failed", srvURL) } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return errors.Wrap(err, "client.Get() error reading response") } if string(b) != "ok" { return errors.New("client.Get() unexpected response found") } return nil } // Test with default root if err := doTest(client); err != nil { t.Errorf("Test with rotate-ca-0.json failed: %v", err) } } // doReload uses the reload implementation but overwrites the new address with // the one being used. func doReload(ca *CA) error { config, err := authority.LoadConfiguration(ca.opts.configFile) if err != nil { return errors.Wrap(err, "error reloading ca") } newCA, err := New(config, WithPassword(ca.opts.password), WithConfigFile(ca.opts.configFile), WithDatabase(ca.auth.GetDatabase())) if err != nil { return errors.Wrap(err, "error reloading ca") } // Use same address in new server newCA.srv.Addr = ca.srv.Addr if err := ca.srv.Reload(newCA.srv); err != nil { return err } // Wait a few ms until the http server calls listener.Accept() time.Sleep(100 * time.Millisecond) return nil } func TestBootstrapListener(t *testing.T) { srv := startCABootstrapServer() defer srv.Close() token := func() string { return generateBootstrapToken(srv.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } mtlsServer := startCABootstrapServer() next := mtlsServer.Config.Handler mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") defer mtlsServer.Close() mtlsToken := func() string { return generateBootstrapToken(mtlsServer.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } type args struct { token string } tests := []struct { name string args args wantErr bool }{ {"ok", args{token()}, false}, {"ok mtls", args{mtlsToken()}, false}, {"fail", args{"bad-token"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { inner := newLocalListener() defer inner.Close() lis, err := BootstrapListener(context.Background(), tt.args.token, inner) if (err != nil) != tt.wantErr { t.Errorf("BootstrapListener() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr { if lis != nil { t.Errorf("BootstrapListener() = %v, want nil", lis) } return } wg := new(sync.WaitGroup) wg.Add(1) go func() { http.Serve(lis, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) })) wg.Done() }() defer wg.Wait() defer lis.Close() client, err := BootstrapClient(context.Background(), token()) if err != nil { t.Errorf("BootstrapClient() error = %v", err) return } resp, err := client.Get("https://" + lis.Addr().String()) if err != nil { t.Errorf("client.Get() error = %v", err) return } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("io.ReadAll() error = %v", err) return } if string(b) != "ok" { t.Errorf("client.Get() = %s, want ok", string(b)) return } }) } } ================================================ FILE: ca/ca.go ================================================ package ca import ( "bytes" "context" "crypto/tls" "crypto/x509" "errors" "fmt" "log" "net" "net/http" "net/url" "reflect" "strings" "time" "github.com/coreos/go-systemd/v22/daemon" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "golang.org/x/sync/errgroup" "github.com/smallstep/cli-utils/step" "github.com/smallstep/nosql" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" adminAPI "github.com/smallstep/certificates/authority/admin/api" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/internal/metrix" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/middleware/requestid" "github.com/smallstep/certificates/monitoring" "github.com/smallstep/certificates/scep" scepAPI "github.com/smallstep/certificates/scep/api" "github.com/smallstep/certificates/server" ) type options struct { configFile string linkedCAToken string quiet bool password []byte issuerPassword []byte sshHostPassword []byte sshUserPassword []byte database db.AuthDB x509CAService apiv1.CertificateAuthorityService tlsConfig *tls.Config } func (o *options) apply(opts []Option) { for _, fn := range opts { fn(o) } } // Option is the type of options passed to the CA constructor. type Option func(o *options) // WithConfigFile sets the given name as the configuration file name in the CA // options. func WithConfigFile(name string) Option { return func(o *options) { o.configFile = name } } // WithX509CAService provides the x509CAService to be used for signing x509 requests func WithX509CAService(svc apiv1.CertificateAuthorityService) Option { return func(o *options) { o.x509CAService = svc } } // WithPassword sets the given password as the configured password in the CA // options. func WithPassword(password []byte) Option { return func(o *options) { o.password = password } } // WithSSHHostPassword sets the given password to decrypt the key used to sign // ssh host certificates. func WithSSHHostPassword(password []byte) Option { return func(o *options) { o.sshHostPassword = password } } // WithSSHUserPassword sets the given password to decrypt the key used to sign // ssh user certificates. func WithSSHUserPassword(password []byte) Option { return func(o *options) { o.sshUserPassword = password } } // WithIssuerPassword sets the given password as the configured certificate // issuer password in the CA options. func WithIssuerPassword(password []byte) Option { return func(o *options) { o.issuerPassword = password } } // WithDatabase sets the given authority database to the CA options. func WithDatabase(d db.AuthDB) Option { return func(o *options) { o.database = d } } // WithTLSConfig sets the TLS configuration to be used by the HTTP(s) server // spun by step-ca. func WithTLSConfig(t *tls.Config) Option { return func(o *options) { o.tlsConfig = t } } // WithLinkedCAToken sets the token used to authenticate with the linkedca. func WithLinkedCAToken(token string) Option { return func(o *options) { o.linkedCAToken = token } } // WithQuiet sets the quiet flag. func WithQuiet(quiet bool) Option { return func(o *options) { o.quiet = quiet } } // CA is the type used to build the complete certificate authority. It builds // the HTTP server, set ups the middlewares and the HTTP handlers. type CA struct { auth *authority.Authority config *config.Config srv *server.Server insecureSrv *server.Server metricsSrv *server.Server opts *options renewer *TLSRenewer compactStop chan struct{} } // New creates and initializes the CA with the given configuration and options. func New(cfg *config.Config, opts ...Option) (*CA, error) { ca := &CA{ config: cfg, opts: new(options), compactStop: make(chan struct{}), } ca.opts.apply(opts) return ca.Init(cfg) } // Init initializes the CA with the given configuration. func (ca *CA) Init(cfg *config.Config) (*CA, error) { // Set password, it's ok to set nil password, the ca will prompt for them if // they are required. opts := []authority.Option{ authority.WithPassword(ca.opts.password), authority.WithSSHHostPassword(ca.opts.sshHostPassword), authority.WithSSHUserPassword(ca.opts.sshUserPassword), authority.WithIssuerPassword(ca.opts.issuerPassword), } if ca.opts.linkedCAToken != "" { opts = append(opts, authority.WithLinkedCAToken(ca.opts.linkedCAToken)) } if ca.opts.database != nil { opts = append(opts, authority.WithDatabase(ca.opts.database)) } if ca.opts.quiet { opts = append(opts, authority.WithQuietInit()) } if ca.opts.x509CAService != nil { opts = append(opts, authority.WithX509CAService(ca.opts.x509CAService)) } var meter *metrix.Meter if ca.config.MetricsAddress != "" { meter = metrix.New() opts = append(opts, authority.WithMeter(meter)) } webhookTransport := httptransport.New() opts = append(opts, authority.WithWebhookClient(&http.Client{Transport: webhookTransport}), ) auth, err := authority.New(cfg, opts...) if err != nil { return nil, err } ca.auth = auth var tlsConfig *tls.Config var clientTLSConfig *tls.Config if ca.opts.tlsConfig != nil { // try using the tls Configuration supplied by the caller log.Print("Using tls configuration supplied by the application") tlsConfig = ca.opts.tlsConfig clientTLSConfig = ca.opts.tlsConfig } else { // default to using the step-ca x509 Signer Interface log.Print("Building new tls configuration using step-ca x509 Signer Interface") tlsConfig, clientTLSConfig, err = ca.getTLSConfig(auth) if err != nil { return nil, err } } webhookTransport.TLSClientConfig = clientTLSConfig // Using chi as the main router mux := chi.NewRouter() handler := http.Handler(mux) insecureMux := chi.NewRouter() insecureHandler := http.Handler(insecureMux) // Add HEAD middleware mux.Use(middleware.GetHead) insecureMux.Use(middleware.GetHead) // Add regular CA api endpoints in / and /1.0 api.Route(mux) mux.Route("/1.0", func(r chi.Router) { api.Route(r) }) // Mount the CRL to the insecure mux insecureMux.Get("/crl", api.CRL) insecureMux.Get("/1.0/crl", api.CRL) // Add ACME api endpoints in /acme and /1.0/acme dns := cfg.DNSNames[0] u, err := url.Parse("https://" + cfg.Address) if err != nil { return nil, err } port := u.Port() if port != "" && port != "443" { dns = fmt.Sprintf("%s:%s", dns, port) } // ACME Router is only available if we have a database. var acmeDB acme.DB var acmeLinker acme.Linker if cfg.DB == nil && auth.HasACMEProvisioner() { log.Println("WARNING: No database is configured. ACME provisioners are disabled.") } if cfg.DB != nil { acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) if err != nil { return nil, fmt.Errorf("error configuring ACME DB interface: %w", err) } acmeLinker = acme.NewLinker(dns, "acme") mux.Route("/acme", func(r chi.Router) { acmeAPI.Route(r) }) // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 // of the ACME spec. mux.Route("/2.0/acme", func(r chi.Router) { acmeAPI.Route(r) }) } // Admin API Router if cfg.AuthorityConfig.EnableAdmin { adminDB := auth.GetAdminDatabase() if adminDB != nil { acmeAdminResponder := adminAPI.NewACMEAdminResponder() policyAdminResponder := adminAPI.NewPolicyAdminResponder() webhookAdminResponder := adminAPI.NewWebhookAdminResponder() mux.Route("/admin", func(r chi.Router) { adminAPI.Route( r, adminAPI.WithACMEResponder(acmeAdminResponder), adminAPI.WithPolicyResponder(policyAdminResponder), adminAPI.WithWebhookResponder(webhookAdminResponder), ) }) } } var scepAuthority *scep.Authority if ca.shouldServeSCEPEndpoints() { // get the SCEP authority configuration. Validation is // performed within the authority instantiation process. scepAuthority = auth.GetSCEP() // According to the RFC (https://tools.ietf.org/html/rfc8894#section-7.10), // SCEP operations are performed using HTTP, so that's why the API is mounted // to the insecure mux. scepPrefix := "scep" insecureMux.Route("/"+scepPrefix, func(r chi.Router) { scepAPI.Route(r) }) // The RFC also mentions usage of HTTPS, but seems to advise // against it, because of potential interoperability issues. // Currently I think it's not bad to use HTTPS also, so that's // why I've kept the API endpoints in both muxes and both HTTP // as well as HTTPS can be used to request certificates // using SCEP. mux.Route("/"+scepPrefix, func(r chi.Router) { scepAPI.Route(r) }) } // helpful routine for logging all routes //dumpRoutes(mux) //dumpRoutes(insecureMux) // Add monitoring if configured if len(cfg.Monitoring) > 0 { m, err := monitoring.New(cfg.Monitoring) if err != nil { return nil, err } handler = m.Middleware(handler) insecureHandler = m.Middleware(insecureHandler) } // Add logger if configured var legacyTraceHeader string if len(cfg.Logger) > 0 { logger, err := logging.New("ca", cfg.Logger) if err != nil { return nil, err } legacyTraceHeader = logger.GetTraceHeader() handler = logger.Middleware(handler) insecureHandler = logger.Middleware(insecureHandler) } // always use request ID middleware; traceHeader is provided for backwards compatibility (for now) handler = requestid.New(legacyTraceHeader).Middleware(handler) insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler) // Create context with all the necessary values. baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv.BaseContext = func(net.Listener) context.Context { return baseContext } // only start the insecure server if the insecure address is configured // and, currently, also only when it should serve SCEP endpoints. if ca.shouldServeInsecureServer() { // TODO: instead opt for having a single server.Server but two // http.Servers handling the HTTP and HTTPS handler? The latter // will probably introduce more complexity in terms of graceful // reload. ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil) ca.insecureSrv.BaseContext = func(net.Listener) context.Context { return baseContext } } if meter != nil { ca.metricsSrv = server.New(ca.config.MetricsAddress, meter, nil) ca.metricsSrv.BaseContext = func(net.Listener) context.Context { return baseContext } } return ca, nil } // shouldServeInsecureServer returns whether or not the insecure // server should also be started. This is (currently) only the case // if the insecure address has been configured AND when a SCEP // provisioner is configured or when a CRL is configured. func (ca *CA) shouldServeInsecureServer() bool { switch { case ca.config.InsecureAddress == "": return false case ca.shouldServeSCEPEndpoints(): return true case ca.config.CRL.IsEnabled(): return true default: return false } } // buildContext builds the server base context. func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { ctx := authority.NewContext(context.Background(), a) if authDB := a.GetDatabase(); authDB != nil { ctx = db.NewContext(ctx, authDB) } if adminDB := a.GetAdminDatabase(); adminDB != nil { ctx = admin.NewContext(ctx, adminDB) } if scepAuthority != nil { ctx = scep.NewContext(ctx, scepAuthority) } if acmeDB != nil { ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil) } return ctx } // Run starts the CA calling to the server ListenAndServe method. func (ca *CA) Run() error { if !ca.opts.quiet { authorityInfo := ca.auth.GetInfo() log.Printf("Starting %s", step.Version()) log.Printf("Documentation: https://u.step.sm/docs/ca") log.Printf("Community Discord: https://u.step.sm/discord") if step.Contexts().GetCurrent() != nil { log.Printf("Current context: %s", step.Contexts().GetCurrent().Name) } log.Printf("Config file: %s", ca.getConfigFileOutput()) baseURL := fmt.Sprintf("https://%s%s", authorityInfo.DNSNames[0], ca.config.Address[strings.LastIndex(ca.config.Address, ":"):]) log.Printf("The primary server URL is %s", baseURL) log.Printf("Root certificates are available at %s/roots.pem", baseURL) if len(authorityInfo.DNSNames) > 1 { log.Printf("Additional configured hostnames: %s", strings.Join(authorityInfo.DNSNames[1:], ", ")) } for _, crt := range authorityInfo.RootX509Certs { log.Printf("X.509 Root Fingerprint: %s", x509util.Fingerprint(crt)) } if authorityInfo.SSHCAHostPublicKey != nil { log.Printf("SSH Host CA Key: %s\n", bytes.TrimSpace(authorityInfo.SSHCAHostPublicKey)) } if authorityInfo.SSHCAUserPublicKey != nil { log.Printf("SSH User CA Key: %s\n", bytes.TrimSpace(authorityInfo.SSHCAUserPublicKey)) } } eg := new(errgroup.Group) eg.Go(func() error { ca.runCompactJob() return nil }) if ca.insecureSrv != nil { eg.Go(func() error { return ca.insecureSrv.ListenAndServe() }) } if ca.metricsSrv != nil { eg.Go(func() error { return ca.metricsSrv.ListenAndServe() }) } eg.Go(func() error { return ca.srv.ListenAndServe() }) _, _ = daemon.SdNotify(true, daemon.SdNotifyReady) err := eg.Wait() _, _ = daemon.SdNotify(true, daemon.SdNotifyStopping) // if the error is not the usual HTTP server closed error, it is // highly likely that an error occurred when starting one of the // CA servers, possibly because of a port already being in use or // some part of the configuration not being correct. This case is // handled by stopping the CA in its entirety. if !errors.Is(err, http.ErrServerClosed) { log.Println("shutting down due to startup error ...") if stopErr := ca.Stop(); stopErr != nil { err = fmt.Errorf("failed stopping CA after error occurred: %w: %w", err, stopErr) } else { err = fmt.Errorf("stopped CA after error occurred: %w", err) } } return err } // Stop stops the CA calling to the server Shutdown method. func (ca *CA) Stop() error { close(ca.compactStop) if ca.renewer != nil { ca.renewer.Stop() } if err := ca.auth.Shutdown(); err != nil { log.Printf("error stopping ca.Authority: %+v\n", err) } // Concurrently shutdown services var eg errgroup.Group if ca.insecureSrv != nil { eg.Go(func() error { return ca.insecureSrv.Shutdown() }) } if ca.metricsSrv != nil { eg.Go(func() error { return ca.metricsSrv.Shutdown() }) } if ca.srv != nil { eg.Go(func() error { return ca.srv.Shutdown() }) } // Return first error return eg.Wait() } // Reload reloads the configuration of the CA and calls to the server Reload // method. func (ca *CA) Reload() error { _, _ = daemon.SdNotify(true, daemon.SdNotifyReloading) cfg, err := config.LoadConfiguration(ca.opts.configFile) if err != nil { return fmt.Errorf("error reloading ca configuration: %w", err) } logContinue := func(reason string) { log.Println(reason) log.Println("Continuing to run with the original configuration.") log.Println("You can force a restart by sending a SIGTERM signal and then restarting the step-ca.") } // Do not allow reload if the database configuration has changed. if !reflect.DeepEqual(ca.config.DB, cfg.DB) { logContinue("Reload failed because the database configuration has changed.") return errors.New("error reloading ca: database configuration cannot change") } newCA, err := New(cfg, WithPassword(ca.opts.password), WithSSHHostPassword(ca.opts.sshHostPassword), WithSSHUserPassword(ca.opts.sshUserPassword), WithIssuerPassword(ca.opts.issuerPassword), WithLinkedCAToken(ca.opts.linkedCAToken), WithQuiet(ca.opts.quiet), WithConfigFile(ca.opts.configFile), WithDatabase(ca.auth.GetDatabase()), ) if err != nil { logContinue("Reload failed because the CA with new configuration could not be initialized.") return fmt.Errorf("error reloading ca: %w", err) } if ca.insecureSrv != nil { if err = ca.insecureSrv.Reload(newCA.insecureSrv); err != nil { logContinue("Reload failed because insecure server could not be replaced.") return fmt.Errorf("error reloading insecure server: %w", err) } } if ca.metricsSrv != nil { if err = ca.metricsSrv.Reload(newCA.metricsSrv); err != nil { logContinue("Reload failed because metrics server could not be replaced.") return fmt.Errorf("error reloading metrics server: %w", err) } } if err = ca.srv.Reload(newCA.srv); err != nil { logContinue("Reload failed because server could not be replaced.") return fmt.Errorf("error reloading server: %w", err) } // 1. Stop previous renewer // 2. Safely shutdown any internal resources (e.g. key manager) // 3. Replace ca properties // Do not replace ca.srv if ca.renewer != nil { ca.renewer.Stop() } ca.auth.CloseForReload() ca.auth = newCA.auth ca.config = newCA.config ca.opts = newCA.opts ca.renewer = newCA.renewer _, _ = daemon.SdNotify(true, daemon.SdNotifyReady) return nil } // get TLSConfig returns separate TLSConfigs for server and client with the // same self-renewing certificate. func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, *tls.Config, error) { // Create initial TLS certificate tlsCrt, err := auth.GetTLSCertificate() if err != nil { return nil, nil, err } // Start tls renewer with the new certificate. // If a renewer was started, attempt to stop it before. if ca.renewer != nil { ca.renewer.Stop() } ca.renewer, err = NewTLSRenewer(tlsCrt, auth.GetTLSCertificate) if err != nil { return nil, nil, err } ca.renewer.Run() var serverTLSConfig *tls.Config if ca.config.TLS != nil { serverTLSConfig = ca.config.TLS.TLSConfig() } else { serverTLSConfig = &tls.Config{ MinVersion: tls.VersionTLS12, } } // GetCertificate will only be called if the client supplies SNI // information or if tlsConfig.Certificates is empty. // When client requests are made using an IP address (as opposed to a domain // name) the server does not receive any SNI and may fallback to using the // first entry in the Certificates attribute; by setting the attribute to // empty we are implicitly forcing GetCertificate to be the only mechanism // by which the server can find it's own leaf Certificate. serverTLSConfig.Certificates = []tls.Certificate{} clientTLSConfig := serverTLSConfig.Clone() serverTLSConfig.GetCertificate = ca.renewer.GetCertificateForCA clientTLSConfig.GetClientCertificate = ca.renewer.GetClientCertificate // initialize a certificate pool with root CA certificates to trust when doing mTLS. certPool := x509.NewCertPool() // initialize a certificate pool with root CA certificates to trust when connecting // to webhook servers rootCAsPool, err := x509.SystemCertPool() if err != nil { return nil, nil, err } for _, crt := range auth.GetRootCertificates() { certPool.AddCert(crt) rootCAsPool.AddCert(crt) } // adding the intermediate CA certificates to the pool will allow clients that // do mTLS but don't send an intermediate to successfully connect. The intermediates // added here are used when building a certificate chain. intermediates := tlsCrt.Certificate[1:] for _, certBytes := range intermediates { cert, err := x509.ParseCertificate(certBytes) if err != nil { return nil, nil, err } certPool.AddCert(cert) rootCAsPool.AddCert(cert) } // Add support for mutual tls to renew certificates serverTLSConfig.ClientAuth = tls.VerifyClientCertIfGiven serverTLSConfig.ClientCAs = certPool clientTLSConfig.RootCAs = rootCAsPool return serverTLSConfig, clientTLSConfig, nil } // shouldServeSCEPEndpoints returns if the CA should be // configured with endpoints for SCEP. This is assumed to be // true if a SCEPService exists, which is true in case at // least one SCEP provisioner was configured. func (ca *CA) shouldServeSCEPEndpoints() bool { return ca.auth.GetSCEP() != nil } //nolint:unused // useful for debugging func dumpRoutes(mux chi.Routes) { // helpful routine for logging all routes walkFunc := func(method string, route string, _ http.Handler, _ ...func(http.Handler) http.Handler) error { fmt.Printf("%s %s\n", method, route) return nil } if err := chi.Walk(mux, walkFunc); err != nil { fmt.Printf("Logging err: %s\n", err.Error()) } } func (ca *CA) getConfigFileOutput() string { if ca.config.WasLoadedFromFile() { return ca.config.Filepath() } return "loaded from token" } // runCompactJob will run the value log garbage collector if the nosql database // supports it. func (ca *CA) runCompactJob() { caDB, ok := ca.auth.GetDatabase().(*db.DB) if !ok { return } compactor, ok := caDB.DB.(nosql.Compactor) if !ok { return } // Compact database at start. runCompact(compactor) // Compact database every minute. ticker := time.NewTicker(time.Minute) defer ticker.Stop() for { select { case <-ca.compactStop: return case <-ticker.C: runCompact(compactor) } } } // runCompact executes the compact job until it returns an error. func runCompact(c nosql.Compactor) { for err := error(nil); err == nil; { err = c.Compact(0.7) } } ================================================ FILE: ca/ca_test.go ================================================ package ca import ( "bytes" "context" "crypto" "crypto/rand" "crypto/sha1" //nolint:gosec // used to create the Subject Key Identifier by RFC 5280 "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/json" "encoding/pem" "fmt" "net/http" "net/http/httptest" "os" "strings" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" "go.step.sm/crypto/x509util" ) type ClosingBuffer struct { *bytes.Buffer } func (cb *ClosingBuffer) Close() error { return nil } func getCSR(priv interface{}) (*x509.CertificateRequest, error) { _csr := &x509.CertificateRequest{ Subject: pkix.Name{CommonName: "test.smallstep.com"}, DNSNames: []string{"test.smallstep.com"}, } csrBytes, err := x509.CreateCertificateRequest(rand.Reader, _csr, priv) if err != nil { return nil, err } return x509.ParseCertificateRequest(csrBytes) } func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { b, err := x509.MarshalPKIXPublicKey(pub) if err != nil { return nil, errors.Wrap(err, "error marshaling public key") } info := struct { Algorithm pkix.AlgorithmIdentifier SubjectPublicKey asn1.BitString }{} if _, err = asn1.Unmarshal(b, &info); err != nil { return nil, errors.Wrap(err, "error unmarshaling public key") } //nolint:gosec // used to create the Subject Key Identifier by RFC 5280 hash := sha1.Sum(info.SubjectPublicKey.Bytes) return hash[:], nil } func TestMain(m *testing.M) { DisableIdentity = true os.Exit(m.Run()) } func TestCASign(t *testing.T) { pub, priv, err := keyutil.GenerateDefaultKeyPair() assert.FatalError(t, err) asn1dn := &authority.ASN1DN{ Country: "Tazmania", Organization: "Acme Co", Locality: "Landscapes", Province: "Sudden Cliffs", StreetAddress: "TNT", CommonName: "test.smallstep.com", } config, err := authority.LoadConfiguration("testdata/ca.json") assert.FatalError(t, err) config.AuthorityConfig.Template = asn1dn ca, err := New(config) assert.FatalError(t, err) intermediateCert, err := pemutil.ReadCertificate("testdata/secrets/intermediate_ca.crt") assert.FatalError(t, err) clijwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) assert.FatalError(t, err) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: clijwk.Key}, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", clijwk.KeyID)) assert.FatalError(t, err) validAud := []string{"https://127.0.0.1:0/sign"} now := time.Now().UTC() leafExpiry := now.Add(time.Minute * 5) type signTest struct { ca *CA body string status int errMsg string } tests := map[string]func(t *testing.T) *signTest{ "fail invalid-json-body": func(t *testing.T) *signTest { return &signTest{ ca: ca, body: "invalid json", status: http.StatusBadRequest, errMsg: errs.BadRequestPrefix, } }, "fail invalid-csr-sig": func(t *testing.T) *signTest { der := []byte(`-----BEGIN CERTIFICATE REQUEST----- MIIDNjCCAh4CAQAwYzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRYwFAYDVQQH DA1TYW4gRnJhbmNpc2NvMRIwEAYDVQQKDAlzbWFsbHN0ZXAxGzAZBgNVBAMMEnRl c3Quc21hbGxzdGVwLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB ANPahliigZ38QpBLmQMS3MVKKZ5gapNjqR7LIEYoYWa4lTFiUnbwg8tSfIFcgLZr jNIxn7/98+JOJHKgS03NhFJoS5hej0LyypleOGJ0nk2qawYVKnn1ftoKjkfxkfZI a/5rsDF1jhNBspB/KPHWE0eimKQJbUiVG1zA1sExnXDecF3vJfBj+DPDWngx4yxR /jYEKjt4tQ6Ei752TbosrCHYeYXzkr6iAwiNz6vT/ewLb6b8JmuN8X6Y1I9ogDGx hntBJ1jAK8x3IGTjYbkm+mqVuCyhNcHtGfEHcBnUEzLAPrVFn8kGiAnU17FJ0uQ7 1C9CtUzgBRZCxSBm6Qs+Zs8CAwEAAaCBjTCBigYJKoZIhvcNAQkOMX0wezAMBgNV HRMBAf8EAjAAMB0GA1UdJQQWMBQGCCsGAQUFBwMCBggrBgEFBQcDATAOBgNVHQ8B Af8EBAMCBaAwHQYDVR0RBBYwFIISdGVzdC5zbWFsbHN0ZXAuY29tMB0GA1UdDgQW BBQj6N4RTAAjhV3UBYXH72mkdOGpqzANBgkqhkiG9w0BAQsFAAOCAQEAN0/ivCBk FD53SqtRmqqc7C9saoRNvV+wDi4Sg6YGLFQLjbZPJrqQURWdHtV9O3sb3p8O5erX 9Kgq3C7fqd//0mro4GZ1GTpjsPKIMocZFfH7zEhAZlvQLRKWICjoBaOwxQum2qY/ B3+ltAXb4uqGdbI0jPkkyWGN5CQhK+ZHoYe/zGtTEmHBcPxRtJJkukQQjUgZhjU2 Z7K+w3AjOxj47XLNHHlW83QYUJ2mN+mEZF9DhrZb2ydYOlpy0V2NJwv7QrmnFaDj R0v3BFLTblIp100li3oV2QaM/yESrgo9XIjEEGzCGz5cNs5ovNadufUZDCJyyT4q ZEp7knvU2psWRw== -----END CERTIFICATE REQUEST-----`) block, _ := pem.Decode(der) assert.NotNil(t, block) csr, err := x509.ParseCertificateRequest(block.Bytes) assert.FatalError(t, err) body, err := json.Marshal(&api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: csr}, OTT: "foo", }) assert.FatalError(t, err) return &signTest{ ca: ca, body: string(body), status: http.StatusBadRequest, errMsg: errs.BadRequestPrefix, } }, "fail unauthorized-ott": func(t *testing.T) *signTest { csr, err := getCSR(priv) assert.FatalError(t, err) body, err := json.Marshal(&api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: csr}, OTT: "foo", }) assert.FatalError(t, err) return &signTest{ ca: ca, body: string(body), status: http.StatusUnauthorized, errMsg: errs.UnauthorizedDefaultMsg, } }, "fail commonname-claim": func(t *testing.T) *signTest { jti, err := randutil.ASCII(32) assert.FatalError(t, err) cl := struct { jose.Claims SANS []string `json:"sans"` }{ Claims: jose.Claims{ Subject: "invalid", Issuer: "step-cli", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAud, ID: jti, }, SANS: []string{"invalid"}, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) csr, err := getCSR(priv) assert.FatalError(t, err) body, err := json.Marshal(&api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: csr}, OTT: raw, }) assert.FatalError(t, err) return &signTest{ ca: ca, body: string(body), status: http.StatusForbidden, errMsg: errs.ForbiddenPrefix, } }, "ok": func(t *testing.T) *signTest { jti, err := randutil.ASCII(32) assert.FatalError(t, err) cl := struct { jose.Claims SANS []string `json:"sans"` }{ Claims: jose.Claims{ Subject: "test.smallstep.com", Issuer: "step-cli", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAud, ID: jti, }, SANS: []string{"test.smallstep.com"}, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) csr, err := getCSR(priv) assert.FatalError(t, err) body, err := json.Marshal(&api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: csr}, OTT: raw, NotBefore: api.NewTimeDuration(now), NotAfter: api.NewTimeDuration(leafExpiry), }) assert.FatalError(t, err) return &signTest{ ca: ca, body: string(body), status: http.StatusCreated, } }, "ok-backwards-compat-missing-subject-SAN": func(t *testing.T) *signTest { jti, err := randutil.ASCII(32) assert.FatalError(t, err) cl := struct { jose.Claims SANS []string `json:"sans"` }{ Claims: jose.Claims{ Subject: "test.smallstep.com", Issuer: "step-cli", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: validAud, ID: jti, }, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() assert.FatalError(t, err) csr, err := getCSR(priv) assert.FatalError(t, err) body, err := json.Marshal(&api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: csr}, OTT: raw, NotBefore: api.NewTimeDuration(now), NotAfter: api.NewTimeDuration(leafExpiry), }) assert.FatalError(t, err) return &signTest{ ca: ca, body: string(body), status: http.StatusCreated, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) rq, err := http.NewRequest("POST", "/sign", strings.NewReader(tc.body)) assert.FatalError(t, err) rr := httptest.NewRecorder() ctx := authority.NewContext(context.Background(), tc.ca.auth) tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} resp := &http.Response{ Body: body, } if rr.Code < http.StatusBadRequest { var sign api.SignResponse assert.FatalError(t, readJSON(body, &sign)) leaf := sign.ServerPEM.Certificate intermediate := sign.CaPEM.Certificate assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second)) assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second)) assert.Equals(t, leaf.Subject.String(), pkix.Name{ Country: []string{asn1dn.Country}, Organization: []string{asn1dn.Organization}, Locality: []string{asn1dn.Locality}, StreetAddress: []string{asn1dn.StreetAddress}, Province: []string{asn1dn.Province}, CommonName: asn1dn.CommonName, }.String()) assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) assert.Equals(t, leaf.ExtKeyUsage, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"}) subjectKeyID, err := generateSubjectKeyID(pub) assert.FatalError(t, err) assert.Equals(t, leaf.SubjectKeyId, subjectKeyID) assert.Equals(t, leaf.AuthorityKeyId, intermediateCert.SubjectKeyId) realIntermediate, err := x509.ParseCertificate(intermediateCert.Raw) assert.FatalError(t, err) assert.Equals(t, intermediate, realIntermediate) } else { err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } assert.HasPrefix(t, err.Error(), tc.errMsg) } } }) } } func TestCAProvisioners(t *testing.T) { config, err := authority.LoadConfiguration("testdata/ca.json") assert.FatalError(t, err) ca, err := New(config) assert.FatalError(t, err) type ekt struct { ca *CA status int errMsg string } tests := map[string]func(t *testing.T) *ekt{ "ok": func(t *testing.T) *ekt { return &ekt{ ca: ca, status: http.StatusOK, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) rq, err := http.NewRequest("GET", "/provisioners", strings.NewReader("")) assert.FatalError(t, err) rr := httptest.NewRecorder() ctx := authority.NewContext(context.Background(), tc.ca.auth) tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} resp := &http.Response{ Body: body, } if rr.Code < http.StatusBadRequest { var resp api.ProvisionersResponse assert.FatalError(t, readJSON(body, &resp)) a, err := json.Marshal(config.AuthorityConfig.Provisioners) assert.FatalError(t, err) b, err := json.Marshal(resp.Provisioners) assert.FatalError(t, err) assert.Equals(t, a, b) } else { err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } assert.HasPrefix(t, err.Error(), tc.errMsg) } } }) } } func TestCAProvisionerEncryptedKey(t *testing.T) { config, err := authority.LoadConfiguration("testdata/ca.json") assert.FatalError(t, err) ca, err := New(config) assert.FatalError(t, err) type ekt struct { ca *CA kid string expectedKey string status int errMsg string } tests := map[string]func(t *testing.T) *ekt{ "not-found": func(t *testing.T) *ekt { return &ekt{ ca: ca, kid: "foo", status: http.StatusNotFound, errMsg: errs.NotFoundDefaultMsg, } }, "ok": func(t *testing.T) *ekt { p := config.AuthorityConfig.Provisioners[2].(*provisioner.JWK) return &ekt{ ca: ca, kid: p.Key.KeyID, expectedKey: p.EncryptedKey, status: http.StatusOK, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) rq, err := http.NewRequest("GET", fmt.Sprintf("/provisioners/%s/encrypted-key", tc.kid), strings.NewReader("")) assert.FatalError(t, err) rr := httptest.NewRecorder() ctx := authority.NewContext(context.Background(), tc.ca.auth) tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} resp := &http.Response{ Body: body, } if rr.Code < http.StatusBadRequest { var ek api.ProvisionerKeyResponse assert.FatalError(t, readJSON(body, &ek)) assert.Equals(t, ek.Key, tc.expectedKey) } else { err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } assert.HasPrefix(t, err.Error(), tc.errMsg) } } }) } } func TestCARoot(t *testing.T) { config, err := authority.LoadConfiguration("testdata/ca.json") assert.FatalError(t, err) ca, err := New(config) assert.FatalError(t, err) rootCrt, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") assert.FatalError(t, err) type rootTest struct { ca *CA sha string status int errMsg string } tests := map[string]func(t *testing.T) *rootTest{ "not-found": func(t *testing.T) *rootTest { return &rootTest{ ca: ca, sha: "foo", status: http.StatusNotFound, errMsg: `root certificate with fingerprint "foo" was not found`, } }, "success": func(t *testing.T) *rootTest { return &rootTest{ ca: ca, sha: "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", status: http.StatusOK, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) rq, err := http.NewRequest("GET", fmt.Sprintf("/root/%s", tc.sha), strings.NewReader("")) assert.FatalError(t, err) rr := httptest.NewRecorder() ctx := authority.NewContext(context.Background(), tc.ca.auth) tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} resp := &http.Response{ Body: body, } if rr.Code < http.StatusBadRequest { var root api.RootResponse assert.FatalError(t, readJSON(body, &root)) assert.Equals(t, root.RootPEM.Certificate, rootCrt) } else { err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } assert.HasPrefix(t, err.Error(), tc.errMsg) } } }) } } func TestCAHealth(t *testing.T) { config, err := authority.LoadConfiguration("testdata/ca.json") assert.FatalError(t, err) ca, err := New(config) assert.FatalError(t, err) type rootTest struct { ca *CA status int } tests := map[string]func(t *testing.T) *rootTest{ "success": func(t *testing.T) *rootTest { return &rootTest{ ca: ca, status: http.StatusOK, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) rq, err := http.NewRequest("GET", "/health", strings.NewReader("")) assert.FatalError(t, err) rr := httptest.NewRecorder() ctx := authority.NewContext(context.Background(), tc.ca.auth) tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} if rr.Code < http.StatusBadRequest { var health api.HealthResponse assert.FatalError(t, readJSON(body, &health)) assert.Equals(t, health, api.HealthResponse{Status: "ok"}) } } }) } } func TestCARenew(t *testing.T) { pub, priv, err := keyutil.GenerateDefaultKeyPair() assert.FatalError(t, err) asn1dn := &authority.ASN1DN{ Country: "Tazmania", Organization: "Acme Co", Locality: "Landscapes", Province: "Sudden Cliffs", StreetAddress: "TNT", CommonName: "test", } config, err := authority.LoadConfiguration("testdata/ca.json") assert.FatalError(t, err) config.AuthorityConfig.Template = asn1dn ca, err := New(config) assert.FatalError(t, err) assert.FatalError(t, err) intermediateCert, err := pemutil.ReadCertificate("testdata/secrets/intermediate_ca.crt") assert.FatalError(t, err) intermediateKey, err := pemutil.Read("testdata/secrets/intermediate_ca_key", pemutil.WithPassword([]byte("password"))) assert.FatalError(t, err) now := time.Now().UTC() leafExpiry := now.Add(time.Minute * 5) type renewTest struct { ca *CA tlsConnState *tls.ConnectionState status int errMsg string } tests := map[string]func(t *testing.T) *renewTest{ "request-missing-tls": func(t *testing.T) *renewTest { return &renewTest{ ca: ca, tlsConnState: nil, status: http.StatusBadRequest, errMsg: errs.BadRequestPrefix, } }, "request-missing-peer-certificate": func(t *testing.T) *renewTest { return &renewTest{ ca: ca, tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}}, status: http.StatusBadRequest, errMsg: errs.BadRequestPrefix, } }, "success": func(t *testing.T) *renewTest { cr, err := x509util.CreateCertificateRequest("test", []string{"funk"}, priv.(crypto.Signer)) assert.FatalError(t, err) cert, err := x509util.NewCertificate(cr) assert.FatalError(t, err) crt := cert.GetCertificate() crt.NotBefore = now crt.NotAfter = leafExpiry crt, err = x509util.CreateCertificate(crt, intermediateCert, pub, intermediateKey.(crypto.Signer)) assert.FatalError(t, err) return &renewTest{ ca: ca, tlsConnState: &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{crt}, }, status: http.StatusCreated, } }, } for name, genTestCase := range tests { t.Run(name, func(t *testing.T) { tc := genTestCase(t) rq, err := http.NewRequest("POST", "/renew", strings.NewReader("")) assert.FatalError(t, err) rq.TLS = tc.tlsConnState rr := httptest.NewRecorder() ctx := authority.NewContext(context.Background(), tc.ca.auth) tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} resp := &http.Response{ Body: body, } if rr.Code < http.StatusBadRequest { var sign api.SignResponse assert.FatalError(t, readJSON(body, &sign)) leaf := sign.ServerPEM.Certificate intermediate := sign.CaPEM.Certificate assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second)) assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second)) assert.Equals(t, leaf.Subject.String(), pkix.Name{ CommonName: asn1dn.CommonName, }.String()) assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) assert.Equals(t, leaf.ExtKeyUsage, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) assert.Equals(t, leaf.DNSNames, []string{"funk"}) subjectKeyID, err := generateSubjectKeyID(pub) assert.FatalError(t, err) assert.Equals(t, leaf.SubjectKeyId, subjectKeyID) assert.Equals(t, leaf.AuthorityKeyId, intermediateCert.SubjectKeyId) realIntermediate, err := x509.ParseCertificate(intermediateCert.Raw) assert.FatalError(t, err) assert.Equals(t, intermediate, realIntermediate) assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions) } else { err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } assert.HasPrefix(t, err.Error(), tc.errMsg) } } }) } } ================================================ FILE: ca/client/requestid.go ================================================ package client import "context" type contextKey struct{} // NewRequestIDContext returns a new context with the given request ID added to the // context. func NewRequestIDContext(ctx context.Context, requestID string) context.Context { return context.WithValue(ctx, contextKey{}, requestID) } // RequestIDFromContext returns the request ID from the context if it exists. // and is not empty. func RequestIDFromContext(ctx context.Context) (string, bool) { v, ok := ctx.Value(contextKey{}).(string) return v, ok && v != "" } ================================================ FILE: ca/client.go ================================================ package ca import ( "bytes" "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/sha256" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/hex" "encoding/json" "encoding/pem" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" "time" "github.com/pkg/errors" "golang.org/x/net/http2" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "github.com/smallstep/cli-utils/step" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/errs" ) // DisableIdentity is a global variable to disable the identity. var DisableIdentity = false // UserAgent will set the User-Agent header in the client requests. var UserAgent = "step-http-client/1.0" type uaClient struct { Client *http.Client } func newClient(transport http.RoundTripper, timeout time.Duration) *uaClient { return &uaClient{ Client: &http.Client{ Transport: transport, Timeout: timeout, }, } } //nolint:gosec // used in bootstrap protocol func newInsecureClient() *uaClient { return &uaClient{ Client: &http.Client{ Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}), }, } } func (c *uaClient) GetTransport() http.RoundTripper { return c.Client.Transport } func (c *uaClient) SetTransport(tr http.RoundTripper) { c.Client.Transport = tr } func (c *uaClient) CloseIdleConnections() { c.Client.CloseIdleConnections() } func (c *uaClient) Get(u string) (*http.Response, error) { return c.GetWithContext(context.Background(), u) } func (c *uaClient) GetWithContext(ctx context.Context, u string) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, "GET", u, http.NoBody) if err != nil { return nil, errors.Wrapf(err, "create GET %s request failed", u) } return c.Do(req) } func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) { return c.PostWithContext(context.Background(), u, contentType, body) } func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, body io.Reader) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, "POST", u, body) if err != nil { return nil, errors.Wrapf(err, "create POST %s request failed", u) } req.Header.Set("Content-Type", contentType) return c.Do(req) } // requestIDHeader is the header name used for propagating request IDs from // the CA client to the CA and back again. const requestIDHeader = "X-Request-Id" // newRequestID generates a new random UUIDv4 request ID. If it fails, // the request ID will be the empty string. func newRequestID() string { requestID, err := randutil.UUIDv4() if err != nil { return "" } return requestID } // enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's // empty, the context is searched for a request ID. If that's also empty, a new // request ID is generated. func enforceRequestID(r *http.Request) { if requestID := r.Header.Get(requestIDHeader); requestID == "" { if reqID, ok := client.RequestIDFromContext(r.Context()); ok { // TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been // used before by the client (unless it's a retry for the same request)? requestID = reqID } else { requestID = newRequestID() } r.Header.Set(requestIDHeader, requestID) } } func (c *uaClient) Do(req *http.Request) (*http.Response, error) { req.Header.Set("User-Agent", UserAgent) enforceRequestID(req) return c.Client.Do(req) //nolint:gosec // request to user-configured CA server } // RetryFunc defines the method used to retry a request. If it returns true, the // request will be retried once. type RetryFunc func(code int) bool // ClientOption is the type of options passed to the Client constructor. type ClientOption func(o *clientOptions) error // TransportDecorator is the type used to support customization of the HTTP // transport. type TransportDecorator func(http.RoundTripper) http.RoundTripper type clientOptions struct { transport http.RoundTripper transportDecorator TransportDecorator timeout time.Duration rootSHA256 string rootFilename string rootBundle []byte certificate tls.Certificate getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error) retryFunc RetryFunc x5cJWK *jose.JSONWebKey x5cCertFile string x5cCertStrs []string x5cCert *x509.Certificate x5cSubject string } func (o *clientOptions) apply(opts []ClientOption) (err error) { o.applyDefaultIdentity() for _, fn := range opts { if err = fn(o); err != nil { return } } return } // applyDefaultIdentity sets the options for the default identity if the // identity file is present. The identity is enabled by default. func (o *clientOptions) applyDefaultIdentity() { if DisableIdentity { return } // Do not load an identity if something fails i, err := identity.LoadDefaultIdentity() if err != nil { return } if err := i.Validate(); err != nil { return } crt, err := i.TLSCertificate() if err != nil { return } o.certificate = crt o.getClientCertificate = i.GetClientCertificateFunc() } // checkTransport checks if other ways to set up a transport have been provided. // If they have it returns an error. func (o *clientOptions) checkTransport() error { if o.transport != nil || o.rootFilename != "" || o.rootSHA256 != "" || o.rootBundle != nil { return errors.New("multiple transport methods have been configured") } return nil } // getTransport returns the transport configured in the clientOptions. func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err error) { if o.transport != nil { tr = o.transport } if o.rootFilename != "" { if tr, err = getTransportFromFile(o.rootFilename); err != nil { return nil, err } } if o.rootSHA256 != "" { if tr, err = getTransportFromSHA256(endpoint, o.rootSHA256); err != nil { return nil, err } } if o.rootBundle != nil { if tr, err = getTransportFromCABundle(o.rootBundle); err != nil { return nil, err } } // As the last option attempt to load the default root ca if tr == nil { rootFile := getRootCAPath() if _, err := os.Stat(rootFile); err == nil { if tr, err = getTransportFromFile(rootFile); err != nil { return nil, err } } if tr == nil { return nil, errors.New("a transport, a root cert, or a root sha256 must be used") } } // Add client certificate if available if o.certificate.Certificate != nil { switch tr := tr.(type) { case *http.Transport: if tr.TLSClientConfig == nil { tr.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, } } if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate } case *http2.Transport: if tr.TLSClientConfig == nil { tr.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, } } if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate } default: return nil, errors.Errorf("unsupported transport type %T", tr) } } // Wrap the transport using the decorator function if necessary return decorateRoundTripper(tr, o.transportDecorator), nil } // WithTransport adds a custom transport to the Client. It will fail if a // previous option to create the transport has been configured. func WithTransport(tr http.RoundTripper) ClientOption { return func(o *clientOptions) error { if err := o.checkTransport(); err != nil { return err } o.transport = tr return nil } } // WithTransportDecorator allows customization of the HTTP transport used by the // client. The provided function receives the configured [http.RoundTripper] and // can wrap it with additional functionality. func WithTransportDecorator(fn TransportDecorator) ClientOption { return func(o *clientOptions) error { o.transportDecorator = fn return nil } } // WithInsecure adds a insecure transport that bypasses TLS verification. func WithInsecure() ClientOption { return func(o *clientOptions) error { o.transport = &http.Transport{ Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{ MinVersion: tls.VersionTLS12, //nolint:gosec // insecure option InsecureSkipVerify: true, }, } return nil } } // WithRootFile will create the transport using the given root certificate. It // will fail if a previous option to create the transport has been configured. func WithRootFile(filename string) ClientOption { return func(o *clientOptions) error { if err := o.checkTransport(); err != nil { return err } o.rootFilename = filename return nil } } // WithRootSHA256 will create the transport using an insecure client to retrieve // the root certificate using its fingerprint. It will fail if a previous option // to create the transport has been configured. func WithRootSHA256(sum string) ClientOption { return func(o *clientOptions) error { if err := o.checkTransport(); err != nil { return err } o.rootSHA256 = sum return nil } } // WithCABundle will create the transport using the given root certificates. It // will fail if a previous option to create the transport has been configured. func WithCABundle(bundle []byte) ClientOption { return func(o *clientOptions) error { if err := o.checkTransport(); err != nil { return err } o.rootBundle = bundle return nil } } // WithCertificate will set the given certificate as the TLS client certificate // in the client. func WithCertificate(cert tls.Certificate) ClientOption { return func(o *clientOptions) error { o.certificate = cert return nil } } // WithAdminX5C will set the given file as the X5C certificate for use // by the client. func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile string) ClientOption { return func(o *clientOptions) error { // Get private key from given key file var ( err error opts []jose.Option ) if passwordFile != "" { opts = append(opts, jose.WithPasswordFile(passwordFile)) } blk, err := pemutil.Serialize(key) if err != nil { return errors.Wrap(err, "error serializing private key") } o.x5cJWK, err = jose.ParseKey(pem.EncodeToMemory(blk), opts...) if err != nil { return err } o.x5cCertStrs, err = jose.ValidateX5C(certs, o.x5cJWK.Key) if err != nil { return errors.Wrap(err, "error validating x5c certificate chain and key for use in x5c header") } o.x5cCert = certs[0] switch leaf := certs[0]; { case leaf.Subject.CommonName != "": o.x5cSubject = leaf.Subject.CommonName case len(leaf.DNSNames) > 0: o.x5cSubject = leaf.DNSNames[0] case len(leaf.EmailAddresses) > 0: o.x5cSubject = leaf.EmailAddresses[0] } return nil } } // WithRetryFunc defines a method used to retry a request. func WithRetryFunc(fn RetryFunc) ClientOption { return func(o *clientOptions) error { o.retryFunc = fn return nil } } // WithTimeout defines the time limit for requests made by this client. The // timeout includes connection time, any redirects, and reading the response // body. func WithTimeout(d time.Duration) ClientOption { return func(o *clientOptions) error { o.timeout = d return nil } } func getTransportFromFile(filename string) (http.RoundTripper, error) { data, err := os.ReadFile(filename) // #nosec G703 -- filename is based on configuration; data read from file is processed with expected format if err != nil { return nil, errors.Wrapf(err, "error reading %s", filename) } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(data) { return nil, errors.Errorf("error parsing %s: no certificates found", filename) } return getDefaultTransport(&tls.Config{ MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, }), nil } func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { u, err := parseEndpoint(endpoint) if err != nil { return nil, err } caClient := &Client{endpoint: u} root, err := caClient.Root(sum) if err != nil { return nil, err } pool := x509.NewCertPool() pool.AddCert(root.RootPEM.Certificate) return getDefaultTransport(&tls.Config{ MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, }), nil } func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) { pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(bundle) { return nil, errors.New("error parsing ca bundle: no certificates found") } return getDefaultTransport(&tls.Config{ MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, }), nil } // parseEndpoint parses and validates the given endpoint. It supports general // URLs like https://ca.smallstep.com[:port][/path], and incomplete URLs like // ca.smallstep.com[:port][/path]. func parseEndpoint(endpoint string) (*url.URL, error) { u, err := url.Parse(endpoint) if err != nil { return nil, errors.Wrapf(err, "error parsing endpoint '%s'", endpoint) } // URLs are generally parsed as: // [scheme:][//[userinfo@]host][/]path[?query][#fragment] // But URLs that do not start with a slash after the scheme are interpreted as // scheme:opaque[?query][#fragment] if u.Opaque == "" { if u.Scheme == "" { u.Scheme = "https" } if u.Host == "" { // endpoint looks like ca.smallstep.com or ca.smallstep.com/1.0/sign if u.Path != "" { parts := strings.SplitN(u.Path, "/", 2) u.Host = parts[0] if len(parts) == 2 { u.Path = parts[1] } else { u.Path = "" } return parseEndpoint(u.String()) } return nil, errors.Errorf("error parsing endpoint: url '%s' is not valid", endpoint) } return u, nil } // scheme:opaque[?query][#fragment] // endpoint looks like ca.smallstep.com:443 or ca.smallstep.com:443/1.0/sign return parseEndpoint("https://" + endpoint) } // ProvisionerOption is the type of options passed to the Provisioner method. type ProvisionerOption func(o *ProvisionerOptions) error // ProvisionerOptions stores options for the provisioner CRUD API. type ProvisionerOptions struct { Cursor string Limit int ID string Name string } // Apply caches provisioner options on a struct for later use. func (o *ProvisionerOptions) Apply(opts []ProvisionerOption) (err error) { for _, fn := range opts { if err = fn(o); err != nil { return } } return } func (o *ProvisionerOptions) rawQuery() string { v := url.Values{} if o.Cursor != "" { v.Set("cursor", o.Cursor) } if o.Limit > 0 { v.Set("limit", strconv.Itoa(o.Limit)) } if o.ID != "" { v.Set("id", o.ID) } if o.Name != "" { v.Set("name", o.Name) } return v.Encode() } // WithProvisionerCursor will request the provisioners starting with the given cursor. func WithProvisionerCursor(cursor string) ProvisionerOption { return func(o *ProvisionerOptions) error { o.Cursor = cursor return nil } } // WithProvisionerLimit will request the given number of provisioners. func WithProvisionerLimit(limit int) ProvisionerOption { return func(o *ProvisionerOptions) error { o.Limit = limit return nil } } // WithProvisionerID will request the given provisioner. func WithProvisionerID(id string) ProvisionerOption { return func(o *ProvisionerOptions) error { o.ID = id return nil } } // WithProvisionerName will request the given provisioner. func WithProvisionerName(name string) ProvisionerOption { return func(o *ProvisionerOptions) error { o.Name = name return nil } } // Client implements an HTTP client for the CA server. type Client struct { client *uaClient endpoint *url.URL retryFunc RetryFunc timeout time.Duration opts []ClientOption transportDecorator TransportDecorator } // NewClient creates a new Client with the given endpoint and options. func NewClient(endpoint string, opts ...ClientOption) (*Client, error) { u, err := parseEndpoint(endpoint) if err != nil { return nil, err } // Retrieve transport from options. o := defaultClientOptions() if err := o.apply(opts); err != nil { return nil, err } tr, err := o.getTransport(endpoint) if err != nil { return nil, err } return &Client{ client: newClient(tr, o.timeout), endpoint: u, retryFunc: o.retryFunc, timeout: o.timeout, opts: opts, transportDecorator: o.transportDecorator, }, nil } func (c *Client) retryOnError(r *http.Response) bool { if c.retryFunc != nil { if c.retryFunc(r.StatusCode) { o := defaultClientOptions() if err := o.apply(c.opts); err != nil { return false } tr, err := o.getTransport(c.endpoint.String()) if err != nil { return false } r.Body.Close() c.client.SetTransport(tr) return true } } return false } // GetCaURL returns the configured CA url. func (c *Client) GetCaURL() string { return c.endpoint.String() } // GetRootCAs returns the RootCAs certificate pool from the configured // transport. func (c *Client) GetRootCAs() *x509.CertPool { switch t := c.client.GetTransport().(type) { case *http.Transport: if t.TLSClientConfig != nil { return t.TLSClientConfig.RootCAs } return nil case *http2.Transport: if t.TLSClientConfig != nil { return t.TLSClientConfig.RootCAs } return nil default: return nil } } // SetTransport updates the transport of the internal HTTP client. func (c *Client) SetTransport(tr http.RoundTripper) { c.client.SetTransport(tr) } // CloseIdleConnections closes any connections on its Transport which were // previously connected from previous requests but are now sitting idle in a // "keep-alive" state. It does not interrupt any connections currently in use. func (c *Client) CloseIdleConnections() { c.client.CloseIdleConnections() } // Version performs the version request to the CA with an empty context and returns the // api.VersionResponse struct. func (c *Client) Version() (*api.VersionResponse, error) { return c.VersionWithContext(context.Background()) } // VersionWithContext performs the version request to the CA with the provided context // and returns the api.VersionResponse struct. func (c *Client) VersionWithContext(ctx context.Context) (*api.VersionResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/version"}) retry: resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var version api.VersionResponse if err := readJSON(resp.Body, &version); err != nil { return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Version; error reading %s", u) } return &version, nil } // Health performs the health request to the CA with an empty context // and returns the api.HealthResponse struct. func (c *Client) Health() (*api.HealthResponse, error) { return c.HealthWithContext(context.Background()) } // HealthWithContext performs the health request to the CA with the provided context // and returns the api.HealthResponse struct. func (c *Client) HealthWithContext(ctx context.Context) (*api.HealthResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/health"}) retry: resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var health api.HealthResponse if err := readJSON(resp.Body, &health); err != nil { return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Health; error reading %s", u) } return &health, nil } // Root performs the root request to the CA with an empty context and the provided // SHA256 and returns the api.RootResponse struct. It uses an insecure client, but // it checks the resulting root certificate with the given SHA256, returning an error // if they do not match. func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) { return c.RootWithContext(context.Background(), sha256Sum) } // RootWithContext performs the root request to the CA with an empty context and the provided // SHA256 and returns the api.RootResponse struct. It uses an insecure client, but // it checks the resulting root certificate with the given SHA256, returning an error // if they do not match. func (c *Client) RootWithContext(ctx context.Context, sha256Sum string) (*api.RootResponse, error) { var retried bool sha256Sum = strings.ToLower(strings.ReplaceAll(sha256Sum, "-", "")) u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum}) retry: resp, err := newInsecureClient().GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var root api.RootResponse if err := readJSON(resp.Body, &root); err != nil { return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Root; error reading %s", u) } // verify the sha256 sum := sha256.Sum256(root.RootPEM.Raw) if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) { return nil, errs.BadRequest("root certificate fingerprint does not match") } return &root, nil } // Sign performs the sign request to the CA with an empty context and returns // the api.SignResponse struct. func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) { return c.SignWithContext(context.Background(), req) } // SignWithContext performs the sign request to the CA with the provided context // and returns the api.SignResponse struct. func (c *Client) SignWithContext(ctx context.Context, req *api.SignRequest) (*api.SignResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "client.Sign; error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"}) retry: resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Sign; error reading %s", u) } // Add tls.ConnectionState: // We'll extract the root certificate from the verified chains sign.TLS = resp.TLS return &sign, nil } // Renew performs the renew request to the CA with an empty context and // returns the api.SignResponse struct. func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { return c.RenewWithContext(context.Background(), tr) } // RenewWithContext performs the renew request to the CA with the provided context // and returns the api.SignResponse struct. func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) httpClient := &http.Client{Transport: tr} retry: req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") resp, err := httpClient.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Renew; error reading %s", u) } return &sign, nil } // RenewWithToken performs the renew request to the CA with the given // authorization token and and empty context and returns the api.SignResponse struct. // This method is generally used to renew an expired certificate. func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) { return c.RenewWithTokenAndContext(context.Background(), token) } // RenewWithTokenAndContext performs the renew request to the CA with the given // authorization token and context and returns the api.SignResponse struct. // This method is generally used to renew an expired certificate. func (c *Client) RenewWithTokenAndContext(ctx context.Context, token string) (*api.SignResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody) if err != nil { return nil, errors.Wrapf(err, "create POST %s request failed", u) } req.Header.Add("Authorization", "Bearer "+token) retry: resp, err := c.client.Do(req) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u) } return &sign, nil } // Rekey performs the rekey request to the CA with an empty context and // returns the api.SignResponse struct. func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { return c.RekeyWithContext(context.Background(), req, tr) } // RekeyWithContext performs the rekey request to the CA with the provided context // and returns the api.SignResponse struct. func (c *Client) RekeyWithContext(ctx context.Context, req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"}) httpClient := &http.Client{Transport: tr} retry: httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") resp, err := httpClient.Do(httpReq) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; error reading %s", u) } return &sign, nil } // Revoke performs the revoke request to the CA with an empty context and returns // the api.RevokeResponse struct. func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) { return c.RevokeWithContext(context.Background(), req, tr) } // RevokeWithContext performs the revoke request to the CA with the provided context and // returns the api.RevokeResponse struct. func (c *Client) RevokeWithContext(ctx context.Context, req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } var uaClient *uaClient retry: if tr != nil { uaClient = newClient(tr, c.timeout) } else { uaClient = c.client } u := c.endpoint.ResolveReference(&url.URL{Path: "/revoke"}) resp, err := uaClient.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var revoke api.RevokeResponse if err := readJSON(resp.Body, &revoke); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &revoke, nil } // Provisioners performs the provisioners request to the CA with an empty context // and returns the api.ProvisionersResponse struct with a map of provisioners. // // ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to // paginate the provisioners. func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) { return c.ProvisionersWithContext(context.Background(), opts...) } // ProvisionersWithContext performs the provisioners request to the CA with the provided context // and returns the api.ProvisionersResponse struct with a map of provisioners. // // ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to // paginate the provisioners. func (c *Client) ProvisionersWithContext(ctx context.Context, opts ...ProvisionerOption) (*api.ProvisionersResponse, error) { var retried bool o := new(ProvisionerOptions) if err := o.Apply(opts); err != nil { return nil, err } u := c.endpoint.ResolveReference(&url.URL{ Path: "/provisioners", RawQuery: o.rawQuery(), }) retry: resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var provisioners api.ProvisionersResponse if err := readJSON(resp.Body, &provisioners); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &provisioners, nil } // ProvisionerKey performs the request to the CA with an empty context to get // the encrypted key for the given provisioner kid and returns the api.ProvisionerKeyResponse // struct with the encrypted key. func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) { return c.ProvisionerKeyWithContext(context.Background(), kid) } // ProvisionerKeyWithContext performs the request to the CA with the provided context to get // the encrypted key for the given provisioner kid and returns the api.ProvisionerKeyResponse // struct with the encrypted key. func (c *Client) ProvisionerKeyWithContext(ctx context.Context, kid string) (*api.ProvisionerKeyResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"}) retry: resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var key api.ProvisionerKeyResponse if err := readJSON(resp.Body, &key); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &key, nil } // Roots performs the get roots request to the CA with an empty context // and returns the api.RootsResponse struct. func (c *Client) Roots() (*api.RootsResponse, error) { return c.RootsWithContext(context.Background()) } // RootsWithContext performs the get roots request to the CA with the provided context // and returns the api.RootsResponse struct. func (c *Client) RootsWithContext(ctx context.Context) (*api.RootsResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"}) retry: resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var roots api.RootsResponse if err := readJSON(resp.Body, &roots); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &roots, nil } // Federation performs the get federation request to the CA with an empty context // and returns the api.FederationResponse struct. func (c *Client) Federation() (*api.FederationResponse, error) { return c.FederationWithContext(context.Background()) } // FederationWithContext performs the get federation request to the CA with the provided context // and returns the api.FederationResponse struct. func (c *Client) FederationWithContext(ctx context.Context) (*api.FederationResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"}) retry: resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var federation api.FederationResponse if err := readJSON(resp.Body, &federation); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &federation, nil } // SSHSign performs the POST /ssh/sign request to the CA with an empty context // and returns the api.SSHSignResponse struct. func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) { return c.SSHSignWithContext(context.Background(), req) } // SSHSignWithContext performs the POST /ssh/sign request to the CA with the provided context // and returns the api.SSHSignResponse struct. func (c *Client) SSHSignWithContext(ctx context.Context, req *api.SSHSignRequest) (*api.SSHSignResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"}) retry: resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var sign api.SSHSignResponse if err := readJSON(resp.Body, &sign); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &sign, nil } // SSHRenew performs the POST /ssh/renew request to the CA with an empty context // and returns the api.SSHRenewResponse struct. func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) { return c.SSHRenewWithContext(context.Background(), req) } // SSHRenewWithContext performs the POST /ssh/renew request to the CA with the provided context // and returns the api.SSHRenewResponse struct. func (c *Client) SSHRenewWithContext(ctx context.Context, req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/renew"}) retry: resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var renew api.SSHRenewResponse if err := readJSON(resp.Body, &renew); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &renew, nil } // SSHRekey performs the POST /ssh/rekey request to the CA with an empty context // and returns the api.SSHRekeyResponse struct. func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) { return c.SSHRekeyWithContext(context.Background(), req) } // SSHRekeyWithContext performs the POST /ssh/rekey request to the CA with the provided context // and returns the api.SSHRekeyResponse struct. func (c *Client) SSHRekeyWithContext(ctx context.Context, req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/rekey"}) retry: resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var rekey api.SSHRekeyResponse if err := readJSON(resp.Body, &rekey); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &rekey, nil } // SSHRevoke performs the POST /ssh/revoke request to the CA with an empty context // and returns the api.SSHRevokeResponse struct. func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) { return c.SSHRevokeWithContext(context.Background(), req) } // SSHRevokeWithContext performs the POST /ssh/revoke request to the CA with the provided context // and returns the api.SSHRevokeResponse struct. func (c *Client) SSHRevokeWithContext(ctx context.Context, req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/revoke"}) retry: resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var revoke api.SSHRevokeResponse if err := readJSON(resp.Body, &revoke); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &revoke, nil } // SSHRoots performs the GET /ssh/roots request to the CA with an empty context // and returns the api.SSHRootsResponse struct. func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) { return c.SSHRootsWithContext(context.Background()) } // SSHRootsWithContext performs the GET /ssh/roots request to the CA with the provided context // and returns the api.SSHRootsResponse struct. func (c *Client) SSHRootsWithContext(ctx context.Context) (*api.SSHRootsResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/roots"}) retry: resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var keys api.SSHRootsResponse if err := readJSON(resp.Body, &keys); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &keys, nil } // SSHFederation performs the get /ssh/federation request to the CA with an empty context // and returns the api.SSHRootsResponse struct. func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) { return c.SSHFederationWithContext(context.Background()) } // SSHFederationWithContext performs the get /ssh/federation request to the CA with the provided context // and returns the api.SSHRootsResponse struct. func (c *Client) SSHFederationWithContext(ctx context.Context) (*api.SSHRootsResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/federation"}) retry: resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var keys api.SSHRootsResponse if err := readJSON(resp.Body, &keys); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &keys, nil } // SSHConfig performs the POST /ssh/config request to the CA with an empty context // to get the ssh configuration templates. func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) { return c.SSHConfigWithContext(context.Background(), req) } // SSHConfigWithContext performs the POST /ssh/config request to the CA with the provided context // to get the ssh configuration templates. func (c *Client) SSHConfigWithContext(ctx context.Context, req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/config"}) retry: resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var cfg api.SSHConfigResponse if err := readJSON(resp.Body, &cfg); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &cfg, nil } // SSHCheckHost performs the POST /ssh/check-host request to the CA with an empty context, // the principal and a token and returns the api.SSHCheckPrincipalResponse. func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalResponse, error) { return c.SSHCheckHostWithContext(context.Background(), principal, token) } // SSHCheckHostWithContext performs the POST /ssh/check-host request to the CA with the provided context, // principal and token and returns the api.SSHCheckPrincipalResponse. func (c *Client) SSHCheckHostWithContext(ctx context.Context, principal, token string) (*api.SSHCheckPrincipalResponse, error) { var retried bool body, err := json.Marshal(&api.SSHCheckPrincipalRequest{ Type: provisioner.SSHHostCert, Principal: principal, Token: token, }) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request", errs.WithMessage("Failed to marshal the check-host request")) } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"}) retry: resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var check api.SSHCheckPrincipalResponse if err := readJSON(resp.Body, &check); err != nil { return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", []any{u, errs.WithMessage("Failed to parse response from /ssh/check-host endpoint")}...) } return &check, nil } // SSHGetHosts performs the GET /ssh/get-hosts request to the CA with an empty context. func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) { return c.SSHGetHostsWithContext(context.Background()) } // SSHGetHostsWithContext performs the GET /ssh/get-hosts request to the CA with the provided context. func (c *Client) SSHGetHostsWithContext(ctx context.Context) (*api.SSHGetHostsResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/hosts"}) retry: resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var hosts api.SSHGetHostsResponse if err := readJSON(resp.Body, &hosts); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &hosts, nil } // SSHBastion performs the POST /ssh/bastion request to the CA with an empty context. func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) { return c.SSHBastionWithContext(context.Background(), req) } // SSHBastionWithContext performs the POST /ssh/bastion request to the CA with the provided context. func (c *Client) SSHBastionWithContext(ctx context.Context, req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "client.SSHBastion; error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"}) retry: resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } if resp.StatusCode >= 400 { if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } return nil, readError(resp) } var bastion api.SSHBastionResponse if err := readJSON(resp.Body, &bastion); err != nil { return nil, errors.Wrapf(err, "client.SSHBastion; error reading %s", u) } return &bastion, nil } // RootFingerprint is a helper method that returns the current root fingerprint. // It does an health connection and gets the fingerprint from the TLS verified chains. func (c *Client) RootFingerprint() (string, error) { return c.RootFingerprintWithContext(context.Background()) } // RootFingerprintWithContext is a helper method that returns the current root fingerprint. // It does an health connection and gets the fingerprint from the TLS verified chains. func (c *Client) RootFingerprintWithContext(ctx context.Context) (string, error) { u := c.endpoint.ResolveReference(&url.URL{Path: "/health"}) resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return "", clientError(err) } defer resp.Body.Close() if resp.TLS == nil || len(resp.TLS.VerifiedChains) == 0 { return "", errors.New("missing verified chains") } lastChain := resp.TLS.VerifiedChains[len(resp.TLS.VerifiedChains)-1] if len(lastChain) == 0 { return "", errors.New("missing verified chains") } return x509util.Fingerprint(lastChain[len(lastChain)-1]), nil } // CreateSignRequest is a helper function that given an x509 OTT returns a // simple but secure sign request as well as the private key used. func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) { token, err := jose.ParseSigned(ott) if err != nil { return nil, nil, errors.Wrap(err, "error parsing ott") } var claims authority.Claims if err := token.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, nil, errors.Wrap(err, "error parsing ott") } pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return nil, nil, errors.Wrap(err, "error generating key") } dnsNames, ips, emails, uris := x509util.SplitSANs(claims.SANs) if claims.Email != "" { emails = append(emails, claims.Email) } template := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: claims.Subject, }, SignatureAlgorithm: x509.ECDSAWithSHA256, DNSNames: dnsNames, IPAddresses: ips, EmailAddresses: emails, URIs: uris, } csr, err := x509.CreateCertificateRequest(rand.Reader, template, pk) if err != nil { return nil, nil, errors.Wrap(err, "error creating certificate request") } cr, err := x509.ParseCertificateRequest(csr) if err != nil { return nil, nil, errors.Wrap(err, "error parsing certificate request") } if err := cr.CheckSignature(); err != nil { return nil, nil, errors.Wrap(err, "error signing certificate request") } return &api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: cr}, OTT: ott, }, pk, nil } // CreateCertificateRequest creates a new CSR with the given common name and // SANs. If no san is provided the commonName will set also a SAN. func CreateCertificateRequest(commonName string, sans ...string) (*api.CertificateRequest, crypto.PrivateKey, error) { key, err := keyutil.GenerateDefaultKey() if err != nil { return nil, nil, err } return createCertificateRequest(commonName, sans, key) } // CreateIdentityRequest returns a new CSR to create the identity. If an // identity was already present it reuses the private key. func CreateIdentityRequest(commonName string, sans ...string) (*api.CertificateRequest, crypto.PrivateKey, error) { var identityKey crypto.PrivateKey if i, err := identity.LoadDefaultIdentity(); err == nil && i.Key != "" { if k, err := pemutil.Read(i.Key); err == nil { identityKey = k } } if identityKey == nil { return CreateCertificateRequest(commonName, sans...) } return createCertificateRequest(commonName, sans, identityKey) } // LoadDefaultIdentity is a wrapper for identity.LoadDefaultIdentity. func LoadDefaultIdentity() (*identity.Identity, error) { return identity.LoadDefaultIdentity() } // WriteDefaultIdentity is a wrapper for identity.WriteDefaultIdentity. func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { return identity.WriteDefaultIdentity(certChain, key) } func createCertificateRequest(commonName string, sans []string, key crypto.PrivateKey) (*api.CertificateRequest, crypto.PrivateKey, error) { if len(sans) == 0 { sans = []string{commonName} } dnsNames, ips, emails, uris := x509util.SplitSANs(sans) template := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: commonName, }, DNSNames: dnsNames, IPAddresses: ips, EmailAddresses: emails, URIs: uris, } csr, err := x509.CreateCertificateRequest(rand.Reader, template, key) if err != nil { return nil, nil, err } cr, err := x509.ParseCertificateRequest(csr) if err != nil { return nil, nil, err } if err := cr.CheckSignature(); err != nil { return nil, nil, err } return &api.CertificateRequest{CertificateRequest: cr}, key, nil } // getRootCAPath returns the path where the root CA is stored based on the // STEPPATH environment variable. func getRootCAPath() string { return filepath.Join(step.Path(), "certs", "root_ca.crt") } func readJSON(r io.ReadCloser, v interface{}) error { defer r.Close() return json.NewDecoder(r).Decode(v) } func readProtoJSON(r io.ReadCloser, m proto.Message) error { defer r.Close() data, err := io.ReadAll(r) if err != nil { return err } return protojson.Unmarshal(data, m) } func readError(r *http.Response) error { defer r.Body.Close() apiErr := new(errs.Error) if err := json.NewDecoder(r.Body).Decode(apiErr); err != nil { return fmt.Errorf("failed decoding CA error response: %w", err) } apiErr.RequestID = r.Header.Get("X-Request-Id") return apiErr } func clientError(err error) error { var uerr *url.Error if errors.As(err, &uerr) { return fmt.Errorf("client %s %s failed: %w", strings.ToUpper(uerr.Op), uerr.URL, uerr.Err) } return fmt.Errorf("client request failed: %w", err) } func decorateRoundTripper(tr http.RoundTripper, td TransportDecorator) http.RoundTripper { if td != nil { return td(tr) } return tr } ================================================ FILE: ca/client_test.go ================================================ package ca import ( "bytes" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "encoding/json" "encoding/pem" "errors" "net/http" "net/http/httptest" "net/url" "reflect" "strings" "testing" "time" "github.com/google/uuid" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/errs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" ) const ( rootPEM = `-----BEGIN CERTIFICATE----- MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7 qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY /iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/ zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6 yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx -----END CERTIFICATE-----` certPEM = `-----BEGIN CERTIFICATE----- MIIDujCCAqKgAwIBAgIIE31FZVaPXTUwDQYJKoZIhvcNAQEFBQAwSTELMAkGA1UE BhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJbmMxJTAjBgNVBAMTHEdvb2dsZSBJbnRl cm5ldCBBdXRob3JpdHkgRzIwHhcNMTQwMTI5MTMyNzQzWhcNMTQwNTI5MDAwMDAw WjBpMQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwN TW91bnRhaW4gVmlldzETMBEGA1UECgwKR29vZ2xlIEluYzEYMBYGA1UEAwwPbWFp bC5nb29nbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfRrObuSW5T7q 5CnSEqefEmtH4CCv6+5EckuriNr1CjfVvqzwfAhopXkLrq45EQm8vkmf7W96XJhC 7ZM0dYi1/qOCAU8wggFLMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAa BgNVHREEEzARgg9tYWlsLmdvb2dsZS5jb20wCwYDVR0PBAQDAgeAMGgGCCsGAQUF BwEBBFwwWjArBggrBgEFBQcwAoYfaHR0cDovL3BraS5nb29nbGUuY29tL0dJQUcy LmNydDArBggrBgEFBQcwAYYfaHR0cDovL2NsaWVudHMxLmdvb2dsZS5jb20vb2Nz cDAdBgNVHQ4EFgQUiJxtimAuTfwb+aUtBn5UYKreKvMwDAYDVR0TAQH/BAIwADAf BgNVHSMEGDAWgBRK3QYWG7z2aLV29YG2u2IaulqBLzAXBgNVHSAEEDAOMAwGCisG AQQB1nkCBQEwMAYDVR0fBCkwJzAloCOgIYYfaHR0cDovL3BraS5nb29nbGUuY29t L0dJQUcyLmNybDANBgkqhkiG9w0BAQUFAAOCAQEAH6RYHxHdcGpMpFE3oxDoFnP+ gtuBCHan2yE2GRbJ2Cw8Lw0MmuKqHlf9RSeYfd3BXeKkj1qO6TVKwCh+0HdZk283 TZZyzmEOyclm3UGFYe82P/iDFt+CeQ3NpmBg+GoaVCuWAARJN/KfglbLyyYygcQq 0SgeDh8dRKUiaW3HQSoYvTvdTuqzwK4CXsr3b5/dAOY8uMuG/IAR3FgwTbZ1dtoW RvOTa8hYiU6A475WuZKyEHcwnGYe57u2I2KbMgcKjPniocj4QzgYsVAVKW3IwaOh yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA== -----END CERTIFICATE-----` csrPEM = `-----BEGIN CERTIFICATE REQUEST----- MIIEYjCCAkoCAQAwHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0ZXAuY29tMIICIjAN BgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuCpifZfoZhYNywfpnPa21NezXgtn wrWBFE6xhVzE7YDSIqtIsj8aR7R8zwEymxfv5j5298LUy/XSmItVH31CsKyfcGqN QM0PZr9XY3z5V6qchGMqjzt/jqlYMBHujcxIFBfz4HATxSgKyvHqvw14ESsS2huu 7jowx+XTKbFYgKcXrjBkvOej5FXD3ehkg0jDA2UAJNdfKmrc1BBEaaqOtfh7eyU2 HU7+5gxH8C27IiCAmNj719E0B99Nu2MUw6aLFIM4xAcRga33Avevx6UuXZZIEepe V1sihrkcnDK9Vsxkme5erXzvAoOiRusiC2iIomJHJrdRM5ReEU+N+Tl1Kxq+rk7H /qAq78wVm07M1/GGi9SUMObZS4WuJpM6whlikIAEbv9iV+CK0sv/Jr/AADdGMmQU lwk+Q0ZNE8p4ZuWILv/dtLDtDVBpnrrJ9e8duBtB0lGcG8MdaUCQ346EI4T0Sgx0 hJ+wMq8zYYFfPIZEHC8o9p1ywWN9ySpJ8Zj/5ubmx9v2bY67GbuVFEa8iAp+S00x /Z8nD6/JsoKtexuHyGr3ixWFzlBqXDuugukIDFUOVDCbuGw4Io4/hEMu4Zz0TIFk Uu/wf2z75Tt8EkosKLu2wieKcY7n7Vhog/0tqexqWlWtJH0tvq4djsGoSvA62WPs 0iXXj+aZIARPNhECAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4ICAQA0vyHIndAkIs/I Nnz5yZWCokRjokoKv3Aj4VilyjncL+W0UIPULLU/47ZyoHVSUj2t8gknr9xu/Kd+ g/2z0RiF3CIp8IUH49w/HYWaR95glzVNAAzr8qD9UbUqloLVQW3lObSRGtezhdZO sspw5dC+inhAb1LZhx8PVxB3SAeJ8h11IEBr0s2Hxt9viKKd7YPtIFZkZdOkVx4R if1DMawj1P6fEomf8z7m+dmbUYTqqosbCbRL01mzEga/kF6JyH/OzpNlcsAiyM8e BxPWH6TtPqwmyy4y7j1outmM0RnyUw5A0HmIbWh+rHpXiHVsnNqse0XfzmaxM8+z dxYeDax8aMWZKfvY1Zew+xIxl7DtEy1BpxrZcawumJYt5+LL+bwF/OtL0inQLnw8 zyqydsXNdrpIQJnfmWPld7ThWbQw2FBE70+nFSxHeG2ULnpF3M9xf6ZNAF4gqaNE Q7vMNPBWrJWu+A++vHY61WGET+h4lY3GFr2I8OE4IiHPQi1D7Y0+fwOmStwuRPM4 2rARcJChNdiYBkkuvs4kixKTTjdXhB8RQtuBSrJ0M1tzq2qMbm7F8G01rOg4KlXU 58jHzJwr1K7cx0lpWfGTtc5bseCGtTKmDBXTziw04yl8eE1+ZFOganixGwCtl4Tt DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w== -----END CERTIFICATE REQUEST-----` ) func mustKey(t *testing.T) *ecdsa.PrivateKey { t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) return priv } func parseCertificate(t *testing.T, data string) *x509.Certificate { t.Helper() block, _ := pem.Decode([]byte(data)) if block == nil { require.Fail(t, "failed to parse certificate PEM") return nil } cert, err := x509.ParseCertificate(block.Bytes) require.NoError(t, err, "failed to parse certificate") return cert } func parseCertificateRequest(t *testing.T, csrPEM string) *x509.CertificateRequest { t.Helper() block, _ := pem.Decode([]byte(csrPEM)) if block == nil { require.Fail(t, "failed to parse certificate request PEM") return nil } csr, err := x509.ParseCertificateRequest(block.Bytes) require.NoError(t, err, "failed to parse certificate request") return csr } func equalJSON(t *testing.T, a, b interface{}) bool { t.Helper() if reflect.DeepEqual(a, b) { return true } ab, err := json.Marshal(a) require.NoError(t, err) bb, err := json.Marshal(b) require.NoError(t, err) return bytes.Equal(ab, bb) } func TestClient_Version(t *testing.T) { ok := &api.VersionResponse{Version: "test"} tests := []struct { name string response interface{} responseCode int wantErr bool expectedErr error }{ {"ok", ok, 200, false, nil}, {"500", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, {"404", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Version() if tt.wantErr { if assert.Error(t, err) { assert.EqualError(t, err, tt.expectedErr.Error()) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_Health(t *testing.T) { ok := &api.HealthResponse{Status: "ok"} tests := []struct { name string response interface{} responseCode int wantErr bool expectedErr error }{ {"ok", ok, 200, false, nil}, {"not ok", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Health() if tt.wantErr { if assert.Error(t, err) { assert.EqualError(t, err, tt.expectedErr.Error()) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_Root(t *testing.T) { ok := &api.RootResponse{ RootPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, } tests := []struct { name string shasum string response interface{} responseCode int wantErr bool expectedErr error }{ {"ok", "a047a37fa2d2e118a4f5095fe074d6cfe0e352425a7632bf8659c03919a6c81d", ok, 200, false, nil}, {"not found", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { expected := "/root/" + tt.shasum if r.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", r.RequestURI, expected) } render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Root(tt.shasum) if tt.wantErr { if assert.Error(t, err) { assert.EqualError(t, err, tt.expectedErr.Error()) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_Sign(t *testing.T) { ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ {Certificate: parseCertificate(t, certPEM)}, {Certificate: parseCertificate(t, rootPEM)}, }, } request := &api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)}, OTT: "the-ott", NotBefore: api.NewTimeDuration(time.Now()), NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)), } tests := []struct { name string request *api.SignRequest response interface{} responseCode int wantErr bool expectedErr error }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body := new(api.SignRequest) if err := read.JSON(r.Body, body); err != nil { e, ok := tt.response.(error) require.True(t, ok, "response expected to be error type") render.Error(w, r, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { if !reflect.DeepEqual(body, &api.SignRequest{}) { t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) } } else { t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) } } render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Sign(tt.request) if tt.wantErr { if assert.Error(t, err) { assert.EqualError(t, err, tt.expectedErr.Error()) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_Revoke(t *testing.T) { ok := &api.RevokeResponse{Status: "ok"} request := &api.RevokeRequest{ Serial: "sn", OTT: "the-ott", ReasonCode: 4, } tests := []struct { name string request *api.RevokeRequest response interface{} responseCode int wantErr bool expectedErr error }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body := new(api.RevokeRequest) if err := read.JSON(r.Body, body); err != nil { e, ok := tt.response.(error) require.True(t, ok, "response expected to be error type") render.Error(w, r, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { if !reflect.DeepEqual(body, &api.RevokeRequest{}) { t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) } } else { t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) } } render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Revoke(tt.request, nil) if tt.wantErr { if assert.Error(t, err) { assert.True(t, strings.HasPrefix(err.Error(), tt.expectedErr.Error())) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_Renew(t *testing.T) { ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ {Certificate: parseCertificate(t, certPEM)}, {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { name string response interface{} responseCode int wantErr bool err error }{ {"ok", ok, 200, false, nil}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Renew(nil) if tt.wantErr { if assert.Error(t, err) { var sc render.StatusCodedError if assert.ErrorAs(t, err, &sc) { assert.Equal(t, tt.responseCode, sc.StatusCode()) } assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_RenewWithToken(t *testing.T) { ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ {Certificate: parseCertificate(t, certPEM)}, {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { name string response interface{} responseCode int wantErr bool err error }{ {"ok", ok, 200, false, nil}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer token" { render.JSONStatus(w, r, errs.InternalServer("force"), 500) } else { render.JSONStatus(w, r, tt.response, tt.responseCode) } }) got, err := c.RenewWithToken("token") if tt.wantErr { if assert.Error(t, err) { var sc render.StatusCodedError if assert.ErrorAs(t, err, &sc) { assert.Equal(t, tt.responseCode, sc.StatusCode()) } assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_Rekey(t *testing.T) { ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ {Certificate: parseCertificate(t, certPEM)}, {Certificate: parseCertificate(t, rootPEM)}, }, } request := &api.RekeyRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)}, } tests := []struct { name string request *api.RekeyRequest response interface{} responseCode int wantErr bool err error }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Rekey(tt.request, nil) if tt.wantErr { if assert.Error(t, err) { var sc render.StatusCodedError if assert.ErrorAs(t, err, &sc) { assert.Equal(t, tt.responseCode, sc.StatusCode()) } assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_Provisioners(t *testing.T) { ok := &api.ProvisionersResponse{ Provisioners: provisioner.List{}, } internalServerError := errs.InternalServer("Internal Server Error") tests := []struct { name string args []ProvisionerOption expectedURI string response interface{} responseCode int wantErr bool }{ {"ok", nil, "/provisioners", ok, 200, false}, {"ok with cursor", []ProvisionerOption{WithProvisionerCursor("abc")}, "/provisioners?cursor=abc", ok, 200, false}, {"ok with limit", []ProvisionerOption{WithProvisionerLimit(10)}, "/provisioners?limit=10", ok, 200, false}, {"ok with cursor+limit", []ProvisionerOption{WithProvisionerCursor("abc"), WithProvisionerLimit(10)}, "/provisioners?cursor=abc&limit=10", ok, 200, false}, {"fail", nil, "/provisioners", internalServerError, 500, true}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.RequestURI != tt.expectedURI { t.Errorf("RequestURI = %s, want %s", r.RequestURI, tt.expectedURI) } render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Provisioners(tt.args...) if tt.wantErr { if assert.Error(t, err) { assert.True(t, strings.HasPrefix(err.Error(), errs.InternalServerErrorDefaultMsg)) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_ProvisionerKey(t *testing.T) { ok := &api.ProvisionerKeyResponse{ Key: "an encrypted key", } tests := []struct { name string kid string response interface{} responseCode int wantErr bool err error }{ {"ok", "kid", ok, 200, false, nil}, {"fail", "invalid", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { expected := "/provisioners/" + tt.kid + "/encrypted-key" if r.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", r.RequestURI, expected) } render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.ProvisionerKey(tt.kid) if tt.wantErr { if assert.Error(t, err) { var sc render.StatusCodedError if assert.ErrorAs(t, err, &sc) { assert.Equal(t, tt.responseCode, sc.StatusCode()) } assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_Roots(t *testing.T) { ok := &api.RootsResponse{ Certificates: []api.Certificate{ {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { name string response interface{} responseCode int wantErr bool err error }{ {"ok", ok, 200, false, nil}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Roots() if tt.wantErr { if assert.Error(t, err) { var sc render.StatusCodedError if assert.ErrorAs(t, err, &sc) { assert.Equal(t, tt.responseCode, sc.StatusCode()) } assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_Federation(t *testing.T) { ok := &api.FederationResponse{ Certificates: []api.Certificate{ {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { name string response interface{} responseCode int wantErr bool err error }{ {"ok", ok, 200, false, nil}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.Federation() if tt.wantErr { if assert.Error(t, err) { var sc render.StatusCodedError if assert.ErrorAs(t, err, &sc) { assert.Equal(t, tt.responseCode, sc.StatusCode()) } assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_SSHRoots(t *testing.T) { key, err := ssh.NewPublicKey(mustKey(t).Public()) require.NoError(t, err) ok := &api.SSHRootsResponse{ HostKeys: []api.SSHPublicKey{{PublicKey: key}}, UserKeys: []api.SSHPublicKey{{PublicKey: key}}, } tests := []struct { name string response interface{} responseCode int wantErr bool err error }{ {"ok", ok, 200, false, nil}, {"not found", errs.NotFound("force"), 404, true, errors.New(errs.NotFoundDefaultMsg)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.SSHRoots() if tt.wantErr { if assert.Error(t, err) { var sc render.StatusCodedError if assert.ErrorAs(t, err, &sc) { assert.Equal(t, tt.responseCode, sc.StatusCode()) } assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func Test_parseEndpoint(t *testing.T) { expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"} expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"} type args struct { endpoint string } tests := []struct { name string args args want *url.URL wantErr bool }{ {"ok", args{"https://ca.smallstep.com"}, expected1, false}, {"ok no scheme", args{"//ca.smallstep.com"}, expected1, false}, {"ok only host", args{"ca.smallstep.com"}, expected1, false}, {"ok no bars", args{"https://ca.smallstep.com"}, expected1, false}, {"ok schema, host and path", args{"https://ca.smallstep.com/1.0/sign"}, expected2, false}, {"ok no bars with path", args{"https://ca.smallstep.com/1.0/sign"}, expected2, false}, {"ok host and path", args{"ca.smallstep.com/1.0/sign"}, expected2, false}, {"ok host and port", args{"ca.smallstep.com:443"}, &url.URL{Scheme: "https", Host: "ca.smallstep.com:443"}, false}, {"ok host, path and port", args{"ca.smallstep.com:443/1.0/sign"}, &url.URL{Scheme: "https", Host: "ca.smallstep.com:443", Path: "/1.0/sign"}, false}, {"fail bad url", args{"://ca.smallstep.com"}, nil, true}, {"fail no host", args{"https://"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := parseEndpoint(tt.args.endpoint) if tt.wantErr { assert.Error(t, err) assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.want, got) }) } } func TestClient_RootFingerprint(t *testing.T) { ok := &api.HealthResponse{Status: "ok"} nok := errs.InternalServer("Internal Server Error") httpsServer := httptest.NewTLSServer(nil) defer httpsServer.Close() httpsServerFingerprint := x509util.Fingerprint(httpsServer.Certificate()) httpServer := httptest.NewServer(nil) defer httpServer.Close() tests := []struct { name string server *httptest.Server response interface{} responseCode int want string wantErr bool }{ {"ok", httpsServer, ok, 200, httpsServerFingerprint, false}, {"ok with error", httpsServer, nok, 500, httpsServerFingerprint, false}, {"fail", httpServer, ok, 200, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tr := tt.server.Client().Transport c, err := NewClient(tt.server.URL, WithTransport(tr)) require.NoError(t, err) tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.RootFingerprint() if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.want, got) }) } } func TestClient_RootFingerprintWithServer(t *testing.T) { srv := startCABootstrapServer() defer srv.Close() caClient, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) require.NoError(t, err) fp, err := caClient.RootFingerprint() assert.NoError(t, err) assert.Equal(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) } func TestClient_SSHBastion(t *testing.T) { ok := &api.SSHBastionResponse{ Hostname: "host.local", Bastion: &authority.Bastion{ Hostname: "bastion.local", }, } tests := []struct { name string request *api.SSHBastionRequest response interface{} responseCode int wantErr bool err error }{ {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil}, {"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil}, {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) defer srv.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { render.JSONStatus(w, r, tt.response, tt.responseCode) }) got, err := c.SSHBastion(tt.request) if tt.wantErr { if assert.Error(t, err) { if tt.responseCode != 200 { var sc render.StatusCodedError if assert.ErrorAs(t, err, &sc) { assert.Equal(t, tt.responseCode, sc.StatusCode()) } assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } } assert.Nil(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.response, got) }) } } func TestClient_GetCaURL(t *testing.T) { tests := []struct { name string caURL string want string }{ {"ok", "https://ca.com", "https://ca.com"}, {"ok no schema", "ca.com", "https://ca.com"}, {"ok with port", "https://ca.com:9000", "https://ca.com:9000"}, {"ok with version", "https://ca.com/1.0", "https://ca.com/1.0"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport)) require.NoError(t, err) got := c.GetCaURL() assert.Equal(t, tt.want, got) }) } } func TestClient_WithTimeout(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(200 * time.Millisecond) render.JSONStatus(w, r, api.HealthResponse{Status: "ok"}, 200) })) defer srv.Close() tests := []struct { name string options []ClientOption assertion assert.ErrorAssertionFunc }{ {"ok", []ClientOption{WithTransport(http.DefaultTransport)}, assert.NoError}, {"ok with timeout", []ClientOption{WithTransport(http.DefaultTransport), WithTimeout(5 * time.Second)}, assert.NoError}, {"fail with timeout", []ClientOption{WithTransport(http.DefaultTransport), WithTimeout(10 * time.Millisecond)}, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, tt.options...) require.NoError(t, err) assert.NotZero(t, c.timeout) _, err = c.Health() tt.assertion(t, err) }) } } type decoratedRoundTripper func(*http.Request) (*http.Response, error) func (rt decoratedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return rt(req) } func TestClient_WithTransportDecorator(t *testing.T) { var srv *httptest.Server srv = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.RequestURI, "/root") { render.JSONStatus(w, r, api.RootResponse{ RootPEM: api.NewCertificate(srv.Certificate()), }, 200) return } if s := r.Header.Get("X-Test-Header"); s != "" { render.JSONStatus(w, r, api.HealthResponse{Status: s}, 200) } else { render.JSONStatus(w, r, api.HealthResponse{Status: "ok"}, 200) } })) defer srv.Close() fp := x509util.Fingerprint(srv.Certificate()) c, err := NewClient(srv.URL, WithRootSHA256(fp), WithTransportDecorator(func(tr http.RoundTripper) http.RoundTripper { return decoratedRoundTripper(func(r *http.Request) (*http.Response, error) { r.Header.Add("X-Test-Header", "some-data") return tr.RoundTrip(r) }) })) require.NoError(t, err) resp, err := c.Health() require.NoError(t, err) assert.Equal(t, "some-data", resp.Status) } func Test_enforceRequestID(t *testing.T) { set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) set.Header.Set("X-Request-Id", "already-set") inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context")) newRequestID := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) tests := []struct { name string r *http.Request want string }{ { name: "set", r: set, want: "already-set", }, { name: "context", r: inContext, want: "from-context", }, { name: "new", r: newRequestID, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { enforceRequestID(tt.r) v := tt.r.Header.Get("X-Request-Id") if assert.NotEmpty(t, v) { if tt.want != "" { assert.Equal(t, tt.want, v) } } }) } } func Test_newRequestID(t *testing.T) { requestID := newRequestID() u, err := uuid.Parse(requestID) assert.NoError(t, err) assert.Equal(t, uuid.Version(0x4), u.Version()) assert.Equal(t, uuid.RFC4122, u.Variant()) assert.Equal(t, requestID, u.String()) } ================================================ FILE: ca/identity/client.go ================================================ package identity import ( "crypto/tls" "crypto/x509" "encoding/json" "fmt" "net/http" "net/url" "os" "github.com/pkg/errors" "github.com/smallstep/certificates/internal/httptransport" ) // Client wraps http.Client with a transport using the step root and identity. type Client struct { CaURL *url.URL *http.Client } // ResolveReference resolves the given reference from the CaURL. func (c *Client) ResolveReference(ref *url.URL) *url.URL { return c.CaURL.ResolveReference(ref) } // LoadClient configures an http.Client with the root in // $STEPPATH/config/defaults.json and the identity defined in // $STEPPATH/config/identity.json func LoadClient() (*Client, error) { defaultsFile := DefaultsFile() b, err := os.ReadFile(defaultsFile) if err != nil { return nil, errors.Wrapf(err, "error reading %s", defaultsFile) } var defaults defaultsConfig if err := json.Unmarshal(b, &defaults); err != nil { return nil, errors.Wrapf(err, "error unmarshaling %s", defaultsFile) } if err := defaults.Validate(); err != nil { return nil, errors.Wrapf(err, "error validating %s", defaultsFile) } caURL, err := url.Parse(defaults.CaURL) if err != nil { return nil, errors.Wrapf(err, "error validating %s", defaultsFile) } if caURL.Scheme == "" { caURL.Scheme = "https" } identity, err := LoadDefaultIdentity() if err != nil { return nil, err } if err := identity.Validate(); err != nil { return nil, errors.Wrapf(err, "error validating %s", IdentityFile()) } if kind := identity.Kind(); kind != MutualTLS { return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind) } // Prepare transport with information in defaults.json and identity.json tr := httptransport.New() tr.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, GetClientCertificate: identity.GetClientCertificateFunc(), } // RootCAs b, err = os.ReadFile(defaults.Root) if err != nil { return nil, errors.Wrapf(err, "error loading %s", defaults.Root) } pool := x509.NewCertPool() if pool.AppendCertsFromPEM(b) { tr.TLSClientConfig.RootCAs = pool } return &Client{ CaURL: caURL, Client: &http.Client{ Transport: tr, }, }, nil } type defaultsConfig struct { CaURL string `json:"ca-url"` Root string `json:"root"` } func (c *defaultsConfig) Validate() error { switch { case c.CaURL == "": return fmt.Errorf("missing or invalid `ca-url` property") case c.Root == "": return fmt.Errorf("missing or invalid `root` property") default: return nil } } ================================================ FILE: ca/identity/client_test.go ================================================ package identity import ( "crypto/tls" "crypto/x509" "net/http" "net/http/httptest" "net/url" "os" "reflect" "sort" "testing" "github.com/smallstep/certificates/internal/httptransport" ) func returnInput(val string) func() string { return func() string { return val } } func TestClient(t *testing.T) { oldIdentityFile := IdentityFile oldDefaultsFile := DefaultsFile defer func() { IdentityFile = oldIdentityFile DefaultsFile = oldDefaultsFile }() IdentityFile = returnInput("testdata/config/identity.json") DefaultsFile = returnInput("testdata/config/defaults.json") client, err := LoadClient() if err != nil { t.Fatal(err) } okServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { w.WriteHeader(http.StatusUnauthorized) } else { w.WriteHeader(http.StatusOK) } })) defer okServer.Close() crt, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/secrets/server_key") if err != nil { t.Fatal(err) } b, err := os.ReadFile("testdata/certs/root_ca.crt") if err != nil { t.Fatal(err) } pool := x509.NewCertPool() pool.AppendCertsFromPEM(b) okServer.TLS = &tls.Config{ Certificates: []tls.Certificate{crt}, ClientCAs: pool, ClientAuth: tls.VerifyClientCertIfGiven, MinVersion: tls.VersionTLS12, } okServer.StartTLS() badServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) })) defer badServer.Close() if resp, err := client.Get(okServer.URL); err != nil { t.Errorf("client.Get() error = %v", err) } else { resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("client.Get() = %d, want %d", resp.StatusCode, http.StatusOK) } } if _, err := client.Get(badServer.URL); err == nil { t.Errorf("client.Get() error = %v, wantErr true", err) } } func TestClient_ResolveReference(t *testing.T) { type fields struct { CaURL *url.URL } type args struct { ref *url.URL } tests := []struct { name string fields fields args args want *url.URL }{ {"ok", fields{&url.URL{Scheme: "https", Host: "localhost"}}, args{&url.URL{Path: "/foo"}}, &url.URL{Scheme: "https", Host: "localhost", Path: "/foo"}}, {"ok", fields{&url.URL{Scheme: "https", Host: "localhost", Path: "/bar"}}, args{&url.URL{Path: "/foo"}}, &url.URL{Scheme: "https", Host: "localhost", Path: "/foo"}}, {"ok", fields{&url.URL{Scheme: "https", Host: "localhost"}}, args{&url.URL{Path: "/foo", RawQuery: "foo=bar"}}, &url.URL{Scheme: "https", Host: "localhost", Path: "/foo", RawQuery: "foo=bar"}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Client{ CaURL: tt.fields.CaURL, } if got := c.ResolveReference(tt.args.ref); !reflect.DeepEqual(got, tt.want) { t.Errorf("Client.ResolveReference() = %v, want %v", got, tt.want) } }) } } func TestLoadClient(t *testing.T) { oldIdentityFile := IdentityFile oldDefaultsFile := DefaultsFile defer func() { IdentityFile = oldIdentityFile DefaultsFile = oldDefaultsFile }() crt, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key") if err != nil { t.Fatal(err) } b, err := os.ReadFile("testdata/certs/root_ca.crt") if err != nil { t.Fatal(err) } pool := x509.NewCertPool() pool.AppendCertsFromPEM(b) tr := httptransport.New() tr.TLSClientConfig = &tls.Config{ Certificates: []tls.Certificate{crt}, RootCAs: pool, MinVersion: tls.VersionTLS12, } expected := &Client{ CaURL: &url.URL{Scheme: "https", Host: "127.0.0.1"}, Client: &http.Client{ Transport: tr, }, } tests := []struct { name string prepare func() want *Client wantErr bool }{ {"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") DefaultsFile = returnInput("testdata/config/defaults.json") }, expected, false}, {"fail identity", func() { IdentityFile = returnInput("testdata/config/missing.json") DefaultsFile = returnInput("testdata/config/defaults.json") }, nil, true}, {"fail identity", func() { IdentityFile = returnInput("testdata/config/fail.json") DefaultsFile = returnInput("testdata/config/defaults.json") }, nil, true}, {"fail defaults", func() { IdentityFile = returnInput("testdata/config/identity.json") DefaultsFile = returnInput("testdata/config/missing.json") }, nil, true}, {"fail defaults", func() { IdentityFile = returnInput("testdata/config/identity.json") DefaultsFile = returnInput("testdata/config/fail.json") }, nil, true}, {"fail ca", func() { IdentityFile = returnInput("testdata/config/identity.json") DefaultsFile = returnInput("testdata/config/badca.json") }, nil, true}, {"fail root", func() { IdentityFile = returnInput("testdata/config/identity.json") DefaultsFile = returnInput("testdata/config/badroot.json") }, nil, true}, {"fail type", func() { IdentityFile = returnInput("testdata/config/badIdentity.json") DefaultsFile = returnInput("testdata/config/defaults.json") }, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.prepare() got, err := LoadClient() if (err != nil) != tt.wantErr { t.Errorf("LoadClient() error = %v, wantErr %v", err, tt.wantErr) return } if tt.want == nil { if !reflect.DeepEqual(got, tt.want) { t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) } } else { gotTransport := got.Client.Transport.(*http.Transport) wantTransport := tt.want.Client.Transport.(*http.Transport) switch { case gotTransport.TLSClientConfig.GetClientCertificate == nil: t.Error("LoadClient() transport does not define GetClientCertificate") case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !equalPools(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs): t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) default: crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil) if err != nil { t.Errorf("LoadClient() GetClientCertificate error = %v", err) } else if !reflect.DeepEqual(*crt, wantTransport.TLSClientConfig.Certificates[0]) { t.Errorf("LoadClient() GetClientCertificate crt = %#v, want %#v", *crt, wantTransport.TLSClientConfig.Certificates[0]) } } } }) } } func Test_defaultsConfig_Validate(t *testing.T) { type fields struct { CaURL string Root string } tests := []struct { name string fields fields wantErr bool }{ {"ok", fields{"https://127.0.0.1", "root_ca.crt"}, false}, {"fail ca-url", fields{"", "root_ca.crt"}, true}, {"fail root", fields{"https://127.0.0.1", ""}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &defaultsConfig{ CaURL: tt.fields.CaURL, Root: tt.fields.Root, } if err := c.Validate(); (err != nil) != tt.wantErr { t.Errorf("defaultsConfig.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } //nolint:staticcheck,gocritic func equalPools(a, b *x509.CertPool) bool { if reflect.DeepEqual(a, b) { return true } subjects := a.Subjects() sA := make([]string, len(subjects)) for i := range subjects { sA[i] = string(subjects[i]) } subjects = b.Subjects() sB := make([]string, len(subjects)) for i := range subjects { sB[i] = string(subjects[i]) } sort.Strings(sA) sort.Strings(sB) return reflect.DeepEqual(sA, sB) } ================================================ FILE: ca/identity/identity.go ================================================ package identity import ( "bytes" "crypto" "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "net/http" "os" "path/filepath" "strings" "time" "github.com/pkg/errors" "github.com/smallstep/cli-utils/step" "go.step.sm/crypto/pemutil" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/internal/httptransport" ) // Type represents the different types of identity files. type Type string // Disabled represents a disabled identity type const Disabled Type = "" // MutualTLS represents the identity using mTLS. const MutualTLS Type = "mTLS" // TunnelTLS represents an identity using a (m)TLS tunnel. // // TunnelTLS can be optionally configured with client certificates and a root // file with the CAs to trust. By default it will use the system truststore // instead of the CA truststore. const TunnelTLS Type = "tTLS" // DefaultLeeway is the duration for matching not before claims. const DefaultLeeway = 1 * time.Minute var ( identityDir = step.IdentityPath configDir = step.ConfigPath // IdentityFile contains a pointer to a function that outputs the location of // the identity file. IdentityFile = step.IdentityFile // DefaultsFile contains a prointer a function that outputs the location of the // defaults configuration file. DefaultsFile = step.DefaultsFile ) // Identity represents the identity file that can be used to authenticate with // the CA. type Identity struct { Type string `json:"type"` Certificate string `json:"crt"` Key string `json:"key"` // Host is the tunnel host for a TunnelTLS (tTLS) identity. Host string `json:"host,omitempty"` // Root is the CA bundle of root CAs used in TunnelTLS to trust the // certificate of the host. Root string `json:"root,omitempty"` } // LoadIdentity loads an identity present in the given filename. func LoadIdentity(filename string) (*Identity, error) { b, err := os.ReadFile(filename) if err != nil { return nil, errors.Wrapf(err, "error reading %s", filename) } identity := new(Identity) if err := json.Unmarshal(b, &identity); err != nil { return nil, errors.Wrapf(err, "error unmarshaling %s", filename) } return identity, nil } // LoadDefaultIdentity loads the default identity. func LoadDefaultIdentity() (*Identity, error) { return LoadIdentity(IdentityFile()) } // WriteDefaultIdentity writes the given certificates and key and the // identity.json pointing to the new files. func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { if err := os.MkdirAll(configDir(), 0700); err != nil { return errors.Wrap(err, "error creating config directory") } identityDir := identityDir() if err := os.MkdirAll(identityDir, 0700); err != nil { return errors.Wrap(err, "error creating identity directory") } certFilename := filepath.Join(identityDir, "identity.crt") keyFilename := filepath.Join(identityDir, "identity_key") // Write certificate if err := writeCertificate(certFilename, certChain); err != nil { return err } // Write key buf := new(bytes.Buffer) block, err := pemutil.Serialize(key) if err != nil { return err } if err := pem.Encode(buf, block); err != nil { return errors.Wrap(err, "error encoding identity key") } if err := os.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing identity certificate") } // Write identity.json buf.Reset() enc := json.NewEncoder(buf) enc.SetIndent("", " ") if err := enc.Encode(Identity{ Type: string(MutualTLS), Certificate: certFilename, Key: keyFilename, }); err != nil { return errors.Wrap(err, "error writing identity json") } if err := os.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing identity certificate") } return nil } // WriteIdentityCertificate writes the identity certificate to disk. func WriteIdentityCertificate(certChain []api.Certificate) error { filename := filepath.Join(identityDir(), "identity.crt") return writeCertificate(filename, certChain) } // writeCertificate writes the given certificate on disk. func writeCertificate(filename string, certChain []api.Certificate) error { buf := new(bytes.Buffer) for _, crt := range certChain { block := &pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, } if err := pem.Encode(buf, block); err != nil { return errors.Wrap(err, "error encoding certificate") } } if err := os.WriteFile(filename, buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing certificate") } return nil } // Kind returns the type for the given identity. func (i *Identity) Kind() Type { switch strings.ToLower(i.Type) { case "": return Disabled case "mtls": return MutualTLS case "ttls": return TunnelTLS default: return Type(i.Type) } } // Validate validates the identity object. func (i *Identity) Validate() error { switch i.Kind() { case Disabled: return nil case MutualTLS: if i.Certificate == "" { return errors.New("identity.crt cannot be empty") } if i.Key == "" { return errors.New("identity.key cannot be empty") } if err := fileExists(i.Certificate); err != nil { return err } return fileExists(i.Key) case TunnelTLS: if i.Host == "" { return errors.New("tunnel.host cannot be empty") } if i.Certificate != "" { if err := fileExists(i.Certificate); err != nil { return err } if i.Key == "" { return errors.New("tunnel.key cannot be empty") } if err := fileExists(i.Key); err != nil { return err } } if i.Root != "" { if err := fileExists(i.Root); err != nil { return err } } return nil default: return errors.Errorf("unsupported identity type %s", i.Type) } } // TLSCertificate returns a tls.Certificate for the identity. func (i *Identity) TLSCertificate() (tls.Certificate, error) { fail := func(err error) (tls.Certificate, error) { return tls.Certificate{}, err } switch i.Kind() { case Disabled: return tls.Certificate{}, nil case MutualTLS, TunnelTLS: crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) if err != nil { return fail(errors.Wrap(err, "error creating identity certificate")) } // Check if certificate is expired. x509Cert, err := x509.ParseCertificate(crt.Certificate[0]) if err != nil { return fail(errors.Wrap(err, "error creating identity certificate")) } now := time.Now().Truncate(time.Second) if now.Add(DefaultLeeway).Before(x509Cert.NotBefore) { return fail(errors.New("certificate is not yet valid")) } if now.After(x509Cert.NotAfter) { return fail(errors.New("certificate is already expired")) } return crt, nil default: return fail(errors.Errorf("unsupported identity type %s", i.Type)) } } // GetClientCertificateFunc returns a method that can be used as the // GetClientCertificate property in a tls.Config. func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { return func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) if err != nil { return nil, errors.Wrap(err, "error loading identity certificate") } return &crt, nil } } // GetCertPool returns a x509.CertPool if the identity defines a custom root. func (i *Identity) GetCertPool() (*x509.CertPool, error) { if i.Root == "" { //nolint:nilnil // legacy return nil, nil } b, err := os.ReadFile(i.Root) if err != nil { return nil, errors.Wrap(err, "error reading identity root") } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(b) { return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root) } return pool, nil } // Renewer is that interface that a renew client must implement. type Renewer interface { GetRootCAs() *x509.CertPool Renew(tr http.RoundTripper) (*api.SignResponse, error) } // Renew renews the current identity certificate using a client with a renew // method. func (i *Identity) Renew(client Renewer) error { switch i.Kind() { case Disabled: return nil case MutualTLS, TunnelTLS: cert, err := i.TLSCertificate() if err != nil { return err } tr := httptransport.New() tr.TLSClientConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, RootCAs: client.GetRootCAs(), MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, } sign, err := client.Renew(tr) if err != nil { return err } if len(sign.CertChainPEM) == 0 { sign.CertChainPEM = []api.Certificate{sign.ServerPEM, sign.CaPEM} } // Write certificate buf := new(bytes.Buffer) for _, crt := range sign.CertChainPEM { block := &pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, } if err := pem.Encode(buf, block); err != nil { return errors.Wrap(err, "error encoding identity certificate") } } certFilename := filepath.Join(identityDir(), "identity.crt") if err := os.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing identity certificate") } return nil default: return errors.Errorf("unsupported identity type %s", i.Type) } } func fileExists(filename string) error { info, err := os.Stat(filename) if err != nil { return errors.Wrapf(err, "error reading %s", filename) } if info.IsDir() { return errors.Errorf("error reading %s: file is a directory", filename) } return nil } ================================================ FILE: ca/identity/identity_test.go ================================================ package identity import ( "crypto" "crypto/tls" "crypto/x509" "fmt" "net/http" "os" "path/filepath" "reflect" "testing" "github.com/smallstep/certificates/api" "go.step.sm/crypto/pemutil" ) func TestLoadDefaultIdentity(t *testing.T) { oldFile := IdentityFile defer func() { IdentityFile = oldFile }() expected := &Identity{ Type: "mTLS", Certificate: "testdata/identity/identity.crt", Key: "testdata/identity/identity_key", } tests := []struct { name string prepare func() want *Identity wantErr bool }{ {"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") }, expected, false}, {"fail read", func() { IdentityFile = returnInput("testdata/config/missing.json") }, nil, true}, {"fail unmarshal", func() { IdentityFile = returnInput("testdata/config/fail.json") }, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.prepare() got, err := LoadDefaultIdentity() if (err != nil) != tt.wantErr { t.Errorf("LoadDefaultIdentity() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("LoadDefaultIdentity() = %v, want %v", got, tt.want) } }) } } func TestIdentity_Kind(t *testing.T) { type fields struct { Type string } tests := []struct { name string fields fields want Type }{ {"disabled", fields{""}, Disabled}, {"mutualTLS", fields{"mTLS"}, MutualTLS}, {"tunnelTLS", fields{"tTLS"}, TunnelTLS}, {"unknown", fields{"unknown"}, Type("unknown")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &Identity{ Type: tt.fields.Type, } if got := i.Kind(); got != tt.want { t.Errorf("Identity.Kind() = %v, want %v", got, tt.want) } }) } } func TestIdentity_Validate(t *testing.T) { type fields struct { Type string Certificate string Key string Host string Root string } tests := []struct { name string fields fields wantErr bool }{ {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, false}, {"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, false}, {"ok disabled", fields{}, false}, {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, true}, {"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key", "", ""}, true}, {"fail key", fields{"mTLS", "testdata/identity/identity.crt", "", "", ""}, true}, {"fail key", fields{"tTLS", "testdata/identity/identity.crt", "", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, {"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, true}, {"fail missing certificate", fields{"tTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, {"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", ""}, true}, {"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, {"fail host", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", "testdata/certs/root_ca.crt"}, true}, {"fail root", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &Identity{ Type: tt.fields.Type, Certificate: tt.fields.Certificate, Key: tt.fields.Key, Host: tt.fields.Host, Root: tt.fields.Root, } if err := i.Validate(); (err != nil) != tt.wantErr { t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestIdentity_TLSCertificate(t *testing.T) { expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key") if err != nil { t.Fatal(err) } type fields struct { Type string Certificate string Key string } tests := []struct { name string fields fields want tls.Certificate wantErr bool }{ {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, {"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, {"ok disabled", fields{}, tls.Certificate{}, false}, {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, {"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, {"fail not after", fields{"mTLS", "testdata/identity/expired.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, {"fail not before", fields{"mTLS", "testdata/identity/not_before.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &Identity{ Type: tt.fields.Type, Certificate: tt.fields.Certificate, Key: tt.fields.Key, } got, err := i.TLSCertificate() if (err != nil) != tt.wantErr { t.Errorf("Identity.TLSCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Identity.TLSCertificate() = %v, want %v", got, tt.want) } }) } } func Test_fileExists(t *testing.T) { type args struct { filename string } tests := []struct { name string args args wantErr bool }{ {"ok", args{"testdata/identity/identity.crt"}, false}, {"missing", args{"testdata/identity/missing.crt"}, true}, {"directory", args{"testdata/identity"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := fileExists(tt.args.filename); (err != nil) != tt.wantErr { t.Errorf("fileExists() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestWriteDefaultIdentity(t *testing.T) { tmpDir := t.TempDir() oldConfigDir := configDir oldIdentityDir := identityDir oldIdentityFile := IdentityFile defer func() { configDir = oldConfigDir identityDir = oldIdentityDir IdentityFile = oldIdentityFile os.RemoveAll(tmpDir) }() certs, err := pemutil.ReadCertificateBundle("testdata/identity/identity.crt") if err != nil { t.Fatal(err) } key, err := pemutil.Read("testdata/identity/identity_key") if err != nil { t.Fatal(err) } var certChain []api.Certificate for _, c := range certs { certChain = append(certChain, api.Certificate{Certificate: c}) } configDir = returnInput(filepath.Join(tmpDir, "config")) identityDir = returnInput(filepath.Join(tmpDir, "identity")) IdentityFile = returnInput(filepath.Join(tmpDir, "config", "identity.json")) type args struct { certChain []api.Certificate key crypto.PrivateKey } tests := []struct { name string prepare func() args args wantErr bool }{ {"ok", func() {}, args{certChain, key}, false}, {"fail mkdir config", func() { configDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt")) identityDir = returnInput(filepath.Join(tmpDir, "identity")) }, args{certChain, key}, true}, {"fail mkdir identity", func() { configDir = returnInput(filepath.Join(tmpDir, "config")) identityDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt")) }, args{certChain, key}, true}, {"fail certificate", func() { configDir = returnInput(filepath.Join(tmpDir, "config")) identityDir = returnInput(filepath.Join(tmpDir, "bad-dir")) os.MkdirAll(identityDir(), 0600) }, args{certChain, key}, true}, {"fail key", func() { configDir = returnInput(filepath.Join(tmpDir, "config")) identityDir = returnInput(filepath.Join(tmpDir, "identity")) }, args{certChain, "badKey"}, true}, {"fail write identity", func() { configDir = returnInput(filepath.Join(tmpDir, "bad-dir")) identityDir = returnInput(filepath.Join(tmpDir, "identity")) IdentityFile = returnInput(filepath.Join(configDir(), "identity.json")) os.MkdirAll(configDir(), 0600) }, args{certChain, key}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.prepare() if err := WriteDefaultIdentity(tt.args.certChain, tt.args.key); (err != nil) != tt.wantErr { t.Errorf("WriteDefaultIdentity() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestIdentity_GetClientCertificateFunc(t *testing.T) { expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key") if err != nil { t.Fatal(err) } type fields struct { Type string Certificate string Key string Host string Root string } tests := []struct { name string fields fields want *tls.Certificate wantErr bool }{ {"ok mTLS", fields{"mtls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, &expected, false}, {"ok tTLS", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, &expected, false}, {"fail missing cert", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, nil, true}, {"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &Identity{ Type: tt.fields.Type, Certificate: tt.fields.Certificate, Key: tt.fields.Key, Host: tt.fields.Host, Root: tt.fields.Root, } fn := i.GetClientCertificateFunc() got, err := fn(&tls.CertificateRequestInfo{}) if (err != nil) != tt.wantErr { t.Errorf("Identity.GetClientCertificateFunc() = %v, wantErr %v", err, tt.wantErr) } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Identity.GetClientCertificateFunc() = %v, want %v", got, tt.want) } }) } } func TestIdentity_GetCertPool(t *testing.T) { type fields struct { Type string Certificate string Key string Host string Root string } tests := []struct { name string fields fields wantSubjects [][]byte wantErr bool }{ {"ok", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, [][]byte{[]byte("0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA")}, false}, {"ok nil", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", ""}, nil, false}, {"fail missing", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, nil, true}, {"fail no cert", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/secrets/root_ca_key"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &Identity{ Type: tt.fields.Type, Certificate: tt.fields.Certificate, Key: tt.fields.Key, Host: tt.fields.Host, Root: tt.fields.Root, } got, err := i.GetCertPool() if (err != nil) != tt.wantErr { t.Errorf("Identity.GetCertPool() error = %v, wantErr %v", err, tt.wantErr) return } if got != nil { //nolint:staticcheck // we don't have a different way to check // the certificates in the pool. subjects := got.Subjects() if !reflect.DeepEqual(subjects, tt.wantSubjects) { t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects) } } }) } } type renewer struct { pool *x509.CertPool sign *api.SignResponse err error } func (r *renewer) GetRootCAs() *x509.CertPool { return r.pool } func (r *renewer) Renew(http.RoundTripper) (*api.SignResponse, error) { return r.sign, r.err } func TestIdentity_Renew(t *testing.T) { tmpDir := t.TempDir() oldIdentityDir := identityDir identityDir = returnInput("testdata/identity") defer func() { identityDir = oldIdentityDir os.RemoveAll(tmpDir) }() certs, err := pemutil.ReadCertificateBundle("testdata/identity/identity.crt") if err != nil { t.Fatal(err) } ok := &renewer{ sign: &api.SignResponse{ ServerPEM: api.Certificate{Certificate: certs[0]}, CaPEM: api.Certificate{Certificate: certs[1]}, CertChainPEM: []api.Certificate{ {Certificate: certs[0]}, {Certificate: certs[1]}, }, }, } okOld := &renewer{ sign: &api.SignResponse{ ServerPEM: api.Certificate{Certificate: certs[0]}, CaPEM: api.Certificate{Certificate: certs[1]}, }, } fail := &renewer{ err: fmt.Errorf("an error"), } type fields struct { Type string Certificate string Key string } type args struct { client Renewer } tests := []struct { name string prepare func() fields fields args args wantErr bool }{ {"ok", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, false}, {"ok old", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{okOld}, false}, {"ok disabled", func() {}, fields{}, args{nil}, false}, {"fail type", func() {}, fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, {"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true}, {"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true}, {"fail write identity", func() { identityDir = returnInput(filepath.Join(tmpDir, "bad-dir")) os.MkdirAll(identityDir(), 0600) }, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.prepare() i := &Identity{ Type: tt.fields.Type, Certificate: tt.fields.Certificate, Key: tt.fields.Key, } if err := i.Renew(tt.args.client); (err != nil) != tt.wantErr { t.Errorf("Identity.Renew() error = %v, wantErr %v", err, tt.wantErr) } }) } } ================================================ FILE: ca/identity/testdata/certs/intermediate_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBozCCAUqgAwIBAgIQF4UYp5uEiuq/BO0cOWTq9DAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTEyMTIwMjQ1MThaFw0yOTEy MDkwMjQ1MThaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQGECLvDj+ZSqW78DRmUaugh0EU4NQ5 PoZxsLpB0gUsvNDGE0V5/2Q85GmsYzlBjBuoM+RfvF2fSP+dDTs3Hwjgo2YwZDAO BgNVHQ8BAf8EBAMCAQYwEgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU95Au B82vrt2UJyDTNBQH3B8sePUwHwYDVR0jBBgwFoAUgwZucvb+H/1chTPLQ1GYTJwK CXQwCgYIKoZIzj0EAwIDRwAwRAIgSaHuI61rNsFf1ke5WSUyuqy51DIE/ONCSWKT VQgTVJMCIAMsE+Eibk43hL4qQi5vBJiFLfGQDDN/9HUi6w4w5EZ7 -----END CERTIFICATE----- ================================================ FILE: ca/identity/testdata/certs/root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBfDCCASGgAwIBAgIQE8W0gyMruWxRDfegdPHrdDAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTEyMTIwMjQ1MThaFw0yOTEy MDkwMjQ1MThaMBwxGjAYBgNVBAMTEVNtYWxsc3RlcCBSb290IENBMFkwEwYHKoZI zj0CAQYIKoZIzj0DAQcDQgAEgd74QbUDcEj3aV5Oxv5eAMzwnejj7S/iDFAp89t9 kEb+Ux4NZC3Pay+92yRL//dBUI5WOopLXBniYomH4SFJg6NFMEMwDgYDVR0PAQH/ BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQEwHQYDVR0OBBYEFIMGbnL2/h/9XIUz y0NRmEycCgl0MAoGCCqGSM49BAMCA0kAMEYCIQD3/IUBL5/9Hpdp2+t4XnA42cwQ j5WkGY5hJIhdQ5P8qgIhAMf19nAIUlSbXKPf21Gv6eYEoNuuLfpcqnfBt5NJX64M -----END CERTIFICATE----- ================================================ FILE: ca/identity/testdata/certs/server.crt ================================================ -----BEGIN CERTIFICATE----- MIICHDCCAcKgAwIBAgIQQ4n25nGGKm6uGyVQ4cDNCTAKBggqhkjOPQQDAjAkMSIw IAYDVQQDExlTbWFsbHN0ZXAgSW50ZXJtZWRpYXRlIENBMB4XDTE5MTIxMjAyNTAz OVoXDTI5MTIwOTAyNTAzOVowFjEUMBIGA1UEAxMLdGVzdCBzZXJ2ZXIwWTATBgcq hkjOPQIBBggqhkjOPQMBBwNCAATmQRMCzRP1hBcYhAXlbiyR9QtsQosQfCZTS+en g6TtL9VjWsQXqd1SSStfi0grPyiTQLIPhPbSho/VJzSpf59Do4HjMIHgMA4GA1Ud DwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwHQYDVR0O BBYEFBvz34jDFrb3G4qiGkZZj99BnabAMB8GA1UdIwQYMBaAFPeQLgfNr67dlCcg 0zQUB9wfLHj1MBoGA1UdEQQTMBGCCWxvY2FsaG9zdIcEfwAAATBTBgwrBgEEAYKk ZMYoQAEEQzBBAgEBBA9qb2VAZXhhbXBsZS5jb20EKzJ3U05fQ21leFhXaWdfRG5w VlpzWUZkTUgxU3RjODZCSUJ6TjBydDVpcEUwCgYIKoZIzj0EAwIDSAAwRQIhAOt6 /x9LWQyBtx3RcyyALF2//OCfGjAx0zLGmUsXIHGIAiAZGVwTxbhxiYU95AXncS3F 3tXNaaIJyyO7atiVPhCR1A== -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIBozCCAUqgAwIBAgIQF4UYp5uEiuq/BO0cOWTq9DAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTEyMTIwMjQ1MThaFw0yOTEy MDkwMjQ1MThaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQGECLvDj+ZSqW78DRmUaugh0EU4NQ5 PoZxsLpB0gUsvNDGE0V5/2Q85GmsYzlBjBuoM+RfvF2fSP+dDTs3Hwjgo2YwZDAO BgNVHQ8BAf8EBAMCAQYwEgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU95Au B82vrt2UJyDTNBQH3B8sePUwHwYDVR0jBBgwFoAUgwZucvb+H/1chTPLQ1GYTJwK CXQwCgYIKoZIzj0EAwIDRwAwRAIgSaHuI61rNsFf1ke5WSUyuqy51DIE/ONCSWKT VQgTVJMCIAMsE+Eibk43hL4qQi5vBJiFLfGQDDN/9HUi6w4w5EZ7 -----END CERTIFICATE----- ================================================ FILE: ca/identity/testdata/config/badIdentity.json ================================================ { "type": "", "crt": "testdata/identity/identity.crt", "key": "testdata/identity/identity_key" } ================================================ FILE: ca/identity/testdata/config/badca.json ================================================ { "ca-url": ":", "ca-config": "testdata/config/ca.json", "fingerprint": "9dc35eef23a234b2520516a3169090d7ec2fc61323bdd6e4fde08bcfec5d0931", "root": "testdata/certs/root_ca.crt" } ================================================ FILE: ca/identity/testdata/config/badroot.json ================================================ { "ca-url": "https://127.0.0.1", "ca-config": "testdata/config/ca.json", "fingerprint": "9dc35eef23a234b2520516a3169090d7ec2fc61323bdd6e4fde08bcfec5d0931", "root": "testdata/certs/missing.crt" } ================================================ FILE: ca/identity/testdata/config/ca.json ================================================ { "root": "testdata/certs/root_ca.crt", "federatedRoots": [], "crt": "testdata/certs/intermediate_ca.crt", "key": "testdata/secrets/intermediate_ca_key", "address": ":443", "dnsNames": [ "127.0.0.1", "localhost" ], "logger": { "format": "text" }, "authority": { "provisioners": [ { "type": "jwk", "name": "joe@example.com", "key": { "use": "sig", "kty": "EC", "kid": "2wSN_CmexXWig_DnpVZsYFdMH1Stc86BIBzN0rt5ipE", "crv": "P-256", "alg": "ES256", "x": "QqYaIULUQqP0EOmogorCcQIxEtI7-zCRcUVFxyNwq4Q", "y": "YeIMipM7uMHjlxpFIUbfCBC1xEXczXNYRzJCMyrGcH0" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiSVQ3MVNUMTNNMTd1S3Y4VHRDczYyUSJ9.TXShNLPcITS0bFvQeMjjCDhQLICQs1ShECkgUkUsAm9ZWpSq6Yu03w.SWxtxscivS3L5Yo5.O-XY9YKK8wEJgVs7X1-FxiM_6w4s7iJQNXRD2JrZRsXtDqUz7diPfXuBOFPUFsNzykvob1qCsU4B23Ek2nbaS2HqPrIOGbOvOsR8Pt6kNoraH1QDp3Hyzkv0S-VGM0MCGYDDmmH33PZmsdS36Aw8v9xBnDHlwlMg4NjTskxpqggfQl01433B0lCJqJdrmeBeGL1ZCKixvc-wAQxU8GH5iiD925ViLY7RlVo-tmIBXpxRgheLgKiuMxmgPvf15qCdgU5TRqeuJbYJLzvPpoai0W4WHjpM1zLjjmp5OYRFW4m4ZRZf5g1Cm4lstFPUlTn85fkMZFdBh4_bFbjAv7k.epXp8DZKHj_dxP9EohwDIg" } ] }, "tls": { "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" ], "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false } } ================================================ FILE: ca/identity/testdata/config/defaults.json ================================================ { "ca-url": "https://127.0.0.1", "ca-config": "testdata/config/ca.json", "fingerprint": "9dc35eef23a234b2520516a3169090d7ec2fc61323bdd6e4fde08bcfec5d0931", "root": "testdata/certs/root_ca.crt" } ================================================ FILE: ca/identity/testdata/config/fail.json ================================================ This is not a json file ================================================ FILE: ca/identity/testdata/config/identity.json ================================================ { "type": "mTLS", "crt": "testdata/identity/identity.crt", "key": "testdata/identity/identity_key" } ================================================ FILE: ca/identity/testdata/config/tunnel.json ================================================ { "type": "mTLS", "crt": "testdata/identity/identity.crt", "key": "testdata/identity/identity_key", "host": "tunnel:443", "root": "testdata/certs/root_ca.crt" } ================================================ FILE: ca/identity/testdata/identity/expired.crt ================================================ -----BEGIN CERTIFICATE----- MIICIDCCAcegAwIBAgIRAM1GK1TLmvWLVOjP0dqVCiEwCgYIKoZIzj0EAwIwJDEi MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0xODEyMTIwMzI2 MzZaFw0xODEyMTMwMzI2MzZaMBoxGDAWBgNVBAMMD2pvZUBleGFtcGxlLmNvbTBZ MBMGByqGSM49AgEGCCqGSM49AwEHA0IABI0+NSjg3+vGhAeZGrxPksrXFqq0AIUB D3nQPmGPuUWIEmbt6qp3EVF/o+KwzWgDv5fzBmDlBkdBRz9xc3XIcQ2jgeMwgeAw HwYDVR0jBBgwFoAU95AuB82vrt2UJyDTNBQH3B8sePUwDgYDVR0PAQH/BAQDAgWg MB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAdBgNVHQ4EFgQU1Ht6zX2M eVXcnxhM4hxU0RCblNowGgYDVR0RBBMwEYEPam9lQGV4YW1wbGUuY29tMFMGDCsG AQQBgqRkxihAAQRDMEECAQEED2pvZUBleGFtcGxlLmNvbQQrMndTTl9DbWV4WFdp Z19EbnBWWnNZRmRNSDFTdGM4NkJJQnpOMHJ0NWlwRTAKBggqhkjOPQQDAgNHADBE AiBgoPACCRJ6s+C5Yz3BWeyM6VnWewctnaMsVJKyPdb98AIgV/7HRZsc5Xgi8iVt D4XxVOZDu/y1V4VIH5W4INfg6JA= -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIBozCCAUqgAwIBAgIQF4UYp5uEiuq/BO0cOWTq9DAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTEyMTIwMjQ1MThaFw0yOTEy MDkwMjQ1MThaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQGECLvDj+ZSqW78DRmUaugh0EU4NQ5 PoZxsLpB0gUsvNDGE0V5/2Q85GmsYzlBjBuoM+RfvF2fSP+dDTs3Hwjgo2YwZDAO BgNVHQ8BAf8EBAMCAQYwEgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU95Au B82vrt2UJyDTNBQH3B8sePUwHwYDVR0jBBgwFoAUgwZucvb+H/1chTPLQ1GYTJwK CXQwCgYIKoZIzj0EAwIDRwAwRAIgSaHuI61rNsFf1ke5WSUyuqy51DIE/ONCSWKT VQgTVJMCIAMsE+Eibk43hL4qQi5vBJiFLfGQDDN/9HUi6w4w5EZ7 -----END CERTIFICATE----- ================================================ FILE: ca/identity/testdata/identity/identity.crt ================================================ -----BEGIN CERTIFICATE----- MIICHzCCAcagAwIBAgIQfVgJ4dZ2AhS88uthvlIzyjAKBggqhkjOPQQDAjAkMSIw IAYDVQQDExlTbWFsbHN0ZXAgSW50ZXJtZWRpYXRlIENBMB4XDTE5MTIxMjAyNDgy MVoXDTI5MTIwOTAyNDgyMVowGjEYMBYGA1UEAwwPam9lQGV4YW1wbGUuY29tMFkw EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEjT41KODf68aEB5kavE+SytcWqrQAhQEP edA+YY+5RYgSZu3qqncRUX+j4rDNaAO/l/MGYOUGR0FHP3FzdchxDaOB4zCB4DAO BgNVHQ8BAf8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0G A1UdDgQWBBTUe3rNfYx5VdyfGEziHFTREJuU2jAfBgNVHSMEGDAWgBT3kC4Hza+u 3ZQnINM0FAfcHyx49TAaBgNVHREEEzARgQ9qb2VAZXhhbXBsZS5jb20wUwYMKwYB BAGCpGTGKEABBEMwQQIBAQQPam9lQGV4YW1wbGUuY29tBCsyd1NOX0NtZXhYV2ln X0RucFZac1lGZE1IMVN0Yzg2QklCek4wcnQ1aXBFMAoGCCqGSM49BAMCA0cAMEQC IHkYnKUBrXc/GIosKgnhHqVeRMi2O1JhnZdTE1uoy2C0AiA9ZrmGqPvpQ86f5yq5 llsieqBTzIum6A45q0/4XeN3QA== -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIBozCCAUqgAwIBAgIQF4UYp5uEiuq/BO0cOWTq9DAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTEyMTIwMjQ1MThaFw0yOTEy MDkwMjQ1MThaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQGECLvDj+ZSqW78DRmUaugh0EU4NQ5 PoZxsLpB0gUsvNDGE0V5/2Q85GmsYzlBjBuoM+RfvF2fSP+dDTs3Hwjgo2YwZDAO BgNVHQ8BAf8EBAMCAQYwEgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU95Au B82vrt2UJyDTNBQH3B8sePUwHwYDVR0jBBgwFoAUgwZucvb+H/1chTPLQ1GYTJwK CXQwCgYIKoZIzj0EAwIDRwAwRAIgSaHuI61rNsFf1ke5WSUyuqy51DIE/ONCSWKT VQgTVJMCIAMsE+Eibk43hL4qQi5vBJiFLfGQDDN/9HUi6w4w5EZ7 -----END CERTIFICATE----- ================================================ FILE: ca/identity/testdata/identity/identity_key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIJ4A5QcJioS5I89uT/hkuWPy/nlW5qy8vM8Tm2sgUCDyoAoGCCqGSM49 AwEHoUQDQgAEjT41KODf68aEB5kavE+SytcWqrQAhQEPedA+YY+5RYgSZu3qqncR UX+j4rDNaAO/l/MGYOUGR0FHP3FzdchxDQ== -----END EC PRIVATE KEY----- ================================================ FILE: ca/identity/testdata/identity/not_before.crt ================================================ -----BEGIN CERTIFICATE----- MIICIDCCAcagAwIBAgIQHRUI8eJv55I9/5IHi1mpmjAKBggqhkjOPQQDAjAkMSIw IAYDVQQDExlTbWFsbHN0ZXAgSW50ZXJtZWRpYXRlIENBMB4XDTI5MTIwOTAzMzAx NFoXDTI5MTIxMDAzMzAxNFowGjEYMBYGA1UEAwwPam9lQGV4YW1wbGUuY29tMFkw EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEjT41KODf68aEB5kavE+SytcWqrQAhQEP edA+YY+5RYgSZu3qqncRUX+j4rDNaAO/l/MGYOUGR0FHP3FzdchxDaOB4zCB4DAf BgNVHSMEGDAWgBT3kC4Hza+u3ZQnINM0FAfcHyx49TAOBgNVHQ8BAf8EBAMCBaAw HQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBTUe3rNfYx5 VdyfGEziHFTREJuU2jAaBgNVHREEEzARgQ9qb2VAZXhhbXBsZS5jb20wUwYMKwYB BAGCpGTGKEABBEMwQQIBAQQPam9lQGV4YW1wbGUuY29tBCsyd1NOX0NtZXhYV2ln X0RucFZac1lGZE1IMVN0Yzg2QklCek4wcnQ1aXBFMAoGCCqGSM49BAMCA0gAMEUC IQDJVzxQ0lY9+haZLs5qxhbaWoTmXwCbYdkwhThDfM/izwIgRZCmshc1flfimIPO eblT85Gk16ND/diV6pmtUaMT73I= -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIBozCCAUqgAwIBAgIQF4UYp5uEiuq/BO0cOWTq9DAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTEyMTIwMjQ1MThaFw0yOTEy MDkwMjQ1MThaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQGECLvDj+ZSqW78DRmUaugh0EU4NQ5 PoZxsLpB0gUsvNDGE0V5/2Q85GmsYzlBjBuoM+RfvF2fSP+dDTs3Hwjgo2YwZDAO BgNVHQ8BAf8EBAMCAQYwEgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU95Au B82vrt2UJyDTNBQH3B8sePUwHwYDVR0jBBgwFoAUgwZucvb+H/1chTPLQ1GYTJwK CXQwCgYIKoZIzj0EAwIDRwAwRAIgSaHuI61rNsFf1ke5WSUyuqy51DIE/ONCSWKT VQgTVJMCIAMsE+Eibk43hL4qQi5vBJiFLfGQDDN/9HUi6w4w5EZ7 -----END CERTIFICATE----- ================================================ FILE: ca/identity/testdata/secrets/intermediate_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,37e3019a1aa420225bbd4f342a3ce330 3SNIIXzE11cGKTPnErv8S1HIrd2lbQo+lsMT9GrU33GAi/MTvp0hx0txy7E3CsrU DbuPXs3zLCjgoNLOeyAWLqGjPLRt4YNnZGVDi3F/dFUAWxgXH8gZQ2d9ZqAXwxdd bhT4ZcRFgFzCPlHExtxBrJe+Tmeuq1HqD+8gpOSYbt0= -----END EC PRIVATE KEY----- ================================================ FILE: ca/identity/testdata/secrets/root_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,48fc92ab6885b2377d8bbac5b035bde2 BE07EXlLmJbAfjt2c9GwQoTT07DzjLWgiGWqxMKC0bOLQdmHe2pFudeQldDhTOme xnr9rRj9h+GRWV+sIzp+ilGd4/F6lfzWMl44GA5y7uBNWKhnI1uB9m9oo69hBNRg dQuDmAx5EWXvg7Mgg1MQZIPY8539RXWJdAs+uRSI12g= -----END EC PRIVATE KEY----- ================================================ FILE: ca/identity/testdata/secrets/server_key ================================================ -----BEGIN EC PRIVATE KEY----- MHcCAQEEIIGgfuMfx7h1VaCYzzEPZhrbTLsAr6dtyuQ2RLl6jKqBoAoGCCqGSM49 AwEHoUQDQgAE5kETAs0T9YQXGIQF5W4skfULbEKLEHwmU0vnp4Ok7S/VY1rEF6nd UkkrX4tIKz8ok0CyD4T20oaP1Sc0qX+fQw== -----END EC PRIVATE KEY----- ================================================ FILE: ca/mutable_tls_config.go ================================================ package ca import ( "crypto/tls" "crypto/x509" "sync" "github.com/smallstep/certificates/api" ) // mutableTLSConfig allows to use a tls.Config with mutable cert pools. type mutableTLSConfig struct { sync.RWMutex config *tls.Config clientCerts []*x509.Certificate rootCerts []*x509.Certificate mutClientCerts []*x509.Certificate mutRootCerts []*x509.Certificate } // newMutableTLSConfig creates a new mutableTLSConfig that will be later // initialized with a tls.Config. func newMutableTLSConfig() *mutableTLSConfig { return &mutableTLSConfig{ clientCerts: []*x509.Certificate{}, rootCerts: []*x509.Certificate{}, mutClientCerts: []*x509.Certificate{}, mutRootCerts: []*x509.Certificate{}, } } // Init initializes the mutable tls.Config with the given tls.Config. func (c *mutableTLSConfig) Init(base *tls.Config) { c.Lock() c.config = base.Clone() c.Unlock() } // TLSConfig returns the updated tls.Config it it has changed. It's used in the // tls.Config GetConfigForClient. func (c *mutableTLSConfig) TLSConfig() (config *tls.Config) { c.RLock() config = c.config.Clone() c.RUnlock() return } // Reload reloads the tls.Config with the new CAs. func (c *mutableTLSConfig) Reload() { // Prepare new pools c.RLock() rootCAs := x509.NewCertPool() clientCAs := x509.NewCertPool() // Fixed certs for _, cert := range c.rootCerts { rootCAs.AddCert(cert) } for _, cert := range c.clientCerts { clientCAs.AddCert(cert) } // Mutable certs for _, cert := range c.mutRootCerts { rootCAs.AddCert(cert) } for _, cert := range c.mutClientCerts { clientCAs.AddCert(cert) } c.RUnlock() // Set new pool c.Lock() c.config.RootCAs = rootCAs c.config.ClientCAs = clientCAs c.mutRootCerts = []*x509.Certificate{} c.mutClientCerts = []*x509.Certificate{} c.Unlock() } // AddImmutableClientCACert add an immutable cert to ClientCAs. func (c *mutableTLSConfig) AddImmutableClientCACert(cert *x509.Certificate) { c.Lock() c.clientCerts = append(c.clientCerts, cert) c.Unlock() } // AddImmutableRootCACert add an immutable cert to RootCas. func (c *mutableTLSConfig) AddImmutableRootCACert(cert *x509.Certificate) { c.Lock() c.rootCerts = append(c.rootCerts, cert) c.Unlock() } // AddClientCAs add mutable certs to ClientCAs. func (c *mutableTLSConfig) AddClientCAs(certs []api.Certificate) { c.Lock() for _, cert := range certs { c.mutClientCerts = append(c.mutClientCerts, cert.Certificate) } c.Unlock() } // AddRootCAs add mutable certs to RootCAs. func (c *mutableTLSConfig) AddRootCAs(certs []api.Certificate) { c.Lock() for _, cert := range certs { c.mutRootCerts = append(c.mutRootCerts, cert.Certificate) } c.Unlock() } ================================================ FILE: ca/provisioner.go ================================================ package ca import ( "encoding/json" "net/url" "time" "github.com/pkg/errors" "github.com/smallstep/cli-utils/token" "github.com/smallstep/cli-utils/token/provision" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" "github.com/smallstep/certificates/authority/provisioner" ) const tokenLifetime = 5 * time.Minute // Provisioner is an authorized entity that can sign tokens necessary for // signature requests. type Provisioner struct { *Client name string kid string audience string sshAudience string fingerprint string jwk *jose.JSONWebKey tokenLifetime time.Duration } // NewProvisioner loads and decrypts key material from the CA for the named // provisioner. The key identified by `kid` will be used if specified. If `kid` // is the empty string we'll use the first key for the named provisioner that // decrypts using `password`. func NewProvisioner(name, kid, caURL string, password []byte, opts ...ClientOption) (*Provisioner, error) { client, err := NewClient(caURL, opts...) if err != nil { return nil, err } // Get the fingerprint of the current connection fp, err := client.RootFingerprint() if err != nil { return nil, err } var jwk *jose.JSONWebKey switch { case name == "": return nil, errors.New("provisioner name cannot be empty") case kid == "": jwk, err = loadProvisionerJWKByName(client, name, password) default: jwk, err = loadProvisionerJWKByKid(client, kid, password) } if err != nil { return nil, err } return &Provisioner{ Client: client, name: name, kid: jwk.KeyID, audience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/sign"}).String(), sshAudience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/ssh/sign"}).String(), fingerprint: fp, jwk: jwk, tokenLifetime: tokenLifetime, }, nil } // Name returns the provisioner's name. func (p *Provisioner) Name() string { return p.name } // Kid returns the provisioners key ID. func (p *Provisioner) Kid() string { return p.kid } // Fingerprint root certificate fingerprint. func (p *Provisioner) Fingerprint() string { return p.fingerprint } // Audience returns the audience for tokens used with X.509 certificates. func (p *Provisioner) Audience() string { return p.audience } // SSHAudience returns audience used with SSH certificates. func (p *Provisioner) SSHAudience() string { return p.sshAudience } // SetFingerprint overwrites the default fingerprint used. func (p *Provisioner) SetFingerprint(sum string) { p.fingerprint = sum } // SetAudience overwrites the default audience used with X.509 certificates. func (p *Provisioner) SetAudience(s string) { p.audience = s } // SetSSHAudience overwrites the default audience used with SSH certificates. func (p *Provisioner) SetSSHAudience(s string) { p.sshAudience = s } // Token generates a bootstrap token for a subject. func (p *Provisioner) Token(subject string, sans ...string) (string, error) { if len(sans) == 0 { sans = []string{subject} } // A random jwt id will be used to identify duplicated tokens jwtID, err := randutil.Hex(64) // 256 bits if err != nil { return "", err } notBefore := time.Now() notAfter := notBefore.Add(tokenLifetime) tokOptions := []token.Options{ token.WithJWTID(jwtID), token.WithKid(p.kid), token.WithIssuer(p.name), token.WithAudience(p.audience), token.WithValidity(notBefore, notAfter), token.WithSANS(sans), } if p.fingerprint != "" { tokOptions = append(tokOptions, token.WithSHA(p.fingerprint)) } tok, err := provision.New(subject, tokOptions...) if err != nil { return "", err } return tok.SignedString(p.jwk.Algorithm, p.jwk.Key) } // SSHToken generates a SSH token. func (p *Provisioner) SSHToken(certType, keyID string, principals []string) (string, error) { jwtID, err := randutil.Hex(64) if err != nil { return "", err } notBefore := time.Now() notAfter := notBefore.Add(tokenLifetime) tokOptions := []token.Options{ token.WithJWTID(jwtID), token.WithKid(p.kid), token.WithIssuer(p.name), token.WithAudience(p.sshAudience), token.WithValidity(notBefore, notAfter), token.WithSSH(provisioner.SignSSHOptions{ CertType: certType, Principals: principals, KeyID: keyID, }), } if p.fingerprint != "" { tokOptions = append(tokOptions, token.WithSHA(p.fingerprint)) } tok, err := provision.New(keyID, tokOptions...) if err != nil { return "", err } return tok.SignedString(p.jwk.Algorithm, p.jwk.Key) } func decryptProvisionerJWK(encryptedKey string, password []byte) (*jose.JSONWebKey, error) { enc, err := jose.ParseEncrypted(encryptedKey) if err != nil { return nil, errors.Wrap(err, "error parsing provisioner encrypted key") } data, err := enc.Decrypt(password) if err != nil { return nil, errors.Wrap(err, "error decrypting provisioner key with provided password") } jwk := new(jose.JSONWebKey) if err := json.Unmarshal(data, jwk); err != nil { return nil, errors.Wrap(err, "error unmarshaling provisioning key") } return jwk, nil } // loadProvisionerJWKByKid retrieves a provisioner key from the CA by key ID and // decrypts it using the specified password. func loadProvisionerJWKByKid(client *Client, kid string, password []byte) (*jose.JSONWebKey, error) { encrypted, err := getProvisionerKey(client, kid) if err != nil { return nil, err } return decryptProvisionerJWK(encrypted, password) } // loadProvisionerJWKByName retrieves the list of provisioners and encrypted key then // returns the key of the first provisioner with a matching name that can be successfully // decrypted with the specified password. func loadProvisionerJWKByName(client *Client, name string, password []byte) (*jose.JSONWebKey, error) { provisioners, err := getProvisioners(client) if err != nil { return nil, errors.Wrap(err, "error getting the provisioners") } for _, provisioner := range provisioners { if provisioner.GetName() == name { if _, encryptedKey, ok := provisioner.GetEncryptedKey(); ok { if key, err := decryptProvisionerJWK(encryptedKey, password); err == nil { return key, nil } } } } return nil, errors.Errorf("provisioner '%s' not found (or your password is wrong)", name) } // getProvisioners returns the list of provisioners using the configured client. func getProvisioners(client *Client) (provisioner.List, error) { var cursor string var provisioners provisioner.List for { resp, err := client.Provisioners(WithProvisionerCursor(cursor), WithProvisionerLimit(100)) if err != nil { return nil, err } provisioners = append(provisioners, resp.Provisioners...) if resp.NextCursor == "" { return provisioners, nil } cursor = resp.NextCursor } } // getProvisionerKey returns the encrypted provisioner key for the given kid. func getProvisionerKey(client *Client, kid string) (string, error) { resp, err := client.ProvisionerKey(kid) if err != nil { return "", err } return resp.Key, nil } ================================================ FILE: ca/provisioner_test.go ================================================ package ca import ( "net/url" "os" "reflect" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" ) func getTestProvisioner(t *testing.T, caURL string) *Provisioner { jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) if err != nil { t.Fatal(err) } cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") if err != nil { t.Fatal(err) } client, err := NewClient(caURL, WithRootFile("testdata/secrets/root_ca.crt")) if err != nil { t.Fatal(err) } return &Provisioner{ Client: client, name: "mariano", kid: "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", audience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/sign"}).String(), sshAudience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/ssh/sign"}).String(), fingerprint: x509util.Fingerprint(cert), jwk: jwk, tokenLifetime: 5 * time.Minute, } } func mustParseSigned(t *testing.T, tok string, key, dest any) { t.Helper() jwt, err := jose.ParseSigned(tok) require.NoError(t, err) require.NoError(t, jwt.Claims(key, dest)) } func TestNewProvisioner(t *testing.T) { ca := startCATestServer(t) defer ca.Close() want := getTestProvisioner(t, ca.URL) caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt") require.NoError(t, err) type args struct { name string kid string caURL string password []byte clientOption ClientOption } tests := []struct { name string args args want *Provisioner wantErr bool }{ {"ok", args{want.name, want.kid, ca.URL, []byte("password"), WithRootFile("testdata/secrets/root_ca.crt")}, want, false}, {"ok-by-name", args{want.name, "", ca.URL, []byte("password"), WithRootFile("testdata/secrets/root_ca.crt")}, want, false}, {"ok-with-bundle", args{want.name, want.kid, ca.URL, []byte("password"), WithCABundle(caBundle)}, want, false}, {"ok-with-fingerprint", args{want.name, want.kid, ca.URL, []byte("password"), WithRootSHA256(want.fingerprint)}, want, false}, {"fail-bad-kid", args{want.name, "bad-kid", ca.URL, []byte("password"), WithRootFile("testdata/secrets/root_ca.crt")}, nil, true}, {"fail-empty-name", args{"", want.kid, ca.URL, []byte("password"), WithRootFile("testdata/secrets/root_ca.crt")}, nil, true}, {"fail-bad-name", args{"bad-name", "", ca.URL, []byte("password"), WithRootFile("testdata/secrets/root_ca.crt")}, nil, true}, {"fail-by-password", args{want.name, want.kid, ca.URL, []byte("bad-password"), WithRootFile("testdata/secrets/root_ca.crt")}, nil, true}, {"fail-by-password-no-kid", args{want.name, "", ca.URL, []byte("bad-password"), WithRootFile("testdata/secrets/root_ca.crt")}, nil, true}, {"fail-bad-certificate", args{want.name, want.kid, ca.URL, []byte("password"), WithRootFile("testdata/secrets/federated_ca.crt")}, nil, true}, {"fail-not-found-certificate", args{want.name, want.kid, ca.URL, []byte("password"), WithRootFile("testdata/secrets/missing.crt")}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := NewProvisioner(tt.args.name, tt.args.kid, tt.args.caURL, tt.args.password, tt.args.clientOption) if (err != nil) != tt.wantErr { t.Errorf("NewProvisioner() error = %v, wantErr %v", err, tt.wantErr) return } // Client won't match. // Make sure it does. if got != nil { got.Client = want.Client } if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewProvisioner() = %v, want %v", got, tt.want) } }) } } func TestProvisioner_Getters(t *testing.T) { p := getTestProvisioner(t, "https://127.0.0.1:9000") if got := p.Name(); got != p.name { t.Errorf("Provisioner.Name() = %v, want %v", got, p.name) } if got := p.Kid(); got != p.kid { t.Errorf("Provisioner.Kid() = %v, want %v", got, p.kid) } if got := p.Fingerprint(); got != p.fingerprint { t.Errorf("Provisioner.Fingerprint() = %v, want %v", got, p.kid) } if got := p.Audience(); got != p.audience { t.Errorf("Provisioner.Audience() = %v, want %v", got, p.kid) } if got := p.SSHAudience(); got != p.sshAudience { t.Errorf("Provisioner.SSHAudience() = %v, want %v", got, p.kid) } } func TestProvisioner_Setters(t *testing.T) { p := getTestProvisioner(t, "https://127.0.0.1:9000") u, err := url.Parse(p.GetCaURL()) require.NoError(t, err) p.SetFingerprint("71498a347624fe99e1baff52d57d04a75be9c695c67bd6b0a08903e809f7497d") p.SetAudience(u.ResolveReference(&url.URL{Path: "/1.0/revoke"}).String()) p.SetSSHAudience(u.ResolveReference(&url.URL{Path: "/1.0/ssh/revoke"}).String()) tok, err := p.Token("test@example.com") require.NoError(t, err) claims := make(map[string]any) mustParseSigned(t, tok, p.jwk.Public(), &claims) assert.Equal(t, "71498a347624fe99e1baff52d57d04a75be9c695c67bd6b0a08903e809f7497d", claims["sha"]) assert.Equal(t, "https://127.0.0.1:9000/1.0/revoke", claims["aud"]) tok, err = p.SSHToken("user", "test@example.com", []string{"test"}) require.NoError(t, err) claims = make(map[string]any) mustParseSigned(t, tok, p.jwk.Public(), &claims) assert.Equal(t, "71498a347624fe99e1baff52d57d04a75be9c695c67bd6b0a08903e809f7497d", claims["sha"]) assert.Equal(t, "https://127.0.0.1:9000/1.0/ssh/revoke", claims["aud"]) } func TestProvisioner_Token(t *testing.T) { p := getTestProvisioner(t, "https://127.0.0.1:9000") sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" type fields struct { name string kid string fingerprint string jwk *jose.JSONWebKey tokenLifetime time.Duration } type args struct { subject string sans []string } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", nil}, false}, {"ok-with-san", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com"}}, false}, {"ok-with-sans", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com", "127.0.0.1"}}, false}, {"fail-no-subject", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"", []string{"foo.smallstep.com"}}, true}, {"fail-no-key", fields{p.name, p.kid, sha, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject", nil}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Provisioner{ name: tt.fields.name, kid: tt.fields.kid, audience: "https://127.0.0.1:9000/1.0/sign", fingerprint: tt.fields.fingerprint, jwk: tt.fields.jwk, tokenLifetime: tt.fields.tokenLifetime, } got, err := p.Token(tt.args.subject, tt.args.sans...) if (err != nil) != tt.wantErr { t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr == false { jwt, err := jose.ParseSigned(got) if err != nil { t.Error(err) return } var claims jose.Claims if err := jwt.Claims(tt.fields.jwk.Public(), &claims); err != nil { t.Error(err) return } if err := claims.ValidateWithLeeway(jose.Expected{ Audience: []string{"https://127.0.0.1:9000/1.0/sign"}, Issuer: tt.fields.name, Subject: tt.args.subject, Time: time.Now().UTC(), }, time.Minute); err != nil { t.Error(err) return } lifetime := claims.Expiry.Time().Sub(claims.NotBefore.Time()) if lifetime != tt.fields.tokenLifetime { t.Errorf("Claims token life time = %s, want %s", lifetime, tt.fields.tokenLifetime) } allClaims := make(map[string]interface{}) if err := jwt.Claims(tt.fields.jwk.Public(), &allClaims); err != nil { t.Error(err) return } if v, ok := allClaims["sha"].(string); !ok || v != sha { t.Errorf("Claim sha = %s, want %s", v, sha) } if len(tt.args.sans) == 0 { if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) { t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject}) } } else { want := []interface{}{} for _, s := range tt.args.sans { want = append(want, s) } if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, want) { t.Errorf("Claim sans = %s, want %s", v, want) } } if v, ok := allClaims["jti"].(string); !ok || v == "" { t.Errorf("Claim jti = %s, want not blank", v) } } }) } } func TestProvisioner_IPv6Token(t *testing.T) { p := getTestProvisioner(t, "https://[::1]:9000") sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" type fields struct { name string kid string fingerprint string jwk *jose.JSONWebKey tokenLifetime time.Duration } type args struct { subject string sans []string } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", nil}, false}, {"ok-with-san", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com"}}, false}, {"ok-with-sans", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com", "127.0.0.1"}}, false}, {"fail-no-subject", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"", []string{"foo.smallstep.com"}}, true}, {"fail-no-key", fields{p.name, p.kid, sha, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject", nil}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Provisioner{ name: tt.fields.name, kid: tt.fields.kid, audience: "https://[::1]:9000/1.0/sign", fingerprint: tt.fields.fingerprint, jwk: tt.fields.jwk, tokenLifetime: tt.fields.tokenLifetime, } got, err := p.Token(tt.args.subject, tt.args.sans...) if (err != nil) != tt.wantErr { t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr == false { jwt, err := jose.ParseSigned(got) if err != nil { t.Error(err) return } var claims jose.Claims if err := jwt.Claims(tt.fields.jwk.Public(), &claims); err != nil { t.Error(err) return } if err := claims.ValidateWithLeeway(jose.Expected{ Audience: []string{"https://[::1]:9000/1.0/sign"}, Issuer: tt.fields.name, Subject: tt.args.subject, Time: time.Now().UTC(), }, time.Minute); err != nil { t.Error(err) return } lifetime := claims.Expiry.Time().Sub(claims.NotBefore.Time()) if lifetime != tt.fields.tokenLifetime { t.Errorf("Claims token life time = %s, want %s", lifetime, tt.fields.tokenLifetime) } allClaims := make(map[string]interface{}) if err := jwt.Claims(tt.fields.jwk.Public(), &allClaims); err != nil { t.Error(err) return } if v, ok := allClaims["sha"].(string); !ok || v != sha { t.Errorf("Claim sha = %s, want %s", v, sha) } if len(tt.args.sans) == 0 { if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) { t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject}) } } else { want := []interface{}{} for _, s := range tt.args.sans { want = append(want, s) } if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, want) { t.Errorf("Claim sans = %s, want %s", v, want) } } if v, ok := allClaims["jti"].(string); !ok || v == "" { t.Errorf("Claim jti = %s, want not blank", v) } } }) } } func TestProvisioner_SSHToken(t *testing.T) { p := getTestProvisioner(t, "https://127.0.0.1:9000") sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" type fields struct { name string kid string fingerprint string jwk *jose.JSONWebKey tokenLifetime time.Duration } type args struct { certType string keyID string principals []string } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"user", "foo@smallstep.com", []string{"foo"}}, false}, {"ok host", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"host", "foo.smallstep.com", []string{"foo.smallstep.com"}}, false}, {"ok multiple principals", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"user", "foo@smallstep.com", []string{"foo", "bar"}}, false}, {"fail-no-subject", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"user", "", []string{"foo"}}, true}, {"fail-no-key", fields{p.name, p.kid, sha, &jose.JSONWebKey{}, p.tokenLifetime}, args{"user", "foo@smallstep.com", []string{"foo"}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Provisioner{ name: tt.fields.name, kid: tt.fields.kid, audience: "https://127.0.0.1:9000/1.0/sign", sshAudience: "https://127.0.0.1:9000/1.0/ssh/sign", fingerprint: tt.fields.fingerprint, jwk: tt.fields.jwk, tokenLifetime: tt.fields.tokenLifetime, } got, err := p.SSHToken(tt.args.certType, tt.args.keyID, tt.args.principals) if (err != nil) != tt.wantErr { t.Errorf("Provisioner.SSHToken() error = %v, wantErr %v", err, tt.wantErr) return } if tt.wantErr == false { jwt, err := jose.ParseSigned(got) if err != nil { t.Error(err) return } var claims jose.Claims if err := jwt.Claims(tt.fields.jwk.Public(), &claims); err != nil { t.Error(err) return } if err := claims.ValidateWithLeeway(jose.Expected{ Audience: []string{"https://127.0.0.1:9000/1.0/ssh/sign"}, Issuer: tt.fields.name, Subject: tt.args.keyID, Time: time.Now().UTC(), }, time.Minute); err != nil { t.Error(err) return } lifetime := claims.Expiry.Time().Sub(claims.NotBefore.Time()) if lifetime != tt.fields.tokenLifetime { t.Errorf("Claims token life time = %s, want %s", lifetime, tt.fields.tokenLifetime) } allClaims := make(map[string]interface{}) if err := jwt.Claims(tt.fields.jwk.Public(), &allClaims); err != nil { t.Error(err) return } if v, ok := allClaims["sha"].(string); !ok || v != sha { t.Errorf("Claim sha = %s, want %s", v, sha) } principals := make([]interface{}, len(tt.args.principals)) for i, p := range tt.args.principals { principals[i] = p } want := map[string]interface{}{ "ssh": map[string]interface{}{ "certType": tt.args.certType, "keyID": tt.args.keyID, "principals": principals, "validAfter": "", "validBefore": "", }, } if !reflect.DeepEqual(allClaims["step"], want) { t.Errorf("Claim step = %s, want %s", allClaims["step"], want) } if v, ok := allClaims["jti"].(string); !ok || v == "" { t.Errorf("Claim jti = %s, want not blank", v) } } }) } } ================================================ FILE: ca/renew.go ================================================ package ca import ( "context" "crypto/tls" "math/rand" "sync" "time" "github.com/pkg/errors" ) // RenewFunc defines the type of the functions used to get a new tls // certificate. type RenewFunc func() (*tls.Certificate, error) var minCertDuration = time.Minute // TLSRenewer automatically renews a tls certificate using a RenewFunc. type TLSRenewer struct { renewMutex sync.RWMutex RenewCertificate RenewFunc cert *tls.Certificate timer *time.Timer renewBefore time.Duration renewJitter time.Duration certNotAfter time.Time } type tlsRenewerOptions func(r *TLSRenewer) error // WithRenewBefore modifies a tlsRenewer by setting the renewBefore attribute. func WithRenewBefore(b time.Duration) func(r *TLSRenewer) error { return func(r *TLSRenewer) error { r.renewBefore = b return nil } } // WithRenewJitter modifies a tlsRenewer by setting the renewJitter attribute. func WithRenewJitter(j time.Duration) func(r *TLSRenewer) error { return func(r *TLSRenewer) error { r.renewJitter = j return nil } } // NewTLSRenewer creates a TLSRenewer for the given cert. It will use the given // RenewFunc to get a new certificate when required. func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOptions) (*TLSRenewer, error) { r := &TLSRenewer{ RenewCertificate: fn, cert: cert, certNotAfter: cert.Leaf.NotAfter.Add(-1 * time.Minute), } for _, f := range opts { if err := f(r); err != nil { return nil, errors.Wrap(err, "error applying options") } } // Use the current time to calculate the initial period. Using a notBefore // in the past might set a renewBefore too large, causing continuous // renewals due to the negative values in nextRenewDuration. period := cert.Leaf.NotAfter.Sub(time.Now().Truncate(time.Second)) if period < minCertDuration { return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period) } // By default we will try to renew the cert before 2/3 of the validity // period have expired. if r.renewBefore == 0 { r.renewBefore = period / 3 } // By default we set the jitter to 1/20th of the validity period. if r.renewJitter == 0 { r.renewJitter = period / 20 } return r, nil } // Run starts the certificate renewer for the given certificate. func (r *TLSRenewer) Run() { cert := r.getCertificate() next := r.nextRenewDuration(cert.Leaf.NotAfter) r.renewMutex.Lock() r.timer = time.AfterFunc(next, r.renewCertificate) r.renewMutex.Unlock() } // RunContext starts the certificate renewer for the given certificate. func (r *TLSRenewer) RunContext(ctx context.Context) { r.Run() go func() { <-ctx.Done() r.Stop() }() } // Stop prevents the renew timer from firing. func (r *TLSRenewer) Stop() bool { if r.timer != nil { return r.timer.Stop() } return true } // GetCertificate returns the current server certificate. // // This method is set in the tls.Config GetCertificate property. func (r *TLSRenewer) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { return r.getCertificate(), nil } // GetCertificateForCA returns the current server certificate. It can only be // used if the renew function creates the new certificate and do not uses a TLS // request. It's intended to be use by the certificate authority server. // // This method is set in the tls.Config GetCertificate property. func (r *TLSRenewer) GetCertificateForCA(*tls.ClientHelloInfo) (*tls.Certificate, error) { return r.getCertificateForCA(), nil } // GetClientCertificate returns the current client certificate. // // This method is set in the tls.Config GetClientCertificate property. func (r *TLSRenewer) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) { return r.getCertificate(), nil } // getCertificate returns the certificate using a read-only lock. // // Known issue: It cannot renew an expired certificate because the /renew // endpoint requires a valid client certificate. The certificate can expire // if the timer does not fire e.g. when the CA is run from a laptop that // enters sleep mode. func (r *TLSRenewer) getCertificate() *tls.Certificate { r.renewMutex.RLock() cert := r.cert r.renewMutex.RUnlock() return cert } // getCertificateForCA returns the certificate using a read-only lock. It will // automatically renew the certificate if it has expired. func (r *TLSRenewer) getCertificateForCA() *tls.Certificate { r.renewMutex.RLock() // Force certificate renewal if the timer didn't run. // This is an special case that can happen after a computer sleep. if time.Now().After(r.certNotAfter) { r.renewMutex.RUnlock() r.renewCertificate() r.renewMutex.RLock() } cert := r.cert r.renewMutex.RUnlock() return cert } // setCertificate updates the certificate using a read-write lock. It also // updates certNotAfter with 1m of delta; this will force the renewal of the // certificate if it is about to expire. func (r *TLSRenewer) setCertificate(cert *tls.Certificate) { r.renewMutex.Lock() r.cert = cert r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute) r.renewMutex.Unlock() } func (r *TLSRenewer) renewCertificate() { var next time.Duration cert, err := r.RenewCertificate() if err != nil { next = r.renewJitter / 2 next += time.Duration(mathRandInt63n(int64(next))) } else { r.setCertificate(cert) next = r.nextRenewDuration(cert.Leaf.NotAfter) } r.renewMutex.Lock() r.timer.Reset(next) r.renewMutex.Unlock() } func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration { d := time.Until(notAfter).Truncate(time.Second) - r.renewBefore n := mathRandInt63n(int64(r.renewJitter)) d -= time.Duration(n) if d < 0 { d = 0 } return d } //nolint:gosec // not used for cryptographic security func mathRandInt63n(n int64) int64 { return rand.Int63n(n) } ================================================ FILE: ca/signal.go ================================================ package ca import ( "log" "os" "os/signal" "syscall" ) // Stopper is the interface that external commands can implement to stop the // server. type Stopper interface { Stop() error } // StopReloader is the interface that external commands can implement to stop // the server and reload the configuration while running. type StopReloader interface { Stop() error Reload() error } // StopHandler watches SIGINT, SIGTERM on a list of servers implementing the // Stopper interface, and when one of those signals is caught we'll run Stop // (SIGINT, SIGTERM) on all servers. func StopHandler(servers ...Stopper) { signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) defer signal.Stop(signals) for sig := range signals { switch sig { case syscall.SIGINT, syscall.SIGTERM: log.Println("shutting down ...") for _, server := range servers { err := server.Stop() if err != nil { log.Printf("error stopping server: %s", err.Error()) } } return } } } // StopReloaderHandler watches SIGINT, SIGTERM and SIGHUP on a list of servers // implementing the StopReloader interface, and when one of those signals is // caught we'll run Stop (SIGINT, SIGTERM) or Reload (SIGHUP) on all servers. func StopReloaderHandler(servers ...StopReloader) { signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) defer signal.Stop(signals) for sig := range signals { switch sig { case syscall.SIGHUP: log.Println("reloading ...") for _, server := range servers { err := server.Reload() if err != nil { log.Printf("error reloading server: %+v", err) } } case syscall.SIGINT, syscall.SIGTERM: log.Println("shutting down ...") for _, server := range servers { err := server.Stop() if err != nil { log.Printf("error stopping server: %s", err.Error()) } } return } } } ================================================ FILE: ca/testdata/ca.json ================================================ { "root": "../ca/testdata/secrets/root_ca.crt", "federatedRoots": ["../ca/testdata/secrets/federated_ca.crt"], "crt": "../ca/testdata/secrets/intermediate_ca.crt", "key": "../ca/testdata/secrets/intermediate_ca_key", "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.3, "renegotiation": false, "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" ] }, "authority": { "backdate": "0s", "provisioners": [ { "name": "max", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IkpsNkZLWUp4V1UwdGRIbG9UanA1aGcifQ.Qy0EP6u5-t0ggOweoc3Z1DCzR5BllsQi.KUkviZ_TJKY4c0Mi.h7QZqgh_Fl2MZpmVy4h375yC0DORjB1dQULbNqc6MuUCW2iweWVRysFImUXiXMUKRarJC5adwWy1GhyAqUj6Xj1iOZDGLjYnqMETGWcI0rKDBwcSU7y7Y-2VYBRDSM2b7aWtTBfz3_kvEaw_vc3b5CEPJ86UlZc-jhKFRr_IcGWU-vXX5-bppoH15IPreyzi55YdjCll338lYpDecB_Paym3XBXotyd2iGXXUwoA1npEFwuyRMMEhl9zLp7rVcMW6A_32EzB8cZANEnA0C4FXGHQalY6u_2UeqxcC8_FuXPay6VIYODyRqcABvvkft3nwOcrI0pYDGBdk2w2Euk.kOAFq3Tg6s4vBGS_plMpSw", "key": { "use": "sig", "kty": "EC", "kid": "IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", "crv": "P-256", "alg": "ES256", "x": "XmaY0c9Cc_kjfn9uhimiDiKnKn00gmFzzsvElg4KxoE", "y": "ZhYcFQBqtErdC_pA7sOXrO7AboCEPIKP9Ik4CHJqANk" } }, { "name": "mike", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlZsWnl0dUxrWTR5enlqZXJybnN0aGcifQ.QP15wQYjZ12BLgl-XTq2Vb12G3OHAfic.X35QqAaXwnlmeCUU._2qIUp0TI8yDI7c2e9upIRdrnmB5OvtLfrYN-Su2NLBpaoYtr9O55Wo0Iryc0W2pYqnVDPvgPPes4P4nQAnzw5WhFYc1Xf1ZEetfdNhwi1x2FNwPbACBAgxm5AW40O5AAlbLcWushYASfeMBZocTGXuSGUzwFqoWD-5EDJ80TWQ7cAj3ttHrJ_3QV9hi4O9KJUCiXngN-Yz2zXrhBL4NOH2fmRbaf5c0rF8xUJIIW-TcyYJeX_Fbx1IzzKKPd9USUwkDhxD4tLa51I345xVqjuwG1PEn6nF8JKqLRVUKEKFin-ShXrfE61KceyAvm4YhWKrbJWIm3bH5Hxaphy4.TexIrIhsRxJStpE3EJ925Q", "key": { "use": "sig", "kty": "EC", "kid": "DC06fatJ5nALkfEubR3VVgQ2XNy_DXSKZhwGoRO8cWU", "crv": "P-256", "alg": "ES256", "x": "SuaL-GJ3LmgBF43Da9ZCY-BzmvlkMJ61MAZ1UELPpTw", "y": "wnqZSMuXpmUxORq20t83LyY4BDYmqDGV9P7FGR6mw84" } }, { "name": "step-cli", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg", "key": { "use": "sig", "kty": "EC", "kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc", "crv": "P-256", "alg": "ES256", "x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8", "y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y" } }, { "name": "mariano", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", "key": { "use": "sig", "kty": "EC", "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", "crv": "P-256", "alg": "ES256", "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" }, "claims": { "minTLSCertDuration": "1s" } }, { "name": "maxey", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6Ik5SLTk5ZkVMSm1CLW1FZGllUlFFc3cifQ.Fr314BEUGTda4ICJl2uxFdjpEUGGqJEV.gBbu_DZE1ONDu14r.X-7MKMyokZIF1HTCVqqL0tTWgaC1ZGZBLLltd11ZUhQTswo_8kvgiTv3cFShj7ATF0tAY8HStyJmzLO8mKPVOPDXSwjdNsPriZclI6JWGi9iOu8pEiN9pZM6-itxan1JMcDUNg2U-P1BmKppHRbDKsOTivymfRyeUk51dBIlS54p5xNK1HFLc1YtWC1Rc_ngYVqOgqlhIrCHArAEBe3jrfUaH2ym-8fkVdwVqtxmte3XXK9g8FchsygRNnOKtRcr0TyzTUV-7bPi8_t02Zi-EHLFaSawVXWV_Qk1GeLYJR22Rp74beo-b5-lCNVp10btO0xdGySUWmCJ4v4_QZw.c8unwWycwtfdJMM_0b0fuA", "key": { "use": "sig", "kty": "EC", "kid": "kA5qxq_k8VFc2vzriBUU1FdzHpRfQ5Uq4W3803l1m5U", "crv": "P-256", "alg": "ES256", "x": "qGXXrT1vgRKVpqLoVwdgIut5VjvxrHa_V4xhh2kQvY0", "y": "8YHQPb031kQ9gMG8ue-YRy0Fm8Gc-v6TnYYLxRGcSjw" } } ], "template": { "country": "US", "locality": "San Francisco", "organization": "Smallstep" } } } ================================================ FILE: ca/testdata/federated-ca.json ================================================ { "root": "testdata/rotated/root_ca.crt", "federatedRoots": ["testdata/secrets/root_ca.crt"], "crt": "testdata/rotated/intermediate_ca.crt", "key": "testdata/rotated/intermediate_ca_key", "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false, "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" ] }, "authority": { "provisioners": [ { "name": "mariano", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", "key": { "use": "sig", "kty": "EC", "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", "crv": "P-256", "alg": "ES256", "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" }, "claims": { "minTLSCertDuration": "1s", "defaultTLSCertDuration": "5s" } } ], "template": { "country": "US", "locality": "San Francisco", "organization": "Smallstep" } } } ================================================ FILE: ca/testdata/rotate-ca-0.json ================================================ { "root": "testdata/secrets/root_ca.crt", "crt": "testdata/secrets/intermediate_ca.crt", "key": "testdata/secrets/intermediate_ca_key", "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false, "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" ] }, "authority": { "provisioners": [ { "name": "mariano", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", "key": { "use": "sig", "kty": "EC", "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", "crv": "P-256", "alg": "ES256", "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" }, "claims": { "minTLSCertDuration": "1s", "defaultTLSCertDuration": "5s" } } ], "template": { "country": "US", "locality": "San Francisco", "organization": "Smallstep" } } } ================================================ FILE: ca/testdata/rotate-ca-1.json ================================================ { "root": ["testdata/secrets/root_ca.crt", "testdata/rotated/root_ca.crt"], "crt": "testdata/secrets/intermediate_ca.crt", "key": "testdata/secrets/intermediate_ca_key", "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false, "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" ] }, "authority": { "provisioners": [ { "name": "mariano", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", "key": { "use": "sig", "kty": "EC", "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", "crv": "P-256", "alg": "ES256", "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" }, "claims": { "minTLSCertDuration": "1s", "defaultTLSCertDuration": "5s" } } ], "template": { "country": "US", "locality": "San Francisco", "organization": "Smallstep" } } } ================================================ FILE: ca/testdata/rotate-ca-2.json ================================================ { "root": ["testdata/rotated/root_ca.crt", "testdata/secrets/root_ca.crt"], "crt": "testdata/rotated/intermediate_ca.crt", "key": "testdata/rotated/intermediate_ca_key", "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false, "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" ] }, "authority": { "provisioners": [ { "name": "mariano", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", "key": { "use": "sig", "kty": "EC", "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", "crv": "P-256", "alg": "ES256", "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" }, "claims": { "minTLSCertDuration": "1s", "defaultTLSCertDuration": "5s" } } ], "template": { "country": "US", "locality": "San Francisco", "organization": "Smallstep" } } } ================================================ FILE: ca/testdata/rotate-ca-3.json ================================================ { "root": "testdata/rotated/root_ca.crt", "crt": "testdata/rotated/intermediate_ca.crt", "key": "testdata/rotated/intermediate_ca_key", "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false, "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" ] }, "authority": { "provisioners": [ { "name": "mariano", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", "key": { "use": "sig", "kty": "EC", "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", "crv": "P-256", "alg": "ES256", "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" }, "claims": { "minTLSCertDuration": "1s", "defaultTLSCertDuration": "5s" } } ], "template": { "country": "US", "locality": "San Francisco", "organization": "Smallstep" } } } ================================================ FILE: ca/testdata/rotated/intermediate_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBxTCCAWugAwIBAgIQLIY6MR/1fBRQY4ZTTsPAJjAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTAxMDcyMDExMzBaFw0yOTAx MDQyMDExMzBaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARgtjL/KLNpdq81YYWaek1lrkPM/QF1 m+ujwv5jya21fAXljdBLh6m2xco1GPfwPBbwUGlNOdEqE9Nq3Qx3ngPKo4GGMIGD MA4GA1UdDwEB/wQEAwIBpjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIw EgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUqixeZ/K1HW9N6SVw7ONya98S u8UwHwYDVR0jBBgwFoAUgIzlCLxh/RlwEany4JQHOorLAIEwCgYIKoZIzj0EAwID SAAwRQIgdGX6lxThrKlt3v+3HJZlaWdmoeQ3vYwpJb9uHExZdVYCIQDCxsdI8EnB bxjnJscbT4zvqVsq6AmycdbFwgy8RIeVzg== -----END CERTIFICATE----- ================================================ FILE: ca/testdata/rotated/intermediate_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,7dcc0a8c1d73c8d438184e0928875329 r6yrQrHg6zBZRSjQpe8RzyQALEfiT3/8lMvvPu3BX6yign5skMfCVMXZhzbmAwmR BJBIX+5hkudR2VN+hrsOyuU7FvIk4gx2c8buIlFObfYXIml0mpuThfm52ciAtOTE S0hkfYvPcOAjzaDZ+8Po/mYhkODgyvijogn4ioTF/Ss= -----END EC PRIVATE KEY----- ================================================ FILE: ca/testdata/rotated/root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB /wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc Sw== -----END CERTIFICATE----- ================================================ FILE: ca/testdata/rotated/root_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,8ce79d28601b9809905ef7c362a20749 H+pTTL3B5fLYycgHLxFOW0fZsayr7Y+BW8THKf12h8dk0/eOE1wNoX2TuMtpbZgO lMJdFPL+SAPCCmuZOZIcQDejRHVcYBq1wvrrnw/yfVawXC4xze+J4Y+q0J2WY+rM xcLGlEOIRZkvdDVGmSitEZBl0Ibk0p9tG++7QGqAvnk= -----END EC PRIVATE KEY----- ================================================ FILE: ca/testdata/rsaca.json ================================================ { "root": "../ca/testdata/secrets/rsa_root_ca.crt", "federatedRoots": [], "crt": "../ca/testdata/secrets/rsa_intermediate_ca.crt", "key": "../ca/testdata/secrets/rsa_intermediate_ca_key", "password": "1234", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.3, "renegotiation": false, "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" ] }, "authority": { "backdate": "0s", "provisioners": [ { "name": "scep", "type": "scep", "challenge": "not-so-secret" }, { "name": "step-cli", "type": "jwk", "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg", "key": { "use": "sig", "kty": "EC", "kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc", "crv": "P-256", "alg": "ES256", "x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8", "y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y" } } ], "template": { "country": "US", "locality": "San Francisco", "organization": "Smallstep" } } } ================================================ FILE: ca/testdata/secrets/federated_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB /wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc Sw== -----END CERTIFICATE----- ================================================ FILE: ca/testdata/secrets/intermediate_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIB0DCCAXWgAwIBAgIQaYEAv6hTHRU+ZEnIJ6VB7zAKBggqhkjOPQQDAjAhMR8w HQYDVQQDExZTbWFsbHN0ZXAgVGVzdCBSb290IENBMB4XDTE4MDkyNzE4MTgwOVoX DTI4MDkyNDE4MTgwOVowKTEnMCUGA1UEAxMeU21hbGxzdGVwIFRlc3QgSW50ZXJt ZWRpYXRlIENBMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUnFoY688av7AhSsP vAMXHuA66zdzujzw/Wx0F/ZkWagbo52zskTxElrTt/Qkiotv33EKTUaJ7mSV/ZhW DaI6TqOBhjCBgzAOBgNVHQ8BAf8EBAMCAaYwHQYDVR0lBBYwFAYIKwYBBQUHAwEG CCsGAQUFBwMCMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFAKELAm5/V3t 40xrDbKcDn5VWYThMB8GA1UdIwQYMBaAFAdgQF1Ej2WxY52Olc2wKVePE596MAoG CCqGSM49BAMCA0kAMEYCIQCoCUGx0W5wv3iQjlGIhux/zWZiDkyIbGj3ASeUL5v9 QgIhAJ8dVOcqW3oq2TF9hHv8tXjhwmK44krO/FMK4gHljo4i -----END CERTIFICATE----- ================================================ FILE: ca/testdata/secrets/intermediate_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,62bb1ccb9ed22ed553a479e34a4a0765 6lqTXwNel3jJjj+LdkA1E3Xr7bbeSukQLouFq2cbjh9Zyqb2xuhS2goxWZw0DDmG rhCCKyiQnR+ImuHAwZnKBouWvp6po8CR4C1STNAX45wPfIhPV3UA49xbiA1sM+AE QrlwCWVk9x/JhkZURK0T/3TWtdk9llcnhSKfAXnekAA= -----END EC PRIVATE KEY----- ================================================ FILE: ca/testdata/secrets/ott_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,f6870a50902e9397844faaf37f6196fc BVotbStC8KUiRyR6azjNu5nM1ER3/DtrdS/DxzDWJdWCPfayvQAU47DwoZdZ8Id2 Cu92bfKB0gQsgckPSfQhMC6sCd9JEiV7NqyLztDLnJJBmhml6fPMhoQaHAZy+qgW RiVrBaYXR92DTbtzFuYb03nmHeUVCjAT/R8Q21SCAfE= -----END EC PRIVATE KEY----- ================================================ FILE: ca/testdata/secrets/ott_key.public ================================================ -----BEGIN PUBLIC KEY----- MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEtTKthEHN7RuybhkaC43J2oLfBG99 5FNSWbtahLAiK7Z7fDJxfBUHfroXTAsTkn2AimrwQhDj3TSccE2kgZ5sQA== -----END PUBLIC KEY----- ================================================ FILE: ca/testdata/secrets/ott_mariano_priv.jwk ================================================ { "protected": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ", "encrypted_key": "7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB", "iv": "u-54daK2y-0UO9na", "ciphertext": "3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts", "tag": "vSYfxsi2UU9LQeySDjAnnQ" } ================================================ FILE: ca/testdata/secrets/ott_mariano_pub.jwk ================================================ { "use": "sig", "kty": "EC", "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", "crv": "P-256", "alg": "ES256", "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" } ================================================ FILE: ca/testdata/secrets/root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET /A8LXNH4M06A7vE= -----END CERTIFICATE----- ================================================ FILE: ca/testdata/secrets/root_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,3e0252253bf2ca8a21087f2f36c3bb4d YlSY9zZ7jEMEWqgk3IT3B+WuJrnAMn9OBtMeWMo9FL1eQFLfAJBwKiKdEUYyeAwi qi4nxx4MvfpkN02B53rmObUmAWQsxOPlMY3/KVkwQ1ovT/+eC/BGieBMvm/1aOYu 7/rnNAvI/3gWrbQ59mW6pr2qjK2eHr08s6S6GUx3C2E= -----END EC PRIVATE KEY----- ================================================ FILE: ca/testdata/secrets/rsa_intermediate_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIFJTCCAw2gAwIBAgIRAMBEHdXQtHUla+J13aUn/0gwDQYJKoZIhvcNAQELBQAw FjEUMBIGA1UEAxMLcnNhLXJvb3QtY2EwHhcNMjIxMjAyMTE0MzE2WhcNMzIxMTI5 MTE0MzE2WjAeMRwwGgYDVQQDExNyc2EtaW50ZXJtZWRpYXRlLWNhMIICIjANBgkq hkiG9w0BAQEFAAOCAg8AMIICCgKCAgEArxVkidtUrM6KIdGZ8a2QtJWezrTxTiEM lDeYqLd4CKp1bjQ7JOi1uc0mBG0Y4u5NwQRDk3L2aulLrENsPx4PMsPwMPXZgw67 zTTuug1/uec8phW9IvEqu8FDQhFCMzZZMmc/0UTLmhJq5NZhIU8SQ6XYF/5s11Gm zBbBG1CEV6KcwVul8+T/GcHr60h2/X4uRkibEdUsDy0jHFLMPOWMeKQXoA8hVWHc QRYInRS5q+aFZ79YqMTUFT2tKdgSCiDsm6MqAPhFVB20ZrxMU6zco67+DBKAzSGy qO0H6fxkStN4RBrCFTgUdyUPwSe5xCOVfR4JbF8pXMI9cA7iCT0Mw9ZgbTncKVdn epwIZfqqYMP0C3EL+BZOSfEQeXIq7qlmHKwRRkc010ZaLmbKB9Kug/HcsS3CevU2 J0Efosi2xfMcfhi11rAfKvZpyAuOVap7BONro3yYXjv6Co9sDWtyK6VkLsczp2MM NHxhzjGXAcQdnU79UbGxO67imZm6FYLTwcg/6SVrfh+slLJ5nCyXqC/LaQ+Mc7Q+ mdibgOzHSYg/QHVamic0uqn4BLw8QjICIZAnWWJHYjVgCieZrvK/7BGOjQ8+LT/8 NhjI6MSuNMcXLxyOciiPw1r8fT/NUbJZblMDhibGTaOFCoMc3niY/fwxPb3p1J8I tmOLoK8HCysCAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYB Af8CAQAwHQYDVR0OBBYEFO0ULj6Dt1RakbRqV4rVFUdRHK3KMB8GA1UdIwQYMBaA FCd9WZYMPpfDLBKjySFENwIXJpuzMA0GCSqGSIb3DQEBCwUAA4ICAQCiWxrj4HqV J9tGj59Ea2cMZUcBfGPYh4dZ0af6IlNZnqW9ZlmNNF/h0VvCpd28STZlkW7hp2Xb RcJ0tXs3MvnU0Sqzw8ZTevJgIIbiOIwndfmi4apfSC63JXftBkThP0xpR5LI/4pH UPYyeGA13fynH4YmO4QBsGEXlKMKSYSjwrheYKkSB73AYlc7r8OqE/NAVHc1xzov 9GT4p7w+tF6vrgzUtwqpAEVM/3USmSx4rgSdkI4DPkrYb1HEqT8ixOIH/3IG42ag UZgICckBPqcki8UbnU4nbxWVGJd18FE2n4wC2erewlBL+1PJFTmgDEKmOlcabot8 QEk/YOpMThCm79VGuFB7frXoFefLCl5q1K5yV1eDsmr79ZFIy2WM2alnVk2Cvk/9 oJQQ42AWRVHGFuaIrG+hLLtwq17MnoeyQ/A2IRlpWu7DpaCVfuPA+3yQC06qo98u A3vGpifN8eohTSEMYNGQAsUsArYPwMEp/QrP4EwK8YnaJtd2HCnG4VS3D+RenRIF 04b8EXX64ePD07uzPh7dKpWfmdJf1xj8GSndw2vk14KYDOvjrXirVkNCXFxgU9jp uTLGU/7Panm81xQgjeNwRaXxWvvDSrQaKMZ1QL6i0U7OTso0Q4VHivGG7IDhYSkA zNRdjmJnuap8XWGs/4xKjMJcv12UtnaMgw== -----END CERTIFICATE----- ================================================ FILE: ca/testdata/secrets/rsa_intermediate_ca_key ================================================ -----BEGIN RSA PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,03e26f42f8642e55946bcad62fef0c2e 54jydVXdnixOccnF90L9pkfsy4mrRC9xyl4BbZMaYwplZC+LE+U80GAdXOqSxBEo sQBz+OTYaq2bmT2MDnoHty8I4vdDTVmxovc+NdtCJdC+etc2bSEKt68K57BPEqa0 o7SE5Lk39zSDIFkyltQeYNII8sCX7H26kRsfZhmDYPoFXGCfnxrEQoASaF8S9n3l 9yERxk4untsVpvOPPde6Vn3b40ALqg0J0PaqzIbWifbWL8Uu3IeP27VHJLS4AH23 emkWaZiT1bjWNevwWiU0REZ1CxyShaggJa4YwXPJJyRcQlvnVMZ8+DjXoQ1EdSGA EGMfG6i5zDrRAdRDRgbJM56wZqIWup+/Kd0WyVGOteGFhzyl8Pad65NGYP9saPE/ P0/Wi51t30KllF6i6XHATeAKPgGAMkl8E9x9KCQVqGEWi8Ceu3w5AMxC1tcwB0Xy 1X9NBipHaDh0DneTTdRRpwGCEIkZefDwy0z4rgsxrbKyY0YP1NKsFt+rNFkdNSnK RevnejtYHSDjOyGImnLRJ0c2nxwet93hfY1g3yzagKtWUp/TXOO7EkggqUPObQhC n9U5tkPxvHTCXSzeK3QqrbReyb3AlEay8Th8R8roxcClV83E4vcjjuvitcJ0MbSW +/jCU1WhCanat67je749MB19msA95XYxNsAmCn17vJIVRI/QBS9HQCkf7UW1Jptm hU06/7sytuOFboXh/xhfoQUomlx8Hl/GqV2yGZyL7SsH4sxoT9cVCO2vXCnr/r63 Uo1nkEHQNddbBCR7yvjoeeq5PypGxZibC7YzWx87Hwcr8dEhBwzoqeIFkhnEVMyq Y3xFIilqqRaLwG1c77wy5jReTv/OTJ2OU2VFDu3Pf7zAOcGtcQazcNv7PMkiqunK Bp/vDLL+LiaWO/5Zl49DFPTGkRj5kNK/aNajfHyw0hvYYytiaGaH9DM3L+7kC7YG 2le8eLbUgZ7tqw3P2KueCK1F6Ef5X2It2sjxv/w5hz6lDtGfEIVXJuOamSUEewkY 9xM9njmqFhQjb71Khm3+/HUoxvmOebpuQ884xORfvzJ1rl8IHA84VTo8/XKp3EST yMC39rGhtVuADHvNz3Y/WAWIbrJkkdZvMXyYKoOTosNVeFjJxyfKlz8ZMYmy6cM6 mjOcsaI8xYUslYtpj/7vAjtcF4tJv94cQB/KGdUc/Z5JQ3r8zooG8ghEPt/5jiEr 4ECCK7btew0mexVv+HY3rX7UiPCHugfX6+XEIxQ8+AsM27FNFaKxjxTE2r9h95mP jmcWO7YqyqyeEZmKoxNo5oLMKXIDKxzK6ianJYg65xMnT+cH5vcnVaQKaC2QcnMI TiLOz/+ZdJSz2FiyE4myjnp9COKQhsDOfQA/1xzPF/4dqWMyWijGnlcozCHlU0i+ 2oG7izmDl9zn79v8VH6y0WjeEywoH5XlrF5eKBA2g7AtB8MCJTpIRVazTRbvhjaP EXr+Zk6vPVlDS0KOIUJ4V8iYcatdoaJz1fM3XjVZ6Wwy8TaYd9EBwWlWdFDx6r3s 1aT5fDDyZNjnTx80OHyWT2IS/+/FrColWGc9s/t5raFm3KEnvVpFc+7/AKOV3keB +3KVSg4ILLDYf7PfMrT2IPrWObuUXZ2InZPEG3T7BOtbbdO8BDbDng1xLxPGDFgQ zKUFngsPO90PoDmNUZ9dBZ/oOI54e38hqUGB7vdTsNlX+VTK4n+qb8w7GzNhGgnR fTP927HeuFBdq8Y2ngxt2i6vg9yo7Ojd+nG5OLj2T7uyNraKdaaBx4Nd9ZUZbNGt 4EueDSHCALKBsimLl4DfnMDnUK3G79dsoazs/nUr5y7kaUlkBGNZ/iSuoqpgeTKU jsTmVjRj4W5opC+UUBiY/tE7qHGczLDw/mw/NP14nQ5iFdwi6EJv3viprYL+zL/A zRTkcQ0KqBfc1ChVWhvxIg7QCsnPT6+y0yn2k6n4a9cUvQXcOQKqF5eOJWPE3ZeC 7fgIwt7ZdqHPZHyMxAnwWbmsj1Tn09SBW1b7S4t50aAPTRjDmrp5iC4vK59L7qPZ ekoft0VduaJlKqq90Bh5ouRvTO6ytDI261bbEIGQqH1nJVt12bhNA6h3xI3Iwn/E qlMLAN1M36LenUEp9l77AfiFU1f+d8ZP2U6bJo4FKTnRR33R6+89sezUmVEuqozt qONJo0DE9XSAVhxpVX1QF91RjrJSiQtNyRkaOEyTsw2VvJpQNI5GAQN0TbcqgCVD aUqPUuwntC7Wx7PkF6OR07rVxSIvhXs1NlG09nPZVByVCRJf/zKp9jMcRVJSXp16 +Sqw4qifz/INEPGPgM3vr0GdvEN27S1IEFUZDU0M+e6KcHeIoLEPnhQPnZfO/kPT 69gRFOZAcONvnGyP+Fj74fRWpWWIIN6b8oIzPN8tez9g+DdmXHf/LnD0fGIfhqPI GjjZcNJ8oa2F2qZfmwtrYs8UIChJxfZXK/lV7Jgf48ZDSF73war8nGHA/Sir4NsF 9cp3TxTSpXo2iXqb8ZH679q7OJ3UE7OiVKr2XzVEo7T/QSPnV4l9eiq9lDb/1cnS AFfm3m0+Zqy+uE+Qfkigt5jWXBLQ3DbJEUNriumsit5dMeh2zCMwtYsWC8fumJw1 6kJVZ7yEFXhFggTkHrgTZCI/9ym8FxCcz7W9qNy47h3aDOMs+yRidyl279FsKMR3 gkjZmvGyAuZRqNttqldexMGwH1qVPIwDtCHdwesdefAydr/9h/ElDzAyBG31u3zN 7Bp5/JkN9OycTvUB7SIMR80Q7wwPJngovRu1wdKQVZC+y/snJR6tQx9u+OuSHrB+ X0J4LFuxSj5PjsTH5y2o3UFbuKzxaIwbEibPvUc7FqW7O9/N4gYZaANgcodo0ozb ZjhcL+oE90AGQyKSKGna5bZWdokLQBOUyro442gKXAOVARMzEHwIIWwD3bm6Mj0a AmaMta3/LoCj54ESPFqRm7lCTmTj4gR6t5TED810hEimbxE8CBB6yrGTyj+vn+nH 9Wn1D+Pgo0QuHp1yBZI5xrFtX2Dm6TW7cKuv0oohgjd2WFKNIqzDhOeIslk3K9TL kcBqeYMDJ5xi/R5/dfE5yLg7WhsPcH5QcMO2I6Sm+smXWytB8zo3NkX5UXUTdWNp -----END RSA PRIVATE KEY----- ================================================ FILE: ca/testdata/secrets/rsa_root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIE/DCCAuSgAwIBAgIRAOFB5q6CzRilW0ERurTeSQ4wDQYJKoZIhvcNAQELBQAw FjEUMBIGA1UEAxMLcnNhLXJvb3QtY2EwHhcNMjIxMjAyMTE0MjI1WhcNMzIxMTI5 MTE0MjI1WjAWMRQwEgYDVQQDEwtyc2Etcm9vdC1jYTCCAiIwDQYJKoZIhvcNAQEB BQADggIPADCCAgoCggIBALIcD6VfJ6NZLWOhrLHr9au3WhKOmvt2gp+l53rjmwP3 PLApSnFi3PGE9gvwzdGd0XeIIithgj+FiZEk/gdWfjx3abjpNM4uTsjBweQ4d3uT zgH5h/AmGbSVUweqOCvmK5cingcvc2UGVbDo5VOP50bZR8O9NY2OQNgFHig7Z+xT eZSkGF7Sxm1zNMNU7BZqBNofFcwYDIaR/sBFuE9Im2qXj0duHbC1GXuVivE+iTDI ir52qsuobnXwEQyGe3EOwIAD9AMPsmmJ/vZSaVLFO0dIbSwTqB3nXaNC8+hA/dyX a9gEdVsSzKUiXfsk5awAOHOAEpCusywyJzZhhIyqot4rr3A3nuVOmg5utvJX2jMr wtGT7n7YhJWJVIcB/ahx/G7qwkcphEM7jnfweVgdDGTjcJ2tZchqx4U0axo+5wQy hebLz6z9QLkmfIMW0qjV6JcrYz2U1T4xSFmyNBhOrJQw4OFufSEWqYSJxoUHHOBn Dy4V98AhoIkK5UDTeTrQea5QJRGRhiCfl6VpuO1YAP/4oNrJa+rWrzYPU5bq3FF2 z2aCb9MAxnDQmfHfCSn6avioM2BcRQ8SfVVj1XsI4JtS7i7kqsHzuezJp28Jvll5 sOTGp6CNASLJg2zRE3LZbNuuZ3JlVDZPDHqOqci7Gw8xwNXZv1SNNVDBDLsN3sSd AgMBAAGjRTBDMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0G A1UdDgQWBBQnfVmWDD6XwywSo8khRDcCFyabszANBgkqhkiG9w0BAQsFAAOCAgEA mV/q1xjM9k+2Z9MhC7RXT0a/9bMVry9RiWp4xD09bPLRso+T9Pys/m222DxTjW6+ JAM1fwm6HKESeWHToIBnB1htIG2jMSC5wn2/oKfEFnJU16f4lE7aoFMHP6Pxhf9w dGXvb7Pbze1MHNtNabx5x2uVp5DLTjOjL2o7pufSXNpB3djx20jADx5KqqXQiIqk rMDi1rpWRnNT/IqkkmDdGbG9WyKp28z8HPW2Iyq80zp1d3diJvtRZTeDTBrc8NGk 96RpK1IVY0c8Z56UfecILuthm18ChSxm8DTXdc1CA1e89fiZ/pfEXPrbYLdcq8/b WQjA39z0zTiGC6gjd0g5hGeXZ5ThuW0s1EwpWmcF5bvHOxK2SOtYzxxy6bhbOzU7 4J0uCj+GIR7eKtdrHdRv0cHFPE4/XDEI/93UCJjOphNekSKGUiQKzTZhjP7g6DdM bBtsdEwkVckqFTrOlHy1aDfoUzuOB8DDwSs/59h/0a2MtGBq1MAjLaZUlDAUUbYO x8VbloQHxcEdrUYmIGEhoI+zPz6Bm2xsaIs72R10y9PfFV5xY9JcsnA9AvJ2KOHo RH7gmqh7GyqCNcQf7bfhC2SLMa8luEn0tQFVx7F/vbO1rzpvsEvtsvHka1SEEqS/ ctNS8RyWPh92jaJQ4U9nWMHOJJZ4LYW38gsMc/om7+k= -----END CERTIFICATE----- ================================================ FILE: ca/testdata/secrets/rsa_root_ca_key ================================================ -----BEGIN RSA PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,2b2138d58dc4fe659251306226ee53ef f6H0Rvs3FmbjNk31qTe/CikGW6oFT3p+6A/g6E7gnloHuxVv4HdM0RDHOUvSMG49 hb2kLYfbztJ7+RyGdc5JhozgfwRJsSP+iT0JbDQyHlTWzG5YasMnGrMbeLayn0Bw fhawiDOaBzK36hBFnx3aE5D3MOEbJM81/tZ7SoAovvgLZmhTH5w6cGYSJle6Fgey 47skoiuRX4JJ1Us6aiME203l6AdPEs01XVPRQFZHMdbTCQ5ZVeH/BS2GHn2vfosg PLJ6RUQIILuBytRwiWZIXoVZDI7T0d6eiUizj2cyIi58rypkGDDeDvN6Uzq/r8Or epwo28YlIDRz4H40XIGVDnD8LcIbAmcfe2FTz+TbTTcQcBySnuQQJhJ6aGDnE2LG 5QPSZOLAStlMP6ceGB6oeo7nBYLnqxUbDyeNeeIfBmf2NDgFpIwjgjs+QMXg1XFP /Z0BnKm/bmKKc94w3BwsAsZ0RwZTS+WyK+xKoXpQNVRoECKZt2oDTIPUX5no7dQO CQPOvJoYjGd+IS1jykvViYZYW4Lae08thWOMbWVTyV882/wpR7DN697w38VFLw2x q9dhd4wFZfzwGndO6xUq5h3qHGXg5xPS/ArvK5KGRXFusHI0HKqK5TeJwW2NBFky AhYPr/wdGdyL+mjU4ynjG4AekdAUi8t2Jpxf7+NuWGIbD/J00GPExOUuM68gqmuw 9wGXj0EaPEyBSWc7Sq95o5+eg5VjwLsGnEKLtYLJRuwOnQ+6LZddOsaNNHgafP3N yrDN4Xu2NBowzrbAPv4nFxQF6pNkAJTtTOimRQA5qwf04Er1KSw3SAs0s8WXTDnp kySMpvSibBo4CQE+XvjGISg+yTY7/Uj6lZrJFzwrl4Nne1k960qofY7B6D8sGSxk 8DZpsNkfY86juIBUri9pp2+nqEmj8NcK0gGpNgomYbPQHoQuudfWKER8JEX5hp9g ik3RZIpes3yKJYbzEKpeAOMRy2yS75B9DOpvIO7YPfUsjGVWnV55Cqni+z2QUOWR laRnRReQRQ/C3sinoFCEDZNmw5W5ex+iaGxj7d88tolFzvN6P7JdJTrq9kZ6pEIV yJnWT6dxoabxtyArpOAIwsEbeVXyFq1o0UF5x8Y1xOJOvlWallj0cZo/mHO+sFVT VLR1Ijh+klcKjJnU1s7yk/Ls/eRMJzSnk3iAV9WJuuOFyvpzmO26uTQh0f1rSrk/ k4DA9Klywo9OFlCvGU5xuRhISDEBBrxtKQMkefFQRBxclqZldDmbss2Zr4vfmhjx 5JdETi7q40Nt0kWsXi/XXIriEILvVIShYuER84aYSG2LQw3kOREA2BLOfJXYvRxc g3UzHviOYRpzPb7fJmOsSa1sRMWTKbZn1eBwFZbmqZmFboVzUmYUiFqFFCMGaPq0 afhkZGmM/dPgStruKEyXCcAHnsIFruNZGICnDUbQyAwXlw66fJG/IwL0FbTdhr3i 68wlLKA3uAAdTPkNQvef9Ed5b2xu9Yazt3ub93sKTbSzv1PZU+VVyrfmCVXXRrku ybRoLd4HAeuMKZ7jF4dLNzPDvJ6SfdMP7Qw/NoeCBogbtsstsHe+3hOEmBYlPZpT +AyXV/BNEvli9uBUlwy9B7B6s0hj1bxMnxHTCxEuCBf3EYgRlwcIRqSCi7EV6FuT 1ScpROJP3U1+FSF8b2pP7W23xAGtXUOBSoGvMlZcxF3+xB4L3zVMqfqwlbLXvSix QoXKYtESBmVVLT5jc+sWUelEynXowG+YaDVUyEBx99vlXAznQ3D99rD1dzvzx9Um TI0aV2IjeUgOXWP5b7rVs+GJ82DDUBBsZYEYK6JIpiBtdhAYhWbplusJCwdOAmyj +9JtfLrdJTohAn1smp285wHrHgdhLECEotaHqh8Cubrw5u66couCw7ibVSDrmshi 8xiPL3hp0jWbE10Lah4MK5pMLfjq2wOta435RuD3HNJu1nGGvEIb/+Malef8JzBW y7iABGlAHPNhcOheNQX/nXuTnUOwv69N03/i+/hWzGHIjYH/nI02EZ+CHtuCbeUd JCP3Ia1xCDEZJEtb2GmswgB5P06U7z2rZemb2HIWC6/Sors72WlEZhtis1y9/mRF 1pmGFsqmQCHk7XNrdZB56KjB4Kkj7eOE5xO01ALdZXs7nIhB7S9Sqk6Rtf+Th85N 1BT/esB3d30ORVu3TbV1uashC71ThtdNEpYNi441Yfs8u/c/c+7NtUoxBcIIvMEs FMCLs4Nqt1y2UxocWGQtii2EvjwStAgtNIGhq+/6SZVRIU3CyYm3RRx5eQ1VdfNh i+bOJlf5l8/gZXsaWwD2tBOCibml9GbFJeGPQi9Rc6AUUeTGmNRnA1PUcwbs96uz F/lmo1dms5jiV2+d+SFQgAujrJSRsST4GxpqDlU3T/anIknusTkOyyuP3Z3EY6eA LiY6sdKYj40IFdpM3aLl6LAIgkTXS1ji4nvfu5CAdBAsntTRVRB2Ew3ux8+ZsShg Rg/LMEmEP8oMq1JFrx9q2rlBghWyUdY5M+ZY/e8hGheMuaUGs8SeqWlI513+CvLw sWOUwnox+j9rjvj43Q3ac9mbqjwjykMpBDAMhAeJkW5FSK5gc6LPmRvUhfyv0De7 bgA6dpQYh6+l3yKoWmNQdFZ0YtuEc+wzzgbyUE1s/BOTB3WDLaBnUAw7R3nkTUyX 05t5b1NCcrj2fpe0DhRa7KqNQTVazEgZIkd0nPVGP8bmfMEMCXw2ri0wls0F4KkB Y52Ctx+/kQkP8HYJMV79RURNvI9204C8a+w09++w9rmHuUlGXfJ7/iVADRaXI1pM E+N4q7KrhcQYlRWthmwsol2unqtnTHjSyHiYtHeagNTt2eNkAqG61E+mtYsjQ6Al +aL3vi73hJ6oNLpT8Cb2S4XYDziIlKTtX4biZYJgkc/P4Ado0Z5ZhXqLnt+BsrDv FuqpZoHp0BA9qaCPuocL7Ne6cVTY1PGKS+Gkh9u+QWmrp1QGltNQNUiNUiuSKP79 41tdta3UYstwtuTydQPGbg71YPSXM6CqEUuYINP5yVSiO3k1aPA82Uxr3TYdnym7 D54ctp9HHk3SYpA/zdT5clNwyNiTv/bZ2Wa0DUpBRK3epvLVB6fyGlmSFnOtyelP -----END RSA PRIVATE KEY----- ================================================ FILE: ca/testdata/secrets/step_cli_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-128-CBC,e2c9c7cdad45b5032f1990b929cf83fd k3Yd307VgDrdllCBGN7PP8dOMQvEAUkq1lYtyxAWa7u/DuxeDP7SYlDB+xEk/UL8 bgoYYCProydEElYFzGg8Z98WYAzbNoP2p6PPPpAhOZsxJjc5OfTHf/OQleR8PjD5 ryN4woGuq7Tiq5xritlyhluPc91ODqMsm4P98X1sPYA= -----END EC PRIVATE KEY----- ================================================ FILE: ca/testdata/secrets/step_cli_key.public ================================================ -----BEGIN PUBLIC KEY----- MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE7ZdAAMZCFU4XwgblI5RfZouBi8lY mF6DlZusNNnsbm+xCvYl3PAPZ+DKvKYERdazEPEU2OOo3riostJst0tn1g== -----END PUBLIC KEY----- ================================================ FILE: ca/testdata/secrets/step_cli_key_priv.jwk ================================================ { "protected": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ", "encrypted_key": "XaN9zcPQeWt49zchUDm34FECUTHfQTn_", "iv": "tmNHPQDqR3ebsWfd", "ciphertext": "9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw", "tag": "thPcx3t1AUcWuEygXIY3Fg" } ================================================ FILE: ca/testdata/secrets/step_cli_key_pub.jwk ================================================ { "use": "sig", "kty": "EC", "kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc", "crv": "P-256", "alg": "ES256", "x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8", "y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y" } ================================================ FILE: ca/tls.go ================================================ package ca import ( "context" "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/pem" "net" "net/http" "os" "runtime" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/ca/identity" ) // mTLSDialContext will hold the dial context function to use in // getDefaultTransport. var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error) // localAddr is the local address to use when dialing an address. This address // is defined by the environment variable STEP_CLIENT_ADDR. var localAddr net.Addr func init() { // STEP_TLS_TUNNEL is an environment variable that can be set to do an TLS // over (m)TLS tunnel to step-ca using identity-like credentials. The value // is a path to a json file with the tunnel host, certificate, key and root // used to create the (m)TLS tunnel. // // The configuration should look like: // { // "type": "tTLS", // "host": "tunnel.example.com:443" // "crt": "/path/to/tunnel.crt", // "key": "/path/to/tunnel.key", // "root": "/path/to/tunnel-root.crt" // } // // This feature is EXPERIMENTAL and might change at any time. if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" { id, err := identity.LoadIdentity(path) if err != nil { panic(err) } if err := id.Validate(); err != nil { panic(err) } host, port, err := net.SplitHostPort(id.Host) if err != nil { panic(err) } pool, err := id.GetCertPool() if err != nil { panic(err) } mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) { d := &tls.Dialer{ NetDialer: createDefaultDialer(), Config: &tls.Config{ MinVersion: tls.VersionTLS12, RootCAs: pool, GetClientCertificate: id.GetClientCertificateFunc(), }, } return func(ctx context.Context, _, _ string) (net.Conn, error) { return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) } } } // STEP_CLIENT_ADDR is an environment variable that can be set to define the // local address to use when dialing an address. This can be useful when // step is run behind a CIDR-based ACL. // // STEP_CLIENT_ADDR can be set to an IP ("127.0.0.1", "[::1]"), a hostname // ("localhost"), or a host:port ("[::1]:0"). If the port is set to // something other than ":0" and the dialer is created multiple times it // will fail with an "address already in use" error. // // See https://github.com/smallstep/cli/issues/730 if v := os.Getenv("STEP_CLIENT_ADDR"); v != "" { _, _, err := net.SplitHostPort(v) if err != nil { // assuming that the error is a missing port, if it's not it will // panic below. v += ":0" } localAddr, err = net.ResolveTCPAddr("tcp", v) if err != nil { panic(err) } } } // GetClientTLSConfig returns a tls.Config for client use configured with the // sign certificate, and a new certificate pool with the sign root certificate. // The client certificate will automatically rotate before expiring. func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) { tlsConfig, _, err := c.getClientTLSConfig(ctx, sign, pk, options) if err != nil { return nil, err } return tlsConfig, nil } func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, http.RoundTripper, error) { cert, err := TLSCertificate(sign, pk) if err != nil { return nil, nil, err } renewer, err := NewTLSRenewer(cert, nil) if err != nil { return nil, nil, err } tlsConfig := getDefaultTLSConfig(sign) // Note that with GetClientCertificate tlsConfig.Certificates is not used. // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() tlsConfig.GetClientCertificate = renewer.GetClientCertificate // Apply options and initialize mutable tls.Config tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) if err := tlsCtx.apply(options); err != nil { return nil, nil, err } tr := getDefaultTransport(tlsConfig) tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) // Add decorator if available, and use the resulting [http.RoundTripper] // going forward rt := decorateRoundTripper(tr, c.transportDecorator) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, rt, pk) //nolint:contextcheck // deeply nested context // Update client transport c.SetTransport(rt) // Start renewer renewer.RunContext(ctx) return tlsConfig, rt, nil } // GetServerTLSConfig returns a tls.Config for server use configured with the // sign certificate, and a new certificate pool with the sign root certificate. // The returned tls.Config will only verify the client certificate if provided. // The server certificate will automatically rotate before expiring. func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) { cert, err := TLSCertificate(sign, pk) if err != nil { return nil, err } renewer, err := NewTLSRenewer(cert, nil) if err != nil { return nil, err } tlsConfig := getDefaultTLSConfig(sign) // Note that GetCertificate will only be called if the client supplies SNI // information or if tlsConfig.Certificates is empty. // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() tlsConfig.GetCertificate = renewer.GetCertificate tlsConfig.GetClientCertificate = renewer.GetClientCertificate tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert // Apply options and initialize mutable tls.Config tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) if err := tlsCtx.apply(options); err != nil { return nil, err } // GetConfigForClient allows seamless root and federated roots rotation. // If the return of the callback is not-nil, it will use the returned // tls.Config instead of the default one. tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx) // Update renew function with transport tr := getDefaultTransport(tlsConfig) tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) // Add decorator if available, and use the resulting [http.RoundTripper] // going forward rt := decorateRoundTripper(tr, c.transportDecorator) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, rt, pk) //nolint:contextcheck // deeply nested context // Update client transport c.SetTransport(rt) // Start renewer renewer.RunContext(ctx) return tlsConfig, nil } // Transport returns an [http.RoundTripper] configured to use the client // certificate from the sign response. func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (http.RoundTripper, error) { _, tr, err := c.getClientTLSConfig(ctx, sign, pk, options) if err != nil { return nil, err } return tr, nil } // buildGetConfigForClient returns an implementation of GetConfigForClient // callback in tls.Config. // // If the implementation returns a nil tls.Config, the original Config will be // used, but if it's non-nil, the returned Config will be used to handle this // connection. func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHelloInfo) (*tls.Config, error) { return func(*tls.ClientHelloInfo) (*tls.Config, error) { return ctx.mutableConfig.TLSConfig(), nil } } // buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport. func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) { d := createDefaultDialer() // TLS dialers do not support context, but we can use the context // deadline if it is set. if t, ok := ctx.Deadline(); ok { d.Deadline = t } return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig()) } } // Certificate returns the server or client certificate from the sign response. func Certificate(sign *api.SignResponse) (*x509.Certificate, error) { if sign.ServerPEM.Certificate == nil { return nil, errors.New("ca: certificate does not exist") } return sign.ServerPEM.Certificate, nil } // IntermediateCertificate returns the CA intermediate certificate from the sign // response. func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error) { if sign.CaPEM.Certificate == nil { return nil, errors.New("ca: certificate does not exist") } return sign.CaPEM.Certificate, nil } // RootCertificate returns the root certificate from the sign response. func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) { if sign == nil || sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 { return nil, errors.New("ca: certificate does not exist") } lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1] if len(lastChain) == 0 { return nil, errors.New("ca: certificate does not exist") } return lastChain[len(lastChain)-1], nil } // TLSCertificate creates a new TLS certificate from the sign response and the // private key used. func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certificate, error) { certPEM, err := getPEM(sign.ServerPEM) if err != nil { return nil, err } caPEM, err := getPEM(sign.CaPEM) if err != nil { return nil, err } keyPEM, err := getPEM(pk) if err != nil { return nil, err } //nolint:gocritic // using a new variable for clarity chain := append(certPEM, caPEM...) cert, err := tls.X509KeyPair(chain, keyPEM) if err != nil { return nil, errors.Wrap(err, "error creating tls certificate") } leaf, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { return nil, errors.Wrap(err, "error parsing tls certificate") } cert.Leaf = leaf return &cert, nil } func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { if sign.TLSOptions != nil { return sign.TLSOptions.TLSConfig() } return &tls.Config{ MinVersion: tls.VersionTLS12, } } // createDefaultDialer returns a new dialer with the default configuration. func createDefaultDialer() *net.Dialer { // With the KeepAlive parameter set to 0, it will be use Golang's default. return &net.Dialer{ Timeout: 30 * time.Second, LocalAddr: localAddr, } } // getDefaultTransport returns an http.Transport with the same parameters than // http.DefaultTransport, but adds the given tls.Config and configures the // transport for HTTP/2. func getDefaultTransport(tlsConfig *tls.Config) *http.Transport { var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error) switch { case runtime.GOOS == "js" && runtime.GOARCH == "wasm": // when running in js/wasm and using the default dialer context all requests // performed by the CA client resulted in a "protocol not supported" error. // By setting the dial context to nil requests will be handled by the browser // fetch API instead. Currently this will always set the dial context to nil, // but we could implement some additional logic similar to what's found in // https://github.com/golang/go/pull/46923/files to support a different dial // context if it is available, required and expected to work. dialContext = nil case mTLSDialContext == nil: d := createDefaultDialer() dialContext = d.DialContext default: dialContext = mTLSDialContext() } return &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: dialContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: tlsConfig, } } func getPEM(i interface{}) ([]byte, error) { block := new(pem.Block) switch i := i.(type) { case api.Certificate: block.Type = "CERTIFICATE" block.Bytes = i.Raw case *x509.Certificate: block.Type = "CERTIFICATE" block.Bytes = i.Raw case *rsa.PrivateKey: block.Type = "RSA PRIVATE KEY" block.Bytes = x509.MarshalPKCS1PrivateKey(i) case *ecdsa.PrivateKey: var err error block.Type = "EC PRIVATE KEY" block.Bytes, err = x509.MarshalECPrivateKey(i) if err != nil { return nil, errors.Wrap(err, "error marshaling private key") } case ed25519.PrivateKey: var err error block.Type = "PRIVATE KEY" block.Bytes, err = x509.MarshalPKCS8PrivateKey(i) if err != nil { return nil, errors.Wrap(err, "error marshaling private key") } default: return nil, errors.Errorf("unsupported key type %T", i) } return pem.EncodeToMemory(block), nil } func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr http.RoundTripper, pk crypto.PrivateKey) RenewFunc { return func() (*tls.Certificate, error) { // Close connections in keep-alive state defer client.CloseIdleConnections() // Get updated list of roots if err := ctx.applyRenew(); err != nil { return nil, err } // Get new certificate sign, err := client.Renew(tr) if err != nil { return nil, err } return TLSCertificate(sign, pk) } } ================================================ FILE: ca/tls_options.go ================================================ package ca import ( "crypto/tls" "crypto/x509" "github.com/smallstep/certificates/api" ) // TLSOption defines the type of a function that modifies a tls.Config. type TLSOption func(ctx *TLSOptionCtx) error // TLSOptionCtx is the context modified on TLSOption methods. type TLSOptionCtx struct { Client *Client Config *tls.Config Sign *api.SignResponse OnRenewFunc []TLSOption mutableConfig *mutableTLSConfig hasRootCA bool hasClientCA bool } // newTLSOptionCtx creates the TLSOption context. func newTLSOptionCtx(c *Client, config *tls.Config, sign *api.SignResponse) *TLSOptionCtx { return &TLSOptionCtx{ Client: c, Config: config, Sign: sign, mutableConfig: newMutableTLSConfig(), } } func (ctx *TLSOptionCtx) apply(options []TLSOption) error { for _, fn := range options { if err := fn(ctx); err != nil { return err } } // Initialize mutable config with the fully configured tls.Config ctx.mutableConfig.Init(ctx.Config) // Build RootCAs and ClientCAs with given root certificate if necessary if root, err := RootCertificate(ctx.Sign); err == nil { if !ctx.hasRootCA { if ctx.Config.RootCAs == nil { ctx.Config.RootCAs = x509.NewCertPool() } ctx.Config.RootCAs.AddCert(root) ctx.mutableConfig.AddImmutableRootCACert(root) } if !ctx.hasClientCA && ctx.Config.ClientAuth != tls.NoClientCert { if ctx.Config.ClientCAs == nil { ctx.Config.ClientCAs = x509.NewCertPool() } ctx.Config.ClientCAs.AddCert(root) ctx.mutableConfig.AddImmutableClientCACert(root) } } // Update tls.Config with mutable data if ctx.Config.RootCAs == nil && len(ctx.mutableConfig.mutRootCerts) > 0 { ctx.Config.RootCAs = x509.NewCertPool() } if ctx.Config.ClientCAs == nil && len(ctx.mutableConfig.mutClientCerts) > 0 { ctx.Config.ClientCAs = x509.NewCertPool() } // Add mutable certificates for _, cert := range ctx.mutableConfig.mutRootCerts { ctx.Config.RootCAs.AddCert(cert) } for _, cert := range ctx.mutableConfig.mutClientCerts { ctx.Config.ClientCAs.AddCert(cert) } ctx.mutableConfig.Reload() return nil } func (ctx *TLSOptionCtx) applyRenew() error { for _, fn := range ctx.OnRenewFunc { if err := fn(ctx); err != nil { return err } } // Reload mutable config with the changes ctx.mutableConfig.Reload() return nil } // RequireAndVerifyClientCert is a tls.Config option used on servers to enforce // a valid TLS client certificate. This is the default option for mTLS servers. func RequireAndVerifyClientCert() TLSOption { return func(ctx *TLSOptionCtx) error { ctx.Config.ClientAuth = tls.RequireAndVerifyClientCert return nil } } // VerifyClientCertIfGiven is a tls.Config option used on on servers to validate // a TLS client certificate if it is provided. It does not requires a certificate. func VerifyClientCertIfGiven() TLSOption { return func(ctx *TLSOptionCtx) error { ctx.Config.ClientAuth = tls.VerifyClientCertIfGiven return nil } } // AddRootCA adds to the tls.Config RootCAs the given certificate. RootCAs // defines the set of root certificate authorities that clients use when // verifying server certificates. func AddRootCA(cert *x509.Certificate) TLSOption { return func(ctx *TLSOptionCtx) error { if ctx.Config.RootCAs == nil { ctx.Config.RootCAs = x509.NewCertPool() } ctx.hasRootCA = true ctx.Config.RootCAs.AddCert(cert) ctx.mutableConfig.AddImmutableRootCACert(cert) return nil } } // AddClientCA adds to the tls.Config ClientCAs the given certificate. ClientCAs // defines the set of root certificate authorities that servers use if required // to verify a client certificate by the policy in ClientAuth. func AddClientCA(cert *x509.Certificate) TLSOption { return func(ctx *TLSOptionCtx) error { if ctx.Config.ClientCAs == nil { ctx.Config.ClientCAs = x509.NewCertPool() } ctx.hasClientCA = true ctx.Config.ClientCAs.AddCert(cert) ctx.mutableConfig.AddImmutableClientCACert(cert) return nil } } // AddRootsToRootCAs does a roots request and adds to the tls.Config RootCAs all // the certificates in the response. RootCAs defines the set of root certificate // authorities that clients use when verifying server certificates. // // BootstrapServer and BootstrapClient methods include this option by default. func AddRootsToRootCAs() TLSOption { // var once sync.Once fn := func(ctx *TLSOptionCtx) error { certs, err := ctx.Client.Roots() if err != nil { return err } ctx.hasRootCA = true ctx.mutableConfig.AddRootCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) return fn(ctx) } } // AddRootsToClientCAs does a roots request and adds to the tls.Config ClientCAs // all the certificates in the response. ClientCAs defines the set of root // certificate authorities that servers use if required to verify a client // certificate by the policy in ClientAuth. // // BootstrapServer method includes this option by default. func AddRootsToClientCAs() TLSOption { // var once sync.Once fn := func(ctx *TLSOptionCtx) error { certs, err := ctx.Client.Roots() if err != nil { return err } ctx.hasClientCA = true ctx.mutableConfig.AddClientCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) return fn(ctx) } } // AddFederationToRootCAs does a federation request and adds to the tls.Config // RootCAs all the certificates in the response. RootCAs defines the set of root // certificate authorities that clients use when verifying server certificates. func AddFederationToRootCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { certs, err := ctx.Client.Federation() if err != nil { return err } ctx.mutableConfig.AddRootCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) return fn(ctx) } } // AddFederationToClientCAs does a federation request and adds to the tls.Config // ClientCAs all the certificates in the response. ClientCAs defines the set of // root certificate authorities that servers use if required to verify a client // certificate by the policy in ClientAuth. func AddFederationToClientCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { certs, err := ctx.Client.Federation() if err != nil { return err } ctx.mutableConfig.AddClientCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) return fn(ctx) } } // AddRootsToCAs does a roots request and adds the resulting certs to the // tls.Config RootCAs and ClientCAs. Combines the functionality of // AddRootsToRootCAs and AddRootsToClientCAs. func AddRootsToCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { certs, err := ctx.Client.Roots() if err != nil { return err } ctx.hasRootCA = true ctx.hasClientCA = true ctx.mutableConfig.AddRootCAs(certs.Certificates) ctx.mutableConfig.AddClientCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) return fn(ctx) } } // AddFederationToCAs does a federation request and adds the resulting certs to the // tls.Config RootCAs and ClientCAs. Combines the functionality of // AddFederationToRootCAs and AddFederationToClientCAs. func AddFederationToCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { certs, err := ctx.Client.Federation() if err != nil { return err } if ctx.mutableConfig == nil { if ctx.Config.RootCAs == nil { ctx.Config.RootCAs = x509.NewCertPool() } if ctx.Config.ClientCAs == nil { ctx.Config.ClientCAs = x509.NewCertPool() } for _, cert := range certs.Certificates { ctx.Config.RootCAs.AddCert(cert.Certificate) ctx.Config.ClientCAs.AddCert(cert.Certificate) } } else { ctx.mutableConfig.AddRootCAs(certs.Certificates) ctx.mutableConfig.AddClientCAs(certs.Certificates) } return nil } return func(ctx *TLSOptionCtx) error { ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) return fn(ctx) } } ================================================ FILE: ca/tls_options_test.go ================================================ package ca import ( "crypto/tls" "crypto/x509" "fmt" "net/http" "os" "reflect" "sort" "testing" "github.com/stretchr/testify/require" "github.com/smallstep/certificates/api" ) //nolint:gosec // test tls config func Test_newTLSOptionCtx(t *testing.T) { client, err := NewClient("https://ca.smallstep.com", WithTransport(http.DefaultTransport)) if err != nil { t.Fatalf("NewClient() error = %v", err) } type args struct { c *Client config *tls.Config sign *api.SignResponse } tests := []struct { name string args args want *TLSOptionCtx }{ {"ok", args{client, &tls.Config{}, &api.SignResponse{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, Sign: &api.SignResponse{}, mutableConfig: newMutableTLSConfig()}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := newTLSOptionCtx(tt.args.c, tt.args.config, tt.args.sign); !reflect.DeepEqual(got, tt.want) { t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want) } }) } } //nolint:gosec // test tls config func TestTLSOptionCtx_apply(t *testing.T) { fail := func() TLSOption { return func(ctx *TLSOptionCtx) error { return fmt.Errorf("an error") } } type fields struct { Config *tls.Config } type args struct { options []TLSOption } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{&tls.Config{}}, args{[]TLSOption{RequireAndVerifyClientCert()}}, false}, {"ok", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven()}}, false}, {"fail", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven(), fail()}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Config: tt.fields.Config, mutableConfig: newMutableTLSConfig(), } if err := ctx.apply(tt.args.options); (err != nil) != tt.wantErr { t.Errorf("TLSOptionCtx.apply() error = %v, wantErr %v", err, tt.wantErr) } }) } } //nolint:gosec // test tls config func TestRequireAndVerifyClientCert(t *testing.T) { tests := []struct { name string want *tls.Config }{ {"ok", &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Config: &tls.Config{}, mutableConfig: newMutableTLSConfig(), } if err := RequireAndVerifyClientCert()(ctx); err != nil { t.Errorf("RequireAndVerifyClientCert() error = %v", err) return } if !reflect.DeepEqual(ctx.Config, tt.want) { t.Errorf("RequireAndVerifyClientCert() = %v, want %v", ctx.Config, tt.want) } }) } } //nolint:gosec // test tls config func TestVerifyClientCertIfGiven(t *testing.T) { tests := []struct { name string want *tls.Config }{ {"ok", &tls.Config{ClientAuth: tls.VerifyClientCertIfGiven}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Config: &tls.Config{}, mutableConfig: newMutableTLSConfig(), } if err := VerifyClientCertIfGiven()(ctx); err != nil { t.Errorf("VerifyClientCertIfGiven() error = %v", err) return } if !reflect.DeepEqual(ctx.Config, tt.want) { t.Errorf("VerifyClientCertIfGiven() = %v, want %v", ctx.Config, tt.want) } }) } } //nolint:gosec // test tls config func TestAddRootCA(t *testing.T) { cert := parseCertificate(t, rootPEM) pool := x509.NewCertPool() pool.AddCert(cert) type args struct { cert *x509.Certificate } tests := []struct { name string args args want *tls.Config }{ {"ok", args{cert}, &tls.Config{RootCAs: pool}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Config: &tls.Config{}, mutableConfig: newMutableTLSConfig(), } if err := AddRootCA(tt.args.cert)(ctx); err != nil { t.Errorf("AddRootCA() error = %v", err) return } if !reflect.DeepEqual(ctx.Config, tt.want) && !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) { t.Errorf("AddRootCA() = %v, want %v", ctx.Config, tt.want) } }) } } //nolint:gosec // test tls config func TestAddClientCA(t *testing.T) { cert := parseCertificate(t, rootPEM) pool := x509.NewCertPool() pool.AddCert(cert) type args struct { cert *x509.Certificate } tests := []struct { name string args args want *tls.Config }{ {"ok", args{cert}, &tls.Config{ClientCAs: pool}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Config: &tls.Config{}, mutableConfig: newMutableTLSConfig(), } if err := AddClientCA(tt.args.cert)(ctx); err != nil { t.Errorf("AddClientCA() error = %v", err) return } if !reflect.DeepEqual(ctx.Config, tt.want) && !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) { t.Errorf("AddClientCA() = %v, want %v", ctx.Config, tt.want) } }) } } //nolint:gosec // test tls config func TestAddRootsToRootCAs(t *testing.T) { ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") require.NoError(t, err) cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) type args struct { client *Client config *tls.Config } tests := []struct { name string args args want *tls.Config wantErr bool }{ {"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false}, {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Client: tt.args.client, Config: tt.args.config, mutableConfig: newMutableTLSConfig(), } if err := ctx.apply([]TLSOption{AddRootsToRootCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr) return } if !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) { t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want) } }) } } //nolint:gosec // test tls config func TestAddRootsToClientCAs(t *testing.T) { ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") require.NoError(t, err) cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) type args struct { client *Client config *tls.Config } tests := []struct { name string args args want *tls.Config wantErr bool }{ {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false}, {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Client: tt.args.client, Config: tt.args.config, mutableConfig: newMutableTLSConfig(), } if err := ctx.apply([]TLSOption{AddRootsToClientCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr) return } if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) { t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want) } }) } } //nolint:gosec // test tls config func TestAddFederationToRootCAs(t *testing.T) { ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") require.NoError(t, err) crt1 := parseCertificate(t, string(root)) crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) type args struct { client *Client config *tls.Config } tests := []struct { name string args args want *tls.Config wantErr bool }{ {"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false}, {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Client: tt.args.client, Config: tt.args.config, mutableConfig: newMutableTLSConfig(), } if err := ctx.apply([]TLSOption{AddFederationToRootCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(ctx.Config, tt.want) { // Federated roots are randomly sorted if !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) || ctx.Config.ClientCAs != nil { t.Errorf("AddFederationToRootCAs() = %v, want %v", ctx.Config, tt.want) } } }) } } //nolint:gosec // test tls config func TestAddFederationToClientCAs(t *testing.T) { ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") require.NoError(t, err) crt1 := parseCertificate(t, string(root)) crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) type args struct { client *Client config *tls.Config } tests := []struct { name string args args want *tls.Config wantErr bool }{ {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false}, {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Client: tt.args.client, Config: tt.args.config, mutableConfig: newMutableTLSConfig(), } if err := ctx.apply([]TLSOption{AddFederationToClientCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(ctx.Config, tt.want) { // Federated roots are randomly sorted if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) || ctx.Config.RootCAs != nil { t.Errorf("AddFederationToClientCAs() = %v, want %v", ctx.Config, tt.want) } } }) } } //nolint:gosec // test tls config func TestAddRootsToCAs(t *testing.T) { ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") require.NoError(t, err) cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) type args struct { client *Client config *tls.Config } tests := []struct { name string args args want *tls.Config wantErr bool }{ {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Client: tt.args.client, Config: tt.args.config, mutableConfig: newMutableTLSConfig(), } if err := ctx.apply([]TLSOption{AddRootsToCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr) return } if !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) || !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) { t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want) } }) } } //nolint:gosec // test tls config func TestAddFederationToCAs(t *testing.T) { ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") require.NoError(t, err) crt1 := parseCertificate(t, string(root)) crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) type args struct { client *Client config *tls.Config } tests := []struct { name string args args want *tls.Config wantErr bool }{ {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ Client: tt.args.client, Config: tt.args.config, mutableConfig: newMutableTLSConfig(), } if err := ctx.apply([]TLSOption{AddFederationToCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(ctx.Config, tt.want) { // Federated roots are randomly sorted if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) || !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) { t.Errorf("AddFederationToCAs() = %v, want %v", ctx.Config, tt.want) } } }) } } //nolint:staticcheck,gocritic func equalPools(a, b *x509.CertPool) bool { if reflect.DeepEqual(a, b) { return true } subjects := a.Subjects() sA := make([]string, len(subjects)) for i := range subjects { sA[i] = string(subjects[i]) } subjects = b.Subjects() sB := make([]string, len(subjects)) for i := range subjects { sB[i] = string(subjects[i]) } sort.Strings(sA) sort.Strings(sB) return reflect.DeepEqual(sA, sB) } ================================================ FILE: ca/tls_test.go ================================================ package ca import ( "bytes" "context" "crypto" "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/hex" "io" "log" "net" "net/http" "net/http/httptest" "reflect" "testing" "time" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" ) func generateOTT(t *testing.T, subject string) string { t.Helper() now := time.Now() jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) require.NoError(t, err) opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts) require.NoError(t, err) id, err := randutil.ASCII(64) require.NoError(t, err) cl := struct { jose.Claims SANS []string `json:"sans"` }{ Claims: jose.Claims{ ID: id, Subject: subject, Issuer: "mariano", NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), Audience: []string{"https://127.0.0.1:0/sign"}, }, SANS: []string{subject}, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() require.NoError(t, err) return raw } func startTestServer(baseContext context.Context, tlsConfig *tls.Config, handler http.Handler) *httptest.Server { srv := httptest.NewUnstartedServer(handler) srv.TLS = tlsConfig // Base context MUST be set before the start of the server srv.Config.BaseContext = func(l net.Listener) context.Context { return baseContext } srv.StartTLS() // Force the use of GetCertificate on IPs srv.TLS.Certificates = nil return srv } func startCATestServer(t *testing.T) *httptest.Server { config, err := authority.LoadConfiguration("testdata/ca.json") require.NoError(t, err) ca, err := New(config) require.NoError(t, err) // Use a httptest.Server instead baseContext := buildContext(ca.auth, nil, nil, nil) srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler) return srv } func sign(t *testing.T, domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { t.Helper() srv := startCATestServer(t) defer srv.Close() return signDuration(t, srv, domain, 0) } func signDuration(t *testing.T, srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) { t.Helper() req, pk, err := CreateSignRequest(generateOTT(t, domain)) require.NoError(t, err) if duration > 0 { req.NotBefore = api.NewTimeDuration(time.Now()) req.NotAfter = api.NewTimeDuration(req.NotBefore.Time().Add(duration)) } client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt")) require.NoError(t, err) sr, err := client.Sign(req) require.NoError(t, err) return client, sr, pk } func serverHandler(t *testing.T, clientDomain string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.RequestURI != "/no-cert" { if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 { w.Write([]byte("fail")) t.Error("http.Request.TLS does not have peer certificates") return } if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain { w.Write([]byte("fail")) t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain) return } if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) { w.Write([]byte("fail")) t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) return } // Add serial number to check rotation sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw) w.Header().Set("x-fingerprint", hex.EncodeToString(sum[:])) } w.Write([]byte("ok")) }) } func TestClient_GetServerTLSConfig_http(t *testing.T) { clientDomain := "test.domain" client, sr, pk := sign(t, "127.0.0.1") // Create mTLS server ctx, cancel := context.WithCancel(context.Background()) defer cancel() tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk) if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvMTLS.Close() // Create TLS server ctx, cancel = context.WithCancel(context.Background()) defer cancel() tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven()) if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvTLS.Close() tests := []struct { name string getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client wantErr map[string]bool }{ {"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { tr, err := client.Transport(context.Background(), sr, pk) if err != nil { t.Errorf("Client.Transport() error = %v", err) return nil } return &http.Client{ Transport: tr, } }, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}}, {"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk) if err != nil { t.Errorf("Client.GetClientTLSConfig() error = %v", err) return nil } return &http.Client{ Transport: getDefaultTransport(tlsConfig), } }, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}}, {"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { root, err := RootCertificate(sr) if err != nil { t.Errorf("RootCertificate() error = %v", err) return nil } tlsConfig := getDefaultTLSConfig(sr) tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs.AddCert(root) return &http.Client{ Transport: getDefaultTransport(tlsConfig), } }, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}}, {"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { return &http.Client{} }, map[string]bool{srvTLS.URL + "/no-cert": true, srvMTLS.URL + "/no-cert": true}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client, sr, pk := sign(t, clientDomain) cli := tt.getClient(t, client, sr, pk) if cli == nil { return } for path, wantErr := range tt.wantErr { t.Run(path, func(t *testing.T) { resp, err := cli.Get(path) if (err != nil) != wantErr { t.Errorf("http.Client.Get() error = %v, wantErr %v", err, wantErr) return } if wantErr { return } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("io.ReadAll() error = %v", err) } if !bytes.Equal(b, []byte("ok")) { t.Errorf("response body unexpected, got %s, want ok", b) } }) } }) } } func TestClient_GetServerTLSConfig_renew(t *testing.T) { reset := setMinCertDuration(1 * time.Second) defer reset() // Start CA ca := startCATestServer(t) defer ca.Close() clientDomain := "test.domain" client, sr, pk := signDuration(t, ca, "127.0.0.1", 5*time.Second) // Start mTLS server ctx, cancel := context.WithCancel(context.Background()) defer cancel() tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk) require.NoError(t, err) srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvMTLS.Close() // Start TLS server ctx, cancel = context.WithCancel(context.Background()) defer cancel() tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven()) require.NoError(t, err) srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvTLS.Close() // Transport client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second) tr, err := client.Transport(context.Background(), sr, pk) require.NoError(t, err) tr1, ok := tr.(*http.Transport) require.True(t, ok) // Transport with tlsConfig client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second) tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk) require.NoError(t, err) tr2 := getDefaultTransport(tlsConfig) // No client cert root, err := RootCertificate(sr) require.NoError(t, err) tlsConfig = getDefaultTLSConfig(sr) tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs.AddCert(root) tr3 := getDefaultTransport(tlsConfig) // Disable keep alives to force TLS handshake tr1.DisableKeepAlives = true tr2.DisableKeepAlives = true tr3.DisableKeepAlives = true tests := []struct { name string client *http.Client wantErr map[string]bool }{ {"with transport", &http.Client{Transport: tr1}, map[string]bool{ srvTLS.URL: false, srvMTLS.URL: false, }}, {"with tlsConfig", &http.Client{Transport: tr2}, map[string]bool{ srvTLS.URL: false, srvMTLS.URL: false, }}, {"with no ClientCert", &http.Client{Transport: tr3}, map[string]bool{ srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true, }}, {"fail with default", &http.Client{}, map[string]bool{ srvTLS.URL + "/no-cert": true, srvMTLS.URL + "/no-cert": true, }}, } // To count different cert fingerprints fingerprints := map[string]struct{}{} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { for path, wantErr := range tt.wantErr { t.Run(path, func(t *testing.T) { resp, err := tt.client.Get(path) if (err != nil) != wantErr { t.Errorf("http.Client.Get() error = %v", err) return } if wantErr { return } if fp := resp.Header.Get("x-fingerprint"); fp != "" { fingerprints[fp] = struct{}{} } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("io.ReadAll() error = %v", err) return } if !bytes.Equal(b, []byte("ok")) { t.Errorf("response body unexpected, got %s, want ok", b) return } }) } }) } if l := len(fingerprints); l != 2 { t.Errorf("number of fingerprints unexpected, got %d, want 2", l) } // Wait for renewal log.Printf("Sleeping for %s ...\n", 5*time.Second) time.Sleep(5 * time.Second) for _, tt := range tests { t.Run("renewed "+tt.name, func(t *testing.T) { for path, wantErr := range tt.wantErr { t.Run(path, func(t *testing.T) { resp, err := tt.client.Get(path) if (err != nil) != wantErr { t.Errorf("http.Client.Get() error = %v", err) return } if wantErr { return } if fp := resp.Header.Get("x-fingerprint"); fp != "" { fingerprints[fp] = struct{}{} } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("io.ReadAll() error = %v", err) return } if !bytes.Equal(b, []byte("ok")) { t.Errorf("response body unexpected, got %s, want ok", b) return } }) } }) } if l := len(fingerprints); l != 4 { t.Errorf("number of fingerprints unexpected, got %d, want 4", l) } } func TestCertificate(t *testing.T) { cert := parseCertificate(t, certPEM) ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: cert}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ {Certificate: cert}, {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { name string sign *api.SignResponse want *x509.Certificate wantErr bool }{ {"ok", ok, cert, false}, {"fail", &api.SignResponse{}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := Certificate(tt.sign) if (err != nil) != tt.wantErr { t.Errorf("Certificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Certificate() = %v, want %v", got, tt.want) } }) } } func TestIntermediateCertificate(t *testing.T) { intermediate := parseCertificate(t, rootPEM) ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, CaPEM: api.Certificate{Certificate: intermediate}, CertChainPEM: []api.Certificate{ {Certificate: parseCertificate(t, certPEM)}, {Certificate: intermediate}, }, } tests := []struct { name string sign *api.SignResponse want *x509.Certificate wantErr bool }{ {"ok", ok, intermediate, false}, {"fail", &api.SignResponse{}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := IntermediateCertificate(tt.sign) if (err != nil) != tt.wantErr { t.Errorf("IntermediateCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("IntermediateCertificate() = %v, want %v", got, tt.want) } }) } } func TestRootCertificateCertificate(t *testing.T) { root := parseCertificate(t, rootPEM) ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ {Certificate: parseCertificate(t, certPEM)}, {Certificate: parseCertificate(t, rootPEM)}, }, TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ {root, root}, }}, } noTLS := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ {Certificate: parseCertificate(t, certPEM)}, {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { name string sign *api.SignResponse want *x509.Certificate wantErr bool }{ {"ok", ok, root, false}, {"fail", &api.SignResponse{}, nil, true}, {"no tls", noTLS, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := RootCertificate(tt.sign) if (err != nil) != tt.wantErr { t.Errorf("RootCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("RootCertificate() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cas/apiv1/extension.go ================================================ package apiv1 import ( "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "github.com/pkg/errors" ) var ( oidStepRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} oidStepCertificateAuthority = append(asn1.ObjectIdentifier(nil), append(oidStepRoot, 2)...) ) // CertificateAuthorityExtension type is used to encode the certificate // authority extension. type CertificateAuthorityExtension struct { Type string CertificateID string `asn1:"optional,omitempty"` KeyValuePairs []string `asn1:"optional,omitempty"` } // CreateCertificateAuthorityExtension returns a X.509 extension that shows the // CAS type, id and a list of optional key value pairs. func CreateCertificateAuthorityExtension(typ Type, certificateID string, keyValuePairs ...string) (pkix.Extension, error) { b, err := asn1.Marshal(CertificateAuthorityExtension{ Type: typ.String(), CertificateID: certificateID, KeyValuePairs: keyValuePairs, }) if err != nil { return pkix.Extension{}, errors.Wrapf(err, "error marshaling certificate id extension") } return pkix.Extension{ Id: oidStepCertificateAuthority, Critical: false, Value: b, }, nil } // FindCertificateAuthorityExtension returns the certificate authority extension // from a signed certificate. func FindCertificateAuthorityExtension(cert *x509.Certificate) (pkix.Extension, bool) { for _, ext := range cert.Extensions { if ext.Id.Equal(oidStepCertificateAuthority) { return ext, true } } return pkix.Extension{}, false } // RemoveCertificateAuthorityExtension removes the certificate authority // extension from a certificate template. func RemoveCertificateAuthorityExtension(cert *x509.Certificate) { for i, ext := range cert.ExtraExtensions { if ext.Id.Equal(oidStepCertificateAuthority) { cert.ExtraExtensions = append(cert.ExtraExtensions[:i], cert.ExtraExtensions[i+1:]...) return } } } ================================================ FILE: cas/apiv1/extension_test.go ================================================ package apiv1 import ( "crypto/x509" "crypto/x509/pkix" "reflect" "testing" ) func TestCreateCertificateAuthorityExtension(t *testing.T) { type args struct { typ Type certificateID string keyValuePairs []string } tests := []struct { name string args args want pkix.Extension wantErr bool }{ {"ok", args{Type(CloudCAS), "1ac75689-cd3f-482e-a695-8a13daf39dc4", nil}, pkix.Extension{ Id: oidStepCertificateAuthority, Critical: false, Value: []byte{ 0x30, 0x30, 0x13, 0x08, 0x63, 0x6c, 0x6f, 0x75, 0x64, 0x63, 0x61, 0x73, 0x13, 0x24, 0x31, 0x61, 0x63, 0x37, 0x35, 0x36, 0x38, 0x39, 0x2d, 0x63, 0x64, 0x33, 0x66, 0x2d, 0x34, 0x38, 0x32, 0x65, 0x2d, 0x61, 0x36, 0x39, 0x35, 0x2d, 0x38, 0x61, 0x31, 0x33, 0x64, 0x61, 0x66, 0x33, 0x39, 0x64, 0x63, 0x34, }, }, false}, {"ok", args{Type(CloudCAS), "1ac75689-cd3f-482e-a695-8a13daf39dc4", []string{"foo", "bar"}}, pkix.Extension{ Id: oidStepCertificateAuthority, Critical: false, Value: []byte{ 0x30, 0x3c, 0x13, 0x08, 0x63, 0x6c, 0x6f, 0x75, 0x64, 0x63, 0x61, 0x73, 0x13, 0x24, 0x31, 0x61, 0x63, 0x37, 0x35, 0x36, 0x38, 0x39, 0x2d, 0x63, 0x64, 0x33, 0x66, 0x2d, 0x34, 0x38, 0x32, 0x65, 0x2d, 0x61, 0x36, 0x39, 0x35, 0x2d, 0x38, 0x61, 0x31, 0x33, 0x64, 0x61, 0x66, 0x33, 0x39, 0x64, 0x63, 0x34, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f, 0x13, 0x03, 0x62, 0x61, 0x72, }, }, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := CreateCertificateAuthorityExtension(tt.args.typ, tt.args.certificateID, tt.args.keyValuePairs...) if (err != nil) != tt.wantErr { t.Errorf("CreateCertificateAuthorityExtension() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CreateCertificateAuthorityExtension() = %v, want %v", got, tt.want) } }) } } func TestFindCertificateAuthorityExtension(t *testing.T) { expected := pkix.Extension{ Id: oidStepCertificateAuthority, Value: []byte("fake data"), } type args struct { cert *x509.Certificate } tests := []struct { name string args args want pkix.Extension want1 bool }{ {"first", args{&x509.Certificate{Extensions: []pkix.Extension{ expected, {Id: []int{1, 2, 3, 4}}, }}}, expected, true}, {"last", args{&x509.Certificate{Extensions: []pkix.Extension{ {Id: []int{1, 2, 3, 4}}, {Id: []int{2, 3, 4, 5}}, expected, }}}, expected, true}, {"fail", args{&x509.Certificate{Extensions: []pkix.Extension{ {Id: []int{1, 2, 3, 4}}, }}}, pkix.Extension{}, false}, {"fail ExtraExtensions", args{&x509.Certificate{ExtraExtensions: []pkix.Extension{ expected, {Id: []int{1, 2, 3, 4}}, }}}, pkix.Extension{}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1 := FindCertificateAuthorityExtension(tt.args.cert) if !reflect.DeepEqual(got, tt.want) { t.Errorf("FindCertificateAuthorityExtension() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { t.Errorf("FindCertificateAuthorityExtension() got1 = %v, want %v", got1, tt.want1) } }) } } func TestRemoveCertificateAuthorityExtension(t *testing.T) { caExt := pkix.Extension{ Id: oidStepCertificateAuthority, Value: []byte("fake data"), } type args struct { cert *x509.Certificate } tests := []struct { name string args args want *x509.Certificate }{ {"first", args{&x509.Certificate{ExtraExtensions: []pkix.Extension{ caExt, {Id: []int{1, 2, 3, 4}}, }}}, &x509.Certificate{ExtraExtensions: []pkix.Extension{ {Id: []int{1, 2, 3, 4}}, }}}, {"last", args{&x509.Certificate{ExtraExtensions: []pkix.Extension{ {Id: []int{1, 2, 3, 4}}, caExt, }}}, &x509.Certificate{ExtraExtensions: []pkix.Extension{ {Id: []int{1, 2, 3, 4}}, }}}, {"missing", args{&x509.Certificate{ExtraExtensions: []pkix.Extension{ {Id: []int{1, 2, 3, 4}}, }}}, &x509.Certificate{ExtraExtensions: []pkix.Extension{ {Id: []int{1, 2, 3, 4}}, }}}, {"extensions", args{&x509.Certificate{Extensions: []pkix.Extension{ caExt, {Id: []int{1, 2, 3, 4}}, }}}, &x509.Certificate{Extensions: []pkix.Extension{ caExt, {Id: []int{1, 2, 3, 4}}, }}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { RemoveCertificateAuthorityExtension(tt.args.cert) if !reflect.DeepEqual(tt.args.cert, tt.want) { t.Errorf("RemoveCertificateAuthorityExtension() cert = %v, want %v", tt.args.cert, tt.want) } }) } } ================================================ FILE: cas/apiv1/options.go ================================================ package apiv1 import ( "crypto" "crypto/x509" "encoding/json" "github.com/pkg/errors" "go.step.sm/crypto/kms" ) // Options represents the configuration options used to select and configure the // CertificateAuthorityService (CAS) to use. type Options struct { // AuthorityID is the the id oc the current authority. This is used on // StepCAS to add information about the origin of a certificate. AuthorityID string `json:"-"` // The type of the CAS to use. Type string `json:"type"` // CertificateAuthority reference: // In StepCAS the value is the CA url, e.g., "https://ca.smallstep.com:9000". // In CloudCAS the format is "projects/*/locations/*/certificateAuthorities/*". // In VaultCAS the value is the url, e.g., "https://vault.smallstep.com". CertificateAuthority string `json:"certificateAuthority,omitempty"` // CertificateAuthorityFingerprint is the root fingerprint used to // authenticate the connection to the CA when using StepCAS. CertificateAuthorityFingerprint string `json:"certificateAuthorityFingerprint,omitempty"` // CertificateIssuer contains the configuration used in StepCAS. CertificateIssuer *CertificateIssuer `json:"certificateIssuer,omitempty"` // Path to the credentials file used in CloudCAS. If not defined the default // authentication mechanism provided by Google SDK will be used. See // https://cloud.google.com/docs/authentication. CredentialsFile string `json:"credentialsFile,omitempty"` // CertificateChain contains the issuer certificate, along with any other // bundled certificates to be returned in the chain to consumers. It is used // used in SoftCAS and it is configured in the crt property of the ca.json. CertificateChain []*x509.Certificate `json:"-"` // Signer is the private key or a KMS signer for the issuer certificate. It // is used in SoftCAS and it is configured in the key property of the // ca.json. Signer crypto.Signer `json:"-"` // CertificateSigner combines CertificateChain and Signer in a callback that // returns the chain of certificate and signer used to sign X.509 // certificates in SoftCAS. CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) `json:"-"` // IsCreator is set to true when we're creating a certificate authority. It // is used to skip some validations when initializing a // CertificateAuthority. This option is used on SoftCAS and CloudCAS. IsCreator bool `json:"-"` // IsCAGetter is set to true when we're just using the // CertificateAuthorityGetter interface to retrieve the root certificate. It // is used to skip some validations when initializing a // CertificateAuthority. This option is used on StepCAS. IsCAGetter bool `json:"-"` // KeyManager is the KMS used to generate keys in SoftCAS. KeyManager kms.KeyManager `json:"-"` // Project, Location, CaPool and GCSBucket are parameters used in CloudCAS // to create a new certificate authority. If a CaPool does not exist it will // be created. GCSBucket is optional, if not provided GCloud will create a // managed bucket. Project string `json:"-"` Location string `json:"-"` CaPool string `json:"-"` CaPoolTier string `json:"-"` GCSBucket string `json:"-"` // Generic structure to configure any CAS Config json.RawMessage `json:"config,omitempty"` } // CertificateIssuer contains the properties used to use the StepCAS certificate // authority service. type CertificateIssuer struct { Type string `json:"type"` Provisioner string `json:"provisioner,omitempty"` Certificate string `json:"crt,omitempty"` Key string `json:"key,omitempty"` Password string `json:"password,omitempty"` } // Validate checks the fields in Options. func (o *Options) Validate() error { var typ Type if o == nil { typ = Type(SoftCAS) } else { typ = Type(o.Type) } // Check that the type can be loaded. if _, ok := LoadCertificateAuthorityServiceNewFunc(typ); !ok { return errors.Errorf("unsupported cas type %s", typ) } return nil } // Is returns if the options have the given type. func (o *Options) Is(t Type) bool { if o == nil { return t.String() == SoftCAS } return Type(o.Type).String() == t.String() } ================================================ FILE: cas/apiv1/options_test.go ================================================ package apiv1 import ( "context" "crypto" "crypto/x509" "sync" "testing" ) type testCAS struct { name string } func (t *testCAS) CreateCertificate(*CreateCertificateRequest) (*CreateCertificateResponse, error) { return nil, nil } func (t *testCAS) RenewCertificate(*RenewCertificateRequest) (*RenewCertificateResponse, error) { return nil, nil } func (t *testCAS) RevokeCertificate(*RevokeCertificateRequest) (*RevokeCertificateResponse, error) { return nil, nil } //nolint:gocritic // ignore sloppy test func name func mockRegister(t *testing.T) { t.Helper() Register(SoftCAS, func(ctx context.Context, opts Options) (CertificateAuthorityService, error) { return &testCAS{name: SoftCAS}, nil }) Register(CloudCAS, func(ctx context.Context, opts Options) (CertificateAuthorityService, error) { return &testCAS{name: CloudCAS}, nil }) t.Cleanup(func() { registry = new(sync.Map) }) } func TestOptions_Validate(t *testing.T) { mockRegister(t) type fields struct { Type string CredentialsFile string CertificateAuthority string Issuer *x509.Certificate Signer crypto.Signer } tests := []struct { name string fields fields wantErr bool }{ {"empty", fields{}, false}, {"SoftCAS", fields{SoftCAS, "", "", nil, nil}, false}, {"CloudCAS", fields{CloudCAS, "", "", nil, nil}, false}, {"softcas", fields{"softcas", "", "", nil, nil}, false}, {"CLOUDCAS", fields{"CLOUDCAS", "", "", nil, nil}, false}, {"fail", fields{"FailCAS", "", "", nil, nil}, true}, } t.Run("nil", func(t *testing.T) { var o *Options if err := o.Validate(); err != nil { t.Errorf("Options.Validate() error = %v, wantErr %v", err, false) } }) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &Options{ Type: tt.fields.Type, CredentialsFile: tt.fields.CredentialsFile, CertificateAuthority: tt.fields.CertificateAuthority, CertificateChain: []*x509.Certificate{tt.fields.Issuer}, Signer: tt.fields.Signer, } if err := o.Validate(); (err != nil) != tt.wantErr { t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestOptions_Is(t *testing.T) { mockRegister(t) type fields struct { Type string CredentialsFile string CertificateAuthority string Issuer *x509.Certificate Signer crypto.Signer } type args struct { t Type } tests := []struct { name string fields fields args args want bool }{ {"empty", fields{}, args{}, true}, {"SoftCAS", fields{SoftCAS, "", "", nil, nil}, args{"SoftCAS"}, true}, {"CloudCAS", fields{CloudCAS, "", "", nil, nil}, args{"CloudCAS"}, true}, {"softcas", fields{"softcas", "", "", nil, nil}, args{SoftCAS}, true}, {"CLOUDCAS", fields{"CLOUDCAS", "", "", nil, nil}, args{CloudCAS}, true}, {"UnknownCAS", fields{"UnknownCAS", "", "", nil, nil}, args{"UnknownCAS"}, true}, {"fail", fields{CloudCAS, "", "", nil, nil}, args{"SoftCAS"}, false}, {"fail", fields{SoftCAS, "", "", nil, nil}, args{"CloudCAS"}, false}, } t.Run("nil", func(t *testing.T) { var o *Options if got := o.Is(SoftCAS); got != true { t.Errorf("Options.Is() = %v, want %v", got, true) } }) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { o := &Options{ Type: tt.fields.Type, CredentialsFile: tt.fields.CredentialsFile, CertificateAuthority: tt.fields.CertificateAuthority, CertificateChain: []*x509.Certificate{tt.fields.Issuer}, Signer: tt.fields.Signer, } if got := o.Is(tt.args.t); got != tt.want { t.Errorf("Options.Is() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cas/apiv1/registry.go ================================================ package apiv1 import ( "context" "sync" ) var ( registry = new(sync.Map) ) // CertificateAuthorityServiceNewFunc is the type that represents the method to initialize a new // CertificateAuthorityService. type CertificateAuthorityServiceNewFunc func(ctx context.Context, opts Options) (CertificateAuthorityService, error) // Register adds to the registry a method to create a KeyManager of type t. func Register(t Type, fn CertificateAuthorityServiceNewFunc) { registry.Store(t.String(), fn) } // LoadCertificateAuthorityServiceNewFunc returns the function to initialize a KeyManager. func LoadCertificateAuthorityServiceNewFunc(t Type) (CertificateAuthorityServiceNewFunc, bool) { v, ok := registry.Load(t.String()) if !ok { return nil, false } fn, ok := v.(CertificateAuthorityServiceNewFunc) return fn, ok } ================================================ FILE: cas/apiv1/registry_test.go ================================================ package apiv1 import ( "context" "fmt" "reflect" "sync" "testing" ) func TestRegister(t *testing.T) { t.Cleanup(func() { registry = new(sync.Map) }) type args struct { t Type fn CertificateAuthorityServiceNewFunc } tests := []struct { name string args args want CertificateAuthorityService wantErr bool }{ {"ok", args{"TestCAS", func(ctx context.Context, opts Options) (CertificateAuthorityService, error) { return &testCAS{}, nil }}, &testCAS{}, false}, {"error", args{"ErrorCAS", func(ctx context.Context, opts Options) (CertificateAuthorityService, error) { return nil, fmt.Errorf("an error") }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { Register(tt.args.t, tt.args.fn) fmt.Println(registry) fn, ok := registry.Load(tt.args.t.String()) if !ok { t.Errorf("Register() failed") return } got, err := fn.(CertificateAuthorityServiceNewFunc)(context.Background(), Options{}) if (err != nil) != tt.wantErr { t.Errorf("CertificateAuthorityServiceNewFunc() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CertificateAuthorityServiceNewFunc() = %v, want %v", got, tt.want) } }) } } func TestLoadCertificateAuthorityServiceNewFunc(t *testing.T) { mockRegister(t) type args struct { t Type } tests := []struct { name string args args want CertificateAuthorityService wantOk bool }{ {"default", args{""}, &testCAS{name: SoftCAS}, true}, {"SoftCAS", args{"SoftCAS"}, &testCAS{name: SoftCAS}, true}, {"CloudCAS", args{"CloudCAS"}, &testCAS{name: CloudCAS}, true}, {"softcas", args{"softcas"}, &testCAS{name: SoftCAS}, true}, {"cloudcas", args{"cloudcas"}, &testCAS{name: CloudCAS}, true}, {"FailCAS", args{"FailCAS"}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fn, ok := LoadCertificateAuthorityServiceNewFunc(tt.args.t) if ok != tt.wantOk { t.Errorf("LoadCertificateAuthorityServiceNewFunc() ok = %v, want %v", ok, tt.wantOk) return } if ok { got, err := fn(context.Background(), Options{}) if err != nil { t.Errorf("CertificateAuthorityServiceNewFunc() error = %v", err) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CertificateAuthorityServiceNewFunc() = %v, want %v", got, tt.want) } } }) } } ================================================ FILE: cas/apiv1/requests.go ================================================ package apiv1 import ( "crypto" "crypto/x509" "time" "go.step.sm/crypto/kms/apiv1" ) // CertificateAuthorityType indicates the type of Certificate Authority to // create. type CertificateAuthorityType int const ( // RootCA is the type used to create a self-signed certificate suitable for // use as a root CA. RootCA CertificateAuthorityType = iota + 1 // IntermediateCA is the type used to create a subordinated certificate that // can be used to sign additional leaf certificates. IntermediateCA ) // SignatureAlgorithm used for cryptographic signing. type SignatureAlgorithm int const ( // Not specified. UnspecifiedSignAlgorithm SignatureAlgorithm = iota // RSASSA-PKCS1-v1_5 key and a SHA256 digest. SHA256WithRSA // RSASSA-PKCS1-v1_5 key and a SHA384 digest. SHA384WithRSA // RSASSA-PKCS1-v1_5 key and a SHA512 digest. SHA512WithRSA // RSASSA-PSS key with a SHA256 digest. SHA256WithRSAPSS // RSASSA-PSS key with a SHA384 digest. SHA384WithRSAPSS // RSASSA-PSS key with a SHA512 digest. SHA512WithRSAPSS // ECDSA on the NIST P-256 curve with a SHA256 digest. ECDSAWithSHA256 // ECDSA on the NIST P-384 curve with a SHA384 digest. ECDSAWithSHA384 // ECDSA on the NIST P-521 curve with a SHA512 digest. ECDSAWithSHA512 // EdDSA on Curve25519 with a SHA512 digest. PureEd25519 ) // CreateCertificateRequest is the request used to sign a new certificate. type CreateCertificateRequest struct { Template *x509.Certificate CSR *x509.CertificateRequest Lifetime time.Duration Backdate time.Duration RequestID string Provisioner *ProvisionerInfo IsCAServerCert bool } // ProvisionerInfo contains information of the provisioner used to authorize a // certificate. type ProvisionerInfo struct { ID string Type string Name string } // CreateCertificateResponse is the response to a create certificate request. type CreateCertificateResponse struct { Certificate *x509.Certificate CertificateChain []*x509.Certificate } // RenewCertificateRequest is the request used to re-sign a certificate. type RenewCertificateRequest struct { Template *x509.Certificate CSR *x509.CertificateRequest Lifetime time.Duration Backdate time.Duration Token string RequestID string } // RenewCertificateResponse is the response to a renew certificate request. type RenewCertificateResponse struct { Certificate *x509.Certificate CertificateChain []*x509.Certificate } // RevokeCertificateRequest is the request used to revoke a certificate. type RevokeCertificateRequest struct { Certificate *x509.Certificate SerialNumber string Reason string ReasonCode int PassiveOnly bool RequestID string } // RevokeCertificateResponse is the response to a revoke certificate request. type RevokeCertificateResponse struct { Certificate *x509.Certificate CertificateChain []*x509.Certificate } // GetCertificateAuthorityRequest is the request used to get the root // certificate from a CAS. type GetCertificateAuthorityRequest struct { Name string } // GetCertificateAuthorityResponse is the response that contains // the root certificate. type GetCertificateAuthorityResponse struct { RootCertificate *x509.Certificate IntermediateCertificates []*x509.Certificate } // CreateKeyRequest is the request used to generate a new key using a KMS. type CreateKeyRequest = apiv1.CreateKeyRequest // CreateCertificateAuthorityRequest is the request used to generate a root or // intermediate certificate. type CreateCertificateAuthorityRequest struct { Name string Type CertificateAuthorityType Template *x509.Certificate Lifetime time.Duration Backdate time.Duration RequestID string Project string Location string // Parent is the signer of the new CertificateAuthority. Parent *CreateCertificateAuthorityResponse // CreateKey defines the KMS CreateKeyRequest to use when creating a new // CertificateAuthority. If CreateKey is nil, a default algorithm will be // used. CreateKey *CreateKeyRequest } // CreateCertificateAuthorityResponse is the response for // CreateCertificateAuthority method and contains the root or intermediate // certificate generated as well as the CA chain. type CreateCertificateAuthorityResponse struct { Name string Certificate *x509.Certificate CertificateChain []*x509.Certificate KeyName string PublicKey crypto.PublicKey PrivateKey crypto.PrivateKey Signer crypto.Signer } // CreateCRLRequest is the request to create a Certificate Revocation List. type CreateCRLRequest struct { RevocationList *x509.RevocationList } // CreateCRLResponse is the response to a Certificate Revocation List request. type CreateCRLResponse struct { CRL []byte //the CRL in DER format } ================================================ FILE: cas/apiv1/services.go ================================================ package apiv1 import ( "crypto" "crypto/x509" "net/http" "strings" ) // CertificateAuthorityService is the interface implemented to support external // certificate authorities. type CertificateAuthorityService interface { CreateCertificate(req *CreateCertificateRequest) (*CreateCertificateResponse, error) RenewCertificate(req *RenewCertificateRequest) (*RenewCertificateResponse, error) RevokeCertificate(req *RevokeCertificateRequest) (*RevokeCertificateResponse, error) } // CertificateAuthorityCRLGenerator is an optional interface implemented by CertificateAuthorityService // that has a method to create a CRL type CertificateAuthorityCRLGenerator interface { CreateCRL(req *CreateCRLRequest) (*CreateCRLResponse, error) } // CertificateAuthorityGetter is an interface implemented by a // CertificateAuthorityService that has a method to get the root certificate. type CertificateAuthorityGetter interface { GetCertificateAuthority(req *GetCertificateAuthorityRequest) (*GetCertificateAuthorityResponse, error) } // CertificateAuthorityCreator is an interface implemented by a // CertificateAuthorityService that has a method to create a new certificate // authority. type CertificateAuthorityCreator interface { CreateCertificateAuthority(req *CreateCertificateAuthorityRequest) (*CreateCertificateAuthorityResponse, error) } // CertificateAuthoritySigner is an optional interface implemented by a // CertificateAuthorityService that has a method that returns a [crypto.Signer] // using the same key used to issue certificates. type CertificateAuthoritySigner interface { GetSigner() (crypto.Signer, error) } // SignatureAlgorithmGetter is an optional implementation in a crypto.Signer // that returns the SignatureAlgorithm to use. type SignatureAlgorithmGetter interface { SignatureAlgorithm() x509.SignatureAlgorithm } // Type represents the CAS type used. type Type string const ( // DefaultCAS is a CertificateAuthorityService using software. DefaultCAS = "" // SoftCAS is a CertificateAuthorityService using software. SoftCAS = "softcas" // CloudCAS is a CertificateAuthorityService using Google Cloud CAS. CloudCAS = "cloudcas" // StepCAS is a CertificateAuthorityService using another step-ca instance. StepCAS = "stepcas" // VaultCAS is a CertificateAuthorityService using Hasicorp Vault PKI. VaultCAS = "vaultcas" // ExternalCAS is a CertificateAuthorityService using an external injected CA implementation ExternalCAS = "externalcas" ) // String returns a string from the type. It will always return the lower case // version of the Type, as we need a standard type to compare and use as the // registry key. func (t Type) String() string { if t == "" { return SoftCAS } return strings.ToLower(string(t)) } // TypeOf returns the type of the given CertificateAuthorityService. func TypeOf(c CertificateAuthorityService) Type { if ct, ok := c.(interface{ Type() Type }); ok { return ct.Type() } return ExternalCAS } // NotImplementedError is the type of error returned if an operation is not implemented. type NotImplementedError struct { Message string } // Error implements the error interface. func (e NotImplementedError) Error() string { if e.Message != "" { return e.Message } return "not implemented" } // StatusCode implements the StatusCoder interface and returns the HTTP 501 // error. func (e NotImplementedError) StatusCode() int { return http.StatusNotImplemented } // ValidationError is the type of error returned if request is not properly // validated. type ValidationError struct { Message string } // NotImplementedError implements the error interface. func (e ValidationError) Error() string { if e.Message != "" { return e.Message } return "bad request" } // StatusCode implements the StatusCoder interface and returns the HTTP 400 // error. func (e ValidationError) StatusCode() int { return http.StatusBadRequest } ================================================ FILE: cas/apiv1/services_test.go ================================================ package apiv1 import ( "testing" ) type simpleCAS struct{} func (*simpleCAS) CreateCertificate(req *CreateCertificateRequest) (*CreateCertificateResponse, error) { return nil, NotImplementedError{} } func (*simpleCAS) RenewCertificate(req *RenewCertificateRequest) (*RenewCertificateResponse, error) { return nil, NotImplementedError{} } func (*simpleCAS) RevokeCertificate(req *RevokeCertificateRequest) (*RevokeCertificateResponse, error) { return nil, NotImplementedError{} } type fakeCAS struct { simpleCAS } func (*fakeCAS) Type() Type { return SoftCAS } func TestType_String(t *testing.T) { tests := []struct { name string t Type want string }{ {"default", "", "softcas"}, {"SoftCAS", SoftCAS, "softcas"}, {"CloudCAS", CloudCAS, "cloudcas"}, {"ExternalCAS", ExternalCAS, "externalcas"}, {"UnknownCAS", "UnknownCAS", "unknowncas"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.t.String(); got != tt.want { t.Errorf("Type.String() = %v, want %v", got, tt.want) } }) } } func TestTypeOf(t *testing.T) { type args struct { c CertificateAuthorityService } tests := []struct { name string args args want Type }{ {"ok", args{&simpleCAS{}}, ExternalCAS}, {"ok with type", args{&fakeCAS{}}, SoftCAS}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := TypeOf(tt.args.c); got != tt.want { t.Errorf("TypeOf() = %v, want %v", got, tt.want) } }) } } func TestNotImplementedError_Error(t *testing.T) { type fields struct { Message string } tests := []struct { name string fields fields want string }{ {"default", fields{""}, "not implemented"}, {"with message", fields{"method not supported"}, "method not supported"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := NotImplementedError{ Message: tt.fields.Message, } if got := e.Error(); got != tt.want { t.Errorf("NotImplementedError.Error() = %v, want %v", got, tt.want) } }) } } func TestNotImplementedError_StatusCode(t *testing.T) { type fields struct { Message string } tests := []struct { name string fields fields want int }{ {"default", fields{""}, 501}, {"with message", fields{"method not supported"}, 501}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NotImplementedError{ Message: tt.fields.Message, } if got := s.StatusCode(); got != tt.want { t.Errorf("NotImplementedError.StatusCode() = %v, want %v", got, tt.want) } }) } } func TestValidationError_Error(t *testing.T) { type fields struct { Message string } tests := []struct { name string fields fields want string }{ {"default", fields{""}, "bad request"}, {"with message", fields{"token is empty"}, "token is empty"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := ValidationError{ Message: tt.fields.Message, } if got := e.Error(); got != tt.want { t.Errorf("ValidationError.Error() = %v, want %v", got, tt.want) } }) } } func TestValidationError_StatusCode(t *testing.T) { type fields struct { Message string } tests := []struct { name string fields fields want int }{ {"default", fields{""}, 400}, {"with message", fields{"token is empty"}, 400}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := ValidationError{ Message: tt.fields.Message, } if got := e.StatusCode(); got != tt.want { t.Errorf("ValidationError.StatusCode() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cas/cas.go ================================================ package cas import ( "context" "strings" "github.com/pkg/errors" "github.com/smallstep/certificates/cas/apiv1" // Enable default implementation _ "github.com/smallstep/certificates/cas/softcas" ) // CertificateAuthorityService is the interface implemented by all the CAS. type CertificateAuthorityService = apiv1.CertificateAuthorityService // CertificateAuthorityCreator is the interface implemented by all CAS that can create a new authority. type CertificateAuthorityCreator = apiv1.CertificateAuthorityCreator // New creates a new CertificateAuthorityService using the given options. func New(ctx context.Context, opts apiv1.Options) (CertificateAuthorityService, error) { if err := opts.Validate(); err != nil { return nil, err } t := apiv1.Type(strings.ToLower(opts.Type)) if t == apiv1.DefaultCAS { t = apiv1.SoftCAS } fn, ok := apiv1.LoadCertificateAuthorityServiceNewFunc(t) if !ok { return nil, errors.Errorf("unsupported cas type '%s'", t) } return fn(ctx, opts) } // NewCreator creates a new CertificateAuthorityCreator using the given options. func NewCreator(ctx context.Context, opts apiv1.Options) (CertificateAuthorityCreator, error) { opts.IsCreator = true t := apiv1.Type(strings.ToLower(opts.Type)) if t == apiv1.DefaultCAS { t = apiv1.SoftCAS } svc, err := New(ctx, opts) if err != nil { return nil, err } creator, ok := svc.(CertificateAuthorityCreator) if !ok { return nil, errors.Errorf("cas type '%s' does not implements CertificateAuthorityCreator", t) } return creator, nil } ================================================ FILE: cas/cas_test.go ================================================ package cas import ( "context" "crypto/ed25519" "crypto/x509" "crypto/x509/pkix" "fmt" "reflect" "testing" "go.step.sm/crypto/kms" kmsapi "go.step.sm/crypto/kms/apiv1" "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/cas/softcas" ) type mockCAS struct{} func (m *mockCAS) CreateCertificate(*apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) { panic("not implemented") } func (m *mockCAS) RenewCertificate(*apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { panic("not implemented") } func (m *mockCAS) RevokeCertificate(*apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { panic("not implemented") } func TestNew(t *testing.T) { expected := &softcas.SoftCAS{ CertificateChain: []*x509.Certificate{{Subject: pkix.Name{CommonName: "Test Issuer"}}}, Signer: ed25519.PrivateKey{}, } apiv1.Register(apiv1.Type("nockCAS"), func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) { return nil, fmt.Errorf("an error") }) type args struct { ctx context.Context opts apiv1.Options } tests := []struct { name string args args want CertificateAuthorityService wantErr bool }{ {"ok default", args{context.Background(), apiv1.Options{ CertificateChain: []*x509.Certificate{{Subject: pkix.Name{CommonName: "Test Issuer"}}}, Signer: ed25519.PrivateKey{}, }}, expected, false}, {"ok softcas", args{context.Background(), apiv1.Options{ Type: "softcas", CertificateChain: []*x509.Certificate{{Subject: pkix.Name{CommonName: "Test Issuer"}}}, Signer: ed25519.PrivateKey{}, }}, expected, false}, {"ok SoftCAS", args{context.Background(), apiv1.Options{ Type: "SoftCAS", CertificateChain: []*x509.Certificate{{Subject: pkix.Name{CommonName: "Test Issuer"}}}, Signer: ed25519.PrivateKey{}, }}, expected, false}, {"fail empty", args{context.Background(), apiv1.Options{}}, (*softcas.SoftCAS)(nil), true}, {"fail type", args{context.Background(), apiv1.Options{Type: "FailCAS"}}, nil, true}, {"fail load", args{context.Background(), apiv1.Options{Type: "nockCAS"}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := New(tt.args.ctx, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("New() = %#v, want %v", got, tt.want) } }) } } func TestNewCreator(t *testing.T) { keyManager, err := kms.New(context.Background(), kmsapi.Options{}) if err != nil { t.Fatal(err) } apiv1.Register(apiv1.Type("nockCAS"), func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) { return &mockCAS{}, nil }) type args struct { ctx context.Context opts apiv1.Options } tests := []struct { name string args args want CertificateAuthorityCreator wantErr bool }{ {"ok empty", args{context.Background(), apiv1.Options{}}, &softcas.SoftCAS{}, false}, {"ok softcas", args{context.Background(), apiv1.Options{ Type: "softcas", }}, &softcas.SoftCAS{}, false}, {"ok SoftCAS", args{context.Background(), apiv1.Options{ Type: "SoftCAS", KeyManager: keyManager, }}, &softcas.SoftCAS{KeyManager: keyManager}, false}, {"fail type", args{context.Background(), apiv1.Options{Type: "FailCAS"}}, nil, true}, {"fail no creator", args{context.Background(), apiv1.Options{Type: "nockCAS"}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := NewCreator(tt.args.ctx, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("NewCreator() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewCreator() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cas/cloudcas/certificate.go ================================================ package cloudcas import ( "crypto" "crypto/ecdsa" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/pem" "fmt" pb "cloud.google.com/go/security/privateca/apiv1/privatecapb" "github.com/pkg/errors" kmsapi "go.step.sm/crypto/kms/apiv1" "github.com/smallstep/certificates/internal/cast" ) var ( oidExtensionSubjectKeyID = []int{2, 5, 29, 14} oidExtensionKeyUsage = []int{2, 5, 29, 15} oidExtensionExtendedKeyUsage = []int{2, 5, 29, 37} oidExtensionAuthorityKeyID = []int{2, 5, 29, 35} oidExtensionBasicConstraints = []int{2, 5, 29, 19} oidExtensionSubjectAltName = []int{2, 5, 29, 17} oidExtensionCRLDistributionPoints = []int{2, 5, 29, 31} oidExtensionCertificatePolicies = []int{2, 5, 29, 32} oidExtensionAuthorityInfoAccess = []int{1, 3, 6, 1, 5, 5, 7, 1, 1} ) var extraExtensions = [...]asn1.ObjectIdentifier{ oidExtensionSubjectKeyID, // Added by CAS oidExtensionKeyUsage, // Added in CertificateConfig.ReusableConfig oidExtensionExtendedKeyUsage, // Added in CertificateConfig.ReusableConfig oidExtensionAuthorityKeyID, // Added by CAS oidExtensionBasicConstraints, // Added in CertificateConfig.ReusableConfig oidExtensionSubjectAltName, // Added in CertificateConfig.SubjectConfig.SubjectAltName oidExtensionCRLDistributionPoints, // Added by CAS oidExtensionCertificatePolicies, // Added in CertificateConfig.ReusableConfig oidExtensionAuthorityInfoAccess, // Added in CertificateConfig.ReusableConfig and by CAS } var ( oidExtKeyUsageAny = asn1.ObjectIdentifier{2, 5, 29, 37, 0} oidExtKeyUsageIPSECEndSystem = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 5} oidExtKeyUsageIPSECTunnel = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 6} oidExtKeyUsageIPSECUser = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 7} oidExtKeyUsageMicrosoftServerGatedCrypto = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 10, 3, 3} oidExtKeyUsageNetscapeServerGatedCrypto = asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 4, 1} oidExtKeyUsageMicrosoftCommercialCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 2, 1, 22} oidExtKeyUsageMicrosoftKernelCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 61, 1, 1} ) const ( nameTypeEmail = 1 nameTypeDNS = 2 nameTypeURI = 6 nameTypeIP = 7 ) func createCertificateConfig(tpl *x509.Certificate) (*pb.Certificate_Config, error) { pk, err := createPublicKey(tpl.PublicKey) if err != nil { return nil, err } config := &pb.CertificateConfig{ SubjectConfig: &pb.CertificateConfig_SubjectConfig{ Subject: createSubject(tpl), SubjectAltName: createSubjectAlternativeNames(tpl), }, X509Config: createX509Parameters(tpl), PublicKey: pk, } return &pb.Certificate_Config{ Config: config, }, nil } func createPublicKey(key crypto.PublicKey) (*pb.PublicKey, error) { switch key := key.(type) { case *ecdsa.PublicKey: asn1Bytes, err := x509.MarshalPKIXPublicKey(key) if err != nil { return nil, errors.Wrap(err, "error marshaling public key") } return &pb.PublicKey{ Format: pb.PublicKey_PEM, Key: pem.EncodeToMemory(&pem.Block{ Type: "PUBLIC KEY", Bytes: asn1Bytes, }), }, nil case *rsa.PublicKey: return &pb.PublicKey{ Format: pb.PublicKey_PEM, Key: pem.EncodeToMemory(&pem.Block{ Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(key), }), }, nil default: return nil, errors.Errorf("unsupported public key type: %T", key) } } func createSubject(cert *x509.Certificate) *pb.Subject { sub := cert.Subject ret := &pb.Subject{ CommonName: sub.CommonName, } if len(sub.Country) > 0 { ret.CountryCode = sub.Country[0] } if len(sub.Organization) > 0 { ret.Organization = sub.Organization[0] } if len(sub.OrganizationalUnit) > 0 { ret.OrganizationalUnit = sub.OrganizationalUnit[0] } if len(sub.Locality) > 0 { ret.Locality = sub.Locality[0] } if len(sub.Province) > 0 { ret.Province = sub.Province[0] } if len(sub.StreetAddress) > 0 { ret.StreetAddress = sub.StreetAddress[0] } if len(sub.PostalCode) > 0 { ret.PostalCode = sub.PostalCode[0] } return ret } func createSubjectAlternativeNames(cert *x509.Certificate) *pb.SubjectAltNames { ret := new(pb.SubjectAltNames) ret.DnsNames = cert.DNSNames ret.EmailAddresses = cert.EmailAddresses if n := len(cert.IPAddresses); n > 0 { ret.IpAddresses = make([]string, n) for i, ip := range cert.IPAddresses { ret.IpAddresses[i] = ip.String() } } if n := len(cert.URIs); n > 0 { ret.Uris = make([]string, n) for i, u := range cert.URIs { ret.Uris[i] = u.String() } } // Add extra SANs coming from the extensions if ext, ok := findExtraExtension(cert, oidExtensionSubjectAltName); ok { var rawValues []asn1.RawValue if _, err := asn1.Unmarshal(ext.Value, &rawValues); err == nil { var newValues []asn1.RawValue for _, v := range rawValues { if v.Class == asn1.ClassContextSpecific { switch v.Tag { case nameTypeDNS: if len(ret.DnsNames) == 0 { newValues = append(newValues, v) } case nameTypeEmail: if len(ret.EmailAddresses) == 0 { newValues = append(newValues, v) } case nameTypeIP: if len(ret.IpAddresses) == 0 { newValues = append(newValues, v) } case nameTypeURI: if len(ret.Uris) == 0 { newValues = append(newValues, v) } default: newValues = append(newValues, v) } } else { newValues = append(newValues, v) } } if len(newValues) > 0 { if b, err := asn1.Marshal(newValues); err == nil { ret.CustomSans = []*pb.X509Extension{{ ObjectId: createObjectID(ext.Id), Critical: ext.Critical, Value: b, }} } } } } return ret } func createX509Parameters(cert *x509.Certificate) *pb.X509Parameters { var unknownEKUs []*pb.ObjectId var ekuOptions = &pb.KeyUsage_ExtendedKeyUsageOptions{} for _, eku := range cert.ExtKeyUsage { switch eku { case x509.ExtKeyUsageAny: unknownEKUs = append(unknownEKUs, createObjectID(oidExtKeyUsageAny)) case x509.ExtKeyUsageServerAuth: ekuOptions.ServerAuth = true case x509.ExtKeyUsageClientAuth: ekuOptions.ClientAuth = true case x509.ExtKeyUsageCodeSigning: ekuOptions.CodeSigning = true case x509.ExtKeyUsageEmailProtection: ekuOptions.EmailProtection = true case x509.ExtKeyUsageIPSECEndSystem: unknownEKUs = append(unknownEKUs, createObjectID(oidExtKeyUsageIPSECEndSystem)) case x509.ExtKeyUsageIPSECTunnel: unknownEKUs = append(unknownEKUs, createObjectID(oidExtKeyUsageIPSECTunnel)) case x509.ExtKeyUsageIPSECUser: unknownEKUs = append(unknownEKUs, createObjectID(oidExtKeyUsageIPSECUser)) case x509.ExtKeyUsageTimeStamping: ekuOptions.TimeStamping = true case x509.ExtKeyUsageOCSPSigning: ekuOptions.OcspSigning = true case x509.ExtKeyUsageMicrosoftServerGatedCrypto: unknownEKUs = append(unknownEKUs, createObjectID(oidExtKeyUsageMicrosoftServerGatedCrypto)) case x509.ExtKeyUsageNetscapeServerGatedCrypto: unknownEKUs = append(unknownEKUs, createObjectID(oidExtKeyUsageNetscapeServerGatedCrypto)) case x509.ExtKeyUsageMicrosoftCommercialCodeSigning: unknownEKUs = append(unknownEKUs, createObjectID(oidExtKeyUsageMicrosoftCommercialCodeSigning)) case x509.ExtKeyUsageMicrosoftKernelCodeSigning: unknownEKUs = append(unknownEKUs, createObjectID(oidExtKeyUsageMicrosoftKernelCodeSigning)) } } for _, oid := range cert.UnknownExtKeyUsage { unknownEKUs = append(unknownEKUs, createObjectID(oid)) } var policyIDs []*pb.ObjectId for _, oid := range cert.PolicyIdentifiers { policyIDs = append(policyIDs, createObjectID(oid)) } var caOptions *pb.X509Parameters_CaOptions if cert.BasicConstraintsValid { caOptions = new(pb.X509Parameters_CaOptions) var maxPathLength int32 switch { case cert.MaxPathLenZero: maxPathLength = 0 caOptions.MaxIssuerPathLength = &maxPathLength case cert.MaxPathLen > 0: maxPathLength = cast.Int32(cert.MaxPathLen) caOptions.MaxIssuerPathLength = &maxPathLength } caOptions.IsCa = &cert.IsCA } var extraExtensions []*pb.X509Extension for _, ext := range cert.ExtraExtensions { if isExtraExtension(ext.Id) { extraExtensions = append(extraExtensions, &pb.X509Extension{ ObjectId: createObjectID(ext.Id), Critical: ext.Critical, Value: ext.Value, }) } } return &pb.X509Parameters{ KeyUsage: &pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ DigitalSignature: cert.KeyUsage&x509.KeyUsageDigitalSignature > 0, ContentCommitment: cert.KeyUsage&x509.KeyUsageContentCommitment > 0, KeyEncipherment: cert.KeyUsage&x509.KeyUsageKeyEncipherment > 0, DataEncipherment: cert.KeyUsage&x509.KeyUsageDataEncipherment > 0, KeyAgreement: cert.KeyUsage&x509.KeyUsageKeyAgreement > 0, CertSign: cert.KeyUsage&x509.KeyUsageCertSign > 0, CrlSign: cert.KeyUsage&x509.KeyUsageCRLSign > 0, EncipherOnly: cert.KeyUsage&x509.KeyUsageEncipherOnly > 0, DecipherOnly: cert.KeyUsage&x509.KeyUsageDecipherOnly > 0, }, ExtendedKeyUsage: ekuOptions, UnknownExtendedKeyUsages: unknownEKUs, }, CaOptions: caOptions, PolicyIds: policyIDs, AiaOcspServers: cert.OCSPServer, AdditionalExtensions: extraExtensions, } } // isExtraExtension returns true if the extension oid is not managed in a // different way. func isExtraExtension(oid asn1.ObjectIdentifier) bool { for _, id := range extraExtensions { if id.Equal(oid) { return false } } return true } func createObjectID(oid asn1.ObjectIdentifier) *pb.ObjectId { ret := make([]int32, len(oid)) for i, v := range oid { ret[i] = cast.Int32(v) } return &pb.ObjectId{ ObjectIdPath: ret, } } func findExtraExtension(cert *x509.Certificate, oid asn1.ObjectIdentifier) (pkix.Extension, bool) { for _, ext := range cert.ExtraExtensions { if ext.Id.Equal(oid) { return ext, true } } return pkix.Extension{}, false } func createKeyVersionSpec(alg kmsapi.SignatureAlgorithm, bits int) (*pb.CertificateAuthority_KeyVersionSpec, error) { switch alg { case kmsapi.UnspecifiedSignAlgorithm, kmsapi.ECDSAWithSHA256: return &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_EC_P256_SHA256, }, }, nil case kmsapi.ECDSAWithSHA384: return &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_EC_P384_SHA384, }, }, nil case kmsapi.SHA256WithRSA: algo, err := getRSAPKCS1Algorithm(bits) if err != nil { return nil, err } return &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: algo, }, }, nil case kmsapi.SHA256WithRSAPSS: algo, err := getRSAPSSAlgorithm(bits) if err != nil { return nil, err } return &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: algo, }, }, nil default: return nil, fmt.Errorf("unknown or unsupported signature algorithm '%s'", alg) } } func getRSAPKCS1Algorithm(bits int) (pb.CertificateAuthority_SignHashAlgorithm, error) { switch bits { case 0, 3072: return pb.CertificateAuthority_RSA_PKCS1_3072_SHA256, nil case 2048: return pb.CertificateAuthority_RSA_PKCS1_2048_SHA256, nil case 4096: return pb.CertificateAuthority_RSA_PKCS1_4096_SHA256, nil default: return 0, fmt.Errorf("unsupported RSA PKCS #1 key size '%d'", bits) } } func getRSAPSSAlgorithm(bits int) (pb.CertificateAuthority_SignHashAlgorithm, error) { switch bits { case 0, 3072: return pb.CertificateAuthority_RSA_PSS_3072_SHA256, nil case 2048: return pb.CertificateAuthority_RSA_PSS_2048_SHA256, nil case 4096: return pb.CertificateAuthority_RSA_PSS_4096_SHA256, nil default: return 0, fmt.Errorf("unsupported RSA-PSS key size '%d'", bits) } } ================================================ FILE: cas/cloudcas/certificate_test.go ================================================ package cloudcas import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "net" "net/url" "reflect" "testing" pb "cloud.google.com/go/security/privateca/apiv1/privatecapb" kmsapi "go.step.sm/crypto/kms/apiv1" ) var ( testLeafPublicKey = `-----BEGIN PUBLIC KEY----- MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEAdUSRBrpgHFilN4eaGlNnX2+xfjX a1Iwk2/+AensjFTXJi1UAIB0e+4pqi7Sen5E2QVBhntEHCrA3xOf7czgPw== -----END PUBLIC KEY----- ` testRSACertificate = `-----BEGIN CERTIFICATE----- MIICozCCAkmgAwIBAgIRANNhMpODj7ThgviZCoF6kj8wCgYIKoZIzj0EAwIwKjEo MCYGA1UEAxMfR29vZ2xlIENBUyBUZXN0IEludGVybWVkaWF0ZSBDQTAeFw0yMDA5 MTUwMTUxMDdaFw0zMDA5MTMwMTUxMDNaMB0xGzAZBgNVBAMTEnRlc3Quc21hbGxz dGVwLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANPRjuIlsP5Z 672syAsHlbILFabG/xmrlsO0UdcLo4Yjf9WPAFA+7q+CsVDFh4dQbMv96fsHtdYP E9wlWyMqYG+5E8QT2i0WNFEoYcXOGZuXdyD/TA5Aucu1RuYLrZXQrXWDnvaWOgvr EZ6s9VsPCzzkL8KBejIMQIMY0KXEJfB/HgXZNn8V2trZkWT5CzxbcOF3s3UC1Z6F Ja6zjpxhSyRkqgknJxv6yK4t7HEwdhrDI8uyxJYHPQWKNRjWecHWE9E+MtoS7D08 mTh8qlAKoBbkGolR2nJSXffU09F3vSg+MIfjPiRqjf6394cQ3T9D5yZK//rCrxWU 8KKBQMEmdKcCAwEAAaOBkTCBjjAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYI KwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBQffuoYvH1+IF1cipl35gXJxSJE SjAfBgNVHSMEGDAWgBRIOVqyLDSlErJLuWWEvRm5UU1r1TAdBgNVHREEFjAUghJ0 ZXN0LnNtYWxsc3RlcC5jb20wCgYIKoZIzj0EAwIDSAAwRQIhAL9AAw/LVLvvxBkM sJnHd+RIk7ZblkgcArwpIS2+Z5xNAiBtUED4zyimz9b4aQiXdw4IMd2CKxVyW8eE 6x1vSZMvzQ== -----END CERTIFICATE-----` testRSAPublicKey = `-----BEGIN RSA PUBLIC KEY----- MIIBCgKCAQEA09GO4iWw/lnrvazICweVsgsVpsb/GauWw7RR1wujhiN/1Y8AUD7u r4KxUMWHh1Bsy/3p+we11g8T3CVbIypgb7kTxBPaLRY0UShhxc4Zm5d3IP9MDkC5 y7VG5gutldCtdYOe9pY6C+sRnqz1Ww8LPOQvwoF6MgxAgxjQpcQl8H8eBdk2fxXa 2tmRZPkLPFtw4XezdQLVnoUlrrOOnGFLJGSqCScnG/rIri3scTB2GsMjy7LElgc9 BYo1GNZ5wdYT0T4y2hLsPTyZOHyqUAqgFuQaiVHaclJd99TT0Xe9KD4wh+M+JGqN /rf3hxDdP0PnJkr/+sKvFZTwooFAwSZ0pwIDAQAB -----END RSA PUBLIC KEY----- ` ) func Test_createCertificateConfig(t *testing.T) { cert := mustParseCertificate(t, testLeafCertificate) type args struct { tpl *x509.Certificate } tests := []struct { name string args args want *pb.Certificate_Config wantErr bool }{ {"ok", args{cert}, &pb.Certificate_Config{ Config: &pb.CertificateConfig{ SubjectConfig: &pb.CertificateConfig_SubjectConfig{ Subject: &pb.Subject{ CommonName: "test.smallstep.com", }, SubjectAltName: &pb.SubjectAltNames{ DnsNames: []string{"test.smallstep.com"}, }, }, X509Config: &pb.X509Parameters{ KeyUsage: &pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ DigitalSignature: true, }, ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{ ClientAuth: true, ServerAuth: true, }, }, }, PublicKey: &pb.PublicKey{ Key: []byte(testLeafPublicKey), Format: pb.PublicKey_PEM, }, }, }, false}, {"fail", args{&x509.Certificate{}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := createCertificateConfig(tt.args.tpl) if (err != nil) != tt.wantErr { t.Errorf("createCertificateConfig() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("createCertificateConfig() = %v, want %v", got.Config, tt.want.Config) } }) } } func Test_createPublicKey(t *testing.T) { edpub, _, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } ecCert := mustParseCertificate(t, testLeafCertificate) ecCertPublicKey := ecCert.PublicKey.(*ecdsa.PublicKey) rsaCert := mustParseCertificate(t, testRSACertificate) type args struct { key crypto.PublicKey } tests := []struct { name string args args want *pb.PublicKey wantErr bool }{ {"ok ec", args{ecCert.PublicKey}, &pb.PublicKey{ Format: pb.PublicKey_PEM, Key: []byte(testLeafPublicKey), }, false}, {"ok rsa", args{rsaCert.PublicKey}, &pb.PublicKey{ Format: pb.PublicKey_PEM, Key: []byte(testRSAPublicKey), }, false}, {"fail ed25519", args{edpub}, nil, true}, {"fail ec marshal", args{&ecdsa.PublicKey{ Curve: &elliptic.CurveParams{ Name: "FOO", BitSize: 256, P: ecCertPublicKey.Params().P, B: ecCertPublicKey.Params().B, }, X: ecCertPublicKey.X, Y: ecCertPublicKey.Y, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := createPublicKey(tt.args.key) if (err != nil) != tt.wantErr { t.Errorf("createPublicKey() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("createPublicKey() = %v, want %v", got, tt.want) } }) } } func Test_createSubject(t *testing.T) { type args struct { cert *x509.Certificate } tests := []struct { name string args args want *pb.Subject }{ {"ok empty", args{&x509.Certificate{}}, &pb.Subject{}}, {"ok all", args{&x509.Certificate{ Subject: pkix.Name{ Country: []string{"US"}, Organization: []string{"Smallstep Labs"}, OrganizationalUnit: []string{"Engineering"}, Locality: []string{"San Francisco"}, Province: []string{"California"}, StreetAddress: []string{"1 A St."}, PostalCode: []string{"12345"}, SerialNumber: "1234567890", CommonName: "test.smallstep.com", }, }}, &pb.Subject{ CountryCode: "US", Organization: "Smallstep Labs", OrganizationalUnit: "Engineering", Locality: "San Francisco", Province: "California", StreetAddress: "1 A St.", PostalCode: "12345", CommonName: "test.smallstep.com", }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := createSubject(tt.args.cert); !reflect.DeepEqual(got, tt.want) { t.Errorf("createSubject() = %v, want %v", got, tt.want) } }) } } func Test_createSubjectAlternativeNames(t *testing.T) { marshalRawValues := func(rawValues []asn1.RawValue) []byte { b, err := asn1.Marshal(rawValues) if err != nil { t.Fatal(err) } return b } uri := func(s string) *url.URL { u, err := url.Parse(s) if err != nil { t.Fatal(err) } return u } type args struct { cert *x509.Certificate } tests := []struct { name string args args want *pb.SubjectAltNames }{ {"ok empty", args{&x509.Certificate{}}, &pb.SubjectAltNames{}}, {"ok dns", args{&x509.Certificate{DNSNames: []string{ "doe.com", "doe.org", }}}, &pb.SubjectAltNames{DnsNames: []string{"doe.com", "doe.org"}}}, {"ok emails", args{&x509.Certificate{EmailAddresses: []string{ "john@doe.com", "jane@doe.com", }}}, &pb.SubjectAltNames{EmailAddresses: []string{"john@doe.com", "jane@doe.com"}}}, {"ok ips", args{&x509.Certificate{IPAddresses: []net.IP{ net.ParseIP("127.0.0.1"), net.ParseIP("1.2.3.4"), net.ParseIP("::1"), net.ParseIP("2001:0db8:85a3:a0b:12f0:8a2e:0370:7334"), net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), }}}, &pb.SubjectAltNames{IpAddresses: []string{"127.0.0.1", "1.2.3.4", "::1", "2001:db8:85a3:a0b:12f0:8a2e:370:7334", "2001:db8:85a3::8a2e:370:7334"}}}, {"ok uris", args{&x509.Certificate{URIs: []*url.URL{ uri("mailto:john@doe.com"), uri("https://john@doe.com/hello"), }}}, &pb.SubjectAltNames{Uris: []string{"mailto:john@doe.com", "https://john@doe.com/hello"}}}, {"ok extensions", args{&x509.Certificate{ ExtraExtensions: []pkix.Extension{{ Id: []int{2, 5, 29, 17}, Critical: true, Value: []byte{ 0x30, 0x48, 0x82, 0x0b, 0x77, 0x77, 0x77, 0x2e, 0x64, 0x6f, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x81, 0x0c, 0x6a, 0x61, 0x6e, 0x65, 0x40, 0x64, 0x6f, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x87, 0x04, 0x01, 0x02, 0x03, 0x04, 0x87, 0x10, 0x20, 0x01, 0x0d, 0xb8, 0x85, 0xa3, 0x0a, 0x0b, 0x12, 0xf0, 0x8a, 0x2e, 0x03, 0x70, 0x73, 0x34, 0x86, 0x13, 0x6d, 0x61, 0x69, 0x6c, 0x74, 0x6f, 0x3a, 0x6a, 0x61, 0x6e, 0x65, 0x40, 0x64, 0x6f, 0x65, 0x2e, 0x63, 0x6f, 0x6d, }, }}, }}, &pb.SubjectAltNames{ CustomSans: []*pb.X509Extension{{ ObjectId: &pb.ObjectId{ObjectIdPath: []int32{2, 5, 29, 17}}, Critical: true, Value: []byte{ 0x30, 0x48, 0x82, 0x0b, 0x77, 0x77, 0x77, 0x2e, 0x64, 0x6f, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x81, 0x0c, 0x6a, 0x61, 0x6e, 0x65, 0x40, 0x64, 0x6f, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x87, 0x04, 0x01, 0x02, 0x03, 0x04, 0x87, 0x10, 0x20, 0x01, 0x0d, 0xb8, 0x85, 0xa3, 0x0a, 0x0b, 0x12, 0xf0, 0x8a, 0x2e, 0x03, 0x70, 0x73, 0x34, 0x86, 0x13, 0x6d, 0x61, 0x69, 0x6c, 0x74, 0x6f, 0x3a, 0x6a, 0x61, 0x6e, 0x65, 0x40, 0x64, 0x6f, 0x65, 0x2e, 0x63, 0x6f, 0x6d, }, }}, }}, {"ok extra extensions", args{&x509.Certificate{ DNSNames: []string{"doe.com"}, ExtraExtensions: []pkix.Extension{{ Id: []int{2, 5, 29, 17}, Critical: true, Value: marshalRawValues([]asn1.RawValue{ {Class: asn1.ClassApplication, Tag: 2, IsCompound: true, Bytes: []byte{}}, {Class: asn1.ClassContextSpecific, Tag: nameTypeDNS, Bytes: []byte("doe.com")}, {Class: asn1.ClassContextSpecific, Tag: nameTypeEmail, Bytes: []byte("jane@doe.com")}, {Class: asn1.ClassContextSpecific, Tag: 8, Bytes: []byte("foo.bar")}, }), }}, }}, &pb.SubjectAltNames{ DnsNames: []string{"doe.com"}, CustomSans: []*pb.X509Extension{{ ObjectId: &pb.ObjectId{ObjectIdPath: []int32{2, 5, 29, 17}}, Critical: true, Value: marshalRawValues([]asn1.RawValue{ {Class: asn1.ClassApplication, Tag: 2, IsCompound: true, Bytes: []byte{}}, {Class: asn1.ClassContextSpecific, Tag: nameTypeEmail, Bytes: []byte("jane@doe.com")}, {Class: asn1.ClassContextSpecific, Tag: 8, Bytes: []byte("foo.bar")}, }), }}, }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := createSubjectAlternativeNames(tt.args.cert); !reflect.DeepEqual(got, tt.want) { t.Errorf("createSubjectAlternativeNames() = %v, want %v", got, tt.want) } }) } } func Test_createX509Parameters(t *testing.T) { withKU := func(ku *pb.KeyUsage) *pb.X509Parameters { if ku.BaseKeyUsage == nil { ku.BaseKeyUsage = &pb.KeyUsage_KeyUsageOptions{} } if ku.ExtendedKeyUsage == nil { ku.ExtendedKeyUsage = &pb.KeyUsage_ExtendedKeyUsageOptions{} } return &pb.X509Parameters{ KeyUsage: ku, } } withRCV := func(rcv *pb.X509Parameters) *pb.X509Parameters { if rcv.KeyUsage == nil { rcv.KeyUsage = &pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{}, ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{}, } } return rcv } vTrue := true vFalse := false vZero := int32(0) vOne := int32(1) type args struct { cert *x509.Certificate } tests := []struct { name string args args want *pb.X509Parameters }{ {"keyUsageDigitalSignature", args{&x509.Certificate{ KeyUsage: x509.KeyUsageDigitalSignature, }}, &pb.X509Parameters{ KeyUsage: &pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ DigitalSignature: true, }, ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{}, UnknownExtendedKeyUsages: nil, }, CaOptions: nil, PolicyIds: nil, AiaOcspServers: nil, AdditionalExtensions: nil, }}, // KeyUsage {"KeyUsageDigitalSignature", args{&x509.Certificate{KeyUsage: x509.KeyUsageDigitalSignature}}, withKU(&pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ DigitalSignature: true, }, })}, {"KeyUsageContentCommitment", args{&x509.Certificate{KeyUsage: x509.KeyUsageContentCommitment}}, withKU(&pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ ContentCommitment: true, }, })}, {"KeyUsageKeyEncipherment", args{&x509.Certificate{KeyUsage: x509.KeyUsageKeyEncipherment}}, withKU(&pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ KeyEncipherment: true, }, })}, {"KeyUsageDataEncipherment", args{&x509.Certificate{KeyUsage: x509.KeyUsageDataEncipherment}}, withKU(&pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ DataEncipherment: true, }, })}, {"KeyUsageKeyAgreement", args{&x509.Certificate{KeyUsage: x509.KeyUsageKeyAgreement}}, withKU(&pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ KeyAgreement: true, }, })}, {"KeyUsageCertSign", args{&x509.Certificate{KeyUsage: x509.KeyUsageCertSign}}, withKU(&pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ CertSign: true, }, })}, {"KeyUsageCRLSign", args{&x509.Certificate{KeyUsage: x509.KeyUsageCRLSign}}, withKU(&pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ CrlSign: true, }, })}, {"KeyUsageEncipherOnly", args{&x509.Certificate{KeyUsage: x509.KeyUsageEncipherOnly}}, withKU(&pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ EncipherOnly: true, }, })}, {"KeyUsageDecipherOnly", args{&x509.Certificate{KeyUsage: x509.KeyUsageDecipherOnly}}, withKU(&pb.KeyUsage{ BaseKeyUsage: &pb.KeyUsage_KeyUsageOptions{ DecipherOnly: true, }, })}, // ExtKeyUsage {"ExtKeyUsageAny", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}}}, withKU(&pb.KeyUsage{ UnknownExtendedKeyUsages: []*pb.ObjectId{{ObjectIdPath: []int32{2, 5, 29, 37, 0}}}, })}, {"ExtKeyUsageServerAuth", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}}}, withKU(&pb.KeyUsage{ ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{ ServerAuth: true, }, })}, {"ExtKeyUsageClientAuth", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}}}, withKU(&pb.KeyUsage{ ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{ ClientAuth: true, }, })}, {"ExtKeyUsageCodeSigning", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning}}}, withKU(&pb.KeyUsage{ ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{ CodeSigning: true, }, })}, {"ExtKeyUsageEmailProtection", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageEmailProtection}}}, withKU(&pb.KeyUsage{ ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{ EmailProtection: true, }, })}, {"ExtKeyUsageIPSECEndSystem", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageIPSECEndSystem}}}, withKU(&pb.KeyUsage{ UnknownExtendedKeyUsages: []*pb.ObjectId{{ObjectIdPath: []int32{1, 3, 6, 1, 5, 5, 7, 3, 5}}}, })}, {"ExtKeyUsageIPSECTunnel", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageIPSECTunnel}}}, withKU(&pb.KeyUsage{ UnknownExtendedKeyUsages: []*pb.ObjectId{{ObjectIdPath: []int32{1, 3, 6, 1, 5, 5, 7, 3, 6}}}, })}, {"ExtKeyUsageIPSECUser", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageIPSECUser}}}, withKU(&pb.KeyUsage{ UnknownExtendedKeyUsages: []*pb.ObjectId{{ObjectIdPath: []int32{1, 3, 6, 1, 5, 5, 7, 3, 7}}}, })}, {"ExtKeyUsageTimeStamping", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageTimeStamping}}}, withKU(&pb.KeyUsage{ ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{ TimeStamping: true, }, })}, {"ExtKeyUsageOCSPSigning", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageOCSPSigning}}}, withKU(&pb.KeyUsage{ ExtendedKeyUsage: &pb.KeyUsage_ExtendedKeyUsageOptions{ OcspSigning: true, }, })}, {"ExtKeyUsageMicrosoftServerGatedCrypto", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageMicrosoftServerGatedCrypto}}}, withKU(&pb.KeyUsage{ UnknownExtendedKeyUsages: []*pb.ObjectId{{ObjectIdPath: []int32{1, 3, 6, 1, 4, 1, 311, 10, 3, 3}}}, })}, {"ExtKeyUsageNetscapeServerGatedCrypto", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageNetscapeServerGatedCrypto}}}, withKU(&pb.KeyUsage{ UnknownExtendedKeyUsages: []*pb.ObjectId{{ObjectIdPath: []int32{2, 16, 840, 1, 113730, 4, 1}}}, })}, {"ExtKeyUsageMicrosoftCommercialCodeSigning", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageMicrosoftCommercialCodeSigning}}}, withKU(&pb.KeyUsage{ UnknownExtendedKeyUsages: []*pb.ObjectId{{ObjectIdPath: []int32{1, 3, 6, 1, 4, 1, 311, 2, 1, 22}}}, })}, {"ExtKeyUsageMicrosoftKernelCodeSigning", args{&x509.Certificate{ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageMicrosoftKernelCodeSigning}}}, withKU(&pb.KeyUsage{ UnknownExtendedKeyUsages: []*pb.ObjectId{{ObjectIdPath: []int32{1, 3, 6, 1, 4, 1, 311, 61, 1, 1}}}, })}, // UnknownExtendedKeyUsages {"UnknownExtKeyUsage", args{&x509.Certificate{UnknownExtKeyUsage: []asn1.ObjectIdentifier{{1, 2, 3, 4}, {4, 3, 2, 1}}}}, withKU(&pb.KeyUsage{ UnknownExtendedKeyUsages: []*pb.ObjectId{ {ObjectIdPath: []int32{1, 2, 3, 4}}, {ObjectIdPath: []int32{4, 3, 2, 1}}, }, })}, // BasicCre {"BasicConstraintsCAMax0", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, MaxPathLenZero: true}}, withRCV(&pb.X509Parameters{ CaOptions: &pb.X509Parameters_CaOptions{ IsCa: &vTrue, MaxIssuerPathLength: &vZero, }, })}, {"BasicConstraintsCAMax1", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: 1, MaxPathLenZero: false}}, withRCV(&pb.X509Parameters{ CaOptions: &pb.X509Parameters_CaOptions{ IsCa: &vTrue, MaxIssuerPathLength: &vOne, }, })}, {"BasicConstraintsCANoMax", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: -1, MaxPathLenZero: false}}, withRCV(&pb.X509Parameters{ CaOptions: &pb.X509Parameters_CaOptions{ IsCa: &vTrue, MaxIssuerPathLength: nil, }, })}, {"BasicConstraintsCANoMax0", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, MaxPathLenZero: false}}, withRCV(&pb.X509Parameters{ CaOptions: &pb.X509Parameters_CaOptions{ IsCa: &vTrue, MaxIssuerPathLength: nil, }, })}, {"BasicConstraintsNoCA", args{&x509.Certificate{BasicConstraintsValid: true, IsCA: false, MaxPathLen: 0, MaxPathLenZero: false}}, withRCV(&pb.X509Parameters{ CaOptions: &pb.X509Parameters_CaOptions{ IsCa: &vFalse, MaxIssuerPathLength: nil, }, })}, {"BasicConstraintsNoValid", args{&x509.Certificate{BasicConstraintsValid: false, IsCA: false, MaxPathLen: 0, MaxPathLenZero: false}}, withRCV(&pb.X509Parameters{ CaOptions: nil, })}, // PolicyIdentifiers {"PolicyIdentifiers", args{&x509.Certificate{PolicyIdentifiers: []asn1.ObjectIdentifier{{1, 2, 3, 4}, {4, 3, 2, 1}}}}, withRCV(&pb.X509Parameters{ PolicyIds: []*pb.ObjectId{ {ObjectIdPath: []int32{1, 2, 3, 4}}, {ObjectIdPath: []int32{4, 3, 2, 1}}, }, })}, // OCSPServer {"OCPServers", args{&x509.Certificate{OCSPServer: []string{"https://oscp.doe.com", "https://doe.com/ocsp"}}}, withRCV(&pb.X509Parameters{ AiaOcspServers: []string{"https://oscp.doe.com", "https://doe.com/ocsp"}, })}, // Extensions {"Extensions", args{&x509.Certificate{ExtraExtensions: []pkix.Extension{ {Id: []int{1, 2, 3, 4}, Critical: true, Value: []byte("foobar")}, {Id: []int{2, 5, 29, 17}, Critical: true, Value: []byte("SANs")}, // {Id: []int{4, 3, 2, 1}, Critical: false, Value: []byte("zoobar")}, {Id: []int{2, 5, 29, 31}, Critical: false, Value: []byte("CRL Distribution points")}, }}}, withRCV(&pb.X509Parameters{ AdditionalExtensions: []*pb.X509Extension{ {ObjectId: &pb.ObjectId{ObjectIdPath: []int32{1, 2, 3, 4}}, Critical: true, Value: []byte("foobar")}, {ObjectId: &pb.ObjectId{ObjectIdPath: []int32{4, 3, 2, 1}}, Critical: false, Value: []byte("zoobar")}, }, })}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := createX509Parameters(tt.args.cert); !reflect.DeepEqual(got, tt.want) { t.Errorf("createX509Parameters() = %v, want %v", got, tt.want) } }) } } func Test_isExtraExtension(t *testing.T) { type args struct { oid asn1.ObjectIdentifier } tests := []struct { name string args args want bool }{ {"oidExtensionSubjectKeyID", args{oidExtensionSubjectKeyID}, false}, {"oidExtensionKeyUsage", args{oidExtensionKeyUsage}, false}, {"oidExtensionExtendedKeyUsage", args{oidExtensionExtendedKeyUsage}, false}, {"oidExtensionAuthorityKeyID", args{oidExtensionAuthorityKeyID}, false}, {"oidExtensionBasicConstraints", args{oidExtensionBasicConstraints}, false}, {"oidExtensionSubjectAltName", args{oidExtensionSubjectAltName}, false}, {"oidExtensionCRLDistributionPoints", args{oidExtensionCRLDistributionPoints}, false}, {"oidExtensionCertificatePolicies", args{oidExtensionCertificatePolicies}, false}, {"oidExtensionAuthorityInfoAccess", args{oidExtensionAuthorityInfoAccess}, false}, {"other", args{[]int{1, 2, 3, 4}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := isExtraExtension(tt.args.oid); got != tt.want { t.Errorf("isExtraExtension() = %v, want %v", got, tt.want) } }) } } func Test_createKeyVersionSpec(t *testing.T) { type args struct { alg kmsapi.SignatureAlgorithm bits int } tests := []struct { name string args args want *pb.CertificateAuthority_KeyVersionSpec wantErr bool }{ {"ok P256", args{0, 0}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_EC_P256_SHA256, }}, false}, {"ok P256", args{kmsapi.ECDSAWithSHA256, 0}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_EC_P256_SHA256, }}, false}, {"ok P384", args{kmsapi.ECDSAWithSHA384, 0}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_EC_P384_SHA384, }}, false}, {"ok RSA default", args{kmsapi.SHA256WithRSA, 0}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_RSA_PKCS1_3072_SHA256, }}, false}, {"ok RSA 2048", args{kmsapi.SHA256WithRSA, 2048}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_RSA_PKCS1_2048_SHA256, }}, false}, {"ok RSA 3072", args{kmsapi.SHA256WithRSA, 3072}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_RSA_PKCS1_3072_SHA256, }}, false}, {"ok RSA 4096", args{kmsapi.SHA256WithRSA, 4096}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_RSA_PKCS1_4096_SHA256, }}, false}, {"ok RSA-PSS default", args{kmsapi.SHA256WithRSAPSS, 0}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_RSA_PSS_3072_SHA256, }}, false}, {"ok RSA-PSS 2048", args{kmsapi.SHA256WithRSAPSS, 2048}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_RSA_PSS_2048_SHA256, }}, false}, {"ok RSA-PSS 3072", args{kmsapi.SHA256WithRSAPSS, 3072}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_RSA_PSS_3072_SHA256, }}, false}, {"ok RSA-PSS 4096", args{kmsapi.SHA256WithRSAPSS, 4096}, &pb.CertificateAuthority_KeyVersionSpec{ KeyVersion: &pb.CertificateAuthority_KeyVersionSpec_Algorithm{ Algorithm: pb.CertificateAuthority_RSA_PSS_4096_SHA256, }}, false}, {"fail Ed25519", args{kmsapi.PureEd25519, 0}, nil, true}, {"fail RSA size", args{kmsapi.SHA256WithRSA, 1024}, nil, true}, {"fail RSA-PSS size", args{kmsapi.SHA256WithRSAPSS, 1024}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := createKeyVersionSpec(tt.args.alg, tt.args.bits) if (err != nil) != tt.wantErr { t.Errorf("createKeyVersionSpec() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("createKeyVersionSpec() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cas/cloudcas/cloudcas.go ================================================ package cloudcas import ( "context" "crypto/rand" "crypto/x509" "encoding/asn1" "encoding/pem" "regexp" "strings" "time" privateca "cloud.google.com/go/security/privateca/apiv1" pb "cloud.google.com/go/security/privateca/apiv1/privatecapb" "github.com/google/uuid" gax "github.com/googleapis/gax-go/v2" "github.com/pkg/errors" "google.golang.org/api/option" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" durationpb "google.golang.org/protobuf/types/known/durationpb" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/cas/apiv1" ) func init() { apiv1.Register(apiv1.CloudCAS, func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) { return New(ctx, opts) }) } var now = time.Now // The actual regular expression that matches a certificate authority is: // // ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$ // // But we will allow a more flexible one to fail if this changes. var caRegexp = regexp.MustCompile("^projects/[^/]+/locations/[^/]+/caPools/[^/]+/certificateAuthorities/[^/]+$") // CertificateAuthorityClient is the interface implemented by the Google CAS // client. type CertificateAuthorityClient interface { CreateCertificate(ctx context.Context, req *pb.CreateCertificateRequest, opts ...gax.CallOption) (*pb.Certificate, error) RevokeCertificate(ctx context.Context, req *pb.RevokeCertificateRequest, opts ...gax.CallOption) (*pb.Certificate, error) GetCertificateAuthority(ctx context.Context, req *pb.GetCertificateAuthorityRequest, opts ...gax.CallOption) (*pb.CertificateAuthority, error) CreateCertificateAuthority(ctx context.Context, req *pb.CreateCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.CreateCertificateAuthorityOperation, error) FetchCertificateAuthorityCsr(ctx context.Context, req *pb.FetchCertificateAuthorityCsrRequest, opts ...gax.CallOption) (*pb.FetchCertificateAuthorityCsrResponse, error) ActivateCertificateAuthority(ctx context.Context, req *pb.ActivateCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.ActivateCertificateAuthorityOperation, error) EnableCertificateAuthority(ctx context.Context, req *pb.EnableCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.EnableCertificateAuthorityOperation, error) GetCaPool(ctx context.Context, req *pb.GetCaPoolRequest, opts ...gax.CallOption) (*pb.CaPool, error) CreateCaPool(ctx context.Context, req *pb.CreateCaPoolRequest, opts ...gax.CallOption) (*privateca.CreateCaPoolOperation, error) } // recocationCodeMap maps revocation reason codes from RFC 5280, to Google CAS // revocation reasons. Revocation reason 7 is not used, and revocation reason 8 // (removeFromCRL) is not supported by Google CAS. var revocationCodeMap = map[int]pb.RevocationReason{ 0: pb.RevocationReason_REVOCATION_REASON_UNSPECIFIED, 1: pb.RevocationReason_KEY_COMPROMISE, 2: pb.RevocationReason_CERTIFICATE_AUTHORITY_COMPROMISE, 3: pb.RevocationReason_AFFILIATION_CHANGED, 4: pb.RevocationReason_SUPERSEDED, 5: pb.RevocationReason_CESSATION_OF_OPERATION, 6: pb.RevocationReason_CERTIFICATE_HOLD, 9: pb.RevocationReason_PRIVILEGE_WITHDRAWN, 10: pb.RevocationReason_ATTRIBUTE_AUTHORITY_COMPROMISE, } // caPoolTierMap contains the map between apiv1.Options.Tier and the pb type. var caPoolTierMap = map[string]pb.CaPool_Tier{ "": pb.CaPool_DEVOPS, "ENTERPRISE": pb.CaPool_ENTERPRISE, "DEVOPS": pb.CaPool_DEVOPS, } // CloudCAS implements a Certificate Authority Service using Google Cloud CAS. type CloudCAS struct { client CertificateAuthorityClient certificateAuthority string project string location string caPool string caPoolTier pb.CaPool_Tier gcsBucket string } // newCertificateAuthorityClient creates the certificate authority client. This // function is used for testing purposes. var newCertificateAuthorityClient = func(ctx context.Context, credentialsFile string) (CertificateAuthorityClient, error) { var cloudOpts []option.ClientOption if credentialsFile != "" { cloudOpts = append(cloudOpts, option.WithAuthCredentialsFile(option.ServiceAccount, credentialsFile)) } client, err := privateca.NewCertificateAuthorityClient(ctx, cloudOpts...) if err != nil { return nil, errors.Wrap(err, "error creating client") } return client, nil } // New creates a new CertificateAuthorityService implementation using Google // Cloud CAS. func New(ctx context.Context, opts apiv1.Options) (*CloudCAS, error) { var caPoolTier pb.CaPool_Tier if opts.IsCreator && opts.CertificateAuthority == "" { switch { case opts.Project == "": return nil, errors.New("cloudCAS 'project' cannot be empty") case opts.Location == "": return nil, errors.New("cloudCAS 'location' cannot be empty") case opts.CaPool == "": return nil, errors.New("cloudCAS 'caPool' cannot be empty") } var ok bool if caPoolTier, ok = caPoolTierMap[strings.ToUpper(opts.CaPoolTier)]; !ok { return nil, errors.New("cloudCAS 'caPoolTier' is not a valid tier") } } else { if opts.CertificateAuthority == "" { return nil, errors.New("cloudCAS 'certificateAuthority' cannot be empty") } if !caRegexp.MatchString(opts.CertificateAuthority) { return nil, errors.New("cloudCAS 'certificateAuthority' is not valid certificate authority resource") } // Extract project and location from CertificateAuthority if parts := strings.Split(opts.CertificateAuthority, "/"); len(parts) == 8 { if opts.Project == "" { opts.Project = parts[1] } if opts.Location == "" { opts.Location = parts[3] } if opts.CaPool == "" { opts.CaPool = parts[5] } } } client, err := newCertificateAuthorityClient(ctx, opts.CredentialsFile) if err != nil { return nil, err } // GCSBucket is the the bucket name or empty for a managed bucket. return &CloudCAS{ client: client, certificateAuthority: opts.CertificateAuthority, project: opts.Project, location: opts.Location, caPool: opts.CaPool, gcsBucket: opts.GCSBucket, caPoolTier: caPoolTier, }, nil } // Type returns the type of this CertificateAuthorityService. func (c *CloudCAS) Type() apiv1.Type { return apiv1.CloudCAS } // GetCertificateAuthority returns the root certificate for the given // certificate authority. It implements apiv1.CertificateAuthorityGetter // interface. func (c *CloudCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) { name := req.Name if name == "" { name = c.certificateAuthority } ctx, cancel := defaultContext() defer cancel() resp, err := c.client.GetCertificateAuthority(ctx, &pb.GetCertificateAuthorityRequest{ Name: name, }) if err != nil { return nil, errors.Wrap(err, "cloudCAS GetCertificateAuthority failed") } if len(resp.PemCaCertificates) == 0 { return nil, errors.New("cloudCAS GetCertificateAuthority: PemCACertificate should not be empty") } // Parse intermediate certificates var intermediates = make([]*x509.Certificate, len(resp.PemCaCertificates)-1) for i := 0; i < len(resp.PemCaCertificates)-1; i++ { intermediate, err := parseCertificate(resp.PemCaCertificates[i]) if err != nil { return nil, errors.Wrap(err, "parsing cloudCAS intermediates failed") } intermediates[i] = intermediate } // Last certificate in the chain is the root root, err := parseCertificate(resp.PemCaCertificates[len(resp.PemCaCertificates)-1]) if err != nil { return nil, errors.Wrap(err, "parsing cloudCAS root failed") } return &apiv1.GetCertificateAuthorityResponse{ RootCertificate: root, IntermediateCertificates: intermediates, }, nil } // CreateCertificate signs a new certificate using Google Cloud CAS. func (c *CloudCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) { switch { case req.Template == nil: return nil, errors.New("createCertificateRequest `template` cannot be nil") case req.Lifetime == 0: return nil, errors.New("createCertificateRequest `lifetime` cannot be 0") } cert, chain, err := c.createCertificate(req.Template, req.Lifetime, req.RequestID) if err != nil { return nil, err } return &apiv1.CreateCertificateResponse{ Certificate: cert, CertificateChain: chain, }, nil } // RenewCertificate renews the given certificate using Google Cloud CAS. // Google's CAS does not support the renew operation, so this method uses // CreateCertificate. func (c *CloudCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { switch { case req.Template == nil: return nil, errors.New("renewCertificateRequest `template` cannot be nil") case req.Lifetime == 0: return nil, errors.New("renewCertificateRequest `lifetime` cannot be 0") } cert, chain, err := c.createCertificate(req.Template, req.Lifetime, req.RequestID) if err != nil { return nil, err } return &apiv1.RenewCertificateResponse{ Certificate: cert, CertificateChain: chain, }, nil } // RevokeCertificate revokes a certificate using Google Cloud CAS. func (c *CloudCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { reason, ok := revocationCodeMap[req.ReasonCode] switch { case !ok: return nil, errors.Errorf("revokeCertificate 'reasonCode=%d' is invalid or not supported", req.ReasonCode) case req.Certificate == nil: return nil, errors.New("revokeCertificateRequest `certificate` cannot be nil") } ext, ok := apiv1.FindCertificateAuthorityExtension(req.Certificate) if !ok { return nil, errors.New("error revoking certificate: certificate authority extension was not found") } var cae apiv1.CertificateAuthorityExtension if _, err := asn1.Unmarshal(ext.Value, &cae); err != nil { return nil, errors.Wrap(err, "error unmarshaling certificate authority extension") } ctx, cancel := defaultContext() defer cancel() certpb, err := c.client.RevokeCertificate(ctx, &pb.RevokeCertificateRequest{ Name: c.certificateAuthority + "/certificates/" + cae.CertificateID, Reason: reason, RequestId: req.RequestID, }) if err != nil { return nil, errors.Wrap(err, "cloudCAS RevokeCertificate failed") } cert, chain, err := getCertificateAndChain(certpb) if err != nil { return nil, err } return &apiv1.RevokeCertificateResponse{ Certificate: cert, CertificateChain: chain, }, nil } // CreateCertificateAuthority creates a new root or intermediate certificate // using Google Cloud CAS. func (c *CloudCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthorityRequest) (*apiv1.CreateCertificateAuthorityResponse, error) { switch { case c.project == "": return nil, errors.New("cloudCAS `project` cannot be empty") case c.location == "": return nil, errors.New("cloudCAS `location` cannot be empty") case c.caPool == "": return nil, errors.New("cloudCAS `caPool` cannot be empty") case c.caPoolTier == 0: return nil, errors.New("cloudCAS `caPoolTier` cannot be empty") case req.Template == nil: return nil, errors.New("createCertificateAuthorityRequest `template` cannot be nil") case req.Lifetime == 0: return nil, errors.New("createCertificateAuthorityRequest `lifetime` cannot be 0") case req.Type == apiv1.IntermediateCA && req.Parent == nil: return nil, errors.New("createCertificateAuthorityRequest `parent` cannot be nil") case req.Type == apiv1.IntermediateCA && req.Parent.Name == "" && (req.Parent.Certificate == nil || req.Parent.Signer == nil): return nil, errors.New("createCertificateAuthorityRequest `parent.name` cannot be empty") } var caType pb.CertificateAuthority_Type switch req.Type { case apiv1.RootCA: caType = pb.CertificateAuthority_SELF_SIGNED case apiv1.IntermediateCA: caType = pb.CertificateAuthority_SUBORDINATE default: return nil, errors.Errorf("createCertificateAuthorityRequest `type=%d' is invalid or not supported", req.Type) } // Select key and signature algorithm to use var err error var keySpec *pb.CertificateAuthority_KeyVersionSpec if req.CreateKey == nil { if keySpec, err = createKeyVersionSpec(0, 0); err != nil { return nil, errors.Wrap(err, "createCertificateAuthorityRequest `createKey` is not valid") } } else { if keySpec, err = createKeyVersionSpec(req.CreateKey.SignatureAlgorithm, req.CreateKey.Bits); err != nil { return nil, errors.Wrap(err, "createCertificateAuthorityRequest `createKey` is not valid") } } // Normalize or generate id. caID := normalizeCertificateAuthorityName(req.Name) if caID == "" { id, err := createCertificateID() if err != nil { return nil, err } caID = id } // Add CertificateAuthority extension casExtension, err := apiv1.CreateCertificateAuthorityExtension(apiv1.CloudCAS, caID) if err != nil { return nil, err } req.Template.ExtraExtensions = append(req.Template.ExtraExtensions, casExtension) // Create the caPool if necessary parent, err := c.createCaPoolIfNecessary() if err != nil { return nil, err } // Prepare CreateCertificateAuthorityRequest pbReq := &pb.CreateCertificateAuthorityRequest{ Parent: parent, CertificateAuthorityId: caID, RequestId: req.RequestID, CertificateAuthority: &pb.CertificateAuthority{ Type: caType, Config: &pb.CertificateConfig{ SubjectConfig: &pb.CertificateConfig_SubjectConfig{ Subject: createSubject(req.Template), SubjectAltName: createSubjectAlternativeNames(req.Template), }, X509Config: createX509Parameters(req.Template), }, Lifetime: durationpb.New(req.Lifetime), KeySpec: keySpec, GcsBucket: c.gcsBucket, Labels: map[string]string{}, }, } // Create certificate authority. ctx, cancel := defaultContext() defer cancel() resp, err := c.client.CreateCertificateAuthority(ctx, pbReq) if err != nil { return nil, errors.Wrap(err, "cloudCAS CreateCertificateAuthority failed") } // Wait for the long-running operation. ctx, cancel = defaultInitiatorContext() defer cancel() ca, err := resp.Wait(ctx) if err != nil { return nil, errors.Wrap(err, "cloudCAS CreateCertificateAuthority failed") } // Sign Intermediate CAs with the parent. if req.Type == apiv1.IntermediateCA { ca, err = c.signIntermediateCA(parent, ca.Name, req) if err != nil { return nil, err } } // Enable Certificate Authority. ca, err = c.enableCertificateAuthority(ca) if err != nil { return nil, err } if len(ca.PemCaCertificates) == 0 { return nil, errors.New("cloudCAS CreateCertificateAuthority failed: PemCaCertificates is empty") } cert, err := parseCertificate(ca.PemCaCertificates[0]) if err != nil { return nil, err } var chain []*x509.Certificate if pemChain := ca.PemCaCertificates[1:]; len(pemChain) > 0 { chain = make([]*x509.Certificate, len(pemChain)) for i, s := range pemChain { if chain[i], err = parseCertificate(s); err != nil { return nil, err } } } return &apiv1.CreateCertificateAuthorityResponse{ Name: ca.Name, Certificate: cert, CertificateChain: chain, }, nil } func (c *CloudCAS) createCaPoolIfNecessary() (string, error) { ctx, cancel := defaultContext() defer cancel() pool, err := c.client.GetCaPool(ctx, &pb.GetCaPoolRequest{ Name: "projects/" + c.project + "/locations/" + c.location + "/caPools/" + c.caPool, }) if err == nil { return pool.Name, nil } if status.Code(err) != codes.NotFound { return "", errors.Wrap(err, "cloudCAS GetCaPool failed") } // PublishCrl is only supported by the enterprise tier var publishCrl bool if c.caPoolTier == pb.CaPool_ENTERPRISE { publishCrl = true } ctx, cancel = defaultContext() defer cancel() op, err := c.client.CreateCaPool(ctx, &pb.CreateCaPoolRequest{ Parent: "projects/" + c.project + "/locations/" + c.location, CaPoolId: c.caPool, CaPool: &pb.CaPool{ Tier: c.caPoolTier, IssuancePolicy: nil, PublishingOptions: &pb.CaPool_PublishingOptions{ PublishCaCert: true, PublishCrl: publishCrl, }, }, }) if err != nil { return "", errors.Wrap(err, "cloudCAS CreateCaPool failed") } ctx, cancel = defaultInitiatorContext() defer cancel() pool, err = op.Wait(ctx) if err != nil { return "", errors.Wrap(err, "cloudCAS CreateCaPool failed") } return pool.Name, nil } func (c *CloudCAS) enableCertificateAuthority(ca *pb.CertificateAuthority) (*pb.CertificateAuthority, error) { if ca.State == pb.CertificateAuthority_ENABLED { return ca, nil } ctx, cancel := defaultContext() defer cancel() resp, err := c.client.EnableCertificateAuthority(ctx, &pb.EnableCertificateAuthorityRequest{ Name: ca.Name, }) if err != nil { return nil, errors.Wrap(err, "cloudCAS EnableCertificateAuthority failed") } ctx, cancel = defaultInitiatorContext() defer cancel() ca, err = resp.Wait(ctx) if err != nil { return nil, errors.Wrap(err, "cloudCAS EnableCertificateAuthority failed") } return ca, nil } func (c *CloudCAS) createCertificate(tpl *x509.Certificate, lifetime time.Duration, requestID string) (*x509.Certificate, []*x509.Certificate, error) { // Removes the CAS extension if it exists. apiv1.RemoveCertificateAuthorityExtension(tpl) // Create new CAS extension with the certificate id. id, err := createCertificateID() if err != nil { return nil, nil, err } casExtension, err := apiv1.CreateCertificateAuthorityExtension(apiv1.CloudCAS, id) if err != nil { return nil, nil, err } tpl.ExtraExtensions = append(tpl.ExtraExtensions, casExtension) // Create and submit certificate certConfig, err := createCertificateConfig(tpl) if err != nil { return nil, nil, err } ctx, cancel := defaultContext() defer cancel() cert, err := c.client.CreateCertificate(ctx, &pb.CreateCertificateRequest{ Parent: "projects/" + c.project + "/locations/" + c.location + "/caPools/" + c.caPool, CertificateId: id, Certificate: &pb.Certificate{ CertificateConfig: certConfig, Lifetime: durationpb.New(lifetime), Labels: map[string]string{}, }, IssuingCertificateAuthorityId: getResourceName(c.certificateAuthority), RequestId: requestID, }) if err != nil { return nil, nil, errors.Wrap(err, "cloudCAS CreateCertificate failed") } // Return certificate and certificate chain return getCertificateAndChain(cert) } func (c *CloudCAS) signIntermediateCA(parent, name string, req *apiv1.CreateCertificateAuthorityRequest) (*pb.CertificateAuthority, error) { id, err := createCertificateID() if err != nil { return nil, err } // Fetch intermediate CSR ctx, cancel := defaultInitiatorContext() defer cancel() csr, err := c.client.FetchCertificateAuthorityCsr(ctx, &pb.FetchCertificateAuthorityCsrRequest{ Name: name, }) if err != nil { return nil, errors.Wrap(err, "cloudCAS FetchCertificateAuthorityCsr failed") } // Sign the CSR with the ca. var cert *pb.Certificate if req.Parent.Certificate != nil && req.Parent.Signer != nil { // Using a local certificate and key. cr, err := parseCertificateRequest(csr.PemCsr) if err != nil { return nil, err } template, err := x509util.CreateCertificateTemplate(cr) if err != nil { return nil, err } t := now() template.NotBefore = t.Add(-1 * req.Backdate) template.NotAfter = t.Add(req.Lifetime) // Sign certificate crt, err := x509util.CreateCertificate(template, req.Parent.Certificate, template.PublicKey, req.Parent.Signer) if err != nil { return nil, err } // Build pb.Certificate for activaion chain := []string{ encodeCertificate(req.Parent.Certificate), } for _, c := range req.Parent.CertificateChain { chain = append(chain, encodeCertificate(c)) } cert = &pb.Certificate{ PemCertificate: encodeCertificate(crt), PemCertificateChain: chain, } } else { // Using the parent in CloudCAS. ctx, cancel = defaultInitiatorContext() defer cancel() cert, err = c.client.CreateCertificate(ctx, &pb.CreateCertificateRequest{ Parent: parent, CertificateId: id, Certificate: &pb.Certificate{ CertificateConfig: &pb.Certificate_PemCsr{ PemCsr: csr.PemCsr, }, Lifetime: durationpb.New(req.Lifetime), Labels: map[string]string{}, }, IssuingCertificateAuthorityId: getResourceName(req.Parent.Name), RequestId: req.RequestID, }) if err != nil { return nil, errors.Wrap(err, "cloudCAS CreateCertificate failed") } } // Activate the intermediate certificate. ctx, cancel = defaultInitiatorContext() defer cancel() resp, err := c.client.ActivateCertificateAuthority(ctx, &pb.ActivateCertificateAuthorityRequest{ Name: name, PemCaCertificate: cert.PemCertificate, SubordinateConfig: &pb.SubordinateConfig{ SubordinateConfig: &pb.SubordinateConfig_PemIssuerChain{ PemIssuerChain: &pb.SubordinateConfig_SubordinateConfigChain{ PemCertificates: cert.PemCertificateChain, }, }, }, RequestId: req.RequestID, }) if err != nil { return nil, errors.Wrap(err, "cloudCAS ActivateCertificateAuthority1 failed") } // Wait for the long-running operation. ctx, cancel = defaultInitiatorContext() defer cancel() ca, err := resp.Wait(ctx) if err != nil { return nil, errors.Wrap(err, "cloudCAS ActivateCertificateAuthority failed") } return ca, nil } func defaultContext() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), 15*time.Second) } func defaultInitiatorContext() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), 60*time.Second) } func createCertificateID() (string, error) { id, err := uuid.NewRandomFromReader(rand.Reader) if err != nil { return "", errors.Wrap(err, "error creating certificate id") } return id.String(), nil } func parseCertificate(pemCert string) (*x509.Certificate, error) { block, _ := pem.Decode([]byte(pemCert)) if block == nil { return nil, errors.New("error decoding certificate: not a valid PEM encoded block") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, errors.Wrap(err, "error parsing certificate") } return cert, nil } func parseCertificateRequest(pemCsr string) (*x509.CertificateRequest, error) { block, _ := pem.Decode([]byte(pemCsr)) if block == nil { return nil, errors.New("error decoding certificate request: not a valid PEM encoded block") } cr, err := x509.ParseCertificateRequest(block.Bytes) if err != nil { return nil, errors.Wrap(err, "error parsing certificate request") } return cr, nil } func encodeCertificate(cert *x509.Certificate) string { return string(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: cert.Raw, })) } func getCertificateAndChain(certpb *pb.Certificate) (*x509.Certificate, []*x509.Certificate, error) { cert, err := parseCertificate(certpb.PemCertificate) if err != nil { return nil, nil, err } pemChain := certpb.PemCertificateChain[:len(certpb.PemCertificateChain)-1] chain := make([]*x509.Certificate, len(pemChain)) for i := range pemChain { chain[i], err = parseCertificate(pemChain[i]) if err != nil { return nil, nil, err } } return cert, chain, nil } // getResourceName returns the last part of a resource. func getResourceName(name string) string { parts := strings.Split(name, "/") return parts[len(parts)-1] } // Normalize a certificate authority name to comply with [a-zA-Z0-9-_]. func normalizeCertificateAuthorityName(name string) string { return strings.Map(func(r rune) rune { switch { case r >= 'a' && r <= 'z': return r case r >= 'A' && r <= 'Z': return r case r >= '0' && r <= '9': return r case r == '-': return r case r == '_': return r default: return '-' } }, name) } ================================================ FILE: cas/cloudcas/cloudcas_test.go ================================================ //go:generate go run go.uber.org/mock/mockgen -package cloudcas -mock_names=CertificateAuthorityClient=MockCertificateAuthorityClient -destination mock_client_test.go github.com/smallstep/certificates/cas/cloudcas CertificateAuthorityClient //go:generate go run go.uber.org/mock/mockgen -package cloudcas -mock_names=OperationsServer=MockOperationsServer -destination mock_operation_server_test.go cloud.google.com/go/longrunning/autogen/longrunningpb OperationsServer package cloudcas import ( "bytes" "context" "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rand" "crypto/x509" "encoding/asn1" "encoding/pem" "fmt" "io" "net" "os" "path/filepath" "reflect" "testing" "time" lroauto "cloud.google.com/go/longrunning/autogen" "cloud.google.com/go/longrunning/autogen/longrunningpb" privateca "cloud.google.com/go/security/privateca/apiv1" pb "cloud.google.com/go/security/privateca/apiv1/privatecapb" "github.com/google/uuid" "github.com/googleapis/gax-go/v2" "github.com/pkg/errors" "go.uber.org/mock/gomock" "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/types/known/anypb" kmsapi "go.step.sm/crypto/kms/apiv1" "github.com/smallstep/certificates/cas/apiv1" ) var ( errTest = errors.New("test error") testCaPoolName = "projects/test-project/locations/us-west1/caPools/test-capool" testAuthorityName = "projects/test-project/locations/us-west1/caPools/test-capool/certificateAuthorities/test-ca" testCertificateName = "projects/test-project/locations/us-west1/caPools/test-capool/certificateAuthorities/test-ca/certificates/test-certificate" testProject = "test-project" testLocation = "us-west1" testCaPool = "test-capool" testRootCertificate = `-----BEGIN CERTIFICATE----- MIIBeDCCAR+gAwIBAgIQcXWWjtSZ/PAyH8D1Ou4L9jAKBggqhkjOPQQDAjAbMRkw FwYDVQQDExBDbG91ZENBUyBSb290IENBMB4XDTIwMTAyNzIyNTM1NFoXDTMwMTAy NzIyNTM1NFowGzEZMBcGA1UEAxMQQ2xvdWRDQVMgUm9vdCBDQTBZMBMGByqGSM49 AgEGCCqGSM49AwEHA0IABIySHA4b78Yu4LuGhZIlv/PhNwXz4ZoV1OUZQ0LrK3vj B13O12DLZC5uj1z3kxdQzXUttSbtRv49clMpBiTpsZKjRTBDMA4GA1UdDwEB/wQE AwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSZ+t9RMHbFTl5BatM3 5bJlHPOu3DAKBggqhkjOPQQDAgNHADBEAiASah6gg0tVM3WI0meCQ4SEKk7Mjhbv +SmhuZHWV1QlXQIgRXNyWcpVUrAoG6Uy1KQg07LDpF5dFeK9InrDxSJAkVo= -----END CERTIFICATE-----` testIntermediateCertificate = `-----BEGIN CERTIFICATE----- MIIBpDCCAUmgAwIBAgIRALLKxnxyl0GBeKevIcbx02wwCgYIKoZIzj0EAwIwGzEZ MBcGA1UEAxMQQ2xvdWRDQVMgUm9vdCBDQTAeFw0yMDEwMjcyMjUzNTRaFw0zMDEw MjcyMjUzNTRaMCMxITAfBgNVBAMTGENsb3VkQ0FTIEludGVybWVkaWF0ZSBDQTBZ MBMGByqGSM49AgEGCCqGSM49AwEHA0IABPLuqxgBY+QmaXc8zKIC8FMgjJ6dF/cL b+Dig0XKc5GH/T1ORrhgOkRayrQcjPMu+jkjg25qn6vvp43LRtUKPXOjZjBkMA4G A1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEAMB0GA1UdDgQWBBQ8RVQI VgXAmRNDX8qItalVpSBEGjAfBgNVHSMEGDAWgBSZ+t9RMHbFTl5BatM35bJlHPOu 3DAKBggqhkjOPQQDAgNJADBGAiEA70MVYVqjm8SBHJf5cOlWfiXXOfHUsctTJ+/F pLsKBogCIQDJJkoQqYl9B59Dq3zydl8bpJevQxsoaa4Wqg+ZBMkvbQ== -----END CERTIFICATE-----` testLeafCertificate = `-----BEGIN CERTIFICATE----- MIIB1jCCAX2gAwIBAgIQQfOn+COMeuD8VYF1TiDkEzAKBggqhkjOPQQDAjAqMSgw JgYDVQQDEx9Hb29nbGUgQ0FTIFRlc3QgSW50ZXJtZWRpYXRlIENBMB4XDTIwMDkx NDIyNTE1NVoXDTMwMDkxMjIyNTE1MlowHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0 ZXAuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEAdUSRBrpgHFilN4eaGlN nX2+xfjXa1Iwk2/+AensjFTXJi1UAIB0e+4pqi7Sen5E2QVBhntEHCrA3xOf7czg P6OBkTCBjjAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG AQUFBwMCMB0GA1UdDgQWBBSYPbu4Tmm7Zze/hCePeZH1Avoj+jAfBgNVHSMEGDAW gBRIOVqyLDSlErJLuWWEvRm5UU1r1TAdBgNVHREEFjAUghJ0ZXN0LnNtYWxsc3Rl cC5jb20wCgYIKoZIzj0EAwIDRwAwRAIgY+nTc+RHn31/BOhht4JpxCmJPHxqFT3S ojnictBudV0CIB87ipY5HV3c8FLVEzTA0wFwdDZvQraQYsthwbg2kQFb -----END CERTIFICATE-----` testSignedCertificate = `-----BEGIN CERTIFICATE----- MIIB/DCCAaKgAwIBAgIQHHFuGMz0cClfde5kqP5prTAKBggqhkjOPQQDAjAqMSgw JgYDVQQDEx9Hb29nbGUgQ0FTIFRlc3QgSW50ZXJtZWRpYXRlIENBMB4XDTIwMDkx NTAwMDQ0M1oXDTMwMDkxMzAwMDQ0MFowHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0 ZXAuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEMqNCiXMvbn74LsHzRv+8 17m9vEzH6RHrg3m82e0uEc36+fZWV/zJ9SKuONmnl5VP79LsjL5SVH0RDj73U2XO DKOBtjCBszAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG AQUFBwMCMB0GA1UdDgQWBBRTA2cTs7PCNjnps/+T0dS8diqv0DAfBgNVHSMEGDAW gBRIOVqyLDSlErJLuWWEvRm5UU1r1TBCBgwrBgEEAYKkZMYoQAIEMjAwEwhjbG91 ZGNhcxMkZDhkMThhNjgtNTI5Ni00YWYzLWFlNGItMmY4NzdkYTNmYmQ5MAoGCCqG SM49BAMCA0gAMEUCIGxl+pqJ50WYWUqK2l4V1FHoXSi0Nht5kwTxFxnWZu1xAiEA zemu3bhWLFaGg3s8i+HTEhw4RqkHP74vF7AVYp88bAw= -----END CERTIFICATE-----` testIntermediateCsr = `-----BEGIN CERTIFICATE REQUEST----- MIHeMIGFAgEAMCMxITAfBgNVBAMTGENsb3VkQ0FTIEludGVybWVkaWF0ZSBDQTBZ MBMGByqGSM49AgEGCCqGSM49AwEHA0IABPLuqxgBY+QmaXc8zKIC8FMgjJ6dF/cL b+Dig0XKc5GH/T1ORrhgOkRayrQcjPMu+jkjg25qn6vvp43LRtUKPXOgADAKBggq hkjOPQQDAgNIADBFAiEAn3pkYXb2OzoQZ+AExFqd7qZ7pg2nyP2kBZZ01Pl8KfcC IHKplBXDR79/i7kjOtv1iWfgf5S/XQHrz178gXA0YQe7 -----END CERTIFICATE REQUEST-----` testRootKey = `-----BEGIN EC PRIVATE KEY----- MHcCAQEEIN51Rgg6YcQVLeCRzumdw4pjM3VWqFIdCbnsV3Up1e/goAoGCCqGSM49 AwEHoUQDQgAEjJIcDhvvxi7gu4aFkiW/8+E3BfPhmhXU5RlDQusre+MHXc7XYMtk Lm6PXPeTF1DNdS21Ju1G/j1yUykGJOmxkg== -----END EC PRIVATE KEY-----` //nolint:unused,gocritic,varcheck testIntermediateKey = `-----BEGIN EC PRIVATE KEY----- MHcCAQEEIMMX/XkXGnRDD4fYu7Z4rHACdJn/iyOy2UTwsv+oZ0C+oAoGCCqGSM49 AwEHoUQDQgAE8u6rGAFj5CZpdzzMogLwUyCMnp0X9wtv4OKDRcpzkYf9PU5GuGA6 RFrKtByM8y76OSODbmqfq++njctG1Qo9cw== -----END EC PRIVATE KEY-----` ) type testClient struct { credentialsFile string certificate *pb.Certificate certificateAuthority *pb.CertificateAuthority err error } func newTestClient(credentialsFile string) (CertificateAuthorityClient, error) { if credentialsFile == "testdata/error.json" { return nil, errTest } return &testClient{ credentialsFile: credentialsFile, }, nil } func okTestClient() *testClient { return &testClient{ credentialsFile: "testdata/credentials.json", certificate: &pb.Certificate{ Name: testCertificateName, PemCertificate: testSignedCertificate, PemCertificateChain: []string{testIntermediateCertificate, testRootCertificate}, }, certificateAuthority: &pb.CertificateAuthority{ PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, }, } } func okTestClientRootOnly() *testClient { return &testClient{ credentialsFile: "testdata/credentials.json", certificate: &pb.Certificate{ Name: testCertificateName, PemCertificate: testSignedCertificate, PemCertificateChain: []string{testRootCertificate}, }, certificateAuthority: &pb.CertificateAuthority{ PemCaCertificates: []string{testRootCertificate}, }, } } func okTestClientWithMultipleIntermediates() *testClient { return &testClient{ credentialsFile: "testdata/credentials.json", certificate: &pb.Certificate{ Name: testCertificateName, PemCertificate: testSignedCertificate, PemCertificateChain: []string{testIntermediateCertificate, testIntermediateCertificate, testIntermediateCertificate, testRootCertificate}, }, certificateAuthority: &pb.CertificateAuthority{ PemCaCertificates: []string{testIntermediateCertificate, testIntermediateCertificate, testIntermediateCertificate, testRootCertificate}, }, } } func failTestClient() *testClient { return &testClient{ credentialsFile: "testdata/credentials.json", err: errTest, } } func badRootTestClient() *testClient { return &testClient{ credentialsFile: "testdata/credentials.json", certificate: &pb.Certificate{ Name: testCertificateName, PemCertificate: "not a pem cert", PemCertificateChain: []string{testIntermediateCertificate, testRootCertificate}, }, certificateAuthority: &pb.CertificateAuthority{ PemCaCertificates: []string{testIntermediateCertificate, "not a pem cert"}, }, } } func badIntermediateTestClient() *testClient { return &testClient{ credentialsFile: "testdata/credentials.json", certificate: &pb.Certificate{ Name: testCertificateName, PemCertificate: "this is not a pem", PemCertificateChain: []string{testIntermediateCertificate, testRootCertificate}, }, certificateAuthority: &pb.CertificateAuthority{ PemCaCertificates: []string{"this intermediate is not a pem", testRootCertificate}, }, } } func setTeeReader(t *testing.T, w *bytes.Buffer) { t.Helper() reader := rand.Reader t.Cleanup(func() { rand.Reader = reader }) rand.Reader = io.TeeReader(reader, w) } type badSigner struct { pub crypto.PublicKey } func createBadSigner(t *testing.T) *badSigner { t.Helper() pub, _, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } return &badSigner{ pub: pub, } } func (b *badSigner) Public() crypto.PublicKey { return b.pub } func (b *badSigner) Sign(io.Reader, []byte, crypto.SignerOpts) ([]byte, error) { return nil, fmt.Errorf("💥") } func (c *testClient) CreateCertificate(context.Context, *pb.CreateCertificateRequest, ...gax.CallOption) (*pb.Certificate, error) { return c.certificate, c.err } func (c *testClient) RevokeCertificate(context.Context, *pb.RevokeCertificateRequest, ...gax.CallOption) (*pb.Certificate, error) { return c.certificate, c.err } func (c *testClient) GetCertificateAuthority(context.Context, *pb.GetCertificateAuthorityRequest, ...gax.CallOption) (*pb.CertificateAuthority, error) { return c.certificateAuthority, c.err } func (c *testClient) CreateCertificateAuthority(context.Context, *pb.CreateCertificateAuthorityRequest, ...gax.CallOption) (*privateca.CreateCertificateAuthorityOperation, error) { return nil, errors.New("use NewMockCertificateAuthorityClient") } func (c *testClient) FetchCertificateAuthorityCsr(context.Context, *pb.FetchCertificateAuthorityCsrRequest, ...gax.CallOption) (*pb.FetchCertificateAuthorityCsrResponse, error) { return nil, errors.New("use NewMockCertificateAuthorityClient") } func (c *testClient) ActivateCertificateAuthority(context.Context, *pb.ActivateCertificateAuthorityRequest, ...gax.CallOption) (*privateca.ActivateCertificateAuthorityOperation, error) { return nil, errors.New("use NewMockCertificateAuthorityClient") } func (c *testClient) EnableCertificateAuthority(context.Context, *pb.EnableCertificateAuthorityRequest, ...gax.CallOption) (*privateca.EnableCertificateAuthorityOperation, error) { return nil, errors.New("use NewMockCertificateAuthorityClient") } func (c *testClient) GetCaPool(context.Context, *pb.GetCaPoolRequest, ...gax.CallOption) (*pb.CaPool, error) { return nil, errors.New("use NewMockCertificateAuthorityClient") } func (c *testClient) CreateCaPool(context.Context, *pb.CreateCaPoolRequest, ...gax.CallOption) (*privateca.CreateCaPoolOperation, error) { return nil, errors.New("use NewMockCertificateAuthorityClient") } func mustParseCertificate(t *testing.T, pemCert string) *x509.Certificate { t.Helper() crt, err := parseCertificate(pemCert) if err != nil { t.Fatal(err) } return crt } func mustParseECKey(t *testing.T, pemKey string) *ecdsa.PrivateKey { t.Helper() block, _ := pem.Decode([]byte(pemKey)) if block == nil { t.Fatal("failed to parse key") return nil } key, err := x509.ParseECPrivateKey(block.Bytes) if err != nil { t.Fatal(err) } return key } func TestNew(t *testing.T) { tmp := newCertificateAuthorityClient newCertificateAuthorityClient = func(ctx context.Context, credentialsFile string) (CertificateAuthorityClient, error) { return newTestClient(credentialsFile) } t.Cleanup(func() { newCertificateAuthorityClient = tmp }) type args struct { ctx context.Context opts apiv1.Options } tests := []struct { name string args args want *CloudCAS wantErr bool }{ {"ok", args{context.Background(), apiv1.Options{ CertificateAuthority: testAuthorityName, }}, &CloudCAS{ client: &testClient{}, certificateAuthority: testAuthorityName, project: testProject, location: testLocation, caPool: testCaPool, caPoolTier: 0, }, false}, {"ok authority and creator", args{context.Background(), apiv1.Options{ CertificateAuthority: testAuthorityName, IsCreator: true, }}, &CloudCAS{ client: &testClient{}, certificateAuthority: testAuthorityName, project: testProject, location: testLocation, caPool: testCaPool, caPoolTier: 0, }, false}, {"ok with credentials", args{context.Background(), apiv1.Options{ CertificateAuthority: testAuthorityName, CredentialsFile: "testdata/credentials.json", }}, &CloudCAS{ client: &testClient{credentialsFile: "testdata/credentials.json"}, certificateAuthority: testAuthorityName, project: testProject, location: testLocation, caPool: testCaPool, caPoolTier: 0, }, false}, {"ok creator", args{context.Background(), apiv1.Options{ IsCreator: true, Project: testProject, Location: testLocation, CaPool: testCaPool, }}, &CloudCAS{ client: &testClient{}, project: testProject, location: testLocation, caPool: testCaPool, caPoolTier: pb.CaPool_DEVOPS, }, false}, {"ok creator devops", args{context.Background(), apiv1.Options{ IsCreator: true, Project: testProject, Location: testLocation, CaPool: testCaPool, CaPoolTier: "DevOps", }}, &CloudCAS{ client: &testClient{}, project: testProject, location: testLocation, caPool: testCaPool, caPoolTier: pb.CaPool_DEVOPS, }, false}, {"ok creator enterprise", args{context.Background(), apiv1.Options{ IsCreator: true, Project: testProject, Location: testLocation, CaPool: testCaPool, CaPoolTier: "ENTERPRISE", }}, &CloudCAS{ client: &testClient{}, project: testProject, location: testLocation, caPool: testCaPool, caPoolTier: pb.CaPool_ENTERPRISE, }, false}, {"fail certificate authority", args{context.Background(), apiv1.Options{ CertificateAuthority: "projects/ok1234/locations/ok1234/caPools/ok1234/certificateAuthorities/ok1234/bad", }}, nil, true}, {"fail certificate authority regex", args{context.Background(), apiv1.Options{}}, nil, true}, {"fail with credentials", args{context.Background(), apiv1.Options{ CertificateAuthority: testAuthorityName, CredentialsFile: "testdata/error.json", }}, nil, true}, {"fail creator project", args{context.Background(), apiv1.Options{ IsCreator: true, Project: "", Location: testLocation, }}, nil, true}, {"fail creator location", args{context.Background(), apiv1.Options{ IsCreator: true, Project: testProject, Location: "", }}, nil, true}, {"fail caPool", args{context.Background(), apiv1.Options{ IsCreator: true, Project: testProject, Location: testLocation, CaPool: "", }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := New(tt.args.ctx, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("New() = %v, want %v", got, tt.want) } }) } } func TestNew_register(t *testing.T) { tmp := newCertificateAuthorityClient newCertificateAuthorityClient = func(ctx context.Context, credentialsFile string) (CertificateAuthorityClient, error) { return newTestClient(credentialsFile) } t.Cleanup(func() { newCertificateAuthorityClient = tmp }) want := &CloudCAS{ client: &testClient{credentialsFile: "testdata/credentials.json"}, certificateAuthority: testAuthorityName, project: testProject, location: testLocation, caPool: testCaPool, } newFn, ok := apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.CloudCAS) if !ok { t.Error("apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.CloudCAS) was not found") return } got, err := newFn(context.Background(), apiv1.Options{ CertificateAuthority: testAuthorityName, CredentialsFile: "testdata/credentials.json", }) if err != nil { t.Errorf("New() error = %v", err) return } if !reflect.DeepEqual(got, want) { t.Errorf("New() = %v, want %v", got, want) } } func TestNew_real(t *testing.T) { if v, ok := os.LookupEnv("GOOGLE_APPLICATION_CREDENTIALS"); ok { os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS") t.Cleanup(func() { t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", v) }) } failDefaultCredentials := true if home, err := os.UserHomeDir(); err == nil { file := filepath.Join(home, ".config", "gcloud", "application_default_credentials.json") if _, err := os.Stat(file); err == nil { failDefaultCredentials = false } } type args struct { ctx context.Context opts apiv1.Options } tests := []struct { name string skipOnCI bool args args wantErr bool }{ {"fail default credentials", true, args{context.Background(), apiv1.Options{CertificateAuthority: testAuthorityName}}, failDefaultCredentials}, {"fail certificate authority", false, args{context.Background(), apiv1.Options{}}, true}, {"fail with credentials", false, args{context.Background(), apiv1.Options{ CertificateAuthority: testAuthorityName, CredentialsFile: "testdata/missing.json", }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.skipOnCI && os.Getenv("CI") == "true" { t.SkipNow() } _, err := New(tt.args.ctx, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestCloudCAS_Type(t *testing.T) { tests := []struct { name string want apiv1.Type }{ {"ok", apiv1.CloudCAS}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &CloudCAS{} if got := c.Type(); got != tt.want { t.Errorf("CloudCAS.Type() = %v, want %v", got, tt.want) } }) } } func TestCloudCAS_GetCertificateAuthority(t *testing.T) { intermediate := mustParseCertificate(t, testIntermediateCertificate) root := mustParseCertificate(t, testRootCertificate) type fields struct { client CertificateAuthorityClient certificateAuthority string } type args struct { req *apiv1.GetCertificateAuthorityRequest } tests := []struct { name string fields fields args args want *apiv1.GetCertificateAuthorityResponse wantErr bool }{ {"ok", fields{okTestClient(), testCertificateName}, args{&apiv1.GetCertificateAuthorityRequest{}}, &apiv1.GetCertificateAuthorityResponse{ RootCertificate: root, IntermediateCertificates: []*x509.Certificate{intermediate}, }, false}, {"ok with name", fields{okTestClient(), testCertificateName}, args{&apiv1.GetCertificateAuthorityRequest{ Name: testCertificateName, }}, &apiv1.GetCertificateAuthorityResponse{ RootCertificate: root, IntermediateCertificates: []*x509.Certificate{intermediate}, }, false}, {"ok with root only", fields{okTestClientRootOnly(), testCertificateName}, args{&apiv1.GetCertificateAuthorityRequest{}}, &apiv1.GetCertificateAuthorityResponse{ RootCertificate: root, IntermediateCertificates: []*x509.Certificate{}, }, false}, {"ok with multiple intermediates", fields{okTestClientWithMultipleIntermediates(), testCertificateName}, args{&apiv1.GetCertificateAuthorityRequest{}}, &apiv1.GetCertificateAuthorityResponse{ RootCertificate: root, IntermediateCertificates: []*x509.Certificate{intermediate, intermediate, intermediate}, }, false}, {"fail GetCertificateAuthority", fields{failTestClient(), testCertificateName}, args{&apiv1.GetCertificateAuthorityRequest{}}, nil, true}, {"fail bad root", fields{badRootTestClient(), testCertificateName}, args{&apiv1.GetCertificateAuthorityRequest{}}, nil, true}, {"fail bad intermediate", fields{badIntermediateTestClient(), testCertificateName}, args{&apiv1.GetCertificateAuthorityRequest{}}, nil, true}, {"fail no pems", fields{&testClient{certificateAuthority: &pb.CertificateAuthority{}}, testCertificateName}, args{&apiv1.GetCertificateAuthorityRequest{}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &CloudCAS{ client: tt.fields.client, certificateAuthority: tt.fields.certificateAuthority, } got, err := c.GetCertificateAuthority(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("CloudCAS.GetCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CloudCAS.GetCertificateAuthority() = %v, want %v", got, tt.want) } }) } } func TestCloudCAS_CreateCertificate(t *testing.T) { type fields struct { client CertificateAuthorityClient certificateAuthority string } type args struct { req *apiv1.CreateCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.CreateCertificateResponse wantErr bool }{ {"ok", fields{okTestClient(), testCertificateName}, args{&apiv1.CreateCertificateRequest{ Template: mustParseCertificate(t, testLeafCertificate), Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: mustParseCertificate(t, testSignedCertificate), CertificateChain: []*x509.Certificate{mustParseCertificate(t, testIntermediateCertificate)}, }, false}, {"fail Template", fields{okTestClient(), testCertificateName}, args{&apiv1.CreateCertificateRequest{ Lifetime: 24 * time.Hour, }}, nil, true}, {"fail Lifetime", fields{okTestClient(), testCertificateName}, args{&apiv1.CreateCertificateRequest{ Template: mustParseCertificate(t, testLeafCertificate), }}, nil, true}, {"fail CreateCertificate", fields{failTestClient(), testCertificateName}, args{&apiv1.CreateCertificateRequest{ Template: mustParseCertificate(t, testLeafCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail Certificate", fields{badRootTestClient(), testCertificateName}, args{&apiv1.CreateCertificateRequest{ Template: mustParseCertificate(t, testLeafCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &CloudCAS{ client: tt.fields.client, certificateAuthority: tt.fields.certificateAuthority, } got, err := c.CreateCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("CloudCAS.CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CloudCAS.CreateCertificate() = %v, want %v", got, tt.want) } }) } } func TestCloudCAS_createCertificate(t *testing.T) { leaf := mustParseCertificate(t, testLeafCertificate) signed := mustParseCertificate(t, testSignedCertificate) chain := []*x509.Certificate{mustParseCertificate(t, testIntermediateCertificate)} type fields struct { client CertificateAuthorityClient certificateAuthority string } type args struct { tpl *x509.Certificate lifetime time.Duration requestID string } tests := []struct { name string fields fields args args want *x509.Certificate want1 []*x509.Certificate wantErr bool }{ {"ok", fields{okTestClient(), testAuthorityName}, args{leaf, 24 * time.Hour, "request-id"}, signed, chain, false}, {"fail CertificateConfig", fields{okTestClient(), testAuthorityName}, args{&x509.Certificate{}, 24 * time.Hour, "request-id"}, nil, nil, true}, {"fail CreateCertificate", fields{failTestClient(), testAuthorityName}, args{leaf, 24 * time.Hour, "request-id"}, nil, nil, true}, {"fail ParseCertificates", fields{badRootTestClient(), testAuthorityName}, args{leaf, 24 * time.Hour, "request-id"}, nil, nil, true}, {"fail create id", fields{okTestClient(), testAuthorityName}, args{leaf, 24 * time.Hour, "request-id"}, nil, nil, true}, } // Pre-calculate rand.Random buf := new(bytes.Buffer) setTeeReader(t, buf) for i := 0; i < len(tests)-1; i++ { _, err := uuid.NewRandomFromReader(rand.Reader) if err != nil { t.Fatal(err) } } rand.Reader = buf for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &CloudCAS{ client: tt.fields.client, certificateAuthority: tt.fields.certificateAuthority, } got, got1, err := c.createCertificate(tt.args.tpl, tt.args.lifetime, tt.args.requestID) if (err != nil) != tt.wantErr { t.Errorf("CloudCAS.createCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CloudCAS.createCertificate() got = %v, want %v", got, tt.want) } if !reflect.DeepEqual(got1, tt.want1) { t.Errorf("CloudCAS.createCertificate() got1 = %v, want %v", got1, tt.want1) } }) } } func TestCloudCAS_RenewCertificate(t *testing.T) { type fields struct { client CertificateAuthorityClient certificateAuthority string } type args struct { req *apiv1.RenewCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.RenewCertificateResponse wantErr bool }{ {"ok", fields{okTestClient(), testCertificateName}, args{&apiv1.RenewCertificateRequest{ Template: mustParseCertificate(t, testLeafCertificate), Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: mustParseCertificate(t, testSignedCertificate), CertificateChain: []*x509.Certificate{mustParseCertificate(t, testIntermediateCertificate)}, }, false}, {"fail Template", fields{okTestClient(), testCertificateName}, args{&apiv1.RenewCertificateRequest{ Lifetime: 24 * time.Hour, }}, nil, true}, {"fail Lifetime", fields{okTestClient(), testCertificateName}, args{&apiv1.RenewCertificateRequest{ Template: mustParseCertificate(t, testLeafCertificate), }}, nil, true}, {"fail CreateCertificate", fields{failTestClient(), testCertificateName}, args{&apiv1.RenewCertificateRequest{ Template: mustParseCertificate(t, testLeafCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail Certificate", fields{badRootTestClient(), testCertificateName}, args{&apiv1.RenewCertificateRequest{ Template: mustParseCertificate(t, testLeafCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &CloudCAS{ client: tt.fields.client, certificateAuthority: tt.fields.certificateAuthority, } got, err := c.RenewCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("CloudCAS.RenewCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CloudCAS.RenewCertificate() = %v, want %v", got, tt.want) } }) } } func TestCloudCAS_RevokeCertificate(t *testing.T) { badExtensionCert := mustParseCertificate(t, testSignedCertificate) for i, ext := range badExtensionCert.Extensions { if ext.Id.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 2}) { badExtensionCert.Extensions[i].Value = []byte("bad-data") } } type fields struct { client CertificateAuthorityClient certificateAuthority string } type args struct { req *apiv1.RevokeCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.RevokeCertificateResponse wantErr bool }{ {"ok", fields{okTestClient(), testCertificateName}, args{&apiv1.RevokeCertificateRequest{ Certificate: mustParseCertificate(t, testSignedCertificate), ReasonCode: 1, }}, &apiv1.RevokeCertificateResponse{ Certificate: mustParseCertificate(t, testSignedCertificate), CertificateChain: []*x509.Certificate{mustParseCertificate(t, testIntermediateCertificate)}, }, false}, {"fail Extension", fields{okTestClient(), testCertificateName}, args{&apiv1.RevokeCertificateRequest{ Certificate: mustParseCertificate(t, testLeafCertificate), ReasonCode: 1, }}, nil, true}, {"fail Extension Value", fields{okTestClient(), testCertificateName}, args{&apiv1.RevokeCertificateRequest{ Certificate: badExtensionCert, ReasonCode: 1, }}, nil, true}, {"fail Certificate", fields{okTestClient(), testCertificateName}, args{&apiv1.RevokeCertificateRequest{ ReasonCode: 2, }}, nil, true}, {"fail ReasonCode", fields{okTestClient(), testCertificateName}, args{&apiv1.RevokeCertificateRequest{ Certificate: mustParseCertificate(t, testSignedCertificate), ReasonCode: 100, }}, nil, true}, {"fail ReasonCode 7", fields{okTestClient(), testCertificateName}, args{&apiv1.RevokeCertificateRequest{ Certificate: mustParseCertificate(t, testSignedCertificate), ReasonCode: 7, }}, nil, true}, {"fail ReasonCode 8", fields{okTestClient(), testCertificateName}, args{&apiv1.RevokeCertificateRequest{ Certificate: mustParseCertificate(t, testSignedCertificate), ReasonCode: 8, }}, nil, true}, {"fail RevokeCertificate", fields{failTestClient(), testCertificateName}, args{&apiv1.RevokeCertificateRequest{ Certificate: mustParseCertificate(t, testSignedCertificate), ReasonCode: 1, }}, nil, true}, {"fail ParseCertificate", fields{badRootTestClient(), testCertificateName}, args{&apiv1.RevokeCertificateRequest{ Certificate: mustParseCertificate(t, testSignedCertificate), ReasonCode: 1, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &CloudCAS{ client: tt.fields.client, certificateAuthority: tt.fields.certificateAuthority, } got, err := c.RevokeCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("CloudCAS.RevokeCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CloudCAS.RevokeCertificate() = %v, want %v", got, tt.want) } }) } } func Test_createCertificateID(t *testing.T) { buf := new(bytes.Buffer) setTeeReader(t, buf) id, err := uuid.NewRandomFromReader(rand.Reader) if err != nil { t.Fatal(err) } rand.Reader = buf tests := []struct { name string want string wantErr bool }{ {"ok", id.String(), false}, {"fail", "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := createCertificateID() if (err != nil) != tt.wantErr { t.Errorf("createCertificateID() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("createCertificateID() = %v, want %v", got, tt.want) } }) } } func Test_parseCertificate(t *testing.T) { type args struct { pemCert string } tests := []struct { name string args args want *x509.Certificate wantErr bool }{ {"ok", args{testLeafCertificate}, mustParseCertificate(t, testLeafCertificate), false}, {"ok intermediate", args{testIntermediateCertificate}, mustParseCertificate(t, testIntermediateCertificate), false}, {"fail pem", args{"not pem"}, nil, true}, {"fail parseCertificate", args{"-----BEGIN CERTIFICATE-----\nZm9vYmFyCg==\n-----END CERTIFICATE-----\n"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := parseCertificate(tt.args.pemCert) if (err != nil) != tt.wantErr { t.Errorf("parseCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("parseCertificate() = %v, want %v", got, tt.want) } }) } } func Test_getCertificateAndChain(t *testing.T) { type args struct { certpb *pb.Certificate } tests := []struct { name string args args want *x509.Certificate want1 []*x509.Certificate wantErr bool }{ {"ok", args{&pb.Certificate{ Name: testCertificateName, PemCertificate: testSignedCertificate, PemCertificateChain: []string{testIntermediateCertificate, testRootCertificate}, }}, mustParseCertificate(t, testSignedCertificate), []*x509.Certificate{mustParseCertificate(t, testIntermediateCertificate)}, false}, {"fail PemCertificate", args{&pb.Certificate{ Name: testCertificateName, PemCertificate: "foobar", PemCertificateChain: []string{testIntermediateCertificate, testRootCertificate}, }}, nil, nil, true}, {"fail PemCertificateChain", args{&pb.Certificate{ Name: testCertificateName, PemCertificate: testSignedCertificate, PemCertificateChain: []string{"foobar", testRootCertificate}, }}, nil, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1, err := getCertificateAndChain(tt.args.certpb) if (err != nil) != tt.wantErr { t.Errorf("getCertificateAndChain() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("getCertificateAndChain() got = %v, want %v", got, tt.want) } if !reflect.DeepEqual(got1, tt.want1) { t.Errorf("getCertificateAndChain() got1 = %v, want %v", got1, tt.want1) } }) } } func TestCloudCAS_CreateCertificateAuthority(t *testing.T) { must := func(a, _ any) any { return a } ctrl := gomock.NewController(t) defer ctrl.Finish() mosCtrl := gomock.NewController(t) defer mosCtrl.Finish() m := NewMockCertificateAuthorityClient(ctrl) mos := NewMockOperationsServer(mosCtrl) // Create operation server srv := grpc.NewServer() longrunningpb.RegisterOperationsServer(srv, mos) lis := bufconn.Listen(2) go srv.Serve(lis) defer srv.Stop() // Create fake privateca client conn, err := grpc.NewClient("localhost", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return lis.Dial() })) if err != nil { t.Fatal(err) } client, err := lroauto.NewOperationsClient(context.Background(), option.WithGRPCConn(conn)) if err != nil { t.Fatal(err) } fake, err := privateca.NewCertificateAuthorityClient(context.Background(), option.WithGRPCConn(conn)) if err != nil { t.Fatal(err) } fake.LROClient = client // Configure mocks anee := gomock.Any() // ok root m.EXPECT().GetCaPool(anee, anee).Return(nil, status.Error(codes.NotFound, "not found")) m.EXPECT().CreateCaPool(anee, anee).Return(fake.CreateCaPoolOperation("CreateCaPool"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCaPool", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CaPool{ Name: testCaPoolName, })).(*anypb.Any), }, }, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testRootCertificate}, })).(*anypb.Any), }, }, nil) m.EXPECT().EnableCertificateAuthority(anee, anee).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "EnableCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testRootCertificate}, })).(*anypb.Any), }, }, nil) // ok intermediate m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, })).(*anypb.Any), }, }, nil) m.EXPECT().FetchCertificateAuthorityCsr(anee, anee).Return(&pb.FetchCertificateAuthorityCsrResponse{ PemCsr: testIntermediateCsr, }, nil) m.EXPECT().CreateCertificate(anee, anee).Return(&pb.Certificate{ PemCertificate: testIntermediateCertificate, PemCertificateChain: []string{testRootCertificate}, }, nil) m.EXPECT().ActivateCertificateAuthority(anee, anee).Return(fake.ActivateCertificateAuthorityOperation("ActivateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "ActivateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, })).(*anypb.Any), }, }, nil) m.EXPECT().EnableCertificateAuthority(anee, anee).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "EnableCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, })).(*anypb.Any), }, }, nil) // ok intermediate local signer m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, })).(*anypb.Any), }, }, nil) m.EXPECT().FetchCertificateAuthorityCsr(anee, anee).Return(&pb.FetchCertificateAuthorityCsrResponse{ PemCsr: testIntermediateCsr, }, nil) m.EXPECT().ActivateCertificateAuthority(anee, anee).Return(fake.ActivateCertificateAuthorityOperation("ActivateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "ActivateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, })).(*anypb.Any), }, }, nil) m.EXPECT().EnableCertificateAuthority(anee, anee).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "EnableCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, })).(*anypb.Any), }, }, nil) // ok create key m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testRootCertificate}, })).(*anypb.Any), }, }, nil) m.EXPECT().EnableCertificateAuthority(anee, anee).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "EnableCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testRootCertificate}, })).(*anypb.Any), }, }, nil) // fail GetCaPool m.EXPECT().GetCaPool(anee, anee).Return(nil, errTest) // fail CreateCaPool m.EXPECT().GetCaPool(anee, anee).Return(nil, status.Error(codes.NotFound, "not found")) m.EXPECT().CreateCaPool(anee, anee).Return(nil, errTest) // fail CreateCaPool.Wait m.EXPECT().GetCaPool(anee, anee).Return(nil, status.Error(codes.NotFound, "not found")) m.EXPECT().CreateCaPool(anee, anee).Return(fake.CreateCaPoolOperation("CreateCaPool"), nil) mos.EXPECT().GetOperation(anee, anee).Return(nil, errTest) // fail CreateCertificateAuthority m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(nil, errTest) // fail CreateCertificateAuthority.Wait m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(nil, errTest) // fail EnableCertificateAuthority m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testRootCertificate}, })).(*anypb.Any), }, }, nil) m.EXPECT().EnableCertificateAuthority(anee, anee).Return(nil, errTest) // fail EnableCertificateAuthority.Wait m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testRootCertificate}, })).(*anypb.Any), }, }, nil) m.EXPECT().EnableCertificateAuthority(anee, anee).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) mos.EXPECT().GetOperation(anee, anee).Return(nil, errTest) // fail EnableCertificateAuthority intermediate m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, })).(*anypb.Any), }, }, nil) m.EXPECT().FetchCertificateAuthorityCsr(anee, anee).Return(&pb.FetchCertificateAuthorityCsrResponse{ PemCsr: testIntermediateCsr, }, nil) m.EXPECT().CreateCertificate(anee, anee).Return(&pb.Certificate{ PemCertificate: testIntermediateCertificate, PemCertificateChain: []string{testRootCertificate}, }, nil) m.EXPECT().ActivateCertificateAuthority(anee, anee).Return(fake.ActivateCertificateAuthorityOperation("ActivateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "ActivateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, })).(*anypb.Any), }, }, nil) m.EXPECT().EnableCertificateAuthority(anee, anee).Return(nil, errTest) // fail EnableCertificateAuthority.Wait intermediate m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, })).(*anypb.Any), }, }, nil) m.EXPECT().FetchCertificateAuthorityCsr(anee, anee).Return(&pb.FetchCertificateAuthorityCsrResponse{ PemCsr: testIntermediateCsr, }, nil) m.EXPECT().CreateCertificate(anee, anee).Return(&pb.Certificate{ PemCertificate: testIntermediateCertificate, PemCertificateChain: []string{testRootCertificate}, }, nil) m.EXPECT().ActivateCertificateAuthority(anee, anee).Return(fake.ActivateCertificateAuthorityOperation("ActivateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "ActivateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, PemCaCertificates: []string{testIntermediateCertificate, testRootCertificate}, })).(*anypb.Any), }, }, nil) m.EXPECT().EnableCertificateAuthority(anee, anee).Return(fake.EnableCertificateAuthorityOperation("EnableCertificateAuthorityOperation"), nil) mos.EXPECT().GetOperation(anee, anee).Return(nil, errTest) // fail FetchCertificateAuthorityCsr m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, })).(*anypb.Any), }, }, nil) m.EXPECT().FetchCertificateAuthorityCsr(anee, anee).Return(nil, errTest) // fail CreateCertificate m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, })).(*anypb.Any), }, }, nil) m.EXPECT().FetchCertificateAuthorityCsr(anee, anee).Return(&pb.FetchCertificateAuthorityCsrResponse{ PemCsr: testIntermediateCsr, }, nil) m.EXPECT().CreateCertificate(anee, anee).Return(nil, errTest) // fail ActivateCertificateAuthority m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, })).(*anypb.Any), }, }, nil) m.EXPECT().FetchCertificateAuthorityCsr(anee, anee).Return(&pb.FetchCertificateAuthorityCsrResponse{ PemCsr: testIntermediateCsr, }, nil) m.EXPECT().CreateCertificate(anee, anee).Return(&pb.Certificate{ PemCertificate: testIntermediateCertificate, PemCertificateChain: []string{testRootCertificate}, }, nil) m.EXPECT().ActivateCertificateAuthority(anee, anee).Return(nil, errTest) // fail ActivateCertificateAuthority.Wait m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, })).(*anypb.Any), }, }, nil) m.EXPECT().FetchCertificateAuthorityCsr(anee, anee).Return(&pb.FetchCertificateAuthorityCsrResponse{ PemCsr: testIntermediateCsr, }, nil) m.EXPECT().CreateCertificate(anee, anee).Return(&pb.Certificate{ PemCertificate: testIntermediateCertificate, PemCertificateChain: []string{testRootCertificate}, }, nil) m.EXPECT().ActivateCertificateAuthority(anee, anee).Return(fake.ActivateCertificateAuthorityOperation("ActivateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(nil, errTest) // fail x509util.CreateCertificate m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, })).(*anypb.Any), }, }, nil) m.EXPECT().FetchCertificateAuthorityCsr(anee, anee).Return(&pb.FetchCertificateAuthorityCsrResponse{ PemCsr: testIntermediateCsr, }, nil) // fail parseCertificateRequest m.EXPECT().GetCaPool(anee, anee).Return(&pb.CaPool{Name: testCaPoolName}, nil) m.EXPECT().CreateCertificateAuthority(anee, anee).Return(fake.CreateCertificateAuthorityOperation("CreateCertificateAuthority"), nil) mos.EXPECT().GetOperation(anee, anee).Return(&longrunningpb.Operation{ Name: "CreateCertificateAuthority", Done: true, Result: &longrunningpb.Operation_Response{ Response: must(anypb.New(&pb.CertificateAuthority{ Name: testAuthorityName, })).(*anypb.Any), }, }, nil) m.EXPECT().FetchCertificateAuthorityCsr(anee, anee).Return(&pb.FetchCertificateAuthorityCsrResponse{ PemCsr: "Not a CSR", }, nil) rootCrt := mustParseCertificate(t, testRootCertificate) intCrt := mustParseCertificate(t, testIntermediateCertificate) type fields struct { client CertificateAuthorityClient certificateAuthority string project string location string caPool string caPoolTier pb.CaPool_Tier } type args struct { req *apiv1.CreateCertificateAuthorityRequest } tests := []struct { name string fields fields args args want *apiv1.CreateCertificateAuthorityResponse wantErr bool }{ {"ok root", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_ENTERPRISE}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: rootCrt, }, false}, {"ok intermediate", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: rootCrt, }, }}, &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: intCrt, CertificateChain: []*x509.Certificate{rootCrt}, }, false}, {"ok intermediate local signer", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_ENTERPRISE}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: rootCrt, Signer: mustParseECKey(t, testRootKey), }, }}, &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: intCrt, CertificateChain: []*x509.Certificate{rootCrt}, }, false}, {"ok create key", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, CreateKey: &kmsapi.CreateKeyRequest{ SignatureAlgorithm: kmsapi.ECDSAWithSHA256, }, }}, &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: rootCrt, }, false}, {"fail project", fields{m, "", "", testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail location", fields{m, "", testProject, "", testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail caPool", fields{m, "", testProject, testLocation, "", pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail template", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail lifetime", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), }}, nil, true}, {"fail parent", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail parent name", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{}, }}, nil, true}, {"fail type", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: 0, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail create key", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, CreateKey: &kmsapi.CreateKeyRequest{ SignatureAlgorithm: kmsapi.PureEd25519, }, }}, nil, true}, {"fail GetCaPool", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail CreateCaPool", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail CreateCaPool.Wait", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail CreateCertificateAuthority", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail CreateCertificateAuthority.Wait", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail EnableCertificateAuthority", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail EnableCertificateAuthority.Wait", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: mustParseCertificate(t, testRootCertificate), Lifetime: 24 * time.Hour, }}, nil, true}, {"fail EnableCertificateAuthority intermediate", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: rootCrt, }, }}, nil, true}, {"fail EnableCertificateAuthority.Wait intermediate", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: rootCrt, }, }}, nil, true}, {"fail FetchCertificateAuthorityCsr", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: rootCrt, }, }}, nil, true}, {"fail CreateCertificate", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: rootCrt, }, }}, nil, true}, {"fail ActivateCertificateAuthority", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: rootCrt, }, }}, nil, true}, {"fail ActivateCertificateAuthority.Wait", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Name: testAuthorityName, Certificate: rootCrt, }, }}, nil, true}, {"fail x509util.CreateCertificate", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: rootCrt, Signer: createBadSigner(t), }, }}, nil, true}, {"fail parseCertificateRequest", fields{m, "", testProject, testLocation, testCaPool, pb.CaPool_DEVOPS}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: mustParseCertificate(t, testIntermediateCertificate), Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: rootCrt, Signer: createBadSigner(t), }, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &CloudCAS{ client: tt.fields.client, certificateAuthority: tt.fields.certificateAuthority, project: tt.fields.project, location: tt.fields.location, caPool: tt.fields.caPool, caPoolTier: tt.fields.caPoolTier, } got, err := c.CreateCertificateAuthority(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("CloudCAS.CreateCertificateAuthority() error = %+v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CloudCAS.CreateCertificateAuthority() = %v, want %v", got, tt.want) } }) } } func Test_normalizeCertificateAuthorityName(t *testing.T) { type args struct { name string } tests := []struct { name string args args want string }{ {"ok", args{"Test-CA-Name_1234"}, "Test-CA-Name_1234"}, {"change", args{"💥 CA"}, "--CA"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := normalizeCertificateAuthorityName(tt.args.name); got != tt.want { t.Errorf("normalizeCertificateAuthorityName() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cas/cloudcas/mock_client_test.go ================================================ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/smallstep/certificates/cas/cloudcas (interfaces: CertificateAuthorityClient) // // Generated by this command: // // mockgen -package cloudcas -mock_names=CertificateAuthorityClient=MockCertificateAuthorityClient -destination mock_client_test.go github.com/smallstep/certificates/cas/cloudcas CertificateAuthorityClient // // Package cloudcas is a generated GoMock package. package cloudcas import ( context "context" reflect "reflect" privateca "cloud.google.com/go/security/privateca/apiv1" privatecapb "cloud.google.com/go/security/privateca/apiv1/privatecapb" gax "github.com/googleapis/gax-go/v2" gomock "go.uber.org/mock/gomock" ) // MockCertificateAuthorityClient is a mock of CertificateAuthorityClient interface. type MockCertificateAuthorityClient struct { ctrl *gomock.Controller recorder *MockCertificateAuthorityClientMockRecorder isgomock struct{} } // MockCertificateAuthorityClientMockRecorder is the mock recorder for MockCertificateAuthorityClient. type MockCertificateAuthorityClientMockRecorder struct { mock *MockCertificateAuthorityClient } // NewMockCertificateAuthorityClient creates a new mock instance. func NewMockCertificateAuthorityClient(ctrl *gomock.Controller) *MockCertificateAuthorityClient { mock := &MockCertificateAuthorityClient{ctrl: ctrl} mock.recorder = &MockCertificateAuthorityClientMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockCertificateAuthorityClient) EXPECT() *MockCertificateAuthorityClientMockRecorder { return m.recorder } // ActivateCertificateAuthority mocks base method. func (m *MockCertificateAuthorityClient) ActivateCertificateAuthority(ctx context.Context, req *privatecapb.ActivateCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.ActivateCertificateAuthorityOperation, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ActivateCertificateAuthority", varargs...) ret0, _ := ret[0].(*privateca.ActivateCertificateAuthorityOperation) ret1, _ := ret[1].(error) return ret0, ret1 } // ActivateCertificateAuthority indicates an expected call of ActivateCertificateAuthority. func (mr *MockCertificateAuthorityClientMockRecorder) ActivateCertificateAuthority(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivateCertificateAuthority", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).ActivateCertificateAuthority), varargs...) } // CreateCaPool mocks base method. func (m *MockCertificateAuthorityClient) CreateCaPool(ctx context.Context, req *privatecapb.CreateCaPoolRequest, opts ...gax.CallOption) (*privateca.CreateCaPoolOperation, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateCaPool", varargs...) ret0, _ := ret[0].(*privateca.CreateCaPoolOperation) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateCaPool indicates an expected call of CreateCaPool. func (mr *MockCertificateAuthorityClientMockRecorder) CreateCaPool(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCaPool", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).CreateCaPool), varargs...) } // CreateCertificate mocks base method. func (m *MockCertificateAuthorityClient) CreateCertificate(ctx context.Context, req *privatecapb.CreateCertificateRequest, opts ...gax.CallOption) (*privatecapb.Certificate, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateCertificate", varargs...) ret0, _ := ret[0].(*privatecapb.Certificate) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateCertificate indicates an expected call of CreateCertificate. func (mr *MockCertificateAuthorityClientMockRecorder) CreateCertificate(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCertificate", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).CreateCertificate), varargs...) } // CreateCertificateAuthority mocks base method. func (m *MockCertificateAuthorityClient) CreateCertificateAuthority(ctx context.Context, req *privatecapb.CreateCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.CreateCertificateAuthorityOperation, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "CreateCertificateAuthority", varargs...) ret0, _ := ret[0].(*privateca.CreateCertificateAuthorityOperation) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateCertificateAuthority indicates an expected call of CreateCertificateAuthority. func (mr *MockCertificateAuthorityClientMockRecorder) CreateCertificateAuthority(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCertificateAuthority", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).CreateCertificateAuthority), varargs...) } // EnableCertificateAuthority mocks base method. func (m *MockCertificateAuthorityClient) EnableCertificateAuthority(ctx context.Context, req *privatecapb.EnableCertificateAuthorityRequest, opts ...gax.CallOption) (*privateca.EnableCertificateAuthorityOperation, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "EnableCertificateAuthority", varargs...) ret0, _ := ret[0].(*privateca.EnableCertificateAuthorityOperation) ret1, _ := ret[1].(error) return ret0, ret1 } // EnableCertificateAuthority indicates an expected call of EnableCertificateAuthority. func (mr *MockCertificateAuthorityClientMockRecorder) EnableCertificateAuthority(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnableCertificateAuthority", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).EnableCertificateAuthority), varargs...) } // FetchCertificateAuthorityCsr mocks base method. func (m *MockCertificateAuthorityClient) FetchCertificateAuthorityCsr(ctx context.Context, req *privatecapb.FetchCertificateAuthorityCsrRequest, opts ...gax.CallOption) (*privatecapb.FetchCertificateAuthorityCsrResponse, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "FetchCertificateAuthorityCsr", varargs...) ret0, _ := ret[0].(*privatecapb.FetchCertificateAuthorityCsrResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // FetchCertificateAuthorityCsr indicates an expected call of FetchCertificateAuthorityCsr. func (mr *MockCertificateAuthorityClientMockRecorder) FetchCertificateAuthorityCsr(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchCertificateAuthorityCsr", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).FetchCertificateAuthorityCsr), varargs...) } // GetCaPool mocks base method. func (m *MockCertificateAuthorityClient) GetCaPool(ctx context.Context, req *privatecapb.GetCaPoolRequest, opts ...gax.CallOption) (*privatecapb.CaPool, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetCaPool", varargs...) ret0, _ := ret[0].(*privatecapb.CaPool) ret1, _ := ret[1].(error) return ret0, ret1 } // GetCaPool indicates an expected call of GetCaPool. func (mr *MockCertificateAuthorityClientMockRecorder) GetCaPool(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCaPool", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).GetCaPool), varargs...) } // GetCertificateAuthority mocks base method. func (m *MockCertificateAuthorityClient) GetCertificateAuthority(ctx context.Context, req *privatecapb.GetCertificateAuthorityRequest, opts ...gax.CallOption) (*privatecapb.CertificateAuthority, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "GetCertificateAuthority", varargs...) ret0, _ := ret[0].(*privatecapb.CertificateAuthority) ret1, _ := ret[1].(error) return ret0, ret1 } // GetCertificateAuthority indicates an expected call of GetCertificateAuthority. func (mr *MockCertificateAuthorityClientMockRecorder) GetCertificateAuthority(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCertificateAuthority", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).GetCertificateAuthority), varargs...) } // RevokeCertificate mocks base method. func (m *MockCertificateAuthorityClient) RevokeCertificate(ctx context.Context, req *privatecapb.RevokeCertificateRequest, opts ...gax.CallOption) (*privatecapb.Certificate, error) { m.ctrl.T.Helper() varargs := []any{ctx, req} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "RevokeCertificate", varargs...) ret0, _ := ret[0].(*privatecapb.Certificate) ret1, _ := ret[1].(error) return ret0, ret1 } // RevokeCertificate indicates an expected call of RevokeCertificate. func (mr *MockCertificateAuthorityClientMockRecorder) RevokeCertificate(ctx, req any, opts ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, req}, opts...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeCertificate", reflect.TypeOf((*MockCertificateAuthorityClient)(nil).RevokeCertificate), varargs...) } ================================================ FILE: cas/cloudcas/mock_operation_server_test.go ================================================ // Code generated by MockGen. DO NOT EDIT. // Source: cloud.google.com/go/longrunning/autogen/longrunningpb (interfaces: OperationsServer) // // Generated by this command: // // mockgen -package cloudcas -mock_names=OperationsServer=MockOperationsServer -destination mock_operation_server_test.go cloud.google.com/go/longrunning/autogen/longrunningpb OperationsServer // // Package cloudcas is a generated GoMock package. package cloudcas import ( context "context" reflect "reflect" longrunningpb "cloud.google.com/go/longrunning/autogen/longrunningpb" gomock "go.uber.org/mock/gomock" emptypb "google.golang.org/protobuf/types/known/emptypb" ) // MockOperationsServer is a mock of OperationsServer interface. type MockOperationsServer struct { ctrl *gomock.Controller recorder *MockOperationsServerMockRecorder isgomock struct{} } // MockOperationsServerMockRecorder is the mock recorder for MockOperationsServer. type MockOperationsServerMockRecorder struct { mock *MockOperationsServer } // NewMockOperationsServer creates a new mock instance. func NewMockOperationsServer(ctrl *gomock.Controller) *MockOperationsServer { mock := &MockOperationsServer{ctrl: ctrl} mock.recorder = &MockOperationsServerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockOperationsServer) EXPECT() *MockOperationsServerMockRecorder { return m.recorder } // CancelOperation mocks base method. func (m *MockOperationsServer) CancelOperation(arg0 context.Context, arg1 *longrunningpb.CancelOperationRequest) (*emptypb.Empty, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CancelOperation", arg0, arg1) ret0, _ := ret[0].(*emptypb.Empty) ret1, _ := ret[1].(error) return ret0, ret1 } // CancelOperation indicates an expected call of CancelOperation. func (mr *MockOperationsServerMockRecorder) CancelOperation(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelOperation", reflect.TypeOf((*MockOperationsServer)(nil).CancelOperation), arg0, arg1) } // DeleteOperation mocks base method. func (m *MockOperationsServer) DeleteOperation(arg0 context.Context, arg1 *longrunningpb.DeleteOperationRequest) (*emptypb.Empty, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeleteOperation", arg0, arg1) ret0, _ := ret[0].(*emptypb.Empty) ret1, _ := ret[1].(error) return ret0, ret1 } // DeleteOperation indicates an expected call of DeleteOperation. func (mr *MockOperationsServerMockRecorder) DeleteOperation(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOperation", reflect.TypeOf((*MockOperationsServer)(nil).DeleteOperation), arg0, arg1) } // GetOperation mocks base method. func (m *MockOperationsServer) GetOperation(arg0 context.Context, arg1 *longrunningpb.GetOperationRequest) (*longrunningpb.Operation, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetOperation", arg0, arg1) ret0, _ := ret[0].(*longrunningpb.Operation) ret1, _ := ret[1].(error) return ret0, ret1 } // GetOperation indicates an expected call of GetOperation. func (mr *MockOperationsServerMockRecorder) GetOperation(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOperation", reflect.TypeOf((*MockOperationsServer)(nil).GetOperation), arg0, arg1) } // ListOperations mocks base method. func (m *MockOperationsServer) ListOperations(arg0 context.Context, arg1 *longrunningpb.ListOperationsRequest) (*longrunningpb.ListOperationsResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ListOperations", arg0, arg1) ret0, _ := ret[0].(*longrunningpb.ListOperationsResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // ListOperations indicates an expected call of ListOperations. func (mr *MockOperationsServerMockRecorder) ListOperations(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListOperations", reflect.TypeOf((*MockOperationsServer)(nil).ListOperations), arg0, arg1) } // WaitOperation mocks base method. func (m *MockOperationsServer) WaitOperation(arg0 context.Context, arg1 *longrunningpb.WaitOperationRequest) (*longrunningpb.Operation, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "WaitOperation", arg0, arg1) ret0, _ := ret[0].(*longrunningpb.Operation) ret1, _ := ret[1].(error) return ret0, ret1 } // WaitOperation indicates an expected call of WaitOperation. func (mr *MockOperationsServerMockRecorder) WaitOperation(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitOperation", reflect.TypeOf((*MockOperationsServer)(nil).WaitOperation), arg0, arg1) } ================================================ FILE: cas/softcas/softcas.go ================================================ package softcas import ( "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/x509" "time" "github.com/pkg/errors" "go.step.sm/crypto/kms" kmsapi "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/x509util" "github.com/smallstep/certificates/cas/apiv1" ) func init() { apiv1.Register(apiv1.SoftCAS, func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) { return New(ctx, opts) }) } var now = time.Now // SoftCAS implements a Certificate Authority Service using Golang or KMS // crypto. This is the default CAS used in step-ca. type SoftCAS struct { CertificateChain []*x509.Certificate Signer crypto.Signer CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) KeyManager kms.KeyManager } // New creates a new CertificateAuthorityService implementation using Golang or KMS // crypto. func New(_ context.Context, opts apiv1.Options) (*SoftCAS, error) { if !opts.IsCreator { switch { case len(opts.CertificateChain) == 0 && opts.CertificateSigner == nil: return nil, errors.New("softCAS 'CertificateChain' cannot be nil") case opts.Signer == nil && opts.CertificateSigner == nil: return nil, errors.New("softCAS 'signer' cannot be nil") } } return &SoftCAS{ CertificateChain: opts.CertificateChain, Signer: opts.Signer, CertificateSigner: opts.CertificateSigner, KeyManager: opts.KeyManager, }, nil } // Type returns the type of this CertificateAuthorityService. func (c *SoftCAS) Type() apiv1.Type { return apiv1.SoftCAS } // GetSigner implements [apiv1.CertificateAuthoritySigner] and returns a // [crypto.Signer] with the intermediate key. func (c *SoftCAS) GetSigner() (crypto.Signer, error) { _, signer, err := c.getCertSigner() return signer, err } // CreateCertificate signs a new certificate using Golang or KMS crypto. func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) { switch { case req.Template == nil: return nil, errors.New("createCertificateRequest `template` cannot be nil") case req.Lifetime == 0: return nil, errors.New("createCertificateRequest `lifetime` cannot be 0") } t := now() // Provisioners can also set specific values. if req.Template.NotBefore.IsZero() { req.Template.NotBefore = t.Add(-1 * req.Backdate) } if req.Template.NotAfter.IsZero() { req.Template.NotAfter = t.Add(req.Lifetime) } chain, signer, err := c.getCertSigner() if err != nil { return nil, err } req.Template.Issuer = chain[0].Subject cert, err := createCertificate(req.Template, chain[0], req.Template.PublicKey, signer) if err != nil { return nil, err } return &apiv1.CreateCertificateResponse{ Certificate: cert, CertificateChain: chain, }, nil } // RenewCertificate signs the given certificate template using Golang or KMS crypto. func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { switch { case req.Template == nil: return nil, errors.New("createCertificateRequest `template` cannot be nil") case req.Lifetime == 0: return nil, errors.New("createCertificateRequest `lifetime` cannot be 0") } t := now() req.Template.NotBefore = t.Add(-1 * req.Backdate) req.Template.NotAfter = t.Add(req.Lifetime) chain, signer, err := c.getCertSigner() if err != nil { return nil, err } req.Template.Issuer = chain[0].Subject cert, err := createCertificate(req.Template, chain[0], req.Template.PublicKey, signer) if err != nil { return nil, err } return &apiv1.RenewCertificateResponse{ Certificate: cert, CertificateChain: chain, }, nil } // RevokeCertificate revokes the given certificate in step-ca. In SoftCAS this // operation is a no-op as the actual revoke will happen when we store the entry // in the db. func (c *SoftCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { chain, _, err := c.getCertSigner() if err != nil { return nil, err } return &apiv1.RevokeCertificateResponse{ Certificate: req.Certificate, CertificateChain: chain, }, nil } // CreateCRL will create a new CRL based on the RevocationList passed to it func (c *SoftCAS) CreateCRL(req *apiv1.CreateCRLRequest) (*apiv1.CreateCRLResponse, error) { certChain, signer, err := c.getCertSigner() if err != nil { return nil, err } revocationListBytes, err := x509.CreateRevocationList(rand.Reader, req.RevocationList, certChain[0], signer) if err != nil { return nil, err } return &apiv1.CreateCRLResponse{CRL: revocationListBytes}, nil } // CreateCertificateAuthority creates a root or an intermediate certificate. func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthorityRequest) (*apiv1.CreateCertificateAuthorityResponse, error) { switch { case req.Template == nil: return nil, errors.New("createCertificateAuthorityRequest `template` cannot be nil") case req.Lifetime == 0: return nil, errors.New("createCertificateAuthorityRequest `lifetime` cannot be 0") case req.Type == apiv1.IntermediateCA && req.Parent == nil: return nil, errors.New("createCertificateAuthorityRequest `parent` cannot be nil") case req.Type == apiv1.IntermediateCA && req.Parent.Certificate == nil: return nil, errors.New("createCertificateAuthorityRequest `parent.template` cannot be nil") case req.Type == apiv1.IntermediateCA && req.Parent.Signer == nil: return nil, errors.New("createCertificateAuthorityRequest `parent.signer` cannot be nil") } key, err := c.createKey(req.CreateKey) if err != nil { return nil, err } signer, err := c.createSigner(&key.CreateSignerRequest) if err != nil { return nil, err } t := now() if req.Template.NotBefore.IsZero() { req.Template.NotBefore = t.Add(-1 * req.Backdate) } if req.Template.NotAfter.IsZero() { req.Template.NotAfter = t.Add(req.Lifetime) } var cert *x509.Certificate switch req.Type { case apiv1.RootCA: cert, err = createCertificate(req.Template, req.Template, signer.Public(), signer) if err != nil { return nil, err } case apiv1.IntermediateCA: cert, err = createCertificate(req.Template, req.Parent.Certificate, signer.Public(), req.Parent.Signer) if err != nil { return nil, err } default: return nil, errors.Errorf("createCertificateAuthorityRequest `type=%d' is invalid or not supported", req.Type) } // Add the parent var chain []*x509.Certificate if req.Parent != nil { chain = append(chain, req.Parent.Certificate) chain = append(chain, req.Parent.CertificateChain...) } return &apiv1.CreateCertificateAuthorityResponse{ Name: cert.Subject.CommonName, Certificate: cert, CertificateChain: chain, KeyName: key.Name, PublicKey: key.PublicKey, PrivateKey: key.PrivateKey, Signer: signer, }, nil } // initializeKeyManager initializes the default key manager if was not given. func (c *SoftCAS) initializeKeyManager() (err error) { if c.KeyManager == nil { c.KeyManager, err = kms.New(context.Background(), kmsapi.Options{ Type: kmsapi.DefaultKMS, }) } return } // getCertSigner returns the certificate chain and signer to use. func (c *SoftCAS) getCertSigner() ([]*x509.Certificate, crypto.Signer, error) { if c.CertificateSigner != nil { return c.CertificateSigner() } return c.CertificateChain, c.Signer, nil } // createKey uses the configured kms to create a key. func (c *SoftCAS) createKey(req *kmsapi.CreateKeyRequest) (*kmsapi.CreateKeyResponse, error) { if err := c.initializeKeyManager(); err != nil { return nil, err } if req == nil { req = &kmsapi.CreateKeyRequest{ SignatureAlgorithm: kmsapi.ECDSAWithSHA256, } } return c.KeyManager.CreateKey(req) } // createSigner uses the configured kms to create a singer func (c *SoftCAS) createSigner(req *kmsapi.CreateSignerRequest) (crypto.Signer, error) { if err := c.initializeKeyManager(); err != nil { return nil, err } return c.KeyManager.CreateSigner(req) } // createCertificate sets the SignatureAlgorithm of the template if necessary // and calls x509util.CreateCertificate. func createCertificate(template, parent *x509.Certificate, pub crypto.PublicKey, signer crypto.Signer) (*x509.Certificate, error) { // Signers can specify the signature algorithm. This is especially important // when x509.CreateCertificate attempts to validate a RSAPSS signature. if template.SignatureAlgorithm == 0 { if sa, ok := signer.(apiv1.SignatureAlgorithmGetter); ok { template.SignatureAlgorithm = sa.SignatureAlgorithm() } else if _, ok := parent.PublicKey.(*rsa.PublicKey); ok { // For RSA issuers, only overwrite the default algorithm is the // intermediate is signed with an RSA signature scheme. if isRSA(parent.SignatureAlgorithm) { template.SignatureAlgorithm = parent.SignatureAlgorithm } } } return x509util.CreateCertificate(template, parent, pub, signer) } func isRSA(sa x509.SignatureAlgorithm) bool { switch sa { case x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: return true case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS: return true default: return false } } ================================================ FILE: cas/softcas/softcas_test.go ================================================ package softcas import ( "bytes" "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "fmt" "io" "math/big" "reflect" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/cas/apiv1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/kms" kmsapi "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/minica" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" ) var ( testIntermediatePem = `-----BEGIN CERTIFICATE----- MIIBPjCB8aADAgECAhAk4aPIlsVvQg3gveApc3mIMAUGAytlcDAeMRwwGgYDVQQD ExNTbWFsbHN0ZXAgVW5pdCBUZXN0MB4XDTIwMDkxNjAyMDgwMloXDTMwMDkxNDAy MDgwMlowHjEcMBoGA1UEAxMTU21hbGxzdGVwIFVuaXQgVGVzdDAqMAUGAytlcAMh ANLs3JCzECR29biut0NDsaLnh0BGij5eJx6VkdJPfS/ko0UwQzAOBgNVHQ8BAf8E BAMCAQYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQUup5qpZFMAFdgK7RB xNzmUaQM8YwwBQYDK2VwA0EAAwcW25E/6bchyKwp3RRK1GXiPMDCc+hsTJxuOLWy YM7ga829dU8X4pRcEEAcBndqCED/502excjEK7U9vCkFCg== -----END CERTIFICATE-----` testIntermediateKeyPem = `-----BEGIN PRIVATE KEY----- MC4CAQAwBQYDK2VwBCIEII9ZckcrDKlbhZKR0jp820Uz6mOMLFsq2JhI+Tl7WJwH -----END PRIVATE KEY-----` ) var ( errTest = errors.New("test error") testIssuer = mustIssuer() testSigner = mustSigner() testTemplate = &x509.Certificate{ Subject: pkix.Name{CommonName: "test.smallstep.com"}, DNSNames: []string{"test.smallstep.com"}, KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, PublicKey: testSigner.Public(), SerialNumber: big.NewInt(1234), } testRootTemplate = &x509.Certificate{ Subject: pkix.Name{CommonName: "Test Root CA"}, KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, PublicKey: testSigner.Public(), BasicConstraintsValid: true, IsCA: true, MaxPathLen: 1, SerialNumber: big.NewInt(1234), } testIntermediateTemplate = &x509.Certificate{ Subject: pkix.Name{CommonName: "Test Intermediate CA"}, KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, PublicKey: testSigner.Public(), BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, MaxPathLenZero: true, SerialNumber: big.NewInt(1234), } testNow = time.Now() testSignedTemplate = mustSign(testTemplate, testIssuer, testNow, testNow.Add(24*time.Hour)) testSignedRootTemplate = mustSign(testRootTemplate, testRootTemplate, testNow, testNow.Add(24*time.Hour)) testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour)) testCertificateSigner = func() ([]*x509.Certificate, crypto.Signer, error) { return []*x509.Certificate{testIssuer}, testSigner, nil } testFailCertificateSigner = func() ([]*x509.Certificate, crypto.Signer, error) { return nil, nil, errTest } ) type signatureAlgorithmSigner struct { crypto.Signer algorithm x509.SignatureAlgorithm } func (s *signatureAlgorithmSigner) SignatureAlgorithm() x509.SignatureAlgorithm { return s.algorithm } type mockKeyManager struct { signer crypto.Signer errGetPublicKey error errCreateKey error errCreatesigner error errClose error } func (m *mockKeyManager) GetPublicKey(*kmsapi.GetPublicKeyRequest) (crypto.PublicKey, error) { signer := testSigner if m.signer != nil { signer = m.signer } return signer.Public(), m.errGetPublicKey } func (m *mockKeyManager) CreateKey(req *kmsapi.CreateKeyRequest) (*kmsapi.CreateKeyResponse, error) { signer := testSigner if m.signer != nil { signer = m.signer } return &kmsapi.CreateKeyResponse{ Name: req.Name, PrivateKey: signer, PublicKey: signer.Public(), }, m.errCreateKey } func (m *mockKeyManager) CreateSigner(*kmsapi.CreateSignerRequest) (crypto.Signer, error) { signer := testSigner if m.signer != nil { signer = m.signer } return signer, m.errCreatesigner } func (m *mockKeyManager) CreateDecrypter(*kmsapi.CreateDecrypterRequest) (crypto.Decrypter, error) { return nil, nil } func (m *mockKeyManager) Close() error { return m.errClose } type badSigner struct{} func (b *badSigner) Public() crypto.PublicKey { return testSigner.Public() } func (b *badSigner) Sign(_ io.Reader, _ []byte, _ crypto.SignerOpts) ([]byte, error) { return nil, fmt.Errorf("💥") } //nolint:gocritic // ignore sloppy test func name func mockNow(t *testing.T) { tmp := now now = func() time.Time { return testNow } t.Cleanup(func() { now = tmp }) } func mustIssuer() *x509.Certificate { v, err := pemutil.Parse([]byte(testIntermediatePem)) if err != nil { panic(err) } return v.(*x509.Certificate) } func mustSigner() crypto.Signer { v, err := pemutil.Parse([]byte(testIntermediateKeyPem)) if err != nil { panic(err) } return v.(crypto.Signer) } func mustSign(template, parent *x509.Certificate, notBefore, notAfter time.Time) *x509.Certificate { tmpl := *template tmpl.NotBefore = notBefore tmpl.NotAfter = notAfter tmpl.Issuer = parent.Subject cert, err := x509util.CreateCertificate(&tmpl, parent, tmpl.PublicKey, testSigner) if err != nil { panic(err) } return cert } func setTeeReader(t *testing.T, w *bytes.Buffer) { t.Helper() reader := rand.Reader t.Cleanup(func() { rand.Reader = reader }) rand.Reader = io.TeeReader(reader, w) } func TestNew(t *testing.T) { assertEqual := func(x, y interface{}) bool { return reflect.DeepEqual(x, y) || fmt.Sprintf("%#v", x) == fmt.Sprintf("%#v", y) } type args struct { ctx context.Context opts apiv1.Options } tests := []struct { name string args args want *SoftCAS wantErr bool }{ {"ok", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}}, &SoftCAS{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}, false}, {"ok with callback", args{context.Background(), apiv1.Options{CertificateSigner: testCertificateSigner}}, &SoftCAS{CertificateSigner: testCertificateSigner}, false}, {"fail no issuer", args{context.Background(), apiv1.Options{Signer: testSigner}}, nil, true}, {"fail no signer", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := New(tt.args.ctx, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } if !assertEqual(got, tt.want) { t.Errorf("New() = %v, want %v", got, tt.want) } }) } } func TestNew_register(t *testing.T) { newFn, ok := apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.SoftCAS) if !ok { t.Error("apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.SoftCAS) was not found") return } want := &SoftCAS{ CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner, } got, err := newFn(context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}) if err != nil { t.Errorf("New() error = %v", err) return } if !reflect.DeepEqual(got, want) { t.Errorf("New() = %v, want %v", got, want) } } func TestSoftCAS_Type(t *testing.T) { tests := []struct { name string want apiv1.Type }{ {"ok", apiv1.SoftCAS}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{} if got := c.Type(); got != tt.want { t.Errorf("SoftCAS.Type() = %v, want %v", got, tt.want) } }) } } func TestSoftCAS_GetSigner(t *testing.T) { ca, err := minica.New() require.NoError(t, err) type fields struct { CertificateChain []*x509.Certificate Signer crypto.Signer CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) KeyManager kms.KeyManager } tests := []struct { name string fields fields want crypto.Signer assertion assert.ErrorAssertionFunc }{ {"ok signer", fields{[]*x509.Certificate{ca.Intermediate}, ca.Signer, nil, nil}, ca.Signer, assert.NoError}, {"ok certificateSigner", fields{[]*x509.Certificate{ca.Intermediate}, nil, func() ([]*x509.Certificate, crypto.Signer, error) { return []*x509.Certificate{ca.Intermediate}, ca.Signer, nil }, nil}, ca.Signer, assert.NoError}, {"fail certificateSigner", fields{[]*x509.Certificate{ca.Intermediate}, nil, func() ([]*x509.Certificate, crypto.Signer, error) { return nil, nil, apiv1.NotImplementedError{} }, nil}, nil, assert.Error}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ CertificateChain: tt.fields.CertificateChain, Signer: tt.fields.Signer, CertificateSigner: tt.fields.CertificateSigner, KeyManager: tt.fields.KeyManager, } got, err := c.GetSigner() tt.assertion(t, err) assert.Equal(t, tt.want, got) }) } } func TestSoftCAS_CreateCertificate(t *testing.T) { mockNow(t) // Set rand.Reader to EOF buf := new(bytes.Buffer) setTeeReader(t, buf) rand.Reader = buf tmplNotBefore := *testTemplate tmplNotBefore.NotBefore = testNow tmplWithLifetime := *testTemplate tmplWithLifetime.NotBefore = testNow tmplWithLifetime.NotAfter = testNow.Add(24 * time.Hour) tmplNoSerial := *testTemplate tmplNoSerial.SerialNumber = nil saTemplate := *testSignedTemplate saTemplate.SignatureAlgorithm = 0 saSigner := &signatureAlgorithmSigner{ Signer: testSigner, algorithm: x509.PureEd25519, } type fields struct { Issuer *x509.Certificate Signer crypto.Signer CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.CreateCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.CreateCertificateResponse wantErr bool }{ {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok signature algorithm", fields{testIssuer, saSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &saTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok with notBefore", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNotBefore, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok with notBefore+notAfter", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplWithLifetime, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.CreateCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"fail template", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, {"fail lifetime", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{Template: testTemplate}}, nil, true}, {"fail CreateCertificate", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNoSerial, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.CreateCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ CertificateChain: []*x509.Certificate{tt.fields.Issuer}, Signer: tt.fields.Signer, CertificateSigner: tt.fields.CertificateSigner, } got, err := c.CreateCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("SoftCAS.CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("SoftCAS.CreateCertificate() = %v, want %v", got, tt.want) } }) } } func TestSoftCAS_CreateCertificate_pss(t *testing.T) { signer, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatal(err) } now := time.Now() template := &x509.Certificate{ Subject: pkix.Name{CommonName: "Test Root CA"}, KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, PublicKey: signer.Public(), BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, SerialNumber: big.NewInt(1234), SignatureAlgorithm: x509.SHA256WithRSAPSS, NotBefore: now, NotAfter: now.Add(24 * time.Hour), } iss, err := x509util.CreateCertificate(template, template, signer.Public(), signer) if err != nil { t.Fatal(err) } if iss.SignatureAlgorithm != x509.SHA256WithRSAPSS { t.Errorf("Certificate.SignatureAlgorithm = %v, want %v", iss.SignatureAlgorithm, x509.SHA256WithRSAPSS) } c := &SoftCAS{ CertificateChain: []*x509.Certificate{iss}, Signer: signer, } cert, err := c.CreateCertificate(&apiv1.CreateCertificateRequest{ Template: &x509.Certificate{ Subject: pkix.Name{CommonName: "test.smallstep.com"}, DNSNames: []string{"test.smallstep.com"}, KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, PublicKey: testSigner.Public(), SerialNumber: big.NewInt(1234), }, Lifetime: time.Hour, Backdate: time.Minute, }) if err != nil { t.Fatalf("SoftCAS.CreateCertificate() error = %v", err) } if cert.Certificate.SignatureAlgorithm != x509.SHA256WithRSAPSS { t.Errorf("Certificate.SignatureAlgorithm = %v, want %v", iss.SignatureAlgorithm, x509.SHA256WithRSAPSS) } pool := x509.NewCertPool() pool.AddCert(iss) if _, err = cert.Certificate.Verify(x509.VerifyOptions{ CurrentTime: time.Now(), Roots: pool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, }); err != nil { t.Errorf("Certificate.Verify() error = %v", err) } } func TestSoftCAS_CreateCertificate_ec_rsa(t *testing.T) { rootSigner, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } intSigner, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatal(err) } now := time.Now() // Root template template := &x509.Certificate{ Subject: pkix.Name{CommonName: "Test Root CA"}, KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, PublicKey: rootSigner.Public(), BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, SerialNumber: big.NewInt(1234), NotBefore: now, NotAfter: now.Add(24 * time.Hour), } root, err := x509util.CreateCertificate(template, template, rootSigner.Public(), rootSigner) if err != nil { t.Fatal(err) } // Intermediate template template = &x509.Certificate{ Subject: pkix.Name{CommonName: "Test Intermediate CA"}, KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, PublicKey: intSigner.Public(), BasicConstraintsValid: true, IsCA: true, MaxPathLen: 0, SerialNumber: big.NewInt(1234), NotBefore: now, NotAfter: now.Add(24 * time.Hour), } iss, err := x509util.CreateCertificate(template, root, intSigner.Public(), rootSigner) if err != nil { t.Fatal(err) } if iss.SignatureAlgorithm != x509.ECDSAWithSHA256 { t.Errorf("Certificate.SignatureAlgorithm = %v, want %v", iss.SignatureAlgorithm, x509.ECDSAWithSHA256) } c := &SoftCAS{ CertificateChain: []*x509.Certificate{iss}, Signer: intSigner, } cert, err := c.CreateCertificate(&apiv1.CreateCertificateRequest{ Template: &x509.Certificate{ Subject: pkix.Name{CommonName: "test.smallstep.com"}, DNSNames: []string{"test.smallstep.com"}, KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, PublicKey: testSigner.Public(), SerialNumber: big.NewInt(1234), }, Lifetime: time.Hour, Backdate: time.Minute, }) if err != nil { t.Fatalf("SoftCAS.CreateCertificate() error = %v", err) } if cert.Certificate.SignatureAlgorithm != x509.SHA256WithRSA { t.Errorf("Certificate.SignatureAlgorithm = %v, want %v", iss.SignatureAlgorithm, x509.SHA256WithRSAPSS) } roots := x509.NewCertPool() roots.AddCert(root) intermediates := x509.NewCertPool() intermediates.AddCert(iss) if _, err = cert.Certificate.Verify(x509.VerifyOptions{ CurrentTime: time.Now(), Roots: roots, Intermediates: intermediates, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, }); err != nil { t.Errorf("Certificate.Verify() error = %v", err) } } func TestSoftCAS_RenewCertificate(t *testing.T) { mockNow(t) // Set rand.Reader to EOF buf := new(bytes.Buffer) setTeeReader(t, buf) rand.Reader = buf tmplNoSerial := *testTemplate tmplNoSerial.SerialNumber = nil saSigner := &signatureAlgorithmSigner{ Signer: testSigner, algorithm: x509.PureEd25519, } type fields struct { Issuer *x509.Certificate Signer crypto.Signer CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.RenewCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.RenewCertificateResponse wantErr bool }{ {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok signature algorithm", fields{testIssuer, saSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"fail template", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, {"fail lifetime", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, {"fail CreateCertificate", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: &tmplNoSerial, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ CertificateChain: []*x509.Certificate{tt.fields.Issuer}, Signer: tt.fields.Signer, CertificateSigner: tt.fields.CertificateSigner, } got, err := c.RenewCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("SoftCAS.RenewCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("SoftCAS.RenewCertificate() = %v, want %v", got, tt.want) } }) } } func TestSoftCAS_RevokeCertificate(t *testing.T) { type fields struct { Issuer *x509.Certificate Signer crypto.Signer CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.RevokeCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.RevokeCertificateResponse wantErr bool }{ {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{ Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, Reason: "test reason", ReasonCode: 1, }}, &apiv1.RevokeCertificateResponse{ Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok no cert", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{ Reason: "test reason", ReasonCode: 1, }}, &apiv1.RevokeCertificateResponse{ Certificate: nil, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok empty", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{}}, &apiv1.RevokeCertificateResponse{ Certificate: nil, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.RevokeCertificateRequest{ Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, Reason: "test reason", ReasonCode: 1, }}, &apiv1.RevokeCertificateResponse{ Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.RevokeCertificateRequest{ Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, Reason: "test reason", ReasonCode: 1, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ CertificateChain: []*x509.Certificate{tt.fields.Issuer}, Signer: tt.fields.Signer, CertificateSigner: tt.fields.CertificateSigner, } got, err := c.RevokeCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("SoftCAS.RevokeCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("SoftCAS.RevokeCertificate() = %v, want %v", got, tt.want) } }) } } func Test_now(t *testing.T) { t0 := time.Now() t1 := now() if t1.Sub(t0) > time.Second { t.Errorf("now() = %s, want ~%s", t1, t0) } } func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { mockNow(t) saSigner := &signatureAlgorithmSigner{ Signer: testSigner, algorithm: x509.PureEd25519, } type fields struct { Issuer *x509.Certificate Signer crypto.Signer KeyManager kms.KeyManager } type args struct { req *apiv1.CreateCertificateAuthorityRequest } tests := []struct { name string fields fields args args want *apiv1.CreateCertificateAuthorityResponse wantErr bool }{ {"ok root", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testRootTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateAuthorityResponse{ Name: "Test Root CA", Certificate: testSignedRootTemplate, PublicKey: testSignedRootTemplate.PublicKey, PrivateKey: testSigner, Signer: testSigner, }, false}, {"ok intermediate", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: testSignedRootTemplate, Signer: testSigner, }, }}, &apiv1.CreateCertificateAuthorityResponse{ Name: "Test Intermediate CA", Certificate: testSignedIntermediateTemplate, CertificateChain: []*x509.Certificate{testSignedRootTemplate}, PublicKey: testSignedIntermediateTemplate.PublicKey, PrivateKey: testSigner, Signer: testSigner, }, false}, {"ok signature algorithm", fields{nil, nil, &mockKeyManager{signer: saSigner}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testRootTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateAuthorityResponse{ Name: "Test Root CA", Certificate: testSignedRootTemplate, PublicKey: testSignedRootTemplate.PublicKey, PrivateKey: saSigner, Signer: saSigner, }, false}, {"ok createKey", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testRootTemplate, Lifetime: 24 * time.Hour, CreateKey: &kmsapi.CreateKeyRequest{ Name: "root_ca.crt", SignatureAlgorithm: kmsapi.ECDSAWithSHA256, }, }}, &apiv1.CreateCertificateAuthorityResponse{ Name: "Test Root CA", Certificate: testSignedRootTemplate, PublicKey: testSignedRootTemplate.PublicKey, KeyName: "root_ca.crt", PrivateKey: testSigner, Signer: testSigner, }, false}, {"fail template", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail lifetime", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testIntermediateTemplate, }}, nil, true}, {"fail type", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail parent", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail parent.certificate", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Signer: testSigner, }, }}, nil, true}, {"fail parent.signer", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: testSignedRootTemplate, }, }}, nil, true}, {"fail createKey", fields{nil, nil, &mockKeyManager{errCreateKey: errTest}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail createSigner", fields{nil, nil, &mockKeyManager{errCreatesigner: errTest}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail sign root", fields{nil, nil, &mockKeyManager{signer: &badSigner{}}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, }}, nil, true}, {"fail sign intermediate", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: testSignedRootTemplate, Signer: &badSigner{}, }, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ CertificateChain: []*x509.Certificate{tt.fields.Issuer}, Signer: tt.fields.Signer, KeyManager: tt.fields.KeyManager, } got, err := c.CreateCertificateAuthority(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("SoftCAS.CreateCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("SoftCAS.CreateCertificateAuthority() = \n%#v, want \n%#v", got, tt.want) } }) } } func TestSoftCAS_defaultKeyManager(t *testing.T) { mockNow(t) type args struct { req *apiv1.CreateCertificateAuthorityRequest } tests := []struct { name string args args wantErr bool }{ {"ok root", args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.RootCA, Template: &x509.Certificate{ Subject: pkix.Name{CommonName: "Test Root CA"}, KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, BasicConstraintsValid: true, IsCA: true, MaxPathLen: 1, SerialNumber: big.NewInt(1234), }, Lifetime: 24 * time.Hour, }}, false}, {"ok intermediate", args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: testSignedRootTemplate, Signer: testSigner, }, }}, false}, {"fail with default key manager", args{&apiv1.CreateCertificateAuthorityRequest{ Type: apiv1.IntermediateCA, Template: testIntermediateTemplate, Lifetime: 24 * time.Hour, Parent: &apiv1.CreateCertificateAuthorityResponse{ Certificate: testSignedRootTemplate, Signer: &badSigner{}, }, }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{} _, err := c.CreateCertificateAuthority(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("SoftCAS.CreateCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) return } }) } } func Test_isRSA(t *testing.T) { type args struct { sa x509.SignatureAlgorithm } tests := []struct { name string args args want bool }{ {"SHA256WithRSA", args{x509.SHA256WithRSA}, true}, {"SHA384WithRSA", args{x509.SHA384WithRSA}, true}, {"SHA512WithRSA", args{x509.SHA512WithRSA}, true}, {"SHA256WithRSAPSS", args{x509.SHA256WithRSAPSS}, true}, {"SHA384WithRSAPSS", args{x509.SHA384WithRSAPSS}, true}, {"SHA512WithRSAPSS", args{x509.SHA512WithRSAPSS}, true}, {"ECDSAWithSHA256", args{x509.ECDSAWithSHA256}, false}, {"ECDSAWithSHA384", args{x509.ECDSAWithSHA384}, false}, {"ECDSAWithSHA512", args{x509.ECDSAWithSHA512}, false}, {"PureEd25519", args{x509.PureEd25519}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := isRSA(tt.args.sa); got != tt.want { t.Errorf("isRSA() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cas/stepcas/issuer.go ================================================ package stepcas import ( "context" "net/url" "strings" "time" "github.com/google/uuid" "github.com/pkg/errors" "github.com/smallstep/certificates/ca" "github.com/smallstep/certificates/cas/apiv1" ) // raAuthorityNS is a custom namespace used to generate endpoint ids based on // the authority id. var raAuthorityNS = uuid.MustParse("d6f14c1f-2f92-47bf-a04f-7b2c11382edd") // newServerEndpointID returns a uuid v5 using raAuthorityNS as the namespace. // The return uuid will be used as the server endpoint id, it will be unique per // authority. func newServerEndpointID(data string) uuid.UUID { return uuid.NewSHA1(raAuthorityNS, []byte(data)) } type raInfo struct { AuthorityID string `json:"authorityId,omitempty"` EndpointID string `json:"endpointId,omitempty"` ProvisionerID string `json:"provisionerId,omitempty"` ProvisionerType string `json:"provisionerType,omitempty"` ProvisionerName string `json:"provisionerName,omitempty"` } type stepIssuer interface { SignToken(subject string, sans []string, info *raInfo) (string, error) RevokeToken(subject string) (string, error) Lifetime(d time.Duration) time.Duration } // newStepIssuer returns the configured step issuer. func newStepIssuer(ctx context.Context, caURL *url.URL, client *ca.Client, iss *apiv1.CertificateIssuer) (stepIssuer, error) { if err := validateCertificateIssuer(iss); err != nil { return nil, err } switch strings.ToLower(iss.Type) { case "x5c": return newX5CIssuer(caURL, iss) case "jwk": return newJWKIssuer(ctx, caURL, client, iss) default: return nil, errors.Errorf("stepCAS `certificateIssuer.type` %s is not supported", iss.Type) } } // validateCertificateIssuer validates the configuration of the certificate // issuer. func validateCertificateIssuer(iss *apiv1.CertificateIssuer) error { switch { case iss == nil: return errors.New("stepCAS 'certificateIssuer' cannot be nil") case iss.Type == "": return errors.New("stepCAS `certificateIssuer.type` cannot be empty") } switch strings.ToLower(iss.Type) { case "x5c": return validateX5CIssuer(iss) case "jwk": return validateJWKIssuer(iss) default: return errors.Errorf("stepCAS `certificateIssuer.type` %s is not supported", iss.Type) } } // validateX5CIssuer validates the configuration of x5c issuer. func validateX5CIssuer(iss *apiv1.CertificateIssuer) error { switch { case iss.Certificate == "": return errors.New("stepCAS `certificateIssuer.crt` cannot be empty") case iss.Key == "": return errors.New("stepCAS `certificateIssuer.key` cannot be empty") case iss.Provisioner == "": return errors.New("stepCAS `certificateIssuer.provisioner` cannot be empty") default: return nil } } // validateJWKIssuer validates the configuration of jwk issuer. If the key is // not given, then it will download it from the CA. If the password is not set // it will be prompted. func validateJWKIssuer(iss *apiv1.CertificateIssuer) error { switch iss.Provisioner { case "": return errors.New("stepCAS `certificateIssuer.provisioner` cannot be empty") default: return nil } } ================================================ FILE: cas/stepcas/issuer_test.go ================================================ package stepcas import ( "context" "net/url" "reflect" "testing" "time" "github.com/google/uuid" "github.com/smallstep/certificates/ca" "github.com/smallstep/certificates/cas/apiv1" "go.step.sm/crypto/jose" ) type mockErrIssuer struct{} func (m mockErrIssuer) SignToken(string, []string, *raInfo) (string, error) { return "", apiv1.NotImplementedError{} } func (m mockErrIssuer) RevokeToken(string) (string, error) { return "", apiv1.NotImplementedError{} } func (m mockErrIssuer) Lifetime(d time.Duration) time.Duration { return d } type mockErrSigner struct{} func (s *mockErrSigner) Sign([]byte) (*jose.JSONWebSignature, error) { return nil, apiv1.NotImplementedError{} } func (s *mockErrSigner) Options() jose.SignerOptions { return jose.SignerOptions{} } func Test_newServerEndpointID(t *testing.T) { type args struct { name string } tests := []struct { name string args args want []byte }{ {"ok", args{"foo"}, []byte{ 0x8f, 0x63, 0x69, 0x20, 0x8a, 0x7a, 0x57, 0x0c, 0xbe, 0x4c, 0x46, 0x66, 0x77, 0xf8, 0x54, 0xe7, }}, {"ok uuid", args{"e4fa6d2d-fa9c-4fdc-913e-7484cc9516e4"}, []byte{ 0x8d, 0x8d, 0x7f, 0x04, 0x73, 0xd4, 0x5f, 0x2f, 0xa8, 0xe1, 0x28, 0x9a, 0xd1, 0xa8, 0xcf, 0x7e, }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var want uuid.UUID copy(want[:], tt.want) got := newServerEndpointID(tt.args.name) if !reflect.DeepEqual(got, want) { t.Errorf("newServerEndpointID() = %v, want %v", got, tt.want) } // Check version if v := (got[6] & 0xf0) >> 4; v != 5 { t.Errorf("newServerEndpointID() version = %d, want 5", v) } // Check variant if v := (got[8] & 0x80) >> 6; v != 2 { t.Errorf("newServerEndpointID() variant = %d, want 2", v) } }) } } func Test_newStepIssuer(t *testing.T) { caURL, client := testCAHelper(t) signer, err := newJWKSignerFromEncryptedKey(testKeyID, testEncryptedJWKKey, testPassword) if err != nil { t.Fatal(err) } type args struct { caURL *url.URL client *ca.Client iss *apiv1.CertificateIssuer } tests := []struct { name string args args want stepIssuer wantErr bool }{ {"x5c", args{caURL, client, &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath, Key: testX5CKeyPath, }}, &x5cIssuer{ caURL: caURL, certFile: testX5CPath, keyFile: testX5CKeyPath, issuer: "X5C", }, false}, {"jwk", args{caURL, client, &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "ra@doe.org", Key: testX5CKeyPath, }}, &jwkIssuer{ caURL: caURL, issuer: "ra@doe.org", signer: signer, }, false}, {"fail", args{caURL, client, &apiv1.CertificateIssuer{ Type: "unknown", Provisioner: "ra@doe.org", Key: testX5CKeyPath, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := newStepIssuer(context.TODO(), tt.args.caURL, tt.args.client, tt.args.iss) if (err != nil) != tt.wantErr { t.Errorf("newStepIssuer() error = %v, wantErr %v", err, tt.wantErr) return } if tt.args.iss.Type == "jwk" && got != nil && tt.want != nil { got.(*jwkIssuer).signer = tt.want.(*jwkIssuer).signer } if !reflect.DeepEqual(got, tt.want) { t.Errorf("newStepIssuer() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cas/stepcas/jwk_issuer.go ================================================ package stepcas import ( "context" "crypto" "encoding/json" "net/url" "time" "github.com/pkg/errors" "github.com/smallstep/cli-utils/ui" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/ca" "github.com/smallstep/certificates/cas/apiv1" ) type jwkIssuer struct { caURL *url.URL issuer string signer jose.Signer } func newJWKIssuer(ctx context.Context, caURL *url.URL, client *ca.Client, cfg *apiv1.CertificateIssuer) (*jwkIssuer, error) { var err error var signer jose.Signer // Read the key from the CA if not provided. // Or read it from a PEM file. if cfg.Key == "" { p, err := findProvisioner(ctx, client, provisioner.TypeJWK, cfg.Provisioner) if err != nil { return nil, err } kid, key, ok := p.GetEncryptedKey() if !ok { return nil, errors.Errorf("provisioner with name %s does not have an encrypted key", cfg.Provisioner) } signer, err = newJWKSignerFromEncryptedKey(kid, key, cfg.Password) if err != nil { return nil, err } } else { signer, err = newJWKSigner(cfg.Key, cfg.Password) if err != nil { return nil, err } } return &jwkIssuer{ caURL: caURL, issuer: cfg.Provisioner, signer: signer, }, nil } func (i *jwkIssuer) SignToken(subject string, sans []string, info *raInfo) (string, error) { aud := i.caURL.ResolveReference(&url.URL{ Path: "/1.0/sign", }).String() return i.createToken(aud, subject, sans, info) } func (i *jwkIssuer) RevokeToken(subject string) (string, error) { aud := i.caURL.ResolveReference(&url.URL{ Path: "/1.0/revoke", }).String() return i.createToken(aud, subject, nil, nil) } func (i *jwkIssuer) Lifetime(d time.Duration) time.Duration { return d } func (i *jwkIssuer) createToken(aud, sub string, sans []string, info *raInfo) (string, error) { id, err := randutil.Hex(64) // 256 bits if err != nil { return "", err } claims := defaultClaims(i.issuer, sub, aud, id) builder := jose.Signed(i.signer).Claims(claims) if len(sans) > 0 { builder = builder.Claims(map[string]interface{}{ "sans": sans, }) } if info != nil { builder = builder.Claims(map[string]interface{}{ "step": map[string]interface{}{ "ra": info, }, }) } tok, err := builder.CompactSerialize() if err != nil { return "", errors.Wrap(err, "error signing token") } return tok, nil } func newJWKSigner(keyFile, password string) (jose.Signer, error) { signer, err := readKey(keyFile, password) if err != nil { return nil, err } kid, err := jose.Thumbprint(&jose.JSONWebKey{Key: signer.Public()}) if err != nil { return nil, err } so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", kid) return newJoseSigner(signer, so) } func newJWKSignerFromEncryptedKey(kid, key, password string) (jose.Signer, error) { var jwk jose.JSONWebKey // If the password is empty it will use the password prompter. b, err := jose.Decrypt([]byte(key), jose.WithPassword([]byte(password)), jose.WithPasswordPrompter("Please enter the password to decrypt the provisioner key", func(msg string) ([]byte, error) { return ui.PromptPassword(msg) })) if err != nil { return nil, err } // Decrypt returns the JSON representation of the JWK. if err := json.Unmarshal(b, &jwk); err != nil { return nil, errors.Wrap(err, "error parsing provisioner key") } signer, ok := jwk.Key.(crypto.Signer) if !ok { return nil, errors.New("error parsing provisioner key: key is not a crypto.Signer") } so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", kid) return newJoseSigner(signer, so) } func findProvisioner(ctx context.Context, client *ca.Client, typ provisioner.Type, name string) (provisioner.Interface, error) { cursor := "" for { ps, err := client.ProvisionersWithContext(ctx, ca.WithProvisionerCursor(cursor)) if err != nil { return nil, err } for _, p := range ps.Provisioners { if p.GetType() == typ && p.GetName() == name { return p, nil } } if ps.NextCursor == "" { return nil, errors.Errorf("provisioner with name %s was not found", name) } cursor = ps.NextCursor } } ================================================ FILE: cas/stepcas/jwk_issuer_test.go ================================================ package stepcas import ( "net/url" "reflect" "testing" "time" "go.step.sm/crypto/jose" ) func Test_jwkIssuer_SignToken(t *testing.T) { caURL, err := url.Parse("https://ca.smallstep.com") if err != nil { t.Fatal(err) } signer, err := newJWKSignerFromEncryptedKey(testKeyID, testEncryptedJWKKey, testPassword) if err != nil { t.Fatal(err) } type fields struct { caURL *url.URL issuer string signer jose.Signer } type args struct { subject string sans []string info *raInfo } type stepClaims struct { RA *raInfo `json:"ra"` } type claims struct { Aud jose.Audience `json:"aud"` Sub string `json:"sub"` Sans []string `json:"sans"` Step stepClaims `json:"step"` } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{caURL, "ra@doe.org", signer}, args{"doe", []string{"doe.org"}, nil}, false}, {"ok ra", fields{caURL, "ra@doe.org", signer}, args{"doe", []string{"doe.org"}, &raInfo{ AuthorityID: "authority-id", ProvisionerID: "provisioner-id", ProvisionerType: "provisioner-type", }}, false}, {"ok ra endpoint id", fields{caURL, "ra@doe.org", signer}, args{"doe", []string{"doe.org"}, &raInfo{ AuthorityID: "authority-id", EndpointID: "endpoint-id", }}, false}, {"fail", fields{caURL, "ra@doe.org", &mockErrSigner{}}, args{"doe", []string{"doe.org"}, nil}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &jwkIssuer{ caURL: tt.fields.caURL, issuer: tt.fields.issuer, signer: tt.fields.signer, } got, err := i.SignToken(tt.args.subject, tt.args.sans, tt.args.info) if (err != nil) != tt.wantErr { t.Errorf("jwkIssuer.SignToken() error = %v, wantErr %v", err, tt.wantErr) return } if !tt.wantErr { jwt, err := jose.ParseSigned(got) if err != nil { t.Errorf("jose.ParseSigned() error = %v", err) } var c claims want := claims{ Aud: jose.Audience{tt.fields.caURL.String() + "/1.0/sign"}, Sub: tt.args.subject, Sans: tt.args.sans, } if tt.args.info != nil { want.Step.RA = tt.args.info } if err := jwt.Claims(testX5CKey.Public(), &c); err != nil { t.Log(got) t.Errorf("jwt.Claims() error = %v", err) } if !reflect.DeepEqual(c, want) { t.Errorf("jwt.Claims() claims = %#v, want %#v", c, want) } } }) } } func Test_jwkIssuer_RevokeToken(t *testing.T) { caURL, err := url.Parse("https://ca.smallstep.com") if err != nil { t.Fatal(err) } signer, err := newJWKSignerFromEncryptedKey(testKeyID, testEncryptedJWKKey, testPassword) if err != nil { t.Fatal(err) } type fields struct { caURL *url.URL issuer string signer jose.Signer } type args struct { subject string } type claims struct { Aud jose.Audience `json:"aud"` Sub string `json:"sub"` Sans []string `json:"sans"` } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{caURL, "ra@doe.org", signer}, args{"doe"}, false}, {"ok", fields{caURL, "ra@doe.org", &mockErrSigner{}}, args{"doe"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &jwkIssuer{ caURL: tt.fields.caURL, issuer: tt.fields.issuer, signer: tt.fields.signer, } got, err := i.RevokeToken(tt.args.subject) if (err != nil) != tt.wantErr { t.Errorf("jwkIssuer.RevokeToken() error = %v, wantErr %v", err, tt.wantErr) return } if !tt.wantErr { jwt, err := jose.ParseSigned(got) if err != nil { t.Errorf("jose.ParseSigned() error = %v", err) } var c claims want := claims{ Aud: []string{tt.fields.caURL.String() + "/1.0/revoke"}, Sub: tt.args.subject, } if err := jwt.Claims(testX5CKey.Public(), &c); err != nil { t.Errorf("jwt.Claims() error = %v", err) } if !reflect.DeepEqual(c, want) { t.Errorf("jwt.Claims() claims = %#v, want %#v", c, want) } } }) } } func Test_jwkIssuer_Lifetime(t *testing.T) { caURL, err := url.Parse("https://ca.smallstep.com") if err != nil { t.Fatal(err) } signer, err := newJWKSignerFromEncryptedKey(testKeyID, testEncryptedJWKKey, testPassword) if err != nil { t.Fatal(err) } type fields struct { caURL *url.URL issuer string signer jose.Signer } type args struct { d time.Duration } tests := []struct { name string fields fields args args want time.Duration }{ {"ok", fields{caURL, "ra@smallstep.com", signer}, args{time.Second}, time.Second}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &jwkIssuer{ caURL: tt.fields.caURL, issuer: tt.fields.issuer, signer: tt.fields.signer, } if got := i.Lifetime(tt.args.d); got != tt.want { t.Errorf("jwkIssuer.Lifetime() = %v, want %v", got, tt.want) } }) } } func Test_newJWKSignerFromEncryptedKey(t *testing.T) { encrypt := func(plaintext string) string { recipient := jose.Recipient{ Algorithm: jose.PBES2_HS256_A128KW, Key: testPassword, PBES2Count: jose.PBKDF2Iterations, PBES2Salt: []byte{0x01, 0x02}, } opts := new(jose.EncrypterOptions) opts.WithContentType(jose.ContentType("jwk+json")) encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts) if err != nil { t.Fatal(err) } jwe, err := encrypter.Encrypt([]byte(plaintext)) if err != nil { t.Fatal(err) } ret, err := jwe.CompactSerialize() if err != nil { t.Fatal(err) } return ret } type args struct { kid string key string password string } tests := []struct { name string args args wantErr bool }{ {"ok", args{testKeyID, testEncryptedJWKKey, testPassword}, false}, {"fail decrypt", args{testKeyID, testEncryptedJWKKey, "bad-password"}, true}, {"fail unmarshal", args{testKeyID, encrypt(`{not a json}`), testPassword}, true}, {"fail not signer", args{testKeyID, encrypt(`{"kty":"oct","k":"password"}`), testPassword}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := newJWKSignerFromEncryptedKey(tt.args.kid, tt.args.key, tt.args.password) if (err != nil) != tt.wantErr { t.Errorf("newJWKSignerFromEncryptedKey() error = %v, wantErr %v", err, tt.wantErr) } }) } } ================================================ FILE: cas/stepcas/stepcas.go ================================================ package stepcas import ( "context" "crypto/x509" "net/url" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/ca" "github.com/smallstep/certificates/cas/apiv1" ) func init() { apiv1.Register(apiv1.StepCAS, func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) { return New(ctx, opts) }) } // StepCAS implements the cas.CertificateAuthorityService interface using // another step-ca instance. type StepCAS struct { iss stepIssuer client *ca.Client authorityID string fingerprint string } // New creates a new CertificateAuthorityService implementation using another // step-ca instance. func New(ctx context.Context, opts apiv1.Options) (*StepCAS, error) { switch { case opts.CertificateAuthority == "": return nil, errors.New("stepCAS 'certificateAuthority' cannot be empty") case opts.CertificateAuthorityFingerprint == "": return nil, errors.New("stepCAS 'certificateAuthorityFingerprint' cannot be empty") } caURL, err := url.Parse(opts.CertificateAuthority) if err != nil { return nil, errors.Wrap(err, "stepCAS `certificateAuthority` is not valid") } // Create client. client, err := ca.NewClient(opts.CertificateAuthority, ca.WithRootSHA256(opts.CertificateAuthorityFingerprint)) //nolint:contextcheck // deeply nested context if err != nil { return nil, err } var iss stepIssuer // Create configured issuer unless we only want to use GetCertificateAuthority. // This avoid the request for the password if not provided. if !opts.IsCAGetter { if iss, err = newStepIssuer(ctx, caURL, client, opts.CertificateIssuer); err != nil { return nil, err } } return &StepCAS{ iss: iss, client: client, authorityID: opts.AuthorityID, fingerprint: opts.CertificateAuthorityFingerprint, }, nil } // Type returns the type of this CertificateAuthorityService. func (s *StepCAS) Type() apiv1.Type { return apiv1.StepCAS } // CreateCertificate uses the step-ca sign request with the configured // provisioner to get a new certificate from the certificate authority. func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) { switch { case req.CSR == nil: return nil, errors.New("createCertificateRequest `csr` cannot be nil") case req.Template == nil: return nil, errors.New("createCertificateRequest `template` cannot be nil") case req.Lifetime < 0: return nil, errors.New("createCertificateRequest `lifetime` cannot be less than 0") } info := &raInfo{ AuthorityID: s.authorityID, } if req.IsCAServerCert { info.EndpointID = newServerEndpointID(s.authorityID).String() } if p := req.Provisioner; p != nil { info.ProvisionerID = p.ID info.ProvisionerType = p.Type info.ProvisionerName = p.Name } cert, chain, err := s.createCertificate(req.CSR, req.Template, req.Lifetime, info) if err != nil { return nil, err } return &apiv1.CreateCertificateResponse{ Certificate: cert, CertificateChain: chain, }, nil } // RenewCertificate will always return a non-implemented error as mTLS renewals // are not supported yet. func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { if req.Token == "" { return nil, apiv1.ValidationError{Message: "renewCertificateRequest `token` cannot be empty"} } resp, err := s.client.RenewWithToken(req.Token) if err != nil { return nil, err } var chain []*x509.Certificate cert := resp.CertChainPEM[0].Certificate for _, c := range resp.CertChainPEM[1:] { chain = append(chain, c.Certificate) } return &apiv1.RenewCertificateResponse{ Certificate: cert, CertificateChain: chain, }, nil } // RevokeCertificate revokes a certificate. func (s *StepCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { if req.SerialNumber == "" && req.Certificate == nil { return nil, errors.New("revokeCertificateRequest `serialNumber` or `certificate` are required") } serialNumber := req.SerialNumber if req.Certificate != nil { serialNumber = req.Certificate.SerialNumber.String() } token, err := s.iss.RevokeToken(serialNumber) if err != nil { return nil, err } _, err = s.client.Revoke(&api.RevokeRequest{ Serial: serialNumber, ReasonCode: req.ReasonCode, Reason: req.Reason, OTT: token, Passive: req.PassiveOnly, }, nil) if err != nil { return nil, err } return &apiv1.RevokeCertificateResponse{ Certificate: req.Certificate, CertificateChain: nil, }, nil } // GetCertificateAuthority returns the root certificate of the certificate // authority using the configured fingerprint. func (s *StepCAS) GetCertificateAuthority(*apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) { resp, err := s.client.Root(s.fingerprint) if err != nil { return nil, err } return &apiv1.GetCertificateAuthorityResponse{ RootCertificate: resp.RootPEM.Certificate, }, nil } func (s *StepCAS) createCertificate(cr *x509.CertificateRequest, template *x509.Certificate, lifetime time.Duration, raInfo *raInfo) (*x509.Certificate, []*x509.Certificate, error) { sans := make([]string, 0, len(template.DNSNames)+len(template.EmailAddresses)+len(template.IPAddresses)+len(template.URIs)) sans = append(sans, template.DNSNames...) sans = append(sans, template.EmailAddresses...) for _, ip := range template.IPAddresses { sans = append(sans, ip.String()) } for _, u := range template.URIs { sans = append(sans, u.String()) } commonName := template.Subject.CommonName if commonName == "" && len(sans) > 0 { commonName = sans[0] } token, err := s.iss.SignToken(commonName, sans, raInfo) if err != nil { return nil, nil, err } resp, err := s.client.Sign(&api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: cr}, OTT: token, NotAfter: s.lifetime(lifetime), }) if err != nil { return nil, nil, err } var chain []*x509.Certificate cert := resp.CertChainPEM[0].Certificate for _, c := range resp.CertChainPEM[1:] { chain = append(chain, c.Certificate) } return cert, chain, nil } func (s *StepCAS) lifetime(d time.Duration) api.TimeDuration { var td api.TimeDuration td.SetDuration(s.iss.Lifetime(d)) return td } ================================================ FILE: cas/stepcas/stepcas_test.go ================================================ package stepcas import ( "bytes" "context" "crypto" "crypto/ed25519" "crypto/rand" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "reflect" "testing" "time" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/ca" "github.com/smallstep/certificates/cas/apiv1" "github.com/stretchr/testify/require" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" "go.step.sm/crypto/x509util" ) var ( testRootCrt *x509.Certificate testRootKey crypto.Signer testRootPath, testRootKeyPath string testRootFingerprint string testIssCrt *x509.Certificate testIssKey crypto.Signer testIssPath, testIssKeyPath string testX5CCrt *x509.Certificate testX5CKey crypto.Signer testX5CPath, testX5CKeyPath string testPassword, testEncryptedKeyPath string testKeyID, testEncryptedJWKKey string testCR *x509.CertificateRequest testCrt *x509.Certificate testKey crypto.Signer testFailCR *x509.CertificateRequest ) func mustSignCertificate(subject string, sans []string, template string, parent *x509.Certificate, signer crypto.Signer) (*x509.Certificate, crypto.Signer) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { panic(err) } cr, err := x509util.CreateCertificateRequest(subject, sans, priv) if err != nil { panic(err) } cert, err := x509util.NewCertificate(cr, x509util.WithTemplate(template, x509util.CreateTemplateData(subject, sans))) if err != nil { panic(err) } crt := cert.GetCertificate() crt.NotBefore = time.Now() crt.NotAfter = crt.NotBefore.Add(time.Hour) if parent == nil { parent = crt } if signer == nil { signer = priv } if crt, err = x509util.CreateCertificate(crt, parent, pub, signer); err != nil { panic(err) } return crt, priv } func mustSerializeCrt(filename string, certs ...*x509.Certificate) { buf := new(bytes.Buffer) for _, c := range certs { if err := pem.Encode(buf, &pem.Block{ Type: "CERTIFICATE", Bytes: c.Raw, }); err != nil { panic(err) } } if err := os.WriteFile(filename, buf.Bytes(), 0600); err != nil { panic(err) } } func mustSerializeKey(filename string, key crypto.Signer) { b, err := x509.MarshalPKCS8PrivateKey(key) if err != nil { panic(err) } b = pem.EncodeToMemory(&pem.Block{ Type: "PRIVATE KEY", Bytes: b, }) if err := os.WriteFile(filename, b, 0600); err != nil { panic(err) } } func mustEncryptKey(filename string, key crypto.Signer) { _, err := pemutil.Serialize(key, pemutil.ToFile(filename, 0600), pemutil.WithPKCS8(true), pemutil.WithPassword([]byte(testPassword))) if err != nil { panic(err) } } func testCAHelper(t *testing.T) (*url.URL, *ca.Client) { t.Helper() writeJSON := func(w http.ResponseWriter, v interface{}) { _ = json.NewEncoder(w).Encode(v) } parseJSON := func(r *http.Request, v interface{}) { _ = json.NewDecoder(r.Body).Decode(v) } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.RequestURI { case "/root/" + testRootFingerprint: w.WriteHeader(http.StatusOK) writeJSON(w, api.RootResponse{ RootPEM: api.NewCertificate(testRootCrt), }) case "/sign": var msg api.SignRequest parseJSON(r, &msg) if msg.CsrPEM.DNSNames[0] == "fail.doe.org" { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, `{"error":"fail","message":"fail"}`) return } w.WriteHeader(http.StatusOK) writeJSON(w, api.SignResponse{ CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)}, }) case "/renew": if r.Header.Get("Authorization") == "Bearer fail" { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, `{"error":"fail","message":"fail"}`) return } w.WriteHeader(http.StatusOK) writeJSON(w, api.SignResponse{ CertChainPEM: []api.Certificate{api.NewCertificate(testCrt), api.NewCertificate(testIssCrt)}, }) case "/revoke": var msg api.RevokeRequest parseJSON(r, &msg) if msg.Serial == "fail" { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, `{"error":"fail","message":"fail"}`) return } w.WriteHeader(http.StatusOK) writeJSON(w, api.RevokeResponse{ Status: "ok", }) case "/provisioners": w.WriteHeader(http.StatusOK) writeJSON(w, api.ProvisionersResponse{ NextCursor: "cursor", Provisioners: []provisioner.Interface{ &provisioner.JWK{ Type: "JWK", Name: "ra@doe.org", Key: &jose.JSONWebKey{KeyID: testKeyID, Key: testX5CKey.Public()}, EncryptedKey: testEncryptedJWKKey, }, &provisioner.JWK{ Type: "JWK", Name: "empty@doe.org", Key: &jose.JSONWebKey{KeyID: testKeyID, Key: testX5CKey.Public()}, }, }, }) case "/provisioners?cursor=cursor": w.WriteHeader(http.StatusOK) writeJSON(w, api.ProvisionersResponse{}) default: w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, `{"error":"not found"}`) } })) t.Cleanup(func() { srv.Close() }) u, err := url.Parse(srv.URL) if err != nil { srv.Close() t.Fatal(err) } client, err := ca.NewClient(srv.URL, ca.WithTransport(http.DefaultTransport)) if err != nil { srv.Close() t.Fatal(err) } return u, client } func testX5CIssuer(t *testing.T, caURL *url.URL, password string) *x5cIssuer { t.Helper() key, givenPassword := testX5CKeyPath, password if password != "" { key = testEncryptedKeyPath password = testPassword } x5c, err := newX5CIssuer(caURL, &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath, Key: key, Password: password, }) if err != nil { t.Fatal(err) } x5c.password = givenPassword return x5c } func testJWKIssuer(t *testing.T, caURL *url.URL, password string) *jwkIssuer { t.Helper() client, err := ca.NewClient(caURL.String(), ca.WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } key := testX5CKeyPath if password != "" { key = testEncryptedKeyPath password = testPassword } jwk, err := newJWKIssuer(context.TODO(), caURL, client, &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "ra@doe.org", Key: key, Password: password, }) if err != nil { t.Fatal(err) } return jwk } func TestMain(m *testing.M) { testRootCrt, testRootKey = mustSignCertificate("Test Root Certificate", nil, x509util.DefaultRootTemplate, nil, nil) testIssCrt, testIssKey = mustSignCertificate("Test Intermediate Certificate", nil, x509util.DefaultIntermediateTemplate, testRootCrt, testRootKey) testX5CCrt, testX5CKey = mustSignCertificate("Test X5C Certificate", nil, x509util.DefaultLeafTemplate, testIssCrt, testIssKey) testRootFingerprint = x509util.Fingerprint(testRootCrt) // Final certificate. var err error sans := []string{"doe.org", "jane@doe.org", "127.0.0.1", "::1", "localhost", "uuid:f81d4fae-7dec-11d0-a765-00a0c91e6bf6;name=value"} testCrt, testKey = mustSignCertificate("Test Certificate", sans, x509util.DefaultLeafTemplate, testIssCrt, testIssKey) testCR, err = x509util.CreateCertificateRequest("Test Certificate", sans, testKey) if err != nil { panic(err) } // CR used in errors. testFailCR, err = x509util.CreateCertificateRequest("", []string{"fail.doe.org"}, testKey) if err != nil { panic(err) } // Password used to encrypt the key. testPassword, err = randutil.Hex(32) if err != nil { panic(err) } // Encrypted JWK key used when the key is downloaded from the CA. jwe, err := jose.EncryptJWK(&jose.JSONWebKey{Key: testX5CKey}, []byte(testPassword)) if err != nil { panic(err) } testEncryptedJWKKey, err = jwe.CompactSerialize() if err != nil { panic(err) } testKeyID, err = jose.Thumbprint(&jose.JSONWebKey{Key: testX5CKey}) if err != nil { panic(err) } // Create test files. path, err := os.MkdirTemp(os.TempDir(), "stepcas") if err != nil { panic(err) } testRootPath = filepath.Join(path, "root_ca.crt") testRootKeyPath = filepath.Join(path, "root_ca.key") mustSerializeCrt(testRootPath, testRootCrt) mustSerializeKey(testRootKeyPath, testRootKey) testIssPath = filepath.Join(path, "intermediate_ca.crt") testIssKeyPath = filepath.Join(path, "intermediate_ca.key") mustSerializeCrt(testIssPath, testIssCrt) mustSerializeKey(testIssKeyPath, testIssKey) testX5CPath = filepath.Join(path, "x5c.crt") testX5CKeyPath = filepath.Join(path, "x5c.key") mustSerializeCrt(testX5CPath, testX5CCrt, testIssCrt) mustSerializeKey(testX5CKeyPath, testX5CKey) testEncryptedKeyPath = filepath.Join(path, "x5c.enc.key") mustEncryptKey(testEncryptedKeyPath, testX5CKey) code := m.Run() if err := os.RemoveAll(path); err != nil { panic(err) } os.Exit(code) } func Test_init(t *testing.T) { caURL, _ := testCAHelper(t) fn, ok := apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.StepCAS) if !ok { t.Errorf("apiv1.Register() ok = %v, want true", ok) return } fn(context.Background(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath, Key: testX5CKeyPath, }, }) } func TestNew(t *testing.T) { caURL, client := testCAHelper(t) signer, err := newJWKSignerFromEncryptedKey(testKeyID, testEncryptedJWKKey, testPassword) if err != nil { t.Fatal(err) } type args struct { ctx context.Context opts apiv1.Options } tests := []struct { name string args args want *StepCAS wantErr bool }{ {"ok", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath, Key: testX5CKeyPath, }, }}, &StepCAS{ iss: &x5cIssuer{ caURL: caURL, certFile: testX5CPath, keyFile: testX5CKeyPath, issuer: "X5C", }, client: client, fingerprint: testRootFingerprint, }, false}, {"ok jwk", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "ra@doe.org", Key: testX5CKeyPath, }, }}, &StepCAS{ iss: &jwkIssuer{ caURL: caURL, issuer: "ra@doe.org", signer: signer, }, client: client, fingerprint: testRootFingerprint, }, false}, {"ok jwk provisioners", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "ra@doe.org", Password: testPassword, }, }}, &StepCAS{ iss: &jwkIssuer{ caURL: caURL, issuer: "ra@doe.org", signer: signer, }, client: client, fingerprint: testRootFingerprint, }, false}, {"ok ca getter", args{context.TODO(), apiv1.Options{ IsCAGetter: true, CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "ra@doe.org", }, }}, &StepCAS{ iss: nil, client: client, fingerprint: testRootFingerprint, }, false}, {"fail authority", args{context.TODO(), apiv1.Options{ CertificateAuthority: "", CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath, Key: testX5CKeyPath, }, }}, nil, true}, {"fail fingerprint", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: "", CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath, Key: testX5CKeyPath, }, }}, nil, true}, {"fail type", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "", Provisioner: "X5C", Certificate: testX5CPath, Key: testX5CKeyPath, }, }}, nil, true}, {"fail provisioner", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "", Certificate: testX5CPath, Key: testX5CKeyPath, }, }}, nil, true}, {"fail provisioner jwk", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "", Key: testX5CKeyPath, }, }}, nil, true}, {"fail provisioner not found", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "notfound@doe.org", Password: testPassword, }, }}, nil, true}, {"fail invalid password", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "ra@doe.org", Password: "bad-password", }, }}, nil, true}, {"fail no key", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "empty@doe.org", Password: testPassword, }, }}, nil, true}, {"fail certificate", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: "", Key: testX5CKeyPath, }, }}, nil, true}, {"fail key", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath, Key: "", }, }}, nil, true}, {"fail key jwk", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "ra@smallstep.com", Key: "", }, }}, nil, true}, {"bad authority", args{context.TODO(), apiv1.Options{ CertificateAuthority: "https://foobar", CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath, Key: testX5CKeyPath, }, }}, nil, true}, {"fail parse url", args{context.TODO(), apiv1.Options{ CertificateAuthority: "::failparse", CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath, Key: testX5CKeyPath, }, }}, nil, true}, {"fail new client", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: "foobar", CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath, Key: testX5CKeyPath, }, }}, nil, true}, {"fail new x5c issuer", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "x5c", Provisioner: "X5C", Certificate: testX5CPath + ".missing", Key: testX5CKeyPath, }, }}, nil, true}, {"fail new jwk issuer", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "ra@doe.org", Key: testX5CKeyPath + ".missing", }, }}, nil, true}, {"bad issuer", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: nil}}, nil, true}, {"bad issuer type", args{context.TODO(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, CertificateIssuer: &apiv1.CertificateIssuer{ Type: "fail", Provisioner: "X5C", Certificate: testX5CPath, Key: testX5CKeyPath, }, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := New(tt.args.ctx, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } // We cannot compare neither the client nor the signer. if got != nil && tt.want != nil { got.client = tt.want.client if jwk, ok := got.iss.(*jwkIssuer); ok { jwk.signer = signer } } if !reflect.DeepEqual(got, tt.want) { t.Errorf("New() = %v, want %v", got, tt.want) } }) } } func TestStepCAS_Type(t *testing.T) { tests := []struct { name string want apiv1.Type }{ {"ok", apiv1.StepCAS}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &StepCAS{} if got := c.Type(); got != tt.want { t.Errorf("StepCAS.Type() = %v, want %v", got, tt.want) } }) } } func TestStepCAS_CreateCertificate(t *testing.T) { caURL, client := testCAHelper(t) x5c := testX5CIssuer(t, caURL, "") jwk := testJWKIssuer(t, caURL, "") x5cEnc := testX5CIssuer(t, caURL, testPassword) jwkEnc := testJWKIssuer(t, caURL, testPassword) x5cBad := testX5CIssuer(t, caURL, "bad-password") testTemplate := &x509.Certificate{ Subject: testCR.Subject, DNSNames: testCR.DNSNames, EmailAddresses: testCR.EmailAddresses, IPAddresses: testCR.IPAddresses, URIs: testCR.URIs, } testOtherCR, err := x509util.CreateCertificateRequest("Test Certificate", []string{"test.example.com"}, testKey) require.NoError(t, err) type fields struct { iss stepIssuer client *ca.Client fingerprint string } type args struct { req *apiv1.CreateCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.CreateCertificateResponse wantErr bool }{ {"ok", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, Template: testTemplate, Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, CertificateChain: []*x509.Certificate{testIssCrt}, }, false}, {"ok with different CSR", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testOtherCR, Template: testTemplate, Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, CertificateChain: []*x509.Certificate{testIssCrt}, }, false}, {"ok with password", fields{x5cEnc, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, Template: testTemplate, Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, CertificateChain: []*x509.Certificate{testIssCrt}, }, false}, {"ok jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, Template: testTemplate, Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, CertificateChain: []*x509.Certificate{testIssCrt}, }, false}, {"ok jwk with password", fields{jwkEnc, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, Template: testTemplate, Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, CertificateChain: []*x509.Certificate{testIssCrt}, }, false}, {"ok with provisioner", fields{jwk, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, Template: testTemplate, Lifetime: time.Hour, Provisioner: &apiv1.ProvisionerInfo{ID: "provisioner-id", Type: "ACME"}, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, CertificateChain: []*x509.Certificate{testIssCrt}, }, false}, {"ok with server cert", fields{jwk, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, Template: testTemplate, Lifetime: time.Hour, IsCAServerCert: true, }}, &apiv1.CreateCertificateResponse{ Certificate: testCrt, CertificateChain: []*x509.Certificate{testIssCrt}, }, false}, {"fail CSR", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: nil, Template: testTemplate, Lifetime: time.Hour, }}, nil, true}, {"fail Template", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, Template: nil, Lifetime: time.Hour, }}, nil, true}, {"fail lifetime", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, Lifetime: 0, }}, nil, true}, {"fail sign token", fields{mockErrIssuer{}, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, Lifetime: time.Hour, }}, nil, true}, {"fail client sign", fields{x5c, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testFailCR, Lifetime: time.Hour, }}, nil, true}, {"fail password", fields{x5cBad, client, testRootFingerprint}, args{&apiv1.CreateCertificateRequest{ CSR: testCR, Lifetime: time.Hour, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &StepCAS{ iss: tt.fields.iss, client: tt.fields.client, authorityID: "authority-id", fingerprint: tt.fields.fingerprint, } got, err := s.CreateCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("StepCAS.CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("StepCAS.CreateCertificate() = %v, want %v", got, tt.want) } }) } } func TestStepCAS_RenewCertificate(t *testing.T) { caURL, client := testCAHelper(t) jwk := testJWKIssuer(t, caURL, "") tokenIssuer := testX5CIssuer(t, caURL, "") token, err := tokenIssuer.SignToken("test", []string{"test.example.com"}, nil) if err != nil { t.Fatal(err) } type fields struct { iss stepIssuer client *ca.Client fingerprint string } type args struct { req *apiv1.RenewCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.RenewCertificateResponse wantErr bool }{ {"ok", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ Template: &x509.Certificate{}, Backdate: time.Minute, Lifetime: time.Hour, Token: token, }}, &apiv1.RenewCertificateResponse{ Certificate: testCrt, CertificateChain: []*x509.Certificate{testIssCrt}, }, false}, {"fail no token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ Template: &x509.Certificate{}, Backdate: time.Minute, Lifetime: time.Hour, }}, nil, true}, {"fail bad token", fields{jwk, client, testRootFingerprint}, args{&apiv1.RenewCertificateRequest{ Template: &x509.Certificate{}, Backdate: time.Minute, Lifetime: time.Hour, Token: "fail", }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &StepCAS{ iss: tt.fields.iss, client: tt.fields.client, fingerprint: tt.fields.fingerprint, } got, err := s.RenewCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("StepCAS.RenewCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Error(reflect.DeepEqual(got.Certificate, tt.want.Certificate)) t.Error(reflect.DeepEqual(got.CertificateChain, tt.want.CertificateChain)) t.Errorf("StepCAS.RenewCertificate() = %v, want %v", got.Certificate.Subject, tt.want.Certificate.Subject) } }) } } func TestStepCAS_RevokeCertificate(t *testing.T) { caURL, client := testCAHelper(t) x5c := testX5CIssuer(t, caURL, "") jwk := testJWKIssuer(t, caURL, "") x5cEnc := testX5CIssuer(t, caURL, testPassword) jwkEnc := testJWKIssuer(t, caURL, testPassword) x5cBad := testX5CIssuer(t, caURL, "bad-password") type fields struct { iss stepIssuer client *ca.Client fingerprint string } type args struct { req *apiv1.RevokeCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.RevokeCertificateResponse wantErr bool }{ {"ok serial number", fields{x5c, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "ok", Certificate: nil, }}, &apiv1.RevokeCertificateResponse{}, false}, {"ok certificate", fields{x5c, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "", Certificate: testCrt, }}, &apiv1.RevokeCertificateResponse{ Certificate: testCrt, }, false}, {"ok both", fields{x5c, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "ok", Certificate: testCrt, }}, &apiv1.RevokeCertificateResponse{ Certificate: testCrt, }, false}, {"ok with password", fields{x5cEnc, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "ok", Certificate: nil, }}, &apiv1.RevokeCertificateResponse{}, false}, {"ok serial number jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "ok", Certificate: nil, }}, &apiv1.RevokeCertificateResponse{}, false}, {"ok certificate jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "", Certificate: testCrt, }}, &apiv1.RevokeCertificateResponse{ Certificate: testCrt, }, false}, {"ok both jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "ok", Certificate: testCrt, }}, &apiv1.RevokeCertificateResponse{ Certificate: testCrt, }, false}, {"ok jwk with password", fields{jwkEnc, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "ok", Certificate: nil, }}, &apiv1.RevokeCertificateResponse{}, false}, {"fail request", fields{x5c, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "", Certificate: nil, }}, nil, true}, {"fail revoke token", fields{mockErrIssuer{}, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "ok", }}, nil, true}, {"fail client revoke", fields{x5c, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "fail", }}, nil, true}, {"fail password", fields{x5cBad, client, testRootFingerprint}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "ok", Certificate: nil, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &StepCAS{ iss: tt.fields.iss, client: tt.fields.client, fingerprint: tt.fields.fingerprint, } got, err := s.RevokeCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("StepCAS.RevokeCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("StepCAS.RevokeCertificate() = %v, want %v", got, tt.want) } }) } } func TestStepCAS_GetCertificateAuthority(t *testing.T) { caURL, client := testCAHelper(t) x5c := testX5CIssuer(t, caURL, "") jwk := testJWKIssuer(t, caURL, "") type fields struct { iss stepIssuer client *ca.Client fingerprint string } type args struct { req *apiv1.GetCertificateAuthorityRequest } tests := []struct { name string fields fields args args want *apiv1.GetCertificateAuthorityResponse wantErr bool }{ {"ok", fields{x5c, client, testRootFingerprint}, args{&apiv1.GetCertificateAuthorityRequest{ Name: caURL.String(), }}, &apiv1.GetCertificateAuthorityResponse{ RootCertificate: testRootCrt, }, false}, {"ok jwk", fields{jwk, client, testRootFingerprint}, args{&apiv1.GetCertificateAuthorityRequest{ Name: caURL.String(), }}, &apiv1.GetCertificateAuthorityResponse{ RootCertificate: testRootCrt, }, false}, {"fail fingerprint", fields{x5c, client, "fail"}, args{&apiv1.GetCertificateAuthorityRequest{ Name: caURL.String(), }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &StepCAS{ iss: tt.fields.iss, client: tt.fields.client, fingerprint: tt.fields.fingerprint, } got, err := s.GetCertificateAuthority(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("StepCAS.GetCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("StepCAS.GetCertificateAuthority() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cas/stepcas/x5c_issuer.go ================================================ package stepcas import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "net/url" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/cas/apiv1" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" ) const defaultValidity = 5 * time.Minute // timeNow returns the current time. // This method is used for unit testing purposes. var timeNow = time.Now type x5cIssuer struct { caURL *url.URL issuer string certFile string keyFile string password string } // newX5CIssuer create a new x5c token issuer. The given configuration should be // already validate. func newX5CIssuer(caURL *url.URL, cfg *apiv1.CertificateIssuer) (*x5cIssuer, error) { _, err := newX5CSigner(cfg.Certificate, cfg.Key, cfg.Password) if err != nil { return nil, err } return &x5cIssuer{ caURL: caURL, issuer: cfg.Provisioner, certFile: cfg.Certificate, keyFile: cfg.Key, password: cfg.Password, }, nil } func (i *x5cIssuer) SignToken(subject string, sans []string, info *raInfo) (string, error) { aud := i.caURL.ResolveReference(&url.URL{ Path: "/1.0/sign", Fragment: "x5c/" + i.issuer, }).String() return i.createToken(aud, subject, sans, info) } func (i *x5cIssuer) RevokeToken(subject string) (string, error) { aud := i.caURL.ResolveReference(&url.URL{ Path: "/1.0/revoke", Fragment: "x5c/" + i.issuer, }).String() return i.createToken(aud, subject, nil, nil) } func (i *x5cIssuer) Lifetime(d time.Duration) time.Duration { cert, err := pemutil.ReadCertificate(i.certFile, pemutil.WithFirstBlock()) if err != nil { return d } now := timeNow() if now.Add(d + time.Minute).After(cert.NotAfter) { return cert.NotAfter.Sub(now) - time.Minute } return d } func (i *x5cIssuer) createToken(aud, sub string, sans []string, info *raInfo) (string, error) { signer, err := newX5CSigner(i.certFile, i.keyFile, i.password) if err != nil { return "", err } id, err := randutil.Hex(64) // 256 bits if err != nil { return "", err } claims := defaultClaims(i.issuer, sub, aud, id) builder := jose.Signed(signer).Claims(claims) if len(sans) > 0 { builder = builder.Claims(map[string]interface{}{ "sans": sans, }) } if info != nil { builder = builder.Claims(map[string]interface{}{ "step": map[string]interface{}{ "ra": info, }, }) } tok, err := builder.CompactSerialize() if err != nil { return "", errors.Wrap(err, "error signing token") } return tok, nil } func defaultClaims(iss, sub, aud, id string) jose.Claims { now := timeNow() return jose.Claims{ ID: id, Issuer: iss, Subject: sub, Audience: jose.Audience{aud}, Expiry: jose.NewNumericDate(now.Add(defaultValidity)), NotBefore: jose.NewNumericDate(now), IssuedAt: jose.NewNumericDate(now), } } func readKey(keyFile, password string) (crypto.Signer, error) { var opts []pemutil.Options if password != "" { opts = append(opts, pemutil.WithPassword([]byte(password))) } key, err := pemutil.Read(keyFile, opts...) if err != nil { return nil, err } signer, ok := key.(crypto.Signer) if !ok { return nil, errors.New("key is not a crypto.Signer") } return signer, nil } func newX5CSigner(certFile, keyFile, password string) (jose.Signer, error) { signer, err := readKey(keyFile, password) if err != nil { return nil, err } kid, err := jose.Thumbprint(&jose.JSONWebKey{Key: signer.Public()}) if err != nil { return nil, err } certs, err := pemutil.ReadCertificateBundle(certFile) if err != nil { return nil, errors.Wrap(err, "error reading x5c certificate chain") } certStrs, err := jose.ValidateX5C(certs, signer) if err != nil { return nil, errors.Wrap(err, "error validating x5c certificate chain and key") } so := new(jose.SignerOptions) so.WithType("JWT") so.WithHeader("kid", kid) so.WithHeader("x5c", certStrs) return newJoseSigner(signer, so) } func newJoseSigner(key crypto.Signer, so *jose.SignerOptions) (jose.Signer, error) { var alg jose.SignatureAlgorithm switch k := key.Public().(type) { case *ecdsa.PublicKey: switch k.Curve.Params().Name { case "P-256": alg = jose.ES256 case "P-384": alg = jose.ES384 case "P-521": alg = jose.ES512 default: return nil, errors.Errorf("unsupported elliptic curve %s", k.Curve.Params().Name) } case ed25519.PublicKey: alg = jose.EdDSA case *rsa.PublicKey: alg = jose.DefaultRSASigAlgorithm default: return nil, errors.Errorf("unsupported key type %T", k) } signer, err := jose.NewSigner(jose.SigningKey{Algorithm: alg, Key: key}, so) if err != nil { return nil, errors.Wrap(err, "error creating jose.Signer") } return signer, nil } ================================================ FILE: cas/stepcas/x5c_issuer_test.go ================================================ package stepcas import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/rsa" "io" "net/url" "reflect" "testing" "time" "go.step.sm/crypto/jose" ) type noneSigner []byte func (b noneSigner) Public() crypto.PublicKey { return []byte(b) } func (b noneSigner) Sign(_ io.Reader, digest []byte, _ crypto.SignerOpts) (signature []byte, err error) { return digest, nil } //nolint:gocritic // ignore sloppy test func name func fakeTime(t *testing.T) { t.Helper() tmp := timeNow t.Cleanup(func() { timeNow = tmp }) timeNow = func() time.Time { return testX5CCrt.NotBefore } } func Test_x5cIssuer_SignToken(t *testing.T) { caURL, err := url.Parse("https://ca.smallstep.com") if err != nil { t.Fatal(err) } type fields struct { caURL *url.URL certFile string keyFile string issuer string } type args struct { subject string sans []string info *raInfo } type stepClaims struct { RA *raInfo `json:"ra"` } type claims struct { Aud jose.Audience `json:"aud"` Sub string `json:"sub"` Sans []string `json:"sans"` Step stepClaims `json:"step"` } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{caURL, testX5CPath, testX5CKeyPath, "X5C"}, args{"doe", []string{"doe.org"}, nil}, false}, {"ok ra", fields{caURL, testX5CPath, testX5CKeyPath, "X5C"}, args{"doe", []string{"doe.org"}, &raInfo{ AuthorityID: "authority-id", ProvisionerID: "provisioner-id", ProvisionerType: "provisioner-type", }}, false}, {"ok ra endpoint id", fields{caURL, testX5CPath, testX5CKeyPath, "X5C"}, args{"doe", []string{"doe.org"}, &raInfo{ AuthorityID: "authority-id", EndpointID: "endpoint-id", }}, false}, {"fail crt", fields{caURL, "", testX5CKeyPath, "X5C"}, args{"doe", []string{"doe.org"}, nil}, true}, {"fail key", fields{caURL, testX5CPath, "", "X5C"}, args{"doe", []string{"doe.org"}, nil}, true}, {"fail no signer", fields{caURL, testIssKeyPath, testIssPath, "X5C"}, args{"doe", []string{"doe.org"}, nil}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &x5cIssuer{ caURL: tt.fields.caURL, certFile: tt.fields.certFile, keyFile: tt.fields.keyFile, issuer: tt.fields.issuer, } got, err := i.SignToken(tt.args.subject, tt.args.sans, tt.args.info) if (err != nil) != tt.wantErr { t.Errorf("x5cIssuer.SignToken() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr { jwt, err := jose.ParseSigned(got) if err != nil { t.Errorf("jose.ParseSigned() error = %v", err) } var c claims want := claims{ Aud: []string{tt.fields.caURL.String() + "/1.0/sign#x5c/X5C"}, Sub: tt.args.subject, Sans: tt.args.sans, } if tt.args.info != nil { want.Step.RA = tt.args.info } if err := jwt.Claims(testX5CKey.Public(), &c); err != nil { t.Errorf("jwt.Claims() error = %v", err) } if !reflect.DeepEqual(c, want) { t.Errorf("jwt.Claims() claims = %#v, want %#v", c, want) } } }) } } func Test_x5cIssuer_RevokeToken(t *testing.T) { caURL, err := url.Parse("https://ca.smallstep.com") if err != nil { t.Fatal(err) } type fields struct { caURL *url.URL certFile string keyFile string issuer string } type args struct { subject string } type claims struct { Aud jose.Audience `json:"aud"` Sub string `json:"sub"` Sans []string `json:"sans"` } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{caURL, testX5CPath, testX5CKeyPath, "X5C"}, args{"doe"}, false}, {"fail crt", fields{caURL, "", testX5CKeyPath, "X5C"}, args{"doe"}, true}, {"fail key", fields{caURL, testX5CPath, "", "X5C"}, args{"doe"}, true}, {"fail no signer", fields{caURL, testIssKeyPath, testIssPath, "X5C"}, args{"doe"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &x5cIssuer{ caURL: tt.fields.caURL, certFile: tt.fields.certFile, keyFile: tt.fields.keyFile, issuer: tt.fields.issuer, } got, err := i.RevokeToken(tt.args.subject) if (err != nil) != tt.wantErr { t.Errorf("x5cIssuer.RevokeToken() error = %v, wantErr %v", err, tt.wantErr) return } if !tt.wantErr { jwt, err := jose.ParseSigned(got) if err != nil { t.Errorf("jose.ParseSigned() error = %v", err) } var c claims want := claims{ Aud: []string{tt.fields.caURL.String() + "/1.0/revoke#x5c/X5C"}, Sub: tt.args.subject, } if err := jwt.Claims(testX5CKey.Public(), &c); err != nil { t.Errorf("jwt.Claims() error = %v", err) } if !reflect.DeepEqual(c, want) { t.Errorf("jwt.Claims() claims = %#v, want %#v", c, want) } } }) } } func Test_x5cIssuer_Lifetime(t *testing.T) { fakeTime(t) caURL, err := url.Parse("https://ca.smallstep.com") if err != nil { t.Fatal(err) } // With a leeway of 1m the max duration will be 59m. maxDuration := testX5CCrt.NotAfter.Sub(timeNow()) - time.Minute type fields struct { caURL *url.URL certFile string keyFile string issuer string } type args struct { d time.Duration } tests := []struct { name string fields fields args args want time.Duration }{ {"ok 0s", fields{caURL, testX5CPath, testX5CKeyPath, "X5C"}, args{0}, 0}, {"ok 1m", fields{caURL, testX5CPath, testX5CKeyPath, "X5C"}, args{time.Minute}, time.Minute}, {"ok max-1m", fields{caURL, testX5CPath, testX5CKeyPath, "X5C"}, args{maxDuration - time.Minute}, maxDuration - time.Minute}, {"ok max", fields{caURL, testX5CPath, testX5CKeyPath, "X5C"}, args{maxDuration}, maxDuration}, {"ok max+1m", fields{caURL, testX5CPath, testX5CKeyPath, "X5C"}, args{maxDuration + time.Minute}, maxDuration}, {"ok fail", fields{caURL, testX5CPath + ".missing", testX5CKeyPath, "X5C"}, args{maxDuration + time.Minute}, maxDuration + time.Minute}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { i := &x5cIssuer{ caURL: tt.fields.caURL, certFile: tt.fields.certFile, keyFile: tt.fields.keyFile, issuer: tt.fields.issuer, } if got := i.Lifetime(tt.args.d); got != tt.want { t.Errorf("x5cIssuer.Lifetime() = %v, want %v", got, tt.want) } }) } } func Test_newJoseSigner(t *testing.T) { mustSigner := func(args ...interface{}) crypto.Signer { if err := args[len(args)-1]; err != nil { t.Fatal(err) } for _, a := range args { if s, ok := a.(crypto.Signer); ok { return s } } t.Fatal("signer not found") return nil } p224 := mustSigner(ecdsa.GenerateKey(elliptic.P224(), rand.Reader)) p256 := mustSigner(ecdsa.GenerateKey(elliptic.P256(), rand.Reader)) p384 := mustSigner(ecdsa.GenerateKey(elliptic.P384(), rand.Reader)) p521 := mustSigner(ecdsa.GenerateKey(elliptic.P521(), rand.Reader)) edKey := mustSigner(ed25519.GenerateKey(rand.Reader)) rsaKey := mustSigner(rsa.GenerateKey(rand.Reader, 2048)) type args struct { key crypto.Signer so *jose.SignerOptions } tests := []struct { name string args args want []jose.Header wantErr bool }{ {"p256", args{p256, nil}, []jose.Header{{Algorithm: "ES256"}}, false}, {"p384", args{p384, new(jose.SignerOptions).WithType("JWT")}, []jose.Header{{Algorithm: "ES384", ExtraHeaders: map[jose.HeaderKey]interface{}{"typ": "JWT"}}}, false}, {"p521", args{p521, new(jose.SignerOptions).WithHeader("kid", "the-kid")}, []jose.Header{{Algorithm: "ES512", KeyID: "the-kid"}}, false}, {"ed25519", args{edKey, nil}, []jose.Header{{Algorithm: "EdDSA"}}, false}, {"rsa", args{rsaKey, nil}, []jose.Header{{Algorithm: "RS256"}}, false}, {"fail p224", args{p224, nil}, nil, true}, {"fail signer", args{noneSigner{1, 2, 3}, nil}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := newJoseSigner(tt.args.key, tt.args.so) if (err != nil) != tt.wantErr { t.Errorf("newJoseSigner() error = %v, wantErr %v", err, tt.wantErr) return } if !tt.wantErr { jws, err := got.Sign([]byte("{}")) if err != nil { t.Errorf("jose.Signer.Sign() err = %v", err) } jwt, err := jose.ParseSigned(jws.FullSerialize()) if err != nil { t.Errorf("jose.ParseSigned() err = %v", err) } if !reflect.DeepEqual(jwt.Headers, tt.want) { t.Errorf("jose.Header got = %v, want = %v", jwt.Headers, tt.want) } } }) } } ================================================ FILE: cas/vaultcas/auth/approle/approle.go ================================================ package approle import ( "encoding/json" "errors" "fmt" "github.com/hashicorp/vault/api/auth/approle" ) // AuthOptions defines the configuration options added using the // VaultOptions.AuthOptions field when AuthType is approle type AuthOptions struct { RoleID string `json:"roleID,omitempty"` SecretID string `json:"secretID,omitempty"` SecretIDFile string `json:"secretIDFile,omitempty"` SecretIDEnv string `json:"secretIDEnv,omitempty"` IsWrappingToken bool `json:"isWrappingToken,omitempty"` } func NewApproleAuthMethod(mountPath string, options json.RawMessage) (*approle.AppRoleAuth, error) { var opts *AuthOptions err := json.Unmarshal(options, &opts) if err != nil { return nil, fmt.Errorf("error decoding AppRole auth options: %w", err) } var approleAuth *approle.AppRoleAuth var loginOptions []approle.LoginOption if mountPath != "" { loginOptions = append(loginOptions, approle.WithMountPath(mountPath)) } if opts.IsWrappingToken { loginOptions = append(loginOptions, approle.WithWrappingToken()) } if opts.RoleID == "" { return nil, errors.New("you must set roleID") } var sid approle.SecretID switch { case opts.SecretID != "" && opts.SecretIDFile == "" && opts.SecretIDEnv == "": sid = approle.SecretID{ FromString: opts.SecretID, } case opts.SecretIDFile != "" && opts.SecretID == "" && opts.SecretIDEnv == "": sid = approle.SecretID{ FromFile: opts.SecretIDFile, } case opts.SecretIDEnv != "" && opts.SecretIDFile == "" && opts.SecretID == "": sid = approle.SecretID{ FromEnv: opts.SecretIDEnv, } default: return nil, errors.New("you must set one of secretID, secretIDFile or secretIDEnv") } approleAuth, err = approle.NewAppRoleAuth(opts.RoleID, &sid, loginOptions...) if err != nil { return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err) } return approleAuth, nil } ================================================ FILE: cas/vaultcas/auth/approle/approle_test.go ================================================ package approle import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "testing" vault "github.com/hashicorp/vault/api" ) func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { t.Helper() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.RequestURI { case "/v1/auth/approle/login": w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{ "auth": { "client_token": "hvs.0000" } }`) case "/v1/auth/custom-approle/login": w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{ "auth": { "client_token": "hvs.9999" } }`) default: w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, `{"error":"not found"}`) } })) t.Cleanup(func() { srv.Close() }) u, err := url.Parse(srv.URL) if err != nil { srv.Close() t.Fatal(err) } config := vault.DefaultConfig() config.Address = srv.URL client, err := vault.NewClient(config) if err != nil { srv.Close() t.Fatal(err) } return u, client } func TestApprole_LoginMountPaths(t *testing.T) { caURL, _ := testCAHelper(t) config := vault.DefaultConfig() config.Address = caURL.String() client, _ := vault.NewClient(config) tests := []struct { name string mountPath string token string }{ { name: "ok default mount path", mountPath: "", token: "hvs.0000", }, { name: "ok explicit mount path", mountPath: "approle", token: "hvs.0000", }, { name: "ok custom mount path", mountPath: "custom-approle", token: "hvs.9999", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { method, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`)) if err != nil { t.Errorf("NewApproleAuthMethod() error = %v", err) return } secret, err := client.Auth().Login(context.Background(), method) if err != nil { t.Errorf("Login() error = %v", err) return } token, _ := secret.TokenID() if token != tt.token { t.Errorf("Token error got %v, expected %v", token, tt.token) return } }) } } func TestApprole_NewApproleAuthMethod(t *testing.T) { tests := []struct { name string mountPath string raw string wantErr bool }{ { "ok secret-id string", "", `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000"}`, false, }, { "ok secret-id string and wrapped", "", `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`, false, }, { "ok secret-id string and wrapped with custom mountPath", "approle2", `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`, false, }, { "ok secret-id file", "", `{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`, false, }, { "ok secret-id env", "", `{"RoleID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, false, }, { "fail mandatory role-id", "", `{}`, true, }, { "fail mandatory secret-id any", "", `{"RoleID": "0000-0000-0000-0000"}`, true, }, { "fail multiple secret-id types id and env", "", `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, true, }, { "fail multiple secret-id types id and file", "", `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`, true, }, { "fail multiple secret-id types env and file", "", `{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, true, }, { "fail multiple secret-id types all", "", `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(tt.raw)) if (err != nil) != tt.wantErr { t.Errorf("Approle.NewApproleAuthMethod() error = %v, wantErr %v", err, tt.wantErr) return } }) } } ================================================ FILE: cas/vaultcas/auth/aws/aws.go ================================================ package aws import ( "encoding/json" "fmt" "github.com/hashicorp/vault/api/auth/aws" ) // AuthOptions defines the configuration options added using the // VaultOptions.AuthOptions field when AuthType is aws. // This maps directly to Vault's AWS Login options, // see: https://developer.hashicorp.com/vault/api-docs/auth/aws#login type AuthOptions struct { Role string `json:"role,omitempty"` Region string `json:"region,omitempty"` AwsAuthType string `json:"awsAuthType,omitempty"` // options specific to 'iam' auth type IamServerIDHeader string `json:"iamServerIdHeader"` // options specific to 'ec2' auth type SignatureType string `json:"signatureType,omitempty"` Nonce string `json:"nonce,omitempty"` } func NewAwsAuthMethod(mountPath string, options json.RawMessage) (*aws.AWSAuth, error) { var opts *AuthOptions err := json.Unmarshal(options, &opts) if err != nil { return nil, fmt.Errorf("error decoding AWS auth options: %w", err) } var awsAuth *aws.AWSAuth var loginOptions []aws.LoginOption if mountPath != "" { loginOptions = append(loginOptions, aws.WithMountPath(mountPath)) } if opts.Role != "" { loginOptions = append(loginOptions, aws.WithRole(opts.Role)) } if opts.Region != "" { loginOptions = append(loginOptions, aws.WithRegion(opts.Region)) } switch opts.AwsAuthType { case "iam": loginOptions = append(loginOptions, aws.WithIAMAuth()) if opts.IamServerIDHeader != "" { loginOptions = append(loginOptions, aws.WithIAMServerIDHeader(opts.IamServerIDHeader)) } case "ec2": loginOptions = append(loginOptions, aws.WithEC2Auth()) switch opts.SignatureType { case "pkcs7": loginOptions = append(loginOptions, aws.WithPKCS7Signature()) case "identity": loginOptions = append(loginOptions, aws.WithIdentitySignature()) case "rsa2048": loginOptions = append(loginOptions, aws.WithRSA2048Signature()) case "": // no-op default: return nil, fmt.Errorf("unknown SignatureType type %q; valid options are 'pkcs7', 'identity' and 'rsa2048'", opts.SignatureType) } if opts.Nonce != "" { loginOptions = append(loginOptions, aws.WithNonce(opts.Nonce)) } default: return nil, fmt.Errorf("unknown awsAuthType %q; valid options are 'iam' and 'ec2'", opts.AwsAuthType) } awsAuth, err = aws.NewAWSAuth(loginOptions...) if err != nil { return nil, fmt.Errorf("unable to initialize AWS auth method: %w", err) } return awsAuth, nil } ================================================ FILE: cas/vaultcas/auth/aws/aws_test.go ================================================ package aws import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "testing" vault "github.com/hashicorp/vault/api" ) func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { t.Helper() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.RequestURI { case "/v1/auth/aws/login": w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{ "auth": { "client_token": "hvs.0000" } }`) case "/v1/auth/custom-aws/login": w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{ "auth": { "client_token": "hvs.9999" } }`) default: w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, `{"error":"not found"}`) } })) t.Cleanup(func() { srv.Close() }) u, err := url.Parse(srv.URL) if err != nil { srv.Close() t.Fatal(err) } config := vault.DefaultConfig() config.Address = srv.URL client, err := vault.NewClient(config) if err != nil { srv.Close() t.Fatal(err) } return u, client } func TestAws_LoginMountPaths(t *testing.T) { _, client := testCAHelper(t) // Dummy AWS credentials is needed for Vault client to sign the STS request t.Setenv("AWS_ACCESS_KEY_ID", "AKIAIOSFODNN7EXAMPLE") t.Setenv("AWS_SECRET_ACCESS_KEY", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY") tests := []struct { name string mountPath string token string }{ { name: "ok default mount path", mountPath: "", token: "hvs.0000", }, { name: "ok explicit mount path", mountPath: "aws", token: "hvs.0000", }, { name: "ok custom mount path", mountPath: "custom-aws", token: "hvs.9999", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { method, err := NewAwsAuthMethod(tt.mountPath, json.RawMessage(`{"role":"test-role","awsAuthType":"iam"}`)) if err != nil { t.Errorf("NewAwsAuthMethod() error = %v", err) return } secret, err := client.Auth().Login(context.Background(), method) if err != nil { t.Errorf("Login() error = %v", err) return } token, _ := secret.TokenID() if token != tt.token { t.Errorf("Token error got %v, expected %v", token, tt.token) return } }) } } func TestAws_NewAwsAuthMethod(t *testing.T) { tests := []struct { name string mountPath string raw string wantErr bool }{ { "ok iam", "", `{"role":"test-role","awsAuthType":"iam"}`, false, }, { "ok iam with region", "", `{"role":"test-role","awsAuthType":"iam","region":"us-east-1"}`, false, }, { "ok iam with header", "", `{"role":"test-role","awsAuthType":"iam","iamServerIdHeader":"vault.example.com"}`, false, }, { "ok ec2", "", `{"role":"test-role","awsAuthType":"ec2"}`, false, }, { "ok ec2 with nonce", "", `{"role":"test-role","awsAuthType":"ec2","nonce": "0000-0000-0000-0000"}`, false, }, { "ok ec2 with signature type", "", `{"role":"test-role","awsAuthType":"ec2","signatureType":"rsa2048"}`, false, }, { "fail mandatory role", "", `{}`, true, }, { "fail mandatory auth type", "", `{"role":"test-role"}`, true, }, { "fail invalid auth type", "", `{"role":"test-role","awsAuthType":"test"}`, true, }, { "fail invalid ec2 signature type", "", `{"role":"test-role","awsAuthType":"test","signatureType":"test"}`, true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := NewAwsAuthMethod(tt.mountPath, json.RawMessage(tt.raw)) if (err != nil) != tt.wantErr { t.Errorf("Aws.NewAwsAuthMethod() error = %v, wantErr %v", err, tt.wantErr) return } }) } } ================================================ FILE: cas/vaultcas/auth/kubernetes/kubernetes.go ================================================ package kubernetes import ( "encoding/json" "errors" "fmt" "github.com/hashicorp/vault/api/auth/kubernetes" ) // AuthOptions defines the configuration options added using the // VaultOptions.AuthOptions field when AuthType is kubernetes type AuthOptions struct { Role string `json:"role,omitempty"` TokenPath string `json:"tokenPath,omitempty"` } func NewKubernetesAuthMethod(mountPath string, options json.RawMessage) (*kubernetes.KubernetesAuth, error) { var opts *AuthOptions err := json.Unmarshal(options, &opts) if err != nil { return nil, fmt.Errorf("error decoding Kubernetes auth options: %w", err) } var kubernetesAuth *kubernetes.KubernetesAuth var loginOptions []kubernetes.LoginOption if mountPath != "" { loginOptions = append(loginOptions, kubernetes.WithMountPath(mountPath)) } if opts.TokenPath != "" { loginOptions = append(loginOptions, kubernetes.WithServiceAccountTokenPath(opts.TokenPath)) } if opts.Role == "" { return nil, errors.New("you must set role") } kubernetesAuth, err = kubernetes.NewKubernetesAuth( opts.Role, loginOptions..., ) if err != nil { return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err) } return kubernetesAuth, nil } ================================================ FILE: cas/vaultcas/auth/kubernetes/kubernetes_test.go ================================================ package kubernetes import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "path" "path/filepath" "runtime" "testing" vault "github.com/hashicorp/vault/api" ) func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { t.Helper() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.RequestURI { case "/v1/auth/kubernetes/login": w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{ "auth": { "client_token": "hvs.0000" } }`) case "/v1/auth/custom-kubernetes/login": w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{ "auth": { "client_token": "hvs.9999" } }`) default: w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, `{"error":"not found"}`) } })) t.Cleanup(func() { srv.Close() }) u, err := url.Parse(srv.URL) if err != nil { srv.Close() t.Fatal(err) } config := vault.DefaultConfig() config.Address = srv.URL client, err := vault.NewClient(config) if err != nil { srv.Close() t.Fatal(err) } return u, client } func TestApprole_LoginMountPaths(t *testing.T) { caURL, _ := testCAHelper(t) _, filename, _, _ := runtime.Caller(0) tokenPath := filepath.Join(path.Dir(filename), "token") config := vault.DefaultConfig() config.Address = caURL.String() client, _ := vault.NewClient(config) tests := []struct { name string mountPath string token string }{ { name: "ok default mount path", mountPath: "", token: "hvs.0000", }, { name: "ok explicit mount path", mountPath: "kubernetes", token: "hvs.0000", }, { name: "ok custom mount path", mountPath: "custom-kubernetes", token: "hvs.9999", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { method, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(`{"role": "SomeRoleName", "tokenPath": "`+tokenPath+`"}`)) if err != nil { t.Errorf("NewApproleAuthMethod() error = %v", err) return } secret, err := client.Auth().Login(context.Background(), method) if err != nil { t.Errorf("Login() error = %v", err) return } token, _ := secret.TokenID() if token != tt.token { t.Errorf("Token error got %v, expected %v", token, tt.token) return } }) } } func TestApprole_NewApproleAuthMethod(t *testing.T) { _, filename, _, _ := runtime.Caller(0) tokenPath := filepath.Join(path.Dir(filename), "token") tests := []struct { name string mountPath string raw string wantErr bool }{ { "ok secret-id string", "", `{"role": "SomeRoleName", "tokenPath": "` + tokenPath + `"}`, false, }, { "fail mandatory role", "", `{}`, true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(tt.raw)) if (err != nil) != tt.wantErr { t.Errorf("Kubernetes.NewKubernetesAuthMethod() error = %v, wantErr %v", err, tt.wantErr) return } }) } } ================================================ FILE: cas/vaultcas/auth/kubernetes/token ================================================ token ================================================ FILE: cas/vaultcas/vaultcas.go ================================================ package vaultcas import ( "bytes" "context" "crypto/sha256" "crypto/x509" "encoding/hex" "encoding/json" "encoding/pem" "errors" "fmt" "math/big" "strings" "time" "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/cas/vaultcas/auth/approle" "github.com/smallstep/certificates/cas/vaultcas/auth/aws" "github.com/smallstep/certificates/cas/vaultcas/auth/kubernetes" vault "github.com/hashicorp/vault/api" ) func init() { apiv1.Register(apiv1.VaultCAS, func(ctx context.Context, opts apiv1.Options) (apiv1.CertificateAuthorityService, error) { return New(ctx, opts) }) } // VaultOptions defines the configuration options added using the // apiv1.Options.Config field. type VaultOptions struct { PKIMountPath string `json:"pkiMountPath,omitempty"` PKIRoleDefault string `json:"pkiRoleDefault,omitempty"` PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` PKIRoleEC string `json:"pkiRoleEC,omitempty"` PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` AuthType string `json:"authType,omitempty"` AuthMountPath string `json:"authMountPath,omitempty"` Namespace string `json:"namespace,omitempty"` AuthOptions json.RawMessage `json:"authOptions,omitempty"` } // VaultCAS implements a Certificate Authority Service using Hashicorp Vault. type VaultCAS struct { client *vault.Client config VaultOptions fingerprint string } type certBundle struct { leaf *x509.Certificate intermediates []*x509.Certificate root *x509.Certificate } // New creates a new CertificateAuthorityService implementation // using Hashicorp Vault func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) { if opts.CertificateAuthority == "" { return nil, errors.New("vaultCAS 'certificateAuthority' cannot be empty") } if opts.CertificateAuthorityFingerprint == "" { return nil, errors.New("vaultCAS 'certificateAuthorityFingerprint' cannot be empty") } vc, err := loadOptions(opts.Config) if err != nil { return nil, err } config := vault.DefaultConfig() config.Address = opts.CertificateAuthority client, err := vault.NewClient(config) if err != nil { return nil, fmt.Errorf("unable to initialize vault client: %w", err) } var method vault.AuthMethod switch vc.AuthType { case "kubernetes": method, err = kubernetes.NewKubernetesAuthMethod(vc.AuthMountPath, vc.AuthOptions) case "approle": method, err = approle.NewApproleAuthMethod(vc.AuthMountPath, vc.AuthOptions) case "aws": method, err = aws.NewAwsAuthMethod(vc.AuthMountPath, vc.AuthOptions) default: return nil, fmt.Errorf("unknown auth type: %s, only 'kubernetes' and 'approle' currently supported", vc.AuthType) } if err != nil { return nil, fmt.Errorf("unable to configure %s auth method: %w", vc.AuthType, err) } if vc.Namespace != "" { client.SetNamespace(vc.Namespace) } authInfo, err := client.Auth().Login(ctx, method) if err != nil { return nil, fmt.Errorf("unable to login to %s auth method: %w", vc.AuthType, err) } if authInfo == nil { return nil, errors.New("no auth info was returned after login") } return &VaultCAS{ client: client, config: *vc, fingerprint: opts.CertificateAuthorityFingerprint, }, nil } // Type returns the type of this CertificateAuthorityService. func (v *VaultCAS) Type() apiv1.Type { return apiv1.VaultCAS } // CreateCertificate signs a new certificate using Hashicorp Vault. func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1.CreateCertificateResponse, error) { switch { case req.CSR == nil: return nil, errors.New("createCertificate `csr` cannot be nil") case req.Lifetime == 0: return nil, errors.New("createCertificate `lifetime` cannot be 0") } cert, chain, err := v.createCertificate(req.CSR, req.Lifetime) if err != nil { return nil, err } return &apiv1.CreateCertificateResponse{ Certificate: cert, CertificateChain: chain, }, nil } // GetCertificateAuthority returns the root certificate of the certificate // authority using the configured fingerprint. func (v *VaultCAS) GetCertificateAuthority(*apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) { secret, err := v.client.Logical().Read(v.config.PKIMountPath + "/cert/ca_chain") if err != nil { return nil, fmt.Errorf("error reading ca chain: %w", err) } if secret == nil { return nil, errors.New("error reading ca chain: response is empty") } chain, ok := secret.Data["certificate"].(string) if !ok { return nil, errors.New("error unmarshaling vault response: certificate not found") } cert, err := getCertificateBundle(chain) if err != nil { return nil, err } if cert.root == nil { return nil, errors.New("error unmarshaling vault response: root certificate not found") } sum := sha256.Sum256(cert.root.Raw) if !strings.EqualFold(v.fingerprint, strings.ToLower(hex.EncodeToString(sum[:]))) { return nil, errors.New("error verifying vault root: fingerprint does not match") } return &apiv1.GetCertificateAuthorityResponse{ RootCertificate: cert.root, IntermediateCertificates: cert.intermediates, }, nil } // RenewCertificate will always return a non-implemented error as renewals // are not supported yet. func (v *VaultCAS) RenewCertificate(*apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { return nil, apiv1.NotImplementedError{Message: "vaultCAS does not support renewals"} } // RevokeCertificate revokes a certificate by serial number. func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { if req.SerialNumber == "" && req.Certificate == nil { return nil, errors.New("revokeCertificate `serialNumber` or `certificate` are required") } var sn *big.Int if req.SerialNumber != "" { var ok bool if sn, ok = new(big.Int).SetString(req.SerialNumber, 10); !ok { return nil, fmt.Errorf("error parsing serialNumber: %v cannot be converted to big.Int", req.SerialNumber) } } else { sn = req.Certificate.SerialNumber } vaultReq := map[string]interface{}{ "serial_number": formatSerialNumber(sn), } _, err := v.client.Logical().Write(v.config.PKIMountPath+"/revoke/", vaultReq) if err != nil { return nil, fmt.Errorf("error revoking certificate: %w", err) } return &apiv1.RevokeCertificateResponse{ Certificate: req.Certificate, CertificateChain: nil, }, nil } func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) { var vaultPKIRole string switch cr.PublicKeyAlgorithm { case x509.RSA: vaultPKIRole = v.config.PKIRoleRSA case x509.ECDSA: vaultPKIRole = v.config.PKIRoleEC case x509.Ed25519: vaultPKIRole = v.config.PKIRoleEd25519 default: return nil, nil, fmt.Errorf("unsupported public key algorithm %v", cr.PublicKeyAlgorithm) } vaultReq := map[string]interface{}{ "csr": string(pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE REQUEST", Bytes: cr.Raw, })), "format": "pem_bundle", "ttl": lifetime.String(), } secret, err := v.client.Logical().Write(v.config.PKIMountPath+"/sign/"+vaultPKIRole, vaultReq) if err != nil { return nil, nil, fmt.Errorf("error signing certificate: %w", err) } if secret == nil { return nil, nil, errors.New("error signing certificate: response is empty") } chain, ok := secret.Data["certificate"].(string) if !ok { return nil, nil, errors.New("error unmarshaling vault response: certificate not found") } cert, err := getCertificateBundle(chain) if err != nil { return nil, nil, err } // Return certificate and certificate chain return cert.leaf, cert.intermediates, nil } func loadOptions(config json.RawMessage) (*VaultOptions, error) { // setup default values vc := VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "default", } err := json.Unmarshal(config, &vc) if err != nil { return nil, fmt.Errorf("error decoding vaultCAS config: %w", err) } if vc.PKIRoleRSA == "" { vc.PKIRoleRSA = vc.PKIRoleDefault } if vc.PKIRoleEC == "" { vc.PKIRoleEC = vc.PKIRoleDefault } if vc.PKIRoleEd25519 == "" { vc.PKIRoleEd25519 = vc.PKIRoleDefault } return &vc, nil } func parseCertificates(pemCert string) []*x509.Certificate { var certs []*x509.Certificate rest := []byte(pemCert) var block *pem.Block for { block, rest = pem.Decode(rest) if block == nil { break } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { break } certs = append(certs, cert) } return certs } func getCertificateBundle(chain string) (*certBundle, error) { var root *x509.Certificate var leaf *x509.Certificate var intermediates []*x509.Certificate for _, cert := range parseCertificates(chain) { switch { case isRoot(cert): root = cert case cert.BasicConstraintsValid && cert.IsCA: intermediates = append(intermediates, cert) default: leaf = cert } } certificate := &certBundle{ root: root, leaf: leaf, intermediates: intermediates, } return certificate, nil } // isRoot returns true if the given certificate is a root certificate. func isRoot(cert *x509.Certificate) bool { if cert.BasicConstraintsValid && cert.IsCA { return cert.CheckSignatureFrom(cert) == nil } return false } // formatSerialNumber formats a serial number to a dash-separated hexadecimal // string. func formatSerialNumber(sn *big.Int) string { var ret bytes.Buffer for _, b := range sn.Bytes() { if ret.Len() > 0 { ret.WriteString("-") } ret.WriteString(hex.EncodeToString([]byte{b})) } return ret.String() } ================================================ FILE: cas/vaultcas/vaultcas_test.go ================================================ package vaultcas import ( "bytes" "context" "crypto/x509" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "reflect" "testing" "time" vault "github.com/hashicorp/vault/api" "github.com/smallstep/certificates/cas/apiv1" "go.step.sm/crypto/pemutil" ) var ( testCertificateSigned = `-----BEGIN CERTIFICATE----- MIIB/DCCAaKgAwIBAgIQHHFuGMz0cClfde5kqP5prTAKBggqhkjOPQQDAjAqMSgw JgYDVQQDEx9Hb29nbGUgQ0FTIFRlc3QgSW50ZXJtZWRpYXRlIENBMB4XDTIwMDkx NTAwMDQ0M1oXDTMwMDkxMzAwMDQ0MFowHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0 ZXAuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEMqNCiXMvbn74LsHzRv+8 17m9vEzH6RHrg3m82e0uEc36+fZWV/zJ9SKuONmnl5VP79LsjL5SVH0RDj73U2XO DKOBtjCBszAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG AQUFBwMCMB0GA1UdDgQWBBRTA2cTs7PCNjnps/+T0dS8diqv0DAfBgNVHSMEGDAW gBRIOVqyLDSlErJLuWWEvRm5UU1r1TBCBgwrBgEEAYKkZMYoQAIEMjAwEwhjbG91 ZGNhcxMkZDhkMThhNjgtNTI5Ni00YWYzLWFlNGItMmY4NzdkYTNmYmQ5MAoGCCqG SM49BAMCA0gAMEUCIGxl+pqJ50WYWUqK2l4V1FHoXSi0Nht5kwTxFxnWZu1xAiEA zemu3bhWLFaGg3s8i+HTEhw4RqkHP74vF7AVYp88bAw= -----END CERTIFICATE-----` testCertificateCsrEc = `-----BEGIN CERTIFICATE REQUEST----- MIHoMIGPAgEAMA0xCzAJBgNVBAMTAkVDMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD QgAEUVVVZGD6eUrB20T/qrjKZoYzseQ18AIm9jtUNpQn5hIClpdk2zKy5bja3iUa nmqRKCIz/B/MU55zuNDeckqqX6AgMB4GCSqGSIb3DQEJDjERMA8wDQYDVR0RBAYw BIICRUMwCgYIKoZIzj0EAwIDSAAwRQIhAJxpWyH7cctbzcnK1JBWDAmc/G61bq9y otHrQDfYvS8bAiBVGQz2cfO2SqhvkkQbOqWUFjk1wHzISvlTjyc3IJ7FLw== -----END CERTIFICATE REQUEST-----` testCertificateCsrRsa = `-----BEGIN CERTIFICATE REQUEST----- MIICdDCCAVwCAQAwDjEMMAoGA1UEAxMDUlNBMIIBIjANBgkqhkiG9w0BAQEFAAOC AQ8AMIIBCgKCAQEAxe5XLSZrTCzzH0FJCXvZwghAY5XztzjseSRcm0jL8Q7nvNWi Vpu1n7EmfVU9b8sbvtVYqMQV+hMdj2C/NIw4Yal4Wg+BgunYOrRqfY7oDm4csG0R g5v0h2yQw14kqVrftNyojX0Nv/CPboCGl64PA9zsEXQTB3Y1AUWrUGPiBWNACYIH mjv70Ay9JKBBAqov38I7nka/RgYAl5DCHzU2vvODriBYFWagnzycA4Ni5EKTz93W SPdDEhkWi3ugUqal3SvgHl8re+8d7ghLn85Y3TFuyU2nSMDPHaymsiNFw1mRwOw3 lAseidHJkPQs7q6FiYXaeqetf1j/gw0n23ZogwIDAQABoCEwHwYJKoZIhvcNAQkO MRIwEDAOBgNVHREEBzAFggNSU0EwDQYJKoZIhvcNAQELBQADggEBALnO5vcDkgGO GQoSINa2NmNFxAtYQGYHok5KXYX+S+etmOmDrmrhsl/pSjN3GPCPlThFlbLStB70 oJw67nEjGf0hPEBVlm+qFUsYQ1KGRZFAWDSMQ//pU225XFDCmlzHfV7gZjSkP9GN Gc5VECOzx6hAFR+IEL/l/1GG5HHkPPrr/8OvuIfm2V5ofYmhsXMVVYH52qPofMAV B8UdNnZK3nyLdUqVd+PYUUJmN4bJ8YfxofKKgbLkhvkKp4OZ9vkwUi2+61NdHTf2 wIauOyxEoTlJpU6oA/sxu/2Ht2DP+8y6mognLBuKklE/VH3/2iqQWyg1NV5hyg3b loVSdLsIh5Y= -----END CERTIFICATE REQUEST-----` testCertificateCsrEd25519 = `-----BEGIN CERTIFICATE REQUEST----- MIGuMGICAQAwDjEMMAoGA1UEAxMDT0tQMCowBQYDK2VwAyEAopc6daK4zYR6BDAM pV/v53oR/ewbtrkHZQkN/amFMLagITAfBgkqhkiG9w0BCQ4xEjAQMA4GA1UdEQQH MAWCA09LUDAFBgMrZXADQQDJi47MAgl/WKAz+V/kDu1k/zbKk1nrHHAUonbofHUW M6ihSD43+awq3BPeyPbToeH5orSH9l3MuTfbxPb5BVEH -----END CERTIFICATE REQUEST-----` testRootCertificate = `-----BEGIN CERTIFICATE----- MIIBeDCCAR+gAwIBAgIQcXWWjtSZ/PAyH8D1Ou4L9jAKBggqhkjOPQQDAjAbMRkw FwYDVQQDExBDbG91ZENBUyBSb290IENBMB4XDTIwMTAyNzIyNTM1NFoXDTMwMTAy NzIyNTM1NFowGzEZMBcGA1UEAxMQQ2xvdWRDQVMgUm9vdCBDQTBZMBMGByqGSM49 AgEGCCqGSM49AwEHA0IABIySHA4b78Yu4LuGhZIlv/PhNwXz4ZoV1OUZQ0LrK3vj B13O12DLZC5uj1z3kxdQzXUttSbtRv49clMpBiTpsZKjRTBDMA4GA1UdDwEB/wQE AwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSZ+t9RMHbFTl5BatM3 5bJlHPOu3DAKBggqhkjOPQQDAgNHADBEAiASah6gg0tVM3WI0meCQ4SEKk7Mjhbv +SmhuZHWV1QlXQIgRXNyWcpVUrAoG6Uy1KQg07LDpF5dFeK9InrDxSJAkVo= -----END CERTIFICATE-----` testRootFingerprint = `62e816cbac5c501b7705e18415503852798dfbcd67062f06bcb4af67c290e3c8` ) func mustParseCertificate(t *testing.T, pemCert string) *x509.Certificate { t.Helper() crt := parseCertificates(pemCert)[0] return crt } func mustParseCertificateRequest(t *testing.T, pemData string) *x509.CertificateRequest { t.Helper() csr, err := pemutil.ParseCertificateRequest([]byte(pemData)) if err != nil { t.Fatal(err) } return csr } func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { t.Helper() writeJSON := func(w http.ResponseWriter, v interface{}) { _ = json.NewEncoder(w).Encode(v) } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.RequestURI { case "/v1/auth/approle/login": w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{ "auth": { "client_token": "98a4c7ab-b1fe-361b-ba0b-e307aacfd587" } }`) case "/v1/pki/sign/ec": w.WriteHeader(http.StatusOK) cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case "/v1/pki/sign/rsa": w.WriteHeader(http.StatusOK) cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case "/v1/pki/sign/ed25519": w.WriteHeader(http.StatusOK) cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case "/v1/pki/cert/ca_chain": w.WriteHeader(http.StatusOK) cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case "/v1/pki/revoke": buf := new(bytes.Buffer) buf.ReadFrom(r.Body) m := make(map[string]string) json.Unmarshal(buf.Bytes(), &m) switch m["serial_number"] { case "1c-71-6e-18-cc-f4-70-29-5f-75-ee-64-a8-fe-69-ad": w.WriteHeader(http.StatusOK) return case "01-e2-40": w.WriteHeader(http.StatusOK) return // both case "01-34-3e": w.WriteHeader(http.StatusOK) return default: w.WriteHeader(http.StatusNotFound) } default: w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, `{"error":"not found"}`) } })) t.Cleanup(func() { srv.Close() }) u, err := url.Parse(srv.URL) if err != nil { srv.Close() t.Fatal(err) } config := vault.DefaultConfig() config.Address = srv.URL client, err := vault.NewClient(config) if err != nil { srv.Close() t.Fatal(err) } return u, client } func TestNew_register(t *testing.T) { caURL, _ := testCAHelper(t) fn, ok := apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.VaultCAS) if !ok { t.Errorf("apiv1.Register() ok = %v, want true", ok) return } _, err := fn(context.Background(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, Config: json.RawMessage(`{ "AuthType": "approle", "AuthOptions": {"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false} }`), }) if err != nil { t.Errorf("New() error = %v", err) return } } func TestVaultCAS_Type(t *testing.T) { tests := []struct { name string want apiv1.Type }{ {"ok", apiv1.VaultCAS}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &VaultCAS{} if got := c.Type(); got != tt.want { t.Errorf("VaultCAS.Type() = %v, want %v", got, tt.want) } }) } } func TestVaultCAS_CreateCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", } type fields struct { client *vault.Client options VaultOptions } type args struct { req *apiv1.CreateCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.CreateCertificateResponse wantErr bool }{ {"ok ec", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrEc), Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: mustParseCertificate(t, testCertificateSigned), CertificateChain: nil, }, false}, {"ok rsa", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrRsa), Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: mustParseCertificate(t, testCertificateSigned), CertificateChain: nil, }, false}, {"ok ed25519", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrEd25519), Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: mustParseCertificate(t, testCertificateSigned), CertificateChain: nil, }, false}, {"fail CSR", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: nil, Lifetime: time.Hour, }}, nil, true}, {"fail lifetime", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrEc), Lifetime: 0, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &VaultCAS{ client: tt.fields.client, config: tt.fields.options, } got, err := c.CreateCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("VaultCAS.CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("VaultCAS.CreateCertificate() = %v, want %v", got, tt.want) } }) } } func TestVaultCAS_GetCertificateAuthority(t *testing.T) { caURL, client := testCAHelper(t) type fields struct { client *vault.Client options VaultOptions fingerprint string } type args struct { req *apiv1.GetCertificateAuthorityRequest } options := VaultOptions{ PKIMountPath: "pki", } rootCert := parseCertificates(testRootCertificate)[0] tests := []struct { name string fields fields args args want *apiv1.GetCertificateAuthorityResponse wantErr bool }{ {"ok", fields{client, options, testRootFingerprint}, args{&apiv1.GetCertificateAuthorityRequest{ Name: caURL.String(), }}, &apiv1.GetCertificateAuthorityResponse{ RootCertificate: rootCert, }, false}, {"fail fingerprint", fields{client, options, "fail"}, args{&apiv1.GetCertificateAuthorityRequest{ Name: caURL.String(), }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &VaultCAS{ client: tt.fields.client, fingerprint: tt.fields.fingerprint, config: tt.fields.options, } got, err := s.GetCertificateAuthority(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("VaultCAS.GetCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("VaultCAS.GetCertificateAuthority() = %v, want %v", got, tt.want) } }) } } func TestVaultCAS_RevokeCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", } type fields struct { client *vault.Client options VaultOptions } type args struct { req *apiv1.RevokeCertificateRequest } testCrt := parseCertificates(testCertificateSigned)[0] tests := []struct { name string fields fields args args want *apiv1.RevokeCertificateResponse wantErr bool }{ {"ok serial number", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "123456", Certificate: nil, }}, &apiv1.RevokeCertificateResponse{}, false}, {"ok certificate", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "", Certificate: testCrt, }}, &apiv1.RevokeCertificateResponse{ Certificate: testCrt, }, false}, {"ok both", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "78910", Certificate: testCrt, }}, &apiv1.RevokeCertificateResponse{ Certificate: testCrt, }, false}, {"fail serial string", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "fail", Certificate: nil, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &VaultCAS{ client: tt.fields.client, config: tt.fields.options, } got, err := s.RevokeCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("VaultCAS.RevokeCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("VaultCAS.RevokeCertificate() = %v, want %v", got, tt.want) } }) } } func TestVaultCAS_RenewCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", } type fields struct { client *vault.Client options VaultOptions } type args struct { req *apiv1.RenewCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.RenewCertificateResponse wantErr bool }{ {"not implemented", fields{client, options}, args{&apiv1.RenewCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrEc), Lifetime: time.Hour, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &VaultCAS{ client: tt.fields.client, config: tt.fields.options, } got, err := s.RenewCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("VaultCAS.RenewCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("VaultCAS.RenewCertificate() = %v, want %v", got, tt.want) } }) } } func TestVaultCAS_loadOptions(t *testing.T) { tests := []struct { name string raw string want *VaultOptions wantErr bool }{ { "ok mandatory PKIRole PKIRoleEd25519", `{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "role", PKIRoleEC: "role", PKIRoleEd25519: "ed25519", }, false, }, { "ok mandatory PKIRole PKIRoleEC", `{"PKIRoleDefault": "role", "PKIRoleEC": "ec"}`, &VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "role", PKIRoleEC: "ec", PKIRoleEd25519: "role", }, false, }, { "ok mandatory PKIRole PKIRoleRSA", `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa"}`, &VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "rsa", PKIRoleEC: "role", PKIRoleEd25519: "role", }, false, }, { "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519", `{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "default", PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", }, false, }, { "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519 with useless PKIRoleDefault", `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", }, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := loadOptions(json.RawMessage(tt.raw)) if (err != nil) != tt.wantErr { t.Errorf("VaultCAS.loadOptions() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("VaultCAS.loadOptions() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cmd/step-ca/main.go ================================================ package main import ( "flag" "fmt" "html" "log" "net/http" "os" "reflect" "regexp" "strconv" "time" // Server profiler //nolint:gosec // profile server, if enabled runs on a different port _ "net/http/pprof" "github.com/urfave/cli" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/commands" "github.com/smallstep/cli-utils/command" "github.com/smallstep/cli-utils/command/version" "github.com/smallstep/cli-utils/step" "github.com/smallstep/cli-utils/ui" "github.com/smallstep/cli-utils/usage" "go.step.sm/crypto/pemutil" // Enabled kms interfaces. _ "go.step.sm/crypto/kms/awskms" _ "go.step.sm/crypto/kms/azurekms" _ "go.step.sm/crypto/kms/cloudkms" _ "go.step.sm/crypto/kms/pkcs11" _ "go.step.sm/crypto/kms/softkms" _ "go.step.sm/crypto/kms/sshagentkms" _ "go.step.sm/crypto/kms/tpmkms" _ "go.step.sm/crypto/kms/yubikey" // Enabled cas interfaces. _ "github.com/smallstep/certificates/cas/cloudcas" _ "github.com/smallstep/certificates/cas/softcas" _ "github.com/smallstep/certificates/cas/stepcas" _ "github.com/smallstep/certificates/cas/vaultcas" ) // commit and buildTime are filled in during build by the Makefile var ( BuildTime = "N/A" Version = "N/A" ) func init() { step.Set("Smallstep CA", Version, BuildTime) authority.GlobalVersion.Version = Version // Add support for asking passwords pemutil.PromptPassword = func(msg string) ([]byte, error) { return ui.PromptPassword(msg) } } func exit(code int) { ui.Reset() os.Exit(code) } // appHelpTemplate contains the modified template for the main app var appHelpTemplate = `## NAME **{{.HelpName}}** -- {{.Usage}} ## USAGE {{if .UsageText}}{{.UsageText}}{{else}}**{{.HelpName}}**{{if .Commands}} {{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}_[arguments]_{{end}}{{end}}{{if .Description}} ## DESCRIPTION {{.Description}}{{end}}{{if .VisibleCommands}} ## COMMANDS {{range .VisibleCategories}}{{if .Name}}{{.Name}}:{{end}} ||| |---|---|{{range .VisibleCommands}} | **{{join .Names ", "}}** | {{.Usage}} |{{end}} {{end}}{{if .VisibleFlags}}{{end}} ## OPTIONS {{range $index, $option := .VisibleFlags}}{{if $index}} {{end}}{{$option}} {{end}}{{end}}{{if .Copyright}}{{if len .Authors}} ## AUTHOR{{with $length := len .Authors}}{{if ne 1 $length}}S{{end}}{{end}}: {{range $index, $author := .Authors}}{{if $index}} {{end}}{{$author}}{{end}}{{end}}{{if .Version}}{{if not .HideVersion}} ## ONLINE This documentation is available online at https://smallstep.com/docs/certificates ## VERSION {{.Version}}{{end}}{{end}} ## COPYRIGHT {{.Copyright}} ## FEEDBACK ` + html.UnescapeString("&#"+strconv.Itoa(128525)+";") + " " + html.UnescapeString("&#"+strconv.Itoa(127867)+";") + ` The **step-ca** utility is not instrumented for usage statistics. It does not phone home. But your feedback is extremely valuable. Any information you can provide regarding how you’re using **step-ca** helps. Please send us a sentence or two, good or bad: **feedback@smallstep.com** or https://github.com/smallstep/certificates/discussions. {{end}} ` func main() { // initialize step environment. if err := step.Init(); err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) } // Initialize windows terminal ui.Init() // Override global framework components cli.VersionPrinter = func(c *cli.Context) { version.Command(c) } cli.AppHelpTemplate = appHelpTemplate cli.SubcommandHelpTemplate = usage.SubcommandHelpTemplate cli.CommandHelpTemplate = usage.CommandHelpTemplate cli.HelpPrinter = usage.HelpPrinter cli.FlagNamePrefixer = usage.FlagNamePrefixer cli.FlagStringer = stringifyFlag // Configure cli app app := cli.NewApp() app.Name = "step-ca" app.HelpName = "step-ca" app.Version = step.Version() app.Usage = "an online certificate authority for secure automated certificate management" app.UsageText = `**step-ca** [config] [**--context**=] [**--password-file**=] [**--ssh-host-password-file**=] [**--ssh-user-password-file**=] [**--issuer-password-file**=] [**--resolver**=] [**--help**] [**--version**]` app.Description = `**step-ca** runs the Step Online Certificate Authority (Step CA) using the given configuration. See the README.md for more detailed configuration documentation. ## POSITIONAL ARGUMENTS : File that configures the operation of the Step CA; this file is generated when you initialize the Step CA using 'step ca init' ## EXIT CODES This command will run indefinitely on success and return \>0 if any error occurs. ## EXAMPLES These examples assume that you have already initialized your PKI by running 'step ca init'. If you have not completed this step please see the 'Getting Started' section of the README. Run the Step CA and prompt for password: ''' $ step-ca $STEPPATH/config/ca.json ''' Run the Step CA and read the password from a file - this is useful for automating deployment: ''' $ step-ca $STEPPATH/config/ca.json --password-file ./password.txt ''' Run the Step CA for the context selected with step and a custom password file: ''' $ step context select ssh $ step-ca --password-file ./password.txt ''' Run the Step CA for the context named _mybiz_ and prompt for password: ''' $ step-ca --context=mybiz ''' Run the Step CA for the context named _mybiz_ and an alternate ca.json file: ''' $ step-ca --context=mybiz other-ca.json ''' Run the Step CA for the context named _mybiz_ and read the password from a file - this is useful for automating deployment: ''' $ step-ca --context=mybiz --password-file ./password.txt ''' ` app.Flags = append(app.Flags, commands.AppCommand.Flags...) app.Flags = append(app.Flags, cli.HelpFlag) app.Copyright = fmt.Sprintf("(c) 2018-%d Smallstep Labs, Inc.", time.Now().Year()) // All non-successful output should be written to stderr app.Writer = os.Stdout app.ErrWriter = os.Stderr app.Commands = command.Retrieve() // Start the golang debug logger if environment variable is set. // See https://golang.org/pkg/net/http/pprof/ debugProfAddr := os.Getenv("STEP_PROF_ADDR") if debugProfAddr != "" { go func() { srv := http.Server{ Addr: debugProfAddr, ReadHeaderTimeout: 15 * time.Second, } log.Println(srv.ListenAndServe()) }() } app.Action = func(_ *cli.Context) error { // Hack to be able to run a the top action as a subcommand set := flag.NewFlagSet(app.Name, flag.ContinueOnError) set.Parse(os.Args) ctx := cli.NewContext(app, set, nil) return commands.AppCommand.Run(ctx) } if err := app.Run(os.Args); err != nil { if os.Getenv("STEPDEBUG") == "1" { fmt.Fprintf(os.Stderr, "%+v\n", err) } else { fmt.Fprintln(os.Stderr, err) } exit(1) } exit(0) } func flagValue(f cli.Flag) reflect.Value { fv := reflect.ValueOf(f) for fv.Kind() == reflect.Ptr { fv = reflect.Indirect(fv) } return fv } var placeholderString = regexp.MustCompile(`<.*?>`) func stringifyFlag(f cli.Flag) string { fv := flagValue(f) usg := fv.FieldByName("Usage").String() placeholder := placeholderString.FindString(usg) if placeholder == "" { switch f.(type) { case cli.BoolFlag, cli.BoolTFlag: default: placeholder = "" } } return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usg } ================================================ FILE: commands/app.go ================================================ package commands import ( "bytes" "context" "fmt" "net" "net/http" "os" "path/filepath" "strconv" "strings" "unicode" "github.com/pkg/errors" "github.com/urfave/cli" "github.com/smallstep/cli-utils/errs" "github.com/smallstep/cli-utils/step" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/ca" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/pki" ) // AppCommand is the action used as the top action. var AppCommand = cli.Command{ Name: "start", Action: appAction, UsageText: `**step-ca** [**--password-file**=] [**--ssh-host-password-file**=] [**--ssh-user-password-file**=] [**--issuer-password-file**=] [**--pidfile**=] [**--resolver**=]`, Flags: []cli.Flag{ cli.StringFlag{ Name: "password-file", Usage: `path to the containing the password to decrypt the intermediate private key.`, }, cli.StringFlag{ Name: "ssh-host-password-file", Usage: `path to the containing the password to decrypt the private key used to sign SSH host certificates. If the flag is not passed it will default to --password-file.`, }, cli.StringFlag{ Name: "ssh-user-password-file", Usage: `path to the containing the password to decrypt the private key used to sign SSH user certificates. If the flag is not passed it will default to --password-file.`, }, cli.StringFlag{ Name: "issuer-password-file", Usage: `path to the containing the password to decrypt the certificate issuer private key used in the RA mode.`, }, cli.StringFlag{ Name: "resolver", Usage: "address of a DNS resolver to be used instead of the default.", }, cli.StringFlag{ Name: "token", Usage: "token used to enable the linked ca.", EnvVar: "STEP_CA_TOKEN", }, cli.BoolFlag{ Name: "quiet", Usage: "disable startup information", EnvVar: "STEP_CA_QUIET", }, cli.StringFlag{ Name: "context", Usage: "the of the authority's context.", EnvVar: "STEP_CA_CONTEXT", }, cli.IntFlag{ Name: "acme-http-port", Usage: `the used on http-01 challenges. It can be changed for testing purposes. Requires **--insecure** flag.`, }, cli.IntFlag{ Name: "acme-tls-port", Usage: `the used on tls-alpn-01 challenges. It can be changed for testing purposes. Requires **--insecure** flag.`, }, cli.BoolFlag{ Name: "acme-strict-fqdn", Usage: `enable strict DNS resolution using a fully qualified domain name.`, }, cli.StringFlag{ Name: "pidfile", Usage: "the path to the to write the process ID.", }, cli.BoolFlag{ Name: "insecure", Usage: "enable insecure flags.", }, }, } var pidfile string // AppAction is the action used when the top command runs. func appAction(ctx *cli.Context) error { passFile := ctx.String("password-file") sshHostPassFile := ctx.String("ssh-host-password-file") sshUserPassFile := ctx.String("ssh-user-password-file") issuerPassFile := ctx.String("issuer-password-file") resolver := ctx.String("resolver") token := ctx.String("token") quiet := ctx.Bool("quiet") if ctx.NArg() > 1 { return errs.TooManyArguments(ctx) } // Allow custom ACME ports with insecure if acmePort := ctx.Int("acme-http-port"); acmePort != 0 { if ctx.Bool("insecure") { acme.InsecurePortHTTP01 = acmePort } else { return fmt.Errorf("flag '--acme-http-port' requires the '--insecure' flag") } } if acmePort := ctx.Int("acme-tls-port"); acmePort != 0 { if ctx.Bool("insecure") { acme.InsecurePortTLSALPN01 = acmePort } else { return fmt.Errorf("flag '--acme-tls-port' requires the '--insecure' flag") } } // Set the strict DNS resolution on ACME challenges. Defaults to false. acme.StrictFQDN = ctx.Bool("acme-strict-fqdn") // Allow custom contexts. if caCtx := ctx.String("context"); caCtx != "" { if _, ok := step.Contexts().Get(caCtx); ok { if err := step.Contexts().SetCurrent(caCtx); err != nil { return err } } else if token == "" { return fmt.Errorf("context %q not found", caCtx) } else if err := createContext(caCtx); err != nil { return err } } var configFile string if ctx.NArg() > 0 { configFile = ctx.Args().Get(0) } else { configFile = step.CaConfigFile() } cfg, err := config.LoadConfiguration(configFile) if err != nil && token == "" { var pathErr *os.PathError if errors.As(err, &pathErr) { fmt.Println("step-ca can't find or open the configuration file for your CA.") fmt.Println("You may need to create a CA first by running `step ca init`.") fmt.Println("Documentation: https://u.step.sm/docs/ca") os.Exit(1) } fatal(err) } // Initialize a basic configuration to be used with an automatically // configured linked RA. Default configuration includes: // * badgerv2 on $(step path)/db // * JSON logger // * Default TLS options if cfg == nil { cfg = &config.Config{ SkipValidation: true, Logger: []byte(`{"format":"json"}`), DB: &db.Config{ Type: "badgerv2", DataSource: filepath.Join(step.Path(), "db"), }, AuthorityConfig: &config.AuthConfig{ DeploymentType: pki.LinkedDeployment.String(), Provisioners: provisioner.List{}, Template: &config.ASN1DN{}, Backdate: &provisioner.Duration{ Duration: config.DefaultBackdate, }, }, TLS: &config.DefaultTLSOptions, } } if cfg.AuthorityConfig != nil { if token == "" && strings.EqualFold(cfg.AuthorityConfig.DeploymentType, pki.LinkedDeployment.String()) { return errors.New(`'step-ca' requires the '--token' flag for linked deploy type. To get a linked authority token: 1. Contact us at ` + "\033[1mhttps://u.step.sm/cm\033[0m" + ` to create a new Certificate Manager account 2. Add a new authority and select "Link a step-ca instance" 3. Follow instructions in browser to start 'step-ca' using the '--token' flag `) } } var password []byte if passFile != "" { if password, err = os.ReadFile(passFile); err != nil { fatal(errors.Wrapf(err, "error reading %s", passFile)) } password = bytes.TrimRightFunc(password, unicode.IsSpace) } var sshHostPassword []byte if sshHostPassFile != "" { if sshHostPassword, err = os.ReadFile(sshHostPassFile); err != nil { fatal(errors.Wrapf(err, "error reading %s", sshHostPassFile)) } sshHostPassword = bytes.TrimRightFunc(sshHostPassword, unicode.IsSpace) } var sshUserPassword []byte if sshUserPassFile != "" { if sshUserPassword, err = os.ReadFile(sshUserPassFile); err != nil { fatal(errors.Wrapf(err, "error reading %s", sshUserPassFile)) } sshUserPassword = bytes.TrimRightFunc(sshUserPassword, unicode.IsSpace) } var issuerPassword []byte if issuerPassFile != "" { if issuerPassword, err = os.ReadFile(issuerPassFile); err != nil { fatal(errors.Wrapf(err, "error reading %s", issuerPassFile)) } issuerPassword = bytes.TrimRightFunc(issuerPassword, unicode.IsSpace) } if filename := ctx.String("pidfile"); filename != "" { pid := []byte(strconv.Itoa(os.Getpid()) + "\n") //nolint:gosec // 0644 (-rw-r--r--) are common permissions for a pid file if err := os.WriteFile(filename, pid, 0644); err != nil { fatal(errors.Wrap(err, "error writing pidfile")) } pidfile = filename } // replace resolver if requested if resolver != "" { net.DefaultResolver.PreferGo = true net.DefaultResolver.Dial = func(_ context.Context, network, _ string) (net.Conn, error) { return net.Dial(network, resolver) } } srv, err := ca.New(cfg, ca.WithConfigFile(configFile), ca.WithPassword(password), ca.WithSSHHostPassword(sshHostPassword), ca.WithSSHUserPassword(sshUserPassword), ca.WithIssuerPassword(issuerPassword), ca.WithLinkedCAToken(token), ca.WithQuiet(quiet), ) if err != nil { fatal(err) } go ca.StopReloaderHandler(srv) if err = srv.Run(); err != nil && !errors.Is(err, http.ErrServerClosed) { fatal(err) } if pidfile != "" { os.Remove(pidfile) } return nil } // createContext creates a new context using the given name for the context, // authority and profile. func createContext(name string) error { if err := step.Contexts().Add(&step.Context{ Name: name, Authority: name, Profile: name, }); err != nil { return fmt.Errorf("error adding context: %w", err) } if err := step.Contexts().SaveCurrent(name); err != nil { return fmt.Errorf("error saving context: %w", err) } if err := step.Contexts().SetCurrent(name); err != nil { return fmt.Errorf("error setting context: %w", err) } if err := os.MkdirAll(step.Path(), 0700); err != nil { return fmt.Errorf("error creating directory: %w", err) } return nil } // fatal writes the passed error on the standard error and exits with the exit // code 1. If the environment variable STEPDEBUG is set to 1 it shows the // stack trace of the error. func fatal(err error) { if os.Getenv("STEPDEBUG") == "1" { fmt.Fprintf(os.Stderr, "%+v\n", err) } else { fmt.Fprintln(os.Stderr, err) } if pidfile != "" { os.Remove(pidfile) } os.Exit(2) } ================================================ FILE: commands/export.go ================================================ package commands import ( "bytes" "encoding/json" "fmt" "os" "unicode" "github.com/pkg/errors" "github.com/urfave/cli" "google.golang.org/protobuf/encoding/protojson" "github.com/smallstep/cli-utils/command" "github.com/smallstep/cli-utils/errs" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" ) func init() { command.Register(cli.Command{ Name: "export", Usage: "export the current configuration of step-ca", UsageText: "**step-ca export** ", Action: exportAction, Description: `**step-ca export** exports the current configuration of step-ca. Note that neither the PKI password nor the certificate issuer password will be included in the export file. ## POSITIONAL ARGUMENTS : The ca.json that contains the step-ca configuration. ## EXAMPLES Export the current configuration: ''' $ step-ca export $(step path)/config/ca.json '''`, Flags: []cli.Flag{ cli.StringFlag{ Name: "password-file", Usage: `path to the containing the password to decrypt the intermediate private key.`, }, cli.StringFlag{ Name: "issuer-password-file", Usage: `path to the containing the password to decrypt the certificate issuer private key used in the RA mode.`, }, }, }) } func exportAction(ctx *cli.Context) error { if err := errs.NumberOfArguments(ctx, 1); err != nil { return err } configFile := ctx.Args().Get(0) passwordFile := ctx.String("password-file") issuerPasswordFile := ctx.String("issuer-password-file") cfg, err := config.LoadConfiguration(configFile) if err != nil { return err } if err := cfg.Validate(); err != nil { return err } if passwordFile != "" { b, err := os.ReadFile(passwordFile) if err != nil { return errors.Wrapf(err, "error reading %s", passwordFile) } cfg.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace)) } if issuerPasswordFile != "" { b, err := os.ReadFile(issuerPasswordFile) if err != nil { return errors.Wrapf(err, "error reading %s", issuerPasswordFile) } if cfg.AuthorityConfig.CertificateIssuer != nil { cfg.AuthorityConfig.CertificateIssuer.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace)) } } auth, err := authority.New(cfg) if err != nil { return err } export, err := auth.Export() if err != nil { return err } b, err := protojson.Marshal(export) if err != nil { return errors.Wrap(err, "error marshaling export") } var buf bytes.Buffer if err := json.Indent(&buf, b, "", "\t"); err != nil { return errors.Wrap(err, "error indenting export") } fmt.Println(buf.String()) return nil } ================================================ FILE: commands/onboard.go ================================================ package commands import ( "bytes" "encoding/json" "io" "net/http" "net/url" "os" "github.com/pkg/errors" "github.com/urfave/cli" "github.com/smallstep/cli-utils/command" "github.com/smallstep/cli-utils/errs" "github.com/smallstep/cli-utils/fileutil" "github.com/smallstep/cli-utils/ui" "go.step.sm/crypto/randutil" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/ca" "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/pki" ) // defaultOnboardingURL is the production onboarding url, to use a development // url use: // // export STEP_CA_ONBOARDING_URL=http://localhost:3002/onboarding/ const defaultOnboardingURL = "https://api.smallstep.com/onboarding/" type onboardingConfiguration struct { Name string `json:"name"` DNS string `json:"dns"` Address string `json:"address"` password []byte } type onboardingPayload struct { Fingerprint string `json:"fingerprint"` } type onboardingError struct { StatusCode int `json:"statusCode"` Message string `json:"message"` } func (e onboardingError) Error() string { return e.Message } func init() { command.Register(cli.Command{ Name: "onboard", Usage: "configure and run step-ca from the onboarding guide", UsageText: "**step-ca onboard** ", Action: onboardAction, Description: `**step-ca onboard** configures step certificates using the onboarding guide. Open https://smallstep.com/onboarding in your browser and start the CA with the given token: ''' $ step-ca onboard ''' ## POSITIONAL ARGUMENTS : The token string provided by the onboarding guide.`, }) } func onboardAction(ctx *cli.Context) error { if ctx.NArg() == 0 { return cli.ShowCommandHelp(ctx, "onboard") } if err := errs.NumberOfArguments(ctx, 1); err != nil { return err } // Get onboarding url onboarding := defaultOnboardingURL if v := os.Getenv("STEP_CA_ONBOARDING_URL"); v != "" { onboarding = v } u, err := url.Parse(onboarding) if err != nil { return errors.Wrapf(err, "error parsing %s", onboarding) } ui.Println("Connecting to onboarding guide...") token := ctx.Args().Get(0) onboardingURL := u.ResolveReference(&url.URL{Path: token}).String() //nolint:gosec // onboarding url res, err := http.Get(onboardingURL) if err != nil { return errors.Wrap(err, "error connecting onboarding guide") } defer res.Body.Close() if res.StatusCode >= 400 { var msg onboardingError if err := readJSON(res.Body, &msg); err != nil { return errors.Wrap(err, "error unmarshaling response") } return errors.Wrap(msg, "error receiving onboarding guide") } var cfg onboardingConfiguration if err := readJSON(res.Body, &cfg); err != nil { return errors.Wrap(err, "error unmarshaling response") } password, err := randutil.ASCII(32) if err != nil { return err } cfg.password = []byte(password) ui.Println("Initializing step-ca with the following configuration:") ui.PrintSelected("Name", cfg.Name) ui.PrintSelected("DNS", cfg.DNS) ui.PrintSelected("Address", cfg.Address) ui.PrintSelected("Password", password) ui.Println() caConfig, fp, err := onboardPKI(cfg) if err != nil { return err } payload, err := json.Marshal(onboardingPayload{Fingerprint: fp}) if err != nil { return errors.Wrap(err, "error marshaling payload") } //nolint:gosec // onboarding url resp, err := http.Post(onboardingURL, "application/json", bytes.NewBuffer(payload)) if err != nil { return errors.Wrap(err, "error connecting onboarding guide") } if resp.StatusCode >= 400 { var msg onboardingError if err := readJSON(resp.Body, &msg); err != nil { ui.Printf("%s {{ \"error unmarshalling response: %v\" | yellow }}\n", ui.IconWarn, err) } else { ui.Printf("%s {{ \"error posting fingerprint: %s\" | yellow }}\n", ui.IconWarn, msg.Message) } } else { resp.Body.Close() } ui.Println("Initialized!") ui.Println("Step CA is starting. Please return to the onboarding guide in your browser to continue.") srv, err := ca.New(caConfig, ca.WithPassword(cfg.password)) if err != nil { fatal(err) } go ca.StopReloaderHandler(srv) if err := srv.Run(); err != nil && !errors.Is(err, http.ErrServerClosed) { fatal(err) } return nil } func onboardPKI(cfg onboardingConfiguration) (*config.Config, string, error) { var opts = []pki.Option{ pki.WithAddress(cfg.Address), pki.WithDNSNames([]string{cfg.DNS}), pki.WithProvisioner("admin"), } p, err := pki.New(apiv1.Options{ Type: apiv1.SoftCAS, IsCreator: true, }, opts...) if err != nil { return nil, "", err } // Generate pki ui.Println("Generating root certificate...") root, err := p.GenerateRootCertificate(cfg.Name, cfg.Name, cfg.Name, cfg.password) if err != nil { return nil, "", err } ui.Println("Generating intermediate certificate...") err = p.GenerateIntermediateCertificate(cfg.Name, cfg.Name, cfg.Name, root, cfg.password) if err != nil { return nil, "", err } // Write files to disk if err := p.WriteFiles(); err != nil { return nil, "", err } // Generate provisioner ui.Println("Generating admin provisioner...") if err := p.GenerateKeyPairs(cfg.password); err != nil { return nil, "", err } // Generate and write configuration caConfig, err := p.GenerateConfig() if err != nil { return nil, "", err } b, err := json.MarshalIndent(caConfig, "", " ") //nolint:gosec // config struct contains password field by design if err != nil { return nil, "", errors.Wrapf(err, "error marshaling %s", p.GetCAConfigPath()) } if err := fileutil.WriteFile(p.GetCAConfigPath(), b, 0666); err != nil { return nil, "", errs.FileError(err, p.GetCAConfigPath()) } return caConfig, p.GetRootFingerprint(), nil } func readJSON(r io.ReadCloser, v interface{}) error { defer r.Close() return json.NewDecoder(r).Decode(v) } ================================================ FILE: cosign.pub ================================================ -----BEGIN PUBLIC KEY----- MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEs+6THbAiXx4bja5ARQFNZmPwZjlD GRvt5H+9ZFDhrcFPR1E7eB2rt1B/DhobANdHGKjvEBZEf0v4X/7S+SHrIw== -----END PUBLIC KEY----- ================================================ FILE: db/db.go ================================================ package db import ( "context" "crypto/x509" "encoding/json" "strconv" "strings" "time" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/internal/cast" ) var ( certsTable = []byte("x509_certs") certsDataTable = []byte("x509_certs_data") revokedCertsTable = []byte("revoked_x509_certs") crlTable = []byte("x509_crl") revokedSSHCertsTable = []byte("revoked_ssh_certs") usedOTTTable = []byte("used_ott") sshCertsTable = []byte("ssh_certs") sshHostsTable = []byte("ssh_hosts") sshUsersTable = []byte("ssh_users") sshHostPrincipalsTable = []byte("ssh_host_principals") ) // TODO: at the moment we store a single CRL in the database, in a dedicated table. // is this acceptable? probably not.... var crlKey = []byte("crl") // ErrAlreadyExists can be returned if the DB attempts to set a key that has // been previously set. var ErrAlreadyExists = errors.New("already exists") // Config represents the JSON attributes used for configuring a step-ca DB. type Config struct { Type string `json:"type"` DataSource string `json:"dataSource"` ValueDir string `json:"valueDir,omitempty"` Database string `json:"database,omitempty"` // BadgerFileLoadingMode can be set to 'FileIO' (instead of the default // 'MemoryMap') to avoid memory-mapping log files. This can be useful // in environments with low RAM BadgerFileLoadingMode string `json:"badgerFileLoadingMode"` } // AuthDB is an interface over an Authority DB client that implements a nosql.DB interface. type AuthDB interface { IsRevoked(sn string) (bool, error) IsSSHRevoked(sn string) (bool, error) Revoke(rci *RevokedCertificateInfo) error RevokeSSH(rci *RevokedCertificateInfo) error GetCertificate(serialNumber string) (*x509.Certificate, error) UseToken(id, tok string) (bool, error) IsSSHHost(name string) (bool, error) GetSSHHostPrincipals() ([]string, error) Shutdown() error } type dbKey struct{} // NewContext adds the given authority database to the context. func NewContext(ctx context.Context, db AuthDB) context.Context { return context.WithValue(ctx, dbKey{}, db) } // FromContext returns the current authority database from the given context. func FromContext(ctx context.Context) (db AuthDB, ok bool) { db, ok = ctx.Value(dbKey{}).(AuthDB) return } // MustFromContext returns the current database from the given context. It // will panic if it's not in the context. func MustFromContext(ctx context.Context) AuthDB { var ( db AuthDB ok bool ) if db, ok = FromContext(ctx); !ok { panic("authority database is not in the context") } return db } // CertificateStorer is an extension of AuthDB that allows to store // certificates. type CertificateStorer interface { StoreCertificate(crt *x509.Certificate) error StoreSSHCertificate(crt *ssh.Certificate) error } // CertificateRevocationListDB is an interface to indicate whether the DB supports CRL generation type CertificateRevocationListDB interface { GetRevokedCertificates() (*[]RevokedCertificateInfo, error) GetCRL() (*CertificateRevocationListInfo, error) StoreCRL(*CertificateRevocationListInfo) error } // DB is a wrapper over the nosql.DB interface. type DB struct { nosql.DB isUp bool } // New returns a new database client that implements the AuthDB interface. func New(c *Config) (AuthDB, error) { if c == nil { return newSimpleDB(c) } opts := []nosql.Option{nosql.WithDatabase(c.Database), nosql.WithValueDir(c.ValueDir)} if c.BadgerFileLoadingMode != "" { opts = append(opts, nosql.WithBadgerFileLoadingMode(c.BadgerFileLoadingMode)) } db, err := nosql.New(c.Type, c.DataSource, opts...) if err != nil { return nil, errors.Wrapf(err, "Error opening database of Type %s", c.Type) } tables := [][]byte{ revokedCertsTable, certsTable, usedOTTTable, sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable, revokedSSHCertsTable, certsDataTable, crlTable, } for _, b := range tables { if err := db.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", string(b)) } } return &DB{db, true}, nil } // RevokedCertificateInfo contains information regarding the certificate // revocation action. type RevokedCertificateInfo struct { Serial string ProvisionerID string ReasonCode int Reason string RevokedAt time.Time ExpiresAt time.Time TokenID string MTLS bool ACME bool } // CertificateRevocationListInfo contains a CRL in DER format and associated // metadata to allow a decision on whether to regenerate the CRL or not easier type CertificateRevocationListInfo struct { Number int64 ExpiresAt time.Time Duration time.Duration DER []byte } // IsRevoked returns whether or not a certificate with the given identifier // has been revoked. // In the case of an X509 Certificate the `id` should be the Serial Number of // the Certificate. func (db *DB) IsRevoked(sn string) (bool, error) { // If the DB is nil then act as pass through. if db == nil { return false, nil } // If the error is `Not Found` then the certificate has not been revoked. // Any other error should be propagated to the caller. if _, err := db.Get(revokedCertsTable, []byte(sn)); err != nil { if nosql.IsErrNotFound(err) { return false, nil } return false, errors.Wrap(err, "error checking revocation bucket") } // This certificate has been revoked. return true, nil } // IsSSHRevoked returns whether or not a certificate with the given identifier // has been revoked. // In the case of an X509 Certificate the `id` should be the Serial Number of // the Certificate. func (db *DB) IsSSHRevoked(sn string) (bool, error) { // If the DB is nil then act as pass through. if db == nil { return false, nil } // If the error is `Not Found` then the certificate has not been revoked. // Any other error should be propagated to the caller. if _, err := db.Get(revokedSSHCertsTable, []byte(sn)); err != nil { if nosql.IsErrNotFound(err) { return false, nil } return false, errors.Wrap(err, "error checking revocation bucket") } // This certificate has been revoked. return true, nil } // Revoke adds a certificate to the revocation table. func (db *DB) Revoke(rci *RevokedCertificateInfo) error { rcib, err := json.Marshal(rci) if err != nil { return errors.Wrap(err, "error marshaling revoked certificate info") } _, swapped, err := db.CmpAndSwap(revokedCertsTable, []byte(rci.Serial), nil, rcib) switch { case err != nil: return errors.Wrap(err, "error AuthDB CmpAndSwap") case !swapped: return ErrAlreadyExists default: return nil } } // RevokeSSH adds a SSH certificate to the revocation table. func (db *DB) RevokeSSH(rci *RevokedCertificateInfo) error { rcib, err := json.Marshal(rci) if err != nil { return errors.Wrap(err, "error marshaling revoked certificate info") } _, swapped, err := db.CmpAndSwap(revokedSSHCertsTable, []byte(rci.Serial), nil, rcib) switch { case err != nil: return errors.Wrap(err, "error AuthDB CmpAndSwap") case !swapped: return ErrAlreadyExists default: return nil } } // GetRevokedCertificates gets a list of all revoked certificates. func (db *DB) GetRevokedCertificates() (*[]RevokedCertificateInfo, error) { entries, err := db.List(revokedCertsTable) if err != nil { return nil, err } var revokedCerts []RevokedCertificateInfo for _, e := range entries { var data RevokedCertificateInfo if err := json.Unmarshal(e.Value, &data); err != nil { return nil, err } revokedCerts = append(revokedCerts, data) } return &revokedCerts, nil } // StoreCRL stores a CRL in the DB func (db *DB) StoreCRL(crlInfo *CertificateRevocationListInfo) error { crlInfoBytes, err := json.Marshal(crlInfo) if err != nil { return errors.Wrap(err, "json Marshal error") } if err := db.Set(crlTable, crlKey, crlInfoBytes); err != nil { return errors.Wrap(err, "database Set error") } return nil } // GetCRL gets the existing CRL from the database func (db *DB) GetCRL() (*CertificateRevocationListInfo, error) { crlInfoBytes, err := db.Get(crlTable, crlKey) if err != nil { return nil, errors.Wrap(err, "database Get error") } var crlInfo CertificateRevocationListInfo err = json.Unmarshal(crlInfoBytes, &crlInfo) if err != nil { return nil, errors.Wrap(err, "json Unmarshal error") } return &crlInfo, err } // GetCertificate retrieves a certificate by the serial number. func (db *DB) GetCertificate(serialNumber string) (*x509.Certificate, error) { asn1Data, err := db.Get(certsTable, []byte(serialNumber)) if err != nil { return nil, errors.Wrap(err, "database Get error") } cert, err := x509.ParseCertificate(asn1Data) if err != nil { return nil, errors.Wrapf(err, "error parsing certificate with serial number %s", serialNumber) } return cert, nil } // GetCertificateData returns the data stored for a provisioner func (db *DB) GetCertificateData(serialNumber string) (*CertificateData, error) { b, err := db.Get(certsDataTable, []byte(serialNumber)) if err != nil { return nil, errors.Wrap(err, "database Get error") } var data CertificateData if err := json.Unmarshal(b, &data); err != nil { return nil, errors.Wrap(err, "error unmarshaling json") } return &data, nil } // StoreCertificate stores a certificate PEM. func (db *DB) StoreCertificate(crt *x509.Certificate) error { if err := db.Set(certsTable, []byte(crt.SerialNumber.String()), crt.Raw); err != nil { return errors.Wrap(err, "database Set error") } return nil } // CertificateData is the JSON representation of the data stored in // x509_certs_data table. type CertificateData struct { Provisioner *ProvisionerData `json:"provisioner,omitempty"` RaInfo *provisioner.RAInfo `json:"ra,omitempty"` } // ProvisionerData is the JSON representation of the provisioner stored in the // x509_certs_data table. type ProvisionerData struct { ID string `json:"id"` Name string `json:"name"` Type string `json:"type"` } type raProvisioner interface { RAInfo() *provisioner.RAInfo } // StoreCertificateChain stores the leaf certificate and the provisioner that // authorized the certificate. func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error { leaf := chain[0] serialNumber := []byte(leaf.SerialNumber.String()) data := &CertificateData{} if p != nil { data.Provisioner = &ProvisionerData{ ID: p.GetID(), Name: p.GetName(), Type: p.GetType().String(), } if rap, ok := p.(raProvisioner); ok { data.RaInfo = rap.RAInfo() } } b, err := json.Marshal(data) if err != nil { return errors.Wrap(err, "error marshaling json") } // Add certificate and certificate data in one transaction. tx := new(database.Tx) tx.Set(certsTable, serialNumber, leaf.Raw) tx.Set(certsDataTable, serialNumber, b) if err := db.Update(tx); err != nil { return errors.Wrap(err, "database Update error") } return nil } // StoreRenewedCertificate stores the leaf certificate and the provisioner that // authorized the old certificate if available. func (db *DB) StoreRenewedCertificate(oldCert *x509.Certificate, chain ...*x509.Certificate) error { var certificateData []byte if data, err := db.GetCertificateData(oldCert.SerialNumber.String()); err == nil { if b, err := json.Marshal(data); err == nil { certificateData = b } } leaf := chain[0] serialNumber := []byte(leaf.SerialNumber.String()) // Add certificate and certificate data in one transaction. tx := new(database.Tx) tx.Set(certsTable, serialNumber, leaf.Raw) if certificateData != nil { tx.Set(certsDataTable, serialNumber, certificateData) } if err := db.Update(tx); err != nil { return errors.Wrap(err, "database Update error") } return nil } // UseToken returns true if we were able to successfully store the token for // for the first time, false otherwise. func (db *DB) UseToken(id, tok string) (bool, error) { _, swapped, err := db.CmpAndSwap(usedOTTTable, []byte(id), nil, []byte(tok)) if err != nil { return false, errors.Wrapf(err, "error storing used token %s/%s", string(usedOTTTable), id) } return swapped, nil } // IsSSHHost returns if a principal is present in the ssh hosts table. func (db *DB) IsSSHHost(principal string) (bool, error) { if _, err := db.Get(sshHostsTable, []byte(strings.ToLower(principal))); err != nil { if database.IsErrNotFound(err) { return false, nil } return false, errors.Wrap(err, "database Get error") } return true, nil } type sshHostPrincipalData struct { Serial string Expiry uint64 } // StoreSSHCertificate stores an SSH certificate. func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error { serial := strconv.FormatUint(crt.Serial, 10) tx := new(database.Tx) tx.Set(sshCertsTable, []byte(serial), crt.Marshal()) if crt.CertType == ssh.HostCert { for _, p := range crt.ValidPrincipals { hostPrincipalData, err := json.Marshal(sshHostPrincipalData{ Serial: serial, Expiry: crt.ValidBefore, }) if err != nil { return err } tx.Set(sshHostsTable, []byte(strings.ToLower(p)), []byte(serial)) tx.Set(sshHostPrincipalsTable, []byte(strings.ToLower(p)), hostPrincipalData) } } else { for _, p := range crt.ValidPrincipals { tx.Set(sshUsersTable, []byte(strings.ToLower(p)), []byte(serial)) } } if err := db.Update(tx); err != nil { return errors.Wrap(err, "database Update error") } return nil } // GetSSHHostPrincipals gets a list of all valid host principals. func (db *DB) GetSSHHostPrincipals() ([]string, error) { entries, err := db.List(sshHostPrincipalsTable) if err != nil { return nil, err } var principals []string for _, e := range entries { var data sshHostPrincipalData if err := json.Unmarshal(e.Value, &data); err != nil { return nil, err } if time.Unix(cast.Int64(data.Expiry), 0).After(time.Now()) { principals = append(principals, string(e.Key)) } } return principals, nil } // Shutdown sends a shutdown message to the database. func (db *DB) Shutdown() error { if db.isUp { if err := db.Close(); err != nil { return errors.Wrap(err, "database shutdown error") } db.isUp = false } return nil } // MockAuthDB mocks the AuthDB interface. // type MockAuthDB struct { Err error Ret1 interface{} MIsRevoked func(string) (bool, error) MIsSSHRevoked func(string) (bool, error) MRevoke func(rci *RevokedCertificateInfo) error MRevokeSSH func(rci *RevokedCertificateInfo) error MGetCertificate func(serialNumber string) (*x509.Certificate, error) MGetCertificateData func(serialNumber string) (*CertificateData, error) MStoreCertificate func(crt *x509.Certificate) error MUseToken func(id, tok string) (bool, error) MIsSSHHost func(principal string) (bool, error) MStoreSSHCertificate func(crt *ssh.Certificate) error MGetSSHHostPrincipals func() ([]string, error) MShutdown func() error MGetRevokedCertificates func() (*[]RevokedCertificateInfo, error) MGetCRL func() (*CertificateRevocationListInfo, error) MStoreCRL func(*CertificateRevocationListInfo) error } func (m *MockAuthDB) GetRevokedCertificates() (*[]RevokedCertificateInfo, error) { if m.MGetRevokedCertificates != nil { return m.MGetRevokedCertificates() } return m.Ret1.(*[]RevokedCertificateInfo), m.Err } func (m *MockAuthDB) GetCRL() (*CertificateRevocationListInfo, error) { if m.MGetCRL != nil { return m.MGetCRL() } return m.Ret1.(*CertificateRevocationListInfo), m.Err } func (m *MockAuthDB) StoreCRL(info *CertificateRevocationListInfo) error { if m.MStoreCRL != nil { return m.MStoreCRL(info) } return m.Err } // IsRevoked mock. func (m *MockAuthDB) IsRevoked(sn string) (bool, error) { if m.MIsRevoked != nil { return m.MIsRevoked(sn) } return m.Ret1.(bool), m.Err } // IsSSHRevoked mock. func (m *MockAuthDB) IsSSHRevoked(sn string) (bool, error) { if m.MIsSSHRevoked != nil { return m.MIsSSHRevoked(sn) } return m.Ret1.(bool), m.Err } // UseToken mock. func (m *MockAuthDB) UseToken(id, tok string) (bool, error) { if m.MUseToken != nil { return m.MUseToken(id, tok) } if m.Ret1 == nil { return false, m.Err } return m.Ret1.(bool), m.Err } // Revoke mock. func (m *MockAuthDB) Revoke(rci *RevokedCertificateInfo) error { if m.MRevoke != nil { return m.MRevoke(rci) } return m.Err } // RevokeSSH mock. func (m *MockAuthDB) RevokeSSH(rci *RevokedCertificateInfo) error { if m.MRevokeSSH != nil { return m.MRevokeSSH(rci) } return m.Err } // GetCertificate mock. func (m *MockAuthDB) GetCertificate(serialNumber string) (*x509.Certificate, error) { if m.MGetCertificate != nil { return m.MGetCertificate(serialNumber) } return m.Ret1.(*x509.Certificate), m.Err } // GetCertificateData mock. func (m *MockAuthDB) GetCertificateData(serialNumber string) (*CertificateData, error) { if m.MGetCertificateData != nil { return m.MGetCertificateData(serialNumber) } if cd, ok := m.Ret1.(*CertificateData); ok { return cd, m.Err } return nil, m.Err } // StoreCertificate mock. func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error { if m.MStoreCertificate != nil { return m.MStoreCertificate(crt) } return m.Err } // IsSSHHost mock. func (m *MockAuthDB) IsSSHHost(principal string) (bool, error) { if m.MIsSSHHost != nil { return m.MIsSSHHost(principal) } return m.Ret1.(bool), m.Err } // StoreSSHCertificate mock. func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error { if m.MStoreSSHCertificate != nil { return m.MStoreSSHCertificate(crt) } return m.Err } // GetSSHHostPrincipals mock. func (m *MockAuthDB) GetSSHHostPrincipals() ([]string, error) { if m.MGetSSHHostPrincipals != nil { return m.MGetSSHHostPrincipals() } return m.Ret1.([]string), m.Err } // Shutdown mock. func (m *MockAuthDB) Shutdown() error { if m.MShutdown != nil { return m.MShutdown() } return m.Err } // MockNoSQLDB // type MockNoSQLDB struct { Err error Ret1, Ret2 interface{} MGet func(bucket, key []byte) ([]byte, error) MSet func(bucket, key, value []byte) error MOpen func(dataSourceName string, opt ...database.Option) error MClose func() error MCreateTable func(bucket []byte) error MDeleteTable func(bucket []byte) error MDel func(bucket, key []byte) error MList func(bucket []byte) ([]*database.Entry, error) MUpdate func(tx *database.Tx) error MCmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error) } // CmpAndSwap mock func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) { if m.MCmpAndSwap != nil { return m.MCmpAndSwap(bucket, key, old, newval) } if m.Ret1 == nil { return nil, false, m.Err } return m.Ret1.([]byte), m.Ret2.(bool), m.Err } // Get mock func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) { if m.MGet != nil { return m.MGet(bucket, key) } if m.Ret1 == nil { return nil, m.Err } return m.Ret1.([]byte), m.Err } // Set mock func (m *MockNoSQLDB) Set(bucket, key, value []byte) error { if m.MSet != nil { return m.MSet(bucket, key, value) } return m.Err } // Open mock func (m *MockNoSQLDB) Open(dataSourceName string, opt ...database.Option) error { if m.MOpen != nil { return m.MOpen(dataSourceName, opt...) } return m.Err } // Close mock func (m *MockNoSQLDB) Close() error { if m.MClose != nil { return m.MClose() } return m.Err } // CreateTable mock func (m *MockNoSQLDB) CreateTable(bucket []byte) error { if m.MCreateTable != nil { return m.MCreateTable(bucket) } return m.Err } // DeleteTable mock func (m *MockNoSQLDB) DeleteTable(bucket []byte) error { if m.MDeleteTable != nil { return m.MDeleteTable(bucket) } return m.Err } // Del mock func (m *MockNoSQLDB) Del(bucket, key []byte) error { if m.MDel != nil { return m.MDel(bucket, key) } return m.Err } // List mock func (m *MockNoSQLDB) List(bucket []byte) ([]*database.Entry, error) { if m.MList != nil { return m.MList(bucket) } return m.Ret1.([]*database.Entry), m.Err } // Update mock func (m *MockNoSQLDB) Update(tx *database.Tx) error { if m.MUpdate != nil { return m.MUpdate(tx) } return m.Err } ================================================ FILE: db/db_test.go ================================================ package db import ( "bytes" "crypto/x509" "errors" "math/big" "reflect" "testing" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" ) func TestIsRevoked(t *testing.T) { tests := map[string]struct { key string db *DB isRevoked bool err error }{ "false/nil db": { key: "sn", }, "false/ErrNotFound": { key: "sn", db: &DB{&MockNoSQLDB{Err: database.ErrNotFound, Ret1: nil}, true}, }, "error/checking bucket": { key: "sn", db: &DB{&MockNoSQLDB{Err: errors.New("force"), Ret1: nil}, true}, err: errors.New("error checking revocation bucket: force"), }, "true": { key: "sn", db: &DB{&MockNoSQLDB{Ret1: []byte("value")}, true}, isRevoked: true, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { isRevoked, err := tc.db.IsRevoked(tc.key) if err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, tc.err.Error(), err.Error()) } } else { assert.Nil(t, tc.err) assert.Fatal(t, isRevoked == tc.isRevoked) } }) } } func TestRevoke(t *testing.T) { tests := map[string]struct { rci *RevokedCertificateInfo db *DB err error }{ "error/force isRevoked": { rci: &RevokedCertificateInfo{Serial: "sn"}, db: &DB{&MockNoSQLDB{ MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) { return nil, false, errors.New("force") }, }, true}, err: errors.New("error AuthDB CmpAndSwap: force"), }, "error/was already revoked": { rci: &RevokedCertificateInfo{Serial: "sn"}, db: &DB{&MockNoSQLDB{ MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), false, nil }, }, true}, err: ErrAlreadyExists, }, "ok": { rci: &RevokedCertificateInfo{Serial: "sn"}, db: &DB{&MockNoSQLDB{ MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), true, nil }, }, true}, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { if err := tc.db.Revoke(tc.rci); err != nil { if assert.NotNil(t, tc.err) { assert.HasPrefix(t, tc.err.Error(), err.Error()) } } else { assert.Nil(t, tc.err) } }) } } func TestUseToken(t *testing.T) { type result struct { err error ok bool } tests := map[string]struct { id, tok string db *DB want result }{ "fail/force-CmpAndSwap-error": { id: "id", tok: "token", db: &DB{&MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, false, errors.New("force") }, }, true}, want: result{ ok: false, err: errors.New("error storing used token used_ott/id"), }, }, "fail/CmpAndSwap-already-exists": { id: "id", tok: "token", db: &DB{&MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), false, nil }, }, true}, want: result{ ok: false, }, }, "ok/cmpAndSwap-success": { id: "id", tok: "token", db: &DB{&MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("bar"), true, nil }, }, true}, want: result{ ok: true, }, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { switch ok, err := tc.db.UseToken(tc.id, tc.tok); { case err != nil: if assert.NotNil(t, tc.want.err) { assert.HasPrefix(t, err.Error(), tc.want.err.Error()) } assert.False(t, ok) case ok: assert.True(t, tc.want.ok) default: assert.False(t, tc.want.ok) } }) } } // wrappedProvisioner implements raProvisioner and attProvisioner. type wrappedProvisioner struct { provisioner.Interface raInfo *provisioner.RAInfo } func (p *wrappedProvisioner) RAInfo() *provisioner.RAInfo { return p.raInfo } func TestDB_StoreCertificateChain(t *testing.T) { p := &provisioner.JWK{ ID: "some-id", Name: "admin", Type: "JWK", } rap := &wrappedProvisioner{ Interface: p, raInfo: &provisioner.RAInfo{ ProvisionerID: "ra-id", ProvisionerType: "JWK", ProvisionerName: "ra", }, } chain := []*x509.Certificate{ {Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)}, } type fields struct { DB nosql.DB isUp bool } type args struct { p provisioner.Interface chain []*x509.Certificate } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{&MockNoSQLDB{ MUpdate: func(tx *database.Tx) error { if len(tx.Operations) != 2 { t.Fatal("unexpected number of operations") } assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket) assert.Equals(t, []byte("1234"), tx.Operations[0].Key) assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value) assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket) assert.Equals(t, []byte("1234"), tx.Operations[1].Key) assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), tx.Operations[1].Value) return nil }, }, true}, args{p, chain}, false}, {"ok ra provisioner", fields{&MockNoSQLDB{ MUpdate: func(tx *database.Tx) error { if len(tx.Operations) != 2 { t.Fatal("unexpected number of operations") } assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket) assert.Equals(t, []byte("1234"), tx.Operations[0].Key) assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value) assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket) assert.Equals(t, []byte("1234"), tx.Operations[1].Key) assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`), tx.Operations[1].Value) assert.Equals(t, `{"provisioner":{"id":"some-id","name":"admin","type":"JWK"},"ra":{"provisionerId":"ra-id","provisionerType":"JWK","provisionerName":"ra"}}`, string(tx.Operations[1].Value)) return nil }, }, true}, args{rap, chain}, false}, {"ok no provisioner", fields{&MockNoSQLDB{ MUpdate: func(tx *database.Tx) error { if len(tx.Operations) != 2 { t.Fatal("unexpected number of operations") } assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket) assert.Equals(t, []byte("1234"), tx.Operations[0].Key) assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value) assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket) assert.Equals(t, []byte("1234"), tx.Operations[1].Key) assert.Equals(t, []byte(`{}`), tx.Operations[1].Value) return nil }, }, true}, args{nil, chain}, false}, {"fail store certificate", fields{&MockNoSQLDB{ MUpdate: func(tx *database.Tx) error { return errors.New("test error") }, }, true}, args{p, chain}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { d := &DB{ DB: tt.fields.DB, isUp: tt.fields.isUp, } if err := d.StoreCertificateChain(tt.args.p, tt.args.chain...); (err != nil) != tt.wantErr { t.Errorf("DB.StoreCertificateChain() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestDB_GetCertificateData(t *testing.T) { type fields struct { DB nosql.DB isUp bool } type args struct { serialNumber string } tests := []struct { name string fields fields args args want *CertificateData wantErr bool }{ {"ok", fields{&MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, []byte("x509_certs_data")) assert.Equals(t, key, []byte("1234")) return []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), nil }, }, true}, args{"1234"}, &CertificateData{ Provisioner: &ProvisionerData{ ID: "some-id", Name: "admin", Type: "JWK", }, }, false}, {"fail not found", fields{&MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { return nil, database.ErrNotFound }, }, true}, args{"1234"}, nil, true}, {"fail db", fields{&MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { return nil, errors.New("an error") }, }, true}, args{"1234"}, nil, true}, {"fail unmarshal", fields{&MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { return []byte(`{"bad-json"}`), nil }, }, true}, args{"1234"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db := &DB{ DB: tt.fields.DB, isUp: tt.fields.isUp, } got, err := db.GetCertificateData(tt.args.serialNumber) if (err != nil) != tt.wantErr { t.Errorf("DB.GetCertificateData() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("DB.GetCertificateData() = %v, want %v", got, tt.want) } }) } } func TestDB_StoreRenewedCertificate(t *testing.T) { oldCert := &x509.Certificate{SerialNumber: big.NewInt(1)} chain := []*x509.Certificate{ &x509.Certificate{SerialNumber: big.NewInt(2), Raw: []byte("raw")}, &x509.Certificate{SerialNumber: big.NewInt(0)}, } testErr := errors.New("test error") certsData := []byte(`{"provisioner":{"id":"p","name":"name","type":"JWK"},"ra":{"provisionerId":"rap","provisionerType":"JWK","provisionerName":"rapname"}}`) matchOperation := func(op *database.TxEntry, bucket, key, value []byte) bool { return bytes.Equal(op.Bucket, bucket) && bytes.Equal(op.Key, key) && bytes.Equal(op.Value, value) } type fields struct { DB nosql.DB isUp bool } type args struct { oldCert *x509.Certificate chain []*x509.Certificate } tests := []struct { name string fields fields args args wantErr bool }{ {"ok", fields{&MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { if bytes.Equal(bucket, certsDataTable) && bytes.Equal(key, []byte("1")) { return certsData, nil } t.Error("ok failed: unexpected get") return nil, testErr }, MUpdate: func(tx *database.Tx) error { if len(tx.Operations) != 2 { t.Error("ok failed: unexpected number of operations") return testErr } op0, op1 := tx.Operations[0], tx.Operations[1] if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) { t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value) return testErr } if !matchOperation(op1, certsDataTable, []byte("2"), certsData) { t.Errorf("ok failed: unexpected entry 1, %s[%s]=%s", op1.Bucket, op1.Key, op1.Value) return testErr } return nil }, }, true}, args{oldCert, chain}, false}, {"ok no data", fields{&MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { return nil, database.ErrNotFound }, MUpdate: func(tx *database.Tx) error { if len(tx.Operations) != 1 { t.Error("ok failed: unexpected number of operations") return testErr } op0 := tx.Operations[0] if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) { t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value) return testErr } return nil }, }, true}, args{oldCert, chain}, false}, {"ok fail marshal", fields{&MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { return []byte(`{"bad":"json"`), nil }, MUpdate: func(tx *database.Tx) error { if len(tx.Operations) != 1 { t.Error("ok failed: unexpected number of operations") return testErr } op0 := tx.Operations[0] if !matchOperation(op0, certsTable, []byte("2"), []byte("raw")) { t.Errorf("ok failed: unexpected entry 0, %s[%s]=%s", op0.Bucket, op0.Key, op0.Value) return testErr } return nil }, }, true}, args{oldCert, chain}, false}, {"fail", fields{&MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { return certsData, nil }, MUpdate: func(tx *database.Tx) error { return testErr }, }, true}, args{oldCert, chain}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db := &DB{ DB: tt.fields.DB, isUp: tt.fields.isUp, } if err := db.StoreRenewedCertificate(tt.args.oldCert, tt.args.chain...); (err != nil) != tt.wantErr { t.Errorf("DB.StoreRenewedCertificate() error = %v, wantErr %v", err, tt.wantErr) } }) } } ================================================ FILE: db/simple.go ================================================ package db import ( "crypto/x509" "sync" "time" "github.com/pkg/errors" "github.com/smallstep/nosql/database" "golang.org/x/crypto/ssh" ) // ErrNotImplemented is an error returned when an operation is Not Implemented. var ErrNotImplemented = errors.Errorf("not implemented") // SimpleDB is a barebones implementation of the DB interface. It is NOT an // in memory implementation of the DB, but rather the bare minimum of // functionality that the CA requires to operate securely. type SimpleDB struct { usedTokens *sync.Map } func newSimpleDB(*Config) (*SimpleDB, error) { db := &SimpleDB{} db.usedTokens = new(sync.Map) return db, nil } // IsRevoked noop func (s *SimpleDB) IsRevoked(string) (bool, error) { return false, nil } // IsSSHRevoked noop func (s *SimpleDB) IsSSHRevoked(string) (bool, error) { return false, nil } // Revoke returns a "NotImplemented" error. func (s *SimpleDB) Revoke(*RevokedCertificateInfo) error { return ErrNotImplemented } // GetRevokedCertificates returns a "NotImplemented" error. func (s *SimpleDB) GetRevokedCertificates() (*[]RevokedCertificateInfo, error) { return nil, ErrNotImplemented } // GetCRL returns a "NotImplemented" error. func (s *SimpleDB) GetCRL() (*CertificateRevocationListInfo, error) { return nil, ErrNotImplemented } // StoreCRL returns a "NotImplemented" error. func (s *SimpleDB) StoreCRL(*CertificateRevocationListInfo) error { return ErrNotImplemented } // RevokeSSH returns a "NotImplemented" error. func (s *SimpleDB) RevokeSSH(*RevokedCertificateInfo) error { return ErrNotImplemented } // GetCertificate returns a "NotImplemented" error. func (s *SimpleDB) GetCertificate(string) (*x509.Certificate, error) { return nil, ErrNotImplemented } // StoreCertificate returns a "NotImplemented" error. func (s *SimpleDB) StoreCertificate(*x509.Certificate) error { return ErrNotImplemented } type usedToken struct { UsedAt int64 `json:"ua,omitempty"` Token string `json:"tok,omitempty"` } // UseToken returns a "NotImplemented" error. func (s *SimpleDB) UseToken(id, tok string) (bool, error) { if _, ok := s.usedTokens.LoadOrStore(id, &usedToken{ UsedAt: time.Now().Unix(), Token: tok, }); ok { // Token already exists in DB. return false, nil } // Successfully stored token. return true, nil } // IsSSHHost returns a "NotImplemented" error. func (s *SimpleDB) IsSSHHost(string) (bool, error) { return false, ErrNotImplemented } // StoreSSHCertificate returns a "NotImplemented" error. func (s *SimpleDB) StoreSSHCertificate(*ssh.Certificate) error { return ErrNotImplemented } // GetSSHHostPrincipals returns a "NotImplemented" error. func (s *SimpleDB) GetSSHHostPrincipals() ([]string, error) { return nil, ErrNotImplemented } // Shutdown returns nil func (s *SimpleDB) Shutdown() error { return nil } // nosql.DB interface implementation // // Open opens the database available with the given options. func (s *SimpleDB) Open(string, ...database.Option) error { return ErrNotImplemented } // Close closes the current database. func (s *SimpleDB) Close() error { return ErrNotImplemented } // Get returns the value stored in the given table/bucket and key. func (s *SimpleDB) Get([]byte, []byte) ([]byte, error) { return nil, ErrNotImplemented } // Set sets the given value in the given table/bucket and key. func (s *SimpleDB) Set([]byte, []byte, []byte) error { return ErrNotImplemented } // CmpAndSwap swaps the value at the given bucket and key if the current // value is equivalent to the oldValue input. Returns 'true' if the // swap was successful and 'false' otherwise. func (s *SimpleDB) CmpAndSwap([]byte, []byte, []byte, []byte) ([]byte, bool, error) { return nil, false, ErrNotImplemented } // Del deletes the data in the given table/bucket and key. func (s *SimpleDB) Del([]byte, []byte) error { return ErrNotImplemented } // List returns a list of all the entries in a given table/bucket. func (s *SimpleDB) List([]byte) ([]*database.Entry, error) { return nil, ErrNotImplemented } // Update performs a transaction with multiple read-write commands. func (s *SimpleDB) Update(*database.Tx) error { return ErrNotImplemented } // CreateTable creates a table or a bucket in the database. func (s *SimpleDB) CreateTable([]byte) error { return ErrNotImplemented } // DeleteTable deletes a table or a bucket in the database. func (s *SimpleDB) DeleteTable([]byte) error { return ErrNotImplemented } ================================================ FILE: db/simple_test.go ================================================ package db import ( "testing" "github.com/smallstep/assert" ) func TestSimpleDB(t *testing.T) { db, err := newSimpleDB(nil) assert.FatalError(t, err) // Revoke assert.Equals(t, ErrNotImplemented, db.Revoke(nil)) // IsRevoked -- verify noop isRevoked, err := db.IsRevoked("foo") assert.False(t, isRevoked) assert.Nil(t, err) // StoreCertificate assert.Equals(t, ErrNotImplemented, db.StoreCertificate(nil)) // UseToken ok, err := db.UseToken("foo", "bar") assert.True(t, ok) assert.Nil(t, err) ok, err = db.UseToken("foo", "cat") assert.False(t, ok) assert.Nil(t, err) // Shutdown -- verify noop assert.FatalError(t, db.Shutdown()) ok, err = db.UseToken("foo", "cat") assert.False(t, ok) assert.Nil(t, err) } ================================================ FILE: debian/copyright ================================================ Format: http://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ Upstream-Name: step-ca Source: https://github.com/smallstep/certificates Files: * Copyright: 2021 Smallstep Labs, Inc. License: Apache 2.0 License: Apache 2.0 Copyright (c) 2021 Smallstep Labs, Inc. . Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at . http://www.apache.org/licenses/LICENSE-2.0 . Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: docker/Dockerfile ================================================ FROM golang:alpine AS builder WORKDIR /src COPY . . RUN apk add --no-cache curl git make libcap RUN make V=1 bin/step-ca RUN setcap CAP_NET_BIND_SERVICE=+eip bin/step-ca FROM smallstep/step-kms-plugin:cloud AS kms FROM smallstep/step-cli:latest COPY --from=builder /src/bin/step-ca /usr/local/bin/step-ca COPY --from=kms /usr/local/bin/step-kms-plugin /usr/local/bin/step-kms-plugin USER step ENV CONFIGPATH="/home/step/config/ca.json" ENV PWDPATH="/home/step/secrets/password" VOLUME ["/home/step"] STOPSIGNAL SIGTERM HEALTHCHECK CMD step ca health 2>/dev/null | grep "^ok" >/dev/null COPY docker/entrypoint.sh /entrypoint.sh ENTRYPOINT ["/bin/bash", "/entrypoint.sh"] CMD ["/usr/local/bin/step-ca", "--password-file", "/home/step/secrets/password", "/home/step/config/ca.json"] ================================================ FILE: docker/Dockerfile.hsm ================================================ FROM golang:trixie AS builder WORKDIR /src COPY . . RUN apt-get update RUN apt-get install -y --no-install-recommends \ gcc pkgconf libpcsclite-dev libcap2-bin RUN make V=1 GO_ENVS="CGO_ENABLED=1" bin/step-ca RUN setcap CAP_NET_BIND_SERVICE=+eip bin/step-ca FROM smallstep/step-kms-plugin:trixie AS kms FROM smallstep/step-cli:trixie COPY --from=builder /src/bin/step-ca /usr/local/bin/step-ca COPY --from=kms /usr/local/bin/step-kms-plugin /usr/local/bin/step-kms-plugin USER root RUN apt-get update RUN apt-get install -y --no-install-recommends opensc opensc-pkcs11 pcscd gnutls-bin libpcsclite1 p11-kit yubihsm-pkcs11 RUN mkdir -p /run/pcscd RUN chown step:step /run/pcscd USER step ENV CONFIGPATH="/home/step/config/ca.json" ENV PWDPATH="/home/step/secrets/password" VOLUME ["/home/step"] STOPSIGNAL SIGTERM HEALTHCHECK CMD step ca health 2>/dev/null | grep "^ok" >/dev/null COPY docker/entrypoint.sh /entrypoint.sh ENTRYPOINT ["/bin/bash", "/entrypoint.sh"] CMD ["/usr/local/bin/step-ca", "--password-file", "/home/step/secrets/password", "/home/step/config/ca.json"] ================================================ FILE: docker/entrypoint.sh ================================================ #!/bin/bash set -eo pipefail # Paraphrased from: # https://github.com/influxdata/influxdata-docker/blob/0d341f18067c4652dfa8df7dcb24d69bf707363d/influxdb/2.0/entrypoint.sh # (a repo with no LICENSE.md) export STEPPATH=$(step path) # List of env vars required for step ca init declare -ra REQUIRED_INIT_VARS=(DOCKER_STEPCA_INIT_NAME DOCKER_STEPCA_INIT_DNS_NAMES) # Ensure all env vars required to run step ca init are set. function init_if_possible () { local missing_vars=0 for var in "${REQUIRED_INIT_VARS[@]}"; do if [ -z "${!var}" ]; then missing_vars=1 fi done if [ ${missing_vars} = 1 ]; then >&2 echo "there is no ca.json config file; please run step ca init, or provide config parameters via DOCKER_STEPCA_INIT_ vars" else step_ca_init "${@}" fi } function generate_password () { set +o pipefail < /dev/urandom tr -dc A-Za-z0-9 | head -c40 echo set -o pipefail } # Initialize a CA if not already initialized function step_ca_init () { DOCKER_STEPCA_INIT_PROVISIONER_NAME="${DOCKER_STEPCA_INIT_PROVISIONER_NAME:-admin}" DOCKER_STEPCA_INIT_ADMIN_SUBJECT="${DOCKER_STEPCA_INIT_ADMIN_SUBJECT:-step}" DOCKER_STEPCA_INIT_ADDRESS="${DOCKER_STEPCA_INIT_ADDRESS:-:9000}" DOCKER_STEPCA_INIT_ROOT_FILE="${DOCKER_STEPCA_INIT_ROOT_FILE:-"/run/secrets/root_ca.crt"}" DOCKER_STEPCA_INIT_KEY_FILE="${DOCKER_STEPCA_INIT_KEY_FILE:-"/run/secrets/root_ca_key"}" DOCKER_STEPCA_INIT_KEY_PASSWORD_FILE="${DOCKER_STEPCA_INIT_KEY_PASSWORD_FILE:-"/run/secrets/root_ca_key_password"}" local -a setup_args=( --name "${DOCKER_STEPCA_INIT_NAME}" --dns "${DOCKER_STEPCA_INIT_DNS_NAMES}" --provisioner "${DOCKER_STEPCA_INIT_PROVISIONER_NAME}" --password-file "${STEPPATH}/password" --provisioner-password-file "${STEPPATH}/provisioner_password" --address "${DOCKER_STEPCA_INIT_ADDRESS}" ) if [ -n "${DOCKER_STEPCA_INIT_PASSWORD_FILE}" ]; then cat < "${DOCKER_STEPCA_INIT_PASSWORD_FILE}" > "${STEPPATH}/password" cat < "${DOCKER_STEPCA_INIT_PASSWORD_FILE}" > "${STEPPATH}/provisioner_password" elif [ -n "${DOCKER_STEPCA_INIT_PASSWORD}" ]; then echo "${DOCKER_STEPCA_INIT_PASSWORD}" > "${STEPPATH}/password" echo "${DOCKER_STEPCA_INIT_PASSWORD}" > "${STEPPATH}/provisioner_password" else generate_password > "${STEPPATH}/password" generate_password > "${STEPPATH}/provisioner_password" fi if [ -f "${DOCKER_STEPCA_INIT_ROOT_FILE}" ]; then setup_args=("${setup_args[@]}" --root "${DOCKER_STEPCA_INIT_ROOT_FILE}") fi if [ -f "${DOCKER_STEPCA_INIT_KEY_FILE}" ]; then setup_args=("${setup_args[@]}" --key "${DOCKER_STEPCA_INIT_KEY_FILE}") fi if [ -f "${DOCKER_STEPCA_INIT_KEY_PASSWORD_FILE}" ]; then setup_args=("${setup_args[@]}" --key-password-file "${DOCKER_STEPCA_INIT_KEY_PASSWORD_FILE}") fi if [ -n "${DOCKER_STEPCA_INIT_DEPLOYMENT_TYPE}" ]; then setup_args=("${setup_args[@]}" --deployment-type "${DOCKER_STEPCA_INIT_DEPLOYMENT_TYPE}") fi if [ -n "${DOCKER_STEPCA_INIT_WITH_CA_URL}" ]; then setup_args=("${setup_args[@]}" --with-ca-url "${DOCKER_STEPCA_INIT_WITH_CA_URL}") fi if [ "${DOCKER_STEPCA_INIT_SSH}" == "true" ]; then setup_args=("${setup_args[@]}" --ssh) fi if [ "${DOCKER_STEPCA_INIT_ACME}" == "true" ]; then setup_args=("${setup_args[@]}" --acme) fi if [ "${DOCKER_STEPCA_INIT_REMOTE_MANAGEMENT}" == "true" ]; then setup_args=("${setup_args[@]}" --remote-management --admin-subject "${DOCKER_STEPCA_INIT_ADMIN_SUBJECT}" ) fi step ca init "${setup_args[@]}" echo "" if [ "${DOCKER_STEPCA_INIT_REMOTE_MANAGEMENT}" == "true" ]; then echo "👉 Your CA administrative username is: ${DOCKER_STEPCA_INIT_ADMIN_SUBJECT}" fi echo "👉 Your CA administrative password is: $(< $STEPPATH/provisioner_password )" echo "🤫 This will only be displayed once." shred -u $STEPPATH/provisioner_password mv $STEPPATH/password $PWDPATH } if [ -f /usr/sbin/pcscd ]; then /usr/sbin/pcscd fi if [ ! -f "${STEPPATH}/config/ca.json" ]; then init_if_possible fi exec "${@}" ================================================ FILE: errs/error.go ================================================ package errs import ( "encoding/json" "fmt" "net/http" "github.com/pkg/errors" "github.com/smallstep/certificates/api/log" "github.com/smallstep/certificates/api/render" ) // Option modifies the Error type. type Option func(e *Error) error // withDefaultMessage returns an Option that modifies the error by overwriting // the message only if it is empty. Having withDefaultMessage and // withFormattedMessage avoid vet errors when the "format" passed to // "fmt.Sprintf" is not a constant. func withDefaultMessage(message string) Option { return func(e *Error) error { if e.Msg != "" { return e } e.Msg = message return e } } // withFormattedMessage returns an Option that modifies the error by overwriting // the formatted message only if it is empty. func withFormattedMessage(format string, args ...interface{}) Option { return func(e *Error) error { if e.Msg != "" { return e } e.Msg = fmt.Sprintf(format, args...) return e } } // WithMessage returns an Option that modifies the error by overwriting the // message with the formatted string. func WithMessage(format string, args ...interface{}) Option { return func(e *Error) error { e.Msg = fmt.Sprintf(format, args...) return e } } // WithErrorMessage returns an Option that modifies the error by overwriting the // message with the error string. func WithErrorMessage() Option { return func(e *Error) error { e.Msg = e.Error() return e } } // WithKeyVal returns an Option that adds the given key-value pair to the // Error details. This is helpful for debugging errors. func WithKeyVal(key string, val interface{}) Option { return func(e *Error) error { if e.Details == nil { e.Details = make(map[string]interface{}) } e.Details[key] = val return e } } // Error represents the CA API errors. type Error struct { Status int Err error Msg string Details map[string]interface{} RequestID string `json:"-"` } // ErrorResponse represents an error in JSON format. type ErrorResponse struct { Status int `json:"status"` Message string `json:"message"` } // Unwrap implements the Unwrap interface and returns the original error. func (e *Error) Unwrap() error { return e.Err } // Cause implements the errors.Causer interface and returns the original error. func (e *Error) Cause() error { return e.Err } // Error implements the error interface and returns the error string. func (e *Error) Error() string { return e.Err.Error() } // StatusCode implements the StatusCoder interface and returns the HTTP response // code. func (e *Error) StatusCode() int { return e.Status } // Message returns a user friendly error, if one is set. func (e *Error) Message() string { if e.Msg != "" { return e.Msg } return e.Err.Error() } // Wrap returns an error annotating err with a stack trace at the point Wrap is // called, and the supplied message. If err is nil, Wrap returns nil. func Wrap(status int, e error, m string, args ...interface{}) error { if e == nil { return nil } _, opts := splitOptionArgs(args) var err *Error if errors.As(e, &err) { err.Err = errors.Wrap(err.Err, m) e = err } else { e = errors.Wrap(e, m) } return StatusCodeError(status, e, opts...) } // Wrapf returns an error annotating err with a stack trace at the point Wrap is // called, and the supplied message. If err is nil, Wrap returns nil. func Wrapf(status int, e error, format string, args ...interface{}) error { if e == nil { return nil } as, opts := splitOptionArgs(args) var err *Error if errors.As(e, &err) { err.Err = errors.Wrapf(err.Err, format, args...) e = err } else { e = errors.Wrapf(e, format, as...) } return StatusCodeError(status, e, opts...) } // MarshalJSON implements json.Marshaller interface for the Error struct. func (e *Error) MarshalJSON() ([]byte, error) { var msg string if e.Msg != "" { msg = e.Msg } else { msg = http.StatusText(e.Status) } return json.Marshal(&ErrorResponse{Status: e.Status, Message: msg}) } // UnmarshalJSON implements json.Unmarshaler interface for the Error struct. func (e *Error) UnmarshalJSON(data []byte) error { var er ErrorResponse if err := json.Unmarshal(data, &er); err != nil { return err } e.Status = er.Status e.Err = fmt.Errorf("%s", er.Message) return nil } // Format implements the fmt.Formatter interface. func (e *Error) Format(f fmt.State, c rune) { var fe fmt.Formatter if errors.As(e.Err, &fe) { fe.Format(f, c) return } fmt.Fprint(f, e.Err.Error()) } // Messenger is a friendly message interface that errors can implement. type Messenger interface { Message() string } // StatusCodeError selects the proper error based on the status code. func StatusCodeError(code int, e error, opts ...Option) error { switch code { case http.StatusBadRequest: opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) return NewErr(http.StatusBadRequest, e, opts...) case http.StatusUnauthorized: return UnauthorizedErr(e, opts...) case http.StatusForbidden: opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg)) return NewErr(http.StatusForbidden, e, opts...) case http.StatusInternalServerError: return InternalServerErr(e, opts...) case http.StatusNotImplemented: return NotImplementedErr(e, opts...) default: return UnexpectedErr(code, e, opts...) } } const ( seeLogs = "Please see the certificate authority logs for more info." defaultMsg = "The requested could not be completed. " + seeLogs // BadRequestDefaultMsg 400 default msg BadRequestDefaultMsg = "The request could not be completed; malformed or missing data. " + seeLogs // UnauthorizedDefaultMsg 401 default msg UnauthorizedDefaultMsg = "The request lacked necessary authorization to be completed. " + seeLogs // ForbiddenDefaultMsg 403 default msg ForbiddenDefaultMsg = "The request was forbidden by the certificate authority. " + seeLogs // NotFoundDefaultMsg 404 default msg NotFoundDefaultMsg = "The requested resource could not be found. " + seeLogs // InternalServerErrorDefaultMsg 500 default msg InternalServerErrorDefaultMsg = "The certificate authority encountered an Internal Server Error. " + seeLogs // NotImplementedDefaultMsg 501 default msg NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs ) func defaultMessage(status int) string { switch status { case http.StatusBadRequest: return BadRequestDefaultMsg case http.StatusUnauthorized: return UnauthorizedDefaultMsg case http.StatusForbidden: return ForbiddenDefaultMsg case http.StatusNotFound: return NotFoundDefaultMsg case http.StatusInternalServerError: return InternalServerErrorDefaultMsg case http.StatusNotImplemented: return NotImplementedDefaultMsg default: return defaultMsg } } const ( // BadRequestPrefix is the prefix added to the bad request messages that are // directly sent to the cli. BadRequestPrefix = "The request could not be completed: " // ForbiddenPrefix is the prefix added to the forbidden messates that are // sent to the cli. ForbiddenPrefix = "The request was forbidden by the certificate authority: " ) func formatMessage(status int, msg string) string { switch status { case http.StatusBadRequest: return BadRequestPrefix + msg + "." case http.StatusForbidden: return ForbiddenPrefix + msg + "." default: return msg } } // splitOptionArgs splits the variadic length args into string formatting args // and Option(s) to apply to an Error. func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { indexOptionStart := -1 for i, a := range args { if _, ok := a.(Option); ok { indexOptionStart = i break } } if indexOptionStart < 0 { return args, []Option{} } opts := []Option{} // Ignore any non-Option args that come after the first Option. for _, o := range args[indexOptionStart:] { if opt, ok := o.(Option); ok { opts = append(opts, opt) } } return args[:indexOptionStart], opts } // New creates a new http error with the given status and message. func New(status int, format string, args ...interface{}) error { msg := fmt.Sprintf(format, args...) return &Error{ Status: status, Msg: formatMessage(status, msg), Err: errors.New(msg), } } // NewError creates a new http error with the given error and message. func NewError(status int, err error, format string, args ...interface{}) error { var e *Error if errors.As(err, &e) { return err } msg := fmt.Sprintf(format, args...) var ste log.StackTracedError if !errors.As(err, &ste) { err = errors.Wrap(err, msg) } return &Error{ Status: status, Msg: formatMessage(status, msg), Err: err, } } // NewErr returns a new Error. If the given error implements the StatusCoder // interface we will ignore the given status. func NewErr(status int, err error, opts ...Option) error { var e *Error if !errors.As(err, &e) { var ste render.StatusCodedError if errors.As(err, &ste) { e = &Error{Status: ste.StatusCode(), Err: err} } else { e = &Error{Status: status, Err: err} } } for _, o := range opts { o(e) } return e } // Errorf creates a new error using the given format and status code. func Errorf(code int, format string, args ...interface{}) error { as, opts := splitOptionArgs(args) opts = append(opts, withDefaultMessage(defaultMessage(code))) e := &Error{Status: code, Err: fmt.Errorf(format, as...)} for _, o := range opts { o(e) } return e } // ApplyOptions applies the given options to the error if is the type *Error. // TODO(mariano): try to get rid of this. func ApplyOptions(err error, opts ...interface{}) error { var e *Error if errors.As(err, &e) { _, o := splitOptionArgs(opts) for _, fn := range o { fn(e) } } return err } // InternalServer creates a 500 error with the given format and arguments. func InternalServer(format string, args ...interface{}) error { args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg)) return Errorf(http.StatusInternalServerError, format, args...) } // InternalServerErr returns a 500 error with the given error. func InternalServerErr(err error, opts ...Option) error { opts = append(opts, withDefaultMessage(InternalServerErrorDefaultMsg)) return NewErr(http.StatusInternalServerError, err, opts...) } // NotImplemented creates a 501 error with the given format and arguments. func NotImplemented(format string, args ...interface{}) error { args = append(args, withDefaultMessage(NotImplementedDefaultMsg)) return Errorf(http.StatusNotImplemented, format, args...) } // NotImplementedErr returns a 501 error with the given error. func NotImplementedErr(err error, opts ...Option) error { opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg)) return NewErr(http.StatusNotImplemented, err, opts...) } // BadRequest creates a 400 error with the given format and arguments. func BadRequest(format string, args ...interface{}) error { return New(http.StatusBadRequest, format, args...) } // BadRequestErr returns an 400 error with the given error. func BadRequestErr(err error, format string, args ...interface{}) error { return NewError(http.StatusBadRequest, err, format, args...) } // Unauthorized creates a 401 error with the given format and arguments. func Unauthorized(format string, args ...interface{}) error { args = append(args, withDefaultMessage(UnauthorizedDefaultMsg)) return Errorf(http.StatusUnauthorized, format, args...) } // UnauthorizedErr returns an 401 error with the given error. func UnauthorizedErr(err error, opts ...Option) error { opts = append(opts, withDefaultMessage(UnauthorizedDefaultMsg)) return NewErr(http.StatusUnauthorized, err, opts...) } // Forbidden creates a 403 error with the given format and arguments. func Forbidden(format string, args ...interface{}) error { return New(http.StatusForbidden, format, args...) } // ForbiddenErr returns an 403 error with the given error. func ForbiddenErr(err error, format string, args ...interface{}) error { return NewError(http.StatusForbidden, err, format, args...) } // NotFound creates a 404 error with the given format and arguments. func NotFound(format string, args ...interface{}) error { args = append(args, withDefaultMessage(NotFoundDefaultMsg)) return Errorf(http.StatusNotFound, format, args...) } // NotFoundErr returns an 404 error with the given error. func NotFoundErr(err error, opts ...Option) error { opts = append(opts, withDefaultMessage(NotFoundDefaultMsg)) return NewErr(http.StatusNotFound, err, opts...) } // UnexpectedErr will be used when the certificate authority makes an outgoing // request and receives an unhandled status code. func UnexpectedErr(code int, err error, opts ...Option) error { opts = append(opts, withFormattedMessage("The certificate authority received an unexpected HTTP status code - '%d'. "+seeLogs, code)) return NewErr(code, err, opts...) } ================================================ FILE: errs/errors_test.go ================================================ package errs import ( "errors" "fmt" "net/http" "testing" "github.com/stretchr/testify/assert" ) func TestError_MarshalJSON(t *testing.T) { type fields struct { Status int Err error } tests := []struct { name string fields fields want []byte wantErr bool }{ {"ok", fields{400, fmt.Errorf("bad request")}, []byte(`{"status":400,"message":"Bad Request"}`), false}, {"ok no error", fields{500, nil}, []byte(`{"status":500,"message":"Internal Server Error"}`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := &Error{ Status: tt.fields.Status, Err: tt.fields.Err, } got, err := e.MarshalJSON() if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) return } assert.NoError(t, err) assert.Equal(t, tt.want, got) }) } } func TestError_UnmarshalJSON(t *testing.T) { type args struct { data []byte } tests := []struct { name string args args expected *Error wantErr bool }{ {"ok", args{[]byte(`{"status":400,"message":"bad request"}`)}, &Error{Status: 400, Err: fmt.Errorf("bad request")}, false}, {"fail", args{[]byte(`{"status":"400","message":"bad request"}`)}, &Error{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := new(Error) err := e.UnmarshalJSON(tt.args.data) if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) assert.Equal(t, tt.expected, e) }) } } func TestError_Unwrap(t *testing.T) { err := errors.New("wrapped error") tests := []struct { name string error error want string }{ {"ok New", New(http.StatusBadRequest, "some error"), "some error"}, {"ok New v-wrap", New(http.StatusBadRequest, "some error: %v", err), "some error: wrapped error"}, {"ok NewError", NewError(http.StatusBadRequest, err, "some error"), "some error: wrapped error"}, {"ok NewErr", NewErr(http.StatusBadRequest, err), "wrapped error"}, {"ok NewErr wit message", NewErr(http.StatusBadRequest, err, WithMessage("some message")), "wrapped error"}, {"ok Errorf", Errorf(http.StatusBadRequest, "some error: %w", err), "some error: wrapped error"}, {"ok Errorf v-wrap", Errorf(http.StatusBadRequest, "some error: %v", err), "some error: wrapped error"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := errors.Unwrap(tt.error) assert.EqualError(t, got, tt.want) }) } } type customError struct { Message string } func (e *customError) Error() string { return e.Message } func TestError_Unwrap_As(t *testing.T) { err := &customError{Message: "wrapped error"} tests := []struct { name string error error want bool wantErr *customError }{ {"ok NewError", NewError(http.StatusBadRequest, err, "some error"), true, err}, {"ok NewErr", NewErr(http.StatusBadRequest, err), true, err}, {"ok NewErr wit message", NewErr(http.StatusBadRequest, err, WithMessage("some message")), true, err}, {"ok Errorf", Errorf(http.StatusBadRequest, "some error: %w", err), true, err}, {"fail New", New(http.StatusBadRequest, "some error"), false, nil}, {"fail New v-wrap", New(http.StatusBadRequest, "some error: %v", err), false, nil}, {"fail Errorf", Errorf(http.StatusBadRequest, "some error"), false, nil}, {"fail Errorf v-wrap", Errorf(http.StatusBadRequest, "some error: %v", err), false, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var cerr *customError assert.Equal(t, tt.want, errors.As(tt.error, &cerr)) assert.Equal(t, tt.wantErr, cerr) }) } } func TestErrorf(t *testing.T) { tests := []struct { name string code int format string args []any want error }{ {"bad request", 400, "test error string", nil, &Error{ Status: 400, Err: errors.New("test error string"), Msg: BadRequestDefaultMsg, }}, {"unauthorized", 401, "test error string", nil, &Error{ Status: 401, Err: errors.New("test error string"), Msg: UnauthorizedDefaultMsg, }}, {"forbidden", 403, "test error string", nil, &Error{ Status: 403, Err: errors.New("test error string"), Msg: ForbiddenDefaultMsg, }}, {"not found", 404, "test error string", nil, &Error{ Status: 404, Err: errors.New("test error string"), Msg: NotFoundDefaultMsg, }}, {"internal server error", 500, "test error string", nil, &Error{ Status: 500, Err: errors.New("test error string"), Msg: InternalServerErrorDefaultMsg, }}, {"not implemented", 501, "test error string", nil, &Error{ Status: 501, Err: errors.New("test error string"), Msg: NotImplementedDefaultMsg, }}, {"other", 502, "test error string", nil, &Error{ Status: 502, Err: errors.New("test error string"), Msg: defaultMsg, }}, {"formatted args", 401, "test error string: %s", []any{"some reason"}, &Error{ Status: 401, Err: errors.New("test error string: some reason"), Msg: UnauthorizedDefaultMsg, }}, {"WithMessage", 403, "test error string", []any{WithMessage("%s failed", "something")}, &Error{ Status: 403, Err: errors.New("test error string"), Msg: "something failed", }}, {"WithErrorMessage", 404, "test error string", []any{WithErrorMessage()}, &Error{ Status: 404, Err: errors.New("test error string"), Msg: "test error string", }}, {"WithKeyValue", 500, "test error string", []any{WithKeyVal("foo", 1), WithKeyVal("bar", "zar")}, &Error{ Status: 500, Err: errors.New("test error string"), Msg: InternalServerErrorDefaultMsg, Details: map[string]interface{}{"foo": 1, "bar": "zar"}, }}, {"withDefaultMessage", 501, "test error string", []any{withDefaultMessage("some message")}, &Error{ Status: 501, Err: errors.New("test error string"), Msg: "some message", }}, {"withFormattedMessage", 502, "test error string", []any{withFormattedMessage("some message: %s", "the reason")}, &Error{ Status: 502, Err: errors.New("test error string"), Msg: "some message: the reason", }}, {"WithMessage and withDefaultMessage", 500, "test error string", []any{WithMessage("the message"), withDefaultMessage("some message")}, &Error{ Status: 500, Err: errors.New("test error string"), Msg: "the message", }}, {"WithErrorMessage and withFormattedMessage", 500, "test error string", []any{WithErrorMessage(), withFormattedMessage("some message: %s", "the reason")}, &Error{ Status: 500, Err: errors.New("test error string"), Msg: "test error string", }}, {"formatted args and withMessage", 500, "test error string: %s, code %d", []any{"reason", 1234, WithMessage("the message")}, &Error{ Status: 500, Err: errors.New("test error string: reason, code 1234"), Msg: "the message", }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotErr := Errorf(tt.code, tt.format, tt.args...) assert.Equal(t, tt.want, gotErr) }) } } ================================================ FILE: examples/README.md ================================================ # Examples ## Basic client usage The basic-client example shows the functionality of the `ca.Client` type. The methods work as an SDK for integrating services with the Certificate Authority (CA). In [basic-client/client.go](/examples/basic-client/client.go) we see the initialization of a client: ```go client, err := ca.NewClient("https://localhost:9000", ca.WithRootSHA256("84a033e84196f73bd593fad7a63e509e57fd982f02084359c4e8c5c864efc27d")) ``` The previous code uses the CA address and the root certificate fingerprint. The CA url will be present in the token, and the root fingerprint can be present too if the `--root root_ca.crt` option is used in the creation of the token. If the token does contain the root fingerprint then it is simpler to use: ```go client, err := ca.Bootstrap(token) ``` After the initialization, there are examples of all the client methods. These methods are a convenient way to use the CA API. The first method, `Health`, returns the status of the CA server. If the server is up it will return `{"status":"ok"}`. ```go health, err := client.Health() // Health is a struct created from the JSON response {"status": "ok"} ``` The next method `Root` is used to get and verify the root certificate. We pass a fingerprint and it downloads the root certificate from the CA and verifies that the fingerprint matches. This method uses an insecure HTTP client as it might be used in the initialization of the client, but the response is considered secure because we have compared against the expected digest. ```go root, err := client.Root("84a033e84196f73bd593fad7a63e509e57fd982f02084359c4e8c5c864efc27d") ``` Next we have the most important method; `Sign`. `Sign` will authorize and sign a CSR (Certificate Signing Request) that we provide. To authorize this request we use a provisioning token issued by an authorized provisioner. You can build your own certificate request and add it in the `*api.SignRequest`, but our CA SDK contains a method that will generate a secure random key and create a CSR - combining the key with the information provided in the provisioning token. ```go // Create a CSR from a token and return the SignRequest, the private key, and an // error if something failed. req, pk, err := ca.CreateSignRequest(token) if err != nil { ... } // Do the Sign request and return the signed certificate. sign, err := client.Sign(req) if err != nil { ... } ``` Next is the `Renew` method which is used to (you guessed it!) renew certificates. Certificate renewal relies on a mTLS connection with using an existing certificate. So, as input we will need to pass a transport with the current certificate. ```go // Get a cancelable context to stop the renewal goroutines and timers. ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Create a transport with the sign response and the private key. tr, err := client.Transport(ctx, sign, pk) if err != nil { ... } // Renew the certificate. The return type is equivalent to the Sign method. renew, err := client.Renew(tr) if err != nil { ... } ``` The following methods are for inpsecting Provisioners. One method that returns a list of provisioners or an encrypted key of one provisioner. ```go // Without options it will return the first 20 provisioners. provisioners, err := client.Provisioners() // We can also set a limit up to 100. provisioners, err := client.Provisioners(ca.WithProvisionerLimit(100)) // With a pagination cursor. provisioners, err := client.Provisioners(ca.WithProvisionerCursor("1f18c1ecffe54770e9107ce7b39b39735")) // Or combine both. provisioners, err := client.Provisioners( ca.WithProvisionerCursor("1f18c1ecffe54770e9107ce7b39b39735"), ca.WithProvisionerLimit(100), ) // Return the encrypted key of one of the returned provisioners. The key // returned is an encrypted JWE with the private key used to sign tokens. key, err := client.ProvisionerKey("DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk") ``` The following example shows how to create a tls.Config object that can be injected into servers and clients. By default, these methods will spin off Go routines that auto-renew a certificate once (approximately) two thirds of the duration of the certificate has passed. ```go // Get a cancelable context to stop the renewal goroutines and timers. ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Get tls.Config for a server. tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk) // Get tls.Config for a client. tlsConfig, err := client.GetClientTLSConfig(ctx, sign, pk) // Get an http.Transport for a client; this can be used as a http.RoundTripper // in an http.Client. tr, err := client.Transport(ctx, sign, pk) ``` To run the example you need to start the certificate authority: ```sh certificates $ bin/step-ca examples/pki/config/ca.json 2018/11/02 18:29:25 Serving HTTPS on :9000 ... ``` Then run client.go with a new token: ```sh certificates $ export STEPPATH=examples/pki certificates $ export STEP_CA_URL=https://localhost:9000 certificates $ go run examples/basic-client/client.go $(step ca token client.smallstep.com) ``` ## Bootstrap Client & Server In this example we are going run the CA alongside a simple Server using TLS and a simple client making TLS requests to the server. The examples directory already contains a sample pki configuration with the password `password` hardcoded, but you can create your own using `step ca init`. These examples show the use of some other helper methods - simple ways to create TLS configured http.Server and http.Client objects. The methods are `BootstrapServer` and `BootstrapClient`. ```go // Get a cancelable context to stop the renewal goroutines and timers. ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Create an http.Server that requires a client certificate srv, err := ca.BootstrapServer(ctx, token, &http.Server{ Addr: ":8443", Handler: handler, }) if err != nil { panic(err) } srv.ListenAndServeTLS("", "") ``` ```go // Get a cancelable context to stop the renewal goroutines and timers. ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Create an http.Server that does not require a client certificate srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{ Addr: ":8443", Handler: handler, }, ca.VerifyClientCertIfGiven()) if err != nil { panic(err) } srv.ListenAndServeTLS("", "") ``` ```go // Get a cancelable context to stop the renewal goroutines and timers. ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Create an http.Client client, err := ca.BootstrapClient(ctx, token) if err != nil { panic(err) } resp, err := client.Get("https://localhost:8443") ``` We will demonstrate the mTLS configuration in a different example. In this example we will configure the server to only verify client certificates if they are provided. To being with let's start the Step CA: ```sh certificates $ bin/step-ca examples/pki/config/ca.json 2018/11/02 18:29:25 Serving HTTPS on :9000 ... ``` Next we will start the bootstrap-tls-server and enter `password` prompted for the provisioner password: ```sh certificates $ export STEPPATH=examples/pki certificates $ export STEP_CA_URL=https://localhost:9000 certificates $ go run examples/bootstrap-tls-server/server.go $(step ca token localhost) ✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com) Please enter the password to decrypt the provisioner key: Listening on :8443 ... ``` Let's try to cURL our new bootstrap server with the system certificates bundle as our root. It should fail. ``` certificates $ curl https://localhost:8443 curl: (60) SSL certificate problem: unable to get local issuer certificate More details here: https://curl.haxx.se/docs/sslcerts.html curl performs SSL certificate verification by default, using a "bundle" of Certificate Authority (CA) public keys (CA certs). If the default bundle file isn't adequate, you can specify an alternate file using the --cacert option. If this HTTPS server uses a certificate signed by a CA represented in the bundle, the certificate verification probably failed due to a problem with the certificate (it might be expired, or the name might not match the domain name in the URL). If you'd like to turn off curl's verification of the certificate, use the -k (or --insecure) option. HTTPS-proxy has similar options --proxy-cacert and --proxy-insecure. ``` Now let's use the root certificate generated for the Step PKI. It should work. ```sh certificates $ curl --cacert examples/pki/secrets/root_ca.crt https://localhost:8443 Hello nobody at 2018-11-03 01:49:25.66912 +0000 UTC!!! ``` Notice that in the response we see `nobody`. This is because the server did not detected a TLS client configuration. But if we create a client with its own certificate (generated by the Step CA), we should see the Common Name of the client certificate: ```sh certificates $ export STEPPATH=examples/pki certificates $ export STEP_CA_URL=https://localhost:9000 certificates $ go run examples/bootstrap-client/client.go $(step ca token Mike) ✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com) Please enter the password to decrypt the provisioner key: Server responded: Hello Mike at 2018-11-03 01:52:52.678215 +0000 UTC!!! Server responded: Hello Mike at 2018-11-03 01:52:53.681563 +0000 UTC!!! Server responded: Hello Mike at 2018-11-03 01:52:54.682787 +0000 UTC!!! ... ``` ## Bootstrap mTLS Client & Server This example demonstrates a stricter configuration of the bootstrap-server. Here we configure the server to require mTLS (mutual TLS) with a valid client certificate. As always, we begin by starting the CA: ```sh certificates $ bin/step-ca examples/pki/config/ca.json 2018/11/02 18:29:25 Serving HTTPS on :9000 ... ``` Next we start the mTLS server and we enter `password` when prompted for the provisioner password: ```sh certificates $ export STEPPATH=examples/pki certificates $ export STEP_CA_URL=https://localhost:9000 certificates $ go run examples/bootstrap-mtls-server/server.go $(step ca token localhost) ✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com) Please enter the password to decrypt the provisioner key: Listening on :8443 ... ``` Now that the server is configured to require mTLS cURL-ing should fail even if we use the correct root certificate bundle. ```sh certificates $ curl --cacert examples/pki/secrets/root_ca.crt https://localhost:8443 curl: (35) error:1401E412:SSL routines:CONNECT_CR_FINISHED:sslv3 alert bad certificate ``` However, if we use our client (which requests a certificate from the Step CA when it starts): ```sh certificates $ export STEPPATH=examples/pki certificates $ export STEP_CA_URL=https://localhost:9000 certificates $ go run examples/bootstrap-client/client.go $(step ca token Mike) ✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com) Please enter the password to decrypt the provisioner key: Server responded: Hello Mike at 2018-11-07 21:54:00.140022 +0000 UTC!!! Server responded: Hello Mike at 2018-11-07 21:54:01.140827 +0000 UTC!!! Server responded: Hello Mike at 2018-11-07 21:54:02.141578 +0000 UTC!!! ... ``` ## Certificate rotation We can use the bootstrap-server to demonstrate certificate rotation. We've added a second provisioner, named `mike@smallstep.com`, to the CA configuration. This provisioner is has a default certificate duration of 2 minutes. Let's run the server, and inspect the certificate. We can should be able to see the certificate rotate once approximately 2/3rds of its lifespan has passed. ```sh certificates $ export STEPPATH=examples/pki certificates $ export STEP_CA_URL=https://localhost:9000 certificates $ go run examples/bootstrap-server/server.go $(step ca token localhost) ✔ Key ID: YYNxZ0rq0WsT2MlqLCWvgme3jszkmt99KjoGEJJwAKs (mike@smallstep.com) Please enter the password to decrypt the provisioner key: Listening on :8443 ... ``` In this case, the certificate will rotate after 74-80 seconds. The exact formula is `-/3-rand(/20)` (`duration=120` in our example). We can use the following command to check the certificate expiration and to make sure the certificate changes after 74-80 seconds. ```sh certificates $ step certificate inspect --insecure https://localhost:8443 ``` ## NGINX with Step CA certificates The example under the `docker` directory shows how to combine the Step CA with NGINX to serve or proxy services using certificates created by the Step CA. This example creates 3 different docker images: * nginx-test: docker image with NGINX and a script using inotify-tools to watch for changes in the certificate to reload NGINX. * step-ca-test: docker image with the Step CA * step-renewer-test: docker image with the step cli tool - it creates the certificate and sets a cron that renews the certificate (the cron runs every minute for testing purposes). To run this test you need to have the docker daemon running. With docker running swith to the `examples/docker directory` and run `make`: ``` certificates $ cd examples/docker/ docker $ make GOOS=linux go build -o ca/step-ca github.com/smallstep/certificates/cmd/step-ca GOOS=linux go build -o renewer/step github.com/smallstep/cli/cmd/step docker build -t nginx-test:latest nginx ... docker-compose up WARNING: The Docker Engine you're using is running in swarm mode. Compose does not use swarm mode to deploy services to multiple nodes in a swarm. All containers will be scheduled on the current node. To deploy your application across the swarm, use `docker stack deploy`. Creating network "docker_default" with the default driver Creating docker_ca_1 ... done Creating docker_renewer_1 ... done Creating docker_nginx_1 ... done Attaching to docker_ca_1, docker_renewer_1, docker_nginx_1 ca_1 | 2018/11/12 19:39:16 Serving HTTPS on :443 ... nginx_1 | Setting up watches. nginx_1 | Watches established. ... ``` Make will build the binaries for step and step-ca, create the images, create the containers and start them using docker composer. NGINX will be listening on your local machine on https://localhost:4443, but to make sure the cert is right we need to add the following entry to `/etc/hosts`: ``` 127.0.0.1 nginx ``` Now we can use cURL to verify: ```sh docker $ curl --cacert ca/pki/secrets/root_ca.crt https://nginx:4443/ Welcome to nginx!

Welcome to nginx!

If you see this page, the nginx web server is successfully installed and working. Further configuration is required.

For online documentation and support please refer to nginx.org.
Commercial support is available at nginx.com.

Thank you for using nginx.

``` We can use `make inspect` to witness the certificate being rotated every minute. ```sh docker $ make inspect | head step certificate inspect https://localhost:4443 --insecure Certificate: Data: Version: 3 (0x2) Serial Number: 220353801925419530569669982276277771655 (0xa5c6993a7e110e6f009c83c79edc1d87) Signature Algorithm: ECDSA-SHA256 Issuer: CN=Smallstep Intermediate CA Validity Not Before: Nov 10 02:13:00 2018 UTC Not After : Nov 11 02:13:00 2018 UTC docker $ make inspect | head step certificate inspect https://localhost:4443 --insecure Certificate: Data: Version: 3 (0x2) Serial Number: 207756171799719353821615361892302471392 (0x9c4c621c04d3e8be401ff0d14c5440e0) Signature Algorithm: ECDSA-SHA256 Issuer: CN=Smallstep Intermediate CA Validity Not Before: Nov 10 02:14:00 2018 UTC Not After : Nov 11 02:14:00 2018 UTC ``` Finally, to cleanup the containers and volumes created in this demo use `make down`: ```sh docker $ make down docker-compose down Stopping docker_nginx_1 ... done Stopping docker_renewer_1 ... done Stopping docker_ca_1 ... done Removing docker_nginx_1 ... done Removing docker_renewer_1 ... done Removing docker_ca_1 ... done Removing network docker_default ``` ## Basic Federation The [basic-federation example](basic-federation) showcases how to securely facilitate communication between relying parties of multiple autonomous certificate authorities. Federation is what's required when services are spread between multiple independent Kubernetes clusters, public clouds, and/or serverless cloud functions to enable service communication across boundaries. This example uses a pre-generated PKI (public/private key material). Do not use pre-generated PKIs for dev, staging, or production purposes outside of this example. ### Launch Online CAs Bring up two online CAs; `Cloud CA` and `Kubernetes CA`. ```bash $ step-ca ./pki/cloud/config/ca.federated.json Please enter the password to decrypt intermediate_ca_key: password 2019/01/22 13:38:52 Serving HTTPS on :1443 ... ``` ```bash $ step-ca ./pki/kubernetes/config/ca.federated.json Please enter the password to decrypt intermediate_ca_key: password 2019/01/22 13:39:44 Serving HTTPS on :2443 ... ``` Notice the difference between the two configuration options below. `Cloud CA` will list `Kubernetes CA` in the `federatedRoots` section and vice versa for the federated options. ```bash $ diff pki/cloud/config/ca.json pki/cloud/config/ca.federated.json 3c3 < "federatedRoots": [], --- > "federatedRoots": ["pki/cloud/certs/kubernetes_root_ca.crt"], ``` ### Bring up Demo Server This demo server leverages step's [SDK](https://godoc.org/github.com/smallstep/certificates/ca) to obtain certs, automatically renew them, and fetch a bundle of trusted roots. When it starts up it will report what root certificates it will use to authenticate client certs. ```bash go run server/main.go $(step ca token \ --ca-url https://localhost:1443 \ --root ./pki/cloud/certs/root_ca.crt \ 127.0.0.1) ✔ Key ID: EE1ZiqkMaxsUdpz8SCSkRBzwK9TWUoidQnMnJ8Eryn8 (sebastian@smallstep.com) ✔ Please enter the password to decrypt the provisioner key: password Server is using federated root certificates Accepting certs anchored in CN=Smallstep Public Cloud Root CA Accepting certs anchored in CN=Smallstep Kubernetes Root CA Listening on :8443 ... ``` ### Run Demo Client Similarly step's [SDK](https://godoc.org/github.com/smallstep/certificates/ca) provides a client option to mutually authenticate connections to servers. It automatically handles cert bootstrapping, renewal, and fetches a bundle of trusted roots. The demo client will send HTTP requests to the demo server periodically (every 5s). ```bash $ go run client/main.go $(step ca token sdk_client \ --ca-url https://localhost:2443 \ --root ./pki/kubernetes/certs/root_ca.crt) ✔ Key ID: S5gYgpeqcIAgc1Zr4myZXpgJ_Ao4ryS6F6wqg9o8RYo (sebastian@smallstep.com) ✔ Please enter the password to decrypt the provisioner key: password Server responded: Hello sdk_client (cert issued by 'Smallstep Kubernetes Root CA') at 2019-01-23 00:51:38.576648 +0000 UTC ``` ### Curl as Client While the demo client provides a convenient way to periodically send requests to the demo server curl in combination with a client cert from `Kubernetes CA` can be used to hit the server instead: ```bash $ step ca certificate kube_client kube_client.crt kube_client.key \ --ca-url https://localhost:2443 \ --root pki/kubernetes/certs/root_ca.crt ✔ Key ID: S5gYgpeqcIAgc1Zr4myZXpgJ_Ao4ryS6F6wqg9o8RYo (sebastian@smallstep.com) ✔ Please enter the password to decrypt the provisioner key: ✔ CA: https://localhost:2443/1.0/sign ✔ Certificate: kube_client.crt ✔ Private Key: kube_client.key ``` Federation relies on a bundle of multiple trusted roots which need to be fetched before passed into curl. ```bash $ step ca federation --ca-url https://localhost:1443 \ --root pki/cloud/certs/root_ca.crt \ federated.pem The federation certificate bundle has been saved in federated.pem. ``` Passing the cert (issued by `Kubernetes CA`) into curl using the appropriate command line flags: ```bash $ curl -i --cacert federated.pem \ --cert kube_client.crt \ --key kube_client.key \ https://127.0.0.1:8443 HTTP/2 200 content-type: text/plain; charset=utf-8 content-length: 105 date: Mon, 28 Jan 2019 15:24:54 GMT Hello kube_client (cert issued by 'Smallstep Kubernetes Root CA') at 2019-01-28 15:24:54.864373 +0000 UTC ``` Since the demo server is enrolled with the federated `Cloud CA` that trusts certs issued by the `Kubernetes CA` through federation the connection is successfully established. ## Custom certificate validity periods using Custom Claims Bring up the certificate authority with the example: ```sh certificates $ step-ca examples/pki/config/ca.json 2019/03/11 13:37:03 Serving HTTPS on :9000 ... ``` The example comes with multiple provisioner options, two of which have custom claims to expand the validity of certificates: ```sh $ step ca provisioner list | jq '.[] | "\(.name): \(.claims.defaultTLSCertDuration)"' # null means step default of 24h for cert validity "mariano@smallstep.com: null" "mike@smallstep.com: 2m0s" "decade: 87600h0m0s" "90days: 2160h0m0s" ``` A closer look at a duration-bound provisioner, `90days` for instance, reveals the custom configuration for certificate validity. ```sh $ step ca provisioner list | jq '.[3].claims' { "maxTLSCertDuration": "2160h0m0s", "defaultTLSCertDuration": "2160h0m0s" } ``` Certificates with different validity periods can be generated using the respective provisioners. The durations are strings which are a sequence of decimal numbers, each with optional fraction and a unit suffix, such as "300ms" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". Please see [Getting Started](https://github.com/smallstep/certificates/blob/master/docs/GETTING_STARTED.md) in the docs directory to learn what custom claims configuration options are available and how to use them. ```sh $ step ca certificate decade decade.crt decade.key ✔ Key ID: iu7VZxKUcquv1BCWuvEUOyRy4zYyCmgt61OpRW5VbRE (decade) ✔ Please enter the password to decrypt the provisioner key: password ✔ CA: https://localhost:9000/1.0/sign ✔ Certificate: decade.crt ✔ Private Key: decade.key $ step certificate inspect --format json decade.crt | jq .validity { "start": "2019-03-11T22:34:30Z", "end": "2029-03-08T22:34:30Z", "length": 315360000 } $ step ca certificate 90days 90days.crt 90days.key ✔ Key ID: 2LgjIvfirblnFMC6FjUr8jYkO8nOqz4rKoarCc8kiGU (90days) ✔ Please enter the password to decrypt the provisioner key: password ✔ CA: https://localhost:9000/1.0/sign ✔ Certificate: 90days.crt ✔ Private Key: 90days.key $ step certificate inspect --format json 90days.crt | jq .validity { "start": "2019-03-11T22:35:39Z", "end": "2019-06-09T22:35:39Z", "length": 7776000 } ``` ## Configuration Management Tools Configuration management tools such as Puppet, Chef, Ansible, Salt, etc. make automation and deployment a whole lot easier and more manageable. Step CLI and CA are built with automation in mind and are easy to configure using your favorite tools # Puppet The following are snippets and files that users can add to their puppet manifests to easily instrument services with TLS. ** [step.pp](./puppet/step.pp) ** - Install `step` from source and configure the `step` user, group, and home directory for use by the Step CLI and CA. ** [step_ca.pp](./puppet/step_ca.pp) ** - Install `step-ca` from source. Configure certificates and secrets and run the Step CA. ** [tls_server.pp](./puppet/tls_server.pp) ** - This is your service, instrumented with the Step CA SDK to request, receive, and renew TLS certificates. See [the bootstrap-tls-server](./bootstrap-tls-server/server.go) for a simple integration example. **Note:** This is a significantly oversimplified example that will not work standalone. A complete Puppet configuration should use a service manager (like [systemctl](https://www.digitalocean.com/community/tutorials/how-to-use-systemctl-to-manage-systemd-services-and-units)) and a secret store (like [Hiera](https://puppet.com/docs/puppet/6.0/hiera_intro.html)). If you are interested in seeing a more complete example please let us know and we'll make one available. ================================================ FILE: examples/ansible/smallstep-certs/defaults/main.yml ================================================ # Root cert for each will be saved in /etc/ssl/smallstep/ca/{{ ca_name }}/certs/root_ca.crt smallstep_root_certs: [] # - # ca_name: your_ca # ca_url: "https://certs.your_ca.ca.smallstep.com" # ca_fingerprint: "56092...2200" # Each leaf cert will be saved in /etc/ssl/smallstep/leaf/{{ cert_subject }}/{{ cert_subject }}.crt|key smallstep_leaf_certs: [] # - # ca_name: your_ca # cert_subject: "{{ inventory_hostname }}" # provisioner_name: "admin" # provisioner_password: "{{ smallstep_ssh_provisioner_password }}" ================================================ FILE: examples/ansible/smallstep-certs/tasks/main.yml ================================================ - name: "Ensure provisioners directories exist" file: path: "/etc/ssl/smallstep/provisioners/{{ item.context }}/{{ item.provisioner_name }}" state: directory mode: 0600 owner: root group: root with_items: "{{ smallstep_leaf_certs }}" no_log: true - name: "Ensure provisioner passwords are up to date" copy: dest: "/etc/ssl/smallstep/provisioners/{{ item.context }}/{{ item.provisioner_name }}/provisioner-pass.txt" content: "{{ item.provisioner_password }}" mode: 0700 owner: root group: root with_items: "{{ smallstep_leaf_certs }}" no_log: true - name: "Get root certs for CAs" command: cmd: "step ca bootstrap --context {{ item.context }} --ca-url {{ item.ca_url }} --fingerprint {{ item.ca_fingerprint }}" with_items: "{{ smallstep_root_certs }}" no_log: true - name: "Get leaf certs" command: cmd: "step ca certificate --context {{ item.context }} {{ item.cert_subject }} {{ item.cert_path }} {{ item.key_path }} --force --console --provisioner {{ item.provisioner_name }} --provisioner-password-file /etc/ssl/smallstep/provisioners/{{ item.context }}/{{ item.provisioner_name }}/provisioner-pass.txt" with_items: "{{ smallstep_leaf_certs }}" no_log: true - name: Ensure cron to renew leaf certs is up to date cron: user: "root" name: "renew leaf cert {{ item.cert_subject }}" cron_file: smallstep job: "step ca renew --context {{ item.context }} {{ item.cert_path }} {{ item.key_path }} --expires-in 6h --force >> /var/log/smallstep-{{ item.cert_subject }}.log 2>&1" state: present minute: "*/30" with_items: "{{ smallstep_leaf_certs }}" when: "{{ item.cron_renew }}" no_log: true ================================================ FILE: examples/ansible/smallstep-install/defaults/main.yml ================================================ smallstep_install_step_version: 0.15.3 smallstep_install_step_ssh_version: 0.19.1-1 ================================================ FILE: examples/ansible/smallstep-install/tasks/main.yml ================================================ # These steps automate the installation guide here: # https://smallstep.com/docs/sso-ssh/hosts/ - name: Download step binary get_url: url: "https://files.smallstep.com/step-linux-{{ smallstep_install_step_version }}" dest: "/usr/local/bin/step-{{ smallstep_install_step_version }}" mode: '0755' - name: Link binaries to correct version file: src: "/usr/local/bin/step-{{ smallstep_install_step_version }}" dest: "{{ item }}" state: link with_items: - /usr/bin/step - /usr/local/bin/step - name: Link /usr/local/bin/step to correct binary version file: src: "/usr/local/bin/step-{{ smallstep_install_step_version }}" dest: /usr/local/bin/step state: link - name: Ensure step-ssh is installed apt: deb: "https://files.smallstep.com/step-ssh_{{ smallstep_install_step_ssh_version }}_amd64.deb" state: present ================================================ FILE: examples/ansible/smallstep-ssh/defaults/main.yml ================================================ # If this host is behind a bastion this variable should contain the hostname of the bastion smallstep_ssh_host_behind_bastion_name: "" smallstep_ssh_host_is_bastion: false smallstep_ssh_ca_url: "https://ssh.mycompany.ca.smallstep.com" smallstep_ssh_ca_fingerprint: "XXXXXXXXXXXXXXX" # Whether or not to reinitialize the host even if it's already been installed smallstep_ssh_force_reinit: true ================================================ FILE: examples/ansible/smallstep-ssh/tasks/main.yml ================================================ # These steps automate the installation guide here: # https://smallstep.com/docs/sso-ssh/hosts/ # TODO: Figure out how to make this idempotent instead of reinstalling on each run - name: Bootstrap node to connect to CA command: "step ca bootstrap --context ssh --ca-url {{ smallstep_ssh_ca_url }} --fingerprint {{ smallstep_ssh_ca_fingerprint }} --force" # when: smallstep_ssh_installed.changed or smallstep_ssh_force_reinit - name: Get a host SSH certificate command: "step ssh certificate --context ssh {{ inventory_hostname }} /etc/ssh/ssh_host_ecdsa_key.pub --host --sign --provisioner=\"Service Account\" --token=\"{{ smallstep_ssh_enrollment_token }}\" --force" # when: smallstep_ssh_installed.changed or smallstep_ssh_force_reinit - name: Configure SSHD (will be overwriten by the sshd template in Ansible later) command: "step ssh config --context ssh --host --set Certificate=ssh_host_ecdsa_key-cert.pub --set Key=ssh_host_ecdsa_key" # when: smallstep_ssh_installed.changed or smallstep_ssh_force_reinit - name: Activate SmallStep PAM/NSS modules and nohup sshd command: "step-ssh activate {{ inventory_hostname }}" # when: smallstep_ssh_installed.changed or smallstep_ssh_force_reinit - name: Generate host tags list set_fact: smallstep_ssh_host_tags_string: "{{ smallstep_ssh_host_tags | to_json | regex_replace('\\:\\ ','=') | regex_replace('\\{\\\"|,\\ \\\"', ' --tag \"') | regex_replace('[\\[\\]{}]') }}" - name: Generate command to register set_fact: smallstep_ssh_register_string: | step-ssh-ctl register --hostname {{ inventory_hostname }} {% if not smallstep_ssh_host_is_bastion %}--bastion '{{ smallstep_ssh_host_behind_bastion_name|default("") }}'{% endif %} {% if smallstep_ssh_host_is_bastion %}--is-bastion{% endif %} {{ smallstep_ssh_host_tags_string }} - debug: var=smallstep_ssh_register_string - name: Register host with smallstep command: "{{ smallstep_ssh_register_string }}" # when: smallstep_ssh_installed.changed or smallstep_ssh_force_reinit ================================================ FILE: examples/basic-client/client.go ================================================ //nolint:govet // example code; allow unused variables package main import ( "context" "encoding/json" "fmt" "net" "net/http" "os" "time" "github.com/smallstep/certificates/ca" ) func printResponse(name string, v interface{}) { b, err := json.MarshalIndent(v, "", " ") if err != nil { panic(err) } fmt.Printf("%s response:\n%s\n\n", name, b) } func main() { if len(os.Args) != 2 { fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) os.Exit(1) } token := os.Args[1] // To create the client using ca.NewClient we need: // * The CA address "https://localhost:9000" // * The root certificate fingerprint // 84a033e84196f73bd593fad7a63e509e57fd982f02084359c4e8c5c864efc27d to get // the root fingerprint we can use `step certificate fingerprint root_ca.crt` client, err := ca.NewClient("https://localhost:9000", ca.WithRootSHA256("84a033e84196f73bd593fad7a63e509e57fd982f02084359c4e8c5c864efc27d")) if err != nil { panic(err) } // Other ways to initialize the client would be: // * With the Bootstrap functionality (recommended): // client, err := ca.Bootstrap(token) // * Using the root certificate instead of the fingerprint: // client, err := ca.NewClient("https://localhost:9000", ca.WithRootFile("../pki/secrets/root_ca.crt")) // Get the health of the CA health, err := client.Health() if err != nil { panic(err) } printResponse("Health", health) // Get and verify a root CA root, err := client.Root("84a033e84196f73bd593fad7a63e509e57fd982f02084359c4e8c5c864efc27d") if err != nil { panic(err) } printResponse("Root", root) // We can use ca.CreateSignRequest to generate a new sign request with a // randomly generated key. req, pk, err := ca.CreateSignRequest(token) if err != nil { panic(err) } sign, err := client.Sign(req) if err != nil { panic(err) } printResponse("Sign", sign) // Renew a certificate with a transport that contains the previous // certificate. We should created a context that allows us to finish the // renewal goroutine.∑ ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Finish the renewal goroutine tr, err := client.Transport(ctx, sign, pk) if err != nil { panic(err) } renew, err := client.Renew(tr) if err != nil { panic(err) } printResponse("Renew", renew) // Get tls.Config for a server ctxServer, cancelServer := context.WithCancel(context.Background()) defer cancelServer() tlsConfig, err := client.GetServerTLSConfig(ctxServer, sign, pk) if err != nil { panic(err) } // An http server will use the tls.Config like: _ = &http.Server{ Addr: ":443", Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Write([]byte("Hello world")) }), TLSConfig: tlsConfig, ReadHeaderTimeout: 30 * time.Second, } // Get tls.Config for a client ctxClient, cancelClient := context.WithCancel(context.Background()) defer cancelClient() tlsConfig, err = client.GetClientTLSConfig(ctxClient, sign, pk) if err != nil { panic(err) } // An http.Client will need to create a transport first _ = &http.Client{ Transport: &http.Transport{ TLSClientConfig: tlsConfig, // Options set in http.DefaultTransport Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, DualStack: true, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, }, } // But we can just use client.Transport to get the default configuration ctxTransport, cancelTransport := context.WithCancel(context.Background()) defer cancelTransport() tr, err = client.Transport(ctxTransport, sign, pk) if err != nil { panic(err) } // And http.Client will use the transport like _ = &http.Client{ Transport: tr, } // Get provisioners and provisioner keys. In this example we add two // optional arguments with the initial cursor and a limit. // // A server or a client should not need this functionality, they are used to // sign (private key) and verify (public key) tokens. The step cli can be // used for this purpose. provisioners, err := client.Provisioners(ca.WithProvisionerCursor(""), ca.WithProvisionerLimit(100)) if err != nil { panic(err) } printResponse("Provisioners", provisioners) // Get encrypted key key, err := client.ProvisionerKey("DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk") if err != nil { panic(err) } printResponse("Provisioner Key", key) } ================================================ FILE: examples/basic-federation/client/main.go ================================================ package main import ( "context" "fmt" "io" "os" "time" "github.com/smallstep/certificates/ca" ) func main() { if len(os.Args) != 2 { fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) os.Exit(1) } token := os.Args[1] // make sure to cancel the renew goroutine ctx, cancel := context.WithCancel(context.Background()) defer cancel() client, err := ca.BootstrapClient(ctx, token, ca.AddFederationToRootCAs()) if err != nil { panic(err) } for { resp, err := client.Get("https://127.0.0.1:8443") if err != nil { panic(err) } b, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { panic(err) } fmt.Printf("Server responded: %s\n", b) time.Sleep(5 * time.Second) } } ================================================ FILE: examples/basic-federation/pki/cloud/certs/intermediate_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBvjCCAWSgAwIBAgIQPDSG4MCReDzPu96+Cb5e0TAKBggqhkjOPQQDAjApMScw JQYDVQQDEx5TbWFsbHN0ZXAgUHVibGljIENsb3VkIFJvb3QgQ0EwHhcNMTkwMTE4 MjEwNjE2WhcNMjkwMTE1MjEwNjE2WjAxMS8wLQYDVQQDEyZTbWFsbHN0ZXAgUHVi bGljIENsb3VkIEludGVybWVkaWF0ZSBDQTBZMBMGByqGSM49AgEGCCqGSM49AwEH A0IABCQRZefT8U34OiW3o+R2/Ob2DUWrcL87jj9D7IGtAgOjulbfoJaH3rnJumG2 DtMImBJ1hPa2mXMjThnjOXSlGA+jZjBkMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMB Af8ECDAGAQH/AgEAMB0GA1UdDgQWBBSCHPvFHHqEJU82OtFPZVHhstA/rDAfBgNV HSMEGDAWgBSMOU1vYot1qfDaTG2uY9l2nQNSADAKBggqhkjOPQQDAgNIADBFAiEA hojptJQvmTlu9Ybyr9UCL6Akiks8U1RPF2NS+YKZm+8CIDjbipMuz5AXCez57/5r ZrEv0JxcWpK6AxfitwyYg34e -----END CERTIFICATE----- ================================================ FILE: examples/basic-federation/pki/cloud/certs/kubernetes_root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBkjCCATigAwIBAgIRAKqKZhDGVx8dcxTDGpowzNEwCgYIKoZIzj0EAwIwJzEl MCMGA1UEAxMcU21hbGxzdGVwIEt1YmVybmV0ZXMgUm9vdCBDQTAeFw0xOTAxMTgy MTA2NDdaFw0yOTAxMTUyMTA2NDdaMCcxJTAjBgNVBAMTHFNtYWxsc3RlcCBLdWJl cm5ldGVzIFJvb3QgQ0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAASst212A8a9 h1DBFEzCgIaoZEWWf0JlBkSmnlHCHZLK2ookNKY6k8UAki4o1xpYjeLtlL4xn4WL mMEafC2tPQvxo0UwQzAOBgNVHQ8BAf8EBAMCAQYwEgYDVR0TAQH/BAgwBgEB/wIB ATAdBgNVHQ4EFgQU1+ia0R8GNWYXgs7qkbXHzinghlEwCgYIKoZIzj0EAwIDSAAw RQIgDQlbDQxnNxRsR8d/lQiBSy6v0u6BOmftfbB3y0CcGI4CIQC2dxkUvi6GsfHs zRgU5ZPIT7sVEfNi9G3GZABj0vOnvQ== -----END CERTIFICATE----- ================================================ FILE: examples/basic-federation/pki/cloud/certs/root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBlTCCATygAwIBAgIRAPykhdlDneUGU9rI1g+Y40MwCgYIKoZIzj0EAwIwKTEn MCUGA1UEAxMeU21hbGxzdGVwIFB1YmxpYyBDbG91ZCBSb290IENBMB4XDTE5MDEx ODIxMDYxM1oXDTI5MDExNTIxMDYxM1owKTEnMCUGA1UEAxMeU21hbGxzdGVwIFB1 YmxpYyBDbG91ZCBSb290IENBMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESc9z Y77gf4XhCCOzsAhvMThV3Wro6EVnfBaSlmmnq15VaONG6FP7kVyJEM+XD75Thu10 AsxwB0w4WxKIJ63TNaNFMEMwDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYB Af8CAQEwHQYDVR0OBBYEFIw5TW9ii3Wp8NpMba5j2XadA1IAMAoGCCqGSM49BAMC A0cAMEQCICVylyEZfBBilwKN1nvS4j9Lbt6/nhF5DH9K/wjPIBhiAiBZojDvMZhj mreuuFfRC4kWE/OUG5Iz2qVUtlvL/NaXXQ== -----END CERTIFICATE----- ================================================ FILE: examples/basic-federation/pki/cloud/config/ca.federated.json ================================================ { "root": "pki/cloud/certs/root_ca.crt", "federatedRoots": ["pki/cloud/certs/kubernetes_root_ca.crt"], "crt": "pki/cloud/certs/intermediate_ca.crt", "key": "pki/cloud/secrets/intermediate_ca_key", "address": ":1443", "dnsNames": [ "localhost" ], "logger": { "format": "text" }, "authority": { "provisioners": [ { "name": "sebastian@smallstep.com", "type": "jwk", "key": { "use": "sig", "kty": "EC", "kid": "EE1ZiqkMaxsUdpz8SCSkRBzwK9TWUoidQnMnJ8Eryn8", "crv": "P-256", "alg": "ES256", "x": "BcoXteWHdYxXUrckEQwEQDol2nM97J8KIg7GiXc3AMc", "y": "8QkL41tl7BZ5uIf_VTOEypp8vsKUnDGpNdrfk0FNt0Q" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoicTV2MkZ2bmRjNnF0ZWlEZkVBNWI5ZyJ9.MZCVUhU1yuYkhKbQqJDJX0Y6d1X6ranvIeGHIpWLc_STHAgta8c0tA.4o-sw0jTV064OtuL.QVqIo2l0Qf_MRXVghjFYUFkWlK-3VomqzskLLDfLz1norWQa-wEdV56_CIZ7gAPxiLj2N6VOlpEg-sKA2xL3w9b-2WovH_o93iN2MziiWajFf9uq-41LVyEeROd_6Gs4TQbxyz5rk_iMsZeRNRKTpYvW1E2lA4YlMTm4QLV7s7xkGaWsL_-pATfb24bnDMrjRyAVLR61rPHxUQ2wcK_hRG272xoSAsNOWRrUcnDYzjylj-YfmhQZy77Rf38Rxy3UlhB4iMB-y7wMoMuseRKTvBEncL-c0wrllKWUP_KjCl6VeanKWAGUilbmgIpEa1Y_QbNZTD9t0rw2TJSkjx0.CNtEZWZrfp542E65F2oi4w" } ] }, "tls": { "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" ], "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false } } ================================================ FILE: examples/basic-federation/pki/cloud/config/ca.json ================================================ { "root": "pki/cloud/certs/root_ca.crt", "federatedRoots": [], "crt": "pki/cloud/certs/intermediate_ca.crt", "key": "pki/cloud/secrets/intermediate_ca_key", "address": ":1443", "dnsNames": [ "localhost" ], "logger": { "format": "text" }, "authority": { "provisioners": [ { "name": "sebastian@smallstep.com", "type": "jwk", "key": { "use": "sig", "kty": "EC", "kid": "EE1ZiqkMaxsUdpz8SCSkRBzwK9TWUoidQnMnJ8Eryn8", "crv": "P-256", "alg": "ES256", "x": "BcoXteWHdYxXUrckEQwEQDol2nM97J8KIg7GiXc3AMc", "y": "8QkL41tl7BZ5uIf_VTOEypp8vsKUnDGpNdrfk0FNt0Q" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoicTV2MkZ2bmRjNnF0ZWlEZkVBNWI5ZyJ9.MZCVUhU1yuYkhKbQqJDJX0Y6d1X6ranvIeGHIpWLc_STHAgta8c0tA.4o-sw0jTV064OtuL.QVqIo2l0Qf_MRXVghjFYUFkWlK-3VomqzskLLDfLz1norWQa-wEdV56_CIZ7gAPxiLj2N6VOlpEg-sKA2xL3w9b-2WovH_o93iN2MziiWajFf9uq-41LVyEeROd_6Gs4TQbxyz5rk_iMsZeRNRKTpYvW1E2lA4YlMTm4QLV7s7xkGaWsL_-pATfb24bnDMrjRyAVLR61rPHxUQ2wcK_hRG272xoSAsNOWRrUcnDYzjylj-YfmhQZy77Rf38Rxy3UlhB4iMB-y7wMoMuseRKTvBEncL-c0wrllKWUP_KjCl6VeanKWAGUilbmgIpEa1Y_QbNZTD9t0rw2TJSkjx0.CNtEZWZrfp542E65F2oi4w" } ] }, "tls": { "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" ], "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false } } ================================================ FILE: examples/basic-federation/pki/cloud/secrets/intermediate_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,61d1fda0c56e8b8fcf2eb8ba5e20903f pFdzP/NSsBCJ3LiVNd6Qg8BRZrNO9+MOcXeg93LAOFQVOswPRya+tfTojfv6t4Lm qJYJlENAAKCjM5+GuMeslTpyqlgDvUYi36v4FAxvLNp49r+nzKXqq0chIPhlLwGb Mr8rEtX97vosyYEtsYhbjjzgCJXMhBXzi9tbLkkiRxw= -----END EC PRIVATE KEY----- ================================================ FILE: examples/basic-federation/pki/cloud/secrets/root_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,78223655d6ee4b458ff2d5b79dadbd06 +bHgP06x5ScVppDSysFN/1xVPnC5MrrUhG7zLG/eX6yJrj4YegthwnDrCuWSCc9U 0fv0jJ0QGGm8YL5NdH1EpOF7f3+KhOaCv3KLUFJqDWy7HC9anvLHRZPxEeGgndfZ 5EJz043Lg7xYSDDaH1a0ZVcHPpyiUJNhS9lxmOQP+Ec= -----END EC PRIVATE KEY----- ================================================ FILE: examples/basic-federation/pki/kubernetes/certs/cloud_root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBlTCCATygAwIBAgIRAPykhdlDneUGU9rI1g+Y40MwCgYIKoZIzj0EAwIwKTEn MCUGA1UEAxMeU21hbGxzdGVwIFB1YmxpYyBDbG91ZCBSb290IENBMB4XDTE5MDEx ODIxMDYxM1oXDTI5MDExNTIxMDYxM1owKTEnMCUGA1UEAxMeU21hbGxzdGVwIFB1 YmxpYyBDbG91ZCBSb290IENBMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESc9z Y77gf4XhCCOzsAhvMThV3Wro6EVnfBaSlmmnq15VaONG6FP7kVyJEM+XD75Thu10 AsxwB0w4WxKIJ63TNaNFMEMwDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYB Af8CAQEwHQYDVR0OBBYEFIw5TW9ii3Wp8NpMba5j2XadA1IAMAoGCCqGSM49BAMC A0cAMEQCICVylyEZfBBilwKN1nvS4j9Lbt6/nhF5DH9K/wjPIBhiAiBZojDvMZhj mreuuFfRC4kWE/OUG5Iz2qVUtlvL/NaXXQ== -----END CERTIFICATE----- ================================================ FILE: examples/basic-federation/pki/kubernetes/certs/intermediate_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBuzCCAWGgAwIBAgIRAK/si8hkdtyDLAjg6ZiRxIgwCgYIKoZIzj0EAwIwJzEl MCMGA1UEAxMcU21hbGxzdGVwIEt1YmVybmV0ZXMgUm9vdCBDQTAeFw0xOTAxMTgy MTA2NDlaFw0yOTAxMTUyMTA2NDlaMC8xLTArBgNVBAMTJFNtYWxsc3RlcCBLdWJl cm5ldGVzIEludGVybWVkaWF0ZSBDQTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IA BBIuNImYDPC6BIGm0/C97su9GPsrGoS2uSHDuPIORbaGPtGI8A2SHInw7pHXRqnW 5TvrbCOcLI2Ao2RQvAd/9jCjZjBkMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8E CDAGAQH/AgEAMB0GA1UdDgQWBBT3w/Rtaw3B4mjhqwDb+/tCFLQTozAfBgNVHSME GDAWgBTX6JrRHwY1ZheCzuqRtcfOKeCGUTAKBggqhkjOPQQDAgNIADBFAiEAz8Lo GPyrFRGvQ6Eie/qIjJByNEmN3FCOJOJr0J7csy0CIBLxKfMhsT719con+/ZMBNRt HEV9xTWpPhUvnlsTPIJl -----END CERTIFICATE----- ================================================ FILE: examples/basic-federation/pki/kubernetes/certs/root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBkjCCATigAwIBAgIRAKqKZhDGVx8dcxTDGpowzNEwCgYIKoZIzj0EAwIwJzEl MCMGA1UEAxMcU21hbGxzdGVwIEt1YmVybmV0ZXMgUm9vdCBDQTAeFw0xOTAxMTgy MTA2NDdaFw0yOTAxMTUyMTA2NDdaMCcxJTAjBgNVBAMTHFNtYWxsc3RlcCBLdWJl cm5ldGVzIFJvb3QgQ0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAASst212A8a9 h1DBFEzCgIaoZEWWf0JlBkSmnlHCHZLK2ookNKY6k8UAki4o1xpYjeLtlL4xn4WL mMEafC2tPQvxo0UwQzAOBgNVHQ8BAf8EBAMCAQYwEgYDVR0TAQH/BAgwBgEB/wIB ATAdBgNVHQ4EFgQU1+ia0R8GNWYXgs7qkbXHzinghlEwCgYIKoZIzj0EAwIDSAAw RQIgDQlbDQxnNxRsR8d/lQiBSy6v0u6BOmftfbB3y0CcGI4CIQC2dxkUvi6GsfHs zRgU5ZPIT7sVEfNi9G3GZABj0vOnvQ== -----END CERTIFICATE----- ================================================ FILE: examples/basic-federation/pki/kubernetes/config/ca.federated.json ================================================ { "root": "pki/kubernetes/certs/root_ca.crt", "federatedRoots": ["pki/kubernetes/certs/cloud_root_ca.crt"], "crt": "pki/kubernetes/certs/intermediate_ca.crt", "key": "pki/kubernetes/secrets/intermediate_ca_key", "address": ":2443", "dnsNames": [ "localhost" ], "logger": { "format": "text" }, "authority": { "provisioners": [ { "name": "sebastian@smallstep.com", "type": "jwk", "key": { "use": "sig", "kty": "EC", "kid": "S5gYgpeqcIAgc1Zr4myZXpgJ_Ao4ryS6F6wqg9o8RYo", "crv": "P-256", "alg": "ES256", "x": "uYecJyfa3pKHrO36zVsPKCHAcCUoYKOic2M7_qv9Jes", "y": "WXzgS-36_BcXSi-G86RcLmLaHJEjmcmkhzR9ajrhhGo" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiRlBVMHBJclNYbUU5VzhFSzFaMlQ2USJ9.wTT9VpKIOKzoiaBJ2fIcOYwypkKTwuKauwkf4fcfmJ6_VVQ6TykIrA.y_YTjVhmpztcOqYq.9j_pppHHcJ3_VEbt-WfxSOks05QMXNI862uYWcFGc7EVCoD1qEiKIDKEuoUSTG-_WcgkaXrag9UUQPKfDpuYi6UJcyaLkehO2DBX26DJ1T-qEYlhPxPGpx8r7p84zcg_AftypD6PNheiCLe6HOQWWuPtuPrfyUpvyfMWkJC14NZjR4iJysKP5dndxFbSTI2XCw1X-zBDVD9xMnVPlRtezIIuDi2cLEYnNaIlr5NMNQBOrzQaSo1LQaOOoSa_OrZzTdjg7HaUU5DaAA6YEDQPfFxLIJdKshvj5sxDlH_LLY58mzsGECrUx396zjvN8FD-bFSrdbCNlJJ4xhtt34c.FhRAcFIg-5k2srdtdsq2VQ" } ] }, "tls": { "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" ], "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false } } ================================================ FILE: examples/basic-federation/pki/kubernetes/config/ca.json ================================================ { "root": "pki/kubernetes/certs/root_ca.crt", "federatedRoots": [], "crt": "pki/kubernetes/certs/intermediate_ca.crt", "key": "pki/kubernetes/secrets/intermediate_ca_key", "address": ":2443", "dnsNames": [ "localhost" ], "logger": { "format": "text" }, "authority": { "provisioners": [ { "name": "sebastian@smallstep.com", "type": "jwk", "key": { "use": "sig", "kty": "EC", "kid": "S5gYgpeqcIAgc1Zr4myZXpgJ_Ao4ryS6F6wqg9o8RYo", "crv": "P-256", "alg": "ES256", "x": "uYecJyfa3pKHrO36zVsPKCHAcCUoYKOic2M7_qv9Jes", "y": "WXzgS-36_BcXSi-G86RcLmLaHJEjmcmkhzR9ajrhhGo" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiRlBVMHBJclNYbUU5VzhFSzFaMlQ2USJ9.wTT9VpKIOKzoiaBJ2fIcOYwypkKTwuKauwkf4fcfmJ6_VVQ6TykIrA.y_YTjVhmpztcOqYq.9j_pppHHcJ3_VEbt-WfxSOks05QMXNI862uYWcFGc7EVCoD1qEiKIDKEuoUSTG-_WcgkaXrag9UUQPKfDpuYi6UJcyaLkehO2DBX26DJ1T-qEYlhPxPGpx8r7p84zcg_AftypD6PNheiCLe6HOQWWuPtuPrfyUpvyfMWkJC14NZjR4iJysKP5dndxFbSTI2XCw1X-zBDVD9xMnVPlRtezIIuDi2cLEYnNaIlr5NMNQBOrzQaSo1LQaOOoSa_OrZzTdjg7HaUU5DaAA6YEDQPfFxLIJdKshvj5sxDlH_LLY58mzsGECrUx396zjvN8FD-bFSrdbCNlJJ4xhtt34c.FhRAcFIg-5k2srdtdsq2VQ" } ] }, "tls": { "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" ], "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false } } ================================================ FILE: examples/basic-federation/pki/kubernetes/secrets/intermediate_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,fb0860a45dbae164b2657f27c2a3f1cf 6nFswxVAct5zIqsThsJ1uYY6gzkhMzAmurih+mhlhANfl6SHURdzk8AtutIQrFfV jr/vTPyr3aR+dqldt/wg1L9pJFc/OoWVlOFbltmLTkWMPk+VHKxR5V6A7IVYoKvm EUFzVb+aHj6M9R7ecBf0IslGQ0nsqYa54WppGHmAA4A= -----END EC PRIVATE KEY----- ================================================ FILE: examples/basic-federation/pki/kubernetes/secrets/root_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,965ede7ef96d4640932c18bdf1795645 treDpMX0uYFlWMiPvjYfnv6K9jmT4f8pTG6AkzZB0eeaNj04tt4FuIgrabHoZFNx IC1mFIRZvhJaiOXNIvQbo/Wnweu8nVV/xn73xNKBramgfXDo9WCvIsffjRg1xtsq s3SuONddo4IdpmrG7iEZTkDe2IzSV5NYhGsKiwvDvaU= -----END EC PRIVATE KEY----- ================================================ FILE: examples/basic-federation/server/main.go ================================================ package main import ( "context" "fmt" "net/http" "os" "time" "github.com/smallstep/certificates/ca" ) func main() { if len(os.Args) != 2 { fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) os.Exit(1) } token := os.Args[1] // make sure to cancel the renew goroutine ctx, cancel := context.WithCancel(context.Background()) defer cancel() srv, err := ca.BootstrapServer(ctx, token, &http.Server{ Addr: ":8443", Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { name := "nobody" issuer := "none" if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { name = r.TLS.PeerCertificates[0].Subject.CommonName issuer = r.TLS.PeerCertificates[len(r.TLS.PeerCertificates)-1].Issuer.CommonName } fmt.Fprintf(w, "Hello %s (cert issued by '%s') at %s", name, issuer, time.Now().UTC()) //nolint:gosec // example code for demonstration }), ReadHeaderTimeout: 30 * time.Second, }, ca.AddFederationToClientCAs(), ListTrustedRoots()) if err != nil { panic(err) } fmt.Println("Listening on :8443 ...") if err := srv.ListenAndServeTLS("", ""); err != nil { panic(err) } } // ListTrustedRoots prints list of trusted roots for illustration purposes func ListTrustedRoots() ca.TLSOption { fn := func(ctx *ca.TLSOptionCtx) error { certs, err := ctx.Client.Federation() if err != nil { return err } roots, err := ctx.Client.Roots() if err != nil { return err } if len(certs.Certificates) > len(roots.Certificates) { fmt.Println("Server is using federated root certificates") } for _, cert := range certs.Certificates { fmt.Printf("Accepting certs anchored in %s\n", cert.Certificate.Subject) } return nil } return func(ctx *ca.TLSOptionCtx) error { ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) return fn(ctx) } } ================================================ FILE: examples/bootstrap-client/client.go ================================================ package main import ( "context" "fmt" "io" "os" "time" "github.com/smallstep/certificates/ca" ) func main() { if len(os.Args) != 2 { fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) os.Exit(1) } token := os.Args[1] // make sure to cancel the renew goroutine ctx, cancel := context.WithCancel(context.Background()) defer cancel() client, err := ca.BootstrapClient(ctx, token) if err != nil { panic(err) } for { resp, err := client.Get("https://localhost:8443") if err != nil { panic(err) } b, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { panic(err) } fmt.Printf("Server responded: %s\n", b) time.Sleep(1 * time.Second) } } ================================================ FILE: examples/bootstrap-mtls-server/server.go ================================================ package main import ( "context" "fmt" "net/http" "os" "time" "github.com/smallstep/certificates/ca" ) func main() { if len(os.Args) != 2 { fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) os.Exit(1) } token := os.Args[1] // make sure to cancel the renew goroutine ctx, cancel := context.WithCancel(context.Background()) defer cancel() srv, err := ca.BootstrapServer(ctx, token, &http.Server{ Addr: ":8443", Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { name := "nobody" if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { name = r.TLS.PeerCertificates[0].Subject.CommonName } fmt.Fprintf(w, "Hello %s at %s!!!", name, time.Now().UTC()) //nolint:gosec // example code for demonstration }), ReadHeaderTimeout: 30 * time.Second, }) if err != nil { panic(err) } fmt.Println("Listening on :8443 ...") if err := srv.ListenAndServeTLS("", ""); err != nil { panic(err) } } ================================================ FILE: examples/bootstrap-tls-server/server.go ================================================ package main import ( "context" "fmt" "net/http" "os" "time" "github.com/smallstep/certificates/ca" ) func main() { if len(os.Args) != 2 { fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) os.Exit(1) } token := os.Args[1] // make sure to cancel the renew goroutine ctx, cancel := context.WithCancel(context.Background()) defer cancel() srv, err := ca.BootstrapServer(ctx, token, &http.Server{ Addr: ":8443", Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { name := "nobody" if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { name = r.TLS.PeerCertificates[0].Subject.CommonName } fmt.Fprintf(w, "Hello %s at %s!!!", name, time.Now().UTC()) //nolint:gosec // example code for demonstration }), ReadHeaderTimeout: 30 * time.Second, }, ca.VerifyClientCertIfGiven()) if err != nil { panic(err) } fmt.Println("Listening on :8443 ...") if err := srv.ListenAndServeTLS("", ""); err != nil { panic(err) } } ================================================ FILE: examples/docker/Makefile ================================================ all: binaries build up binaries: CGO_ENABLED=0 GOOS=linux go build -o ca/step-ca github.com/smallstep/certificates/cmd/step-ca build: build-nginx build-ca build-renewer build-nginx: docker build -t nginx-test:latest nginx build-ca: docker build -t step-ca-test:latest ca build-renewer: docker build -t step-renewer-test:latest renewer up: docker-compose up down: docker-compose down inspect: step certificate inspect https://localhost:4443 --insecure .PHONY: all binaries up down inspect .PHONY: build build-nginx build-ca build-renewer ================================================ FILE: examples/docker/ca/Dockerfile ================================================ FROM alpine ADD step-ca /usr/local/bin/step-ca COPY pki /run # Smallstep CA CMD ["step-ca", "/run/config/ca.json"] ================================================ FILE: examples/docker/ca/pki/config/ca.json ================================================ { "root": "/run/secrets/root_ca.crt", "crt": "/run/secrets/intermediate_ca.crt", "key": "/run/secrets/intermediate_ca_key", "password": "password", "address": ":443", "dnsNames": [ "ca" ], "logger": { "format": "text" }, "authority": { "provisioners": [ { "name": "mariano@smallstep.com", "type": "jwk", "key": { "use": "sig", "kty": "EC", "kid": "DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk", "crv": "P-256", "alg": "ES256", "x": "jXoO1j4CXxoTC32pNzkVC8l6k2LfP0k5ndhJZmcdVbk", "y": "c3JDL4GTFxJWHa8EaHdMh4QgwMh64P2_AGWrD0ADXcI" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiOTFVWjdzRGw3RlNXcldfX1I1NUh3USJ9.FcWtrBDNgrkA33G9Ll9sXh1cPF-3jVXeYe1FLmSDc_Q2PmfLOPvJOA.0ZoN32ayaRWnufJb.WrkffMmDLWiq1-2kn-w7-kVBGW12gjNCBHNHB1hyEdED0rWH1YWpKd8FjoOACdJyLhSn4kAS3Lw5AH7fvO27A48zzvoxZU5EgSm5HG9IjkIH-LBJ-v79ShkpmPylchgjkFhxa5epD11OIK4rFmI7s-0BCjmJokLR_DZBhDMw2khGnsr_MEOfAz9UnqXaQ4MIy8eT52xUpx68gpWFlz2YP3EqiYyNEv0PpjMtyP5lO2i8-p8BqvuJdus9H3fO5Dg-1KVto1wuqh4BQ2JKTauv60QAnM_4sdxRHku3F_nV64SCrZfDvnN2ve21raFROtyXaqHZhN6lyoPxDncy8v4.biaOblEe0N-gMpJyFZ-3-A" }, { "name": "mike@smallstep.com", "type": "jwk", "key": { "use": "sig", "kty": "EC", "kid": "YYNxZ0rq0WsT2MlqLCWvgme3jszkmt99KjoGEJJwAKs", "crv": "P-256", "alg": "ES256", "x": "LsI8nHBflc-mrCbRqhl8d3hSl5sYuSM1AbXBmRfznyg", "y": "F99LoOvi7z-ZkumsgoHIhodP8q9brXe4bhF3szK-c_w" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiVERQS2dzcEItTUR4ZDJxTGo0VlpwdyJ9.2_j0cZgTm2eFkZ-hrtr1hBIvLxN0w3TZhbX0Jrrq7vBMaywhgFcGTA.mCasZCbZJ-JT7vjA.bW052WDKSf_ueEXq1dyxLq0n3qXWRO-LXr7OzBLdUKWKSBGQrzqS5KJWqdUCPoMIHTqpwYvm-iD6uFlcxKBYxnsAG_hoq_V3icvvwNQQSd_q7Thxr2_KtPIDJWNuX1t5qXp11hkgb-8d5HO93CmN7xNDG89pzSUepT6RYXOZ483mP5fre9qzkfnrjx3oPROCnf3SnIVUvqk7fwfXuniNsg3NrNqncHYUQNReiq3e9I1R60w0ZQTvIReY7-zfiq7iPgVqmu5I7XGgFK4iBv0L7UOEora65b4hRWeLxg5t7OCfUqrS9yxAk8FdjFb9sEfjopWViPRepB0dYPH8dVI.fb6-7XWqp0j6CR9Li0NI-Q", "claims": { "minTLSCertDuration": "60s", "defaultTLSCertDuration": "120s" } } ] }, "tls": { "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" ], "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false } } ================================================ FILE: examples/docker/ca/pki/secrets/intermediate_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBxjCCAWugAwIBAgIQAYoOWhdChUmmKzlc0DWcWDAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xODExMDIyMzU0MTNaFw0yODEw MzAyMzU0MTNaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAASxvIWme8/yDAxkR63KgSYkpN7mHKBH k5c8S+uzba4xWbaxZtEZ9NNhEIAgYFZ9/3ThrzLOsuGwRCvPTaD5iycQo4GGMIGD MA4GA1UdDwEB/wQEAwIBpjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIw EgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU8dKIy5ZLH2h3ihWgqjcpoo5e q3YwHwYDVR0jBBgwFoAU0IpOvAyBnn9UhDqOQzXnfEU3aYMwCgYIKoZIzj0EAwID SQAwRgIhANXlcktuaEvORhgRvzQ6vVNgvpqCEXW3CcCHjUl1xSdaAiEAmakkpfFq VsT5PqPnTRgOWlFESRhQ9btl6nQ+2Lt/S5A= -----END CERTIFICATE----- ================================================ FILE: examples/docker/ca/pki/secrets/intermediate_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,4c7758e66df1884f6560839de64d4dd3 S8Ha8uA+bA3IGPurYODwd9VaJZ6FHI2tlznHXCOxT1MlGqyEAc4aWS11QBUz0Ucp excwlqM8kfh5BcN5a+vvInHnv74ZiNPdpt/apzz2LIx52pApzASiKVXRsAUmR4Pv 3MsO1/cVHkilpee1uC+axL32d5YmyP0URpSNJK9BhZo= -----END EC PRIVATE KEY----- ================================================ FILE: examples/docker/ca/pki/secrets/root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBfDCCASGgAwIBAgIQY0CXerxuM+EhTbpVxxLRKjAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xODExMDIyMzU0MTNaFw0yODEw MzAyMzU0MTNaMBwxGjAYBgNVBAMTEVNtYWxsc3RlcCBSb290IENBMFkwEwYHKoZI zj0CAQYIKoZIzj0DAQcDQgAEEGa7ZeL4WVIfPFDS7glJkIVsITVQgjfyz+AhcYaS rkJZlWOGZ60br9uE/wEfUcX1zavrX1Wz+bSJzTvT0AVBNqNFMEMwDgYDVR0PAQH/ BAQDAgGmMBIGA1UdEwEB/wQIMAYBAf8CAQEwHQYDVR0OBBYEFNCKTrwMgZ5/VIQ6 jkM153xFN2mDMAoGCCqGSM49BAMCA0kAMEYCIQCRA4EdlTTMhs2Zd1cT75ZgxeGa mjVPl1vqBxLkHqEO+QIhAPKVm7E452ZBe2o5rQRxGwa94MI+CyuEIH9md3nTgWWX -----END CERTIFICATE----- ================================================ FILE: examples/docker/ca/pki/secrets/root_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,98fdc560ba714aebb9fd4b714395d8ce 2bFn8yRb8lMvDR6oh22PocfhXdaoVNt4QwHCJNy0K0fG8CMokwDfEec//LseP6rA 7/EV11+ZgoN9xyTNe1kB6zFv7/kzCpRm23sqtyio+8xXWnLZNYKBRYYEeJWBUqqd GAfazg4ZFzoIH5TEPWCEAp7M9CVvtiw1SeA/zjewp2k= -----END EC PRIVATE KEY----- ================================================ FILE: examples/docker/docker-compose.yml ================================================ version: '3.3' services: ca: image: step-ca-test:latest ports: - "8443:443" restart: always renewer: depends_on: - ca image: step-renewer-test:latest volumes: - certificates:/var/local/step secrets: - password environment: STEPPATH: /home/step STEP_CA_URL: https://ca STEP_FINGERPRINT: 84a033e84196f73bd593fad7a63e509e57fd982f02084359c4e8c5c864efc27d STEP_ROOT: /var/local/step/root_ca.crt STEP_KID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk STEP_PASSWORD_FILE: /run/secrets/password COMMON_NAME: nginx nginx: depends_on: - renewer image: nginx-test:latest ports: - "4443:443" volumes: - certificates:/var/local/step:ro restart: always volumes: certificates: secrets: password: file: ./password.txt ================================================ FILE: examples/docker/nginx/Dockerfile ================================================ FROM nginx:alpine RUN apk add inotify-tools RUN mkdir -p /var/local/step COPY site.conf /etc/nginx/conf.d/ COPY certwatch.sh / COPY entrypoint.sh / # Certificate watcher and nginx ENTRYPOINT ["/entrypoint.sh"] CMD ["nginx", "-g", "daemon off;"] ================================================ FILE: examples/docker/nginx/certwatch.sh ================================================ #!/bin/sh while true; do inotifywait -e modify /var/local/step/site.crt nginx -s reload done ================================================ FILE: examples/docker/nginx/entrypoint.sh ================================================ #!/bin/sh # Wait for renewer sleep 10 # watch for the update of the cert and reload nginx /certwatch.sh & # Run docker CMD exec "$@" ================================================ FILE: examples/docker/nginx/site.conf ================================================ server { listen 443 ssl; server_name localhost; ssl_certificate /var/local/step/site.crt; ssl_certificate_key /var/local/step/site.key; location / { root /usr/share/nginx/html; index index.html index.htm; } } ================================================ FILE: examples/docker/password.txt ================================================ password ================================================ FILE: examples/docker/renewer/Dockerfile ================================================ FROM smallstep/step-cli USER root RUN mkdir -p /var/local/step ADD crontab /var/spool/cron/crontabs/root RUN chmod 0644 /var/spool/cron/crontabs/root COPY entrypoint.sh / ENTRYPOINT ["/entrypoint.sh"] CMD ["/usr/sbin/crond", "-l", "2", "-f"] ================================================ FILE: examples/docker/renewer/crontab ================================================ # min hour day month weekday command * * * * * step ca renew --force /var/local/step/site.crt /var/local/step/site.key ================================================ FILE: examples/docker/renewer/entrypoint.sh ================================================ #!/bin/sh # Wait for CA sleep 5 # Clean old certificates rm -f /var/local/step/root_ca.crt rm -f /var/local/step/site.crt /var/local/step/site.key # Download the root certificate step ca root /var/local/step/root_ca.crt # Get token STEP_TOKEN=$(step ca token $COMMON_NAME) # Download the root certificate step ca certificate --token $STEP_TOKEN $COMMON_NAME /var/local/step/site.crt /var/local/step/site.key exec "$@" ================================================ FILE: examples/pki/config/ca.json ================================================ { "root": "examples/pki/secrets/root_ca.crt", "federatedRoots": null, "crt": "examples/pki/secrets/intermediate_ca.crt", "key": "examples/pki/secrets/intermediate_ca_key", "address": ":9000", "dnsNames": [ "localhost" ], "logger": { "format": "text" }, "authority": { "provisioners": [ { "type": "jwk", "name": "mariano@smallstep.com", "key": { "use": "sig", "kty": "EC", "kid": "DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk", "crv": "P-256", "alg": "ES256", "x": "jXoO1j4CXxoTC32pNzkVC8l6k2LfP0k5ndhJZmcdVbk", "y": "c3JDL4GTFxJWHa8EaHdMh4QgwMh64P2_AGWrD0ADXcI" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiOTFVWjdzRGw3RlNXcldfX1I1NUh3USJ9.FcWtrBDNgrkA33G9Ll9sXh1cPF-3jVXeYe1FLmSDc_Q2PmfLOPvJOA.0ZoN32ayaRWnufJb.WrkffMmDLWiq1-2kn-w7-kVBGW12gjNCBHNHB1hyEdED0rWH1YWpKd8FjoOACdJyLhSn4kAS3Lw5AH7fvO27A48zzvoxZU5EgSm5HG9IjkIH-LBJ-v79ShkpmPylchgjkFhxa5epD11OIK4rFmI7s-0BCjmJokLR_DZBhDMw2khGnsr_MEOfAz9UnqXaQ4MIy8eT52xUpx68gpWFlz2YP3EqiYyNEv0PpjMtyP5lO2i8-p8BqvuJdus9H3fO5Dg-1KVto1wuqh4BQ2JKTauv60QAnM_4sdxRHku3F_nV64SCrZfDvnN2ve21raFROtyXaqHZhN6lyoPxDncy8v4.biaOblEe0N-gMpJyFZ-3-A" }, { "type": "jwk", "name": "mike@smallstep.com", "key": { "use": "sig", "kty": "EC", "kid": "YYNxZ0rq0WsT2MlqLCWvgme3jszkmt99KjoGEJJwAKs", "crv": "P-256", "alg": "ES256", "x": "LsI8nHBflc-mrCbRqhl8d3hSl5sYuSM1AbXBmRfznyg", "y": "F99LoOvi7z-ZkumsgoHIhodP8q9brXe4bhF3szK-c_w" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiVERQS2dzcEItTUR4ZDJxTGo0VlpwdyJ9.2_j0cZgTm2eFkZ-hrtr1hBIvLxN0w3TZhbX0Jrrq7vBMaywhgFcGTA.mCasZCbZJ-JT7vjA.bW052WDKSf_ueEXq1dyxLq0n3qXWRO-LXr7OzBLdUKWKSBGQrzqS5KJWqdUCPoMIHTqpwYvm-iD6uFlcxKBYxnsAG_hoq_V3icvvwNQQSd_q7Thxr2_KtPIDJWNuX1t5qXp11hkgb-8d5HO93CmN7xNDG89pzSUepT6RYXOZ483mP5fre9qzkfnrjx3oPROCnf3SnIVUvqk7fwfXuniNsg3NrNqncHYUQNReiq3e9I1R60w0ZQTvIReY7-zfiq7iPgVqmu5I7XGgFK4iBv0L7UOEora65b4hRWeLxg5t7OCfUqrS9yxAk8FdjFb9sEfjopWViPRepB0dYPH8dVI.fb6-7XWqp0j6CR9Li0NI-Q", "claims": { "minTLSCertDuration": "1m0s", "defaultTLSCertDuration": "2m0s" } }, { "type": "jwk", "name": "decade", "key": { "use": "sig", "kty": "EC", "kid": "iu7VZxKUcquv1BCWuvEUOyRy4zYyCmgt61OpRW5VbRE", "crv": "P-256", "alg": "ES256", "x": "PExnlmHxnnfpvp4bznMKbA6L_9Bk9ZhtsmvbOwh9Kys", "y": "rrMPGvxscRzDdOYtZ1wsxeQjuuFl0nSzkwTHV_P-K-Y" }, "claims": { "maxTLSCertDuration": "87600h", "defaultTLSCertDuration": "87600h" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiZS1OVzRaZlBUNjFCUmR1bjJyNk9OZyJ9.zjToJ_Od6RIzVmo0cnmLZ69am410ftfBW594qNt60KmKX6JEWUufhA.kSrC74fKK3CkqiNS.G-oUqQhYMFIKuSj8thg9B5TeiaIMsQ-o_PTxIZE-Qb8TDU15ehPAsuIQmnbM6dSpkSGCmZgHTscp3xgLyv6QEBBjUHBpLwciWyipj1KBZDKSgLKeV6G2NiVBMETOaD1DsX3DxrHM-K3T1chXJFMJfkDSx1OEtaVfzqVYLyvNb5y_26oeRNSNYuTLzOrk6Ebr6KJE6lSWpvu1dtOrDAhTErouC56EQu2fTeDCa9eN50iRs4OjmF6FtBlR63h6FkvbmjJWC3zbIOe2RXRQx0Po6_dnKXSIqs7JMZSBerlgw6jzHme8YvqBqc2Ccy4Y4gJ23nwLkcsOVuFNdk6Nb7s.SB296DDrS-Wi4a9x_TGv4A" }, { "type": "jwk", "name": "90days", "key": { "use": "sig", "kty": "EC", "kid": "2LgjIvfirblnFMC6FjUr8jYkO8nOqz4rKoarCc8kiGU", "crv": "P-256", "alg": "ES256", "x": "iHFHMN91iFUDLh2LweFj6o0gDJ-pdmBY4IFIBNfUqd4", "y": "Yfym7KtzZQaQc1gQoT81ggNBPvAdV_0CW0A5mQgOsOc" }, "claims": { "maxTLSCertDuration": "2160h", "defaultTLSCertDuration": "2160h" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiYk9XV0ZUN29uZldtZTdvbzdCMFZOdyJ9.p3gs2xd-Bdtwz1WGzQUZrcZeA8mpaMn_R_wTInpzZ9G1vIeRk-9T4g.RQNXmZP8uAzF1n8b.WpLqmNV_I0RIetdID2ag-igZryM8ekSimaHrXKoEpRAlBdBDZC-9qkbrJPNcTPRUi-29iZiBxKQ-0GX7ytiyulrQl7UfxUSrtT5vjhJEthSOGYXAOerUAnodGjpLCtIueTwVl6KJA2bXUapUd9xFn3DXfVgFagwqo1MrXKuIR0R5A4sjmEx8d2Kn_KQr0ZNnSOaAod2os4tmh3A87u9Jb51FMxhP-8Qbn7ff-RXwT_015C64Ux1zzS-ok89XbTgyfGxkah0-fVFAgS0zosHLI3C_pvumcglmFXZz7otH596BAU_QkqME6X-PGte6j6eldFobP_96tBxOhIRgVKw.Ky4xLbQZEGaBPjGJnKurng" } ] }, "tls": { "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" ], "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false }, "password": "password" } ================================================ FILE: examples/pki/secrets/intermediate_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBxjCCAWugAwIBAgIQAYoOWhdChUmmKzlc0DWcWDAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xODExMDIyMzU0MTNaFw0yODEw MzAyMzU0MTNaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAASxvIWme8/yDAxkR63KgSYkpN7mHKBH k5c8S+uzba4xWbaxZtEZ9NNhEIAgYFZ9/3ThrzLOsuGwRCvPTaD5iycQo4GGMIGD MA4GA1UdDwEB/wQEAwIBpjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIw EgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQU8dKIy5ZLH2h3ihWgqjcpoo5e q3YwHwYDVR0jBBgwFoAU0IpOvAyBnn9UhDqOQzXnfEU3aYMwCgYIKoZIzj0EAwID SQAwRgIhANXlcktuaEvORhgRvzQ6vVNgvpqCEXW3CcCHjUl1xSdaAiEAmakkpfFq VsT5PqPnTRgOWlFESRhQ9btl6nQ+2Lt/S5A= -----END CERTIFICATE----- ================================================ FILE: examples/pki/secrets/intermediate_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,4c7758e66df1884f6560839de64d4dd3 S8Ha8uA+bA3IGPurYODwd9VaJZ6FHI2tlznHXCOxT1MlGqyEAc4aWS11QBUz0Ucp excwlqM8kfh5BcN5a+vvInHnv74ZiNPdpt/apzz2LIx52pApzASiKVXRsAUmR4Pv 3MsO1/cVHkilpee1uC+axL32d5YmyP0URpSNJK9BhZo= -----END EC PRIVATE KEY----- ================================================ FILE: examples/pki/secrets/root_ca.crt ================================================ -----BEGIN CERTIFICATE----- MIIBfDCCASGgAwIBAgIQY0CXerxuM+EhTbpVxxLRKjAKBggqhkjOPQQDAjAcMRow GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xODExMDIyMzU0MTNaFw0yODEw MzAyMzU0MTNaMBwxGjAYBgNVBAMTEVNtYWxsc3RlcCBSb290IENBMFkwEwYHKoZI zj0CAQYIKoZIzj0DAQcDQgAEEGa7ZeL4WVIfPFDS7glJkIVsITVQgjfyz+AhcYaS rkJZlWOGZ60br9uE/wEfUcX1zavrX1Wz+bSJzTvT0AVBNqNFMEMwDgYDVR0PAQH/ BAQDAgGmMBIGA1UdEwEB/wQIMAYBAf8CAQEwHQYDVR0OBBYEFNCKTrwMgZ5/VIQ6 jkM153xFN2mDMAoGCCqGSM49BAMCA0kAMEYCIQCRA4EdlTTMhs2Zd1cT75ZgxeGa mjVPl1vqBxLkHqEO+QIhAPKVm7E452ZBe2o5rQRxGwa94MI+CyuEIH9md3nTgWWX -----END CERTIFICATE----- ================================================ FILE: examples/pki/secrets/root_ca_key ================================================ -----BEGIN EC PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,98fdc560ba714aebb9fd4b714395d8ce 2bFn8yRb8lMvDR6oh22PocfhXdaoVNt4QwHCJNy0K0fG8CMokwDfEec//LseP6rA 7/EV11+ZgoN9xyTNe1kB6zFv7/kzCpRm23sqtyio+8xXWnLZNYKBRYYEeJWBUqqd GAfazg4ZFzoIH5TEPWCEAp7M9CVvtiw1SeA/zjewp2k= -----END EC PRIVATE KEY----- ================================================ FILE: examples/puppet/ca.json.erb ================================================ { "root": "/usr/local/lib/step/.step/secrets/root_ca.crt", "crt": "/usr/local/lib/step/.step/secrets/intermediate_ca.crt", "key": "/usr/local/lib/step/.step/secrets/intermediate_ca_key", "password": "password", "address": ":9000", "dnsNames": [ "localhost" ], "logger": { "format": "text" }, "authority": { "provisioners": [ { "name": "mariano@smallstep.com", "type": "jwk", "key": { "use": "sig", "kty": "EC", "kid": "DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk", "crv": "P-256", "alg": "ES256", "x": "jXoO1j4CXxoTC32pNzkVC8l6k2LfP0k5ndhJZmcdVbk", "y": "c3JDL4GTFxJWHa8EaHdMh4QgwMh64P2_AGWrD0ADXcI" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiOTFVWjdzRGw3RlNXcldfX1I1NUh3USJ9.FcWtrBDNgrkA33G9Ll9sXh1cPF-3jVXeYe1FLmSDc_Q2PmfLOPvJOA.0ZoN32ayaRWnufJb.WrkffMmDLWiq1-2kn-w7-kVBGW12gjNCBHNHB1hyEdED0rWH1YWpKd8FjoOACdJyLhSn4kAS3Lw5AH7fvO27A48zzvoxZU5EgSm5HG9IjkIH-LBJ-v79ShkpmPylchgjkFhxa5epD11OIK4rFmI7s-0BCjmJokLR_DZBhDMw2khGnsr_MEOfAz9UnqXaQ4MIy8eT52xUpx68gpWFlz2YP3EqiYyNEv0PpjMtyP5lO2i8-p8BqvuJdus9H3fO5Dg-1KVto1wuqh4BQ2JKTauv60QAnM_4sdxRHku3F_nV64SCrZfDvnN2ve21raFROtyXaqHZhN6lyoPxDncy8v4.biaOblEe0N-gMpJyFZ-3-A" }, { "name": "mike@smallstep.com", "type": "jwk", "key": { "use": "sig", "kty": "EC", "kid": "YYNxZ0rq0WsT2MlqLCWvgme3jszkmt99KjoGEJJwAKs", "crv": "P-256", "alg": "ES256", "x": "LsI8nHBflc-mrCbRqhl8d3hSl5sYuSM1AbXBmRfznyg", "y": "F99LoOvi7z-ZkumsgoHIhodP8q9brXe4bhF3szK-c_w" }, "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiVERQS2dzcEItTUR4ZDJxTGo0VlpwdyJ9.2_j0cZgTm2eFkZ-hrtr1hBIvLxN0w3TZhbX0Jrrq7vBMaywhgFcGTA.mCasZCbZJ-JT7vjA.bW052WDKSf_ueEXq1dyxLq0n3qXWRO-LXr7OzBLdUKWKSBGQrzqS5KJWqdUCPoMIHTqpwYvm-iD6uFlcxKBYxnsAG_hoq_V3icvvwNQQSd_q7Thxr2_KtPIDJWNuX1t5qXp11hkgb-8d5HO93CmN7xNDG89pzSUepT6RYXOZ483mP5fre9qzkfnrjx3oPROCnf3SnIVUvqk7fwfXuniNsg3NrNqncHYUQNReiq3e9I1R60w0ZQTvIReY7-zfiq7iPgVqmu5I7XGgFK4iBv0L7UOEora65b4hRWeLxg5t7OCfUqrS9yxAk8FdjFb9sEfjopWViPRepB0dYPH8dVI.fb6-7XWqp0j6CR9Li0NI-Q", "claims": { "minTLSCertDuration": "60s", "defaultTLSCertDuration": "120s" } } ] }, "tls": { "cipherSuites": [ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" ], "minVersion": 1.2, "maxVersion": 1.2, "renegotiation": false } } ================================================ FILE: examples/puppet/defaults.json.erb ================================================ { "ca-url": "ca.smallstep.com:8080", "root": "/usr/local/lib/step/.step/secrets/root_ca.crt" } ================================================ FILE: examples/puppet/step.pp ================================================ # smallstep package configuration class step( $version = false, ) { if !$version { fail("class ${name}: version cannot be empty") } $pkg = "step_${version}_linux_amd64.tar.gz" $download_url = "https://github.com/smallstep/cli/releases/download/v${version}/step_${version}_linux_amd64.tar.gz" $step_exec = '/opt/smallstep/bin/step' exec { 'download/update smallstep': command => "/usr/bin/curl --fail -o /tmp/${pkg} ${download_url} && /bin/tar -xzvf /tmp/${pkg} -C /opt", unless => "/usr/bin/which ${step_exec} && ${step_exec} version | grep ${version}", user => 'step', require => File['/opt/smallstep']; } file { '/opt/smallstep': ensure => directory, mode => '0755', owner => 'step'; '/usr/local/lib/step': ensure => directory, mode => '0755', owner => 'step'; '/usr/local/lib/step/.step': ensure => directory, mode => '0755', owner => 'step'; '/usr/local/lib/step/.step/secrets': ensure => directory, mode => '0644', owner => 'step'; '/usr/local/lib/step/.step/config': ensure => directory, mode => '0755', owner => 'step'; } group { 'step': ensure => present, gid => $::step_id, } user { 'step': ensure => present, gid => 'puppet', home => '/usr/local/lib/step', managehome => false, uid => $::step_id, } } ================================================ FILE: examples/puppet/step_ca.pp ================================================ # step_ca package configuration class step_ca( $version = false, ) { if !$version { fail("class ${name}: version cannot be empty") } $pkg = "step_${version}_linux_amd64.tar.gz" $download_url = "https://github.com/smallstep/certificates/releases/download/v${version}/step-certificates_${version}_linux_amd64.tar.gz" $step_ca_exec = '/opt/smallstep/bin/step-ca' exec { 'download/update smallstep': command => "/usr/bin/curl --fail -o /tmp/${pkg} ${download_url} && /bin/tar -xzvf /tmp/${pkg} -C /opt", unless => "/usr/bin/which ${step_exec} && ${step_exec} version | grep ${version}", user => 'step', require => File['/opt/smallstep']; } file { '/usr/local/lib/step/.step': ensure => directory, mode => '0755', owner => 'step'; '/usr/local/lib/step/.step/secrets': ensure => directory, mode => '0644', owner => 'step'; '/usr/local/lib/step/.step/secrets/root_ca.crt': # Get this from Hiera. ensure => file, mode => '0644', owner => 'step'; '/usr/local/lib/step/.step/secrets/intermediate_ca.crt': # Get this from Hiera. ensure => file, mode => '0644', owner => 'step'; '/usr/local/lib/step/.step/secrets/intermediate_ca_key': # Get this from Hiera. ensure => file, mode => '0644', owner => 'step'; '/usr/local/lib/step/.step/secrets/intermediate_pass': # Get this from Hiera. ensure => file, mode => '0644', owner => 'step'; '/usr/local/lib/step/.step/config': ensure => directory, mode => '0755', owner => 'step'; '/usr/local/lib/step/.step/config/ca.json': # Fill from template in repo. ensure => file, content => template('ca.json.erb'), mode => '0755', owner => 'step'; '/usr/local/lib/step/.step/config/ca.json': # Fill from template in repo. ensure => file, content => template('defaults.json.erb'), mode => '0755', owner => 'step'; } service { $name: ensure => running, start => "${step_ca_exec} /usr/local/lib/step/.step/config/ca.json --password-file /usr/local/lib/step/.step/secrets/intermediate_pass", provider => 'systemd', } } ================================================ FILE: examples/puppet/tls_server.pp ================================================ # step package configuration class tls_server( $version = false, ) { if !$version { fail("class ${name}: version cannot be empty") } file { '/usr/local/lib/step/.step/secrets/provisioner_pupppet_pass': # Get this from Hiera. ensure => file, mode => '0644', owner => 'step'; } $step = "/opt/smallstep/bin/step" $step_path = "/usr/local/lib/step/.step" $secrets = "${step_path}/usr/local/lib/step/.step" service { $name: ensure => running, start => "/usr/local/bin/tls_server --token $(${step} token foo.com --ca-url=ca.smallstep.com --root=${secrets}/root_ca.crt --password-file=${secrets}/intermediate_pass)", provider => 'systemd', } } ================================================ FILE: go.mod ================================================ module github.com/smallstep/certificates go 1.25.0 require ( cloud.google.com/go/longrunning v0.8.0 cloud.google.com/go/security v1.19.2 github.com/Masterminds/sprig/v3 v3.3.0 github.com/ccoveille/go-safecast/v2 v2.0.0 github.com/coreos/go-oidc/v3 v3.17.0 github.com/coreos/go-systemd/v22 v22.7.0 github.com/dgraph-io/badger v1.6.2 github.com/dgraph-io/badger/v2 v2.2007.4 github.com/fxamacker/cbor/v2 v2.9.0 github.com/go-chi/chi/v5 v5.2.5 github.com/go-jose/go-jose/v3 v3.0.4 github.com/google/go-cmp v0.7.0 github.com/google/go-tpm v0.9.8 github.com/google/uuid v1.6.0 github.com/googleapis/gax-go/v2 v2.18.0 github.com/hashicorp/vault/api v1.22.0 github.com/hashicorp/vault/api/auth/approle v0.11.0 github.com/hashicorp/vault/api/auth/aws v0.11.0 github.com/hashicorp/vault/api/auth/kubernetes v0.10.0 github.com/newrelic/go-agent/v3 v3.42.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.23.2 github.com/rs/xid v1.6.0 github.com/sirupsen/logrus v1.9.4 github.com/slackhq/nebula v1.10.3 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/cli-utils v0.12.2 github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca github.com/smallstep/linkedca v0.25.0 github.com/smallstep/nosql v0.8.0 github.com/smallstep/pkcs7 v0.2.1 github.com/smallstep/scep v0.0.0-20250318231241-a25cabb69492 github.com/stretchr/testify v1.11.1 github.com/urfave/cli v1.22.17 go.step.sm/crypto v0.77.1 go.uber.org/mock v0.6.0 golang.org/x/crypto v0.49.0 golang.org/x/net v0.52.0 golang.org/x/sync v0.20.0 google.golang.org/api v0.271.0 google.golang.org/grpc v1.79.3 google.golang.org/protobuf v1.36.11 ) require ( cloud.google.com/go v0.123.0 // indirect cloud.google.com/go/auth v0.18.2 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/iam v1.5.3 // indirect cloud.google.com/go/kms v1.26.0 // indirect dario.cat/mergo v1.0.2 // indirect filippo.io/bigmod v0.1.0 // indirect filippo.io/edwards25519 v1.2.0 // indirect github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.4.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.3.1 // indirect github.com/ThalesIgnite/crypto11 v1.2.5 // indirect github.com/aws/aws-sdk-go v1.55.7 // indirect github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect github.com/aws/aws-sdk-go-v2/config v1.32.12 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect github.com/aws/aws-sdk-go-v2/service/kms v1.50.3 // indirect github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect github.com/aws/smithy-go v1.24.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chzyer/readline v1.5.1 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgraph-io/ristretto v0.1.0 // indirect github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-piv/piv-go/v2 v2.5.0 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/golang/glog v1.2.5 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/certificate-transparency-go v1.1.7 // indirect github.com/google/go-tpm-tools v0.4.7 // indirect github.com/google/go-tspi v0.3.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-hclog v1.6.3 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/go-retryablehttp v0.7.8 // indirect github.com/hashicorp/go-rootcerts v1.0.2 // indirect github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0 // indirect github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect github.com/hashicorp/go-sockaddr v1.0.7 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/hashicorp/hcl v1.0.1-vault-7 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.8.0 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/manifoldco/promptui v0.9.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect github.com/miekg/pkcs11 v1.1.2 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/peterbourgon/diskv/v3 v3.0.1 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/procfs v0.19.2 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect github.com/schollz/jsonstore v1.1.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect github.com/spf13/cast v1.7.0 // indirect github.com/thales-e-security/pool v0.0.2 // indirect github.com/x448/float16 v0.8.4 // indirect go.etcd.io/bbolt v1.4.3 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect go.opentelemetry.io/otel v1.39.0 // indirect go.opentelemetry.io/otel/metric v1.39.0 // indirect go.opentelemetry.io/otel/trace v1.39.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/mod v0.33.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/term v0.41.0 // indirect golang.org/x/text v0.35.0 // indirect golang.org/x/time v0.15.0 // indirect golang.org/x/tools v0.42.0 // indirect google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) ================================================ FILE: go.sum ================================================ cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM= cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= cloud.google.com/go/kms v1.26.0 h1:cK9mN2cf+9V63D3H1f6koxTatWy39aTI/hCjz1I+adU= cloud.google.com/go/kms v1.26.0/go.mod h1:pHKOdFJm63hxBsiPkYtowZPltu9dW0MWvBa6IA4HM58= cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= cloud.google.com/go/security v1.19.2 h1:cF3FkCRRbRC1oXuaGZFl3qU2sdu2gP3iOAHKzL5y04Y= cloud.google.com/go/security v1.19.2/go.mod h1:KXmf64mnOsLVKe8mk/bZpU1Rsvxqc0Ej0A6tgCeN93w= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= filippo.io/bigmod v0.1.0 h1:UNzDk7y9ADKST+axd9skUpBQeW7fG2KrTZyOE4uGQy8= filippo.io/bigmod v0.1.0/go.mod h1:OjOXDNlClLblvXdwgFFOQFJEocLhhtai8vGLy0JCZlI= filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIoKjsnZuH8vjyaysT/ses3EvZeaV/1UkF2M= github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.4.0 h1:E4MgwLBGeVB5f2MdcIVD3ELVAWpr+WD6MUe1i+tM/PA= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.4.0/go.mod h1:Y2b/1clN4zsAoUd/pgNAQHjLDnTis/6ROkUfyob6psM= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 h1:nCYfgcSyHZXJI8J0IWE5MsCGlb2xp9fJiXyxWgmOFg4= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0/go.mod h1:ucUjca2JtSZboY8IoUqyQyuuXvwbMBVwFOm0vdQPNhA= github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs= github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/ThalesIgnite/crypto11 v1.2.5 h1:1IiIIEqYmBvUYFeMnHqRft4bwf/O36jryEUpY+9ef8E= github.com/ThalesIgnite/crypto11 v1.2.5/go.mod h1:ILDKtnCKiQ7zRoNxcp36Y1ZR8LBPmR2E23+wTQe/MlE= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go v1.55.7 h1:UJrkFq7es5CShfBwlWAC8DA077vp8PyVbQd3lqLiztE= github.com/aws/aws-sdk-go v1.55.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k= github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0= github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g= github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8= github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk= github.com/aws/aws-sdk-go-v2/service/kms v1.50.3 h1:s/zDSG/a/Su9aX+v0Ld9cimUCdkr5FWPmBV8owaEbZY= github.com/aws/aws-sdk-go-v2/service/kms v1.50.3/go.mod h1:/iSgiUor15ZuxFGQSTf3lA2FmKxFsQoc2tADOarQBSw= github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow= github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE= github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o= github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA= github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU= github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/ccoveille/go-safecast/v2 v2.0.0 h1:+5eyITXAUj3wMjad6cRVJKGnC7vDS55zk0INzJagub0= github.com/ccoveille/go-safecast/v2 v2.0.0/go.mod h1:JIYA4CAR33blIDuE6fSwCp2sz1oOBahXnvmdBhOAABs= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w= github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/badger v1.6.2 h1:mNw0qs90GVgGGWylh0umH5iag1j6n/PeJtNvL6KY/x8= github.com/dgraph-io/badger v1.6.2/go.mod h1:JW2yswe3V058sS0kZ2h/AXeDSqFjxnZcRrVH//y2UQE= github.com/dgraph-io/badger/v2 v2.2007.4 h1:TRWBQg8UrlUhaFdco01nO2uXwzKS7zd+HVdwV/GHc4o= github.com/dgraph-io/badger/v2 v2.2007.4/go.mod h1:vSw/ax2qojzbN6eXHIx6KPKtCSHJN/Uz0X0VPruTIhk= github.com/dgraph-io/ristretto v0.0.2/go.mod h1:KPxhHT9ZxKefz+PCeOGsrHpl1qZ7i70dGTu2u+Ahh6E= github.com/dgraph-io/ristretto v0.0.3-0.20200630154024-f66de99634de/go.mod h1:KPxhHT9ZxKefz+PCeOGsrHpl1qZ7i70dGTu2u+Ahh6E= github.com/dgraph-io/ristretto v0.1.0 h1:Jv3CGQHp9OjuMBSne1485aDpUkTKEcUqF+jm/LuerPI= github.com/dgraph-io/ristretto v0.1.0/go.mod h1:fux0lOrBhrVCJd3lcTHsIJhq1T2rokOu6v9Vcb3Q9ug= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA= github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g= github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-piv/piv-go/v2 v2.5.0 h1:w4KZ3GytEGZt8zm+S7olcIHZk0giL23xVqCa2HgwuqA= github.com/go-piv/piv-go/v2 v2.5.0/go.mod h1:ShZi74nnrWNQEdWzRUd/3cSig3uNOcEZp+EWl0oewnI= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.5 h1:DrW6hGnjIhtvhOIiAKT6Psh/Kd/ldepEa81DKeiRJ5I= github.com/golang/glog v1.2.5/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/certificate-transparency-go v1.0.21/go.mod h1:QeJfpSbVSfYc7RgB3gJFj9cbuQMMchQxrWXz8Ruopmg= github.com/google/certificate-transparency-go v1.1.7 h1:IASD+NtgSTJLPdzkthwvAG1ZVbF2WtFg4IvoA68XGSw= github.com/google/certificate-transparency-go v1.1.7/go.mod h1:FSSBo8fyMVgqptbfF6j5p/XNdgQftAhSmXcIxV9iphE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-configfs-tsm v0.3.3-0.20240919001351-b4b5b84fdcbc h1:SG12DWUUM5igxm+//YX5Yq4vhdoRnOG9HkCodkOn+YU= github.com/google/go-configfs-tsm v0.3.3-0.20240919001351-b4b5b84fdcbc/go.mod h1:EL1GTDFMb5PZQWDviGfZV9n87WeGTR/JUg13RfwkgRo= github.com/google/go-sev-guest v0.14.0 h1:dCb4F3YrHTtrDX3cYIPTifEDz7XagZmXQioxRBW4wOo= github.com/google/go-sev-guest v0.14.0/go.mod h1:SK9vW+uyfuzYdVN0m8BShL3OQCtXZe/JPF7ZkpD3760= github.com/google/go-tdx-guest v0.3.2-0.20241009005452-097ee70d0843 h1:+MoPobRN9HrDhGyn6HnF5NYo4uMBKaiFqAtf/D/OB4A= github.com/google/go-tdx-guest v0.3.2-0.20241009005452-097ee70d0843/go.mod h1:g/n8sKITIT9xRivBUbizo34DTsUm2nN2uU3A662h09g= github.com/google/go-tpm v0.9.8 h1:slArAR9Ft+1ybZu0lBwpSmpwhRXaa85hWtMinMyRAWo= github.com/google/go-tpm v0.9.8/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= github.com/google/go-tpm-tools v0.4.7 h1:J3ycC8umYxM9A4eF73EofRZu4BxY0jjQnUnkhIBbvws= github.com/google/go-tpm-tools v0.4.7/go.mod h1:gSyXTZHe3fgbzb6WEGd90QucmsnT1SRdlye82gH8QjQ= github.com/google/go-tspi v0.3.0 h1:ADtq8RKfP+jrTyIWIZDIYcKOMecRqNJFOew2IT0Inus= github.com/google/go-tspi v0.3.0/go.mod h1:xfMGI3G0PhxCdNVcYr1C4C+EizojDg/TXuX5by8CiHI= github.com/google/logger v1.1.1 h1:+6Z2geNxc9G+4D4oDO9njjjn2d0wN5d7uOo0vOIW1NQ= github.com/google/logger v1.1.1/go.mod h1:BkeJZ+1FhQ+/d087r4dzojEg1u2ZX+ZqG1jTUrLM+zQ= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8= github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= github.com/googleapis/gax-go/v2 v2.18.0 h1:jxP5Uuo3bxm3M6gGtV94P4lliVetoCB4Wk2x8QA86LI= github.com/googleapis/gax-go/v2 v2.18.0/go.mod h1:uSzZN4a356eRG985CzJ3WfbFSpqkLTjsnhWGJR6EwrE= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0 h1:I8bynUKMh9I7JdwtW9voJ0xmHvBpxQtLjrMFDYmhOxY= github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0/go.mod h1:oKHSQs4ivIfZ3fbXGQOop1XuDfdSb8RIsWTGaAanSfg= github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM= github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0= github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw= github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw= github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I= github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= github.com/hashicorp/vault/api v1.22.0 h1:+HYFquE35/B74fHoIeXlZIP2YADVboaPjaSicHEZiH0= github.com/hashicorp/vault/api v1.22.0/go.mod h1:IUZA2cDvr4Ok3+NtK2Oq/r+lJeXkeCrHRmqdyWfpmGM= github.com/hashicorp/vault/api/auth/approle v0.11.0 h1:ViUvgqoSTqHkMi1L1Rr/LnQ+PWiRaGUBGvx4UPfmKOw= github.com/hashicorp/vault/api/auth/approle v0.11.0/go.mod h1:v8ZqBRw+GP264ikIw2sEBKF0VT72MEhLWnZqWt3xEG8= github.com/hashicorp/vault/api/auth/aws v0.11.0 h1:lWdUxrzvPotg6idNr62al4w97BgI9xTDdzMCTViNH2s= github.com/hashicorp/vault/api/auth/aws v0.11.0/go.mod h1:PWqdH/xqaudapmnnGP9ip2xbxT/kRW2qEgpqiQff6Gc= github.com/hashicorp/vault/api/auth/kubernetes v0.10.0 h1:5rqWmUFxnu3S7XYq9dafURwBgabYDFzo2Wv+AMopPHs= github.com/hashicorp/vault/api/auth/kubernetes v0.10.0/go.mod h1:cZZmhF6xboMDmDbMY52oj2DKW6gS0cQ9g0pJ5XIXQ5U= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/miekg/pkcs11 v1.1.2 h1:/VxmeAX5qU6Q3EwafypogwWbYryHFmF2RpkJmw3m4MQ= github.com/miekg/pkcs11 v1.1.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/newrelic/go-agent/v3 v3.42.0 h1:aA2Ea1RT5eD59LtOS1KGFXSmaDs6kM3Jeqo7PpuQoFQ= github.com/newrelic/go-agent/v3 v3.42.0/go.mod h1:sCgxDCVydoKD/C4S8BFxDtmFHvdWHtaIz/a3kiyNB/k= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/peterbourgon/diskv/v3 v3.0.1 h1:x06SQA46+PKIUftmEujdwSEpIx8kR+M9eLYsUxeYveU= github.com/peterbourgon/diskv/v3 v3.0.1/go.mod h1:kJ5Ny7vLdARGU3WUuy6uzO6T0nb/2gWcT1JiBvRmb5o= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/schollz/jsonstore v1.1.0 h1:WZBDjgezFS34CHI+myb4s8GGpir3UMpy7vWoCeO0n6E= github.com/schollz/jsonstore v1.1.0/go.mod h1:15c6+9guw8vDRyozGjN3FoILt0wpruJk9Pi66vjaZfg= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/slackhq/nebula v1.10.3 h1:EstYj8ODEcv6T0R9X5BVq1zgWZnyU5gtPzk99QF1PMU= github.com/slackhq/nebula v1.10.3/go.mod h1:IL5TUQm4x9IFx2kCKPYm1gP47pwd5b8QGnnBH2RHnvs= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= github.com/smallstep/cli-utils v0.12.2 h1:lGzM9PJrH/qawbzMC/s2SvgLdJPKDWKwKzx9doCVO+k= github.com/smallstep/cli-utils v0.12.2/go.mod h1:uCPqefO29goHLGqFnwk0i8W7XJu18X3WHQFRtOm/00Y= github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca h1:VX8L0r8vybH0bPeaIxh4NQzafKQiqvlOn8pmOXbFLO4= github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca/go.mod h1:vNAduivU014fubg6ewygkAvQC0IQVXqdc8vaGl/0er4= github.com/smallstep/linkedca v0.25.0 h1:txT9QHGbCsJq0MhAghBq7qhurGY727tQuqUi+n4BVBo= github.com/smallstep/linkedca v0.25.0/go.mod h1:Q3jVAauFKNlF86W5/RFtgQeyDKz98GL/KN3KG4mJOvc= github.com/smallstep/nosql v0.8.0 h1:FBTCUfKPmWYbrozW+RBKu+fnvbn+zr5rVli/XB4Jp4A= github.com/smallstep/nosql v0.8.0/go.mod h1:5dUpNotHLHhOUapP0PLBVVfp3tG1DFC31VRccg+Cqwo= github.com/smallstep/pkcs7 v0.2.1 h1:6Kfzr/QizdIuB6LSv8y1LJdZ3aPSfTNhTLqAx9CTLfA= github.com/smallstep/pkcs7 v0.2.1/go.mod h1:RcXHsMfL+BzH8tRhmrF1NkkpebKpq3JEM66cOFxanf0= github.com/smallstep/scep v0.0.0-20250318231241-a25cabb69492 h1:k23+s51sgYix4Zgbvpmy+1ZgXLjr4ZTkBTqXmpnImwA= github.com/smallstep/scep v0.0.0-20250318231241-a25cabb69492/go.mod h1:QQhwLqCS13nhv8L5ov7NgusowENUtXdEzdytjmJHdZQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/thales-e-security/pool v0.0.2 h1:RAPs4q2EbWsTit6tpzuvTFlgFRJ3S8Evf5gtvVDbmPg= github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpupY8mv0Phz0gjhU= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/urfave/cli v1.22.17 h1:SYzXoiPfQjHBbkYxbew5prZHS1TOLT3ierW8SYLqtVQ= github.com/urfave/cli v1.22.17/go.mod h1:b0ht0aqgH/6pBYzzxURyrM4xXNgsoT/n2ZzwQiEhNVo= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.step.sm/crypto v0.77.1 h1:4EEqfKdv0egQ1lqz2RhnU8Jv6QgXZfrgoxWMqJF9aDs= go.step.sm/crypto v0.77.1/go.mod h1:U/SsmEm80mNnfD5WIkbhuW/B1eFp3fgFvdXyDLpU1AQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/api v0.271.0 h1:cIPN4qcUc61jlh7oXu6pwOQqbJW2GqYh5PS6rB2C/JY= google.golang.org/api v0.271.0/go.mod h1:CGT29bhwkbF+i11qkRUJb2KMKqcJ1hdFceEIRd9u64Q= google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d h1:vsOm753cOAMkt76efriTCDKjpCbK18XGHMJHo0JUKhc= google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:0oz9d7g9QLSdv9/lgbIjowW1JoxMbxmBVNe8i6tORJI= google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d h1:EocjzKLywydp5uZ5tJ79iP6Q0UjDnyiHkGRWxuPBP8s= google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:48U2I+QQUYhsFrg2SY6r+nJzeOtjey7j//WBESw+qyQ= google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ= google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 h1:F29+wU6Ee6qgu9TddPgooOdaqsxTMunOoj8KA5yuS5A= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1/go.mod h1:5KF+wpkbTSbGcR9zteSqZV6fqFOWBl4Yde8En8MryZA= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: internal/cast/cast.go ================================================ package cast import ( "github.com/ccoveille/go-safecast/v2" ) type signed interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 } type unsigned interface { ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 } type number interface { signed | unsigned } func SafeUint(x int) (uint, error) { return safecast.Convert[uint](x) } func Uint(x int) uint { u, err := SafeUint(x) if err != nil { panic(err) } return u } func SafeInt64[T number](x T) (int64, error) { return safecast.Convert[int64](x) } func Int64[T number](x T) int64 { i64, err := SafeInt64(x) if err != nil { panic(err) } return i64 } func SafeUint64[T signed](x T) (uint64, error) { return safecast.Convert[uint64](x) } func Uint64[T signed](x T) uint64 { u64, err := SafeUint64(x) if err != nil { panic(err) } return u64 } func SafeInt32[T signed](x T) (int32, error) { return safecast.Convert[int32](x) } func Int32[T signed](x T) int32 { i32, err := SafeInt32(x) if err != nil { panic(err) } return i32 } func SafeUint32[T signed](x T) (uint32, error) { return safecast.Convert[uint32](x) } func Uint32[T signed](x T) uint32 { u32, err := SafeUint32(x) if err != nil { panic(err) } return u32 } func SafeUint16(x int) (uint16, error) { return safecast.Convert[uint16](x) } func Uint16(x int) uint16 { u16, err := SafeUint16(x) if err != nil { panic(err) } return u16 } func SafeUint8[T number](x T) (uint8, error) { return safecast.Convert[uint8](x) } func Uint8[T number](x T) uint8 { u8, err := SafeUint8(x) if err != nil { panic(err) } return u8 } ================================================ FILE: internal/cast/cast_test.go ================================================ package cast import ( "math" "testing" "github.com/stretchr/testify/require" ) func TestUintConvertsValues(t *testing.T) { require.Equal(t, uint(0), Uint(0)) require.Equal(t, uint(math.MaxInt), Uint(math.MaxInt)) require.Equal(t, uint(42), Uint(42)) } func TestUintPanicsOnNegativeValue(t *testing.T) { require.Panics(t, func() { Uint(-1) }) } func TestInt64ConvertsValues(t *testing.T) { require.Equal(t, int64(0), Int64(0)) require.Equal(t, int64(math.MaxInt), Int64(math.MaxInt)) require.Equal(t, int64(42), Int64(42)) } func TestInt64PanicsOnLargeValue(t *testing.T) { require.Panics(t, func() { Int64(uint64(math.MaxInt64 + 1)) }) } func TestUint64ConvertsValues(t *testing.T) { require.Equal(t, uint64(0), Uint64(0)) require.Equal(t, uint64(math.MaxInt), Uint64((math.MaxInt))) require.Equal(t, uint64(42), Uint64(42)) } func TestUint64PanicsOnNegativeValue(t *testing.T) { require.Panics(t, func() { Uint64(-1) }) } func TestInt32ConvertsValues(t *testing.T) { require.Equal(t, int32(0), Int32(0)) require.Equal(t, int32(math.MaxInt32), Int32(math.MaxInt32)) require.Equal(t, int32(42), Int32(42)) } func TestInt32PanicsOnTooSmallValue(t *testing.T) { require.Panics(t, func() { Int32(int64(math.MinInt32 - 1)) }) } func TestInt32PanicsOnLargeValue(t *testing.T) { require.Panics(t, func() { Int32(int64(math.MaxInt32 + 1)) }) } func TestUint32ConvertsValues(t *testing.T) { require.Equal(t, uint32(0), Uint32(0)) require.Equal(t, uint32(math.MaxUint32), Uint32(int64(math.MaxUint32))) require.Equal(t, uint32(42), Uint32(42)) } func TestUint32PanicsOnNegativeValue(t *testing.T) { require.Panics(t, func() { Uint32(-1) }) } func TestUint32PanicsOnLargeValue(t *testing.T) { require.Panics(t, func() { Uint32(int64(math.MaxUint32 + 1)) }) } func TestUint16ConvertsValues(t *testing.T) { require.Equal(t, uint16(0), Uint16(0)) require.Equal(t, uint16(math.MaxUint16), Uint16(math.MaxUint16)) require.Equal(t, uint16(42), Uint16(42)) } func TestUint16PanicsOnNegativeValue(t *testing.T) { require.Panics(t, func() { Uint16(-1) }) } func TestUint16PanicsOnLargeValue(t *testing.T) { require.Panics(t, func() { Uint16(math.MaxUint16 + 1) }) } func TestUint8ConvertsValues(t *testing.T) { require.Equal(t, uint8(0), Uint8(0)) require.Equal(t, uint8(math.MaxUint8), Uint8(math.MaxUint8)) require.Equal(t, uint8(42), Uint8(42)) } func TestUint8PanicsOnNegativeValue(t *testing.T) { require.Panics(t, func() { Uint8(-1) }) } func TestUint8PanicsOnLargeValue(t *testing.T) { require.Panics(t, func() { Uint8(math.MaxUint8 + 1) }) } ================================================ FILE: internal/httptransport/httptransport.go ================================================ // Package httptransport implements initialization of [http.Transport] instances and related // functionality. package httptransport import ( "net" "net/http" "time" ) // Wrapper wraps the set of functions mapping [http.Transport] references to [http.RoundTripper]. type Wrapper func(*http.Transport) http.RoundTripper // NoopWrapper returns a [Wrapper] that simply casts its provided [http.Transport] to an // [http.RoundTripper]. func NoopWrapper() Wrapper { return func(t *http.Transport) http.RoundTripper { return t } } // New returns a reference to an [http.Transport] that's initialized just like the // [http.DefaultTransport] is by the standard library. func New() *http.Transport { return &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } } ================================================ FILE: internal/metrix/meter.go ================================================ // Package metrix implements stats-related functionality. package metrix import ( "crypto/x509" "net/http" "strconv" "time" "github.com/smallstep/certificates/authority/provisioner" "golang.org/x/crypto/ssh" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" ) // New initializes and returns a new [Meter]. func New() (m *Meter) { initializedAt := time.Now() defaultLabels := []string{"provisioner", "success"} sshSignLabels := []string{"provisioner", "success", "type"} m = &Meter{ uptime: prometheus.NewGaugeFunc( prometheus.GaugeOpts(opts( "", "uptime_seconds", "Number of seconds since service start", )), func() float64 { return float64(time.Since(initializedAt) / time.Second) }, ), ssh: newProvisionerInstruments("ssh", sshSignLabels, defaultLabels), x509: newProvisionerInstruments("x509", defaultLabels, defaultLabels), kms: &kms{ signed: prometheus.NewCounter(prometheus.CounterOpts(opts("kms", "signed", "Number of KMS-backed signatures"))), errors: prometheus.NewCounter(prometheus.CounterOpts(opts("kms", "errors", "Number of KMS-related errors"))), }, } reg := prometheus.NewRegistry() reg.MustRegister( m.uptime, m.ssh.rekeyed, m.ssh.renewed, m.ssh.signed, m.ssh.webhookAuthorized, m.ssh.webhookEnriched, m.x509.rekeyed, m.x509.renewed, m.x509.signed, m.x509.webhookAuthorized, m.x509.webhookEnriched, m.kms.signed, m.kms.errors, ) h := promhttp.HandlerFor(reg, promhttp.HandlerOpts{ Registry: reg, Timeout: 5 * time.Second, MaxRequestsInFlight: 10, }) mux := http.NewServeMux() mux.Handle("/metrics", h) m.Handler = mux return } // Meter wraps the functionality of a Prometheus-compatible HTTP handler. type Meter struct { http.Handler uptime prometheus.GaugeFunc ssh *provisionerInstruments x509 *provisionerInstruments kms *kms } // SSHRekeyed implements [authority.Meter] for [Meter]. func (m *Meter) SSHRekeyed(cert *ssh.Certificate, p provisioner.Interface, err error) { incrProvisionerCounter(m.ssh.rekeyed, p, err, sshCertValues(cert)...) } // SSHRenewed implements [authority.Meter] for [Meter]. func (m *Meter) SSHRenewed(cert *ssh.Certificate, p provisioner.Interface, err error) { incrProvisionerCounter(m.ssh.renewed, p, err, sshCertValues(cert)...) } // SSHSigned implements [authority.Meter] for [Meter]. func (m *Meter) SSHSigned(cert *ssh.Certificate, p provisioner.Interface, err error) { incrProvisionerCounter(m.ssh.signed, p, err, sshCertValues(cert)...) } // SSHWebhookAuthorized implements [authority.Meter] for [Meter]. func (m *Meter) SSHWebhookAuthorized(p provisioner.Interface, err error) { incrProvisionerCounter(m.ssh.webhookAuthorized, p, err) } // SSHWebhookEnriched implements [authority.Meter] for [Meter]. func (m *Meter) SSHWebhookEnriched(p provisioner.Interface, err error) { incrProvisionerCounter(m.ssh.webhookEnriched, p, err) } // X509Rekeyed implements [authority.Meter] for [Meter]. func (m *Meter) X509Rekeyed(_ []*x509.Certificate, p provisioner.Interface, err error) { incrProvisionerCounter(m.x509.rekeyed, p, err) } // X509Renewed implements [authority.Meter] for [Meter]. func (m *Meter) X509Renewed(_ []*x509.Certificate, p provisioner.Interface, err error) { incrProvisionerCounter(m.x509.renewed, p, err) } // X509Signed implements [authority.Meter] for [Meter]. func (m *Meter) X509Signed(_ []*x509.Certificate, p provisioner.Interface, err error) { incrProvisionerCounter(m.x509.signed, p, err) } // X509WebhookAuthorized implements [authority.Meter] for [Meter]. func (m *Meter) X509WebhookAuthorized(p provisioner.Interface, err error) { incrProvisionerCounter(m.x509.webhookAuthorized, p, err) } // X509WebhookEnriched implements [authority.Meter] for [Meter]. func (m *Meter) X509WebhookEnriched(p provisioner.Interface, err error) { incrProvisionerCounter(m.x509.webhookEnriched, p, err) } func sshCertValues(cert *ssh.Certificate) []string { switch cert.CertType { case ssh.UserCert: return []string{"user"} case ssh.HostCert: return []string{"host"} default: return []string{"unknown"} } } func incrProvisionerCounter(cv *prometheus.CounterVec, p provisioner.Interface, err error, extraValues ...string) { var name string if p != nil { name = p.GetName() } values := append([]string{ name, strconv.FormatBool(err == nil), }, extraValues...) cv.WithLabelValues(values...).Inc() } // KMSSigned implements [authority.Meter] for [Meter]. func (m *Meter) KMSSigned(err error) { if err == nil { m.kms.signed.Inc() } else { m.kms.errors.Inc() } } // provisionerInstruments wraps the counters exported by provisioners. type provisionerInstruments struct { rekeyed *prometheus.CounterVec renewed *prometheus.CounterVec signed *prometheus.CounterVec webhookAuthorized *prometheus.CounterVec webhookEnriched *prometheus.CounterVec } func newProvisionerInstruments(subsystem string, signLabels, webhookLabels []string) *provisionerInstruments { return &provisionerInstruments{ rekeyed: newCounterVec(subsystem, "rekeyed_total", "Number of certificates rekeyed", signLabels...), renewed: newCounterVec(subsystem, "renewed_total", "Number of certificates renewed", signLabels...), signed: newCounterVec(subsystem, "signed_total", "Number of certificates signed", signLabels...), webhookAuthorized: newCounterVec(subsystem, "webhook_authorized_total", "Number of authorizing webhooks called", webhookLabels...), webhookEnriched: newCounterVec(subsystem, "webhook_enriched_total", "Number of enriching webhooks called", webhookLabels...), } } type kms struct { signed prometheus.Counter errors prometheus.Counter } func newCounterVec(subsystem, name, help string, labels ...string) *prometheus.CounterVec { opts := opts(subsystem, name, help) return prometheus.NewCounterVec(prometheus.CounterOpts(opts), labels) } func opts(subsystem, name, help string) prometheus.Opts { return prometheus.Opts{ Namespace: "step_ca", Subsystem: subsystem, Name: name, Help: help, } } ================================================ FILE: internal/userid/userid.go ================================================ package userid import "context" type contextKey struct{} // NewContext returns a new context with the given user ID added to the // context. // TODO(hs): this doesn't seem to be used / set currently; implement // when/where it makes sense. func NewContext(ctx context.Context, userID string) context.Context { return context.WithValue(ctx, contextKey{}, userID) } // FromContext returns the user ID from the context if it exists // and is not empty. func FromContext(ctx context.Context) (string, bool) { v, ok := ctx.Value(contextKey{}).(string) return v, ok && v != "" } ================================================ FILE: logging/clf.go ================================================ package logging import ( "bytes" "fmt" "strconv" "time" "github.com/sirupsen/logrus" ) var clfFields = [...]string{ "request-id", "remote-address", "name", "user-id", "time", "duration", "method", "path", "protocol", "status", "size", } // CommonLogFormat implements the logrus.Formatter interface it writes logrus // entries using a CLF format prepended by the request-id. type CommonLogFormat struct{} // Format implements the logrus.Formatter interface. It returns the given // logrus entry as a CLF line with the following format: // //