Repository: bitly/google_auth_proxy Branch: master Commit: fa2771998a98 Files: 56 Total size: 201.9 KB Directory structure: gitextract_1bjvy9ul/ ├── .gitignore ├── .travis.yml ├── Gopkg.toml ├── LICENSE ├── README.md ├── api/ │ ├── api.go │ └── api_test.go ├── contrib/ │ ├── oauth2_proxy.cfg.example │ └── oauth2_proxy.service.example ├── cookie/ │ ├── cookies.go │ ├── cookies_test.go │ └── nonce.go ├── dist.sh ├── env_options.go ├── env_options_test.go ├── htpasswd.go ├── htpasswd_test.go ├── http.go ├── logging_handler.go ├── logging_handler_test.go ├── main.go ├── oauthproxy.go ├── oauthproxy_test.go ├── options.go ├── options_test.go ├── providers/ │ ├── azure.go │ ├── azure_test.go │ ├── facebook.go │ ├── github.go │ ├── github_test.go │ ├── gitlab.go │ ├── gitlab_test.go │ ├── google.go │ ├── google_test.go │ ├── internal_util.go │ ├── internal_util_test.go │ ├── linkedin.go │ ├── linkedin_test.go │ ├── oidc.go │ ├── provider_data.go │ ├── provider_default.go │ ├── provider_default_test.go │ ├── providers.go │ ├── session_state.go │ └── session_state_test.go ├── string_array.go ├── templates.go ├── templates_test.go ├── test.sh ├── validator.go ├── validator_test.go ├── validator_watcher_copy_test.go ├── validator_watcher_test.go ├── version.go ├── watcher.go └── watcher_unsupported.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ oauth2_proxy vendor dist .godeps *.exe # Go.gitignore # Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders _obj _test # Architecture specific extensions/prefixes *.[568vq] [568vq].out *.cgo1.go *.cgo2.c _cgo_defun.c _cgo_gotypes.go _cgo_export.* _testmain.go # Editor swap/temp files .*.swp ================================================ FILE: .travis.yml ================================================ language: go go: - 1.8.x - 1.9.x script: - wget -O dep https://github.com/golang/dep/releases/download/v0.3.2/dep-linux-amd64 - chmod +x dep - ./dep ensure - ./test.sh sudo: false notifications: email: false ================================================ FILE: Gopkg.toml ================================================ # Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md # for detailed Gopkg.toml documentation. # [[constraint]] name = "github.com/18F/hmacauth" version = "~1.0.1" [[constraint]] name = "github.com/BurntSushi/toml" version = "~0.3.0" [[constraint]] name = "github.com/bitly/go-simplejson" version = "~0.5.0" [[constraint]] branch = "v2" name = "github.com/coreos/go-oidc" [[constraint]] branch = "master" name = "github.com/mreiferson/go-options" [[constraint]] name = "github.com/stretchr/testify" version = "~1.1.4" [[constraint]] branch = "master" name = "golang.org/x/oauth2" [[constraint]] branch = "master" name = "google.golang.org/api" [[constraint]] name = "gopkg.in/fsnotify.v1" version = "~1.2.0" [[constraint]] branch = "master" name = "golang.org/x/crypto" ================================================ FILE: LICENSE ================================================ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ oauth2_proxy ================= A reverse proxy and static file server that provides authentication using Providers (Google, GitHub, and others) to validate accounts by email, domain or group. [![Build Status](https://secure.travis-ci.org/bitly/oauth2_proxy.svg?branch=master)](http://travis-ci.org/bitly/oauth2_proxy) ![Sign In Page](https://cloud.githubusercontent.com/assets/45028/4970624/7feb7dd8-6886-11e4-93e0-c9904af44ea8.png) **NOTICE**: This project was officially archived by Bitly at the end of September 2018. Bitly will no longer be accepting PRs or helping on issues. There has been a [discussion](https://github.com/bitly/oauth2_proxy/issues/628) to find a new home for the project which has led to the following notable forks: - [pomerium](https://github.com/pomerium/pomerium) an identity-access proxy, inspired by BeyondCorp. - [buzzfeed/sso](https://github.com/buzzfeed/sso) a "double OAuth2" flow, where sso-auth is the OAuth2 provider for sso-proxy and Google is the OAuth2 provider for sso-auth. - [openshift/oauth_proxy](https://github.com/openshift/oauth-proxy) an openshift specific version of this project. - [pusher/oauth2_proxy](https://github.com/pusher/oauth2_proxy) official hard fork of this project. Please submit all future PRs and issues to [pusher/oauth2_proxy](https://github.com/pusher/oauth2_proxy). ## Architecture ![OAuth2 Proxy Architecture](https://cloud.githubusercontent.com/assets/45028/8027702/bd040b7a-0d6a-11e5-85b9-f8d953d04f39.png) ## Installation 1. Download [Prebuilt Binary](https://github.com/bitly/oauth2_proxy/releases) (current release is `v2.2`) or build with `$ go get github.com/bitly/oauth2_proxy` which will put the binary in `$GOROOT/bin` Prebuilt binaries can be validated by extracting the file and verifying it against the `sha256sum.txt` checksum file provided for each release starting with version `v2.3`. ``` sha256sum -c sha256sum.txt 2>&1 | grep OK oauth2_proxy-2.3.linux-amd64: OK ``` 2. Select a Provider and Register an OAuth Application with a Provider 3. Configure OAuth2 Proxy using config file, command line options, or environment variables 4. Configure SSL or Deploy behind a SSL endpoint (example provided for Nginx) ## OAuth Provider Configuration You will need to register an OAuth application with a Provider (Google, GitHub or another provider), and configure it with Redirect URI(s) for the domain you intend to run `oauth2_proxy` on. Valid providers are : * [Google](#google-auth-provider) *default* * [Azure](#azure-auth-provider) * [Facebook](#facebook-auth-provider) * [GitHub](#github-auth-provider) * [GitLab](#gitlab-auth-provider) * [LinkedIn](#linkedin-auth-provider) The provider can be selected using the `provider` configuration value. ### Google Auth Provider For Google, the registration steps are: 1. Create a new project: https://console.developers.google.com/project 2. Choose the new project from the top right project dropdown (only if another project is selected) 3. In the project Dashboard center pane, choose **"API Manager"** 4. In the left Nav pane, choose **"Credentials"** 5. In the center pane, choose **"OAuth consent screen"** tab. Fill in **"Product name shown to users"** and hit save. 6. In the center pane, choose **"Credentials"** tab. * Open the **"New credentials"** drop down * Choose **"OAuth client ID"** * Choose **"Web application"** * Application name is freeform, choose something appropriate * Authorized JavaScript origins is your domain ex: `https://internal.yourcompany.com` * Authorized redirect URIs is the location of oauth2/callback ex: `https://internal.yourcompany.com/oauth2/callback` * Choose **"Create"** 4. Take note of the **Client ID** and **Client Secret** It's recommended to refresh sessions on a short interval (1h) with `cookie-refresh` setting which validates that the account is still authorized. #### Restrict auth to specific Google groups on your domain. (optional) 1. Create a service account: https://developers.google.com/identity/protocols/OAuth2ServiceAccount and make sure to download the json file. 2. Make note of the Client ID for a future step. 3. Under "APIs & Auth", choose APIs. 4. Click on Admin SDK and then Enable API. 5. Follow the steps on https://developers.google.com/admin-sdk/directory/v1/guides/delegation#delegate_domain-wide_authority_to_your_service_account and give the client id from step 2 the following oauth scopes: ``` https://www.googleapis.com/auth/admin.directory.group.readonly https://www.googleapis.com/auth/admin.directory.user.readonly ``` 6. Follow the steps on https://support.google.com/a/answer/60757 to enable Admin API access. 7. Create or choose an existing administrative email address on the Gmail domain to assign to the ```google-admin-email``` flag. This email will be impersonated by this client to make calls to the Admin SDK. See the note on the link from step 5 for the reason why. 8. Create or choose an existing email group and set that email to the ```google-group``` flag. You can pass multiple instances of this flag with different groups and the user will be checked against all the provided groups. 9. Lock down the permissions on the json file downloaded from step 1 so only oauth2_proxy is able to read the file and set the path to the file in the ```google-service-account-json``` flag. 10. Restart oauth2_proxy. Note: The user is checked against the group members list on initial authentication and every time the token is refreshed ( about once an hour ). ### Azure Auth Provider 1. [Add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/) to your Azure Active Directory tenant. 2. On the App properties page provide the correct Sign-On URL ie `https://internal.yourcompany.com/oauth2/callback` 3. If applicable take note of your `TenantID` and provide it via the `--azure-tenant=` commandline option. Default the `common` tenant is used. The Azure AD auth provider uses `openid` as it default scope. It uses `https://graph.windows.net` as a default protected resource. It call to `https://graph.windows.net/me` to get the email address of the user that logs in. ### Facebook Auth Provider 1. Create a new FB App from 2. Under FB Login, set your Valid OAuth redirect URIs to `https://internal.yourcompany.com/oauth2/callback` ### GitHub Auth Provider 1. Create a new project: https://github.com/settings/developers 2. Under `Authorization callback URL` enter the correct url ie `https://internal.yourcompany.com/oauth2/callback` The GitHub auth provider supports two additional parameters to restrict authentication to Organization or Team level access. Restricting by org and team is normally accompanied with `--email-domain=*` -github-org="": restrict logins to members of this organisation -github-team="": restrict logins to members of any of these teams (slug), separated by a comma If you are using GitHub enterprise, make sure you set the following to the appropriate url: -login-url="http(s):///login/oauth/authorize" -redeem-url="http(s):///login/oauth/access_token" -validate-url="http(s):///api/v3" ### GitLab Auth Provider Whether you are using GitLab.com or self-hosting GitLab, follow [these steps to add an application](http://doc.gitlab.com/ce/integration/oauth_provider.html) If you are using self-hosted GitLab, make sure you set the following to the appropriate URL: -login-url="/oauth/authorize" -redeem-url="/oauth/token" -validate-url="/api/v4/user" ### LinkedIn Auth Provider For LinkedIn, the registration steps are: 1. Create a new project: https://www.linkedin.com/secure/developer 2. In the OAuth User Agreement section: * In default scope, select r_basicprofile and r_emailaddress. * In "OAuth 2.0 Redirect URLs", enter `https://internal.yourcompany.com/oauth2/callback` 3. Fill in the remaining required fields and Save. 4. Take note of the **Consumer Key / API Key** and **Consumer Secret / Secret Key** ### Microsoft Azure AD Provider For adding an application to the Microsoft Azure AD follow [these steps to add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/). Take note of your `TenantId` if applicable for your situation. The `TenantId` can be used to override the default `common` authorization server with a tenant specific server. ### OpenID Connect Provider OpenID Connect is a spec for OAUTH 2.0 + identity that is implemented by many major providers and several open source projects. This provider was originally built against CoreOS Dex and we will use it as an example. 1. Launch a Dex instance using the [getting started guide](https://github.com/coreos/dex/blob/master/Documentation/getting-started.md). 2. Setup oauth2_proxy with the correct provider and using the default ports and callbacks. 3. Login with the fixture use in the dex guide and run the oauth2_proxy with the following args: -provider oidc -client-id oauth2_proxy -client-secret proxy -redirect-url http://127.0.0.1:4180/oauth2/callback -oidc-issuer-url http://127.0.0.1:5556 -cookie-secure=false -email-domain example.com ## Email Authentication To authorize by email domain use `--email-domain=yourcompany.com`. To authorize individual email addresses use `--authenticated-emails-file=/path/to/file` with one email per line. To authorize all email addresses use `--email-domain=*`. ## Configuration `oauth2_proxy` can be configured via [config file](#config-file), [command line options](#command-line-options) or [environment variables](#environment-variables). To generate a strong cookie secret use `python -c 'import os,base64; print base64.urlsafe_b64encode(os.urandom(16))'` ### Config File An example [oauth2_proxy.cfg](contrib/oauth2_proxy.cfg.example) config file is in the contrib directory. It can be used by specifying `-config=/etc/oauth2_proxy.cfg` ### Command Line Options ``` Usage of oauth2_proxy: -approval-prompt string: OAuth approval_prompt (default "force") -authenticated-emails-file string: authenticate against emails via file (one per line) -azure-tenant string: go to a tenant-specific or common (tenant-independent) endpoint. (default "common") -basic-auth-password string: the password to set when passing the HTTP Basic Auth header -client-id string: the OAuth Client ID: ie: "123456.apps.googleusercontent.com" -client-secret string: the OAuth Client Secret -config string: path to config file -cookie-domain string: an optional cookie domain to force cookies to (ie: .yourcompany.com) -cookie-expire duration: expire timeframe for cookie (default 168h0m0s) -cookie-httponly: set HttpOnly cookie flag (default true) -cookie-name string: the name of the cookie that the oauth_proxy creates (default "_oauth2_proxy") -cookie-refresh duration: refresh the cookie after this duration; 0 to disable -cookie-secret string: the seed string for secure cookies (optionally base64 encoded) -cookie-secure: set secure (HTTPS) cookie flag (default true) -custom-templates-dir string: path to custom html templates -display-htpasswd-form: display username / password login form if an htpasswd file is provided (default true) -email-domain value: authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email -footer string: custom footer string. Use "-" to disable default footer. -github-org string: restrict logins to members of this organisation -github-team string: restrict logins to members of any of these teams (slug), separated by a comma -google-admin-email string: the google admin to impersonate for api calls -google-group value: restrict logins to members of this google group (may be given multiple times). -google-service-account-json string: the path to the service account json credentials -htpasswd-file string: additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption -http-address string: [http://]: or unix:// to listen on for HTTP clients (default "127.0.0.1:4180") -https-address string: : to listen on for HTTPS clients (default ":443") -login-url string: Authentication endpoint -pass-access-token: pass OAuth access_token to upstream via X-Forwarded-Access-Token header -pass-basic-auth: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream (default true) -pass-host-header: pass the request Host Header to upstream (default true) -pass-user-headers: pass X-Forwarded-User and X-Forwarded-Email information to upstream (default true) -profile-url string: Profile access endpoint -provider string: OAuth provider (default "google") -proxy-prefix string: the url root path that this proxy should be nested under (e.g. //sign_in) (default "/oauth2") -redeem-url string: Token redemption endpoint -redirect-url string: the OAuth Redirect URL. ie: "https://internalapp.yourcompany.com/oauth2/callback" -request-logging: Log requests to stdout (default true) -request-logging-format: Template for request log lines (see "Logging Format" paragraph below) -resource string: The resource that is protected (Azure AD only) -scope string: OAuth scope specification -set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode) -signature-key string: GAP-Signature request signature key (algorithm:secretkey) -skip-auth-preflight: will skip authentication for OPTIONS requests -skip-auth-regex value: bypass authentication for requests path's that match (may be given multiple times) -skip-provider-button: will skip sign-in-page to directly reach the next step: oauth/start -ssl-insecure-skip-verify: skip validation of certificates presented when using HTTPS -tls-cert string: path to certificate file -tls-key string: path to private key file -upstream value: the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path -validate-url string: Access token validation endpoint -version: print version string ``` See below for provider specific options ### Upstreams Configuration `oauth2_proxy` supports having multiple upstreams, and has the option to pass requests on to HTTP(S) servers or serve static files from the file system. HTTP and HTTPS upstreams are configured by providing a URL such as `http://127.0.0.1:8080/` for the upstream parameter, that will forward all authenticated requests to be forwarded to the upstream server. If you instead provide `http://127.0.0.1:8080/some/path/` then it will only be requests that start with `/some/path/` which are forwarded to the upstream. Static file paths are configured as a file:// URL. `file:///var/www/static/` will serve the files from that directory at `http://[oauth2_proxy url]/var/www/static/`, which may not be what you want. You can provide the path to where the files should be available by adding a fragment to the configured URL. The value of the fragment will then be used to specify which path the files are available at. `file:///var/www/static/#/static/` will ie. make `/var/www/static/` available at `http://[oauth2_proxy url]/static/`. Multiple upstreams can either be configured by supplying a comma separated list to the `-upstream` parameter, supplying the parameter multiple times or provinding a list in the [config file](#config-file). When multiple upstreams are used routing to them will be based on the path they are set up with. ### Environment variables The following environment variables can be used in place of the corresponding command-line arguments: - `OAUTH2_PROXY_CLIENT_ID` - `OAUTH2_PROXY_CLIENT_SECRET` - `OAUTH2_PROXY_COOKIE_NAME` - `OAUTH2_PROXY_COOKIE_SECRET` - `OAUTH2_PROXY_COOKIE_DOMAIN` - `OAUTH2_PROXY_COOKIE_EXPIRE` - `OAUTH2_PROXY_COOKIE_REFRESH` - `OAUTH2_PROXY_SIGNATURE_KEY` ## SSL Configuration There are two recommended configurations. 1) Configure SSL Termination with OAuth2 Proxy by providing a `--tls-cert=/path/to/cert.pem` and `--tls-key=/path/to/cert.key`. The command line to run `oauth2_proxy` in this configuration would look like this: ```bash ./oauth2_proxy \ --email-domain="yourcompany.com" \ --upstream=http://127.0.0.1:8080/ \ --tls-cert=/path/to/cert.pem \ --tls-key=/path/to/cert.key \ --cookie-secret=... \ --cookie-secure=true \ --provider=... \ --client-id=... \ --client-secret=... ``` 2) Configure SSL Termination with [Nginx](http://nginx.org/) (example config below), Amazon ELB, Google Cloud Platform Load Balancing, or .... Because `oauth2_proxy` listens on `127.0.0.1:4180` by default, to listen on all interfaces (needed when using an external load balancer like Amazon ELB or Google Platform Load Balancing) use `--http-address="0.0.0.0:4180"` or `--http-address="http://:4180"`. Nginx will listen on port `443` and handle SSL connections while proxying to `oauth2_proxy` on port `4180`. `oauth2_proxy` will then authenticate requests for an upstream application. The external endpoint for this example would be `https://internal.yourcompany.com/`. An example Nginx config follows. Note the use of `Strict-Transport-Security` header to pin requests to SSL via [HSTS](http://en.wikipedia.org/wiki/HTTP_Strict_Transport_Security): ``` server { listen 443 default ssl; server_name internal.yourcompany.com; ssl_certificate /path/to/cert.pem; ssl_certificate_key /path/to/cert.key; add_header Strict-Transport-Security max-age=2592000; location / { proxy_pass http://127.0.0.1:4180; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Scheme $scheme; proxy_connect_timeout 1; proxy_send_timeout 30; proxy_read_timeout 30; } } ``` The command line to run `oauth2_proxy` in this configuration would look like this: ```bash ./oauth2_proxy \ --email-domain="yourcompany.com" \ --upstream=http://127.0.0.1:8080/ \ --cookie-secret=... \ --cookie-secure=true \ --provider=... \ --client-id=... \ --client-secret=... ``` ## Endpoint Documentation OAuth2 Proxy responds directly to the following endpoints. All other endpoints will be proxied upstream when authenticated. The `/oauth2` prefix can be changed with the `--proxy-prefix` config variable. * /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info * /ping - returns an 200 OK response * /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies) * /oauth2/start - a URL that will redirect to start the OAuth cycle * /oauth2/callback - the URL used at the end of the OAuth cycle. The oauth app will be configured with this as the callback url. * /oauth2/auth - only returns a 202 Accepted response or a 401 Unauthorized response; for use with the [Nginx `auth_request` directive](#nginx-auth-request) ## Request signatures If `signature_key` is defined, proxied requests will be signed with the `GAP-Signature` header, which is a [Hash-based Message Authentication Code (HMAC)](https://en.wikipedia.org/wiki/Hash-based_message_authentication_code) of selected request information and the request body [see `SIGNATURE_HEADERS` in `oauthproxy.go`](./oauthproxy.go). `signature_key` must be of the form `algorithm:secretkey`, (ie: `signature_key = "sha1:secret0"`) For more information about HMAC request signature validation, read the following: * [Amazon Web Services: Signing and Authenticating REST Requests](https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) * [rc3.org: Using HMAC to authenticate Web service requests](http://rc3.org/2011/12/02/using-hmac-to-authenticate-web-service-requests/) ## Logging Format By default, OAuth2 Proxy logs requests to stdout in a format similar to Apache Combined Log. ``` - [19/Mar/2015:17:20:19 -0400] GET "/path/" HTTP/1.1 "" ``` If you require a different format than that, you can configure it with the `-request-logging-format` flag. The default format is configured as follows: ``` {{.Client}} - {{.Username}} [{{.Timestamp}}] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}} ``` [See `logMessageData` in `logging_handler.go`](./logging_handler.go) for all available variables. ## Adding a new Provider Follow the examples in the [`providers` package](providers/) to define a new `Provider` instance. Add a new `case` to [`providers.New()`](providers/providers.go) to allow `oauth2_proxy` to use the new `Provider`. ## Configuring for use with the Nginx `auth_request` directive The [Nginx `auth_request` directive](http://nginx.org/en/docs/http/ngx_http_auth_request_module.html) allows Nginx to authenticate requests via the oauth2_proxy's `/auth` endpoint, which only returns a 202 Accepted response or a 401 Unauthorized response without proxying the request through. For example: ```nginx server { listen 443 ssl; server_name ...; include ssl/ssl.conf; location /oauth2/ { proxy_pass http://127.0.0.1:4180; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Scheme $scheme; proxy_set_header X-Auth-Request-Redirect $request_uri; } location = /oauth2/auth { proxy_pass http://127.0.0.1:4180; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Scheme $scheme; # nginx auth_request includes headers but not body proxy_set_header Content-Length ""; proxy_pass_request_body off; } location / { auth_request /oauth2/auth; error_page 401 = /oauth2/sign_in; # pass information via X-User and X-Email headers to backend, # requires running with --set-xauthrequest flag auth_request_set $user $upstream_http_x_auth_request_user; auth_request_set $email $upstream_http_x_auth_request_email; proxy_set_header X-User $user; proxy_set_header X-Email $email; # if you enabled --cookie-refresh, this is needed for it to work with auth_request auth_request_set $auth_cookie $upstream_http_set_cookie; add_header Set-Cookie $auth_cookie; proxy_pass http://backend/; # or "root /path/to/site;" or "fastcgi_pass ..." etc } } ``` ================================================ FILE: api/api.go ================================================ package api import ( "encoding/json" "fmt" "io/ioutil" "log" "net/http" "github.com/bitly/go-simplejson" ) func Request(req *http.Request) (*simplejson.Json, error) { resp, err := http.DefaultClient.Do(req) if err != nil { log.Printf("%s %s %s", req.Method, req.URL, err) return nil, err } body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() log.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) if err != nil { return nil, err } if resp.StatusCode != 200 { return nil, fmt.Errorf("got %d %s", resp.StatusCode, body) } data, err := simplejson.NewJson(body) if err != nil { return nil, err } return data, nil } func RequestJson(req *http.Request, v interface{}) error { resp, err := http.DefaultClient.Do(req) if err != nil { log.Printf("%s %s %s", req.Method, req.URL, err) return err } body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() log.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) if err != nil { return err } if resp.StatusCode != 200 { return fmt.Errorf("got %d %s", resp.StatusCode, body) } return json.Unmarshal(body, v) } func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) { req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } req.Header = header return http.DefaultClient.Do(req) } ================================================ FILE: api/api_test.go ================================================ package api import ( "github.com/bitly/go-simplejson" "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" "github.com/stretchr/testify/assert" ) func testBackend(response_code int, payload string) *httptest.Server { return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(response_code) w.Write([]byte(payload)) })) } func TestRequest(t *testing.T) { backend := testBackend(200, "{\"foo\": \"bar\"}") defer backend.Close() req, _ := http.NewRequest("GET", backend.URL, nil) response, err := Request(req) assert.Equal(t, nil, err) result, err := response.Get("foo").String() assert.Equal(t, nil, err) assert.Equal(t, "bar", result) } func TestRequestFailure(t *testing.T) { // Create a backend to generate a test URL, then close it to cause a // connection error. backend := testBackend(200, "{\"foo\": \"bar\"}") backend.Close() req, err := http.NewRequest("GET", backend.URL, nil) assert.Equal(t, nil, err) resp, err := Request(req) assert.Equal(t, (*simplejson.Json)(nil), resp) assert.NotEqual(t, nil, err) if !strings.Contains(err.Error(), "refused") { t.Error("expected error when a connection fails: ", err) } } func TestHttpErrorCode(t *testing.T) { backend := testBackend(404, "{\"foo\": \"bar\"}") defer backend.Close() req, err := http.NewRequest("GET", backend.URL, nil) assert.Equal(t, nil, err) resp, err := Request(req) assert.Equal(t, (*simplejson.Json)(nil), resp) assert.NotEqual(t, nil, err) } func TestJsonParsingError(t *testing.T) { backend := testBackend(200, "not well-formed JSON") defer backend.Close() req, err := http.NewRequest("GET", backend.URL, nil) assert.Equal(t, nil, err) resp, err := Request(req) assert.Equal(t, (*simplejson.Json)(nil), resp) assert.NotEqual(t, nil, err) } // Parsing a URL practically never fails, so we won't cover that test case. func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { token := r.FormValue("access_token") if r.URL.Path == "/" && token == "my_token" { w.WriteHeader(200) w.Write([]byte("some payload")) } else { w.WriteHeader(403) } })) defer backend.Close() response, err := RequestUnparsedResponse( backend.URL+"?access_token=my_token", nil) assert.Equal(t, nil, err) assert.Equal(t, 200, response.StatusCode) body, err := ioutil.ReadAll(response.Body) assert.Equal(t, nil, err) response.Body.Close() assert.Equal(t, "some payload", string(body)) } func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) { backend := testBackend(200, "some payload") // Close the backend now to force a request failure. backend.Close() response, err := RequestUnparsedResponse( backend.URL+"?access_token=my_token", nil) assert.NotEqual(t, nil, err) assert.Equal(t, (*http.Response)(nil), response) } func TestRequestUnparsedResponseUsingHeaders(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" { w.WriteHeader(200) w.Write([]byte("some payload")) } else { w.WriteHeader(403) } })) defer backend.Close() headers := make(http.Header) headers.Set("Auth", "my_token") response, err := RequestUnparsedResponse(backend.URL, headers) assert.Equal(t, nil, err) assert.Equal(t, 200, response.StatusCode) body, err := ioutil.ReadAll(response.Body) assert.Equal(t, nil, err) response.Body.Close() assert.Equal(t, "some payload", string(body)) } ================================================ FILE: contrib/oauth2_proxy.cfg.example ================================================ ## OAuth2 Proxy Config File ## https://github.com/bitly/oauth2_proxy ## : to listen on for HTTP/HTTPS clients # http_address = "127.0.0.1:4180" # https_address = ":443" ## TLS Settings # tls_cert_file = "" # tls_key_file = "" ## the OAuth Redirect URL. # defaults to the "https://" + requested host header + "/oauth2/callback" # redirect_url = "https://internalapp.yourcompany.com/oauth2/callback" ## the http url(s) of the upstream endpoint. If multiple, routing is based on path # upstreams = [ # "http://127.0.0.1:8080/" # ] ## Log requests to stdout # request_logging = true ## pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream # pass_basic_auth = true # pass_user_headers = true ## pass the request Host Header to upstream ## when disabled the upstream Host is used as the Host Header # pass_host_header = true ## Email Domains to allow authentication for (this authorizes any email on this domain) ## for more granular authorization use `authenticated_emails_file` ## To authorize any email addresses use "*" # email_domains = [ # "yourcompany.com" # ] ## The OAuth Client ID, Secret # client_id = "123456.apps.googleusercontent.com" # client_secret = "" ## Pass OAuth Access token to upstream via "X-Forwarded-Access-Token" # pass_access_token = false ## Authenticated Email Addresses File (one email per line) # authenticated_emails_file = "" ## Htpasswd File (optional) ## Additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption ## enabling exposes a username/login signin form # htpasswd_file = "" ## Templates ## optional directory with custom sign_in.html and error.html # custom_templates_dir = "" ## skip SSL checking for HTTPS requests # ssl_insecure_skip_verify = false ## Cookie Settings ## Name - the cookie name ## Secret - the seed string for secure cookies; should be 16, 24, or 32 bytes ## for use with an AES cipher when cookie_refresh or pass_access_token ## is set ## Domain - (optional) cookie domain to force cookies to (ie: .yourcompany.com) ## Expire - (duration) expire timeframe for cookie ## Refresh - (duration) refresh the cookie when duration has elapsed after cookie was initially set. ## Should be less than cookie_expire; set to 0 to disable. ## On refresh, OAuth token is re-validated. ## (ie: 1h means tokens are refreshed on request 1hr+ after it was set) ## Secure - secure cookies are only sent by the browser of a HTTPS connection (recommended) ## HttpOnly - httponly cookies are not readable by javascript (recommended) # cookie_name = "_oauth2_proxy" # cookie_secret = "" # cookie_domain = "" # cookie_expire = "168h" # cookie_refresh = "" # cookie_secure = true # cookie_httponly = true ================================================ FILE: contrib/oauth2_proxy.service.example ================================================ # Systemd service file for oauth2_proxy daemon # # Date: Feb 9, 2016 # Author: Srdjan Grubor [Unit] Description=oauth2_proxy daemon service After=syslog.target network.target [Service] # www-data group and user need to be created before using these lines User=www-data Group=www-data ExecStart=/usr/local/bin/oauth2_proxy -config=/etc/oauth2_proxy.cfg ExecReload=/bin/kill -HUP $MAINPID KillMode=process Restart=always [Install] WantedBy=multi-user.target ================================================ FILE: cookie/cookies.go ================================================ package cookie import ( "crypto/aes" "crypto/cipher" "crypto/hmac" "crypto/rand" "crypto/sha1" "encoding/base64" "fmt" "io" "net/http" "strconv" "strings" "time" ) // cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set. // additionally, the 'value' is encrypted so it's opaque to the browser // Validate ensures a cookie is properly signed func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) { // value, timestamp, sig parts := strings.Split(cookie.Value, "|") if len(parts) != 3 { return } sig := cookieSignature(seed, cookie.Name, parts[0], parts[1]) if checkHmac(parts[2], sig) { ts, err := strconv.Atoi(parts[1]) if err != nil { return } // The expiration timestamp set when the cookie was created // isn't sent back by the browser. Hence, we check whether the // creation timestamp stored in the cookie falls within the // window defined by (Now()-expiration, Now()]. t = time.Unix(int64(ts), 0) if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) { // it's a valid cookie. now get the contents rawValue, err := base64.URLEncoding.DecodeString(parts[0]) if err == nil { value = string(rawValue) ok = true return } } } return } // SignedValue returns a cookie that is signed and can later be checked with Validate func SignedValue(seed string, key string, value string, now time.Time) string { encodedValue := base64.URLEncoding.EncodeToString([]byte(value)) timeStr := fmt.Sprintf("%d", now.Unix()) sig := cookieSignature(seed, key, encodedValue, timeStr) cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig) return cookieVal } func cookieSignature(args ...string) string { h := hmac.New(sha1.New, []byte(args[0])) for _, arg := range args[1:] { h.Write([]byte(arg)) } var b []byte b = h.Sum(b) return base64.URLEncoding.EncodeToString(b) } func checkHmac(input, expected string) bool { inputMAC, err1 := base64.URLEncoding.DecodeString(input) if err1 == nil { expectedMAC, err2 := base64.URLEncoding.DecodeString(expected) if err2 == nil { return hmac.Equal(inputMAC, expectedMAC) } } return false } // Cipher provides methods to encrypt and decrypt cookie values type Cipher struct { cipher.Block } // NewCipher returns a new aes Cipher for encrypting cookie values func NewCipher(secret []byte) (*Cipher, error) { c, err := aes.NewCipher(secret) if err != nil { return nil, err } return &Cipher{Block: c}, err } // Encrypt a value for use in a cookie func (c *Cipher) Encrypt(value string) (string, error) { ciphertext := make([]byte, aes.BlockSize+len(value)) iv := ciphertext[:aes.BlockSize] if _, err := io.ReadFull(rand.Reader, iv); err != nil { return "", fmt.Errorf("failed to create initialization vector %s", err) } stream := cipher.NewCFBEncrypter(c.Block, iv) stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value)) return base64.StdEncoding.EncodeToString(ciphertext), nil } // Decrypt a value from a cookie to it's original string func (c *Cipher) Decrypt(s string) (string, error) { encrypted, err := base64.StdEncoding.DecodeString(s) if err != nil { return "", fmt.Errorf("failed to decrypt cookie value %s", err) } if len(encrypted) < aes.BlockSize { return "", fmt.Errorf("encrypted cookie value should be "+ "at least %d bytes, but is only %d bytes", aes.BlockSize, len(encrypted)) } iv := encrypted[:aes.BlockSize] encrypted = encrypted[aes.BlockSize:] stream := cipher.NewCFBDecrypter(c.Block, iv) stream.XORKeyStream(encrypted, encrypted) return string(encrypted), nil } ================================================ FILE: cookie/cookies_test.go ================================================ package cookie import ( "encoding/base64" "testing" "github.com/stretchr/testify/assert" ) func TestEncodeAndDecodeAccessToken(t *testing.T) { const secret = "0123456789abcdefghijklmnopqrstuv" const token = "my access token" c, err := NewCipher([]byte(secret)) assert.Equal(t, nil, err) encoded, err := c.Encrypt(token) assert.Equal(t, nil, err) decoded, err := c.Decrypt(encoded) assert.Equal(t, nil, err) assert.NotEqual(t, token, encoded) assert.Equal(t, token, decoded) } func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { const secret_b64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk=" const token = "my access token" secret, err := base64.URLEncoding.DecodeString(secret_b64) c, err := NewCipher([]byte(secret)) assert.Equal(t, nil, err) encoded, err := c.Encrypt(token) assert.Equal(t, nil, err) decoded, err := c.Decrypt(encoded) assert.Equal(t, nil, err) assert.NotEqual(t, token, encoded) assert.Equal(t, token, decoded) } ================================================ FILE: cookie/nonce.go ================================================ package cookie import ( "crypto/rand" "fmt" ) func Nonce() (nonce string, err error) { b := make([]byte, 16) _, err = rand.Read(b) if err != nil { return } nonce = fmt.Sprintf("%x", b) return } ================================================ FILE: dist.sh ================================================ #!/bin/bash # build binary distributions for linux/amd64 and darwin/amd64 set -e DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" echo "working dir $DIR" mkdir -p $DIR/dist dep ensure || exit 1 os=$(go env GOOS) arch=$(go env GOARCH) version=$(cat $DIR/version.go | grep "const VERSION" | awk '{print $NF}' | sed 's/"//g') goversion=$(go version | awk '{print $3}') sha256sum=() echo "... running tests" ./test.sh for os in windows linux darwin; do echo "... building v$version for $os/$arch" EXT= if [ $os = windows ]; then EXT=".exe" fi BUILD=$(mktemp -d ${TMPDIR:-/tmp}/oauth2_proxy.XXXXXX) TARGET="oauth2_proxy-$version.$os-$arch.$goversion" FILENAME="oauth2_proxy-$version.$os-$arch$EXT" GOOS=$os GOARCH=$arch CGO_ENABLED=0 \ go build -ldflags="-s -w" -o $BUILD/$TARGET/$FILENAME || exit 1 pushd $BUILD/$TARGET sha256sum+=("$(shasum -a 256 $FILENAME || exit 1)") cd .. && tar czvf $TARGET.tar.gz $TARGET mv $TARGET.tar.gz $DIR/dist popd done checksum_file="sha256sum.txt" cd $DIR/dist if [ -f $checksum_file ]; then rm $checksum_file fi touch $checksum_file for checksum in "${sha256sum[@]}"; do echo "$checksum" >> $checksum_file done ================================================ FILE: env_options.go ================================================ package main import ( "os" "reflect" "strings" ) type EnvOptions map[string]interface{} func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { val := reflect.ValueOf(options).Elem() typ := val.Type() for i := 0; i < typ.NumField(); i++ { // pull out the struct tags: // flag - the name of the command line flag // deprecated - (optional) the name of the deprecated command line flag // cfg - (optional, defaults to underscored flag) the name of the config file option field := typ.Field(i) flagName := field.Tag.Get("flag") envName := field.Tag.Get("env") cfgName := field.Tag.Get("cfg") if cfgName == "" && flagName != "" { cfgName = strings.Replace(flagName, "-", "_", -1) } if envName == "" || cfgName == "" { // resolvable fields must have the `env` and `cfg` struct tag continue } v := os.Getenv(envName) if v != "" { cfg[cfgName] = v } } } ================================================ FILE: env_options_test.go ================================================ package main import ( "os" "testing" "github.com/stretchr/testify/assert" ) type envTest struct { testField string `cfg:"target_field" env:"TEST_ENV_FIELD"` } func TestLoadEnvForStruct(t *testing.T) { cfg := make(EnvOptions) cfg.LoadEnvForStruct(&envTest{}) _, ok := cfg["target_field"] assert.Equal(t, ok, false) os.Setenv("TEST_ENV_FIELD", "1234abcd") cfg.LoadEnvForStruct(&envTest{}) v := cfg["target_field"] assert.Equal(t, v, "1234abcd") } ================================================ FILE: htpasswd.go ================================================ package main import ( "crypto/sha1" "encoding/base64" "encoding/csv" "io" "log" "os" "golang.org/x/crypto/bcrypt" ) // Lookup passwords in a htpasswd file // Passwords must be generated with -B for bcrypt or -s for SHA1. type HtpasswdFile struct { Users map[string]string } func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { r, err := os.Open(path) if err != nil { return nil, err } defer r.Close() return NewHtpasswd(r) } func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { csv_reader := csv.NewReader(file) csv_reader.Comma = ':' csv_reader.Comment = '#' csv_reader.TrimLeadingSpace = true records, err := csv_reader.ReadAll() if err != nil { return nil, err } h := &HtpasswdFile{Users: make(map[string]string)} for _, record := range records { h.Users[record[0]] = record[1] } return h, nil } func (h *HtpasswdFile) Validate(user string, password string) bool { realPassword, exists := h.Users[user] if !exists { return false } shaPrefix := realPassword[:5] if shaPrefix == "{SHA}" { shaValue := realPassword[5:] d := sha1.New() d.Write([]byte(password)) return shaValue == base64.StdEncoding.EncodeToString(d.Sum(nil)) } bcryptPrefix := realPassword[:4] if bcryptPrefix == "$2a$" || bcryptPrefix == "$2b$" || bcryptPrefix == "$2x$" || bcryptPrefix == "$2y$" { return bcrypt.CompareHashAndPassword([]byte(realPassword), []byte(password)) == nil } log.Printf("Invalid htpasswd entry for %s. Must be a SHA or bcrypt entry.", user) return false } ================================================ FILE: htpasswd_test.go ================================================ package main import ( "bytes" "fmt" "testing" "github.com/stretchr/testify/assert" "golang.org/x/crypto/bcrypt" ) func TestSHA(t *testing.T) { file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n")) h, err := NewHtpasswd(file) assert.Equal(t, err, nil) valid := h.Validate("testuser", "asdf") assert.Equal(t, valid, true) } func TestBcrypt(t *testing.T) { hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1) hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2) assert.Equal(t, err, nil) contents := fmt.Sprintf("testuser1:%s\ntestuser2:%s\n", hash1, hash2) file := bytes.NewBuffer([]byte(contents)) h, err := NewHtpasswd(file) assert.Equal(t, err, nil) valid := h.Validate("testuser1", "password") assert.Equal(t, valid, true) valid = h.Validate("testuser2", "top-secret") assert.Equal(t, valid, true) } ================================================ FILE: http.go ================================================ package main import ( "crypto/tls" "log" "net" "net/http" "strings" "time" ) type Server struct { Handler http.Handler Opts *Options } func (s *Server) ListenAndServe() { if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" { s.ServeHTTPS() } else { s.ServeHTTP() } } func (s *Server) ServeHTTP() { httpAddress := s.Opts.HttpAddress scheme := "" i := strings.Index(httpAddress, "://") if i > -1 { scheme = httpAddress[0:i] } var networkType string switch scheme { case "", "http": networkType = "tcp" default: networkType = scheme } slice := strings.SplitN(httpAddress, "//", 2) listenAddr := slice[len(slice)-1] listener, err := net.Listen(networkType, listenAddr) if err != nil { log.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err) } log.Printf("HTTP: listening on %s", listenAddr) server := &http.Server{Handler: s.Handler} err = server.Serve(listener) if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { log.Printf("ERROR: http.Serve() - %s", err) } log.Printf("HTTP: closing %s", listener.Addr()) } func (s *Server) ServeHTTPS() { addr := s.Opts.HttpsAddress config := &tls.Config{ MinVersion: tls.VersionTLS12, MaxVersion: tls.VersionTLS12, } if config.NextProtos == nil { config.NextProtos = []string{"http/1.1"} } var err error config.Certificates = make([]tls.Certificate, 1) config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile) if err != nil { log.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err) } ln, err := net.Listen("tcp", addr) if err != nil { log.Fatalf("FATAL: listen (%s) failed - %s", addr, err) } log.Printf("HTTPS: listening on %s", ln.Addr()) tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) srv := &http.Server{Handler: s.Handler} err = srv.Serve(tlsListener) if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { log.Printf("ERROR: https.Serve() - %s", err) } log.Printf("HTTPS: closing %s", tlsListener.Addr()) } // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted // connections. It's used by ListenAndServe and ListenAndServeTLS so // dead TCP connections (e.g. closing laptop mid-download) eventually // go away. type tcpKeepAliveListener struct { *net.TCPListener } func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { tc, err := ln.AcceptTCP() if err != nil { return } tc.SetKeepAlive(true) tc.SetKeepAlivePeriod(3 * time.Minute) return tc, nil } ================================================ FILE: logging_handler.go ================================================ // largely adapted from https://github.com/gorilla/handlers/blob/master/handlers.go // to add logging of request duration as last value (and drop referrer) package main import ( "fmt" "io" "net" "net/http" "net/url" "text/template" "time" ) const ( defaultRequestLoggingFormat = "{{.Client}} - {{.Username}} [{{.Timestamp}}] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" ) // responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status // code and body size type responseLogger struct { w http.ResponseWriter status int size int upstream string authInfo string } func (l *responseLogger) Header() http.Header { return l.w.Header() } func (l *responseLogger) ExtractGAPMetadata() { upstream := l.w.Header().Get("GAP-Upstream-Address") if upstream != "" { l.upstream = upstream l.w.Header().Del("GAP-Upstream-Address") } authInfo := l.w.Header().Get("GAP-Auth") if authInfo != "" { l.authInfo = authInfo l.w.Header().Del("GAP-Auth") } } func (l *responseLogger) Write(b []byte) (int, error) { if l.status == 0 { // The status will be StatusOK if WriteHeader has not been called yet l.status = http.StatusOK } l.ExtractGAPMetadata() size, err := l.w.Write(b) l.size += size return size, err } func (l *responseLogger) WriteHeader(s int) { l.ExtractGAPMetadata() l.w.WriteHeader(s) l.status = s } func (l *responseLogger) Status() int { return l.status } func (l *responseLogger) Size() int { return l.size } // logMessageData is the container for all values that are available as variables in the request logging format. // All values are pre-formatted strings so it is easy to use them in the format string. type logMessageData struct { Client, Host, Protocol, RequestDuration, RequestMethod, RequestURI, ResponseSize, StatusCode, Timestamp, Upstream, UserAgent, Username string } // loggingHandler is the http.Handler implementation for LoggingHandlerTo and its friends type loggingHandler struct { writer io.Writer handler http.Handler enabled bool logTemplate *template.Template } func LoggingHandler(out io.Writer, h http.Handler, v bool, requestLoggingTpl string) http.Handler { return loggingHandler{ writer: out, handler: h, enabled: v, logTemplate: template.Must(template.New("request-log").Parse(requestLoggingTpl)), } } func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { t := time.Now() url := *req.URL logger := &responseLogger{w: w} h.handler.ServeHTTP(logger, req) if !h.enabled { return } h.writeLogLine(logger.authInfo, logger.upstream, req, url, t, logger.Status(), logger.Size()) } // Log entry for req similar to Apache Common Log Format. // ts is the timestamp with which the entry should be logged. // status, size are used to provide the response HTTP status and size. func (h loggingHandler) writeLogLine(username, upstream string, req *http.Request, url url.URL, ts time.Time, status int, size int) { if username == "" { username = "-" } if upstream == "" { upstream = "-" } if url.User != nil && username == "-" { if name := url.User.Username(); name != "" { username = name } } client := req.Header.Get("X-Real-IP") if client == "" { client = req.RemoteAddr } if c, _, err := net.SplitHostPort(client); err == nil { client = c } duration := float64(time.Now().Sub(ts)) / float64(time.Second) h.logTemplate.Execute(h.writer, logMessageData{ Client: client, Host: req.Host, Protocol: req.Proto, RequestDuration: fmt.Sprintf("%0.3f", duration), RequestMethod: req.Method, RequestURI: fmt.Sprintf("%q", url.RequestURI()), ResponseSize: fmt.Sprintf("%d", size), StatusCode: fmt.Sprintf("%d", status), Timestamp: ts.Format("02/Jan/2006:15:04:05 -0700"), Upstream: upstream, UserAgent: fmt.Sprintf("%q", req.UserAgent()), Username: username, }) h.writer.Write([]byte("\n")) } ================================================ FILE: logging_handler_test.go ================================================ package main import ( "bytes" "fmt" "net/http" "net/http/httptest" "testing" "time" ) func TestLoggingHandler_ServeHTTP(t *testing.T) { ts := time.Now() tests := []struct { Format, ExpectedLogMessage string }{ {defaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", ts.Format("02/Jan/2006:15:04:05 -0700"))}, {"{{.RequestMethod}}", "GET\n"}, } for _, test := range tests { buf := bytes.NewBuffer(nil) handler := func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("test")) } h := LoggingHandler(buf, http.HandlerFunc(handler), true, test.Format) r, _ := http.NewRequest("GET", "/foo/bar", nil) r.RemoteAddr = "127.0.0.1" r.Host = "test-server" h.ServeHTTP(httptest.NewRecorder(), r) actual := buf.String() if actual != test.ExpectedLogMessage { t.Errorf("Log message was\n%s\ninstead of expected \n%s", actual, test.ExpectedLogMessage) } } } ================================================ FILE: main.go ================================================ package main import ( "flag" "fmt" "log" "os" "runtime" "strings" "time" "github.com/BurntSushi/toml" "github.com/mreiferson/go-options" ) func main() { log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) flagSet := flag.NewFlagSet("oauth2_proxy", flag.ExitOnError) emailDomains := StringArray{} upstreams := StringArray{} skipAuthRegex := StringArray{} googleGroups := StringArray{} config := flagSet.String("config", "", "path to config file") showVersion := flagSet.Bool("version", false, "print version string") flagSet.String("http-address", "127.0.0.1:4180", "[http://]: or unix:// to listen on for HTTP clients") flagSet.String("https-address", ":443", ": to listen on for HTTPS clients") flagSet.String("tls-cert", "", "path to certificate file") flagSet.String("tls-key", "", "path to private key file") flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") flagSet.Bool("set-xauthrequest", false, "set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)") flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path") flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream") flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start") flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS") flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.") flagSet.String("github-org", "", "restrict logins to members of this organisation") flagSet.String("github-team", "", "restrict logins to members of this team") flagSet.Var(&googleGroups, "google-group", "restrict logins to members of this google group (may be given multiple times).") flagSet.String("google-admin-email", "", "the google admin to impersonate for api calls") flagSet.String("google-service-account-json", "", "the path to the service account json credentials") flagSet.String("client-id", "", "the OAuth Client ID: ie: \"123456.apps.googleusercontent.com\"") flagSet.String("client-secret", "", "the OAuth Client Secret") flagSet.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") flagSet.String("htpasswd-file", "", "additionally authenticate against a htpasswd file. Entries must be created with \"htpasswd -s\" for SHA encryption or \"htpasswd -B\" for bcrypt encryption") flagSet.Bool("display-htpasswd-form", true, "display username / password login form if an htpasswd file is provided") flagSet.String("custom-templates-dir", "", "path to custom html templates") flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.") flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. //sign_in)") flagSet.String("cookie-name", "_oauth2_proxy", "the name of the cookie that the oauth_proxy creates") flagSet.String("cookie-secret", "", "the seed string for secure cookies (optionally base64 encoded)") flagSet.String("cookie-domain", "", "an optional cookie domain to force cookies to (ie: .yourcompany.com)*") flagSet.Duration("cookie-expire", time.Duration(168)*time.Hour, "expire timeframe for cookie") flagSet.Duration("cookie-refresh", time.Duration(0), "refresh the cookie after this duration; 0 to disable") flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag") flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag") flagSet.Bool("request-logging", true, "Log requests to stdout") flagSet.String("request-logging-format", defaultRequestLoggingFormat, "Template for log lines") flagSet.String("provider", "google", "OAuth provider") flagSet.String("oidc-issuer-url", "", "OpenID Connect issuer URL (ie: https://accounts.google.com)") flagSet.String("login-url", "", "Authentication endpoint") flagSet.String("redeem-url", "", "Token redemption endpoint") flagSet.String("profile-url", "", "Profile access endpoint") flagSet.String("resource", "", "The resource that is protected (Azure AD only)") flagSet.String("validate-url", "", "Access token validation endpoint") flagSet.String("scope", "", "OAuth scope specification") flagSet.String("approval-prompt", "force", "OAuth approval_prompt") flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") flagSet.Parse(os.Args[1:]) if *showVersion { fmt.Printf("oauth2_proxy v%s (built with %s)\n", VERSION, runtime.Version()) return } opts := NewOptions() cfg := make(EnvOptions) if *config != "" { _, err := toml.DecodeFile(*config, &cfg) if err != nil { log.Fatalf("ERROR: failed to load config file %s - %s", *config, err) } } cfg.LoadEnvForStruct(opts) options.Resolve(opts, flagSet, cfg) err := opts.Validate() if err != nil { log.Printf("%s", err) os.Exit(1) } validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) oauthproxy := NewOAuthProxy(opts, validator) if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { if len(opts.EmailDomains) > 1 { oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using one of the following domains: %v", strings.Join(opts.EmailDomains, ", ")) } else if opts.EmailDomains[0] != "*" { oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using %v", opts.EmailDomains[0]) } } if opts.HtpasswdFile != "" { log.Printf("using htpasswd file %s", opts.HtpasswdFile) oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(opts.HtpasswdFile) oauthproxy.DisplayHtpasswdForm = opts.DisplayHtpasswdForm if err != nil { log.Fatalf("FATAL: unable to open %s %s", opts.HtpasswdFile, err) } } s := &Server{ Handler: LoggingHandler(os.Stdout, oauthproxy, opts.RequestLogging, opts.RequestLoggingFormat), Opts: opts, } s.ListenAndServe() } ================================================ FILE: oauthproxy.go ================================================ package main import ( b64 "encoding/base64" "errors" "fmt" "html/template" "log" "net" "net/http" "net/http/httputil" "net/url" "regexp" "strings" "time" "github.com/bitly/oauth2_proxy/cookie" "github.com/bitly/oauth2_proxy/providers" "github.com/mbland/hmacauth" ) const SignatureHeader = "GAP-Signature" var SignatureHeaders []string = []string{ "Content-Length", "Content-Md5", "Content-Type", "Date", "Authorization", "X-Forwarded-User", "X-Forwarded-Email", "X-Forwarded-Access-Token", "Cookie", "Gap-Auth", } type OAuthProxy struct { CookieSeed string CookieName string CSRFCookieName string CookieDomain string CookieSecure bool CookieHttpOnly bool CookieExpire time.Duration CookieRefresh time.Duration Validator func(string) bool RobotsPath string PingPath string SignInPath string SignOutPath string OAuthStartPath string OAuthCallbackPath string AuthOnlyPath string redirectURL *url.URL // the url to receive requests at provider providers.Provider ProxyPrefix string SignInMessage string HtpasswdFile *HtpasswdFile DisplayHtpasswdForm bool serveMux http.Handler SetXAuthRequest bool PassBasicAuth bool SkipProviderButton bool PassUserHeaders bool BasicAuthPassword string PassAccessToken bool CookieCipher *cookie.Cipher skipAuthRegex []string skipAuthPreflight bool compiledRegex []*regexp.Regexp templates *template.Template Footer string } type UpstreamProxy struct { upstream string handler http.Handler auth hmacauth.HmacAuth } func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("GAP-Upstream-Address", u.upstream) if u.auth != nil { r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth")) u.auth.SignRequest(r) } u.handler.ServeHTTP(w, r) } func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) { return httputil.NewSingleHostReverseProxy(target) } func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { director := proxy.Director proxy.Director = func(req *http.Request) { director(req) // use RequestURI so that we aren't unescaping encoded slashes in the request path req.Host = target.Host req.URL.Opaque = req.RequestURI req.URL.RawQuery = "" } } func setProxyDirector(proxy *httputil.ReverseProxy) { director := proxy.Director proxy.Director = func(req *http.Request) { director(req) // use RequestURI so that we aren't unescaping encoded slashes in the request path req.URL.Opaque = req.RequestURI req.URL.RawQuery = "" } } func NewFileServer(path string, filesystemPath string) (proxy http.Handler) { return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) } func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { serveMux := http.NewServeMux() var auth hmacauth.HmacAuth if sigData := opts.signatureData; sigData != nil { auth = hmacauth.NewHmacAuth(sigData.hash, []byte(sigData.key), SignatureHeader, SignatureHeaders) } for _, u := range opts.proxyURLs { path := u.Path switch u.Scheme { case "http", "https": u.Path = "" log.Printf("mapping path %q => upstream %q", path, u) proxy := NewReverseProxy(u) if !opts.PassHostHeader { setProxyUpstreamHostHeader(proxy, u) } else { setProxyDirector(proxy) } serveMux.Handle(path, &UpstreamProxy{u.Host, proxy, auth}) case "file": if u.Fragment != "" { path = u.Fragment } log.Printf("mapping path %q => file system %q", path, u.Path) proxy := NewFileServer(path, u.Path) serveMux.Handle(path, &UpstreamProxy{path, proxy, nil}) default: panic(fmt.Sprintf("unknown upstream protocol %s", u.Scheme)) } } for _, u := range opts.CompiledRegex { log.Printf("compiled skip-auth-regex => %q", u) } redirectURL := opts.redirectURL redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) log.Printf("OAuthProxy configured for %s Client ID: %s", opts.provider.Data().ProviderName, opts.ClientID) refresh := "disabled" if opts.CookieRefresh != time.Duration(0) { refresh = fmt.Sprintf("after %s", opts.CookieRefresh) } log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh) var cipher *cookie.Cipher if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { var err error cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret)) if err != nil { log.Fatal("cookie-secret error: ", err) } } return &OAuthProxy{ CookieName: opts.CookieName, CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"), CookieSeed: opts.CookieSecret, CookieDomain: opts.CookieDomain, CookieSecure: opts.CookieSecure, CookieHttpOnly: opts.CookieHttpOnly, CookieExpire: opts.CookieExpire, CookieRefresh: opts.CookieRefresh, Validator: validator, RobotsPath: "/robots.txt", PingPath: "/ping", SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), SignOutPath: fmt.Sprintf("%s/sign_out", opts.ProxyPrefix), OAuthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix), OAuthCallbackPath: fmt.Sprintf("%s/callback", opts.ProxyPrefix), AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix), ProxyPrefix: opts.ProxyPrefix, provider: opts.provider, serveMux: serveMux, redirectURL: redirectURL, skipAuthRegex: opts.SkipAuthRegex, skipAuthPreflight: opts.SkipAuthPreflight, compiledRegex: opts.CompiledRegex, SetXAuthRequest: opts.SetXAuthRequest, PassBasicAuth: opts.PassBasicAuth, PassUserHeaders: opts.PassUserHeaders, BasicAuthPassword: opts.BasicAuthPassword, PassAccessToken: opts.PassAccessToken, SkipProviderButton: opts.SkipProviderButton, CookieCipher: cipher, templates: loadTemplates(opts.CustomTemplatesDir), Footer: opts.Footer, } } func (p *OAuthProxy) GetRedirectURI(host string) string { // default to the request Host if not set if p.redirectURL.Host != "" { return p.redirectURL.String() } var u url.URL u = *p.redirectURL if u.Scheme == "" { if p.CookieSecure { u.Scheme = "https" } else { u.Scheme = "http" } } u.Host = host return u.String() } func (p *OAuthProxy) displayCustomLoginForm() bool { return p.HtpasswdFile != nil && p.DisplayHtpasswdForm } func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { if code == "" { return nil, errors.New("missing code") } redirectURI := p.GetRedirectURI(host) s, err = p.provider.Redeem(redirectURI, code) if err != nil { return } if s.Email == "" { s.Email, err = p.provider.GetEmailAddress(s) } if s.User == "" { s.User, err = p.provider.GetUserName(s) if err != nil && err.Error() == "not implemented" { err = nil } } return } func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { if value != "" { value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) if len(value) > 4096 { // Cookies cannot be larger than 4kb log.Printf("WARNING - Cookie Size: %d bytes", len(value)) } } return p.makeCookie(req, p.CookieName, value, expiration, now) } func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) } func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { if p.CookieDomain != "" { domain := req.Host if h, _, err := net.SplitHostPort(domain); err == nil { domain = h } if !strings.HasSuffix(domain, p.CookieDomain) { log.Printf("Warning: request host is %q but using configured cookie domain of %q", domain, p.CookieDomain) } } return &http.Cookie{ Name: name, Value: value, Path: "/", Domain: p.CookieDomain, HttpOnly: p.CookieHttpOnly, Secure: p.CookieSecure, Expires: now.Add(expiration), } } func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) { http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now())) } func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) { http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now())) } func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) http.SetCookie(rw, clr) // ugly hack because default domain changed if p.CookieDomain == "" { clr2 := *clr clr2.Domain = req.Host http.SetCookie(rw, &clr2) } } func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) } func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { var age time.Duration c, err := req.Cookie(p.CookieName) if err != nil { // always http.ErrNoCookie return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) } val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire) if !ok { return nil, age, errors.New("Cookie Signature not valid") } session, err := p.provider.SessionFromCookie(val, p.CookieCipher) if err != nil { return nil, age, err } age = time.Now().Truncate(time.Second).Sub(timestamp) return session, age, nil } func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { value, err := p.provider.CookieForSession(s, p.CookieCipher) if err != nil { return err } p.SetSessionCookie(rw, req, value) return nil } func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { rw.WriteHeader(http.StatusOK) fmt.Fprintf(rw, "User-agent: *\nDisallow: /") } func (p *OAuthProxy) PingPage(rw http.ResponseWriter) { rw.WriteHeader(http.StatusOK) fmt.Fprintf(rw, "OK") } func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { log.Printf("ErrorPage %d %s %s", code, title, message) rw.WriteHeader(code) t := struct { Title string Message string ProxyPrefix string }{ Title: fmt.Sprintf("%d %s", code, title), Message: message, ProxyPrefix: p.ProxyPrefix, } p.templates.ExecuteTemplate(rw, "error.html", t) } func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { p.ClearSessionCookie(rw, req) rw.WriteHeader(code) redirect_url := req.URL.RequestURI() if req.Header.Get("X-Auth-Request-Redirect") != "" { redirect_url = req.Header.Get("X-Auth-Request-Redirect") } if redirect_url == p.SignInPath { redirect_url = "/" } t := struct { ProviderName string SignInMessage string CustomLogin bool Redirect string Version string ProxyPrefix string Footer template.HTML }{ ProviderName: p.provider.Data().ProviderName, SignInMessage: p.SignInMessage, CustomLogin: p.displayCustomLoginForm(), Redirect: redirect_url, Version: VERSION, ProxyPrefix: p.ProxyPrefix, Footer: template.HTML(p.Footer), } p.templates.ExecuteTemplate(rw, "sign_in.html", t) } func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { if req.Method != "POST" || p.HtpasswdFile == nil { return "", false } user := req.FormValue("username") passwd := req.FormValue("password") if user == "" { return "", false } // check auth if p.HtpasswdFile.Validate(user, passwd) { log.Printf("authenticated %q via HtpasswdFile", user) return user, true } return "", false } func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { err = req.ParseForm() if err != nil { return } redirect = req.Form.Get("rd") if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { redirect = "/" } return } func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) (ok bool) { isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" return isPreflightRequestAllowed || p.IsWhitelistedPath(req.URL.Path) } func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) { for _, u := range p.compiledRegex { ok = u.MatchString(path) if ok { return } } return } func getRemoteAddr(req *http.Request) (s string) { s = req.RemoteAddr if req.Header.Get("X-Real-IP") != "" { s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP")) } return } func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { switch path := req.URL.Path; { case path == p.RobotsPath: p.RobotsTxt(rw) case path == p.PingPath: p.PingPage(rw) case p.IsWhitelistedRequest(req): p.serveMux.ServeHTTP(rw, req) case path == p.SignInPath: p.SignIn(rw, req) case path == p.SignOutPath: p.SignOut(rw, req) case path == p.OAuthStartPath: p.OAuthStart(rw, req) case path == p.OAuthCallbackPath: p.OAuthCallback(rw, req) case path == p.AuthOnlyPath: p.AuthenticateOnly(rw, req) default: p.Proxy(rw, req) } } func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { redirect, err := p.GetRedirect(req) if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } user, ok := p.ManualSignIn(rw, req) if ok { session := &providers.SessionState{User: user} p.SaveSession(rw, req, session) http.Redirect(rw, req, redirect, 302) } else { if p.SkipProviderButton { p.OAuthStart(rw, req) } else { p.SignInPage(rw, req, http.StatusOK) } } } func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { p.ClearSessionCookie(rw, req) http.Redirect(rw, req, "/", 302) } func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { nonce, err := cookie.Nonce() if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } p.SetCSRFCookie(rw, req, nonce) redirect, err := p.GetRedirect(req) if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } redirectURI := p.GetRedirectURI(req.Host) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302) } func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { remoteAddr := getRemoteAddr(req) // finish the oauth cycle err := req.ParseForm() if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } errorString := req.Form.Get("error") if errorString != "" { p.ErrorPage(rw, 403, "Permission Denied", errorString) return } session, err := p.redeemCode(req.Host, req.Form.Get("code")) if err != nil { log.Printf("%s error redeeming code %s", remoteAddr, err) p.ErrorPage(rw, 500, "Internal Error", "Internal Error") return } s := strings.SplitN(req.Form.Get("state"), ":", 2) if len(s) != 2 { p.ErrorPage(rw, 500, "Internal Error", "Invalid State") return } nonce := s[0] redirect := s[1] c, err := req.Cookie(p.CSRFCookieName) if err != nil { p.ErrorPage(rw, 403, "Permission Denied", err.Error()) return } p.ClearCSRFCookie(rw, req) if c.Value != nonce { log.Printf("%s csrf token mismatch, potential attack", remoteAddr) p.ErrorPage(rw, 403, "Permission Denied", "csrf failed") return } if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { redirect = "/" } // set cookie, or deny if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { log.Printf("%s authentication complete %s", remoteAddr, session) err := p.SaveSession(rw, req, session) if err != nil { log.Printf("%s %s", remoteAddr, err) p.ErrorPage(rw, 500, "Internal Error", "Internal Error") return } http.Redirect(rw, req, redirect, 302) } else { log.Printf("%s Permission Denied: %q is unauthorized", remoteAddr, session.Email) p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") } } func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) { status := p.Authenticate(rw, req) if status == http.StatusAccepted { rw.WriteHeader(http.StatusAccepted) } else { http.Error(rw, "unauthorized request", http.StatusUnauthorized) } } func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { status := p.Authenticate(rw, req) if status == http.StatusInternalServerError { p.ErrorPage(rw, http.StatusInternalServerError, "Internal Error", "Internal Error") } else if status == http.StatusForbidden { if p.SkipProviderButton { p.OAuthStart(rw, req) } else { p.SignInPage(rw, req, http.StatusForbidden) } } else { p.serveMux.ServeHTTP(rw, req) } } func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int { var saveSession, clearSession, revalidated bool remoteAddr := getRemoteAddr(req) session, sessionAge, err := p.LoadCookiedSession(req) if err != nil { log.Printf("%s %s", remoteAddr, err) } if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { log.Printf("%s refreshing %s old session cookie for %s (refresh after %s)", remoteAddr, sessionAge, session, p.CookieRefresh) saveSession = true } if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil { log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) clearSession = true session = nil } else if ok { saveSession = true revalidated = true } if session != nil && session.IsExpired() { log.Printf("%s removing session. token expired %s", remoteAddr, session) session = nil saveSession = false clearSession = true } if saveSession && !revalidated && session != nil && session.AccessToken != "" { if !p.provider.ValidateSessionState(session) { log.Printf("%s removing session. error validating %s", remoteAddr, session) saveSession = false session = nil clearSession = true } } if session != nil && session.Email != "" && !p.Validator(session.Email) { log.Printf("%s Permission Denied: removing session %s", remoteAddr, session) session = nil saveSession = false clearSession = true } if saveSession && session != nil { err := p.SaveSession(rw, req, session) if err != nil { log.Printf("%s %s", remoteAddr, err) return http.StatusInternalServerError } } if clearSession { p.ClearSessionCookie(rw, req) } if session == nil { session, err = p.CheckBasicAuth(req) if err != nil { log.Printf("%s %s", remoteAddr, err) } } if session == nil { return http.StatusForbidden } // At this point, the user is authenticated. proxy normally if p.PassBasicAuth { req.SetBasicAuth(session.User, p.BasicAuthPassword) req.Header["X-Forwarded-User"] = []string{session.User} if session.Email != "" { req.Header["X-Forwarded-Email"] = []string{session.Email} } } if p.PassUserHeaders { req.Header["X-Forwarded-User"] = []string{session.User} if session.Email != "" { req.Header["X-Forwarded-Email"] = []string{session.Email} } } if p.SetXAuthRequest { rw.Header().Set("X-Auth-Request-User", session.User) if session.Email != "" { rw.Header().Set("X-Auth-Request-Email", session.Email) } } if p.PassAccessToken && session.AccessToken != "" { req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken} } if session.Email == "" { rw.Header().Set("GAP-Auth", session.User) } else { rw.Header().Set("GAP-Auth", session.Email) } return http.StatusAccepted } func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { if p.HtpasswdFile == nil { return nil, nil } auth := req.Header.Get("Authorization") if auth == "" { return nil, nil } s := strings.SplitN(auth, " ", 2) if len(s) != 2 || s[0] != "Basic" { return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization")) } b, err := b64.StdEncoding.DecodeString(s[1]) if err != nil { return nil, err } pair := strings.SplitN(string(b), ":", 2) if len(pair) != 2 { return nil, fmt.Errorf("invalid format %s", b) } if p.HtpasswdFile.Validate(pair[0], pair[1]) { log.Printf("authenticated %q via basic auth", pair[0]) return &providers.SessionState{User: pair[0]}, nil } return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) } ================================================ FILE: oauthproxy_test.go ================================================ package main import ( "crypto" "encoding/base64" "io" "io/ioutil" "log" "net" "net/http" "net/http/httptest" "net/url" "regexp" "strings" "testing" "time" "github.com/bitly/oauth2_proxy/providers" "github.com/mbland/hmacauth" "github.com/stretchr/testify/assert" ) func init() { log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) } func TestNewReverseProxy(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) hostname, _, _ := net.SplitHostPort(r.Host) w.Write([]byte(hostname)) })) defer backend.Close() backendURL, _ := url.Parse(backend.URL) backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) backendHost := net.JoinHostPort(backendHostname, backendPort) proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") proxyHandler := NewReverseProxy(proxyURL) setProxyUpstreamHostHeader(proxyHandler, proxyURL) frontend := httptest.NewServer(proxyHandler) defer frontend.Close() getReq, _ := http.NewRequest("GET", frontend.URL, nil) res, _ := http.DefaultClient.Do(getReq) bodyBytes, _ := ioutil.ReadAll(res.Body) if g, e := string(bodyBytes), backendHostname; g != e { t.Errorf("got body %q; expected %q", g, e) } } func TestEncodedSlashes(t *testing.T) { var seen string backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) seen = r.RequestURI })) defer backend.Close() b, _ := url.Parse(backend.URL) proxyHandler := NewReverseProxy(b) setProxyDirector(proxyHandler) frontend := httptest.NewServer(proxyHandler) defer frontend.Close() f, _ := url.Parse(frontend.URL) encodedPath := "/a%2Fb/?c=1" getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}} _, err := http.DefaultClient.Do(getReq) if err != nil { t.Fatalf("err %s", err) } if seen != encodedPath { t.Errorf("got bad request %q expected %q", seen, encodedPath) } } func TestRobotsTxt(t *testing.T) { opts := NewOptions() opts.ClientID = "bazquux" opts.ClientSecret = "foobar" opts.CookieSecret = "xyzzyplugh" opts.Validate() proxy := NewOAuthProxy(opts, func(string) bool { return true }) rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/robots.txt", nil) proxy.ServeHTTP(rw, req) assert.Equal(t, 200, rw.Code) assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String()) } type TestProvider struct { *providers.ProviderData EmailAddress string ValidToken bool } func NewTestProvider(provider_url *url.URL, email_address string) *TestProvider { return &TestProvider{ ProviderData: &providers.ProviderData{ ProviderName: "Test Provider", LoginURL: &url.URL{ Scheme: "http", Host: provider_url.Host, Path: "/oauth/authorize", }, RedeemURL: &url.URL{ Scheme: "http", Host: provider_url.Host, Path: "/oauth/token", }, ProfileURL: &url.URL{ Scheme: "http", Host: provider_url.Host, Path: "/api/v1/profile", }, Scope: "profile.email", }, EmailAddress: email_address, } } func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { return tp.EmailAddress, nil } func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { return tp.ValidToken } func TestBasicAuthPassword(t *testing.T) { provider_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Printf("%#v", r) url := r.URL payload := "" switch url.Path { case "/oauth/token": payload = `{"access_token": "my_auth_token"}` default: payload = r.Header.Get("Authorization") if payload == "" { payload = "No Authorization header found." } } w.WriteHeader(200) w.Write([]byte(payload)) })) opts := NewOptions() opts.Upstreams = append(opts.Upstreams, provider_server.URL) // The CookieSecret must be 32 bytes in order to create the AES // cipher. opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" opts.ClientID = "bazquux" opts.ClientSecret = "foobar" opts.CookieSecure = false opts.PassBasicAuth = true opts.PassUserHeaders = true opts.BasicAuthPassword = "This is a secure password" opts.Validate() provider_url, _ := url.Parse(provider_server.URL) const email_address = "michael.bland@gsa.gov" const user_name = "michael.bland" opts.provider = NewTestProvider(provider_url, email_address) proxy := NewOAuthProxy(opts, func(email string) bool { return email == email_address }) rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader("")) req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now())) proxy.ServeHTTP(rw, req) if rw.Code >= 400 { t.Fatalf("expected 3xx got %d", rw.Code) } cookie := rw.HeaderMap["Set-Cookie"][1] cookieName := proxy.CookieName var value string key_prefix := cookieName + "=" for _, field := range strings.Split(cookie, "; ") { value = strings.TrimPrefix(field, key_prefix) if value != field { break } else { value = "" } } req, _ = http.NewRequest("GET", "/", strings.NewReader("")) req.AddCookie(&http.Cookie{ Name: cookieName, Value: value, Path: "/", Expires: time.Now().Add(time.Duration(24)), HttpOnly: true, }) req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now())) rw = httptest.NewRecorder() proxy.ServeHTTP(rw, req) expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword)) assert.Equal(t, expectedHeader, rw.Body.String()) provider_server.Close() } type PassAccessTokenTest struct { provider_server *httptest.Server proxy *OAuthProxy opts *Options } type PassAccessTokenTestOptions struct { PassAccessToken bool } func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { t := &PassAccessTokenTest{} t.provider_server = httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Printf("%#v", r) url := r.URL payload := "" switch url.Path { case "/oauth/token": payload = `{"access_token": "my_auth_token"}` default: payload = r.Header.Get("X-Forwarded-Access-Token") if payload == "" { payload = "No access token found." } } w.WriteHeader(200) w.Write([]byte(payload)) })) t.opts = NewOptions() t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL) // The CookieSecret must be 32 bytes in order to create the AES // cipher. t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" t.opts.ClientID = "bazquux" t.opts.ClientSecret = "foobar" t.opts.CookieSecure = false t.opts.PassAccessToken = opts.PassAccessToken t.opts.Validate() provider_url, _ := url.Parse(t.provider_server.URL) const email_address = "michael.bland@gsa.gov" t.opts.provider = NewTestProvider(provider_url, email_address) t.proxy = NewOAuthProxy(t.opts, func(email string) bool { return email == email_address }) return t } func (pat_test *PassAccessTokenTest) Close() { pat_test.provider_server.Close() } func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, cookie string) { rw := httptest.NewRecorder() req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader("")) if err != nil { return 0, "" } req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) pat_test.proxy.ServeHTTP(rw, req) return rw.Code, rw.HeaderMap["Set-Cookie"][1] } func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) { cookieName := pat_test.proxy.CookieName var value string key_prefix := cookieName + "=" for _, field := range strings.Split(cookie, "; ") { value = strings.TrimPrefix(field, key_prefix) if value != field { break } else { value = "" } } if value == "" { return 0, "" } req, err := http.NewRequest("GET", "/", strings.NewReader("")) if err != nil { return 0, "" } req.AddCookie(&http.Cookie{ Name: cookieName, Value: value, Path: "/", Expires: time.Now().Add(time.Duration(24)), HttpOnly: true, }) rw := httptest.NewRecorder() pat_test.proxy.ServeHTTP(rw, req) return rw.Code, rw.Body.String() } func TestForwardAccessTokenUpstream(t *testing.T) { pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, }) defer pat_test.Close() // A successful validation will redirect and set the auth cookie. code, cookie := pat_test.getCallbackEndpoint() if code != 302 { t.Fatalf("expected 302; got %d", code) } assert.NotEqual(t, nil, cookie) // Now we make a regular request; the access_token from the cookie is // forwarded as the "X-Forwarded-Access-Token" header. The token is // read by the test provider server and written in the response body. code, payload := pat_test.getRootEndpoint(cookie) if code != 200 { t.Fatalf("expected 200; got %d", code) } assert.Equal(t, "my_auth_token", payload) } func TestDoNotForwardAccessTokenUpstream(t *testing.T) { pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: false, }) defer pat_test.Close() // A successful validation will redirect and set the auth cookie. code, cookie := pat_test.getCallbackEndpoint() if code != 302 { t.Fatalf("expected 302; got %d", code) } assert.NotEqual(t, nil, cookie) // Now we make a regular request, but the access token header should // not be present. code, payload := pat_test.getRootEndpoint(cookie) if code != 200 { t.Fatalf("expected 200; got %d", code) } assert.Equal(t, "No access token found.", payload) } type SignInPageTest struct { opts *Options proxy *OAuthProxy sign_in_regexp *regexp.Regexp sign_in_provider_regexp *regexp.Regexp } const signInRedirectPattern = `` const signInSkipProvider = `>Found<` func NewSignInPageTest(skipProvider bool) *SignInPageTest { var sip_test SignInPageTest sip_test.opts = NewOptions() sip_test.opts.CookieSecret = "foobar" sip_test.opts.ClientID = "bazquux" sip_test.opts.ClientSecret = "xyzzyplugh" sip_test.opts.SkipProviderButton = skipProvider sip_test.opts.Validate() sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool { return true }) sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider) return &sip_test } func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) { rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) sip_test.proxy.ServeHTTP(rw, req) return rw.Code, rw.Body.String() } func TestSignInPageIncludesTargetRedirect(t *testing.T) { sip_test := NewSignInPageTest(false) const endpoint = "/some/random/endpoint" code, body := sip_test.GetEndpoint(endpoint) assert.Equal(t, 403, code) match := sip_test.sign_in_regexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInRedirectPattern + "\nBody:\n" + body) } if match[1] != endpoint { t.Fatal(`expected redirect to "` + endpoint + `", but was "` + match[1] + `"`) } } func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { sip_test := NewSignInPageTest(false) code, body := sip_test.GetEndpoint("/oauth2/sign_in") assert.Equal(t, 200, code) match := sip_test.sign_in_regexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInRedirectPattern + "\nBody:\n" + body) } if match[1] != "/" { t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`) } } func TestSignInPageSkipProvider(t *testing.T) { sip_test := NewSignInPageTest(true) const endpoint = "/some/random/endpoint" code, body := sip_test.GetEndpoint(endpoint) assert.Equal(t, 302, code) match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInSkipProvider + "\nBody:\n" + body) } } func TestSignInPageSkipProviderDirect(t *testing.T) { sip_test := NewSignInPageTest(true) const endpoint = "/sign_in" code, body := sip_test.GetEndpoint(endpoint) assert.Equal(t, 302, code) match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInSkipProvider + "\nBody:\n" + body) } } type ProcessCookieTest struct { opts *Options proxy *OAuthProxy rw *httptest.ResponseRecorder req *http.Request provider TestProvider response_code int validate_user bool } type ProcessCookieTestOpts struct { provider_validate_cookie_response bool } func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { var pc_test ProcessCookieTest pc_test.opts = NewOptions() pc_test.opts.ClientID = "bazquux" pc_test.opts.ClientSecret = "xyzzyplugh" pc_test.opts.CookieSecret = "0123456789abcdefabcd" // First, set the CookieRefresh option so proxy.AesCipher is created, // needed to encrypt the access_token. pc_test.opts.CookieRefresh = time.Hour pc_test.opts.Validate() pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { return pc_test.validate_user }) pc_test.proxy.provider = &TestProvider{ ValidToken: opts.provider_validate_cookie_response, } // Now, zero-out proxy.CookieRefresh for the cases that don't involve // access_token validation. pc_test.proxy.CookieRefresh = time.Duration(0) pc_test.rw = httptest.NewRecorder() pc_test.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) pc_test.validate_user = true return &pc_test } func NewProcessCookieTestWithDefaults() *ProcessCookieTest { return NewProcessCookieTest(ProcessCookieTestOpts{ provider_validate_cookie_response: true, }) } func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie { return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) } func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) if err != nil { return err } p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref)) return nil } func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) { return p.proxy.LoadCookiedSession(p.req) } func TestLoadCookiedSession(t *testing.T) { pc_test := NewProcessCookieTestWithDefaults() startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} pc_test.SaveSession(startSession, time.Now()) session, _, err := pc_test.LoadCookiedSession() assert.Equal(t, nil, err) assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, "michael.bland", session.User) assert.Equal(t, startSession.AccessToken, session.AccessToken) } func TestProcessCookieNoCookieError(t *testing.T) { pc_test := NewProcessCookieTestWithDefaults() session, _, err := pc_test.LoadCookiedSession() assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) if session != nil { t.Errorf("expected nil session. got %#v", session) } } func TestProcessCookieRefreshNotSet(t *testing.T) { pc_test := NewProcessCookieTestWithDefaults() pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour reference := time.Now().Add(time.Duration(-2) * time.Hour) startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} pc_test.SaveSession(startSession, reference) session, age, err := pc_test.LoadCookiedSession() assert.Equal(t, nil, err) if age < time.Duration(-2)*time.Hour { t.Errorf("cookie too young %v", age) } assert.Equal(t, startSession.Email, session.Email) } func TestProcessCookieFailIfCookieExpired(t *testing.T) { pc_test := NewProcessCookieTestWithDefaults() pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} pc_test.SaveSession(startSession, reference) session, _, err := pc_test.LoadCookiedSession() assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expected nil session %#v", session) } } func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { pc_test := NewProcessCookieTestWithDefaults() pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} pc_test.SaveSession(startSession, reference) pc_test.proxy.CookieRefresh = time.Hour session, _, err := pc_test.LoadCookiedSession() assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expected nil session %#v", session) } } func NewAuthOnlyEndpointTest() *ProcessCookieTest { pc_test := NewProcessCookieTestWithDefaults() pc_test.req, _ = http.NewRequest("GET", pc_test.opts.ProxyPrefix+"/auth", nil) return pc_test } func TestAuthOnlyEndpointAccepted(t *testing.T) { test := NewAuthOnlyEndpointTest() startSession := &providers.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} test.SaveSession(startSession, time.Now()) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusAccepted, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) assert.Equal(t, "", string(bodyBytes)) } func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { test := NewAuthOnlyEndpointTest() test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) assert.Equal(t, "unauthorized request\n", string(bodyBytes)) } func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { test := NewAuthOnlyEndpointTest() test.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &providers.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} test.SaveSession(startSession, reference) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) assert.Equal(t, "unauthorized request\n", string(bodyBytes)) } func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { test := NewAuthOnlyEndpointTest() startSession := &providers.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} test.SaveSession(startSession, time.Now()) test.validate_user = false test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) assert.Equal(t, "unauthorized request\n", string(bodyBytes)) } func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { var pc_test ProcessCookieTest pc_test.opts = NewOptions() pc_test.opts.SetXAuthRequest = true pc_test.opts.Validate() pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { return pc_test.validate_user }) pc_test.proxy.provider = &TestProvider{ ValidToken: true, } pc_test.validate_user = true pc_test.rw = httptest.NewRecorder() pc_test.req, _ = http.NewRequest("GET", pc_test.opts.ProxyPrefix+"/auth", nil) startSession := &providers.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} pc_test.SaveSession(startSession, time.Now()) pc_test.proxy.ServeHTTP(pc_test.rw, pc_test.req) assert.Equal(t, http.StatusAccepted, pc_test.rw.Code) assert.Equal(t, "oauth_user", pc_test.rw.HeaderMap["X-Auth-Request-User"][0]) assert.Equal(t, "oauth_user@example.com", pc_test.rw.HeaderMap["X-Auth-Request-Email"][0]) } func TestAuthSkippedForPreflightRequests(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.Write([]byte("response")) })) defer upstream.Close() opts := NewOptions() opts.Upstreams = append(opts.Upstreams, upstream.URL) opts.ClientID = "bazquux" opts.ClientSecret = "foobar" opts.CookieSecret = "xyzzyplugh" opts.SkipAuthPreflight = true opts.Validate() upstream_url, _ := url.Parse(upstream.URL) opts.provider = NewTestProvider(upstream_url, "") proxy := NewOAuthProxy(opts, func(string) bool { return false }) rw := httptest.NewRecorder() req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) proxy.ServeHTTP(rw, req) assert.Equal(t, 200, rw.Code) assert.Equal(t, "response", rw.Body.String()) } type SignatureAuthenticator struct { auth hmacauth.HmacAuth } func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) { result, headerSig, computedSig := v.auth.AuthenticateRequest(r) if result == hmacauth.ResultNoSignature { w.Write([]byte("no signature received")) } else if result == hmacauth.ResultMatch { w.Write([]byte("signatures match")) } else if result == hmacauth.ResultMismatch { w.Write([]byte("signatures do not match:" + "\n received: " + headerSig + "\n computed: " + computedSig)) } else { panic("Unknown result value: " + result.String()) } } type SignatureTest struct { opts *Options upstream *httptest.Server upstream_host string provider *httptest.Server header http.Header rw *httptest.ResponseRecorder authenticator *SignatureAuthenticator } func NewSignatureTest() *SignatureTest { opts := NewOptions() opts.CookieSecret = "cookie secret" opts.ClientID = "client ID" opts.ClientSecret = "client secret" opts.EmailDomains = []string{"acm.org"} authenticator := &SignatureAuthenticator{} upstream := httptest.NewServer( http.HandlerFunc(authenticator.Authenticate)) upstream_url, _ := url.Parse(upstream.URL) opts.Upstreams = append(opts.Upstreams, upstream.URL) providerHandler := func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"access_token": "my_auth_token"}`)) } provider := httptest.NewServer(http.HandlerFunc(providerHandler)) provider_url, _ := url.Parse(provider.URL) opts.provider = NewTestProvider(provider_url, "mbland@acm.org") return &SignatureTest{ opts, upstream, upstream_url.Host, provider, make(http.Header), httptest.NewRecorder(), authenticator, } } func (st *SignatureTest) Close() { st.provider.Close() st.upstream.Close() } // fakeNetConn simulates an http.Request.Body buffer that will be consumed // when it is read by the hmacauth.HmacAuth if not handled properly. See: // https://github.com/18F/hmacauth/pull/4 type fakeNetConn struct { reqBody string } func (fnc *fakeNetConn) Read(p []byte) (n int, err error) { if bodyLen := len(fnc.reqBody); bodyLen != 0 { copy(p, fnc.reqBody) fnc.reqBody = "" return bodyLen, io.EOF } return 0, io.EOF } func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { err := st.opts.Validate() if err != nil { panic(err) } proxy := NewOAuthProxy(st.opts, func(email string) bool { return true }) var bodyBuf io.ReadCloser if body != "" { bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body}) } req := httptest.NewRequest(method, "/foo/bar", bodyBuf) req.Header = st.header state := &providers.SessionState{ Email: "mbland@acm.org", AccessToken: "my_access_token"} value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) if err != nil { panic(err) } cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) req.AddCookie(cookie) // This is used by the upstream to validate the signature. st.authenticator.auth = hmacauth.NewHmacAuth( crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders) proxy.ServeHTTP(st.rw, req) } func TestNoRequestSignature(t *testing.T) { st := NewSignatureTest() defer st.Close() st.MakeRequestWithExpectedKey("GET", "", "") assert.Equal(t, 200, st.rw.Code) assert.Equal(t, st.rw.Body.String(), "no signature received") } func TestRequestSignatureGetRequest(t *testing.T) { st := NewSignatureTest() defer st.Close() st.opts.SignatureKey = "sha1:foobar" st.MakeRequestWithExpectedKey("GET", "", "foobar") assert.Equal(t, 200, st.rw.Code) assert.Equal(t, st.rw.Body.String(), "signatures match") } func TestRequestSignaturePostRequest(t *testing.T) { st := NewSignatureTest() defer st.Close() st.opts.SignatureKey = "sha1:foobar" payload := `{ "hello": "world!" }` st.MakeRequestWithExpectedKey("POST", payload, "foobar") assert.Equal(t, 200, st.rw.Code) assert.Equal(t, st.rw.Body.String(), "signatures match") } ================================================ FILE: options.go ================================================ package main import ( "context" "crypto" "crypto/tls" "encoding/base64" "fmt" "net/http" "net/url" "os" "regexp" "strings" "time" "github.com/bitly/oauth2_proxy/providers" oidc "github.com/coreos/go-oidc" "github.com/mbland/hmacauth" ) // Configuration Options that can be set by Command Line Flag, or Config File type Options struct { ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"` HttpAddress string `flag:"http-address" cfg:"http_address"` HttpsAddress string `flag:"https-address" cfg:"https_address"` RedirectURL string `flag:"redirect-url" cfg:"redirect_url"` ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` TLSCertFile string `flag:"tls-cert" cfg:"tls_cert_file"` TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file"` AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"` EmailDomains []string `flag:"email-domain" cfg:"email_domains"` GitHubOrg string `flag:"github-org" cfg:"github_org"` GitHubTeam string `flag:"github-team" cfg:"github_team"` GoogleGroups []string `flag:"google-group" cfg:"google_group"` GoogleAdminEmail string `flag:"google-admin-email" cfg:"google_admin_email"` GoogleServiceAccountJSON string `flag:"google-service-account-json" cfg:"google_service_account_json"` HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` Footer string `flag:"footer" cfg:"footer"` CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"` CookieHttpOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` Upstreams []string `flag:"upstream" cfg:"upstreams"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"` BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password"` PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token"` PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button"` PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"` SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"` SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` // These options allow for other providers besides Google, with // potential overrides. Provider string `flag:"provider" cfg:"provider"` OIDCIssuerURL string `flag:"oidc-issuer-url" cfg:"oidc_issuer_url"` LoginURL string `flag:"login-url" cfg:"login_url"` RedeemURL string `flag:"redeem-url" cfg:"redeem_url"` ProfileURL string `flag:"profile-url" cfg:"profile_url"` ProtectedResource string `flag:"resource" cfg:"resource"` ValidateURL string `flag:"validate-url" cfg:"validate_url"` Scope string `flag:"scope" cfg:"scope"` ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt"` RequestLogging bool `flag:"request-logging" cfg:"request_logging"` RequestLoggingFormat string `flag:"request-logging-format" cfg:"request_logging_format"` SignatureKey string `flag:"signature-key" cfg:"signature_key" env:"OAUTH2_PROXY_SIGNATURE_KEY"` // internal values that are set after config validation redirectURL *url.URL proxyURLs []*url.URL CompiledRegex []*regexp.Regexp provider providers.Provider signatureData *SignatureData oidcVerifier *oidc.IDTokenVerifier } type SignatureData struct { hash crypto.Hash key string } func NewOptions() *Options { return &Options{ ProxyPrefix: "/oauth2", HttpAddress: "127.0.0.1:4180", HttpsAddress: ":443", DisplayHtpasswdForm: true, CookieName: "_oauth2_proxy", CookieSecure: true, CookieHttpOnly: true, CookieExpire: time.Duration(168) * time.Hour, CookieRefresh: time.Duration(0), SetXAuthRequest: false, SkipAuthPreflight: false, PassBasicAuth: true, PassUserHeaders: true, PassAccessToken: false, PassHostHeader: true, ApprovalPrompt: "force", RequestLogging: true, RequestLoggingFormat: defaultRequestLoggingFormat, } } func parseURL(to_parse string, urltype string, msgs []string) (*url.URL, []string) { parsed, err := url.Parse(to_parse) if err != nil { return nil, append(msgs, fmt.Sprintf( "error parsing %s-url=%q %s", urltype, to_parse, err)) } return parsed, msgs } func (o *Options) Validate() error { if o.SSLInsecureSkipVerify { // TODO: Accept a certificate bundle. insecureTransport := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } http.DefaultClient = &http.Client{Transport: insecureTransport} } msgs := make([]string, 0) if o.CookieSecret == "" { msgs = append(msgs, "missing setting: cookie-secret") } if o.ClientID == "" { msgs = append(msgs, "missing setting: client-id") } if o.ClientSecret == "" { msgs = append(msgs, "missing setting: client-secret") } if o.AuthenticatedEmailsFile == "" && len(o.EmailDomains) == 0 && o.HtpasswdFile == "" { msgs = append(msgs, "missing setting for email validation: email-domain or authenticated-emails-file required."+ "\n use email-domain=* to authorize all email addresses") } if o.OIDCIssuerURL != "" { // Configure discoverable provider data. provider, err := oidc.NewProvider(context.Background(), o.OIDCIssuerURL) if err != nil { return err } o.oidcVerifier = provider.Verifier(&oidc.Config{ ClientID: o.ClientID, }) o.LoginURL = provider.Endpoint().AuthURL o.RedeemURL = provider.Endpoint().TokenURL if o.Scope == "" { o.Scope = "openid email profile" } } o.redirectURL, msgs = parseURL(o.RedirectURL, "redirect", msgs) for _, u := range o.Upstreams { upstreamURL, err := url.Parse(u) if err != nil { msgs = append(msgs, fmt.Sprintf("error parsing upstream: %s", err)) } else { if upstreamURL.Path == "" { upstreamURL.Path = "/" } o.proxyURLs = append(o.proxyURLs, upstreamURL) } } for _, u := range o.SkipAuthRegex { CompiledRegex, err := regexp.Compile(u) if err != nil { msgs = append(msgs, fmt.Sprintf("error compiling regex=%q %s", u, err)) continue } o.CompiledRegex = append(o.CompiledRegex, CompiledRegex) } msgs = parseProviderInfo(o, msgs) if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { valid_cookie_secret_size := false for _, i := range []int{16, 24, 32} { if len(secretBytes(o.CookieSecret)) == i { valid_cookie_secret_size = true } } var decoded bool if string(secretBytes(o.CookieSecret)) != o.CookieSecret { decoded = true } if valid_cookie_secret_size == false { var suffix string if decoded { suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret) } msgs = append(msgs, fmt.Sprintf( "cookie_secret must be 16, 24, or 32 bytes "+ "to create an AES cipher when "+ "pass_access_token == true or "+ "cookie_refresh != 0, but is %d bytes.%s", len(secretBytes(o.CookieSecret)), suffix)) } } if o.CookieRefresh >= o.CookieExpire { msgs = append(msgs, fmt.Sprintf( "cookie_refresh (%s) must be less than "+ "cookie_expire (%s)", o.CookieRefresh.String(), o.CookieExpire.String())) } if len(o.GoogleGroups) > 0 || o.GoogleAdminEmail != "" || o.GoogleServiceAccountJSON != "" { if len(o.GoogleGroups) < 1 { msgs = append(msgs, "missing setting: google-group") } if o.GoogleAdminEmail == "" { msgs = append(msgs, "missing setting: google-admin-email") } if o.GoogleServiceAccountJSON == "" { msgs = append(msgs, "missing setting: google-service-account-json") } } msgs = parseSignatureKey(o, msgs) msgs = validateCookieName(o, msgs) if len(msgs) != 0 { return fmt.Errorf("Invalid configuration:\n %s", strings.Join(msgs, "\n ")) } return nil } func parseProviderInfo(o *Options, msgs []string) []string { p := &providers.ProviderData{ Scope: o.Scope, ClientID: o.ClientID, ClientSecret: o.ClientSecret, ApprovalPrompt: o.ApprovalPrompt, } p.LoginURL, msgs = parseURL(o.LoginURL, "login", msgs) p.RedeemURL, msgs = parseURL(o.RedeemURL, "redeem", msgs) p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs) p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) o.provider = providers.New(o.Provider, p) switch p := o.provider.(type) { case *providers.AzureProvider: p.Configure(o.AzureTenant) case *providers.GitHubProvider: p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam) case *providers.GoogleProvider: if o.GoogleServiceAccountJSON != "" { file, err := os.Open(o.GoogleServiceAccountJSON) if err != nil { msgs = append(msgs, "invalid Google credentials file: "+o.GoogleServiceAccountJSON) } else { p.SetGroupRestriction(o.GoogleGroups, o.GoogleAdminEmail, file) } } case *providers.OIDCProvider: if o.oidcVerifier == nil { msgs = append(msgs, "oidc provider requires an oidc issuer URL") } else { p.Verifier = o.oidcVerifier } } return msgs } func parseSignatureKey(o *Options, msgs []string) []string { if o.SignatureKey == "" { return msgs } components := strings.Split(o.SignatureKey, ":") if len(components) != 2 { return append(msgs, "invalid signature hash:key spec: "+ o.SignatureKey) } algorithm, secretKey := components[0], components[1] if hash, err := hmacauth.DigestNameToCryptoHash(algorithm); err != nil { return append(msgs, "unsupported signature hash algorithm: "+ o.SignatureKey) } else { o.signatureData = &SignatureData{hash, secretKey} } return msgs } func validateCookieName(o *Options, msgs []string) []string { cookie := &http.Cookie{Name: o.CookieName} if cookie.String() == "" { return append(msgs, fmt.Sprintf("invalid cookie name: %q", o.CookieName)) } return msgs } func addPadding(secret string) string { padding := len(secret) % 4 switch padding { case 1: return secret + "===" case 2: return secret + "==" case 3: return secret + "=" default: return secret } } // secretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary func secretBytes(secret string) []byte { b, err := base64.URLEncoding.DecodeString(addPadding(secret)) if err == nil { return []byte(addPadding(string(b))) } return []byte(secret) } ================================================ FILE: options_test.go ================================================ package main import ( "crypto" "fmt" "net/url" "strings" "testing" "time" "github.com/stretchr/testify/assert" ) func testOptions() *Options { o := NewOptions() o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8080/") o.CookieSecret = "foobar" o.ClientID = "bazquux" o.ClientSecret = "xyzzyplugh" o.EmailDomains = []string{"*"} return o } func errorMsg(msgs []string) string { result := make([]string, 0) result = append(result, "Invalid configuration:") result = append(result, msgs...) return strings.Join(result, "\n ") } func TestNewOptions(t *testing.T) { o := NewOptions() o.EmailDomains = []string{"*"} err := o.Validate() assert.NotEqual(t, nil, err) expected := errorMsg([]string{ "missing setting: cookie-secret", "missing setting: client-id", "missing setting: client-secret"}) assert.Equal(t, expected, err.Error()) } func TestGoogleGroupOptions(t *testing.T) { o := testOptions() o.GoogleGroups = []string{"googlegroup"} err := o.Validate() assert.NotEqual(t, nil, err) expected := errorMsg([]string{ "missing setting: google-admin-email", "missing setting: google-service-account-json"}) assert.Equal(t, expected, err.Error()) } func TestGoogleGroupInvalidFile(t *testing.T) { o := testOptions() o.GoogleGroups = []string{"test_group"} o.GoogleAdminEmail = "admin@example.com" o.GoogleServiceAccountJSON = "file_doesnt_exist.json" err := o.Validate() assert.NotEqual(t, nil, err) expected := errorMsg([]string{ "invalid Google credentials file: file_doesnt_exist.json", }) assert.Equal(t, expected, err.Error()) } func TestInitializedOptions(t *testing.T) { o := testOptions() assert.Equal(t, nil, o.Validate()) } // Note that it's not worth testing nonparseable URLs, since url.Parse() // seems to parse damn near anything. func TestRedirectURL(t *testing.T) { o := testOptions() o.RedirectURL = "https://myhost.com/oauth2/callback" assert.Equal(t, nil, o.Validate()) expected := &url.URL{ Scheme: "https", Host: "myhost.com", Path: "/oauth2/callback"} assert.Equal(t, expected, o.redirectURL) } func TestProxyURLs(t *testing.T) { o := testOptions() o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") assert.Equal(t, nil, o.Validate()) expected := []*url.URL{ &url.URL{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, // note the '/' was added &url.URL{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, } assert.Equal(t, expected, o.proxyURLs) } func TestProxyURLsError(t *testing.T) { o := testOptions() o.Upstreams = append(o.Upstreams, "127.0.0.1:8081") err := o.Validate() assert.NotEqual(t, nil, err) expected := errorMsg([]string{ "error parsing upstream: parse 127.0.0.1:8081: " + "first path segment in URL cannot contain colon"}) assert.Equal(t, expected, err.Error()) } func TestCompiledRegex(t *testing.T) { o := testOptions() regexps := []string{"/foo/.*", "/ba[rz]/quux"} o.SkipAuthRegex = regexps assert.Equal(t, nil, o.Validate()) actual := make([]string, 0) for _, regex := range o.CompiledRegex { actual = append(actual, regex.String()) } assert.Equal(t, regexps, actual) } func TestCompiledRegexError(t *testing.T) { o := testOptions() o.SkipAuthRegex = []string{"(foobaz", "barquux)"} err := o.Validate() assert.NotEqual(t, nil, err) expected := errorMsg([]string{ "error compiling regex=\"(foobaz\" error parsing regexp: " + "missing closing ): `(foobaz`", "error compiling regex=\"barquux)\" error parsing regexp: " + "unexpected ): `barquux)`"}) assert.Equal(t, expected, err.Error()) o.SkipAuthRegex = []string{"foobaz", "barquux)"} err = o.Validate() assert.NotEqual(t, nil, err) expected = errorMsg([]string{ "error compiling regex=\"barquux)\" error parsing regexp: " + "unexpected ): `barquux)`"}) assert.Equal(t, expected, err.Error()) } func TestDefaultProviderApiSettings(t *testing.T) { o := testOptions() assert.Equal(t, nil, o.Validate()) p := o.provider.Data() assert.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", p.LoginURL.String()) assert.Equal(t, "https://www.googleapis.com/oauth2/v3/token", p.RedeemURL.String()) assert.Equal(t, "", p.ProfileURL.String()) assert.Equal(t, "profile email", p.Scope) } func TestPassAccessTokenRequiresSpecificCookieSecretLengths(t *testing.T) { o := testOptions() assert.Equal(t, nil, o.Validate()) assert.Equal(t, false, o.PassAccessToken) o.PassAccessToken = true o.CookieSecret = "cookie of invalid length-" assert.NotEqual(t, nil, o.Validate()) o.PassAccessToken = false o.CookieRefresh = time.Duration(24) * time.Hour assert.NotEqual(t, nil, o.Validate()) o.CookieSecret = "16 bytes AES-128" assert.Equal(t, nil, o.Validate()) o.CookieSecret = "24 byte secret AES-192--" assert.Equal(t, nil, o.Validate()) o.CookieSecret = "32 byte secret for AES-256------" assert.Equal(t, nil, o.Validate()) } func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) { o := testOptions() assert.Equal(t, nil, o.Validate()) o.CookieSecret = "0123456789abcdefabcd" o.CookieRefresh = o.CookieExpire assert.NotEqual(t, nil, o.Validate()) o.CookieRefresh -= time.Duration(1) assert.Equal(t, nil, o.Validate()) } func TestBase64CookieSecret(t *testing.T) { o := testOptions() assert.Equal(t, nil, o.Validate()) // 32 byte, base64 (urlsafe) encoded key o.CookieSecret = "yHBw2lh2Cvo6aI_jn_qMTr-pRAjtq0nzVgDJNb36jgQ=" assert.Equal(t, nil, o.Validate()) // 32 byte, base64 (urlsafe) encoded key, w/o padding o.CookieSecret = "yHBw2lh2Cvo6aI_jn_qMTr-pRAjtq0nzVgDJNb36jgQ" assert.Equal(t, nil, o.Validate()) // 24 byte, base64 (urlsafe) encoded key o.CookieSecret = "Kp33Gj-GQmYtz4zZUyUDdqQKx5_Hgkv3" assert.Equal(t, nil, o.Validate()) // 16 byte, base64 (urlsafe) encoded key o.CookieSecret = "LFEqZYvYUwKwzn0tEuTpLA==" assert.Equal(t, nil, o.Validate()) // 16 byte, base64 (urlsafe) encoded key, w/o padding o.CookieSecret = "LFEqZYvYUwKwzn0tEuTpLA" assert.Equal(t, nil, o.Validate()) } func TestValidateSignatureKey(t *testing.T) { o := testOptions() o.SignatureKey = "sha1:secret" assert.Equal(t, nil, o.Validate()) assert.Equal(t, o.signatureData.hash, crypto.SHA1) assert.Equal(t, o.signatureData.key, "secret") } func TestValidateSignatureKeyInvalidSpec(t *testing.T) { o := testOptions() o.SignatureKey = "invalid spec" err := o.Validate() assert.Equal(t, err.Error(), "Invalid configuration:\n"+ " invalid signature hash:key spec: "+o.SignatureKey) } func TestValidateSignatureKeyUnsupportedAlgorithm(t *testing.T) { o := testOptions() o.SignatureKey = "unsupported:default secret" err := o.Validate() assert.Equal(t, err.Error(), "Invalid configuration:\n"+ " unsupported signature hash algorithm: "+o.SignatureKey) } func TestValidateCookie(t *testing.T) { o := testOptions() o.CookieName = "_valid_cookie_name" assert.Equal(t, nil, o.Validate()) } func TestValidateCookieBadName(t *testing.T) { o := testOptions() o.CookieName = "_bad_cookie_name{}" err := o.Validate() assert.Equal(t, err.Error(), "Invalid configuration:\n"+ fmt.Sprintf(" invalid cookie name: %q", o.CookieName)) } ================================================ FILE: providers/azure.go ================================================ package providers import ( "errors" "fmt" "github.com/bitly/go-simplejson" "github.com/bitly/oauth2_proxy/api" "log" "net/http" "net/url" ) type AzureProvider struct { *ProviderData Tenant string } func NewAzureProvider(p *ProviderData) *AzureProvider { p.ProviderName = "Azure" if p.ProfileURL == nil || p.ProfileURL.String() == "" { p.ProfileURL = &url.URL{ Scheme: "https", Host: "graph.windows.net", Path: "/me", RawQuery: "api-version=1.6", } } if p.ProtectedResource == nil || p.ProtectedResource.String() == "" { p.ProtectedResource = &url.URL{ Scheme: "https", Host: "graph.windows.net", } } if p.Scope == "" { p.Scope = "openid" } return &AzureProvider{ProviderData: p} } func (p *AzureProvider) Configure(tenant string) { p.Tenant = tenant if tenant == "" { p.Tenant = "common" } if p.LoginURL == nil || p.LoginURL.String() == "" { p.LoginURL = &url.URL{ Scheme: "https", Host: "login.microsoftonline.com", Path: "/" + p.Tenant + "/oauth2/authorize"} } if p.RedeemURL == nil || p.RedeemURL.String() == "" { p.RedeemURL = &url.URL{ Scheme: "https", Host: "login.microsoftonline.com", Path: "/" + p.Tenant + "/oauth2/token", } } } func getAzureHeader(access_token string) http.Header { header := make(http.Header) header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) return header } func getEmailFromJSON(json *simplejson.Json) (string, error) { var email string var err error email, err = json.Get("mail").String() if err != nil || email == "" { otherMails, otherMailsErr := json.Get("otherMails").Array() if len(otherMails) > 0 { email = otherMails[0].(string) } err = otherMailsErr } return email, err } func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { var email string var err error if s.AccessToken == "" { return "", errors.New("missing access token") } req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) if err != nil { return "", err } req.Header = getAzureHeader(s.AccessToken) json, err := api.Request(req) if err != nil { return "", err } email, err = getEmailFromJSON(json) if err == nil && email != "" { return email, err } email, err = json.Get("userPrincipalName").String() if err != nil { log.Printf("failed making request %s", err) return "", err } if email == "" { log.Printf("failed to get email address") return "", err } return email, err } ================================================ FILE: providers/azure_test.go ================================================ package providers import ( "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" ) func testAzureProvider(hostname string) *AzureProvider { p := NewAzureProvider( &ProviderData{ ProviderName: "", LoginURL: &url.URL{}, RedeemURL: &url.URL{}, ProfileURL: &url.URL{}, ValidateURL: &url.URL{}, ProtectedResource: &url.URL{}, Scope: ""}) if hostname != "" { updateURL(p.Data().LoginURL, hostname) updateURL(p.Data().RedeemURL, hostname) updateURL(p.Data().ProfileURL, hostname) updateURL(p.Data().ValidateURL, hostname) updateURL(p.Data().ProtectedResource, hostname) } return p } func TestAzureProviderDefaults(t *testing.T) { p := testAzureProvider("") assert.NotEqual(t, nil, p) p.Configure("") assert.Equal(t, "Azure", p.Data().ProviderName) assert.Equal(t, "common", p.Tenant) assert.Equal(t, "https://login.microsoftonline.com/common/oauth2/authorize", p.Data().LoginURL.String()) assert.Equal(t, "https://login.microsoftonline.com/common/oauth2/token", p.Data().RedeemURL.String()) assert.Equal(t, "https://graph.windows.net/me?api-version=1.6", p.Data().ProfileURL.String()) assert.Equal(t, "https://graph.windows.net", p.Data().ProtectedResource.String()) assert.Equal(t, "", p.Data().ValidateURL.String()) assert.Equal(t, "openid", p.Data().Scope) } func TestAzureProviderOverrides(t *testing.T) { p := NewAzureProvider( &ProviderData{ LoginURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/auth"}, RedeemURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/token"}, ProfileURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/profile"}, ValidateURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/tokeninfo"}, ProtectedResource: &url.URL{ Scheme: "https", Host: "example.com"}, Scope: "profile"}) assert.NotEqual(t, nil, p) assert.Equal(t, "Azure", p.Data().ProviderName) assert.Equal(t, "https://example.com/oauth/auth", p.Data().LoginURL.String()) assert.Equal(t, "https://example.com/oauth/token", p.Data().RedeemURL.String()) assert.Equal(t, "https://example.com/oauth/profile", p.Data().ProfileURL.String()) assert.Equal(t, "https://example.com/oauth/tokeninfo", p.Data().ValidateURL.String()) assert.Equal(t, "https://example.com", p.Data().ProtectedResource.String()) assert.Equal(t, "profile", p.Data().Scope) } func TestAzureSetTenant(t *testing.T) { p := testAzureProvider("") p.Configure("example") assert.Equal(t, "Azure", p.Data().ProviderName) assert.Equal(t, "example", p.Tenant) assert.Equal(t, "https://login.microsoftonline.com/example/oauth2/authorize", p.Data().LoginURL.String()) assert.Equal(t, "https://login.microsoftonline.com/example/oauth2/token", p.Data().RedeemURL.String()) assert.Equal(t, "https://graph.windows.net/me?api-version=1.6", p.Data().ProfileURL.String()) assert.Equal(t, "https://graph.windows.net", p.Data().ProtectedResource.String()) assert.Equal(t, "", p.Data().ValidateURL.String()) assert.Equal(t, "openid", p.Data().Scope) } func testAzureBackend(payload string) *httptest.Server { path := "/me" query := "api-version=1.6" return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { url := r.URL if url.Path != path || url.RawQuery != query { w.WriteHeader(404) } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { w.WriteHeader(403) } else { w.WriteHeader(200) w.Write([]byte(payload)) } })) } func TestAzureProviderGetEmailAddress(t *testing.T) { b := testAzureBackend(`{ "mail": "user@windows.net" }`) defer b.Close() bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) } func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`) defer b.Close() bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) } func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`) defer b.Close() bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) } func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`) defer b.Close() bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "", email) } func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`) defer b.Close() bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "", email) } func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`) defer b.Close() bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "", email) } ================================================ FILE: providers/facebook.go ================================================ package providers import ( "errors" "fmt" "net/http" "net/url" "github.com/bitly/oauth2_proxy/api" ) type FacebookProvider struct { *ProviderData } func NewFacebookProvider(p *ProviderData) *FacebookProvider { p.ProviderName = "Facebook" if p.LoginURL.String() == "" { p.LoginURL = &url.URL{Scheme: "https", Host: "www.facebook.com", Path: "/v2.5/dialog/oauth", // ?granted_scopes=true } } if p.RedeemURL.String() == "" { p.RedeemURL = &url.URL{Scheme: "https", Host: "graph.facebook.com", Path: "/v2.5/oauth/access_token", } } if p.ProfileURL.String() == "" { p.ProfileURL = &url.URL{Scheme: "https", Host: "graph.facebook.com", Path: "/v2.5/me", } } if p.ValidateURL.String() == "" { p.ValidateURL = p.ProfileURL } if p.Scope == "" { p.Scope = "public_profile email" } return &FacebookProvider{ProviderData: p} } func getFacebookHeader(access_token string) http.Header { header := make(http.Header) header.Set("Accept", "application/json") header.Set("x-li-format", "json") header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) return header } func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { if s.AccessToken == "" { return "", errors.New("missing access token") } req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil) if err != nil { return "", err } req.Header = getFacebookHeader(s.AccessToken) type result struct { Email string } var r result err = api.RequestJson(req, &r) if err != nil { return "", err } if r.Email == "" { return "", errors.New("no email") } return r.Email, nil } func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) } ================================================ FILE: providers/github.go ================================================ package providers import ( "encoding/json" "fmt" "io/ioutil" "log" "net/http" "net/url" "path" "strconv" "strings" ) type GitHubProvider struct { *ProviderData Org string Team string } func NewGitHubProvider(p *ProviderData) *GitHubProvider { p.ProviderName = "GitHub" if p.LoginURL == nil || p.LoginURL.String() == "" { p.LoginURL = &url.URL{ Scheme: "https", Host: "github.com", Path: "/login/oauth/authorize", } } if p.RedeemURL == nil || p.RedeemURL.String() == "" { p.RedeemURL = &url.URL{ Scheme: "https", Host: "github.com", Path: "/login/oauth/access_token", } } // ValidationURL is the API Base URL if p.ValidateURL == nil || p.ValidateURL.String() == "" { p.ValidateURL = &url.URL{ Scheme: "https", Host: "api.github.com", Path: "/", } } if p.Scope == "" { p.Scope = "user:email" } return &GitHubProvider{ProviderData: p} } func (p *GitHubProvider) SetOrgTeam(org, team string) { p.Org = org p.Team = team if org != "" || team != "" { p.Scope += " read:org" } } func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { // https://developer.github.com/v3/orgs/#list-your-organizations var orgs []struct { Login string `json:"login"` } type orgsPage []struct { Login string `json:"login"` } pn := 1 for { params := url.Values{ "limit": {"200"}, "page": {strconv.Itoa(pn)}, } endpoint := &url.URL{ Scheme: p.ValidateURL.Scheme, Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user/orgs"), RawQuery: params.Encode(), } req, _ := http.NewRequest("GET", endpoint.String(), nil) req.Header.Set("Accept", "application/vnd.github.v3+json") req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) resp, err := http.DefaultClient.Do(req) if err != nil { return false, err } body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { return false, err } if resp.StatusCode != 200 { return false, fmt.Errorf( "got %d from %q %s", resp.StatusCode, endpoint.String(), body) } var op orgsPage if err := json.Unmarshal(body, &op); err != nil { return false, err } if len(op) == 0 { break } orgs = append(orgs, op...) pn += 1 } var presentOrgs []string for _, org := range orgs { if p.Org == org.Login { log.Printf("Found Github Organization: %q", org.Login) return true, nil } presentOrgs = append(presentOrgs, org.Login) } log.Printf("Missing Organization:%q in %v", p.Org, presentOrgs) return false, nil } func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { // https://developer.github.com/v3/orgs/teams/#list-user-teams var teams []struct { Name string `json:"name"` Slug string `json:"slug"` Org struct { Login string `json:"login"` } `json:"organization"` } params := url.Values{ "limit": {"200"}, } endpoint := &url.URL{ Scheme: p.ValidateURL.Scheme, Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user/teams"), RawQuery: params.Encode(), } req, _ := http.NewRequest("GET", endpoint.String(), nil) req.Header.Set("Accept", "application/vnd.github.v3+json") req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) resp, err := http.DefaultClient.Do(req) if err != nil { return false, err } body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { return false, err } if resp.StatusCode != 200 { return false, fmt.Errorf( "got %d from %q %s", resp.StatusCode, endpoint.String(), body) } if err := json.Unmarshal(body, &teams); err != nil { return false, fmt.Errorf("%s unmarshaling %s", err, body) } var hasOrg bool presentOrgs := make(map[string]bool) var presentTeams []string for _, team := range teams { presentOrgs[team.Org.Login] = true if p.Org == team.Org.Login { hasOrg = true ts := strings.Split(p.Team, ",") for _, t := range ts { if t == team.Slug { log.Printf("Found Github Organization:%q Team:%q (Name:%q)", team.Org.Login, team.Slug, team.Name) return true, nil } } presentTeams = append(presentTeams, team.Slug) } } if hasOrg { log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams) } else { var allOrgs []string for org, _ := range presentOrgs { allOrgs = append(allOrgs, org) } log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs) } return false, nil } func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { var emails []struct { Email string `json:"email"` Primary bool `json:"primary"` } // if we require an Org or Team, check that first if p.Org != "" { if p.Team != "" { if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok { return "", err } } else { if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok { return "", err } } } endpoint := &url.URL{ Scheme: p.ValidateURL.Scheme, Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user/emails"), } req, _ := http.NewRequest("GET", endpoint.String(), nil) req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken)) resp, err := http.DefaultClient.Do(req) if err != nil { return "", err } body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { return "", err } if resp.StatusCode != 200 { return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) } log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) if err := json.Unmarshal(body, &emails); err != nil { return "", fmt.Errorf("%s unmarshaling %s", err, body) } for _, email := range emails { if email.Primary { return email.Email, nil } } return "", nil } func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { var user struct { Login string `json:"login"` Email string `json:"email"` } endpoint := &url.URL{ Scheme: p.ValidateURL.Scheme, Host: p.ValidateURL.Host, Path: path.Join(p.ValidateURL.Path, "/user"), } req, err := http.NewRequest("GET", endpoint.String(), nil) if err != nil { return "", fmt.Errorf("could not create new GET request: %v", err) } req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken)) resp, err := http.DefaultClient.Do(req) if err != nil { return "", err } body, err := ioutil.ReadAll(resp.Body) defer resp.Body.Close() if err != nil { return "", err } if resp.StatusCode != 200 { return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) } log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) if err := json.Unmarshal(body, &user); err != nil { return "", fmt.Errorf("%s unmarshaling %s", err, body) } return user.Login, nil } ================================================ FILE: providers/github_test.go ================================================ package providers import ( "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" ) func testGitHubProvider(hostname string) *GitHubProvider { p := NewGitHubProvider( &ProviderData{ ProviderName: "", LoginURL: &url.URL{}, RedeemURL: &url.URL{}, ProfileURL: &url.URL{}, ValidateURL: &url.URL{}, Scope: ""}) if hostname != "" { updateURL(p.Data().LoginURL, hostname) updateURL(p.Data().RedeemURL, hostname) updateURL(p.Data().ProfileURL, hostname) updateURL(p.Data().ValidateURL, hostname) } return p } func testGitHubBackend(payload []string) *httptest.Server { pathToQueryMap := map[string][]string{ "/user": []string{""}, "/user/emails": []string{""}, "/user/orgs": []string{"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"}, } return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { url := r.URL query, ok := pathToQueryMap[url.Path] validQuery := false index := 0 for i, q := range query { if q == url.RawQuery { validQuery = true index = i } } if !ok { w.WriteHeader(404) } else if !validQuery { w.WriteHeader(404) } else { w.WriteHeader(200) w.Write([]byte(payload[index])) } })) } func TestGitHubProviderDefaults(t *testing.T) { p := testGitHubProvider("") assert.NotEqual(t, nil, p) assert.Equal(t, "GitHub", p.Data().ProviderName) assert.Equal(t, "https://github.com/login/oauth/authorize", p.Data().LoginURL.String()) assert.Equal(t, "https://github.com/login/oauth/access_token", p.Data().RedeemURL.String()) assert.Equal(t, "https://api.github.com/", p.Data().ValidateURL.String()) assert.Equal(t, "user:email", p.Data().Scope) } func TestGitHubProviderOverrides(t *testing.T) { p := NewGitHubProvider( &ProviderData{ LoginURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/login/oauth/authorize"}, RedeemURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/login/oauth/access_token"}, ValidateURL: &url.URL{ Scheme: "https", Host: "api.example.com", Path: "/"}, Scope: "profile"}) assert.NotEqual(t, nil, p) assert.Equal(t, "GitHub", p.Data().ProviderName) assert.Equal(t, "https://example.com/login/oauth/authorize", p.Data().LoginURL.String()) assert.Equal(t, "https://example.com/login/oauth/access_token", p.Data().RedeemURL.String()) assert.Equal(t, "https://api.example.com/", p.Data().ValidateURL.String()) assert.Equal(t, "profile", p.Data().Scope) } func TestGitHubProviderGetEmailAddress(t *testing.T) { b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "primary": true} ]`}) defer b.Close() bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { b := testGitHubBackend([]string{ `[ {"email": "michael.bland@gsa.gov", "primary": true, "login":"testorg"} ]`, `[ {"email": "michael.bland1@gsa.gov", "primary": true, "login":"testorg1"} ]`, `[ ]`, }) defer b.Close() bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) p.Org = "testorg1" session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } // Note that trying to trigger the "failed building request" case is not // practical, since the only way it can fail is if the URL fails to parse. func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { b := testGitHubBackend([]string{"unused payload"}) defer b.Close() bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &SessionState{AccessToken: "unexpected_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"}) defer b.Close() bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } func TestGitHubProviderGetUserName(t *testing.T) { b := testGitHubBackend([]string{`{"email": "michael.bland@gsa.gov", "login": "mbland"}`}) defer b.Close() bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetUserName(session) assert.Equal(t, nil, err) assert.Equal(t, "mbland", email) } ================================================ FILE: providers/gitlab.go ================================================ package providers import ( "log" "net/http" "net/url" "github.com/bitly/oauth2_proxy/api" ) type GitLabProvider struct { *ProviderData } func NewGitLabProvider(p *ProviderData) *GitLabProvider { p.ProviderName = "GitLab" if p.LoginURL == nil || p.LoginURL.String() == "" { p.LoginURL = &url.URL{ Scheme: "https", Host: "gitlab.com", Path: "/oauth/authorize", } } if p.RedeemURL == nil || p.RedeemURL.String() == "" { p.RedeemURL = &url.URL{ Scheme: "https", Host: "gitlab.com", Path: "/oauth/token", } } if p.ValidateURL == nil || p.ValidateURL.String() == "" { p.ValidateURL = &url.URL{ Scheme: "https", Host: "gitlab.com", Path: "/api/v4/user", } } if p.Scope == "" { p.Scope = "read_user" } return &GitLabProvider{ProviderData: p} } func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { req, err := http.NewRequest("GET", p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) if err != nil { log.Printf("failed building request %s", err) return "", err } json, err := api.Request(req) if err != nil { log.Printf("failed making request %s", err) return "", err } return json.Get("email").String() } ================================================ FILE: providers/gitlab_test.go ================================================ package providers import ( "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" ) func testGitLabProvider(hostname string) *GitLabProvider { p := NewGitLabProvider( &ProviderData{ ProviderName: "", LoginURL: &url.URL{}, RedeemURL: &url.URL{}, ProfileURL: &url.URL{}, ValidateURL: &url.URL{}, Scope: ""}) if hostname != "" { updateURL(p.Data().LoginURL, hostname) updateURL(p.Data().RedeemURL, hostname) updateURL(p.Data().ProfileURL, hostname) updateURL(p.Data().ValidateURL, hostname) } return p } func testGitLabBackend(payload string) *httptest.Server { path := "/api/v4/user" query := "access_token=imaginary_access_token" return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { url := r.URL if url.Path != path || url.RawQuery != query { w.WriteHeader(404) } else { w.WriteHeader(200) w.Write([]byte(payload)) } })) } func TestGitLabProviderDefaults(t *testing.T) { p := testGitLabProvider("") assert.NotEqual(t, nil, p) assert.Equal(t, "GitLab", p.Data().ProviderName) assert.Equal(t, "https://gitlab.com/oauth/authorize", p.Data().LoginURL.String()) assert.Equal(t, "https://gitlab.com/oauth/token", p.Data().RedeemURL.String()) assert.Equal(t, "https://gitlab.com/api/v4/user", p.Data().ValidateURL.String()) assert.Equal(t, "read_user", p.Data().Scope) } func TestGitLabProviderOverrides(t *testing.T) { p := NewGitLabProvider( &ProviderData{ LoginURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/auth"}, RedeemURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/token"}, ValidateURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/api/v4/user"}, Scope: "profile"}) assert.NotEqual(t, nil, p) assert.Equal(t, "GitLab", p.Data().ProviderName) assert.Equal(t, "https://example.com/oauth/auth", p.Data().LoginURL.String()) assert.Equal(t, "https://example.com/oauth/token", p.Data().RedeemURL.String()) assert.Equal(t, "https://example.com/api/v4/user", p.Data().ValidateURL.String()) assert.Equal(t, "profile", p.Data().Scope) } func TestGitLabProviderGetEmailAddress(t *testing.T) { b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}") defer b.Close() b_url, _ := url.Parse(b.URL) p := testGitLabProvider(b_url.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } // Note that trying to trigger the "failed building request" case is not // practical, since the only way it can fail is if the URL fails to parse. func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) { b := testGitLabBackend("unused payload") defer b.Close() b_url, _ := url.Parse(b.URL) p := testGitLabProvider(b_url.Host) // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &SessionState{AccessToken: "unexpected_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { b := testGitLabBackend("{\"foo\": \"bar\"}") defer b.Close() b_url, _ := url.Parse(b.URL) p := testGitLabProvider(b_url.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } ================================================ FILE: providers/google.go ================================================ package providers import ( "bytes" "encoding/base64" "encoding/json" "errors" "fmt" "io" "io/ioutil" "log" "net/http" "net/url" "strings" "time" "golang.org/x/oauth2" "golang.org/x/oauth2/google" "google.golang.org/api/admin/directory/v1" "google.golang.org/api/googleapi" ) type GoogleProvider struct { *ProviderData RedeemRefreshURL *url.URL // GroupValidator is a function that determines if the passed email is in // the configured Google group. GroupValidator func(string) bool } func NewGoogleProvider(p *ProviderData) *GoogleProvider { p.ProviderName = "Google" if p.LoginURL.String() == "" { p.LoginURL = &url.URL{Scheme: "https", Host: "accounts.google.com", Path: "/o/oauth2/auth", // to get a refresh token. see https://developers.google.com/identity/protocols/OAuth2WebServer#offline RawQuery: "access_type=offline", } } if p.RedeemURL.String() == "" { p.RedeemURL = &url.URL{Scheme: "https", Host: "www.googleapis.com", Path: "/oauth2/v3/token"} } if p.ValidateURL.String() == "" { p.ValidateURL = &url.URL{Scheme: "https", Host: "www.googleapis.com", Path: "/oauth2/v1/tokeninfo"} } if p.Scope == "" { p.Scope = "profile email" } return &GoogleProvider{ ProviderData: p, // Set a default GroupValidator to just always return valid (true), it will // be overwritten if we configured a Google group restriction. GroupValidator: func(email string) bool { return true }, } } func emailFromIdToken(idToken string) (string, error) { // id_token is a base64 encode ID token payload // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo jwt := strings.Split(idToken, ".") jwtData := strings.TrimSuffix(jwt[1], "=") b, err := base64.RawURLEncoding.DecodeString(jwtData) if err != nil { return "", err } var email struct { Email string `json:"email"` EmailVerified bool `json:"email_verified"` } err = json.Unmarshal(b, &email) if err != nil { return "", err } if email.Email == "" { return "", errors.New("missing email") } if !email.EmailVerified { return "", fmt.Errorf("email %s not listed as verified", email.Email) } return email.Email, nil } func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { if code == "" { err = errors.New("missing code") return } params := url.Values{} params.Add("redirect_uri", redirectURL) params.Add("client_id", p.ClientID) params.Add("client_secret", p.ClientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") var req *http.Request req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := http.DefaultClient.Do(req) if err != nil { return } var body []byte body, err = ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { return } if resp.StatusCode != 200 { err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) return } var jsonResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int64 `json:"expires_in"` IdToken string `json:"id_token"` } err = json.Unmarshal(body, &jsonResponse) if err != nil { return } var email string email, err = emailFromIdToken(jsonResponse.IdToken) if err != nil { return } s = &SessionState{ AccessToken: jsonResponse.AccessToken, ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), RefreshToken: jsonResponse.RefreshToken, Email: email, } return } // SetGroupRestriction configures the GoogleProvider to restrict access to the // specified group(s). AdminEmail has to be an administrative email on the domain that is // checked. CredentialsFile is the path to a json file containing a Google service // account credentials. func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { adminService := getAdminService(adminEmail, credentialsReader) p.GroupValidator = func(email string) bool { return userInGroup(adminService, groups, email) } } func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Service { data, err := ioutil.ReadAll(credentialsReader) if err != nil { log.Fatal("can't read Google credentials file:", err) } conf, err := google.JWTConfigFromJSON(data, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope) if err != nil { log.Fatal("can't load Google credentials file:", err) } conf.Subject = adminEmail client := conf.Client(oauth2.NoContext) adminService, err := admin.New(client) if err != nil { log.Fatal(err) } return adminService } func userInGroup(service *admin.Service, groups []string, email string) bool { user, err := fetchUser(service, email) if err != nil { log.Printf("error fetching user: %v", err) return false } id := user.Id custID := user.CustomerId for _, group := range groups { members, err := fetchGroupMembers(service, group) if err != nil { if err, ok := err.(*googleapi.Error); ok && err.Code == 404 { log.Printf("error fetching members for group %s: group does not exist", group) } else { log.Printf("error fetching group members: %v", err) return false } } for _, member := range members { switch member.Type { case "CUSTOMER": if member.Id == custID { return true } case "USER": if member.Id == id { return true } } } } return false } func fetchUser(service *admin.Service, email string) (*admin.User, error) { user, err := service.Users.Get(email).Do() return user, err } func fetchGroupMembers(service *admin.Service, group string) ([]*admin.Member, error) { members := []*admin.Member{} pageToken := "" for { req := service.Members.List(group) if pageToken != "" { req.PageToken(pageToken) } r, err := req.Do() if err != nil { return nil, err } for _, member := range r.Members { members = append(members, member) } if r.NextPageToken == "" { break } pageToken = r.NextPageToken } return members, nil } // ValidateGroup validates that the provided email exists in the configured Google // group(s). func (p *GoogleProvider) ValidateGroup(email string) bool { return p.GroupValidator(email) } func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil } newToken, duration, err := p.redeemRefreshToken(s.RefreshToken) if err != nil { return false, err } // re-check that the user is in the proper google group(s) if !p.ValidateGroup(s.Email) { return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) } origExpiration := s.ExpiresOn s.AccessToken = newToken s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second) log.Printf("refreshed access token %s (expired on %s)", s, origExpiration) return true, nil } func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) { // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh params := url.Values{} params.Add("client_id", p.ClientID) params.Add("client_secret", p.ClientSecret) params.Add("refresh_token", refreshToken) params.Add("grant_type", "refresh_token") var req *http.Request req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := http.DefaultClient.Do(req) if err != nil { return } var body []byte body, err = ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { return } if resp.StatusCode != 200 { err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) return } var data struct { AccessToken string `json:"access_token"` ExpiresIn int64 `json:"expires_in"` } err = json.Unmarshal(body, &data) if err != nil { return } token = data.AccessToken expires = time.Duration(data.ExpiresIn) * time.Second return } ================================================ FILE: providers/google_test.go ================================================ package providers import ( "encoding/base64" "encoding/json" "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" ) func newRedeemServer(body []byte) (*url.URL, *httptest.Server) { s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Write(body) })) u, _ := url.Parse(s.URL) return u, s } func newGoogleProvider() *GoogleProvider { return NewGoogleProvider( &ProviderData{ ProviderName: "", LoginURL: &url.URL{}, RedeemURL: &url.URL{}, ProfileURL: &url.URL{}, ValidateURL: &url.URL{}, Scope: ""}) } func TestGoogleProviderDefaults(t *testing.T) { p := newGoogleProvider() assert.NotEqual(t, nil, p) assert.Equal(t, "Google", p.Data().ProviderName) assert.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", p.Data().LoginURL.String()) assert.Equal(t, "https://www.googleapis.com/oauth2/v3/token", p.Data().RedeemURL.String()) assert.Equal(t, "https://www.googleapis.com/oauth2/v1/tokeninfo", p.Data().ValidateURL.String()) assert.Equal(t, "", p.Data().ProfileURL.String()) assert.Equal(t, "profile email", p.Data().Scope) } func TestGoogleProviderOverrides(t *testing.T) { p := NewGoogleProvider( &ProviderData{ LoginURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/auth"}, RedeemURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/token"}, ProfileURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/profile"}, ValidateURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/tokeninfo"}, Scope: "profile"}) assert.NotEqual(t, nil, p) assert.Equal(t, "Google", p.Data().ProviderName) assert.Equal(t, "https://example.com/oauth/auth", p.Data().LoginURL.String()) assert.Equal(t, "https://example.com/oauth/token", p.Data().RedeemURL.String()) assert.Equal(t, "https://example.com/oauth/profile", p.Data().ProfileURL.String()) assert.Equal(t, "https://example.com/oauth/tokeninfo", p.Data().ValidateURL.String()) assert.Equal(t, "profile", p.Data().Scope) } type redeemResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int64 `json:"expires_in"` IdToken string `json:"id_token"` } func TestGoogleProviderGetEmailAddress(t *testing.T) { p := newGoogleProvider() body, err := json.Marshal(redeemResponse{ AccessToken: "a1234", ExpiresIn: 10, RefreshToken: "refresh12345", IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), }) assert.Equal(t, nil, err) var server *httptest.Server p.RedeemURL, server = newRedeemServer(body) defer server.Close() session, err := p.Redeem("http://redirect/", "code1234") assert.Equal(t, nil, err) assert.NotEqual(t, session, nil) assert.Equal(t, "michael.bland@gsa.gov", session.Email) assert.Equal(t, "a1234", session.AccessToken) assert.Equal(t, "refresh12345", session.RefreshToken) } func TestGoogleProviderValidateGroup(t *testing.T) { p := newGoogleProvider() p.GroupValidator = func(email string) bool { return email == "michael.bland@gsa.gov" } assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) p.GroupValidator = func(email string) bool { return email != "michael.bland@gsa.gov" } assert.Equal(t, false, p.ValidateGroup("michael.bland@gsa.gov")) } func TestGoogleProviderWithoutValidateGroup(t *testing.T) { p := newGoogleProvider() assert.Equal(t, true, p.ValidateGroup("michael.bland@gsa.gov")) } // func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { p := newGoogleProvider() body, err := json.Marshal(redeemResponse{ AccessToken: "a1234", IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, }) assert.Equal(t, nil, err) var server *httptest.Server p.RedeemURL, server = newRedeemServer(body) defer server.Close() session, err := p.Redeem("http://redirect/", "code1234") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) } } func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { p := newGoogleProvider() body, err := json.Marshal(redeemResponse{ AccessToken: "a1234", IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), }) assert.Equal(t, nil, err) var server *httptest.Server p.RedeemURL, server = newRedeemServer(body) defer server.Close() session, err := p.Redeem("http://redirect/", "code1234") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) } } func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { p := newGoogleProvider() body, err := json.Marshal(redeemResponse{ AccessToken: "a1234", IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), }) assert.Equal(t, nil, err) var server *httptest.Server p.RedeemURL, server = newRedeemServer(body) defer server.Close() session, err := p.Redeem("http://redirect/", "code1234") assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expect nill session %#v", session) } } ================================================ FILE: providers/internal_util.go ================================================ package providers import ( "io/ioutil" "log" "net/http" "net/url" "github.com/bitly/oauth2_proxy/api" ) // stripToken is a helper function to obfuscate "access_token" // query parameters func stripToken(endpoint string) string { return stripParam("access_token", endpoint) } // stripParam generalizes the obfuscation of a particular // query parameter - typically 'access_token' or 'client_secret' // The parameter's second half is replaced by '...' and returned // as part of the encoded query parameters. // If the target parameter isn't found, the endpoint is returned // unmodified. func stripParam(param, endpoint string) string { u, err := url.Parse(endpoint) if err != nil { log.Printf("error attempting to strip %s: %s", param, err) return endpoint } if u.RawQuery != "" { values, err := url.ParseQuery(u.RawQuery) if err != nil { log.Printf("error attempting to strip %s: %s", param, err) return u.String() } if val := values.Get(param); val != "" { values.Set(param, val[:(len(val)/2)]+"...") u.RawQuery = values.Encode() return u.String() } } return endpoint } // validateToken returns true if token is valid func validateToken(p Provider, access_token string, header http.Header) bool { if access_token == "" || p.Data().ValidateURL == nil { return false } endpoint := p.Data().ValidateURL.String() if len(header) == 0 { params := url.Values{"access_token": {access_token}} endpoint = endpoint + "?" + params.Encode() } resp, err := api.RequestUnparsedResponse(endpoint, header) if err != nil { log.Printf("GET %s", stripToken(endpoint)) log.Printf("token validation request failed: %s", err) return false } body, _ := ioutil.ReadAll(resp.Body) resp.Body.Close() log.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) if resp.StatusCode == 200 { return true } log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) return false } func updateURL(url *url.URL, hostname string) { url.Scheme = "http" url.Host = hostname } ================================================ FILE: providers/internal_util_test.go ================================================ package providers import ( "errors" "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" ) type ValidateSessionStateTestProvider struct { *ProviderData } func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { return "", errors.New("not implemented") } // Note that we're testing the internal validateToken() used to implement // several Provider's ValidateSessionState() implementations func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { return false } type ValidateSessionStateTest struct { backend *httptest.Server response_code int provider *ValidateSessionStateTestProvider header http.Header } func NewValidateSessionStateTest() *ValidateSessionStateTest { var vt_test ValidateSessionStateTest vt_test.backend = httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/oauth/tokeninfo" { w.WriteHeader(500) w.Write([]byte("unknown URL")) } token_param := r.FormValue("access_token") if token_param == "" { missing := false received_headers := r.Header for k, _ := range vt_test.header { received := received_headers.Get(k) expected := vt_test.header.Get(k) if received == "" || received != expected { missing = true } } if missing { w.WriteHeader(500) w.Write([]byte("no token param and missing or incorrect headers")) } } w.WriteHeader(vt_test.response_code) w.Write([]byte("only code matters; contents disregarded")) })) backend_url, _ := url.Parse(vt_test.backend.URL) vt_test.provider = &ValidateSessionStateTestProvider{ ProviderData: &ProviderData{ ValidateURL: &url.URL{ Scheme: "http", Host: backend_url.Host, Path: "/oauth/tokeninfo", }, }, } vt_test.response_code = 200 return &vt_test } func (vt_test *ValidateSessionStateTest) Close() { vt_test.backend.Close() } func TestValidateSessionStateValidToken(t *testing.T) { vt_test := NewValidateSessionStateTest() defer vt_test.Close() assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil)) } func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { vt_test := NewValidateSessionStateTest() defer vt_test.Close() vt_test.header = make(http.Header) vt_test.header.Set("Authorization", "Bearer foobar") assert.Equal(t, true, validateToken(vt_test.provider, "foobar", vt_test.header)) } func TestValidateSessionStateEmptyToken(t *testing.T) { vt_test := NewValidateSessionStateTest() defer vt_test.Close() assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) } func TestValidateSessionStateEmptyValidateURL(t *testing.T) { vt_test := NewValidateSessionStateTest() defer vt_test.Close() vt_test.provider.Data().ValidateURL = nil assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) } func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { vt_test := NewValidateSessionStateTest() // Close immediately to simulate a network failure vt_test.Close() assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) } func TestValidateSessionStateExpiredToken(t *testing.T) { vt_test := NewValidateSessionStateTest() defer vt_test.Close() vt_test.response_code = 401 assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) } func TestStripTokenNotPresent(t *testing.T) { test := "http://local.test/api/test?a=1&b=2" assert.Equal(t, test, stripToken(test)) } func TestStripToken(t *testing.T) { test := "http://local.test/api/test?access_token=deadbeef&b=1&c=2" expected := "http://local.test/api/test?access_token=dead...&b=1&c=2" assert.Equal(t, expected, stripToken(test)) } ================================================ FILE: providers/linkedin.go ================================================ package providers import ( "errors" "fmt" "net/http" "net/url" "github.com/bitly/oauth2_proxy/api" ) type LinkedInProvider struct { *ProviderData } func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { p.ProviderName = "LinkedIn" if p.LoginURL.String() == "" { p.LoginURL = &url.URL{Scheme: "https", Host: "www.linkedin.com", Path: "/uas/oauth2/authorization"} } if p.RedeemURL.String() == "" { p.RedeemURL = &url.URL{Scheme: "https", Host: "www.linkedin.com", Path: "/uas/oauth2/accessToken"} } if p.ProfileURL.String() == "" { p.ProfileURL = &url.URL{Scheme: "https", Host: "www.linkedin.com", Path: "/v1/people/~/email-address"} } if p.ValidateURL.String() == "" { p.ValidateURL = p.ProfileURL } if p.Scope == "" { p.Scope = "r_emailaddress r_basicprofile" } return &LinkedInProvider{ProviderData: p} } func getLinkedInHeader(access_token string) http.Header { header := make(http.Header) header.Set("Accept", "application/json") header.Set("x-li-format", "json") header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) return header } func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { if s.AccessToken == "" { return "", errors.New("missing access token") } req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil) if err != nil { return "", err } req.Header = getLinkedInHeader(s.AccessToken) json, err := api.Request(req) if err != nil { return "", err } email, err := json.String() if err != nil { return "", err } return email, nil } func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) } ================================================ FILE: providers/linkedin_test.go ================================================ package providers import ( "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" ) func testLinkedInProvider(hostname string) *LinkedInProvider { p := NewLinkedInProvider( &ProviderData{ ProviderName: "", LoginURL: &url.URL{}, RedeemURL: &url.URL{}, ProfileURL: &url.URL{}, ValidateURL: &url.URL{}, Scope: ""}) if hostname != "" { updateURL(p.Data().LoginURL, hostname) updateURL(p.Data().RedeemURL, hostname) updateURL(p.Data().ProfileURL, hostname) } return p } func testLinkedInBackend(payload string) *httptest.Server { path := "/v1/people/~/email-address" return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { url := r.URL if url.Path != path { w.WriteHeader(404) } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { w.WriteHeader(403) } else { w.WriteHeader(200) w.Write([]byte(payload)) } })) } func TestLinkedInProviderDefaults(t *testing.T) { p := testLinkedInProvider("") assert.NotEqual(t, nil, p) assert.Equal(t, "LinkedIn", p.Data().ProviderName) assert.Equal(t, "https://www.linkedin.com/uas/oauth2/authorization", p.Data().LoginURL.String()) assert.Equal(t, "https://www.linkedin.com/uas/oauth2/accessToken", p.Data().RedeemURL.String()) assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address", p.Data().ProfileURL.String()) assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address", p.Data().ValidateURL.String()) assert.Equal(t, "r_emailaddress r_basicprofile", p.Data().Scope) } func TestLinkedInProviderOverrides(t *testing.T) { p := NewLinkedInProvider( &ProviderData{ LoginURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/auth"}, RedeemURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/token"}, ProfileURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/profile"}, ValidateURL: &url.URL{ Scheme: "https", Host: "example.com", Path: "/oauth/tokeninfo"}, Scope: "profile"}) assert.NotEqual(t, nil, p) assert.Equal(t, "LinkedIn", p.Data().ProviderName) assert.Equal(t, "https://example.com/oauth/auth", p.Data().LoginURL.String()) assert.Equal(t, "https://example.com/oauth/token", p.Data().RedeemURL.String()) assert.Equal(t, "https://example.com/oauth/profile", p.Data().ProfileURL.String()) assert.Equal(t, "https://example.com/oauth/tokeninfo", p.Data().ValidateURL.String()) assert.Equal(t, "profile", p.Data().Scope) } func TestLinkedInProviderGetEmailAddress(t *testing.T) { b := testLinkedInBackend(`"user@linkedin.com"`) defer b.Close() b_url, _ := url.Parse(b.URL) p := testLinkedInProvider(b_url.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@linkedin.com", email) } func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { b := testLinkedInBackend("unused payload") defer b.Close() b_url, _ := url.Parse(b.URL) p := testLinkedInProvider(b_url.Host) // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &SessionState{AccessToken: "unexpected_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { b := testLinkedInBackend("{\"foo\": \"bar\"}") defer b.Close() b_url, _ := url.Parse(b.URL) p := testLinkedInProvider(b_url.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } ================================================ FILE: providers/oidc.go ================================================ package providers import ( "context" "fmt" "time" "golang.org/x/oauth2" oidc "github.com/coreos/go-oidc" ) type OIDCProvider struct { *ProviderData Verifier *oidc.IDTokenVerifier } func NewOIDCProvider(p *ProviderData) *OIDCProvider { p.ProviderName = "OpenID Connect" return &OIDCProvider{ProviderData: p} } func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { ctx := context.Background() c := oauth2.Config{ ClientID: p.ClientID, ClientSecret: p.ClientSecret, Endpoint: oauth2.Endpoint{ TokenURL: p.RedeemURL.String(), }, RedirectURL: redirectURL, } token, err := c.Exchange(ctx, code) if err != nil { return nil, fmt.Errorf("token exchange: %v", err) } rawIDToken, ok := token.Extra("id_token").(string) if !ok { return nil, fmt.Errorf("token response did not contain an id_token") } // Parse and verify ID Token payload. idToken, err := p.Verifier.Verify(ctx, rawIDToken) if err != nil { return nil, fmt.Errorf("could not verify id_token: %v", err) } // Extract custom claims. var claims struct { Email string `json:"email"` Verified *bool `json:"email_verified"` } if err := idToken.Claims(&claims); err != nil { return nil, fmt.Errorf("failed to parse id_token claims: %v", err) } if claims.Email == "" { return nil, fmt.Errorf("id_token did not contain an email") } if claims.Verified != nil && !*claims.Verified { return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) } s = &SessionState{ AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, ExpiresOn: token.Expiry, Email: claims.Email, } return } func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil } origExpiration := s.ExpiresOn s.ExpiresOn = time.Now().Add(time.Second).Truncate(time.Second) fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration) return false, nil } ================================================ FILE: providers/provider_data.go ================================================ package providers import ( "net/url" ) type ProviderData struct { ProviderName string ClientID string ClientSecret string LoginURL *url.URL RedeemURL *url.URL ProfileURL *url.URL ProtectedResource *url.URL ValidateURL *url.URL Scope string ApprovalPrompt string } func (p *ProviderData) Data() *ProviderData { return p } ================================================ FILE: providers/provider_default.go ================================================ package providers import ( "bytes" "encoding/json" "errors" "fmt" "io/ioutil" "net/http" "net/url" "github.com/bitly/oauth2_proxy/cookie" ) func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { if code == "" { err = errors.New("missing code") return } params := url.Values{} params.Add("redirect_uri", redirectURL) params.Add("client_id", p.ClientID) params.Add("client_secret", p.ClientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { params.Add("resource", p.ProtectedResource.String()) } var req *http.Request req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") var resp *http.Response resp, err = http.DefaultClient.Do(req) if err != nil { return nil, err } var body []byte body, err = ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { return } if resp.StatusCode != 200 { err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) return } // blindly try json and x-www-form-urlencoded var jsonResponse struct { AccessToken string `json:"access_token"` } err = json.Unmarshal(body, &jsonResponse) if err == nil { s = &SessionState{ AccessToken: jsonResponse.AccessToken, } return } var v url.Values v, err = url.ParseQuery(string(body)) if err != nil { return } if a := v.Get("access_token"); a != "" { s = &SessionState{AccessToken: a} } else { err = fmt.Errorf("no access token found %s", body) } return } // GetLoginURL with typical oauth parameters func (p *ProviderData) GetLoginURL(redirectURI, state string) string { var a url.URL a = *p.LoginURL params, _ := url.ParseQuery(a.RawQuery) params.Set("redirect_uri", redirectURI) params.Set("approval_prompt", p.ApprovalPrompt) params.Add("scope", p.Scope) params.Set("client_id", p.ClientID) params.Set("response_type", "code") params.Add("state", state) a.RawQuery = params.Encode() return a.String() } // CookieForSession serializes a session state for storage in a cookie func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { return s.EncodeSessionState(c) } // SessionFromCookie deserializes a session from a cookie value func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { return DecodeSessionState(v, c) } func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { return "", errors.New("not implemented") } // GetUserName returns the Account username func (p *ProviderData) GetUserName(s *SessionState) (string, error) { return "", errors.New("not implemented") } // ValidateGroup validates that the provided email exists in the configured provider // email group(s). func (p *ProviderData) ValidateGroup(email string) bool { return true } func (p *ProviderData) ValidateSessionState(s *SessionState) bool { return validateToken(p, s.AccessToken, nil) } // RefreshSessionIfNeeded func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { return false, nil } ================================================ FILE: providers/provider_default_test.go ================================================ package providers import ( "testing" "time" "github.com/stretchr/testify/assert" ) func TestRefresh(t *testing.T) { p := &ProviderData{} refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), }) assert.Equal(t, false, refreshed) assert.Equal(t, nil, err) } ================================================ FILE: providers/providers.go ================================================ package providers import ( "github.com/bitly/oauth2_proxy/cookie" ) type Provider interface { Data() *ProviderData GetEmailAddress(*SessionState) (string, error) GetUserName(*SessionState) (string, error) Redeem(string, string) (*SessionState, error) ValidateGroup(string) bool ValidateSessionState(*SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string RefreshSessionIfNeeded(*SessionState) (bool, error) SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) CookieForSession(*SessionState, *cookie.Cipher) (string, error) } func New(provider string, p *ProviderData) Provider { switch provider { case "linkedin": return NewLinkedInProvider(p) case "facebook": return NewFacebookProvider(p) case "github": return NewGitHubProvider(p) case "azure": return NewAzureProvider(p) case "gitlab": return NewGitLabProvider(p) case "oidc": return NewOIDCProvider(p) default: return NewGoogleProvider(p) } } ================================================ FILE: providers/session_state.go ================================================ package providers import ( "fmt" "strconv" "strings" "time" "github.com/bitly/oauth2_proxy/cookie" ) type SessionState struct { AccessToken string ExpiresOn time.Time RefreshToken string Email string User string } func (s *SessionState) IsExpired() bool { if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { return true } return false } func (s *SessionState) String() string { o := fmt.Sprintf("Session{%s", s.accountInfo()) if s.AccessToken != "" { o += " token:true" } if !s.ExpiresOn.IsZero() { o += fmt.Sprintf(" expires:%s", s.ExpiresOn) } if s.RefreshToken != "" { o += " refresh_token:true" } return o + "}" } func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { if c == nil || s.AccessToken == "" { return s.accountInfo(), nil } return s.EncryptedString(c) } func (s *SessionState) accountInfo() string { return fmt.Sprintf("email:%s user:%s", s.Email, s.User) } func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { var err error if c == nil { panic("error. missing cipher") } a := s.AccessToken if a != "" { if a, err = c.Encrypt(a); err != nil { return "", err } } r := s.RefreshToken if r != "" { if r, err = c.Encrypt(r); err != nil { return "", err } } return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil } func decodeSessionStatePlain(v string) (s *SessionState, err error) { chunks := strings.Split(v, " ") if len(chunks) != 2 { return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks)) } email := strings.TrimPrefix(chunks[0], "email:") user := strings.TrimPrefix(chunks[1], "user:") if user == "" { user = strings.Split(email, "@")[0] } return &SessionState{User: user, Email: email}, nil } func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { if c == nil { return decodeSessionStatePlain(v) } chunks := strings.Split(v, "|") if len(chunks) != 4 { err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) return } sessionState, err := decodeSessionStatePlain(chunks[0]) if err != nil { return nil, err } if chunks[1] != "" { if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil { return nil, err } } ts, _ := strconv.Atoi(chunks[2]) sessionState.ExpiresOn = time.Unix(int64(ts), 0) if chunks[3] != "" { if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil { return nil, err } } return sessionState, nil } ================================================ FILE: providers/session_state_test.go ================================================ package providers import ( "fmt" "strings" "testing" "time" "github.com/bitly/oauth2_proxy/cookie" "github.com/stretchr/testify/assert" ) const secret = "0123456789abcdefghijklmnopqrstuv" const altSecret = "0000000000abcdefghijklmnopqrstuv" func TestSessionStateSerialization(t *testing.T) { c, err := cookie.NewCipher([]byte(secret)) assert.Equal(t, nil, err) c2, err := cookie.NewCipher([]byte(altSecret)) assert.Equal(t, nil, err) s := &SessionState{ Email: "user@domain.com", AccessToken: "token1234", ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) assert.Equal(t, 3, strings.Count(encoded, "|")) ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, "user", ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.RefreshToken, ss.RefreshToken) // ensure a different cipher can't decode properly (ie: it gets gibberish) ss, err = DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, "user", ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) } func TestSessionStateSerializationWithUser(t *testing.T) { c, err := cookie.NewCipher([]byte(secret)) assert.Equal(t, nil, err) c2, err := cookie.NewCipher([]byte(altSecret)) assert.Equal(t, nil, err) s := &SessionState{ User: "just-user", Email: "user@domain.com", AccessToken: "token1234", ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) assert.Equal(t, 3, strings.Count(encoded, "|")) ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.RefreshToken, ss.RefreshToken) // ensure a different cipher can't decode properly (ie: it gets gibberish) ss, err = DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) } func TestSessionStateSerializationNoCipher(t *testing.T) { s := &SessionState{ Email: "user@domain.com", AccessToken: "token1234", ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } encoded, err := s.EncodeSessionState(nil) assert.Equal(t, nil, err) expected := fmt.Sprintf("email:%s user:", s.Email) assert.Equal(t, expected, encoded) // only email should have been serialized ss, err := DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) assert.Equal(t, "user", ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, "", ss.AccessToken) assert.Equal(t, "", ss.RefreshToken) } func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { s := &SessionState{ User: "just-user", Email: "user@domain.com", AccessToken: "token1234", ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } encoded, err := s.EncodeSessionState(nil) assert.Equal(t, nil, err) expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User) assert.Equal(t, expected, encoded) // only email should have been serialized ss, err := DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, "", ss.AccessToken) assert.Equal(t, "", ss.RefreshToken) } func TestSessionStateAccountInfo(t *testing.T) { s := &SessionState{ Email: "user@domain.com", User: "just-user", } expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User) assert.Equal(t, expected, s.accountInfo()) s.Email = "" expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User) assert.Equal(t, expected, s.accountInfo()) } func TestExpired(t *testing.T) { s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} assert.Equal(t, true, s.IsExpired()) s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} assert.Equal(t, false, s.IsExpired()) s = &SessionState{} assert.Equal(t, false, s.IsExpired()) } ================================================ FILE: string_array.go ================================================ package main import ( "strings" ) type StringArray []string func (a *StringArray) Set(s string) error { *a = append(*a, s) return nil } func (a *StringArray) String() string { return strings.Join(*a, ",") } ================================================ FILE: templates.go ================================================ package main import ( "html/template" "log" "path" ) func loadTemplates(dir string) *template.Template { if dir == "" { return getTemplates() } log.Printf("using custom template directory %q", dir) t, err := template.New("").ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html")) if err != nil { log.Fatalf("failed parsing template %s", err) } return t } func getTemplates() *template.Template { t, err := template.New("foo").Parse(`{{define "sign_in.html"}} Sign In {{ if .CustomLogin }} {{ end }}
{{ if eq .Footer "-" }} {{ else if eq .Footer ""}} Secured with OAuth2 Proxy version {{.Version}} {{ else }} {{.Footer}} {{ end }}
{{end}}`) if err != nil { log.Fatalf("failed parsing template %s", err) } t, err = t.Parse(`{{define "error.html"}} {{.Title}}

{{.Title}}

{{.Message}}


Sign In

{{end}}`) if err != nil { log.Fatalf("failed parsing template %s", err) } return t } ================================================ FILE: templates_test.go ================================================ package main import ( "testing" "github.com/stretchr/testify/assert" ) func TestTemplatesCompile(t *testing.T) { templates := getTemplates() assert.NotEqual(t, templates, nil) } ================================================ FILE: test.sh ================================================ #!/bin/bash EXIT_CODE=0 echo "gofmt" diff -u <(echo -n) <(gofmt -d $(find . -type f -name '*.go' -not -path "./vendor/*")) || EXIT_CODE=1 for pkg in $(go list ./... | grep -v '/vendor/' ); do echo "testing $pkg" echo "go vet $pkg" go vet "$pkg" || EXIT_CODE=1 echo "go test -v $pkg" go test -v -timeout 90s "$pkg" || EXIT_CODE=1 echo "go test -v -race $pkg" GOMAXPROCS=4 go test -v -timeout 90s0s -race "$pkg" || EXIT_CODE=1 done exit $EXIT_CODE ================================================ FILE: validator.go ================================================ package main import ( "encoding/csv" "fmt" "log" "os" "strings" "sync/atomic" "unsafe" ) type UserMap struct { usersFile string m unsafe.Pointer } func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { um := &UserMap{usersFile: usersFile} m := make(map[string]bool) atomic.StorePointer(&um.m, unsafe.Pointer(&m)) if usersFile != "" { log.Printf("using authenticated emails file %s", usersFile) WatchForUpdates(usersFile, done, func() { um.LoadAuthenticatedEmailsFile() onUpdate() }) um.LoadAuthenticatedEmailsFile() } return um } func (um *UserMap) IsValid(email string) (result bool) { m := *(*map[string]bool)(atomic.LoadPointer(&um.m)) _, result = m[email] return } func (um *UserMap) LoadAuthenticatedEmailsFile() { r, err := os.Open(um.usersFile) if err != nil { log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) } defer r.Close() csv_reader := csv.NewReader(r) csv_reader.Comma = ',' csv_reader.Comment = '#' csv_reader.TrimLeadingSpace = true records, err := csv_reader.ReadAll() if err != nil { log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) return } updated := make(map[string]bool) for _, r := range records { address := strings.ToLower(strings.TrimSpace(r[0])) updated[address] = true } atomic.StorePointer(&um.m, unsafe.Pointer(&updated)) } func newValidatorImpl(domains []string, usersFile string, done <-chan bool, onUpdate func()) func(string) bool { validUsers := NewUserMap(usersFile, done, onUpdate) var allowAll bool for i, domain := range domains { if domain == "*" { allowAll = true continue } domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain)) } validator := func(email string) (valid bool) { if email == "" { return } email = strings.ToLower(email) for _, domain := range domains { valid = valid || strings.HasSuffix(email, domain) } if !valid { valid = validUsers.IsValid(email) } if allowAll { valid = true } return valid } return validator } func NewValidator(domains []string, usersFile string) func(string) bool { return newValidatorImpl(domains, usersFile, nil, func() {}) } ================================================ FILE: validator_test.go ================================================ package main import ( "io/ioutil" "os" "strings" "testing" ) type ValidatorTest struct { auth_email_file *os.File done chan bool update_seen bool } func NewValidatorTest(t *testing.T) *ValidatorTest { vt := &ValidatorTest{} var err error vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") if err != nil { t.Fatal("failed to create temp file: " + err.Error()) } vt.done = make(chan bool, 1) return vt } func (vt *ValidatorTest) TearDown() { vt.done <- true os.Remove(vt.auth_email_file.Name()) } func (vt *ValidatorTest) NewValidator(domains []string, updated chan<- bool) func(string) bool { return newValidatorImpl(domains, vt.auth_email_file.Name(), vt.done, func() { if vt.update_seen == false { updated <- true vt.update_seen = true } }) } // This will close vt.auth_email_file. func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { defer vt.auth_email_file.Close() vt.auth_email_file.WriteString(strings.Join(emails, "\n")) if err := vt.auth_email_file.Close(); err != nil { t.Fatal("failed to close temp file " + vt.auth_email_file.Name() + ": " + err.Error()) } } func TestValidatorEmpty(t *testing.T) { vt := NewValidatorTest(t) defer vt.TearDown() vt.WriteEmails(t, []string(nil)) domains := []string(nil) validator := vt.NewValidator(domains, nil) if validator("foo.bar@example.com") { t.Error("nothing should validate when the email and " + "domain lists are empty") } } func TestValidatorSingleEmail(t *testing.T) { vt := NewValidatorTest(t) defer vt.TearDown() vt.WriteEmails(t, []string{"foo.bar@example.com"}) domains := []string(nil) validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example.com") { t.Error("email should validate") } if validator("baz.quux@example.com") { t.Error("email from same domain but not in list " + "should not validate when domain list is empty") } } func TestValidatorSingleDomain(t *testing.T) { vt := NewValidatorTest(t) defer vt.TearDown() vt.WriteEmails(t, []string(nil)) domains := []string{"example.com"} validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example.com") { t.Error("email should validate") } if !validator("baz.quux@example.com") { t.Error("email from same domain should validate") } } func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) { vt := NewValidatorTest(t) defer vt.TearDown() vt.WriteEmails(t, []string{ "xyzzy@example.com", "plugh@example.com", }) domains := []string{"example0.com", "example1.com"} validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example0.com") { t.Error("email from first domain should validate") } if !validator("baz.quux@example1.com") { t.Error("email from second domain should validate") } if !validator("xyzzy@example.com") { t.Error("first email in list should validate") } if !validator("plugh@example.com") { t.Error("second email in list should validate") } if validator("xyzzy.plugh@example.com") { t.Error("email not in list that matches no domains " + "should not validate") } } func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) { vt := NewValidatorTest(t) defer vt.TearDown() vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"}) domains := []string{"Frobozz.Com"} validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example.com") { t.Error("loaded email addresses are not lower-cased") } if !validator("Foo.Bar@Example.Com") { t.Error("validated email addresses are not lower-cased") } if !validator("foo.bar@frobozz.com") { t.Error("loaded domains are not lower-cased") } if !validator("foo.bar@Frobozz.Com") { t.Error("validated domains are not lower-cased") } } func TestValidatorIgnoreSpacesInAuthEmails(t *testing.T) { vt := NewValidatorTest(t) defer vt.TearDown() vt.WriteEmails(t, []string{" foo.bar@example.com "}) domains := []string(nil) validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example.com") { t.Error("email should validate") } } ================================================ FILE: validator_watcher_copy_test.go ================================================ // +build go1.3,!plan9,!solaris,!windows // Turns out you can't copy over an existing file on Windows. package main import ( "io/ioutil" "os" "testing" ) func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver( t *testing.T, emails []string) { orig_file := vt.auth_email_file var err error vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") if err != nil { t.Fatal("failed to create temp file for copy: " + err.Error()) } vt.WriteEmails(t, emails) err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) if err != nil { t.Fatal("failed to copy over temp file: " + err.Error()) } vt.auth_email_file = orig_file } func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { vt := NewValidatorTest(t) defer vt.TearDown() vt.WriteEmails(t, []string{"xyzzy@example.com"}) domains := []string(nil) updated := make(chan bool) validator := vt.NewValidator(domains, updated) if !validator("xyzzy@example.com") { t.Error("email in list should validate") } vt.UpdateEmailFileViaCopyingOver(t, []string{"plugh@example.com"}) <-updated if validator("xyzzy@example.com") { t.Error("email removed from list should not validate") } } ================================================ FILE: validator_watcher_test.go ================================================ // +build go1.3,!plan9,!solaris package main import ( "io/ioutil" "os" "testing" ) func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { var err error vt.auth_email_file, err = os.OpenFile( vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600) if err != nil { t.Fatal("failed to re-open temp file for updates") } vt.WriteEmails(t, emails) } func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace( t *testing.T, emails []string) { orig_file := vt.auth_email_file var err error vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") if err != nil { t.Fatal("failed to create temp file for rename and replace: " + err.Error()) } vt.WriteEmails(t, emails) moved_name := orig_file.Name() + "-moved" err = os.Rename(orig_file.Name(), moved_name) err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) if err != nil { t.Fatal("failed to rename and replace temp file: " + err.Error()) } vt.auth_email_file = orig_file os.Remove(moved_name) } func TestValidatorOverwriteEmailListDirectly(t *testing.T) { vt := NewValidatorTest(t) defer vt.TearDown() vt.WriteEmails(t, []string{ "xyzzy@example.com", "plugh@example.com", }) domains := []string(nil) updated := make(chan bool) validator := vt.NewValidator(domains, updated) if !validator("xyzzy@example.com") { t.Error("first email in list should validate") } if !validator("plugh@example.com") { t.Error("second email in list should validate") } if validator("xyzzy.plugh@example.com") { t.Error("email not in list that matches no domains " + "should not validate") } vt.UpdateEmailFile(t, []string{ "xyzzy.plugh@example.com", "plugh@example.com", }) <-updated if validator("xyzzy@example.com") { t.Error("email removed from list should not validate") } if !validator("plugh@example.com") { t.Error("email retained in list should validate") } if !validator("xyzzy.plugh@example.com") { t.Error("email added to list should validate") } } func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) { vt := NewValidatorTest(t) defer vt.TearDown() vt.WriteEmails(t, []string{"xyzzy@example.com"}) domains := []string(nil) updated := make(chan bool, 1) validator := vt.NewValidator(domains, updated) if !validator("xyzzy@example.com") { t.Error("email in list should validate") } vt.UpdateEmailFileViaRenameAndReplace(t, []string{"plugh@example.com"}) <-updated if validator("xyzzy@example.com") { t.Error("email removed from list should not validate") } } ================================================ FILE: version.go ================================================ package main const VERSION = "2.2.1-alpha" ================================================ FILE: watcher.go ================================================ // +build go1.3,!plan9,!solaris package main import ( "log" "os" "path/filepath" "time" "gopkg.in/fsnotify.v1" ) func WaitForReplacement(filename string, op fsnotify.Op, watcher *fsnotify.Watcher) { const sleep_interval = 50 * time.Millisecond // Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod. if op&fsnotify.Chmod != 0 { time.Sleep(sleep_interval) } for { if _, err := os.Stat(filename); err == nil { if err := watcher.Add(filename); err == nil { log.Printf("watching resumed for %s", filename) return } } time.Sleep(sleep_interval) } } func WatchForUpdates(filename string, done <-chan bool, action func()) { filename = filepath.Clean(filename) watcher, err := fsnotify.NewWatcher() if err != nil { log.Fatal("failed to create watcher for ", filename, ": ", err) } go func() { defer watcher.Close() for { select { case _ = <-done: log.Printf("Shutting down watcher for: %s", filename) break case event := <-watcher.Events: // On Arch Linux, it appears Chmod events precede Remove events, // which causes a race between action() and the coming Remove event. // If the Remove wins, the action() (which calls // UserMap.LoadAuthenticatedEmailsFile()) crashes when the file // can't be opened. if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 { log.Printf("watching interrupted on event: %s", event) watcher.Remove(filename) WaitForReplacement(filename, event.Op, watcher) } log.Printf("reloading after event: %s", event) action() case err := <-watcher.Errors: log.Printf("error watching %s: %s", filename, err) } } }() if err = watcher.Add(filename); err != nil { log.Fatal("failed to add ", filename, " to watcher: ", err) } log.Printf("watching %s for updates", filename) } ================================================ FILE: watcher_unsupported.go ================================================ // +build !go1.3 plan9 solaris package main import ( "log" ) func WatchForUpdates(filename string, done <-chan bool, action func()) { log.Printf("file watching not implemented on this platform") go func() { <-done }() }