Repository: juju/utils Branch: master Commit: b92083fa0866 Files: 204 Total size: 543.2 KB Directory structure: gitextract_0jyt02x8/ ├── .gitignore ├── ISSUE_TEMPLATE.md ├── LICENSE ├── LICENSE.golang ├── Makefile ├── README.md ├── SECURITY.md ├── arch/ │ ├── arch.go │ ├── arch_test.go │ └── package_test.go ├── attempt.go ├── attempt_test.go ├── bzr/ │ ├── bzr.go │ ├── bzr_test.go │ ├── bzr_unix_test.go │ └── bzr_windows_test.go ├── cache/ │ ├── cache.go │ ├── cache_test.go │ ├── export_test.go │ └── package_test.go ├── cert/ │ ├── cert.go │ ├── cert_test.go │ └── exports_test.go ├── command.go ├── command_test.go ├── context.go ├── context_test.go ├── du/ │ ├── LICENSE.ricochet2200 │ ├── diskusage.go │ └── diskusage_windows.go ├── errors.go ├── exec/ │ ├── exec.go │ ├── exec_internal_test.go │ ├── exec_linux_test.go │ ├── exec_test.go │ ├── exec_unix.go │ ├── exec_windows.go │ ├── exec_windows_test.go │ └── package_test.go ├── export_test.go ├── file.go ├── file_test.go ├── file_unix.go ├── file_unix_test.go ├── file_windows.go ├── file_windows_test.go ├── filepath/ │ ├── common.go │ ├── common_test.go │ ├── export_test.go │ ├── filepath.go │ ├── filepath_test.go │ ├── interface_test.go │ ├── package_test.go │ ├── stdlib.go │ ├── stdlib_test.go │ ├── stdlibmatch.go │ ├── unix.go │ ├── unix_test.go │ ├── win.go │ └── win_test.go ├── filestorage/ │ ├── doc.go │ ├── export_test.go │ ├── fakes_test.go │ ├── interfaces.go │ ├── metadata.go │ ├── metadata_store.go │ ├── metadata_test.go │ ├── package_test.go │ ├── wrapper.go │ └── wrapper_test.go ├── fs/ │ ├── copy.go │ └── copy_test.go ├── go.mod ├── go.sum ├── gomaxprocs.go ├── gomaxprocs_test.go ├── hash/ │ ├── fingerprint.go │ ├── fingerprint_test.go │ ├── hash.go │ ├── hash_test.go │ ├── package_test.go │ ├── writer.go │ └── writer_test.go ├── home_unix.go ├── home_unix_test.go ├── home_windows.go ├── home_windows_test.go ├── isubuntu.go ├── isubuntu_test.go ├── jsonhttp/ │ ├── jsonhttp.go │ ├── jsonhttp_test.go │ └── package_test.go ├── keyvalues/ │ ├── keyvalues.go │ ├── keyvalues_test.go │ └── package_test.go ├── limiter.go ├── limiter_test.go ├── multireader.go ├── multireader_test.go ├── naturalsort.go ├── naturalsort_test.go ├── network.go ├── network_test.go ├── os.go ├── os_test.go ├── package_test.go ├── parallel/ │ ├── package_test.go │ ├── parallel.go │ ├── parallel_test.go │ ├── try.go │ └── try_test.go ├── password.go ├── password_test.go ├── proxy/ │ ├── package_test.go │ ├── proxy.go │ └── proxy_test.go ├── randomstring.go ├── randomstring_test.go ├── registry/ │ ├── export_test.go │ ├── package_test.go │ ├── registry.go │ └── registry_test.go ├── relativeurl.go ├── relativeurl_test.go ├── setenv.go ├── setenv_test.go ├── shell/ │ ├── bash.go │ ├── bash_test.go │ ├── command.go │ ├── interface_test.go │ ├── output.go │ ├── package_test.go │ ├── powershell.go │ ├── powershell_test.go │ ├── renderer.go │ ├── renderer_test.go │ ├── script.go │ ├── script_test.go │ ├── unix.go │ ├── win.go │ ├── wincmd.go │ └── wincmd_test.go ├── size.go ├── size_test.go ├── ssh/ │ ├── authorisedkeys.go │ ├── authorisedkeys_test.go │ ├── clientkeys.go │ ├── clientkeys_test.go │ ├── export_test.go │ ├── fakes_test.go │ ├── fingerprint.go │ ├── fingerprint_test.go │ ├── generate.go │ ├── generate_test.go │ ├── package_test.go │ ├── run.go │ ├── run_test.go │ ├── ssh.go │ ├── ssh_gocrypto.go │ ├── ssh_gocrypto_test.go │ ├── ssh_openssh.go │ ├── ssh_test.go │ ├── stream.go │ ├── stream_test.go │ ├── stream_wrapper_unix.go │ ├── stream_wrapper_windows.go │ └── testing/ │ └── keys.go ├── symlink/ │ ├── export_test.go │ ├── symlink.go │ ├── symlink_posix.go │ ├── symlink_test.go │ ├── symlink_windows.go │ ├── symlink_windows_test.go │ ├── zsymlink_windows_386.go │ └── zsymlink_windows_amd64.go ├── systemerrmessages_unix.go ├── systemerrmessages_windows.go ├── tailer/ │ ├── export_test.go │ ├── package_test.go │ ├── tailer.go │ └── tailer_test.go ├── tar/ │ ├── tar.go │ └── tar_test.go ├── timer.go ├── timer_test.go ├── trivial.go ├── trivial_test.go ├── uptime/ │ ├── uptime_nix.go │ ├── uptime_windows.go │ ├── zuptime_windows_386.go │ └── zuptime_windows_amd64.go ├── username.go ├── username_test.go ├── uuid.go ├── uuid_test.go ├── voyeur/ │ ├── package_test.go │ ├── value.go │ └── value_test.go ├── yaml.go ├── yaml_test.go ├── zfile_windows.go └── zip/ ├── package_test.go ├── zip.go └── zip_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib # Test binary, built with `go test -c` *.test # Output of the go coverage tool, specifically when used with LiteIDE *.out # GoLand .idea/ # Dependency directories (remove the comment below to include it) # vendor/ ================================================ FILE: ISSUE_TEMPLATE.md ================================================ ## Issues tracked in Launchpad Please file an issue against https://bugs.launchpad.net/juju/+filebug ================================================ FILE: LICENSE ================================================ All files in this repository are licensed as follows. If you contribute to this repository, it is assumed that you license your contribution under the same license unless you state otherwise. All files Copyright (C) 2015 Canonical Ltd. unless otherwise specified in the file. This software is licensed under the LGPLv3, included below. As a special exception to the GNU Lesser General Public License version 3 ("LGPL3"), the copyright holders of this Library give you permission to convey to a third party a Combined Work that links statically or dynamically to this Library without providing any Minimal Corresponding Source or Minimal Application Code as set out in 4d or providing the installation information set out in section 4e, provided that you comply with the other provisions of LGPL3 and provided that you meet, for the Application the terms and conditions of the license(s) which apply to the Application. Except as stated in this special exception, the provisions of LGPL3 will continue to comply in full to this Library. If you modify this Library, you may apply this exception to your version of this Library, but you are not obliged to do so. If you do not wish to do so, delete this exception statement from your version. This exception does not (and cannot) modify any license terms which apply to the Application, with which you must still comply. GNU LESSER GENERAL PUBLIC LICENSE Version 3, 29 June 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. This version of the GNU Lesser General Public License incorporates the terms and conditions of version 3 of the GNU General Public License, supplemented by the additional permissions listed below. 0. Additional Definitions. As used herein, "this License" refers to version 3 of the GNU Lesser General Public License, and the "GNU GPL" refers to version 3 of the GNU General Public License. "The Library" refers to a covered work governed by this License, other than an Application or a Combined Work as defined below. An "Application" is any work that makes use of an interface provided by the Library, but which is not otherwise based on the Library. Defining a subclass of a class defined by the Library is deemed a mode of using an interface provided by the Library. A "Combined Work" is a work produced by combining or linking an Application with the Library. The particular version of the Library with which the Combined Work was made is also called the "Linked Version". The "Minimal Corresponding Source" for a Combined Work means the Corresponding Source for the Combined Work, excluding any source code for portions of the Combined Work that, considered in isolation, are based on the Application, and not on the Linked Version. The "Corresponding Application Code" for a Combined Work means the object code and/or source code for the Application, including any data and utility programs needed for reproducing the Combined Work from the Application, but excluding the System Libraries of the Combined Work. 1. Exception to Section 3 of the GNU GPL. You may convey a covered work under sections 3 and 4 of this License without being bound by section 3 of the GNU GPL. 2. Conveying Modified Versions. If you modify a copy of the Library, and, in your modifications, a facility refers to a function or data to be supplied by an Application that uses the facility (other than as an argument passed when the facility is invoked), then you may convey a copy of the modified version: a) under this License, provided that you make a good faith effort to ensure that, in the event an Application does not supply the function or data, the facility still operates, and performs whatever part of its purpose remains meaningful, or b) under the GNU GPL, with none of the additional permissions of this License applicable to that copy. 3. Object Code Incorporating Material from Library Header Files. The object code form of an Application may incorporate material from a header file that is part of the Library. You may convey such object code under terms of your choice, provided that, if the incorporated material is not limited to numerical parameters, data structure layouts and accessors, or small macros, inline functions and templates (ten or fewer lines in length), you do both of the following: a) Give prominent notice with each copy of the object code that the Library is used in it and that the Library and its use are covered by this License. b) Accompany the object code with a copy of the GNU GPL and this license document. 4. Combined Works. You may convey a Combined Work under terms of your choice that, taken together, effectively do not restrict modification of the portions of the Library contained in the Combined Work and reverse engineering for debugging such modifications, if you also do each of the following: a) Give prominent notice with each copy of the Combined Work that the Library is used in it and that the Library and its use are covered by this License. b) Accompany the Combined Work with a copy of the GNU GPL and this license document. c) For a Combined Work that displays copyright notices during execution, include the copyright notice for the Library among these notices, as well as a reference directing the user to the copies of the GNU GPL and this license document. d) Do one of the following: 0) Convey the Minimal Corresponding Source under the terms of this License, and the Corresponding Application Code in a form suitable for, and under terms that permit, the user to recombine or relink the Application with a modified version of the Linked Version to produce a modified Combined Work, in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source. 1) Use a suitable shared library mechanism for linking with the Library. A suitable mechanism is one that (a) uses at run time a copy of the Library already present on the user's computer system, and (b) will operate properly with a modified version of the Library that is interface-compatible with the Linked Version. e) Provide Installation Information, but only if you would otherwise be required to provide such information under section 6 of the GNU GPL, and only to the extent that such information is necessary to install and execute a modified version of the Combined Work produced by recombining or relinking the Application with a modified version of the Linked Version. (If you use option 4d0, the Installation Information must accompany the Minimal Corresponding Source and Corresponding Application Code. If you use option 4d1, you must provide the Installation Information in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source.) 5. Combined Libraries. You may place library facilities that are a work based on the Library side by side in a single library together with other library facilities that are not Applications and are not covered by this License, and convey such a combined library under terms of your choice, if you do both of the following: a) Accompany the combined library with a copy of the same work based on the Library, uncombined with any other library facilities, conveyed under the terms of this License. b) Give prominent notice with the combined library that part of it is a work based on the Library, and explaining where to find the accompanying uncombined form of the same work. 6. Revised Versions of the GNU Lesser General Public License. The Free Software Foundation may publish revised and/or new versions of the GNU Lesser General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Library as you received it specifies that a certain numbered version of the GNU Lesser General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that published version or of any later version published by the Free Software Foundation. If the Library as you received it does not specify a version number of the GNU Lesser General Public License, you may choose any version of the GNU Lesser General Public License ever published by the Free Software Foundation. If the Library as you received it specifies that a proxy can decide whether future versions of the GNU Lesser General Public License shall apply, that proxy's public statement of acceptance of any version is permanent authorization for you to choose that version for the Library. ================================================ FILE: LICENSE.golang ================================================ This licence applies to the following files: * filepath/stdlib.go * filepath/stdlibmatch.go Copyright (c) 2010 The Go Authors. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: Makefile ================================================ PROJECT := github.com/juju/utils/v4 .PHONY: check-licence check-go check check: check-licence check-go go test -v $(PROJECT)/... check-licence: @(grep -rFl "Licensed under the LGPLv3" .;\ grep -rFl "MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT" .;\ grep -rFl "license that can be found in the LICENSE.ricochet2200 file" .; \ find . -name "*.go") | sed -e 's,\./,,' | sort | uniq -u | \ xargs -I {} echo FAIL: licence missed: {} check-go: $(eval GOFMT := $(strip $(shell gofmt -l .| sed -e "s/^/ /g"))) @(if [ x$(GOFMT) != x"" ]; then \ echo go fmt is sad: $(GOFMT); \ exit 1; \ fi ) @(go vet -all -composites=false -copylocks=false .) # Install packages required to develop in utils and run tests. install-dependencies: install-snap-dependencies install-mongo-dependencies @echo Installing dependencies @echo Installing bzr @sudo apt install bzr --yes @echo Installing zip @sudo apt install zip --yes install-snap-dependencies: ## install-snap-dependencies: Install the supported snap dependencies @echo Installing go-1.17 snap @sudo snap install go --channel=1.17/stable --classic install-mongo-dependencies: ## install-mongo-dependencies: Install Mongo and its dependencies @echo Adding juju PPA for mongodb @sudo apt-add-repository --yes ppa:juju/stable @sudo apt-get update @echo Installing mongodb @sudo apt-get --yes install \ $(strip $(DEPENDENCIES)) \ $(shell apt-cache madison mongodb-server-core juju-mongodb3.2 juju-mongodb mongodb-server | head -1 | cut -d '|' -f1) ================================================ FILE: README.md ================================================ juju/utils ============ This package provides general utility packages and functions. ================================================ FILE: SECURITY.md ================================================ # Security policy ## Reporting a vulnerability Please provide a description of the issue, the steps you took to create the issue, affected versions, and, if known, mitigations for the issue. The preferred way to report a security issue is through [GitHub's security advisory for this project](https://github.com/juju/utils/security/advisories/new). See [Privately reporting a security vulnerability](https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing/privately-reporting-a-security-vulnerability) for instructions on reporting using GitHub's security advisory feature. The [Ubuntu Security disclosure and embargo policy](https://ubuntu.com/security/disclosure-policy) contains more information about how can contact us, what you can expect when you contact us, and what we expect from you. ================================================ FILE: arch/arch.go ================================================ // Copyright 2014-2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package arch import ( "regexp" "runtime" "strings" ) // The following constants define the machine architectures supported by Juju. const ( AMD64 = "amd64" I386 = "i386" ARM = "armhf" ARM64 = "arm64" PPC64EL = "ppc64el" S390X = "s390x" RISCV64 = "riscv64" // Older versions of Juju used "ppc64" instead of ppc64el LEGACY_PPC64 = "ppc64" ) // AllSupportedArches records the machine architectures recognised by Juju. var AllSupportedArches = []string{ AMD64, I386, ARM, ARM64, PPC64EL, S390X, RISCV64, } // Info records the information regarding each architecture recognised by Juju. var Info = map[string]ArchInfo{ AMD64: {64}, I386: {32}, ARM: {32}, ARM64: {64}, PPC64EL: {64}, S390X: {64}, RISCV64: {64}, } // ArchInfo is a struct containing information about a supported architecture. type ArchInfo struct { // WordSize is the architecture's word size, in bits. WordSize int } // archREs maps regular expressions for matching // `uname -m` to architectures recognised by Juju. var archREs = []struct { *regexp.Regexp arch string }{ {regexp.MustCompile("amd64|x86_64"), AMD64}, {regexp.MustCompile("i?[3-9]86"), I386}, {regexp.MustCompile("(arm$)|(armv.*)"), ARM}, {regexp.MustCompile("aarch64"), ARM64}, {regexp.MustCompile("ppc64|ppc64el|ppc64le"), PPC64EL}, {regexp.MustCompile("s390x"), S390X}, {regexp.MustCompile("riscv64|risc$|risc-[vV]64"), RISCV64}, } // Override for testing. var HostArch = hostArch // hostArch returns the Juju architecture of the machine on which it is run. func hostArch() string { return NormaliseArch(runtime.GOARCH) } // NormaliseArch returns the Juju architecture corresponding to a machine's // reported architecture. The Juju architecture is used to filter simple // streams lookup of tools and images. func NormaliseArch(rawArch string) string { rawArch = strings.TrimSpace(rawArch) for _, re := range archREs { if re.Match([]byte(rawArch)) { return re.arch } } return rawArch } // IsSupportedArch returns true if arch is one supported by Juju. func IsSupportedArch(arch string) bool { for _, a := range AllSupportedArches { if a == arch { return true } } return false } ================================================ FILE: arch/arch_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package arch_test import ( jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/arch" ) type archSuite struct { } var _ = gc.Suite(&archSuite{}) func (s *archSuite) TestHostArch(c *gc.C) { a := arch.HostArch() c.Assert(arch.IsSupportedArch(a), jc.IsTrue) } func (s *archSuite) TestNormaliseArch(c *gc.C) { for _, test := range []struct { raw string arch string }{ {"windows", "windows"}, {"amd64", "amd64"}, {"x86_64", "amd64"}, {"386", "i386"}, {"i386", "i386"}, {"i486", "i386"}, {"arm", "armhf"}, {"armv", "armhf"}, {"armv7", "armhf"}, {"aarch64", "arm64"}, {"arm64", "arm64"}, {"ppc64el", "ppc64el"}, {"ppc64le", "ppc64el"}, {"ppc64", "ppc64el"}, {"s390x", "s390x"}, {"riscv64", "riscv64"}, {"risc", "riscv64"}, {"risc-v64", "riscv64"}, {"risc-V64", "riscv64"}, } { arch := arch.NormaliseArch(test.raw) c.Check(arch, gc.Equals, test.arch) } } func (s *archSuite) TestIsSupportedArch(c *gc.C) { for _, a := range arch.AllSupportedArches { c.Assert(arch.IsSupportedArch(a), jc.IsTrue) } c.Assert(arch.IsSupportedArch("invalid"), jc.IsFalse) } func (s *archSuite) TestArchInfo(c *gc.C) { for _, a := range arch.AllSupportedArches { _, ok := arch.Info[a] c.Assert(ok, jc.IsTrue) } } ================================================ FILE: arch/package_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package arch_test import ( "testing" gc "gopkg.in/check.v1" ) func Test(t *testing.T) { gc.TestingT(t) } ================================================ FILE: attempt.go ================================================ // Copyright 2011, 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "time" ) // The Attempt and AttemptStrategy types are copied from those in launchpad.net/goamz/aws. // AttemptStrategy represents a strategy for waiting for an action // to complete successfully. type AttemptStrategy struct { Total time.Duration // total duration of attempt. Delay time.Duration // interval between each try in the burst. Min int // minimum number of retries; overrides Total } type Attempt struct { strategy AttemptStrategy last time.Time end time.Time force bool count int } // Start begins a new sequence of attempts for the given strategy. func (s AttemptStrategy) Start() *Attempt { now := time.Now() return &Attempt{ strategy: s, last: now, end: now.Add(s.Total), force: true, } } // Next waits until it is time to perform the next attempt or returns // false if it is time to stop trying. // It always returns true the first time it is called - we are guaranteed to // make at least one attempt. func (a *Attempt) Next() bool { now := time.Now() sleep := a.nextSleep(now) if !a.force && !now.Add(sleep).Before(a.end) && a.strategy.Min <= a.count { return false } a.force = false if sleep > 0 && a.count > 0 { time.Sleep(sleep) now = time.Now() } a.count++ a.last = now return true } func (a *Attempt) nextSleep(now time.Time) time.Duration { sleep := a.strategy.Delay - now.Sub(a.last) if sleep < 0 { return 0 } return sleep } // HasNext returns whether another attempt will be made if the current // one fails. If it returns true, the following call to Next is // guaranteed to return true. func (a *Attempt) HasNext() bool { if a.force || a.strategy.Min > a.count { return true } now := time.Now() if now.Add(a.nextSleep(now)).Before(a.end) { a.force = true return true } return false } ================================================ FILE: attempt_test.go ================================================ // Copyright 2011, 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "time" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) func doSomething() (int, error) { return 0, nil } func shouldRetry(error) bool { return false } func doSomethingWith(int) {} func ExampleAttempt_HasNext() { // This example shows how Attempt.HasNext can be used to help // structure an attempt loop. If the godoc example code allowed // us to make the example return an error, we would uncomment // the commented return statements. attempts := utils.AttemptStrategy{ Total: 1 * time.Second, Delay: 250 * time.Millisecond, } for attempt := attempts.Start(); attempt.Next(); { x, err := doSomething() if shouldRetry(err) && attempt.HasNext() { continue } if err != nil { // return err return } doSomethingWith(x) } // return ErrTimedOut return } func (*utilsSuite) TestAttemptTiming(c *gc.C) { testAttempt := utils.AttemptStrategy{ Total: 0.25e9, Delay: 0.1e9, } want := []time.Duration{0, 0.1e9, 0.2e9, 0.2e9} got := make([]time.Duration, 0, len(want)) // avoid allocation when testing timing t0 := time.Now() for a := testAttempt.Start(); a.Next(); { got = append(got, time.Now().Sub(t0)) } got = append(got, time.Now().Sub(t0)) c.Assert(got, gc.HasLen, len(want)) const margin = 0.01e9 for i, got := range want { lo := want[i] - margin hi := want[i] + margin if got < lo || got > hi { c.Errorf("attempt %d want %g got %g", i, want[i].Seconds(), got.Seconds()) } } } func (*utilsSuite) TestAttemptNextHasNext(c *gc.C) { a := utils.AttemptStrategy{}.Start() c.Assert(a.Next(), gc.Equals, true) c.Assert(a.Next(), gc.Equals, false) a = utils.AttemptStrategy{}.Start() c.Assert(a.Next(), gc.Equals, true) c.Assert(a.HasNext(), gc.Equals, false) c.Assert(a.Next(), gc.Equals, false) a = utils.AttemptStrategy{Total: 2e8}.Start() c.Assert(a.Next(), gc.Equals, true) c.Assert(a.HasNext(), gc.Equals, true) time.Sleep(2e8) c.Assert(a.HasNext(), gc.Equals, true) c.Assert(a.Next(), gc.Equals, true) c.Assert(a.Next(), gc.Equals, false) a = utils.AttemptStrategy{Total: 1e8, Min: 2}.Start() time.Sleep(1e8) c.Assert(a.Next(), gc.Equals, true) c.Assert(a.HasNext(), gc.Equals, true) c.Assert(a.Next(), gc.Equals, true) c.Assert(a.HasNext(), gc.Equals, false) c.Assert(a.Next(), gc.Equals, false) } ================================================ FILE: bzr/bzr.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // Package bzr offers an interface to manage branches of the Bazaar VCS. package bzr import ( "bytes" "fmt" "os" "os/exec" "path" "strings" ) // Branch represents a Bazaar branch. type Branch struct { location string env []string } // New returns a new Branch for the Bazaar branch at location. func New(location string) *Branch { b := &Branch{location, cenv()} if _, err := os.Stat(location); err == nil { stdout, _, err := b.bzr("root") if err == nil { // Need to trim \r as well as \n for Windows compatibility b.location = strings.TrimRight(string(stdout), "\r\n") } } return b } // cenv returns a copy of the current process environment with LC_ALL=C. func cenv() []string { env := os.Environ() for i, pair := range env { if strings.HasPrefix(pair, "LC_ALL=") { env[i] = "LC_ALL=C" return env } } return append(env, "LC_ALL=C") } // Location returns the location of branch b. func (b *Branch) Location() string { return b.location } // Join returns b's location with parts appended as path components. // In other words, if b's location is "lp:foo", and parts is {"bar, baz"}, // Join returns "lp:foo/bar/baz". func (b *Branch) Join(parts ...string) string { return path.Join(append([]string{b.location}, parts...)...) } func (b *Branch) bzr(subcommand string, args ...string) (stdout, stderr []byte, err error) { cmd := exec.Command("bzr", append([]string{subcommand}, args...)...) if _, err := os.Stat(b.location); err == nil { cmd.Dir = b.location } errbuf := &bytes.Buffer{} cmd.Stderr = errbuf cmd.Env = b.env stdout, err = cmd.Output() // Some commands fail with exit status 0 (e.g. bzr root). :-( if err != nil || bytes.Contains(errbuf.Bytes(), []byte("ERROR")) { var errmsg string if err != nil { errmsg = err.Error() } return nil, nil, fmt.Errorf(`error running "bzr %s": %s%s%s`, subcommand, stdout, errbuf.Bytes(), errmsg) } return stdout, errbuf.Bytes(), err } // Init intializes a new branch at b's location. func (b *Branch) Init() error { _, _, err := b.bzr("init", b.location) return err } // Add adds to b the path resultant from calling b.Join(parts...). func (b *Branch) Add(parts ...string) error { _, _, err := b.bzr("add", b.Join(parts...)) return err } // Commit commits pending changes into b. func (b *Branch) Commit(message string) error { _, _, err := b.bzr("commit", "-q", "-m", message) return err } // RevisionId returns the Bazaar revision id for the tip of b. func (b *Branch) RevisionId() (string, error) { stdout, stderr, err := b.bzr("revision-info", "-d", b.location) if err != nil { return "", err } pair := bytes.Fields(stdout) if len(pair) != 2 { return "", fmt.Errorf(`invalid output from "bzr revision-info": %s%s`, stdout, stderr) } id := string(pair[1]) if id == "null:" { return "", fmt.Errorf("branch has no content") } return id, nil } // PushLocation returns the default push location for b. func (b *Branch) PushLocation() (string, error) { stdout, _, err := b.bzr("info", b.location) if err != nil { return "", err } if i := bytes.Index(stdout, []byte("push branch:")); i >= 0 { return string(stdout[i+13 : i+bytes.IndexAny(stdout[i:], "\r\n")]), nil } return "", fmt.Errorf("no push branch location defined") } // PushAttr holds options for the Branch.Push method. type PushAttr struct { Location string // Location to push to. Use the default push location if empty. Remember bool // Whether to remember the location being pushed to as the default. } // Push pushes any new revisions in b to attr.Location if that's // provided, or to the default push location otherwise. // See PushAttr for other options. func (b *Branch) Push(attr *PushAttr) error { var args []string if attr != nil { if attr.Remember { args = append(args, "--remember") } if attr.Location != "" { args = append(args, attr.Location) } } _, _, err := b.bzr("push", args...) return err } // CheckClean returns an error if 'bzr status' is not clean. func (b *Branch) CheckClean() error { stdout, _, err := b.bzr("status", b.location) if err != nil { return err } if bytes.Count(stdout, []byte{'\n'}) == 1 && bytes.Contains(stdout, []byte(`See "bzr shelve --list" for details.`)) { return nil // Shelves are fine. } if len(stdout) > 0 { return fmt.Errorf("branch is not clean (bzr status)") } return nil } ================================================ FILE: bzr/bzr_test.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. package bzr_test import ( "io/ioutil" "os" "os/exec" "path/filepath" stdtesting "testing" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/bzr" ) func Test(t *stdtesting.T) { gc.TestingT(t) } var _ = gc.Suite(&BzrSuite{}) type BzrSuite struct { testing.CleanupSuite b *bzr.Branch } const bzr_config = `[DEFAULT] email = testing ` func (s *BzrSuite) SetUpTest(c *gc.C) { s.CleanupSuite.SetUpTest(c) bzrdir := c.MkDir() s.PatchEnvironment("BZR_HOME", bzrdir) err := os.MkdirAll(filepath.Join(bzrdir, bzrHome), 0755) c.Assert(err, jc.ErrorIsNil) err = ioutil.WriteFile( filepath.Join(bzrdir, bzrHome, "bazaar.conf"), []byte(bzr_config), 0644) c.Assert(err, jc.ErrorIsNil) s.b = bzr.New(c.MkDir()) c.Assert(s.b.Init(), gc.IsNil) } func (s *BzrSuite) TestNewFindsRoot(c *gc.C) { err := os.Mkdir(s.b.Join("dir"), 0755) c.Assert(err, jc.ErrorIsNil) b := bzr.New(s.b.Join("dir")) // When bzr has to search for the root, it will expand any symlinks it // found along the way. path, err := filepath.EvalSymlinks(s.b.Location()) c.Assert(err, jc.ErrorIsNil) c.Assert(b.Location(), jc.SamePath, path) } func (s *BzrSuite) TestJoin(c *gc.C) { path := bzr.New("lp:foo").Join("baz", "bar") c.Assert(path, gc.Equals, "lp:foo/baz/bar") } func (s *BzrSuite) TestErrorHandling(c *gc.C) { err := bzr.New("/non/existent/path").Init() c.Assert(err, gc.ErrorMatches, `(?s)error running "bzr init":.*does not exist.*`) } func (s *BzrSuite) TestInit(c *gc.C) { _, err := os.Stat(s.b.Join(".bzr")) c.Assert(err, jc.ErrorIsNil) } func (s *BzrSuite) TestRevisionIdOnEmpty(c *gc.C) { revid, err := s.b.RevisionId() c.Assert(err, gc.ErrorMatches, "branch has no content") c.Assert(revid, gc.Equals, "") } func (s *BzrSuite) TestCommit(c *gc.C) { f, err := os.Create(s.b.Join("myfile")) c.Assert(err, jc.ErrorIsNil) f.Close() err = s.b.Add("myfile") c.Assert(err, jc.ErrorIsNil) err = s.b.Commit("my log message") c.Assert(err, jc.ErrorIsNil) revid, err := s.b.RevisionId() c.Assert(err, jc.ErrorIsNil) cmd := exec.Command("bzr", "log", "--long", "--show-ids", "-v", s.b.Location()) output, err := cmd.CombinedOutput() c.Assert(err, jc.ErrorIsNil) c.Assert(string(output), gc.Matches, "(?s).*revision-id: "+revid+"\n.*message:\n.*my log message\n.*added:\n.*myfile .*") } func (s *BzrSuite) TestPush(c *gc.C) { b1 := bzr.New(c.MkDir()) b2 := bzr.New(c.MkDir()) b3 := bzr.New(c.MkDir()) c.Assert(b1.Init(), gc.IsNil) c.Assert(b2.Init(), gc.IsNil) c.Assert(b3.Init(), gc.IsNil) // Create and add b1/file to the branch. f, err := os.Create(b1.Join("file")) c.Assert(err, jc.ErrorIsNil) f.Close() err = b1.Add("file") c.Assert(err, jc.ErrorIsNil) err = b1.Commit("added file") c.Assert(err, jc.ErrorIsNil) // Push file to b2. err = b1.Push(&bzr.PushAttr{Location: b2.Location()}) c.Assert(err, jc.ErrorIsNil) // Push location should be set to b2. location, err := b1.PushLocation() c.Assert(err, jc.ErrorIsNil) c.Assert(location, jc.SamePath, b2.Location()) // Now push it to b3. err = b1.Push(&bzr.PushAttr{Location: b3.Location()}) c.Assert(err, jc.ErrorIsNil) // Push location is still set to b2. location, err = b1.PushLocation() c.Assert(err, jc.ErrorIsNil) c.Assert(location, jc.SamePath, b2.Location()) // Push it again, this time with the remember flag set. err = b1.Push(&bzr.PushAttr{Location: b3.Location(), Remember: true}) c.Assert(err, jc.ErrorIsNil) // Now the push location has shifted to b3. location, err = b1.PushLocation() c.Assert(err, jc.ErrorIsNil) c.Assert(location, jc.SamePath, b3.Location()) // Both b2 and b3 should have the file. _, err = os.Stat(b2.Join("file")) c.Assert(err, jc.ErrorIsNil) _, err = os.Stat(b3.Join("file")) c.Assert(err, jc.ErrorIsNil) } func (s *BzrSuite) TestCheckClean(c *gc.C) { err := s.b.CheckClean() c.Assert(err, jc.ErrorIsNil) // Create and add b1/file to the branch. f, err := os.Create(s.b.Join("file")) c.Assert(err, jc.ErrorIsNil) f.Close() err = s.b.CheckClean() c.Assert(err, gc.ErrorMatches, `branch is not clean \(bzr status\)`) } ================================================ FILE: bzr/bzr_unix_test.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. //go:build !windows // +build !windows package bzr_test const bzrHome = ".bazaar" ================================================ FILE: bzr/bzr_windows_test.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. //go:build windows // +build windows package bzr_test const bzrHome = "Bazaar/2.0" ================================================ FILE: cache/cache.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // Package cache provides a simple caching mechanism // that limits the age of cache entries and tries to avoid large // repopulation events by staggering refresh times. package cache import ( "math/rand" "sync" "time" "github.com/juju/errors" ) // entry holds a cache entry. The expire field // holds the time after which the entry will be // considered invalid. type entry struct { value any expire time.Time } // Key represents a cache key. It must be a comparable type. type Key any // Cache holds a time-limited set of values for arbitrary keys. type Cache struct { maxAge time.Duration // mu guards the fields below it. mu sync.Mutex // expire holds when the cache is due to expire. expire time.Time // We hold two maps so that can avoid scanning through all the // items in the cache when the cache needs to be refreshed. // Instead, we move items from old to new when they're accessed // and throw away the old map at refresh time. old, new map[Key]entry inFlight map[Key]*fetchCall } // fetch represents an in-progress fetch call. If a cache Get request // is made for an item that is currently being fetched, this will // be used to avoid an extra call to the fetch function. type fetchCall struct { wg sync.WaitGroup val any err error } // New returns a new Cache that will cache items for // at most maxAge. If maxAge is zero, items will // never be cached. func New(maxAge time.Duration) *Cache { // The returned cache will have a zero-valued expire // time, so will expire immediately, causing the new // map to be created. return &Cache{ maxAge: maxAge, inFlight: make(map[Key]*fetchCall), } } // Len returns the total number of cached entries. func (c *Cache) Len() int { c.mu.Lock() defer c.mu.Unlock() return len(c.old) + len(c.new) } // Evict removes the entry with the given key from the cache if present. func (c *Cache) Evict(key Key) { c.mu.Lock() defer c.mu.Unlock() delete(c.new, key) delete(c.old, key) } // EvictAll removes all entries from the cache. func (c *Cache) EvictAll() { c.mu.Lock() defer c.mu.Unlock() c.new = make(map[Key]entry) c.old = nil } // Get returns the value for the given key, using fetch to fetch // the value if it is not found in the cache. // If fetch returns an error, the returned error from Get will have // the same cause. func (c *Cache) Get(key Key, fetch func() (any, error)) (any, error) { return c.getAtTime(key, fetch, time.Now()) } // getAtTime is the internal version of Get, useful for testing; now represents the current // time. func (c *Cache) getAtTime(key Key, fetch func() (any, error), now time.Time) (any, error) { if val, ok := c.cachedValue(key, now); ok { return val, nil } c.mu.Lock() if f, ok := c.inFlight[key]; ok { // There's already an in-flight request for the key, so wait // for that to complete and use its results. c.mu.Unlock() f.wg.Wait() // The value will have been added to the cache by the first fetch, // so no need to add it here. if f.err == nil { return f.val, nil } return nil, errors.Trace(f.err) } var f fetchCall f.wg.Add(1) c.inFlight[key] = &f // Mark the request as done when we return, and after // the value has been added to the cache. defer f.wg.Done() // Fetch the data without the mutex held // so that one slow fetch doesn't hold up // all the other cache accesses. c.mu.Unlock() val, err := fetch() c.mu.Lock() defer c.mu.Unlock() // Set the result in the fetchCall so that other calls can see it. f.val, f.err = val, err if err == nil && c.maxAge >= 2*time.Nanosecond { // If maxAge is < 2ns then the expiry code will panic because the // actual expiry time will be maxAge - a random value in the // interval [0, maxAge/2). If maxAge is < 2ns then this requires // a random interval in [0, 0) which causes a panic. // // This value is so small that there's no need to cache anyway, // which makes tests more obviously deterministic when using // a zero expiry time. c.new[key] = entry{ value: val, expire: now.Add(c.maxAge - time.Duration(rand.Int63n(int64(c.maxAge/2)))), } } delete(c.inFlight, key) if err == nil { return f.val, nil } return nil, errors.Trace(err) } // cachedValue returns any cached value for the given key // and whether it was found. func (c *Cache) cachedValue(key Key, now time.Time) (any, bool) { c.mu.Lock() defer c.mu.Unlock() if now.After(c.expire) { c.old = c.new c.new = make(map[Key]entry) c.expire = now.Add(c.maxAge) } if e, ok := c.entry(c.new, key, now); ok { return e.value, true } if e, ok := c.entry(c.old, key, now); ok { // An old entry has been accessed; move it to the new // map so that we only use a single map access for // subsequent lookups. Note that because we use the same // duration for cache refresh (c.expire) as for max // entry age, this is strictly speaking unnecessary // because any entries in old will have expired by the // time it is dropped. c.new[key] = e delete(c.old, key) return e.value, true } return nil, false } // entry returns an entry from the map and whether it // was found. If the entry has expired, it is deleted from the map. func (c *Cache) entry(m map[Key]entry, key Key, now time.Time) (entry, bool) { e, ok := m[key] if !ok { return entry{}, false } if now.After(e.expire) { // Delete expired entries. delete(m, key) return entry{}, false } return e, true } ================================================ FILE: cache/cache_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package cache_test import ( "fmt" "sync" "time" "github.com/juju/errors" "github.com/juju/utils/v4/cache" gc "gopkg.in/check.v1" ) type suite struct{} var _ = gc.Suite(&suite{}) func (*suite) TestSimpleGet(c *gc.C) { p := cache.New(time.Hour) v, err := p.Get("a", fetchValue(2)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, 2) } func (*suite) TestEvict(c *gc.C) { p := cache.New(time.Hour) v, err := p.Get("a", fetchValue(2)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, 2) v, err = p.Get("a", fetchValue(4)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, 2) p.Evict("a") v, err = p.Get("a", fetchValue(3)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, 3) v, err = p.Get("a", fetchValue(4)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, 3) } func (*suite) TestEvictOld(c *gc.C) { // Test that evict removes entries even when they're // in the old map. now := time.Now() p := cache.New(time.Minute) // Populate the cache with an initial entry. v, err := cache.GetAtTime(p, "a", fetchValue("a"), now) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "a") c.Assert(p.Len(), gc.Equals, 1) v, err = cache.GetAtTime(p, "b", fetchValue("b"), now.Add(time.Minute/2)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "b") c.Assert(p.Len(), gc.Equals, 2) // Fetch an item after the expiry time, // causing current entries to be moved to old. v, err = cache.GetAtTime(p, "a", fetchValue("a1"), now.Add(time.Minute+1)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "a1") c.Assert(p.Len(), gc.Equals, 2) c.Assert(cache.OldLen(p), gc.Equals, 1) p.Evict("b") v, err = cache.GetAtTime(p, "b", fetchValue("b1"), now.Add(time.Minute+2)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "b1") } func (*suite) TestFetchError(c *gc.C) { p := cache.New(time.Hour) expectErr := errors.New("hello") v, err := p.Get("a", fetchError(expectErr)) c.Assert(err, gc.ErrorMatches, "hello") c.Assert(errors.Cause(err), gc.Equals, expectErr) c.Assert(v, gc.Equals, nil) } func (*suite) TestFetchOnlyOnce(c *gc.C) { p := cache.New(time.Hour) v, err := p.Get("a", fetchValue(2)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, 2) v, err = p.Get("a", fetchError(errUnexpectedFetch)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, 2) } func (*suite) TestEntryExpiresAfterMaxEntryAge(c *gc.C) { now := time.Now() p := cache.New(time.Minute) v, err := cache.GetAtTime(p, "a", fetchValue(2), now) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, 2) // Entry is definitely not expired before half the entry expiry time. v, err = cache.GetAtTime(p, "a", fetchError(errUnexpectedFetch), now.Add(time.Minute/2-1)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, 2) // Entry is definitely expired after the entry expiry time v, err = cache.GetAtTime(p, "a", fetchValue(3), now.Add(time.Minute+1)) c.Assert(v, gc.Equals, 3) } func (*suite) TestEntriesRemovedWhenNotRetrieved(c *gc.C) { now := time.Now() p := cache.New(time.Minute) // Populate the cache with an initial entry. v, err := cache.GetAtTime(p, "a", fetchValue("a"), now) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "a") c.Assert(p.Len(), gc.Equals, 1) // Fetch another item after the expiry time, // causing current entries to be moved to old. v, err = cache.GetAtTime(p, "b", fetchValue("b"), now.Add(time.Minute+1)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "b") c.Assert(p.Len(), gc.Equals, 2) c.Assert(cache.OldLen(p), gc.Equals, 1) // Fetch the other item after another expiry time // causing the old entries to be discarded because // nothing has fetched them. v, err = cache.GetAtTime(p, "b", fetchValue("b"), now.Add(time.Minute*2+2)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "b") c.Assert(p.Len(), gc.Equals, 1) } // TestRefreshedEntry tests the code path where a value is moved // from the old map to new. func (*suite) TestRefreshedEntry(c *gc.C) { now := time.Now() p := cache.New(time.Minute) // Populate the cache with an initial entry. v, err := cache.GetAtTime(p, "a", fetchValue("a"), now) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "a") c.Assert(p.Len(), gc.Equals, 1) // Fetch another item very close to the expiry time. v, err = cache.GetAtTime(p, "b", fetchValue("b"), now.Add(time.Minute-1)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "b") c.Assert(p.Len(), gc.Equals, 2) // Fetch it again just after the expiry time, // which should move it into the new map. v, err = cache.GetAtTime(p, "b", fetchError(errUnexpectedFetch), now.Add(time.Minute+1)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "b") c.Assert(p.Len(), gc.Equals, 2) // Fetch another item, causing "a" to be removed from the cache // and keeping "b" in there. v, err = cache.GetAtTime(p, "c", fetchValue("c"), now.Add(time.Minute*2+2)) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, "c") c.Assert(p.Len(), gc.Equals, 2) } // TestConcurrentFetch checks that the cache is safe // to use concurrently. It is designed to fail when // tested with the race detector enabled. func (*suite) TestConcurrentFetch(c *gc.C) { p := cache.New(time.Minute) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() v, err := p.Get("a", fetchValue("a")) c.Check(err, gc.IsNil) c.Check(v, gc.Equals, "a") }() wg.Add(1) go func() { defer wg.Done() v, err := p.Get("b", fetchValue("b")) c.Check(err, gc.IsNil) c.Check(v, gc.Equals, "b") }() wg.Wait() } func (*suite) TestRefreshSpread(c *gc.C) { now := time.Now() p := cache.New(time.Minute) // Get all values to start with. const N = 100 for i := 0; i < N; i++ { v, err := cache.GetAtTime(p, fmt.Sprint(i), fetchValue(i), now) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, i) } counts := make([]int, time.Minute/time.Millisecond/10+1) // Continually get values over the course of the // expiry time; the fetches should be spread out. slot := 0 for t := now.Add(0); t.Before(now.Add(time.Minute + 1)); t = t.Add(time.Millisecond * 10) { for i := 0; i < N; i++ { cache.GetAtTime(p, fmt.Sprint(i), func() (any, error) { counts[slot]++ return i, nil }, t) } slot++ } // There should be no fetches in the first half of the cycle. for i := 0; i < len(counts)/2; i++ { c.Assert(counts[i], gc.Equals, 0, gc.Commentf("slot %d", i)) } max := 0 total := 0 for _, count := range counts { if count > max { max = count } total += count } if max > 10 { c.Errorf("requests grouped too closely (max %d)", max) } c.Assert(total, gc.Equals, N) } func (*suite) TestSingleFlight(c *gc.C) { p := cache.New(time.Minute) start := make(chan struct{}) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() x, err := p.Get("x", func() (any, error) { start <- struct{}{} <-start return 99, nil }) c.Check(x, gc.Equals, 99) c.Check(err, gc.Equals, nil) }() // Wait for the fetch to start. <-start wg.Add(1) go func() { defer wg.Done() x, err := p.Get("x", func() (any, error) { c.Errorf("fetch function unexpectedly called with inflight request") return 55, nil }) c.Check(x, gc.Equals, 99) c.Check(err, gc.Equals, nil) }() // Check that we can still get other values while the // other fetches are in progress. y, err := p.Get("y", func() (any, error) { return 88, nil }) c.Check(y, gc.Equals, 88) c.Check(err, gc.Equals, nil) // Let the original fetch proceed, which should let the other one // succeed too, but sleep for a little bit to let the second goroutine // actually initiate its request. time.Sleep(time.Millisecond) start <- struct{}{} wg.Wait() } var errUnexpectedFetch = errors.New("fetch called unexpectedly") func fetchError(err error) func() (any, error) { return func() (any, error) { return nil, err } } func fetchValue(val any) func() (any, error) { return func() (any, error) { return val, nil } } ================================================ FILE: cache/export_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package cache var GetAtTime = (*Cache).getAtTime func OldLen(c *Cache) int { return len(c.old) } ================================================ FILE: cache/package_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package cache_test import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: cert/cert.go ================================================ // Copyright 2012, 2013 Canonical Ltd. // Copyright 2016 Cloudbase solutions // Licensed under the LGPLv3, see LICENCE file for details. package cert import ( "crypto" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/pem" "fmt" "github.com/juju/errors" ) // OtherName type for asn1 encoding type OtherName struct { A string `asn1:"utf8"` } // GeneralName type for asn1 encoding type GeneralName struct { OID asn1.ObjectIdentifier OtherName `asn1:"tag:0"` } // GeneralNames type for asn1 encoding type GeneralNames struct { GeneralName `asn1:"tag:0"` } var ( // https://support.microsoft.com/en-us/kb/287547 // szOID_NT_PRINCIPAL_NAME 1.3.6.1.4.1.311.20.2.3 szOID = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 20, 2, 3} // http://www.umich.edu/~x509/ssleay/asn1-oids.html // 2 5 29 17 subjectAltName subjAltName = asn1.ObjectIdentifier{2, 5, 29, 17} ) // getUPNExtensionValue returns marsheled asn1 encoded info func getUPNExtensionValue(subject pkix.Name) ([]byte, error) { // returns the ASN.1 encoding of val // in addition to the struct tags recognized // we used: // utf8 => causes string to be marsheled as ASN.1, UTF8 strings // tag:x => specifies the ASN.1 tag number; imples ASN.1 CONTEXT SPECIFIC return asn1.Marshal(GeneralNames{ GeneralName: GeneralName{ // init our ASN.1 object identifier OID: szOID, OtherName: OtherName{ A: subject.CommonName, }, }, }) } // ParseCert parses the given PEM-formatted X509 certificate. func ParseCert(certPEM string) (*x509.Certificate, error) { certPEMData := []byte(certPEM) for len(certPEMData) > 0 { var certBlock *pem.Block certBlock, certPEMData = pem.Decode(certPEMData) if certBlock == nil { break } if certBlock.Type == "CERTIFICATE" { cert, err := x509.ParseCertificate(certBlock.Bytes) return cert, err } } return nil, errors.New("no certificates found") } // ParseCertAndKey parses the given PEM-formatted X509 certificate // and RSA private key. func ParseCertAndKey(certPEM, keyPEM string) (*x509.Certificate, crypto.Signer, error) { tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) if err != nil { return nil, nil, err } cert, err := x509.ParseCertificate(tlsCert.Certificate[0]) if err != nil { return nil, nil, err } key, ok := tlsCert.PrivateKey.(crypto.Signer) if !ok { return nil, nil, fmt.Errorf("private key with unexpected type %T", tlsCert.PrivateKey) } return cert, key, nil } ================================================ FILE: cert/cert_test.go ================================================ // Copyright 2012, 2013 Canonical Ltd. // Copyright 2016 Cloudbase solutions // Licensed under the LGPLv3, see LICENCE file for details. package cert_test import ( "testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/cert" ) func TestAll(t *testing.T) { gc.TestingT(t) } type certSuite struct{} var _ = gc.Suite(certSuite{}) func (certSuite) TestParseCertificate(c *gc.C) { xcert, err := cert.ParseCert(caCertPEM) c.Assert(err, jc.ErrorIsNil) c.Assert(xcert.Subject.CommonName, gc.Equals, `juju-generated CA for model "juju testing"`) xcert, err = cert.ParseCert(caKeyPEM) c.Check(xcert, gc.IsNil) c.Assert(err, gc.ErrorMatches, "no certificates found") xcert, err = cert.ParseCert("hello") c.Check(xcert, gc.IsNil) c.Assert(err, gc.ErrorMatches, "no certificates found") } func (certSuite) TestParseCertAndKey(c *gc.C) { xcert, key, err := cert.ParseCertAndKey(caCertPEM, caKeyPEM) c.Assert(err, jc.ErrorIsNil) c.Assert(xcert.Subject.CommonName, gc.Equals, `juju-generated CA for model "juju testing"`) c.Assert(key, gc.NotNil) c.Assert(xcert.PublicKey, gc.DeepEquals, key.Public()) } var ( caCertPEM = ` -----BEGIN CERTIFICATE----- MIICHDCCAcagAwIBAgIUfzWn5ktGMxD6OiTgfiZyvKdM+ZYwDQYJKoZIhvcNAQEL BQAwazENMAsGA1UEChMEanVqdTEzMDEGA1UEAwwqanVqdS1nZW5lcmF0ZWQgQ0Eg Zm9yIG1vZGVsICJqdWp1IHRlc3RpbmciMSUwIwYDVQQFExwxMjM0LUFCQ0QtSVMt Tk9ULUEtUkVBTC1VVUlEMB4XDTE2MDkyMTEwNDgyN1oXDTI2MDkyODEwNDgyN1ow azENMAsGA1UEChMEanVqdTEzMDEGA1UEAwwqanVqdS1nZW5lcmF0ZWQgQ0EgZm9y IG1vZGVsICJqdWp1IHRlc3RpbmciMSUwIwYDVQQFExwxMjM0LUFCQ0QtSVMtTk9U LUEtUkVBTC1VVUlEMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAL+0X+1zl2vt1wI4 1Q+RnlltJyaJmtwCbHRhREXVGU7t0kTMMNERxqLnuNUyWRz90Rg8s9XvOtCqNYW7 mypGrFECAwEAAaNCMEAwDgYDVR0PAQH/BAQDAgKkMA8GA1UdEwEB/wQFMAMBAf8w HQYDVR0OBBYEFHueMLZ1QJ/2sKiPIJ28TzjIMRENMA0GCSqGSIb3DQEBCwUAA0EA ovZN0RbUHrO8q9Eazh0qPO4mwW9jbGTDz126uNrLoz1g3TyWxIas1wRJ8IbCgxLy XUrBZO5UPZab66lJWXyseA== -----END CERTIFICATE----- ` caKeyPEM = ` -----BEGIN RSA PRIVATE KEY----- MIIBOgIBAAJBAL+0X+1zl2vt1wI41Q+RnlltJyaJmtwCbHRhREXVGU7t0kTMMNER xqLnuNUyWRz90Rg8s9XvOtCqNYW7mypGrFECAwEAAQJAMPa+JaUHgO6foxam/LIB 0u95N3OgFR+dWeBaEsgKDclpREdJ0rXNI+3C3kwqeEZR4omoPlBeSEewSkwHxpmI 0QIhAOjKiHZ5v6R8haleipbDzkGUnZW07hEwL5Ld4MNx/QQ1AiEA0tEzSSNAdM0C M/vY0x5mekIYai8/tFSEG9PJ3ZkpEy0CIQCo9B3YxwI1Un777vbs903iQQeiWP+U EAHnOQvhLgDxpQIgGkpml+9igW5zoOH+h02aQBLwEoXz7tw/YW0HFrCcE70CIGkS ve4WjiEqnQaHNAPy0hY/1DfIgBOSpOfnkFHOk9vX -----END RSA PRIVATE KEY----- ` ) ================================================ FILE: cert/exports_test.go ================================================ // Copyright 2016 Canonical ltd. // Copyright 2016 Cloudbase solutions // Licensed under the LGPLv3, see LICENCE file for details. package cert ================================================ FILE: command.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "os/exec" ) // RunCommand executes the command and return the combined output. func RunCommand(command string, args ...string) (output string, err error) { cmd := exec.Command(command, args...) out, err := cmd.CombinedOutput() output = string(out) if err != nil { return output, err } return output, nil } ================================================ FILE: command_test.go ================================================ // Copyright 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "io/ioutil" "path/filepath" "runtime" "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type EnvironmentPatcher interface { PatchEnvironment(name, value string) } func patchExecutable(patcher EnvironmentPatcher, dir, execName, script string) { patcher.PatchEnvironment("PATH", dir) filename := filepath.Join(dir, execName) ioutil.WriteFile(filename, []byte(script), 0755) } type commandSuite struct { testing.IsolationSuite } var _ = gc.Suite(&commandSuite{}) func (s *commandSuite) TestRunCommandCombinesOutput(c *gc.C) { var content string var cmdName string var expect string if runtime.GOOS != "windows" { content = `#!/bin/bash --norc echo stdout echo stderr 1>&2 ` cmdName = "test-output" expect = "stdout\nstderr\n" } else { content = `@echo off echo stdout echo stderr 1>&2 ` cmdName = "test-output.bat" expect = "stdout\r\nstderr \r\n" } patchExecutable(s, c.MkDir(), cmdName, content) output, err := utils.RunCommand("test-output") c.Assert(err, gc.IsNil) c.Assert(output, gc.Equals, expect) } func (s *commandSuite) TestRunCommandNonZeroExit(c *gc.C) { var content string var cmdName string var expect string if runtime.GOOS != "windows" { content = `#!/bin/bash --norc echo stdout exit 42 ` cmdName = "test-output" expect = "stdout\n" } else { content = `@echo off echo stdout exit 42 ` cmdName = "test-output.bat" expect = "stdout\r\n" } patchExecutable(s, c.MkDir(), cmdName, content) output, err := utils.RunCommand("test-output") c.Assert(err, gc.ErrorMatches, `exit status 42`) c.Assert(output, gc.Equals, expect) } ================================================ FILE: context.go ================================================ // Copyright 2018 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "fmt" "sync" "time" "golang.org/x/net/context" "github.com/juju/clock" ) // timerCtx is an implementation of context.Context that // is done when a given deadline has passed // (as measured by the Clock in the clock field) type timerCtx struct { clock clock.Clock timer clock.Timer deadline time.Time parent context.Context done chan struct{} // mu guards err. mu sync.Mutex // err holds context.Canceled or context.DeadlineExceeded // after the context has been canceled. // If this is non-nil, then done will have been closed. err error } func (ctx *timerCtx) Deadline() (time.Time, bool) { return ctx.deadline, true } func (ctx *timerCtx) Err() error { ctx.mu.Lock() defer ctx.mu.Unlock() return ctx.err } func (ctx *timerCtx) Value(key any) any { return ctx.parent.Value(key) } func (ctx *timerCtx) Done() <-chan struct{} { return ctx.done } func (ctx *timerCtx) cancel(err error) { ctx.mu.Lock() defer ctx.mu.Unlock() if err == nil { panic("cancel with nil error!") } if ctx.err != nil { // Already canceled - no need to do anything. return } ctx.err = err if ctx.timer != nil { ctx.timer.Stop() } close(ctx.done) } func (ctx *timerCtx) String() string { return fmt.Sprintf("%v.WithDeadline(%s [%s])", ctx.parent, ctx.deadline, ctx.deadline.Sub(ctx.clock.Now())) } // ContextWithTimeout is like context.WithTimeout // except that it works with a clock.Clock rather than // wall-clock time. func ContextWithTimeout(parent context.Context, clk clock.Clock, timeout time.Duration) (context.Context, context.CancelFunc) { return ContextWithDeadline(parent, clk, clk.Now().Add(timeout)) } // ContextWithDeadline is like context.WithDeadline // except that it works with a clock.Clock rather than // wall-clock time. func ContextWithDeadline(parent context.Context, clk clock.Clock, deadline time.Time) (context.Context, context.CancelFunc) { d := deadline.Sub(clk.Now()) ctx := &timerCtx{ clock: clk, parent: parent, deadline: deadline, done: make(chan struct{}), } if d <= 0 { // deadline has already passed ctx.cancel(context.DeadlineExceeded) return ctx, func() {} } ctx.timer = clk.NewTimer(d) go func() { select { case <-ctx.timer.Chan(): ctx.cancel(context.DeadlineExceeded) case <-parent.Done(): ctx.cancel(parent.Err()) case <-ctx.done: } }() return ctx, func() { ctx.cancel(context.Canceled) } } ================================================ FILE: context_test.go ================================================ // Copyright 2018 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "fmt" "time" "golang.org/x/net/context" "github.com/juju/clock/testclock" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type contextSuite struct{} var _ = gc.Suite(&contextSuite{}) // Note: the logic in these tests was copied from the tests // in the Go standard library. func (*contextSuite) TestDeadline(c *gc.C) { clk := testclock.NewClock(time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC)) ctx, cancel := utils.ContextWithDeadline(context.Background(), clk, clk.Now().Add(50*time.Millisecond)) defer cancel() c.Assert(fmt.Sprint(ctx), gc.Equals, `context.Background.WithDeadline(2000-01-01 00:00:00.05 +0000 UTC [50ms])`) testContextDeadline(c, ctx, "WithDeadline", clk, 1, 50*time.Millisecond) ctx, cancel = utils.ContextWithDeadline(context.Background(), clk, clk.Now().Add(50*time.Millisecond)) defer cancel() o := otherContext{ctx} testContextDeadline(c, o, "WithDeadline+otherContext", clk, 1, 50*time.Millisecond) ctx, cancel = utils.ContextWithDeadline(context.Background(), clk, clk.Now().Add(50*time.Millisecond)) defer cancel() o = otherContext{ctx} ctx, _ = utils.ContextWithDeadline(o, clk, clk.Now().Add(4*time.Second)) testContextDeadline(c, ctx, "WithDeadline+otherContext+WithDeadline", clk, 2, 50*time.Millisecond) ctx, cancel = utils.ContextWithDeadline(context.Background(), clk, clk.Now().Add(-time.Millisecond)) defer cancel() testContextDeadline(c, ctx, "WithDeadline+inthepast", clk, 0, 0) ctx, cancel = utils.ContextWithDeadline(context.Background(), clk, clk.Now()) testContextDeadline(c, ctx, "WithDeadline+now", clk, 0, 0) } func (*contextSuite) TestTimeout(c *gc.C) { clk := testclock.NewClock(time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC)) ctx, _ := utils.ContextWithTimeout(context.Background(), clk, 50*time.Millisecond) c.Assert(fmt.Sprint(ctx), gc.Equals, `context.Background.WithDeadline(2000-01-01 00:00:00.05 +0000 UTC [50ms])`) testContextDeadline(c, ctx, "WithTimeout", clk, 1, 50*time.Millisecond) ctx, _ = utils.ContextWithTimeout(context.Background(), clk, 50*time.Millisecond) o := otherContext{ctx} testContextDeadline(c, o, "WithTimeout+otherContext", clk, 1, 50*time.Millisecond) ctx, _ = utils.ContextWithTimeout(context.Background(), clk, 50*time.Millisecond) o = otherContext{ctx} ctx, _ = utils.ContextWithTimeout(o, clk, 3*time.Second) testContextDeadline(c, ctx, "WithTimeout+otherContext+WithTimeout", clk, 2, 50*time.Millisecond) } func (*contextSuite) TestCanceledTimeout(c *gc.C) { clk := testclock.NewClock(time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC)) ctx, _ := utils.ContextWithTimeout(context.Background(), clk, time.Second) o := otherContext{ctx} ctx, cancel := utils.ContextWithTimeout(o, clk, 2*time.Second) cancel() time.Sleep(100 * time.Millisecond) // let cancelation propagate select { case <-ctx.Done(): default: c.Errorf("<-ctx.Done() blocked, but shouldn't have") } c.Assert(ctx.Err(), gc.Equals, context.Canceled) } func testContextDeadline(c *gc.C, ctx context.Context, name string, clk *testclock.Clock, waiters int, failAfter time.Duration) { err := clk.WaitAdvance(failAfter, 0, waiters) c.Assert(err, jc.ErrorIsNil) select { case <-time.After(time.Second): c.Fatalf("%s: context should have timed out", name) case <-ctx.Done(): } c.Assert(ctx.Err(), gc.Equals, context.DeadlineExceeded) } // otherContext is a Context that's not one of the types defined in context.go. // This lets us test code paths that differ based on the underlying type of the // Context. type otherContext struct { context.Context } ================================================ FILE: du/LICENSE.ricochet2200 ================================================ This is free and unencumbered software released into the public domain. Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For more information, please refer to ================================================ FILE: du/diskusage.go ================================================ // Copied from https://github.com/ricochet2200/go-disk-usage // Copyright 2011 Rick Smith. // Use of this source code is governed by a public domain // license that can be found in the LICENSE.ricochet2200 file. // //go:build !windows // +build !windows package du import "syscall" type DiskUsage struct { stat *syscall.Statfs_t } // Returns an object holding the disk usage of volumePath // This function assumes volumePath is a valid path func NewDiskUsage(volumePath string) *DiskUsage { var stat syscall.Statfs_t syscall.Statfs(volumePath, &stat) return &DiskUsage{&stat} } // Total free bytes on file system func (this *DiskUsage) Free() uint64 { return this.stat.Bfree * uint64(this.stat.Bsize) } // Total available bytes on file system to an unpriveleged user func (this *DiskUsage) Available() uint64 { return this.stat.Bavail * uint64(this.stat.Bsize) } // Total size of the file system func (this *DiskUsage) Size() uint64 { return this.stat.Blocks * uint64(this.stat.Bsize) } // Total bytes used in file system func (this *DiskUsage) Used() uint64 { return this.Size() - this.Free() } // Percentage of use on the file system func (this *DiskUsage) Usage() float32 { return float32(this.Used()) / float32(this.Size()) } ================================================ FILE: du/diskusage_windows.go ================================================ // Copied from https://github.com/ricochet2200/go-disk-usage // Copyright 2011 Rick Smith. // Use of this source code is governed by a public domain // license that can be found in the LICENSE.ricochet2200 file. // package du import ( "syscall" "unsafe" ) type DiskUsage struct { freeBytes int64 totalBytes int64 availBytes int64 } // Returns an object holding the disk usage of volumePath // This function assumes volumePath is a valid path func NewDiskUsage(volumePath string) *DiskUsage { h := syscall.MustLoadDLL("kernel32.dll") c := h.MustFindProc("GetDiskFreeSpaceExW") du := &DiskUsage{} c.Call( uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(volumePath))), uintptr(unsafe.Pointer(&du.freeBytes)), uintptr(unsafe.Pointer(&du.totalBytes)), uintptr(unsafe.Pointer(&du.availBytes))) return du } // Total free bytes on file system func (this *DiskUsage) Free() uint64 { return uint64(this.freeBytes) } // Total available bytes on file system to an unpriveleged user func (this *DiskUsage) Available() uint64 { return uint64(this.availBytes) } // Total size of the file system func (this *DiskUsage) Size() uint64 { return uint64(this.totalBytes) } // Total bytes used in file system func (this *DiskUsage) Used() uint64 { return this.Size() - this.Free() } // Percentage of use on the file system func (this *DiskUsage) Usage() float32 { return float32(this.Used()) / float32(this.Size()) } ================================================ FILE: errors.go ================================================ // Copyright 2024 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "fmt" ) // RcPassthroughError indicates that a Juju plugin command exited with a // non-zero exit code. This error is used to exit with the return code. type RcPassthroughError struct { Code int } // Error implements error. func (e *RcPassthroughError) Error() string { return fmt.Sprintf("subprocess encountered error code %v", e.Code) } // IsRcPassthroughError returns whether the error is an RcPassthroughError. func IsRcPassthroughError(err error) bool { _, ok := err.(*RcPassthroughError) return ok } // NewRcPassthroughError creates an error that will have the code used at the // return code from the cmd.Main function rather than the default of 1 if // there is an error. func NewRcPassthroughError(code int) error { return &RcPassthroughError{code} } ================================================ FILE: exec/exec.go ================================================ // Copyright 2016 Canonical Ltd. // Copyright 2016 Cloudbase Solutions // Licensed under the LGPLv3, see LICENCE file for details. package exec import ( "bytes" "fmt" "io/ioutil" "os" "os/exec" "path/filepath" "runtime" "strings" "syscall" "time" "github.com/juju/clock" "github.com/juju/errors" "github.com/juju/loggo/v2" ) var logger = loggo.GetLogger("juju.util.exec") // Parameters for RunCommands. Commands contains one or more commands to be // executed using bash or PowerShell. If WorkingDir is set, this is passed // through. Similarly if the Environment is specified, this is used // for executing the command. // TODO: refactor this to use a config struct and a constructor. Remove todo // and extra code from WaitWithCancel once this is done. type RunParams struct { Commands string WorkingDir string Environment []string Clock clock.Clock KillProcess func(*os.Process) error User string tempDir string stdout *bytes.Buffer stderr *bytes.Buffer ps *exec.Cmd } // ExecResponse contains the return code and output generated by executing a // command. type ExecResponse struct { Code int Stdout []byte Stderr []byte } // mergeEnvironment takes in a string array representing the desired environment // and merges it with the current environment. On Windows, clearing the environment, // or having missing environment variables, may lead to standard go packages not working // (os.TempDir relies on $env:TEMP), and powershell erroring out // Currently this function is only used for windows func mergeEnvironment(env []string) []string { if env == nil { return nil } m := make(map[string]string) var tmpEnv []string for _, val := range os.Environ() { varSplit := strings.SplitN(val, "=", 2) m[varSplit[0]] = varSplit[1] } for _, val := range env { varSplit := strings.SplitN(val, "=", 2) m[varSplit[0]] = varSplit[1] } for key, val := range m { tmpEnv = append(tmpEnv, key+"="+val) } return tmpEnv } // shellAndArgs returns the name of the shell command and arguments to run the // specified script. shellAndArgs may write into the provided temporary // directory, which will be maintained until the process exits. func shellAndArgs(tempDir, script, user string) (string, []string, error) { var scriptFile string var cmd string var args []string switch runtime.GOOS { case "windows": scriptFile = filepath.Join(tempDir, "script.ps1") cmd = "powershell.exe" args = []string{ "-NoProfile", "-NonInteractive", "-ExecutionPolicy", "RemoteSigned", "-File", scriptFile, } // Exceptions don't result in a non-zero exit code by default // when using -File. The exit code of an explicit "exit" when // using -Command is ignored and results in an exit code of 1. // We use -File and trap exceptions to cover both. script = "trap {Write-Error $_; exit 1}\n" + script default: scriptFile = filepath.Join(tempDir, "script.sh") if user == "" { cmd = "/bin/bash" args = []string{scriptFile} } else { // Need to make the tempDir readable by all so the user can see it. err := os.Chmod(tempDir, 0755) if err != nil { return "", nil, errors.Annotatef(err, "making tempdir readable by %q", user) } cmd = "/bin/su" args = []string{user, "--login", "--command", fmt.Sprintf("/bin/bash %s", scriptFile)} } } err := ioutil.WriteFile(scriptFile, []byte(script), 0644) if err != nil { return "", nil, err } return cmd, args, nil } // Run sets up the command environment (environment variables, working dir) // and starts the process. The commands are passed into bash on Linux machines // and to powershell on Windows machines. func (r *RunParams) Run() error { if runtime.GOOS == "windows" { r.Environment = mergeEnvironment(r.Environment) } tempDir, err := ioutil.TempDir("", "juju-exec") if err != nil { return err } shell, args, err := shellAndArgs(tempDir, r.Commands, r.User) if err != nil { if err := os.RemoveAll(tempDir); err != nil { logger.Warningf("failed to remove temporary directory: %v", err) } return err } r.ps = exec.Command(shell, args...) if r.Environment != nil { r.ps.Env = r.Environment } if r.WorkingDir != "" { r.ps.Dir = r.WorkingDir } r.populateSysProcAttr() // If there is no user provided KillProcess function we // use the default one. if r.KillProcess == nil { r.KillProcess = KillProcess } r.tempDir = tempDir r.stdout = &bytes.Buffer{} r.stderr = &bytes.Buffer{} r.ps.Stdout = r.stdout r.ps.Stderr = r.stderr return r.ps.Start() } // Process returns the *os.Process instance of the current running process // This will allow us to kill the process if needed, or get more information // on the process func (r *RunParams) Process() *os.Process { if r.ps != nil && r.ps.Process != nil { return r.ps.Process } return nil } // Wait blocks until the process exits, and returns an ExecResponse type // containing stdout, stderr and the return code of the process. If a non-zero // return code is returned, this is collected as the code for the response and // this does not classify as an error. func (r *RunParams) Wait() (*ExecResponse, error) { var err error if r.ps == nil { return nil, errors.New("No process has been started yet") } err = r.ps.Wait() if err := os.RemoveAll(r.tempDir); err != nil { logger.Warningf("failed to remove temporary directory: %v", err) } result := &ExecResponse{ Stdout: r.stdout.Bytes(), Stderr: r.stderr.Bytes(), } if ee, ok := err.(*exec.ExitError); ok && err != nil { status := ee.ProcessState.Sys().(syscall.WaitStatus) if status.Exited() { // A non-zero return code isn't considered an error here. result.Code = status.ExitStatus() err = nil } logger.Infof("run result: %v", ee) } return result, err } // ErrCancelled is returned by WaitWithCancel in case it successfully manages to kill // the running process. var ErrCancelled = errors.New("command cancelled") // timeWaitForKill reperesent the time we wait after attempting to kill a // process before bailing out and returning. const timeWaitForKill = 30 * time.Second type resultWithError struct { execResult *ExecResponse err error } // WaitWithCancel waits until the process exits or until a signal is sent on the // cancel channel. In case a signal is sent it first tries to kill the process and // return ErrCancelled. If it fails at killing the process it will return anyway // and report the problematic PID. func (r *RunParams) WaitWithCancel(cancel <-chan struct{}) (*ExecResponse, error) { // TODO: Remove this once we make Clock a required field _clock := r.Clock if _clock == nil { _clock = clock.WallClock } done := make(chan resultWithError, 1) go func() { defer close(done) waitResult, err := r.Wait() done <- resultWithError{waitResult, err} }() select { case resWithError := <-done: return resWithError.execResult, errors.Trace(resWithError.err) case <-cancel: logger.Debugf("attempting to kill process") err := r.KillProcess(r.ps.Process) if err != nil { logger.Debugf("kill returned: %s", err) } // After we issue a kill we expect the wait above to return within timeWaitForKill. // In case it doesn't we just go on and assume the process is stuck, but we don't block select { case resWithError := <-done: return resWithError.execResult, ErrCancelled case <-_clock.After(timeWaitForKill): return nil, errors.Errorf("tried to kill process %v, but timed out", r.ps.Process.Pid) } } } // RunCommands executes the Commands specified in the RunParams using // powershell on windows, and '/bin/bash -s' on everything else, // passing the commands through as stdin, and collecting // stdout and stderr. If a non-zero return code is returned, this is // collected as the code for the response and this does not classify as an // error. func RunCommands(run RunParams) (*ExecResponse, error) { err := run.Run() if err != nil { return nil, err } return run.Wait() } ================================================ FILE: exec/exec_internal_test.go ================================================ // Copyright 2017 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package exec import ( "os" "path/filepath" "runtime" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" ) type execSuite struct { testing.IsolationSuite } var _ = gc.Suite(&execSuite{}) func (*execSuite) TestShellAndArgsNoUserSpecified(c *gc.C) { if runtime.GOOS == "windows" { c.Skip("non-windows only test") } dir := c.MkDir() stat, err := os.Stat(dir) c.Assert(err, jc.ErrorIsNil) c.Assert(stat.Mode().Perm(), gc.Equals, os.FileMode(0700)) cmd, args, err := shellAndArgs(dir, "env", "") c.Assert(err, jc.ErrorIsNil) scriptFile := filepath.Join(dir, "script.sh") c.Assert(cmd, gc.Equals, "/bin/bash") c.Assert(args, jc.DeepEquals, []string{scriptFile}) } func (*execSuite) TestShellAndArgsAsUser(c *gc.C) { if runtime.GOOS == "windows" { c.Skip("non-windows only test") } dir := c.MkDir() stat, err := os.Stat(dir) c.Assert(err, jc.ErrorIsNil) c.Assert(stat.Mode().Perm(), gc.Equals, os.FileMode(0700)) cmd, args, err := shellAndArgs(dir, "env", "ubuntu") c.Assert(err, jc.ErrorIsNil) scriptFile := filepath.Join(dir, "script.sh") c.Assert(cmd, gc.Equals, "/bin/su") command := "/bin/bash " + scriptFile c.Assert(args, jc.DeepEquals, []string{"ubuntu", "--login", "--command", command}) // The directory is now readable by everyone. stat, err = os.Stat(dir) c.Assert(err, jc.ErrorIsNil) c.Assert(stat.Mode().Perm(), gc.Equals, os.FileMode(0755)) // And the file is world readable stat, err = os.Stat(scriptFile) c.Assert(err, jc.ErrorIsNil) c.Assert(stat.Mode().Perm(), gc.Equals, os.FileMode(0644)) } ================================================ FILE: exec/exec_linux_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package exec_test import ( jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/exec" ) // 0 is thrown by linux because RunParams.Wait // only sets the code if the process exits cleanly const cancelErrCode = 0 func (*execSuite) TestRunCommands(c *gc.C) { newDir := c.MkDir() for i, test := range []struct { message string commands string workingDir string environment []string stdout string stderr string code int }{ { message: "test stdout capture", commands: "echo testing stdout", stdout: "testing stdout\n", }, { message: "test stderr capture", commands: "echo testing stderr >&2", stderr: "testing stderr\n", }, { message: "test return code", commands: "exit 42", code: 42, }, { message: "test working dir", commands: "pwd", workingDir: newDir, stdout: newDir + "\n", }, { message: "test environment", commands: "echo $OMG_IT_WORKS", environment: []string{"OMG_IT_WORKS=like magic"}, stdout: "like magic\n", }, { message: "multiple commands", commands: "cat\necho 123", stdout: "123\n", }, } { c.Logf("%v: %s", i, test.message) params := exec.RunParams{ Commands: test.commands, WorkingDir: test.workingDir, Environment: test.environment, } result, err := exec.RunCommands(params) c.Assert(err, gc.IsNil) c.Assert(string(result.Stdout), gc.Equals, test.stdout) c.Assert(string(result.Stderr), gc.Equals, test.stderr) c.Assert(result.Code, gc.Equals, test.code) err = params.Run() c.Assert(err, gc.IsNil) c.Assert(params.Process(), gc.Not(gc.IsNil)) result, err = params.Wait() c.Assert(err, gc.IsNil) c.Assert(string(result.Stdout), gc.Equals, test.stdout) c.Assert(string(result.Stderr), gc.Equals, test.stderr) c.Assert(result.Code, gc.Equals, test.code) err = params.Run() c.Assert(err, gc.IsNil) c.Assert(params.Process(), gc.Not(gc.IsNil)) result, err = params.WaitWithCancel(nil) c.Assert(err, gc.IsNil) c.Assert(string(result.Stdout), gc.Equals, test.stdout) c.Assert(string(result.Stderr), gc.Equals, test.stderr) c.Assert(result.Code, gc.Equals, test.code) } } func (*execSuite) TestExecUnknownCommand(c *gc.C) { result, err := exec.RunCommands( exec.RunParams{ Commands: "unknown-command", }, ) c.Assert(err, gc.IsNil) c.Assert(result.Stdout, gc.HasLen, 0) c.Assert(string(result.Stderr), jc.Contains, "unknown-command: command not found") // 127 is a special bash return code meaning command not found. c.Assert(result.Code, gc.Equals, 127) } ================================================ FILE: exec/exec_test.go ================================================ // Copyright 2016 Canonical Ltd. // Copyright 2016 Cloudbase Solutions // Licensed under the LGPLv3, see LICENCE file for details. package exec_test import ( "fmt" "os" "time" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/clock" "github.com/juju/utils/v4/exec" ) type execSuite struct { testing.IsolationSuite } var _ = gc.Suite(&execSuite{}) func (*execSuite) TestWaitWithCancel(c *gc.C) { params := exec.RunParams{ Commands: "sleep 100", Clock: &mockClock{C: make(chan time.Time)}, } err := params.Run() c.Assert(err, gc.IsNil) c.Assert(params.Process(), gc.Not(gc.IsNil)) cancelChan := make(chan struct{}, 1) defer close(cancelChan) cancelChan <- struct{}{} result, err := params.WaitWithCancel(cancelChan) c.Assert(err, gc.Equals, exec.ErrCancelled) c.Assert(string(result.Stdout), gc.Equals, "") c.Assert(string(result.Stderr), gc.Equals, "") c.Assert(result.Code, gc.Equals, cancelErrCode) } func (s *execSuite) TestKillAbortedIfUnsuccessfull(c *gc.C) { killCalled := false mockChan := make(chan time.Time, 1) defer close(mockChan) params := exec.RunParams{ Commands: "sleep 100", WorkingDir: "", Environment: []string{}, Clock: &mockClock{C: mockChan}, KillProcess: func(*os.Process) error { killCalled = true return nil }, } err := params.Run() c.Assert(err, gc.IsNil) c.Assert(params.Process(), gc.Not(gc.IsNil)) cancelChan := make(chan struct{}, 1) defer close(cancelChan) cancelChan <- struct{}{} mockChan <- time.Now() res, err := params.WaitWithCancel(cancelChan) c.Assert(err, gc.ErrorMatches, fmt.Sprintf("tried to kill process %d, but timed out", params.Process().Pid)) c.Assert(res, gc.IsNil) c.Assert(killCalled, jc.IsTrue) } type mockClock struct { clock.Clock C <-chan time.Time } func (m *mockClock) After(t time.Duration) <-chan time.Time { return m.C } ================================================ FILE: exec/exec_unix.go ================================================ // Copyright 2016 Canonical Ltd. // Copyright 2016 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. //go:build !windows // +build !windows package exec import ( "os" "syscall" ) // KillProcess tries to kill the process being ran by RunParams // We need this convoluted implementation because everything // ran under the bash script is spawned as a different process // and doesn't get killed by a regular process.Kill() // For details see https://groups.google.com/forum/#!topic/golang-nuts/XoQ3RhFBJl8 func KillProcess(proc *os.Process) error { pgid, err := syscall.Getpgid(proc.Pid) if err == nil { return syscall.Kill(-pgid, 15) // note the minus sign } return nil } // populateSysProcAttr exists so that the method Kill on the same struct // can work correctly. For more information see Kill's comment. func (r *RunParams) populateSysProcAttr() { r.ps.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} } ================================================ FILE: exec/exec_windows.go ================================================ // Copyright 2016 Canonical Ltd. // Copyright 2016 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. //go:build windows // +build windows package exec import ( "os" ) // KillProcess tries to kill the process passed in. func KillProcess(proc *os.Process) error { return proc.Kill() } // populateSysProcAttr is a noop on windows func (r *RunParams) populateSysProcAttr() {} ================================================ FILE: exec/exec_windows_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package exec_test import ( "path/filepath" "strings" "syscall" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/exec" ) // 1 is thrown by powershell after the a command is cancelled const cancelErrCode = 1 // longPath is copied over from the symlink package. This should be removed // if we add it to gc or in some other convenience package func longPath(path string) ([]uint16, error) { pathp, err := syscall.UTF16FromString(path) if err != nil { return nil, err } longp := pathp n, err := syscall.GetLongPathName(&pathp[0], &longp[0], uint32(len(longp))) if err != nil { return nil, err } if n > uint32(len(longp)) { longp = make([]uint16, n) n, err = syscall.GetLongPathName(&pathp[0], &longp[0], uint32(len(longp))) if err != nil { return nil, err } } longp = longp[:n] return longp, nil } func longPathAsString(path string) (string, error) { longp, err := longPath(path) if err != nil { return "", err } return syscall.UTF16ToString(longp), nil } func (*execSuite) TestRunCommands(c *gc.C) { newDir, err := longPathAsString(c.MkDir()) c.Assert(err, gc.IsNil) for i, test := range []struct { message string commands string workingDir string environment []string stdout string stderr string code int }{ { message: "test stdout capture", commands: "echo 'testing stdout'", stdout: "testing stdout\r\n", }, { message: "test stderr capture", commands: "Write-Error 'testing stderr'", stderr: "testing stderr\r\n", }, { message: "test return code", commands: "exit 42", code: 42, }, { message: "test working dir", commands: "(pwd).Path", workingDir: newDir, stdout: filepath.FromSlash(newDir) + "\r\n", }, { message: "test environment", commands: "echo $env:OMG_IT_WORKS", environment: []string{"OMG_IT_WORKS=like magic"}, stdout: "like magic\r\n", }, } { c.Logf("%v: %s", i, test.message) params := exec.RunParams{ Commands: test.commands, WorkingDir: test.workingDir, Environment: test.environment, } result, err := exec.RunCommands(params) c.Assert(err, gc.IsNil) c.Assert(string(result.Stdout), gc.Equals, test.stdout) c.Assert(string(result.Stderr), jc.Contains, test.stderr) c.Assert(result.Code, gc.Equals, test.code) err = params.Run() c.Assert(err, gc.IsNil) c.Assert(params.Process(), gc.Not(gc.IsNil)) result, err = params.Wait() c.Assert(err, gc.IsNil) c.Assert(string(result.Stdout), gc.Equals, test.stdout) c.Assert(string(result.Stderr), jc.Contains, test.stderr) c.Assert(result.Code, gc.Equals, test.code) err = params.Run() c.Assert(err, gc.IsNil) c.Assert(params.Process(), gc.Not(gc.IsNil)) result, err = params.WaitWithCancel(nil) c.Assert(err, gc.IsNil) c.Assert(string(result.Stdout), gc.Equals, test.stdout) c.Assert(string(result.Stderr), jc.Contains, test.stderr) c.Assert(result.Code, gc.Equals, test.code) } } func (*execSuite) TestExecUnknownCommand(c *gc.C) { result, err := exec.RunCommands( exec.RunParams{ Commands: "unknown-command", }, ) c.Assert(err, gc.IsNil) c.Assert(result.Stdout, gc.HasLen, 0) stderr := strings.Replace(string(result.Stderr), "\r\n", "", -1) c.Assert(stderr, jc.Contains, "is not recognized as the name of a cmdlet") // 1 is returned by RunCommands when powershell commands throw exceptions c.Assert(result.Code, gc.Equals, 1) } ================================================ FILE: exec/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package exec_test import ( "testing" gc "gopkg.in/check.v1" ) func Test(t *testing.T) { gc.TestingT(t) } ================================================ FILE: export_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "time" ) var ( GOMAXPROCS = &gomaxprocs NumCPU = &numCPU ResolveSudoByFunc = resolveSudo ) func ExposeBackoffTimerDuration(bot *BackoffTimer) time.Duration { return bot.currentDuration } ================================================ FILE: file.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "fmt" "io" "io/ioutil" "os" "path" "path/filepath" "regexp" "github.com/juju/errors" ) // UserHomeDir returns the home directory for the specified user, or the // home directory for the current user if the specified user is empty. func UserHomeDir(userName string) (hDir string, err error) { if userName == "" { // TODO (wallyworld) - fix tests on Windows // Ordinarily, we'd always use user.Current() to get the current user // and then get the HomeDir from that. But our tests rely on poking // a value into $HOME in order to override the normal home dir for the // current user. So we're forced to use Home() to make the tests pass. // All of our tests currently construct paths with the default user in // mind eg "~/foo". return Home(), nil } hDir, err = homeDir(userName) if err != nil { return "", err } return hDir, nil } // Only match paths starting with ~ (~user/test, ~/test). This will prevent // accidental expansion on Windows when short form paths are present (C:\users\ADMINI~1\test) var userHomePathRegexp = regexp.MustCompile("(^~(?P[^/]*))(?P.*)") // NormalizePath expands a path containing ~ to its absolute form, // and removes any .. or . path elements. func NormalizePath(dir string) (string, error) { if userHomePathRegexp.MatchString(dir) { user := userHomePathRegexp.ReplaceAllString(dir, "$user") userHomeDir, err := UserHomeDir(user) if err != nil { return "", err } dir = userHomePathRegexp.ReplaceAllString(dir, fmt.Sprintf("%s$path", userHomeDir)) } return filepath.Clean(dir), nil } // ExpandPath normalises (via Normalize) a path returning an absolute path. func ExpandPath(path string) (string, error) { normPath, err := NormalizePath(path) if err != nil { return "", errors.Annotate(err, "unable to normalise file path") } return filepath.Abs(normPath) } // EnsureBaseDir ensures that path is always prefixed by baseDir, // allowing for the fact that path might have a Window drive letter in // it. func EnsureBaseDir(baseDir, path string) string { if baseDir == "" { return path } volume := filepath.VolumeName(path) return filepath.Join(baseDir, path[len(volume):]) } // JoinServerPath joins any number of path elements into a single path, adding // a path separator (based on the current juju server OS) if necessary. The // result is Cleaned; in particular, all empty strings are ignored. func JoinServerPath(elem ...string) string { return path.Join(elem...) } // UniqueDirectory returns "path/name" if that directory doesn't exist. If it // does, the method starts appending .1, .2, etc until a unique name is found. func UniqueDirectory(path, name string) (string, error) { dir := filepath.Join(path, name) _, err := os.Stat(dir) if os.IsNotExist(err) { return dir, nil } for i := 1; ; i++ { dir := filepath.Join(path, fmt.Sprintf("%s.%d", name, i)) _, err := os.Stat(dir) if os.IsNotExist(err) { return dir, nil } else if err != nil { return "", err } } } // CopyFile writes the contents of the given source file to dest. func CopyFile(dest, source string) error { df, err := os.Create(dest) if err != nil { return err } f, err := os.Open(source) if err != nil { return err } defer f.Close() _, err = io.Copy(df, f) return err } // AtomicWriteFileAndChange atomically writes the filename with the // given contents and calls the given function after the contents were // written, but before the file is renamed. func AtomicWriteFileAndChange(filename string, contents []byte, change func(string) error) (err error) { dir, file := filepath.Split(filename) f, err := ioutil.TempFile(dir, file) if err != nil { return fmt.Errorf("cannot create temp file: %v", err) } defer func() { _ = f.Close() }() defer func() { if err != nil { // Don't leave the temp file lying around on error. // Close the file before removing. Trying to remove an open file on // Windows will fail. _ = f.Close() _ = os.Remove(f.Name()) } }() if _, err := f.Write(contents); err != nil { return fmt.Errorf("cannot write %q contents: %v", filename, err) } if err := f.Sync(); err != nil { return err } if err := f.Close(); err != nil { return err } if err := change(f.Name()); err != nil { return err } if err := ReplaceFile(f.Name(), filename); err != nil { return fmt.Errorf("cannot replace %q with %q: %v", f.Name(), filename, err) } return nil } // AtomicWriteFile atomically writes the filename with the given // contents and permissions, replacing any existing file at the same // path. func AtomicWriteFile(filename string, contents []byte, perms os.FileMode) (err error) { return AtomicWriteFileAndChange(filename, contents, func(f string) error { // FileMod.Chmod() is not implemented on Windows, however, os.Chmod() is if err := os.Chmod(f, perms); err != nil { return fmt.Errorf("cannot set permissions: %v", err) } return nil }) } ================================================ FILE: file_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "fmt" "io/ioutil" "os" "os/user" "path/filepath" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type fileSuite struct { testing.IsolationSuite } var _ = gc.Suite(&fileSuite{}) func (*fileSuite) TestNormalizePath(c *gc.C) { home := filepath.FromSlash(c.MkDir()) err := utils.SetHome(home) c.Assert(err, gc.IsNil) // TODO (frankban) bug 1324841: improve the isolation of this suite. currentUser, err := user.Current() c.Assert(err, gc.IsNil) for i, test := range []struct { path string expected string err string }{{ path: filepath.FromSlash("/var/lib/juju"), expected: filepath.FromSlash("/var/lib/juju"), }, { path: "~/foo", expected: filepath.Join(home, "foo"), }, { path: "~/foo//../bar", expected: filepath.Join(home, "bar"), }, { path: "~", expected: home, }, { path: "~" + currentUser.Username, expected: currentUser.HomeDir, }, { path: "~" + currentUser.Username + "/foo", expected: filepath.Join(currentUser.HomeDir, "foo"), }, { path: "~" + currentUser.Username + "/foo//../bar", expected: filepath.Join(currentUser.HomeDir, "bar"), }, { path: filepath.FromSlash("foo~bar/baz"), expected: filepath.FromSlash("foo~bar/baz"), }, { path: "~foobar/path", err: ".*" + utils.NoSuchUserErrRegexp, }} { c.Logf("test %d: %s", i, test.path) actual, err := utils.NormalizePath(test.path) if test.err != "" { c.Check(err, gc.ErrorMatches, test.err) } else { c.Check(err, gc.IsNil) c.Check(actual, gc.Equals, test.expected) } } } func (*fileSuite) TestExpandPath(c *gc.C) { home := filepath.FromSlash(c.MkDir()) err := utils.SetHome(home) c.Assert(err, gc.IsNil) currentUser, err := user.Current() c.Assert(err, gc.IsNil) cwd, err := os.Getwd() c.Assert(err, gc.IsNil) for i, test := range []struct { path string expected string err string }{{ path: filepath.FromSlash("/var/lib/juju"), expected: filepath.FromSlash("/var/lib/juju"), }, { path: "~/foo", expected: filepath.Join(home, "foo"), }, { path: "~/foo//../bar", expected: filepath.Join(home, "bar"), }, { path: "~", expected: home, }, { path: "~" + currentUser.Username, expected: currentUser.HomeDir, }, { path: "~" + currentUser.Username + "/foo", expected: filepath.Join(currentUser.HomeDir, "foo"), }, { path: "~" + currentUser.Username + "/foo//../bar", expected: filepath.Join(currentUser.HomeDir, "bar"), }, { path: filepath.FromSlash("foo~bar/baz"), expected: filepath.Join(cwd, "foo~bar/baz"), }, { path: filepath.FromSlash("foo/bar"), expected: filepath.Join(cwd, "foo", "bar"), }, { path: filepath.FromSlash("foo/../bar"), expected: filepath.Join(cwd, "bar"), }, { path: filepath.FromSlash("foo/./bar"), expected: filepath.Join(cwd, "foo", "bar"), }, { path: "~foobar/path", err: ".*" + utils.NoSuchUserErrRegexp, }} { c.Logf("test %d: %s", i, test.path) actual, err := utils.ExpandPath(test.path) if test.err != "" { c.Check(err, gc.ErrorMatches, test.err) } else { c.Check(err, gc.IsNil) c.Check(actual, gc.Equals, test.expected) c.Check(filepath.IsAbs(actual), jc.IsTrue) } } } func (*fileSuite) TestCopyFile(c *gc.C) { dir := c.MkDir() f, err := ioutil.TempFile(dir, "source") c.Assert(err, gc.IsNil) defer f.Close() _, err = f.Write([]byte("hello world")) c.Assert(err, gc.IsNil) dest := filepath.Join(dir, "dest") err = utils.CopyFile(dest, f.Name()) c.Assert(err, gc.IsNil) data, err := ioutil.ReadFile(dest) c.Assert(err, gc.IsNil) c.Assert(string(data), gc.Equals, "hello world") } var atomicWriteFileTests = []struct { summary string change func(filename string, contents []byte) error check func(c *gc.C, fileInfo os.FileInfo) expectErr string }{{ summary: "atomic file write and chmod 0644", change: func(filename string, contents []byte) error { return utils.AtomicWriteFile(filename, contents, 0765) }, check: func(c *gc.C, fi os.FileInfo) { c.Assert(fi.Mode(), gc.Equals, 0765) }, }, { summary: "atomic file write and change", change: func(filename string, contents []byte) error { chmodChange := func(f string) error { // FileMod.Chmod() is not implemented on Windows, however, os.Chmod() is return os.Chmod(f, 0700) } return utils.AtomicWriteFileAndChange(filename, contents, chmodChange) }, check: func(c *gc.C, fi os.FileInfo) { c.Assert(fi.Mode(), gc.Equals, 0700) }, }, { summary: "atomic file write empty contents", change: func(filename string, contents []byte) error { nopChange := func(string) error { return nil } return utils.AtomicWriteFileAndChange(filename, contents, nopChange) }, }, { summary: "atomic file write and failing change func", change: func(filename string, contents []byte) error { errChange := func(string) error { return fmt.Errorf("pow!") } return utils.AtomicWriteFileAndChange(filename, contents, errChange) }, expectErr: "pow!", }} func (*fileSuite) TestAtomicWriteFile(c *gc.C) { dir := c.MkDir() name := "test.file" path := filepath.Join(dir, name) assertDirContents := func(names ...string) { fis, err := ioutil.ReadDir(dir) c.Assert(err, gc.IsNil) c.Assert(fis, gc.HasLen, len(names)) for i, name := range names { c.Assert(fis[i].Name(), gc.Equals, name) } } assertNotExist := func(path string) { _, err := os.Lstat(path) c.Assert(err, jc.Satisfies, os.IsNotExist) } for i, test := range atomicWriteFileTests { c.Logf("test %d: %s", i, test.summary) // First - test with file not already there. assertDirContents() assertNotExist(path) contents := []byte("some\ncontents") err := test.change(path, contents) if test.expectErr == "" { c.Assert(err, gc.IsNil) data, err := ioutil.ReadFile(path) c.Assert(err, gc.IsNil) c.Assert(data, jc.DeepEquals, contents) assertDirContents(name) } else { c.Assert(err, gc.ErrorMatches, test.expectErr) assertDirContents() continue } // Second - test with a file already there. contents = []byte("new\ncontents") err = test.change(path, contents) c.Assert(err, gc.IsNil) data, err := ioutil.ReadFile(path) c.Assert(err, gc.IsNil) c.Assert(data, jc.DeepEquals, contents) assertDirContents(name) // Remove the file to reset scenario. c.Assert(os.Remove(path), gc.IsNil) } } func (*fileSuite) TestMoveFile(c *gc.C) { d := c.MkDir() dest := filepath.Join(d, "foo") f1Name := filepath.Join(d, ".foo1") f2Name := filepath.Join(d, ".foo2") err := ioutil.WriteFile(f1Name, []byte("macaroni"), 0644) c.Assert(err, gc.IsNil) err = ioutil.WriteFile(f2Name, []byte("cheese"), 0644) c.Assert(err, gc.IsNil) ok, err := utils.MoveFile(f1Name, dest) c.Assert(ok, gc.Equals, true) c.Assert(err, gc.IsNil) ok, err = utils.MoveFile(f2Name, dest) c.Assert(ok, gc.Equals, false) c.Assert(err, gc.NotNil) contents, err := ioutil.ReadFile(dest) c.Assert(err, gc.IsNil) c.Assert(contents, gc.DeepEquals, []byte("macaroni")) } ================================================ FILE: file_unix.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. //go:build !windows // +build !windows package utils import ( "fmt" "os" "os/user" "strconv" "strings" "syscall" "github.com/juju/errors" ) func homeDir(userName string) (string, error) { u, err := user.Lookup(userName) if err != nil { return "", errors.NewUserNotFound(err, "no such user") } return u.HomeDir, nil } // MoveFile atomically moves the source file to the destination, returning // whether the file was moved successfully. If the destination already exists, // it returns an error rather than overwrite it. // // On unix systems, an error may occur with a successful move, if the source // file location cannot be unlinked. func MoveFile(source, destination string) (bool, error) { err := os.Link(source, destination) if err != nil { return false, err } err = os.Remove(source) if err != nil { return true, err } return true, nil } // ReplaceFile atomically replaces the destination file or directory // with the source. The errors that are returned are identical to // those returned by os.Rename. func ReplaceFile(source, destination string) error { return os.Rename(source, destination) } // MakeFileURL returns a file URL if a directory is passed in else it does nothing func MakeFileURL(in string) string { if strings.HasPrefix(in, "/") { return "file://" + in } return in } // ChownPath sets the uid and gid of path to match that of the user // specified. func ChownPath(path, username string) error { u, err := user.Lookup(username) if err != nil { return fmt.Errorf("cannot lookup %q user id: %v", username, err) } uid, err := strconv.Atoi(u.Uid) if err != nil { return fmt.Errorf("invalid user id %q: %v", u.Uid, err) } gid, err := strconv.Atoi(u.Gid) if err != nil { return fmt.Errorf("invalid group id %q: %v", u.Gid, err) } return os.Chown(path, uid, gid) } // IsFileOwner checks to see if the ownership of the file corresponds to // the same username func IsFileOwner(path, username string) (bool, error) { u, err := user.Lookup(username) if err != nil { return false, errors.Annotatef(err, "cannot lookup %q user id", username) } info, err := os.Stat(path) if err != nil { return false, errors.Trace(err) } stat, ok := info.Sys().(*syscall.Stat_t) if !ok { return false, fmt.Errorf("cannot lookup %q file", path) } return (strconv.Itoa(int(stat.Uid)) == u.Uid && strconv.Itoa(int(stat.Gid)) == u.Gid), nil } ================================================ FILE: file_unix_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. //go:build !windows // +build !windows package utils_test import ( "fmt" "os" "path/filepath" "time" gc "gopkg.in/check.v1" "github.com/juju/errors" "github.com/juju/utils/v4" ) type unixFileSuite struct { } var _ = gc.Suite(&unixFileSuite{}) func (s *unixFileSuite) TestEnsureBaseDir(c *gc.C) { c.Assert(utils.EnsureBaseDir(`/a`, `/b/c`), gc.Equals, `/a/b/c`) c.Assert(utils.EnsureBaseDir(`/`, `/b/c`), gc.Equals, `/b/c`) c.Assert(utils.EnsureBaseDir(``, `/b/c`), gc.Equals, `/b/c`) } func (s *unixFileSuite) TestFileOwner(c *gc.C) { username, err := utils.LocalUsername() c.Assert(err, gc.IsNil) path := filepath.Join(os.TempDir(), fmt.Sprintf("file-%d", time.Now().UnixNano())) _, err = os.Create(path) c.Assert(err, gc.IsNil) ok, err := utils.IsFileOwner(path, username) c.Assert(err, gc.IsNil) c.Assert(ok, gc.Equals, true) } func (s *unixFileSuite) TestFileOwnerUsingRoot(c *gc.C) { path := filepath.Join(os.TempDir(), fmt.Sprintf("file-%d", time.Now().UnixNano())) _, err := os.Create(path) c.Assert(err, gc.IsNil) ok, err := utils.IsFileOwner(path, "root") c.Assert(err, gc.IsNil) c.Assert(ok, gc.Equals, false) } func (s *unixFileSuite) TestFileOwnerWithInvalidPath(c *gc.C) { username, err := utils.LocalUsername() c.Assert(err, gc.IsNil) path := filepath.Join(os.TempDir(), "file-bad") ok, err := utils.IsFileOwner(path, username) c.Assert(errors.Cause(err), gc.ErrorMatches, "stat .*: no such file or directory") c.Assert(ok, gc.Equals, false) } func (s *unixFileSuite) TestFileOwnerWithInvalidUsername(c *gc.C) { path := filepath.Join(os.TempDir(), fmt.Sprintf("file-%d", time.Now().UnixNano())) _, err := os.Create(path) c.Assert(err, gc.IsNil) ok, err := utils.IsFileOwner(path, "invalid") c.Assert(errors.Cause(err), gc.ErrorMatches, "user: unknown user invalid") c.Assert(ok, gc.Equals, false) } ================================================ FILE: file_windows.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. //go:build windows // +build windows package utils import ( "fmt" "os" "path/filepath" "syscall" "unsafe" "github.com/juju/errors" ) const ( movefile_replace_existing = 0x1 movefile_write_through = 0x8 ) //sys moveFileEx(lpExistingFileName *uint16, lpNewFileName *uint16, dwFlags uint32) (err error) = MoveFileExW // MoveFile atomically moves the source file to the destination, returning // whether the file was moved successfully. If the destination already exists, // it returns an error rather than overwrite it. func MoveFile(source, destination string) (bool, error) { src, err := syscall.UTF16PtrFromString(source) if err != nil { return false, &os.LinkError{"move", source, destination, err} } dest, err := syscall.UTF16PtrFromString(destination) if err != nil { return false, &os.LinkError{"move", source, destination, err} } // see http://msdn.microsoft.com/en-us/library/windows/desktop/aa365240(v=vs.85).aspx if err := moveFileEx(src, dest, movefile_write_through); err != nil { return false, &os.LinkError{"move", source, destination, err} } return true, nil } // ReplaceFile atomically replaces the destination file or directory with the source. // The errors that are returned are identical to those returned by os.Rename. func ReplaceFile(source, destination string) error { src, err := syscall.UTF16PtrFromString(source) if err != nil { return &os.LinkError{"replace", source, destination, err} } dest, err := syscall.UTF16PtrFromString(destination) if err != nil { return &os.LinkError{"replace", source, destination, err} } // see http://msdn.microsoft.com/en-us/library/windows/desktop/aa365240(v=vs.85).aspx if err := moveFileEx(src, dest, movefile_replace_existing|movefile_write_through); err != nil { return &os.LinkError{"replace", source, destination, err} } return nil } // MakeFileURL returns a proper file URL for the given path/directory func MakeFileURL(in string) string { in = filepath.ToSlash(in) // for windows at least should be : to be considered valid // so we cant do anything with less than that. if len(in) < 2 { return in } if string(in[1]) != ":" { return in } // since go 1.6 http client will only take this format. return "file://" + in } func getUserSID(username string) (string, error) { sid, _, _, e := syscall.LookupSID("", username) if e != nil { return "", e } sidStr, err := sid.String() return sidStr, err } func readRegString(h syscall.Handle, key string) (value string, err error) { var typ uint32 var buf uint32 // Get size of registry key err = syscall.RegQueryValueEx(h, syscall.StringToUTF16Ptr(key), nil, &typ, nil, &buf) if err != nil { return value, err } n := make([]uint16, buf/2+1) err = syscall.RegQueryValueEx(h, syscall.StringToUTF16Ptr(key), nil, &typ, (*byte)(unsafe.Pointer(&n[0])), &buf) if err != nil { return value, err } return syscall.UTF16ToString(n[:]), err } func homeFromRegistry(sid string) (string, error) { var h syscall.Handle // This key will exist on all platforms we support the agent on (windows server 2008 and above) keyPath := fmt.Sprintf("Software\\Microsoft\\Windows NT\\CurrentVersion\\ProfileList\\%s", sid) err := syscall.RegOpenKeyEx(syscall.HKEY_LOCAL_MACHINE, syscall.StringToUTF16Ptr(keyPath), 0, syscall.KEY_READ, &h) if err != nil { return "", err } defer syscall.RegCloseKey(h) str, err := readRegString(h, "ProfileImagePath") if err != nil { return "", err } return str, nil } // homeDir returns a local user home dir on Windows // user.Lookup() does not populate Gid and HomeDir on Windows, // so we get it from the registry func homeDir(user string) (string, error) { u, err := getUserSID(user) if err != nil { return "", errors.NewUserNotFound(err, "no such user") } return homeFromRegistry(u) } // ChownPath is not implemented for Windows. func ChownPath(path, username string) error { // This only exists to allow building on Windows. User lookup and // file ownership needs to be handled in a completely different // way and hasn't yet been implemented. return nil } // IsFileOwner is not implemented for Windows. func IsFileOwner(path, username string) (bool, error) { return true, nil } ================================================ FILE: file_windows_test.go ================================================ // Copyright 2013 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. //go:build windows // +build windows package utils_test import ( gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type windowsFileSuite struct { } var _ = gc.Suite(&windowsFileSuite{}) func (s *windowsFileSuite) TestMakeFileURL(c *gc.C) { var makeFileURLTests = []struct { in string expected string }{{ in: "file://C:\\foo\\baz", expected: "file://C:/foo/baz", }, { in: "C:\\foo\\baz", expected: "file://C:/foo/baz", }, { in: "http://foo/baz", expected: "http://foo/baz", }, { in: "file://C:/foo/baz", expected: "file://C:/foo/baz", }} for i, t := range makeFileURLTests { c.Logf("Test %d", i) c.Assert(utils.MakeFileURL(t.in), gc.Equals, t.expected) } } func (s *windowsFileSuite) TestEnsureBaseDir(c *gc.C) { c.Assert(utils.EnsureBaseDir(`C:\r`, `C:\a\b`), gc.Equals, `C:\r\a\b`) c.Assert(utils.EnsureBaseDir(`C:\r`, `D:\a\b`), gc.Equals, `C:\r\a\b`) c.Assert(utils.EnsureBaseDir(`C:`, `D:\a\b`), gc.Equals, `C:\a\b`) c.Assert(utils.EnsureBaseDir(`C:`, `\a\b`), gc.Equals, `C:\a\b`) c.Assert(utils.EnsureBaseDir(``, `C:\a\b`), gc.Equals, `C:\a\b`) } func (s *windowsFileSuite) TestFileOwner(c *gc.C) { c.Assert(utils.IsFileOwner("file://C:\\foo\\baz", "timmy"), gc.Equals, true) } ================================================ FILE: filepath/common.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath func splitSuffix(path string) (string, string) { for i := len(path) - 1; i >= 0; i-- { if path[i] == '.' && i > 0 { return path[:i], path[i:] } } return path, "" } ================================================ FILE: filepath/common_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath_test import ( "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/filepath" ) var _ = gc.Suite(&commonSuite{}) type commonSuite struct { testing.IsolationSuite } func (s commonSuite) TestSplitSuffixHasSuffix(c *gc.C) { path, suffix := filepath.SplitSuffix("spam.ext") c.Check(path, gc.Equals, "spam") c.Check(suffix, gc.Equals, ".ext") } func (s commonSuite) TestSplitSuffixNoSuffix(c *gc.C) { path, suffix := filepath.SplitSuffix("spam") c.Check(path, gc.Equals, "spam") c.Check(suffix, gc.Equals, "") } func (s commonSuite) TestSplitSuffixEmpty(c *gc.C) { path, suffix := filepath.SplitSuffix("") c.Check(path, gc.Equals, "") c.Check(suffix, gc.Equals, "") } func (s commonSuite) TestSplitSuffixDotFilePlain(c *gc.C) { path, suffix := filepath.SplitSuffix(".spam") c.Check(path, gc.Equals, ".spam") c.Check(suffix, gc.Equals, "") } func (s commonSuite) TestSplitSuffixDofileWithSuffix(c *gc.C) { path, suffix := filepath.SplitSuffix(".spam.ext") c.Check(path, gc.Equals, ".spam") c.Check(suffix, gc.Equals, ".ext") } ================================================ FILE: filepath/export_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath var ( SplitSuffix = splitSuffix ) ================================================ FILE: filepath/filepath.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath import ( "runtime" "strings" "github.com/juju/errors" "github.com/juju/utils/v4" ) // Renderer provides methods for the different functions in // the stdlib path/filepath package that don't relate to a concrete // filesystem. So Abs, EvalSymlinks, Glob, Rel, and Walk are not // included. Also, while the functions in path/filepath relate to the // current host, the PathRenderer methods relate to the renderer's // target platform. So for example, a windows-oriented implementation // will give windows-specific results even when used on linux. type Renderer interface { // Base mimics path/filepath. Base(path string) string // Clean mimics path/filepath. Clean(path string) string // Dir mimics path/filepath. Dir(path string) string // Ext mimics path/filepath. Ext(path string) string // FromSlash mimics path/filepath. FromSlash(path string) string // IsAbs mimics path/filepath. IsAbs(path string) bool // Join mimics path/filepath. Join(path ...string) string // Match mimics path/filepath. Match(pattern, name string) (matched bool, err error) // NormCase normalizes the case of a pathname. On Unix and Mac OS X, // this returns the path unchanged; on case-insensitive filesystems, // it converts the path to lowercase. NormCase(path string) string // Split mimics path/filepath. Split(path string) (dir, file string) // SplitList mimics path/filepath. SplitList(path string) []string // SplitSuffix splits the pathname into a pair (root, suffix) such // that root + suffix == path, and ext is empty or begins with a // period and contains at most one period. Leading periods on the // basename are ignored; SplitSuffix('.cshrc') returns ('.cshrc', ''). SplitSuffix(path string) (string, string) // ToSlash mimics path/filepath. ToSlash(path string) string // VolumeName mimics path/filepath. VolumeName(path string) string } // NewRenderer returns a Renderer for the given os. func NewRenderer(os string) (Renderer, error) { if os == "" { os = runtime.GOOS } os = strings.ToLower(os) switch { case os == utils.OSWindows: return &WindowsRenderer{}, nil case utils.OSIsUnix(os): return &UnixRenderer{}, nil case os == "ubuntu": return &UnixRenderer{}, nil default: return nil, errors.NotFoundf("renderer for %q", os) } } ================================================ FILE: filepath/filepath_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath_test import ( "runtime" "github.com/juju/errors" "github.com/juju/testing" jc "github.com/juju/testing/checkers" "github.com/juju/utils/v4" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/filepath" ) type filepathSuite struct { testing.IsolationSuite unix *filepath.UnixRenderer windows *filepath.WindowsRenderer } var _ = gc.Suite(&filepathSuite{}) func (s *filepathSuite) SetupTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) s.unix = &filepath.UnixRenderer{} s.windows = &filepath.WindowsRenderer{} } func (s filepathSuite) checkRenderer(c *gc.C, renderer filepath.Renderer, expected string) { switch expected { case "windows": c.Check(renderer, gc.FitsTypeOf, s.windows) case "unix": c.Check(renderer, gc.FitsTypeOf, s.unix) default: c.Errorf("unknown kind %q", expected) } } func (s filepathSuite) TestNewRendererDefault(c *gc.C) { // All possible values of runtime.GOOS should be supported. renderer, err := filepath.NewRenderer("") c.Assert(err, jc.ErrorIsNil) switch runtime.GOOS { case "windows": s.checkRenderer(c, renderer, "windows") default: s.checkRenderer(c, renderer, "unix") } } func (s filepathSuite) TestNewRendererGOOS(c *gc.C) { // All possible values of runtime.GOOS should be supported. renderer, err := filepath.NewRenderer(runtime.GOOS) c.Assert(err, jc.ErrorIsNil) switch runtime.GOOS { case "windows": s.checkRenderer(c, renderer, "windows") default: s.checkRenderer(c, renderer, "unix") } } func (s filepathSuite) TestNewRendererWindows(c *gc.C) { renderer, err := filepath.NewRenderer("windows") c.Assert(err, jc.ErrorIsNil) s.checkRenderer(c, renderer, "windows") } func (s filepathSuite) TestNewRendererUnix(c *gc.C) { for _, os := range utils.OSUnix { c.Logf("trying %q", os) renderer, err := filepath.NewRenderer(os) c.Assert(err, jc.ErrorIsNil) s.checkRenderer(c, renderer, "unix") } } func (s filepathSuite) TestNewRendererDistros(c *gc.C) { distros := []string{"ubuntu"} for _, distro := range distros { c.Logf("trying %q", distro) renderer, err := filepath.NewRenderer(distro) c.Assert(err, jc.ErrorIsNil) s.checkRenderer(c, renderer, "unix") } } func (s filepathSuite) TestNewRendererUnknown(c *gc.C) { _, err := filepath.NewRenderer("") c.Check(err, jc.Satisfies, errors.IsNotFound) } ================================================ FILE: filepath/interface_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath var _ Renderer = (*UnixRenderer)(nil) var _ Renderer = (*WindowsRenderer)(nil) ================================================ FILE: filepath/package_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath_test import ( "testing" gc "gopkg.in/check.v1" ) func Test(t *testing.T) { gc.TestingT(t) } ================================================ FILE: filepath/stdlib.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE.golang file. package filepath import ( "strings" ) // The following functions are adapted from the GO stdlib source. // Base mimics path/filepath for the given path separator. func Base(sep uint8, volumeName func(string) string, path string) string { if path == "" { return "." } // Strip trailing slashes. for len(path) > 0 && path[len(path)-1] == sep { path = path[0 : len(path)-1] } // Throw away volume name path = path[len(volumeName(path)):] // Find the last element i := len(path) - 1 for i >= 0 && path[i] != sep { i-- } if i >= 0 { path = path[i+1:] } // If empty now, it had only slashes. if path == "" { return string(sep) } return path } // A lazybuf is a lazily constructed path buffer. // It supports append, reading previously appended bytes, // and retrieving the final string. It does not allocate a buffer // to hold the output until that output diverges from s. type lazybuf struct { path string buf []byte w int volAndPath string volLen int } func (b *lazybuf) index(i int) byte { if b.buf != nil { return b.buf[i] } return b.path[i] } func (b *lazybuf) append(c byte) { if b.buf == nil { if b.w < len(b.path) && b.path[b.w] == c { b.w++ return } b.buf = make([]byte, len(b.path)) copy(b.buf, b.path[:b.w]) } b.buf[b.w] = c b.w++ } func (b *lazybuf) string() string { if b.buf == nil { return b.volAndPath[:b.volLen+b.w] } return b.volAndPath[:b.volLen] + string(b.buf[:b.w]) } // Clean mimics path/filepath for the given path separator. func Clean(sep uint8, volumeName func(string) string, path string) string { originalPath := path volLen := len(volumeName(path)) path = path[volLen:] if path == "" { if volLen > 1 && originalPath[1] != ':' { // should be UNC return FromSlash(sep, originalPath) } return originalPath + "." } rooted := (path[0] == sep) // Invariants: // reading from path; r is index of next byte to process. // writing to buf; w is index of next byte to write. // dotdot is index in buf where .. must stop, either because // it is the leading slash or it is a leading ../../.. prefix. n := len(path) out := lazybuf{path: path, volAndPath: originalPath, volLen: volLen} r, dotdot := 0, 0 if rooted { out.append(sep) r, dotdot = 1, 1 } for r < n { switch { case path[r] == sep: // empty path element r++ case path[r] == '.' && (r+1 == n || path[r+1] == sep): // . element r++ case path[r] == '.' && path[r+1] == '.' && (r+2 == n || path[r+2] == sep): // .. element: remove to last separator r += 2 switch { case out.w > dotdot: // can backtrack out.w-- for out.w > dotdot && out.index(out.w) != sep { out.w-- } case !rooted: // cannot backtrack, but not rooted, so append .. element. if out.w > 0 { out.append(sep) } out.append('.') out.append('.') dotdot = out.w } default: // real path element. // add slash if needed if rooted && out.w != 1 || !rooted && out.w != 0 { out.append(sep) } // copy element for ; r < n && path[r] != sep; r++ { out.append(path[r]) } } } // Turn empty string into "." if out.w == 0 { out.append('.') } return FromSlash(sep, out.string()) } // Dir mimics path/filepath for the given path separator. func Dir(sep uint8, volumeName func(string) string, path string) string { vol := volumeName(path) i := len(path) - 1 for i >= len(vol) && path[i] != sep { i-- } dir := Clean(sep, volumeName, path[len(vol):i+1]) return vol + dir } // Ext mimics path/filepath for the given path separator. func Ext(sep uint8, path string) string { for i := len(path) - 1; i >= 0 && path[i] != sep; i-- { if path[i] == '.' { return path[i:] } } return "" } // FromSlash mimics path/filepath for the given path separator. func FromSlash(sep uint8, path string) string { if sep == '/' { return path } return strings.Replace(path, "/", string(sep), -1) } // Join mimics path/filepath for the given path separator. func Join(sep uint8, volumeName func(string) string, elem ...string) string { for i, e := range elem { if e != "" { return Clean(sep, volumeName, strings.Join(elem[i:], string(sep))) } } return "" } // Split mimics path/filepath for the given path separator. func Split(sep uint8, volumeName func(string) string, path string) (dir, file string) { vol := volumeName(path) i := len(path) - 1 for i >= len(vol) && path[i] != sep { i-- } return path[:i+1], path[i+1:] } // ToSlash mimics path/filepath for the given path separator. func ToSlash(sep uint8, path string) string { if sep == '/' { return path } return strings.Replace(path, string(sep), "/", -1) } ================================================ FILE: filepath/stdlib_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath_test import ( gofilepath "path/filepath" "runtime" "strings" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/filepath" ) // The tests here are mostly just sanity checks against the behavior // of the stdlib path/filepath. We are not trying for high coverage levels. type stdlibSuite struct { testing.IsolationSuite path string volumeName func(string) string } var _ = gc.Suite(&stdlibSuite{}) func (s *stdlibSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) switch runtime.GOOS { case "windows": s.path = `C:\a\b\c.xyz` s.volumeName = func(path string) string { return "C:" } default: s.path = "/a/b/c.xyz" s.volumeName = func(string) string { return "" } } } func (s stdlibSuite) TestBase(c *gc.C) { path := filepath.Base(gofilepath.Separator, s.volumeName, s.path) gopath := gofilepath.Base(s.path) c.Check(path, gc.Equals, gopath) c.Check(path, gc.Equals, "c.xyz") } func (s stdlibSuite) TestClean(c *gc.C) { // TODO(ericsnow) Add more cases. originals := map[string]string{ s.path: s.path, } for original, expected := range originals { c.Logf("checking %q", original) path := filepath.Clean(gofilepath.Separator, s.volumeName, original) gopath := gofilepath.Clean(original) c.Check(path, gc.Equals, gopath) c.Check(path, gc.Equals, expected) } } func (s stdlibSuite) TestDir(c *gc.C) { path := filepath.Dir(gofilepath.Separator, s.volumeName, s.path) gopath := gofilepath.Dir(s.path) c.Check(path, gc.Equals, gopath) switch runtime.GOOS { case "windows": c.Check(path, gc.Equals, `\a\b`) default: c.Check(path, gc.Equals, "/a/b") } } func (s stdlibSuite) TestExt(c *gc.C) { ext := filepath.Ext(gofilepath.Separator, s.path) goext := gofilepath.Ext(s.path) c.Check(ext, gc.Equals, goext) c.Check(ext, gc.Equals, ".xyz") } func (s stdlibSuite) TestFromSlash(c *gc.C) { original := "/a/b/c.xyz" path := filepath.FromSlash(gofilepath.Separator, original) gopath := gofilepath.FromSlash(original) c.Check(path, gc.Equals, gopath) c.Check(path, gc.Equals, s.path) } func (s stdlibSuite) TestJoin(c *gc.C) { path := filepath.Join(gofilepath.Separator, s.volumeName, "a", "b", "c.xyz") gopath := gofilepath.Join("a", "b", "c.xyz") c.Check(path, gc.Equals, gopath) expected := s.path[strings.Index(s.path, string(gofilepath.Separator))+1:] c.Check(path, gc.Equals, expected) } func (s stdlibSuite) TestSplit(c *gc.C) { dir, base := filepath.Split(gofilepath.Separator, s.volumeName, s.path) godir, gobase := gofilepath.Split(s.path) c.Check(dir, gc.Equals, godir) c.Check(base, gc.Equals, gobase) switch runtime.GOOS { case "windows": c.Check(dir, gc.Equals, `\a\b\`) default: c.Check(dir, gc.Equals, "/a/b/") } c.Check(base, gc.Equals, "c.xyz") } func (s stdlibSuite) TestToSlash(c *gc.C) { path := filepath.ToSlash(gofilepath.Separator, s.path) gopath := gofilepath.ToSlash(s.path) c.Check(path, gc.Equals, gopath) c.Check(path, gc.Equals, "/a/b/c.xyz") } func (s stdlibSuite) TestMatchTrue(c *gc.C) { tests := map[string]string{ "abc": "abc", "ab[c]": "abc", "": "", "*": "abc", "a*c": "abc", "?": "a", "a?c": "abc", } for pattern, name := range tests { c.Logf("- checking pattern %q against %q -", pattern, name) matched, err := filepath.Match(gofilepath.Separator, pattern, name) c.Assert(err, jc.ErrorIsNil) gomatched, err := gofilepath.Match(pattern, name) c.Assert(err, jc.ErrorIsNil) c.Check(matched, gc.Equals, gomatched) c.Check(matched, jc.IsTrue) } } func (s stdlibSuite) TestMatchFalse(c *gc.C) { tests := map[string]string{ "abc": "xyz", "": "abc", "a*c": "a", "?": "", "a?c": "ac", } for pattern, name := range tests { c.Logf("- checking pattern %q against %q -", pattern, name) matched, err := filepath.Match(gofilepath.Separator, pattern, name) c.Assert(err, jc.ErrorIsNil) gomatched, err := gofilepath.Match(pattern, name) c.Assert(err, jc.ErrorIsNil) c.Check(matched, gc.Equals, gomatched) c.Check(matched, jc.IsFalse) } } func (s stdlibSuite) TestMatchBadPattern(c *gc.C) { tests := map[string]string{ "ab[": "abc", "ab[-c]": "abc", "ab[]": "abc", } for pattern, name := range tests { c.Logf("- checking pattern %q against %q -", pattern, name) _, err := filepath.Match(gofilepath.Separator, pattern, name) _, goerr := gofilepath.Match(pattern, name) c.Check(err, gc.Equals, goerr) c.Check(err, gc.Equals, gofilepath.ErrBadPattern) } } ================================================ FILE: filepath/stdlibmatch.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE.golang file. package filepath import ( "path/filepath" "strings" "unicode/utf8" ) // The following functions are adapted from the GO stdlib source. // Match returns true if name matches the shell file name pattern. // The pattern syntax is: // // pattern: // { term } // term: // '*' matches any sequence of non-Separator characters // '?' matches any single non-Separator character // '[' [ '^' ] { character-range } ']' // character class (must be non-empty) // c matches character c (c != '*', '?', '\\', '[') // '\\' c matches character c // // character-range: // c matches character c (c != '\\', '-', ']') // '\\' c matches character c // lo '-' hi matches character c for lo <= c <= hi // // Match requires pattern to match all of name, not just a substring. // The only possible returned error is ErrBadPattern, when pattern // is malformed. // // On Windows, escaping is disabled. Instead, '\\' is treated as // path separator. // func Match(sep uint8, pattern, name string) (matched bool, err error) { Pattern: for len(pattern) > 0 { var star bool var chunk string star, chunk, pattern = scanChunk(sep, pattern) if star && chunk == "" { // Trailing * matches rest of string unless it has a /. return strings.Index(name, string(sep)) < 0, nil } // Look for match at current position. t, ok, err := matchChunk(sep, chunk, name) // if we're the last chunk, make sure we've exhausted the name // otherwise we'll give a false result even if we could still match // using the star if ok && (len(t) == 0 || len(pattern) > 0) { name = t continue } if err != nil { return false, err } if star { // Look for match skipping i+1 bytes. // Cannot skip /. for i := 0; i < len(name) && name[i] != sep; i++ { t, ok, err := matchChunk(sep, chunk, name[i+1:]) if ok { // if we're the last chunk, make sure we exhausted the name if len(pattern) == 0 && len(t) > 0 { continue } name = t continue Pattern } if err != nil { return false, err } } } return false, nil } return len(name) == 0, nil } // scanChunk gets the next segment of pattern, which is a non-star string // possibly preceded by a star. func scanChunk(sep uint8, pattern string) (star bool, chunk, rest string) { for len(pattern) > 0 && pattern[0] == '*' { pattern = pattern[1:] star = true } inrange := false var i int Scan: for i = 0; i < len(pattern); i++ { switch pattern[i] { case '\\': if sep == '\\' { // error check handled in matchChunk: bad pattern. if i+1 < len(pattern) { i++ } } case '[': inrange = true case ']': inrange = false case '*': if !inrange { break Scan } } } return star, pattern[0:i], pattern[i:] } // matchChunk checks whether chunk matches the beginning of s. // If so, it returns the remainder of s (after the match). // Chunk is all single-character operators: literals, char classes, and ?. func matchChunk(sep uint8, chunk, s string) (rest string, ok bool, err error) { for len(chunk) > 0 { if len(s) == 0 { return } switch chunk[0] { case '[': // character class r, n := utf8.DecodeRuneInString(s) s = s[n:] chunk = chunk[1:] // We can't end right after '[', we're expecting at least // a closing bracket and possibly a caret. if len(chunk) == 0 { err = filepath.ErrBadPattern return } // possibly negated negated := chunk[0] == '^' if negated { chunk = chunk[1:] } // parse all ranges match := false nrange := 0 for { if len(chunk) > 0 && chunk[0] == ']' && nrange > 0 { chunk = chunk[1:] break } var lo, hi rune if lo, chunk, err = getEsc(sep, chunk); err != nil { return } hi = lo if chunk[0] == '-' { if hi, chunk, err = getEsc(sep, chunk[1:]); err != nil { return } } if lo <= r && r <= hi { match = true } nrange++ } if match == negated { return } case '?': if s[0] == sep { return } _, n := utf8.DecodeRuneInString(s) s = s[n:] chunk = chunk[1:] case '\\': if sep != '\\' { chunk = chunk[1:] if len(chunk) == 0 { err = filepath.ErrBadPattern return } } fallthrough default: if chunk[0] != s[0] { return } s = s[1:] chunk = chunk[1:] } } return s, true, nil } // getEsc gets a possibly-escaped character from chunk, for a character class. func getEsc(sep uint8, chunk string) (r rune, nchunk string, err error) { if len(chunk) == 0 || chunk[0] == '-' || chunk[0] == ']' { err = filepath.ErrBadPattern return } if chunk[0] == '\\' && sep != '\\' { chunk = chunk[1:] if len(chunk) == 0 { err = filepath.ErrBadPattern return } } r, n := utf8.DecodeRuneInString(chunk) if r == utf8.RuneError && n == 1 { err = filepath.ErrBadPattern } nchunk = chunk[n:] if len(nchunk) == 0 { err = filepath.ErrBadPattern } return } ================================================ FILE: filepath/unix.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath import ( "strings" ) // A substantial portion of this code comes from the Go stdlib code. const ( UnixSeparator = '/' // OS-specific path separator UnixListSeparator = ':' // OS-specific path list separator ) // UnixRenderer is a Renderer implementation for most flavors of Unix. type UnixRenderer struct{} // Base implements Renderer. func (ur UnixRenderer) Base(path string) string { return Base(UnixSeparator, ur.VolumeName, path) } // Clean implements Renderer. func (ur UnixRenderer) Clean(path string) string { return Clean(UnixSeparator, ur.VolumeName, path) } // Dir implements Renderer. func (ur UnixRenderer) Dir(path string) string { return Dir(UnixSeparator, ur.VolumeName, path) } // Ext implements Renderer. func (UnixRenderer) Ext(path string) string { return Ext(UnixSeparator, path) } // FromSlash implements Renderer. func (UnixRenderer) FromSlash(path string) string { return FromSlash(UnixSeparator, path) } // IsAbs implements Renderer. func (UnixRenderer) IsAbs(path string) bool { return strings.HasPrefix(path, string(UnixSeparator)) } // Join implements Renderer. func (ur UnixRenderer) Join(path ...string) string { return Join(UnixSeparator, ur.VolumeName, path...) } // Match implements Renderer. func (UnixRenderer) Match(pattern, name string) (matched bool, err error) { return Match(UnixSeparator, pattern, name) } // Split implements Renderer. func (ur UnixRenderer) Split(path string) (dir, file string) { return Split(UnixSeparator, ur.VolumeName, path) } // SplitList implements Renderer. func (UnixRenderer) SplitList(path string) []string { if path == "" { return []string{} } return strings.Split(path, string(UnixListSeparator)) } // ToSlash implements Renderer. func (UnixRenderer) ToSlash(path string) string { return ToSlash(UnixSeparator, path) } // VolumeName implements Renderer. func (UnixRenderer) VolumeName(path string) string { return "" } // NormCase implements Renderer. func (UnixRenderer) NormCase(path string) string { return path } // SplitSuffix implements Renderer. func (UnixRenderer) SplitSuffix(path string) (string, string) { return splitSuffix(path) } ================================================ FILE: filepath/unix_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath_test import ( gofilepath "path/filepath" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/filepath" ) var _ = gc.Suite(&unixSuite{}) var _ = gc.Suite(&unixThinWrapperSuite{}) type unixBaseSuite struct { testing.IsolationSuite path string renderer *filepath.UnixRenderer } func (s *unixBaseSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) s.path = "/a/b/c.xyz" s.renderer = &filepath.UnixRenderer{} } func (s *unixBaseSuite) matchesRuntime() bool { return gofilepath.Separator == filepath.UnixSeparator } type unixSuite struct { unixBaseSuite } func (s unixSuite) TestIsAbs(c *gc.C) { isAbs := s.renderer.IsAbs(s.path) c.Check(isAbs, jc.IsTrue) if s.matchesRuntime() { c.Check(isAbs, gc.Equals, gofilepath.IsAbs(s.path)) } } func (s unixSuite) TestSplitList(c *gc.C) { list := s.renderer.SplitList("/a:b:/c/d") c.Check(list, jc.DeepEquals, []string{"/a", "b", "/c/d"}) if s.matchesRuntime() { golist := gofilepath.SplitList("/a:b:/c/d") c.Check(list, jc.DeepEquals, golist) } } func (s unixSuite) TestVolumeName(c *gc.C) { volumeName := s.renderer.VolumeName(s.path) c.Check(volumeName, gc.Equals, "") } func (s unixSuite) TestNormCaseLower(c *gc.C) { normalized := s.renderer.NormCase("spam") c.Check(normalized, gc.Equals, "spam") } func (s unixSuite) TestNormCaseUpper(c *gc.C) { normalized := s.renderer.NormCase("SPAM") c.Check(normalized, gc.Equals, "SPAM") } func (s unixSuite) TestNormCaseMixed(c *gc.C) { normalized := s.renderer.NormCase("sPaM") c.Check(normalized, gc.Equals, "sPaM") } func (s unixSuite) TestNormCaseCapitalized(c *gc.C) { normalized := s.renderer.NormCase("Spam") c.Check(normalized, gc.Equals, "Spam") } func (s unixSuite) TestNormCasePunctuation(c *gc.C) { normalized := s.renderer.NormCase("spam-eggs.ext") c.Check(normalized, gc.Equals, "spam-eggs.ext") } func (s unixSuite) TestSplitSuffix(c *gc.C) { // This is just a sanity check. The splitSuffix tests are more // comprehensive. path, suffix := s.renderer.SplitSuffix("spam.ext") c.Check(path, gc.Equals, "spam") c.Check(suffix, gc.Equals, ".ext") } // unixThinWrapperSuite contains test methods for UnixRenderer methods // that are just thin wrappers around the corresponding helpers in the // filepath package. As such the test coverage is minimal (more of a // sanity check). type unixThinWrapperSuite struct { unixBaseSuite } func (s unixThinWrapperSuite) TestBase(c *gc.C) { path := s.renderer.Base(s.path) c.Check(path, gc.Equals, "c.xyz") if s.matchesRuntime() { gopath := gofilepath.Base(s.path) c.Check(path, gc.Equals, gopath) } } func (s unixThinWrapperSuite) TestClean(c *gc.C) { // TODO(ericsnow) Add more cases. originals := map[string]string{ s.path: s.path, } for original, expected := range originals { c.Logf("checking %q", original) path := s.renderer.Clean(original) c.Check(path, gc.Equals, expected) if s.matchesRuntime() { gopath := gofilepath.Clean(original) c.Check(path, gc.Equals, gopath) } } } func (s unixThinWrapperSuite) TestDir(c *gc.C) { path := s.renderer.Dir(s.path) c.Check(path, gc.Equals, "/a/b") if s.matchesRuntime() { gopath := gofilepath.Dir(s.path) c.Check(path, gc.Equals, gopath) } } func (s unixThinWrapperSuite) TestExt(c *gc.C) { ext := s.renderer.Ext(s.path) c.Check(ext, gc.Equals, ".xyz") if s.matchesRuntime() { goext := gofilepath.Ext(s.path) c.Check(ext, gc.Equals, goext) } } func (s unixThinWrapperSuite) TestFromSlash(c *gc.C) { original := "/a/b/c.xyz" path := s.renderer.FromSlash(original) c.Check(path, gc.Equals, s.path) if s.matchesRuntime() { gopath := gofilepath.FromSlash(original) c.Check(path, gc.Equals, gopath) } } func (s unixThinWrapperSuite) TestJoin(c *gc.C) { path := s.renderer.Join("a", "b", "c.xyz") c.Check(path, gc.Equals, s.path[1:]) if s.matchesRuntime() { gopath := gofilepath.Join("a", "b", "c.xyz") c.Check(path, gc.Equals, gopath) } } func (s unixThinWrapperSuite) TestSplit(c *gc.C) { dir, base := s.renderer.Split(s.path) c.Check(dir, gc.Equals, "/a/b/") c.Check(base, gc.Equals, "c.xyz") if s.matchesRuntime() { godir, gobase := gofilepath.Split(s.path) c.Check(dir, gc.Equals, godir) c.Check(base, gc.Equals, gobase) } } func (s unixThinWrapperSuite) TestToSlash(c *gc.C) { path := s.renderer.ToSlash(s.path) c.Check(path, gc.Equals, "/a/b/c.xyz") if s.matchesRuntime() { gopath := gofilepath.ToSlash(s.path) c.Check(path, gc.Equals, gopath) } } func (s unixThinWrapperSuite) TestMatchTrue(c *gc.C) { tests := map[string]string{ "abc": "abc", "ab[c]": "abc", "": "", "*": "abc", "a*c": "abc", "?": "a", "a?c": "abc", } for pattern, name := range tests { c.Logf("- checking pattern %q against %q -", pattern, name) matched, err := s.renderer.Match(pattern, name) c.Assert(err, jc.ErrorIsNil) c.Check(matched, jc.IsTrue) if s.matchesRuntime() { gomatched, err := gofilepath.Match(pattern, name) c.Assert(err, jc.ErrorIsNil) c.Check(matched, gc.Equals, gomatched) } } } func (s unixThinWrapperSuite) TestMatchFalse(c *gc.C) { tests := map[string]string{ "abc": "xyz", "": "abc", "a*c": "a", "?": "", "a?c": "ac", } for pattern, name := range tests { c.Logf("- checking pattern %q against %q -", pattern, name) matched, err := s.renderer.Match(pattern, name) c.Assert(err, jc.ErrorIsNil) c.Check(matched, jc.IsFalse) if s.matchesRuntime() { gomatched, err := gofilepath.Match(pattern, name) c.Assert(err, jc.ErrorIsNil) c.Check(matched, gc.Equals, gomatched) } } } func (s unixThinWrapperSuite) TestMatchBadPattern(c *gc.C) { tests := map[string]string{ "ab[": "abc", "ab[-c]": "abc", "ab[]": "abc", } for pattern, name := range tests { c.Logf("- checking pattern %q against %q -", pattern, name) _, err := s.renderer.Match(pattern, name) c.Check(err, gc.Equals, gofilepath.ErrBadPattern) if s.matchesRuntime() { _, goerr := gofilepath.Match(pattern, name) c.Check(err, gc.Equals, goerr) } } } ================================================ FILE: filepath/win.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath import ( "strings" ) // A substantial portion of this code comes from the Go stdlib code. const ( WindowsSeparator = '\\' // OS-specific path separator WindowsListSeparator = ';' // OS-specific path list separator ) // WindowsRenderer is a Renderer implementation for Windows. type WindowsRenderer struct{} // Base implements Renderer. func (ur WindowsRenderer) Base(path string) string { return Base(WindowsSeparator, ur.VolumeName, path) } // Clean implements Renderer. func (ur WindowsRenderer) Clean(path string) string { return Clean(WindowsSeparator, ur.VolumeName, path) } // Dir implements Renderer. func (ur WindowsRenderer) Dir(path string) string { return Dir(WindowsSeparator, ur.VolumeName, path) } // Ext implements Renderer. func (WindowsRenderer) Ext(path string) string { return Ext(WindowsSeparator, path) } // FromSlash implements Renderer. func (WindowsRenderer) FromSlash(path string) string { return FromSlash(WindowsSeparator, path) } // IsAbs implements Renderer. func (WindowsRenderer) IsAbs(path string) bool { l := volumeNameLen(path) if l == 0 { return false } path = path[l:] if path == "" { return false } return isSlash(path[0]) } // Join implements Renderer. func (ur WindowsRenderer) Join(path ...string) string { return Join(WindowsSeparator, ur.VolumeName, path...) } // Match implements Renderer. func (WindowsRenderer) Match(pattern, name string) (matched bool, err error) { return Match(WindowsSeparator, pattern, name) } // Split implements Renderer. func (ur WindowsRenderer) Split(path string) (dir, file string) { return Split(WindowsSeparator, ur.VolumeName, path) } // SplitList implements Renderer. func (WindowsRenderer) SplitList(path string) []string { if path == "" { return []string{} } // Split path, respecting but preserving quotes. list := []string{} start := 0 quo := false for i := 0; i < len(path); i++ { switch c := path[i]; { case c == '"': quo = !quo case c == WindowsListSeparator && !quo: list = append(list, path[start:i]) start = i + 1 } } list = append(list, path[start:]) // Remove quotes. for i, s := range list { if strings.Contains(s, `"`) { list[i] = strings.Replace(s, `"`, ``, -1) } } return list } // ToSlash implements Renderer. func (WindowsRenderer) ToSlash(path string) string { return ToSlash(WindowsSeparator, path) } // VolumeName implements Renderer. func (WindowsRenderer) VolumeName(path string) string { return path[:volumeNameLen(path)] } // NormCase implements Renderer. func (WindowsRenderer) NormCase(path string) string { return strings.ToLower(path) } // SplitSuffix implements Renderer. func (WindowsRenderer) SplitSuffix(path string) (string, string) { return splitSuffix(path) } func isSlash(c uint8) bool { return c == WindowsSeparator || c == '/' } // volumeNameLen returns length of the leading volume name on Windows. // It returns 0 elsewhere. func volumeNameLen(path string) int { if len(path) < 2 { return 0 } // with drive letter c := path[0] if path[1] == ':' && ('a' <= c && c <= 'z' || 'A' <= c && c <= 'Z') { return 2 } // is it UNC if l := len(path); l >= 5 && isSlash(path[0]) && isSlash(path[1]) && !isSlash(path[2]) && path[2] != '.' { // first, leading `\\` and next shouldn't be `\`. its server name. for n := 3; n < l-1; n++ { // second, next '\' shouldn't be repeated. if isSlash(path[n]) { n++ // third, following something characters. its share name. if !isSlash(path[n]) { if path[n] == '.' { break } for ; n < l; n++ { if isSlash(path[n]) { break } } return n } break } } } return 0 } ================================================ FILE: filepath/win_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filepath_test import ( gofilepath "path/filepath" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/filepath" ) var _ = gc.Suite(&windowsSuite{}) var _ = gc.Suite(&windowsThinWrapperSuite{}) type windowsBaseSuite struct { testing.IsolationSuite path string renderer *filepath.WindowsRenderer } func (s *windowsBaseSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) s.path = `c:\a\b\c.xyz` s.renderer = &filepath.WindowsRenderer{} } func (s *windowsBaseSuite) matchesRuntime() bool { return gofilepath.Separator == filepath.WindowsSeparator } type windowsSuite struct { windowsBaseSuite } func (s windowsSuite) TestIsAbs(c *gc.C) { isAbs := s.renderer.IsAbs(s.path) c.Check(isAbs, jc.IsTrue) if s.matchesRuntime() { c.Check(isAbs, gc.Equals, gofilepath.IsAbs(s.path)) } } func (s windowsSuite) TestSplitList(c *gc.C) { list := s.renderer.SplitList(`\a;b;\c\d`) c.Check(list, jc.DeepEquals, []string{`\a`, "b", `\c\d`}) if s.matchesRuntime() { golist := gofilepath.SplitList(`\a;b;\c\d`) c.Check(list, jc.DeepEquals, golist) } } func (s windowsSuite) TestVolumeName(c *gc.C) { volumeName := s.renderer.VolumeName(s.path) c.Check(volumeName, gc.Equals, "c:") if s.matchesRuntime() { goresult := gofilepath.VolumeName(s.path) c.Check(volumeName, gc.Equals, goresult) } } func (s windowsSuite) TestNormCaseLower(c *gc.C) { normalized := s.renderer.NormCase("spam") c.Check(normalized, gc.Equals, "spam") } func (s windowsSuite) TestNormCaseUpper(c *gc.C) { normalized := s.renderer.NormCase("SPAM") c.Check(normalized, gc.Equals, "spam") } func (s windowsSuite) TestNormCaseMixed(c *gc.C) { normalized := s.renderer.NormCase("sPaM") c.Check(normalized, gc.Equals, "spam") } func (s windowsSuite) TestNormCaseCapitalized(c *gc.C) { normalized := s.renderer.NormCase("Spam") c.Check(normalized, gc.Equals, "spam") } func (s windowsSuite) TestNormCasePunctuation(c *gc.C) { normalized := s.renderer.NormCase("spam-eggs.ext") c.Check(normalized, gc.Equals, "spam-eggs.ext") } func (s windowsSuite) TestSplitSuffix(c *gc.C) { // This is just a sanity check. The splitSuffix tests are more // comprehensive. path, suffix := s.renderer.SplitSuffix("spam.ext") c.Check(path, gc.Equals, "spam") c.Check(suffix, gc.Equals, ".ext") } // windowsThinWrapperSuite contains test methods for WindowsRenderer methods // that are just thin wrappers around the corresponding helpers in the // filepath package. As such the test coverage is minimal (more of a // sanity check). type windowsThinWrapperSuite struct { windowsBaseSuite } func (s windowsThinWrapperSuite) TestBase(c *gc.C) { path := s.renderer.Base(s.path) c.Check(path, gc.Equals, "c.xyz") if s.matchesRuntime() { gopath := gofilepath.Base(s.path) c.Check(path, gc.Equals, gopath) } } func (s windowsThinWrapperSuite) TestClean(c *gc.C) { // TODO(ericsnow) Add more cases. originals := map[string]string{ s.path: s.path, } for original, expected := range originals { c.Logf("checking %q", original) path := s.renderer.Clean(original) c.Check(path, gc.Equals, expected) if s.matchesRuntime() { gopath := gofilepath.Clean(original) c.Check(path, gc.Equals, gopath) } } } func (s windowsThinWrapperSuite) TestDir(c *gc.C) { path := s.renderer.Dir(s.path) c.Check(path, gc.Equals, `c:\a\b`) if s.matchesRuntime() { gopath := gofilepath.Dir(s.path) c.Check(path, gc.Equals, gopath) } } func (s windowsThinWrapperSuite) TestExt(c *gc.C) { ext := s.renderer.Ext(s.path) c.Check(ext, gc.Equals, ".xyz") if s.matchesRuntime() { goext := gofilepath.Ext(s.path) c.Check(ext, gc.Equals, goext) } } func (s windowsThinWrapperSuite) TestFromSlash(c *gc.C) { original := "/a/b/c.xyz" path := s.renderer.FromSlash(original) c.Check(path, gc.Equals, s.path[2:]) if s.matchesRuntime() { gopath := gofilepath.FromSlash(original) c.Check(path, gc.Equals, gopath) } } func (s windowsThinWrapperSuite) TestJoin(c *gc.C) { path := s.renderer.Join("a", "b", "c.xyz") c.Check(path, gc.Equals, s.path[3:]) if s.matchesRuntime() { gopath := gofilepath.Join("a", "b", "c.xyz") c.Check(path, gc.Equals, gopath) } } func (s windowsThinWrapperSuite) TestSplit(c *gc.C) { dir, base := s.renderer.Split(s.path) c.Check(dir, gc.Equals, `c:\a\b\`) c.Check(base, gc.Equals, "c.xyz") if s.matchesRuntime() { godir, gobase := gofilepath.Split(s.path) c.Check(dir, gc.Equals, godir) c.Check(base, gc.Equals, gobase) } } func (s windowsThinWrapperSuite) TestToSlash(c *gc.C) { path := s.renderer.ToSlash(s.path) c.Check(path, gc.Equals, "c:/a/b/c.xyz") if s.matchesRuntime() { gopath := gofilepath.ToSlash(s.path) c.Check(path, gc.Equals, gopath) } } func (s windowsThinWrapperSuite) TestMatchTrue(c *gc.C) { tests := map[string]string{ "abc": "abc", "ab[c]": "abc", "": "", "*": "abc", "a*c": "abc", "?": "a", "a?c": "abc", } for pattern, name := range tests { c.Logf("- checking pattern %q against %q -", pattern, name) matched, err := s.renderer.Match(pattern, name) c.Assert(err, jc.ErrorIsNil) c.Check(matched, jc.IsTrue) if s.matchesRuntime() { gomatched, err := gofilepath.Match(pattern, name) c.Assert(err, jc.ErrorIsNil) c.Check(matched, gc.Equals, gomatched) } } } func (s windowsThinWrapperSuite) TestMatchFalse(c *gc.C) { tests := map[string]string{ "abc": "xyz", "": "abc", "a*c": "a", "?": "", "a?c": "ac", } for pattern, name := range tests { c.Logf("- checking pattern %q against %q -", pattern, name) matched, err := s.renderer.Match(pattern, name) c.Assert(err, jc.ErrorIsNil) c.Check(matched, jc.IsFalse) if s.matchesRuntime() { gomatched, err := gofilepath.Match(pattern, name) c.Assert(err, jc.ErrorIsNil) c.Check(matched, gc.Equals, gomatched) } } } func (s windowsThinWrapperSuite) TestMatchBadPattern(c *gc.C) { tests := map[string]string{ "ab[": "abc", "ab[-c]": "abc", "ab[]": "abc", } for pattern, name := range tests { c.Logf("- checking pattern %q against %q -", pattern, name) _, err := s.renderer.Match(pattern, name) c.Check(err, gc.Equals, gofilepath.ErrBadPattern) if s.matchesRuntime() { _, goerr := gofilepath.Match(pattern, name) c.Check(err, gc.Equals, goerr) } } } ================================================ FILE: filestorage/doc.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. /* utils/filestorage provides types for abstracting and implementing a system that stores files, including their metadata. Each file in the system is identified by a unique ID, determined by the system at the time the file is stored. File metadata includes such information as the size of the file, its checksum, and when it was created. Regardless of how it is stored in the system, at the abstraction level it is represented as a document. Metadata can exist in the system without an associated file. However, every file must have a corresponding metadata doc stored in the system. A file can be added for a metadata doc that does not have one already. The main type is the FileStorage interface. It exposes the core functionality of such a system. This includes adding/removing files, retrieving them or their metadata, and listing all files in the system. The package also provides a basic implementation of FileStorage, available through NewFileStorage(). This implementation simply wraps two more focused systems: doc storage and raw file storage. The wrapper uses the doc storage to store the metadata and raw file storage to store the files. The two subsystems are exposed via corresponding interfaces: DocStorage (and its specialization MetadataStorage) and RawFileStorage. While a single type could implement both, in practice they will be separate. The doc storage is responsible to generating the unique IDs. The raw file storage defers to the doc storage for any information about the file, including the ID. */ package filestorage ================================================ FILE: filestorage/export_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filestorage ================================================ FILE: filestorage/fakes_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filestorage_test import ( "io" "github.com/juju/errors" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/filestorage" ) // FakeMetadataStorage is used as a DocStorage and MetadataStorage for // testing purposes. type FakeMetadataStorage struct { calls []string id string meta filestorage.Metadata metaList []filestorage.Metadata err error idArg string metaArg filestorage.Metadata } // Check verfies the state of the fake. func (s *FakeMetadataStorage) Check(c *gc.C, id string, meta filestorage.Metadata, calls ...string) { c.Check(s.calls, jc.DeepEquals, calls) c.Check(s.idArg, gc.Equals, id) c.Check(s.metaArg, gc.Equals, meta) } func (s *FakeMetadataStorage) Doc(id string) (filestorage.Document, error) { s.calls = append(s.calls, "Doc") s.idArg = id if s.err != nil { return nil, s.err } return s.meta, nil } func (s *FakeMetadataStorage) ListDocs() ([]filestorage.Document, error) { s.calls = append(s.calls, "ListDoc") if s.err != nil { return nil, s.err } var docs []filestorage.Document for _, doc := range s.metaList { docs = append(docs, doc) } return docs, nil } func (s *FakeMetadataStorage) AddDoc(doc filestorage.Document) (string, error) { s.calls = append(s.calls, "AddDoc") meta, err := filestorage.Convert(doc) if err != nil { return "", errors.Trace(err) } s.metaArg = meta return s.id, nil } func (s *FakeMetadataStorage) RemoveDoc(id string) error { s.calls = append(s.calls, "RemoveDoc") s.idArg = id return s.err } func (s *FakeMetadataStorage) Close() error { s.calls = append(s.calls, "Close") return s.err } func (s *FakeMetadataStorage) Metadata(id string) (filestorage.Metadata, error) { s.calls = append(s.calls, "Metadata") s.idArg = id if s.err != nil { return nil, s.err } return s.meta, nil } func (s *FakeMetadataStorage) ListMetadata() ([]filestorage.Metadata, error) { s.calls = append(s.calls, "ListMetadata") if s.err != nil { return nil, s.err } return s.metaList, nil } func (s *FakeMetadataStorage) AddMetadata(meta filestorage.Metadata) (string, error) { s.calls = append(s.calls, "AddMetadata") s.metaArg = meta if s.err != nil { return "", s.err } return s.id, nil } func (s *FakeMetadataStorage) RemoveMetadata(id string) error { s.calls = append(s.calls, "RemoveMetadata") s.idArg = id return s.err } func (s *FakeMetadataStorage) SetStored(id string) error { s.calls = append(s.calls, "SetStored") s.idArg = id return s.err } // FakeRawFileStorage is used in testing as a RawFileStorage. type FakeRawFileStorage struct { calls []string file io.ReadCloser err error idArg string fileArg io.Reader sizeArg int64 } // Check verfies the state of the fake. func (s *FakeRawFileStorage) Check(c *gc.C, id string, file io.Reader, size int64, calls ...string) { c.Check(s.calls, jc.DeepEquals, calls) c.Check(s.idArg, gc.Equals, id) c.Check(s.fileArg, gc.Equals, file) c.Check(s.sizeArg, gc.Equals, size) } // CheckNotUsed verifies that the fake was not used. func (s *FakeRawFileStorage) CheckNotUsed(c *gc.C) { s.Check(c, "", nil, 0) } func (s *FakeRawFileStorage) File(id string) (io.ReadCloser, error) { s.calls = append(s.calls, "File") s.idArg = id if s.err != nil { return nil, s.err } return s.file, nil } func (s *FakeRawFileStorage) AddFile(id string, file io.Reader, size int64) error { s.calls = append(s.calls, "AddFile") s.idArg = id s.fileArg = file s.sizeArg = size return s.err } func (s *FakeRawFileStorage) RemoveFile(id string) error { s.calls = append(s.calls, "RemoveFile") s.idArg = id return s.err } func (s *FakeRawFileStorage) Close() error { s.calls = append(s.calls, "Close") return s.err } ================================================ FILE: filestorage/interfaces.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filestorage import ( "io" "time" ) // FileStorage is an abstraction that can be used for the storage of files. type FileStorage interface { io.Closer // Metadata returns a file's metadata. Metadata(id string) (Metadata, error) // Get returns a file and its metadata. Get(id string) (Metadata, io.ReadCloser, error) // List returns the metadata for each stored file. List() ([]Metadata, error) // Add stores a file and its metadata. Add(meta Metadata, archive io.Reader) (string, error) // SetFile stores a file for an existing metadata entry. SetFile(id string, file io.Reader) error // Remove removes a file from storage. Remove(id string) error } // Document represents a document that can be identified uniquely // by a string. type Document interface { // ID returns the unique ID of the document. ID() string // SetID sets the ID of the document. If the ID is already set, // SetID() should return true (false otherwise). SetID(id string) (alreadySet bool) } // Metadata is the meta information for a stored file. type Metadata interface { Document // Size is the size of the file (in bytes). Size() int64 // Checksum is the checksum for the file. Checksum() string // ChecksumFormat is the kind (and encoding) of checksum. ChecksumFormat() string // Stored returns when the file was last stored. If it has not been // stored yet, nil is returned. If it has been stored but the // timestamp is not available, a zero value is returned // (see Time.IsZero). Stored() *time.Time // SetFileInfo sets the file info on the metadata. SetFileInfo(size int64, checksum, checksumFormat string) error // SetStored records when the file was last stored. If the previous // value matters, be sure to call Stored() first. SetStored(timestamp *time.Time) } // DocStorage is an abstraction for a system that can store docs (structs). // The system is expected to generate its own unique ID for each doc. type DocStorage interface { io.Closer // Doc returns the doc that matches the ID. If there is no match, // an error is returned (see errors.IsNotFound). Any other problem // also results in an error. Doc(id string) (Document, error) // ListDocs returns a list of all the docs in the storage. ListDocs() ([]Document, error) // AddDoc adds the doc to the storage. If successful, the storage- // generated ID for the doc is returned. Otherwise an error is // returned. AddDoc(doc Document) (string, error) // RemoveDoc removes the matching doc from the storage. If there // is no match an error is returned (see errors.IsNotFound). Any // other problem also results in an error. RemoveDoc(id string) error } // RawFileStorage is an abstraction around a system that can store files. // The system is expected to rely on the user for unique IDs. type RawFileStorage interface { io.Closer // File returns the matching file. If there is no match an error is // returned (see errors.IsNotFound). Any other problem also results // in an error. File(id string) (io.ReadCloser, error) // AddFile adds the file to the storage. If it fails to do so, // it returns an error. If a file is already stored for the ID, // AddFile() fails (see errors.IsAlreadyExists). AddFile(id string, file io.Reader, size int64) error // RemoveFile removes the matching file from the storage. It fails // if there is no such file (see errors.IsNotFound). Any other problem // also results in an error. RemoveFile(id string) error } // MetadataStorage is an extension of DocStorage adapted to file metadata. type MetadataStorage interface { io.Closer // Metadata returns the matching Metadata. It fails if there is no // match (see errors.IsNotFound). Any other problems likewise // results in an error. Metadata(id string) (Metadata, error) // ListMetadata returns a list of all metadata in the storage. ListMetadata() ([]Metadata, error) // AddMetadata adds the metadata to the storage. If successful, the // storage-generated ID for the metadata is returned. Otherwise an // error is returned. AddMetadata(meta Metadata) (string, error) // RemoveMetadata removes the matching metadata from the storage. // If there is no match an error is returned (see errors.IsNotFound). // Any other problem also results in an error. RemoveMetadata(id string) error // SetStored updates the stored metadata to indicate that the // associated file has been successfully stored in a RawFileStorage // system. If it does not find a stored metadata with the matching // ID, it will return an error (see errors.IsNotFound). It also // returns an error if it fails to update the stored metadata. SetStored(id string) error } ================================================ FILE: filestorage/metadata.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filestorage import ( "time" "github.com/juju/errors" ) // RawDoc is a basic, uniquely identifiable document. type RawDoc struct { // ID is the unique identifier for the document. ID string } // Doc wraps a document in the Document interface. type Doc struct { Raw RawDoc } // ID returns the document's unique identifier. func (d *Doc) ID() string { return d.Raw.ID } // SetID sets the document's unique identifier. If the ID is already // set, SetID() returns true (false otherwise). func (d *Doc) SetID(id string) bool { if d.Raw.ID != "" { return true } d.Raw.ID = id return false } // RawFileMetadata holds info specific to stored files. type RawFileMetadata struct { // Size is the size (in bytes) of the stored file. Size int64 // Checksum is the checksum of the stored file. Checksum string // ChecksumFormat describes the kind of the checksum. ChecksumFormat string // Stored records the timestamp of when the file was last stored. Stored *time.Time } // FileMetadata contains the metadata for a single stored file. type FileMetadata struct { Doc Raw RawFileMetadata } // NewMetadata returns a new Metadata for a stored file. func NewMetadata() *FileMetadata { meta := FileMetadata{} return &meta } func (m *FileMetadata) Size() int64 { return m.Raw.Size } func (m *FileMetadata) Checksum() string { return m.Raw.Checksum } func (m *FileMetadata) ChecksumFormat() string { return m.Raw.ChecksumFormat } func (m *FileMetadata) Stored() *time.Time { return m.Raw.Stored } func (m *FileMetadata) SetFileInfo(size int64, checksum, format string) error { // Fall back to existing values. if size == 0 { size = m.Raw.Size } if checksum == "" { checksum = m.Raw.Checksum } if format == "" { format = m.Raw.ChecksumFormat } if checksum != "" { if format == "" { return errors.Errorf("missing checksum format") } } else if format != "" { return errors.Errorf("missing checksum") } // Only allow setting once. if m.Raw.Size != 0 && size != m.Raw.Size { return errors.Errorf("file information (size) already set") } if m.Raw.Checksum != "" && checksum != m.Raw.Checksum { return errors.Errorf("file information (checksum) already set") } if m.Raw.ChecksumFormat != "" && format != m.Raw.ChecksumFormat { return errors.Errorf("file information (checksum format) already set") } // Set the values. m.Raw.Size = size m.Raw.Checksum = checksum m.Raw.ChecksumFormat = format return nil } func (m *FileMetadata) SetStored(timestamp *time.Time) { if timestamp == nil { now := time.Now().UTC() m.Raw.Stored = &now } else { m.Raw.Stored = timestamp } } ================================================ FILE: filestorage/metadata_store.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filestorage import ( "github.com/juju/errors" ) // Convert turns a Document into a Metadata if possible. func Convert(doc Document) (Metadata, error) { meta, ok := doc.(Metadata) if !ok { return nil, errors.Errorf("expected a Metadata doc, got %v", doc) } return meta, nil } // MetadataDocStorage provides the MetadataStorage methods than can be // derived from DocStorage methods. To fully implement MetadataStorage, // this type must be embedded in a type that implements the remaining // methods. type MetadataDocStorage struct { DocStorage } // Metadata implements MetadataStorage.Metadata. func (s *MetadataDocStorage) Metadata(id string) (Metadata, error) { doc, err := s.Doc(id) if err != nil { return nil, errors.Trace(err) } meta, err := Convert(doc) return meta, errors.Trace(err) } // ListMetadata implements MetadataStorage.ListMetadata. func (s *MetadataDocStorage) ListMetadata() ([]Metadata, error) { docs, err := s.ListDocs() if err != nil { return nil, errors.Trace(err) } var metaList []Metadata for _, doc := range docs { if doc == nil { continue } meta, err := Convert(doc) if err != nil { return nil, errors.Trace(err) } metaList = append(metaList, meta) } return metaList, nil } // ListMetadata implements MetadataStorage.ListMetadata. func (s *MetadataDocStorage) AddMetadata(meta Metadata) (string, error) { id, err := s.AddDoc(meta) return id, errors.Trace(err) } // ListMetadata implements MetadataStorage.ListMetadata. func (s *MetadataDocStorage) RemoveMetadata(id string) error { return errors.Trace(s.RemoveDoc(id)) } ================================================ FILE: filestorage/metadata_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filestorage_test import ( "time" "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/filestorage" ) var ( _ filestorage.Document = (*filestorage.Doc)(nil) _ filestorage.Metadata = (*filestorage.FileMetadata)(nil) ) var _ = gc.Suite(&MetadataSuite{}) type MetadataSuite struct { testing.IsolationSuite } func (s *MetadataSuite) TestFileMetadataNewMetadata(c *gc.C) { meta := filestorage.NewMetadata() c.Check(meta.ID(), gc.Equals, "") c.Check(meta.Size(), gc.Equals, int64(0)) c.Check(meta.Checksum(), gc.Equals, "") c.Check(meta.ChecksumFormat(), gc.Equals, "") c.Check(meta.Stored(), gc.IsNil) } func (s *MetadataSuite) TestFileMetadataSetIDInitial(c *gc.C) { meta := filestorage.NewMetadata() meta.SetFileInfo(10, "some sum", "SHA-1") c.Assert(meta.ID(), gc.Equals, "") success := meta.SetID("some id") c.Check(success, gc.Equals, false) c.Check(meta.ID(), gc.Equals, "some id") } func (s *MetadataSuite) TestFileMetadataSetIDAlreadySetSame(c *gc.C) { meta := filestorage.NewMetadata() meta.SetFileInfo(10, "some sum", "SHA-1") success := meta.SetID("some id") c.Assert(success, gc.Equals, false) success = meta.SetID("some id") c.Check(success, gc.Equals, true) c.Check(meta.ID(), gc.Equals, "some id") } func (s *MetadataSuite) TestFileMetadataSetIDAlreadySetDifferent(c *gc.C) { meta := filestorage.NewMetadata() meta.SetFileInfo(10, "some sum", "SHA-1") success := meta.SetID("some id") c.Assert(success, gc.Equals, false) success = meta.SetID("another id") c.Check(success, gc.Equals, true) c.Check(meta.ID(), gc.Equals, "some id") } func (s *MetadataSuite) TestFileMetadataSetFileInfo(c *gc.C) { meta := filestorage.NewMetadata() c.Assert(meta.Size(), gc.Equals, int64(0)) c.Assert(meta.Checksum(), gc.Equals, "") c.Assert(meta.ChecksumFormat(), gc.Equals, "") c.Assert(meta.Stored(), gc.IsNil) meta.SetFileInfo(10, "some sum", "SHA-1") c.Check(meta.Size(), gc.Equals, int64(10)) c.Check(meta.Checksum(), gc.Equals, "some sum") c.Check(meta.ChecksumFormat(), gc.Equals, "SHA-1") c.Check(meta.Stored(), gc.IsNil) } func (s *MetadataSuite) TestFileMetadataSetStored(c *gc.C) { meta := filestorage.NewMetadata() timestamp := time.Now().UTC() meta.SetStored(×tamp) c.Check(meta.Stored(), gc.Equals, ×tamp) } func (s *MetadataSuite) TestFileMetadataSetStoredDefault(c *gc.C) { meta := filestorage.NewMetadata() c.Assert(meta.Stored(), gc.IsNil) meta.SetStored(nil) c.Check(meta.Stored(), gc.NotNil) } ================================================ FILE: filestorage/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filestorage_test import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: filestorage/wrapper.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filestorage import ( "io" "github.com/juju/errors" ) // Ensure fileStorage implements FileStorage. var _ = FileStorage((*fileStorage)(nil)) type fileStorage struct { metaStorage MetadataStorage rawStorage RawFileStorage } // NewFileStorage returns a new FileStorage value that wraps a // MetadataStorage and a RawFileStorage. It coordinates the two even // though they may not be designed to be compatible (or the two may be // the same value). // // A stored file will always have a metadata value stored. However, it // is not required to have a raw file stored. func NewFileStorage(meta MetadataStorage, files RawFileStorage) FileStorage { stor := fileStorage{ metaStorage: meta, rawStorage: files, } return &stor } // Metadata returns the matching metadata. Failure to find it (see // errors.IsNotFound) or any other problem results in an error. func (s *fileStorage) Metadata(id string) (Metadata, error) { meta, err := s.metaStorage.Metadata(id) if err != nil { return nil, errors.Trace(err) } return meta, nil } // Get returns the matching file and its associated metadata. If there // is no match (see errors.IsNotFound) or any other problem, it returns // an error. Both the metadata and file must have been stored for the // file to be considered found. func (s *fileStorage) Get(id string) (Metadata, io.ReadCloser, error) { meta, err := s.Metadata(id) if err != nil { return nil, nil, errors.Trace(err) } if meta.Stored() == nil { return nil, nil, errors.NotFoundf("no file stored for %q", id) } file, err := s.rawStorage.File(id) if err != nil { return nil, nil, errors.Trace(err) } return meta, file, nil } // List returns a list of the metadata for all files in the storage. func (s *fileStorage) List() ([]Metadata, error) { return s.metaStorage.ListMetadata() } func (s *fileStorage) addFile(id string, size int64, file io.Reader) error { err := s.rawStorage.AddFile(id, file, size) if err != nil { return errors.Trace(err) } err = s.metaStorage.SetStored(id) if err != nil { return errors.Trace(err) } return nil } // Add adds the file to the storage. It returns the unique ID generated // by the storage for the file. If no file is provided, only the // metadata is stored. While the passed-in "meta" is not modified, the // new ID and "stored" flag will be saved in metadata storage. Feel // free to explicitly call meta.SetID() and meta.SetStored() afterward. // // Any problem (including an existing file, see errors.IsAlreadyExists) // results in an error. If there is an error while storing either the // file or metadata, neither will be stored. func (s *fileStorage) Add(meta Metadata, file io.Reader) (string, error) { id, err := s.metaStorage.AddMetadata(meta) if err != nil { return "", errors.Trace(err) } if file != nil { err = s.addFile(id, meta.Size(), file) if err != nil { // Remove the metadata we just added. context := err err = s.metaStorage.RemoveMetadata(id) if err != nil { err = errors.Annotate(err, "while handling another error") return "", errors.Wrap(context, err) } return "", errors.Trace(context) } } return id, nil } // SetFile stores the raw file for an existing metadata. If there is no // matching stored metadata an error is returned (see errors.IsNotFound). // If a file has already been stored an error is returned (see // errors.IsAlreadyExists). Any other failure to add the file also // results in an error. func (s *fileStorage) SetFile(id string, file io.Reader) error { meta, err := s.Metadata(id) if err != nil { return errors.Trace(err) } err = s.addFile(id, meta.Size(), file) if err != nil { return errors.Trace(err) } return nil } // Remove removes both the metadata and raw file from the storage. If // there is no match an error is returned (see errors.IsNotFound). // // The raw file is removed first. Thus if there is any problem after // removing the raw file, the metadata will still be stored. However, // in that case the stored metadata is not guaranteed to accurately // represent that there is no corresponding raw file in storage. func (s *fileStorage) Remove(id string) error { err := s.rawStorage.RemoveFile(id) if err != nil && !errors.IsNotFound(err) { return errors.Trace(err) } err = s.metaStorage.RemoveMetadata(id) if err != nil { return errors.Trace(err) } return nil } // Close implements io.Closer.Close. func (s *fileStorage) Close() error { ferr := s.rawStorage.Close() merr := s.metaStorage.Close() if ferr == nil { return errors.Trace(merr) } else if merr == nil { return errors.Trace(ferr) } else { msg := "closing both failed: metadata (%v) and files (%v)" return errors.Errorf(msg, merr, ferr) } } ================================================ FILE: filestorage/wrapper_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package filestorage_test import ( "bytes" "io" "io/ioutil" "github.com/juju/errors" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/filestorage" ) var _ = gc.Suite(&WrapperSuite{}) type WrapperSuite struct { testing.IsolationSuite rawstor *FakeRawFileStorage metastor *FakeMetadataStorage stor filestorage.FileStorage } func (s *WrapperSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) s.rawstor = &FakeRawFileStorage{} s.metastor = &FakeMetadataStorage{} s.stor = filestorage.NewFileStorage(s.metastor, s.rawstor) } func (s *WrapperSuite) metadata() filestorage.Metadata { meta := filestorage.NewMetadata() meta.SetFileInfo(10, "", "") return meta } func (s *WrapperSuite) setMeta() (string, filestorage.Metadata) { id := "" meta := s.metadata() meta.SetID(id) s.metastor.meta = meta s.metastor.metaList = append(s.metastor.metaList, meta) return id, meta } func (s *WrapperSuite) setFile(data string) (string, filestorage.Metadata, io.ReadCloser) { id, meta := s.setMeta() file := ioutil.NopCloser(bytes.NewBufferString(data)) s.rawstor.file = file meta.SetStored(nil) return id, meta, file } func (s *WrapperSuite) TestFileStorageNewFileStorage(c *gc.C) { stor := filestorage.NewFileStorage(s.metastor, s.rawstor) c.Check(stor, gc.NotNil) } func (s *WrapperSuite) TestFileStorageMetadata(c *gc.C) { id, original := s.setMeta() meta, err := s.stor.Metadata(id) c.Assert(err, gc.IsNil) c.Check(meta, jc.DeepEquals, original) s.metastor.Check(c, id, nil, "Metadata") s.rawstor.CheckNotUsed(c) } func (s *WrapperSuite) TestFileStorageGet(c *gc.C) { id, origmeta, origfile := s.setFile("spam") meta, file, err := s.stor.Get(id) c.Assert(err, gc.IsNil) c.Check(meta, gc.Equals, origmeta) c.Check(file, gc.Equals, origfile) } func (s *WrapperSuite) TestFileStorageListEmpty(c *gc.C) { list, err := s.stor.List() c.Assert(err, gc.IsNil) c.Check(list, gc.HasLen, 0) } func (s *WrapperSuite) TestFileStorageListOne(c *gc.C) { id, _ := s.setMeta() list, err := s.stor.List() c.Assert(err, gc.IsNil) c.Check(list, gc.HasLen, 1) c.Assert(list[0], gc.NotNil) c.Check(list[0].ID(), gc.Equals, id) } func (s *WrapperSuite) TestFileStorageListTwo(c *gc.C) { id1, _ := s.setMeta() id2, _ := s.setMeta() list, err := s.stor.List() c.Assert(err, gc.IsNil) c.Assert(list, gc.HasLen, 2) c.Assert(list[0], gc.NotNil) c.Assert(list[1], gc.NotNil) if list[0].ID() == id1 { c.Check(list[1].ID(), gc.Equals, id2) } else { c.Check(list[1].ID(), gc.Equals, id1) } } func (s *WrapperSuite) TestFileStorageAddMeta(c *gc.C) { s.metastor.id = "" meta := s.metadata() c.Assert(meta.ID(), gc.Equals, "") id, err := s.stor.Add(meta, nil) c.Assert(err, gc.IsNil) c.Check(id, gc.Equals, "") c.Check(meta.ID(), gc.Equals, "") s.metastor.Check(c, "", meta, "AddMetadata") s.rawstor.CheckNotUsed(c) } func (s *WrapperSuite) TestFileStorageAddFile(c *gc.C) { s.metastor.id = "" var file *bytes.Buffer meta := s.metadata() id, err := s.stor.Add(meta, file) c.Assert(err, gc.IsNil) c.Check(meta.ID(), gc.Equals, "") c.Check(meta.Stored(), gc.IsNil) c.Check(id, gc.Equals, "") c.Check(meta.ID(), gc.Equals, "") s.metastor.Check(c, id, meta, "AddMetadata", "SetStored") s.rawstor.Check(c, id, file, 10, "AddFile") } func (s *WrapperSuite) TestFileStorageAddIDNotSet(c *gc.C) { original := s.metadata() c.Assert(original.ID(), gc.Equals, "") _, err := s.stor.Add(original, nil) c.Check(err, gc.IsNil) c.Check(original.ID(), gc.Equals, "") } func (s *WrapperSuite) TestFileStorageAddMetaOnly(c *gc.C) { id, original := s.setMeta() meta, err := s.stor.Metadata(id) c.Assert(err, gc.IsNil) c.Check(meta, gc.Equals, original) c.Check(meta.Stored(), gc.IsNil) } func (s *WrapperSuite) TestFileStorageAddIDAlreadySet(c *gc.C) { original := s.metadata() original.SetID("eggs") _, err := s.stor.Add(original, nil) c.Check(err, gc.IsNil) // This should be handled at the lower level. } func (s *WrapperSuite) TestFileStorageAddFileFailureDropsMetadata(c *gc.C) { original := s.metadata() failure := errors.New("failed!") raw := &FakeRawFileStorage{err: failure} stor := filestorage.NewFileStorage(s.metastor, raw) _, err := stor.Add(original, &bytes.Buffer{}) c.Assert(errors.Cause(err), gc.Equals, failure) metalist, metaErr := s.metastor.ListMetadata() c.Assert(metaErr, gc.IsNil) c.Check(metalist, gc.HasLen, 0) c.Check(original.ID(), gc.Equals, "") } func (s *WrapperSuite) TestFileStorageSetFile(c *gc.C) { id, _ := s.setMeta() _, _, err := s.stor.Get(id) c.Assert(err, gc.NotNil) file := bytes.NewBufferString("spam") err = s.stor.SetFile(id, file) c.Assert(err, gc.IsNil) s.metastor.Check(c, id, nil, "Metadata", "Metadata", "SetStored") s.rawstor.Check(c, id, file, 10, "AddFile") } func (s *WrapperSuite) TestFileStorageRemove(c *gc.C) { id := "" err := s.stor.Remove(id) c.Assert(err, gc.IsNil) s.metastor.Check(c, id, nil, "RemoveMetadata") s.rawstor.Check(c, id, nil, 0, "RemoveFile") } func (s *WrapperSuite) TestClose(c *gc.C) { metaStor := &FakeMetadataStorage{} fileStor := &FakeRawFileStorage{} stor := filestorage.NewFileStorage(metaStor, fileStor) err := stor.Close() c.Assert(err, gc.IsNil) c.Check(metaStor.calls, gc.DeepEquals, []string{"Close"}) c.Check(fileStor.calls, gc.DeepEquals, []string{"Close"}) } ================================================ FILE: fs/copy.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package fs import ( "fmt" "io" "os" "path/filepath" ) // Copy recursively copies the file, directory or symbolic link at src // to dst. The destination must not exist. Symbolic links are not // followed. // // If the copy fails half way through, the destination might be left // partially written. func Copy(src, dst string) error { srcInfo, srcErr := os.Lstat(src) if srcErr != nil { return srcErr } _, dstErr := os.Lstat(dst) if dstErr == nil { // TODO(rog) add a flag to permit overwriting? return fmt.Errorf("will not overwrite %q", dst) } if !os.IsNotExist(dstErr) { return dstErr } switch mode := srcInfo.Mode(); mode & os.ModeType { case os.ModeSymlink: return copySymLink(src, dst) case os.ModeDir: return copyDir(src, dst, mode) case 0: return copyFile(src, dst, mode) default: return fmt.Errorf("cannot copy file with mode %v", mode) } } func copySymLink(src, dst string) error { target, err := os.Readlink(src) if err != nil { return err } return os.Symlink(target, dst) } func copyFile(src, dst string, mode os.FileMode) error { srcf, err := os.Open(src) if err != nil { return err } defer srcf.Close() dstf, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode.Perm()) if err != nil { return err } defer dstf.Close() // Make the actual permissions match the source permissions // even in the presence of umask. if err := os.Chmod(dstf.Name(), mode.Perm()); err != nil { return err } if _, err := io.Copy(dstf, srcf); err != nil { return fmt.Errorf("cannot copy %q to %q: %v", src, dst, err) } return nil } func copyDir(src, dst string, mode os.FileMode) error { srcf, err := os.Open(src) if err != nil { return err } defer srcf.Close() if mode&0500 == 0 { // The source directory doesn't have write permission, // so give the new directory write permission anyway // so that we have permission to create its contents. // We'll make the permissions match at the end. mode |= 0500 } if err := os.Mkdir(dst, mode.Perm()); err != nil { return err } for { names, err := srcf.Readdirnames(100) for _, name := range names { if err := Copy(filepath.Join(src, name), filepath.Join(dst, name)); err != nil { return err } } if err == io.EOF { break } if err != nil { return fmt.Errorf("error reading directory %q: %v", src, err) } } if err := os.Chmod(dst, mode.Perm()); err != nil { return err } return nil } ================================================ FILE: fs/copy_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package fs_test import ( "path/filepath" "testing" ft "github.com/juju/testing/filetesting" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/fs" ) type copySuite struct{} var _ = gc.Suite(©Suite{}) func TestPackage(t *testing.T) { gc.TestingT(t) } var copyTests = []struct { about string src ft.Entries dst ft.Entries err string }{{ about: "one file", src: []ft.Entry{ ft.File{"file", "data", 0756}, }, }, { about: "one directory", src: []ft.Entry{ ft.Dir{"dir", 0777}, }, }, { about: "one symlink", src: []ft.Entry{ ft.Symlink{"link", "/foo"}, }, }, { about: "several entries", src: []ft.Entry{ ft.Dir{"top", 0755}, ft.File{"top/foo", "foodata", 0644}, ft.File{"top/bar", "bardata", 0633}, ft.Dir{"top/next", 0721}, ft.Symlink{"top/next/link", "../foo"}, ft.File{"top/next/another", "anotherdata", 0644}, }, }, { about: "destination already exists", src: []ft.Entry{ ft.Dir{"dir", 0777}, }, dst: []ft.Entry{ ft.Dir{"dir", 0777}, }, err: `will not overwrite ".+dir"`, }, { about: "source with unwritable directory", src: []ft.Entry{ ft.Dir{"dir", 0555}, }, }} func (*copySuite) TestCopy(c *gc.C) { for i, test := range copyTests { c.Logf("test %d: %v", i, test.about) src, dst := c.MkDir(), c.MkDir() test.src.Create(c, src) test.dst.Create(c, dst) path := test.src[0].GetPath() err := fs.Copy( filepath.Join(src, path), filepath.Join(dst, path), ) if test.err != "" { c.Check(err, gc.ErrorMatches, test.err) } else { c.Assert(err, gc.IsNil) test.src.Check(c, dst) } } } ================================================ FILE: go.mod ================================================ module github.com/juju/utils/v4 go 1.24.4 require ( github.com/juju/clock v1.0.3 github.com/juju/collections v1.0.4 github.com/juju/errors v1.0.0 github.com/juju/loggo/v2 v2.0.0 github.com/juju/mutex/v2 v2.0.0 github.com/juju/testing v1.2.0 golang.org/x/crypto v0.39.0 golang.org/x/net v0.41.0 golang.org/x/text v0.26.0 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 gopkg.in/yaml.v2 v2.4.0 ) require ( github.com/juju/loggo v1.0.0 // indirect github.com/juju/utils/v3 v3.1.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/term v0.32.0 // indirect ) ================================================ FILE: go.sum ================================================ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU= github.com/juju/clock v1.0.3 h1:yJHIsWXeU8j3QcBdiess09SzfiXRRrsjKPn2whnMeds= github.com/juju/clock v1.0.3/go.mod h1:HIBvJ8kiV/n7UHwKuCkdYL4l/MDECztHR2sAvWDxxf0= github.com/juju/collections v1.0.4 h1:GjL+aN512m2rVDqhPII7P6qB0e+iYFubz8sqBhZaZtk= github.com/juju/collections v1.0.4/go.mod h1:hVrdB0Zwq9wIU1Fl6ItD2+UETeNeOEs+nGvJufVe+0c= github.com/juju/errors v1.0.0 h1:yiq7kjCLll1BiaRuNY53MGI0+EQ3rF6GB+wvboZDefM= github.com/juju/errors v1.0.0/go.mod h1:B5x9thDqx0wIMH3+aLIMP9HjItInYWObRovoCFM5Qe8= github.com/juju/loggo v1.0.0 h1:Y6ZMQOGR9Aj3BGkiWx7HBbIx6zNwNkxhVNOHU2i1bl0= github.com/juju/loggo v1.0.0/go.mod h1:NIXFioti1SmKAlKNuUwbMenNdef59IF52+ZzuOmHYkg= github.com/juju/loggo/v2 v2.0.0 h1:PzyVIn+NgoZ22QUtPgKF/lh+6SnaCOEXhcP+sE4FhOk= github.com/juju/loggo/v2 v2.0.0/go.mod h1:647d6WvXBLj5lvka2qBvccr7vMIvF2KFkEH+0ZuFOUM= github.com/juju/mutex/v2 v2.0.0 h1:rVmJdOaXGWF8rjcFHBNd4x57/1tks5CgXHx55O55SB0= github.com/juju/mutex/v2 v2.0.0/go.mod h1:jwCfBs/smYDaeZLqeaCi8CB8M+tOes4yf827HoOEoqk= github.com/juju/testing v1.2.0 h1:Q0wxjaxx4XPVEN+SgzxKr3d82pjmSBcuM3WndAU391c= github.com/juju/testing v1.2.0/go.mod h1:lqZVzNwBKAbylGZidK77ts6kIdoOkmD52+4m0ysetPo= github.com/juju/utils/v3 v3.1.0 h1:NrNo73oVtfr7kLP17/BDpubXwa7YEW16+Ult6z9kpHI= github.com/juju/utils/v3 v3.1.0/go.mod h1:nAj3sHtdYfAkvnkqttTy3Xzm2HzkD9Hfgnc+upOW2Z8= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lunixbochs/vtclean v0.0.0-20160125035106-4fbf7632a2c6/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mattn/go-colorable v0.0.6/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-isatty v0.0.0-20160806122752-66b8e73f3f5c/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20160105164936-4f90aeace3a2/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= ================================================ FILE: gomaxprocs.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "os" "runtime" ) var gomaxprocs = runtime.GOMAXPROCS var numCPU = runtime.NumCPU // UseMultipleCPUs sets GOMAXPROCS to the number of CPU cores unless it has // already been overridden by the GOMAXPROCS environment variable. func UseMultipleCPUs() { if envGOMAXPROCS := os.Getenv("GOMAXPROCS"); envGOMAXPROCS != "" { n := gomaxprocs(0) logger.Debugf("GOMAXPROCS already set in environment to %q, %d internally", envGOMAXPROCS, n) return } n := numCPU() logger.Debugf("setting GOMAXPROCS to %d", n) gomaxprocs(n) } ================================================ FILE: gomaxprocs_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "os" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type gomaxprocsSuite struct { testing.IsolationSuite setmaxprocs chan int numCPUResponse int setMaxProcs int } var _ = gc.Suite(&gomaxprocsSuite{}) func (s *gomaxprocsSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) // always stub out GOMAXPROCS so we don't actually change anything s.numCPUResponse = 2 s.setMaxProcs = -1 maxProcsFunc := func(n int) int { s.setMaxProcs = n return 1 } numCPUFunc := func() int { return s.numCPUResponse } s.PatchValue(utils.GOMAXPROCS, maxProcsFunc) s.PatchValue(utils.NumCPU, numCPUFunc) s.PatchEnvironment("GOMAXPROCS", "") } func (s *gomaxprocsSuite) TestUseMultipleCPUsDoesNothingWhenGOMAXPROCSSet(c *gc.C) { err := os.Setenv("GOMAXPROCS", "1") c.Assert(err, jc.ErrorIsNil) utils.UseMultipleCPUs() c.Check(s.setMaxProcs, gc.Equals, 0) } func (s *gomaxprocsSuite) TestUseMultipleCPUsWhenEnabled(c *gc.C) { utils.UseMultipleCPUs() c.Check(s.setMaxProcs, gc.Equals, 2) s.numCPUResponse = 4 utils.UseMultipleCPUs() c.Check(s.setMaxProcs, gc.Equals, 4) } ================================================ FILE: hash/fingerprint.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package hash import ( "encoding/base64" "encoding/hex" "hash" "io" "github.com/juju/errors" ) // Fingerprint represents the checksum for some data. type Fingerprint struct { sum []byte } // NewFingerprint returns wraps the provided raw hash sum. This function // roundtrips with Fingerprint.Bytes(). func NewFingerprint(sum []byte, validate func([]byte) error) (Fingerprint, error) { if validate == nil { return Fingerprint{}, errors.New("missing validate func") } if err := validate(sum); err != nil { return Fingerprint{}, errors.Trace(err) } return newFingerprint(sum), nil } // NewValidFingerprint returns a Fingerprint corresponding // to the current of the provided hash. func NewValidFingerprint(hash hash.Hash) Fingerprint { sum := hash.Sum(nil) return newFingerprint(sum) } func newFingerprint(sum []byte) Fingerprint { return Fingerprint{ sum: append([]byte{}, sum...), // Use an isolated copy. } } // GenerateFingerprint returns the fingerprint for the provided data. func GenerateFingerprint(reader io.Reader, newHash func() hash.Hash) (Fingerprint, error) { var fp Fingerprint if reader == nil { return fp, errors.New("missing reader") } if newHash == nil { return fp, errors.New("missing new hash func") } hash := newHash() if _, err := io.Copy(hash, reader); err != nil { return fp, errors.Trace(err) } fp.sum = hash.Sum(nil) return fp, nil } // ParseHexFingerprint returns wraps the provided raw fingerprint string. // This function roundtrips with Fingerprint.Hex(). func ParseHexFingerprint(hexSum string, validate func([]byte) error) (Fingerprint, error) { if validate == nil { return Fingerprint{}, errors.New("missing validate func") } sum, err := hex.DecodeString(hexSum) if err != nil { return Fingerprint{}, errors.Trace(err) } fp, err := NewFingerprint(sum, validate) if err != nil { return Fingerprint{}, errors.Trace(err) } return fp, nil } // ParseBase64Fingerprint returns wraps the provided raw fingerprint string. // This function roundtrips with Fingerprint.Base64(). func ParseBase64Fingerprint(b64Sum string, validate func([]byte) error) (Fingerprint, error) { if validate == nil { return Fingerprint{}, errors.New("missing validate func") } sum, err := base64.StdEncoding.DecodeString(b64Sum) if err != nil { return Fingerprint{}, errors.Trace(err) } fp, err := NewFingerprint(sum, validate) if err != nil { return Fingerprint{}, errors.Trace(err) } return fp, nil } // String implements fmt.Stringer. func (fp Fingerprint) String() string { return fp.Hex() } // Hex returns the hex string representation of the fingerprint. func (fp Fingerprint) Hex() string { return hex.EncodeToString(fp.sum) } // Base64 returns the base64 encoded fingerprint. func (fp Fingerprint) Base64() string { return base64.StdEncoding.EncodeToString(fp.sum) } // Bytes returns the raw (sum) bytes of the fingerprint. func (fp Fingerprint) Bytes() []byte { return append([]byte{}, fp.sum...) } // IsZero returns whether or not the fingerprint is the zero value. func (fp Fingerprint) IsZero() bool { return len(fp.sum) == 0 } // Validate returns an error if the fingerprint is invalid. func (fp Fingerprint) Validate() error { if fp.IsZero() { return errors.NotValidf("zero-value fingerprint") } return nil } ================================================ FILE: hash/fingerprint_test.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package hash_test import ( "crypto/sha512" "encoding/hex" stdhash "hash" "github.com/juju/errors" "github.com/juju/testing" jc "github.com/juju/testing/checkers" "github.com/juju/testing/filetesting" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/hash" ) var _ = gc.Suite(&FingerprintSuite{}) type FingerprintSuite struct { stub *testing.Stub hash *filetesting.StubHash } func (s *FingerprintSuite) SetUpTest(c *gc.C) { s.stub = &testing.Stub{} s.hash = filetesting.NewStubHash(s.stub, nil) } func (s *FingerprintSuite) newHash() stdhash.Hash { s.stub.AddCall("newHash") s.stub.NextErr() // Pop one off. return s.hash } func (s *FingerprintSuite) validate(sum []byte) error { s.stub.AddCall("validate", sum) if err := s.stub.NextErr(); err != nil { return errors.Trace(err) } return nil } func (s *FingerprintSuite) TestNewFingerprintOkay(c *gc.C) { expected, _ := newFingerprint(c, "spamspamspam") fp, err := hash.NewFingerprint(expected, s.validate) c.Assert(err, jc.ErrorIsNil) sum := fp.Bytes() s.stub.CheckCallNames(c, "validate") c.Check(sum, jc.DeepEquals, expected) } func (s *FingerprintSuite) TestNewFingerprintInvalid(c *gc.C) { expected, _ := newFingerprint(c, "spamspamspam") failure := errors.NewNotValid(nil, "bogus!!!") s.stub.SetErrors(failure) _, err := hash.NewFingerprint(expected, s.validate) s.stub.CheckCallNames(c, "validate") c.Check(errors.Cause(err), gc.Equals, failure) } func (s *FingerprintSuite) TestNewValidFingerprint(c *gc.C) { expected, _ := newFingerprint(c, "spamspamspam") s.hash.ReturnSum = expected fp := hash.NewValidFingerprint(s.hash) sum := fp.Bytes() s.stub.CheckCallNames(c, "Sum") c.Check(sum, jc.DeepEquals, expected) } func (s *FingerprintSuite) TestGenerateFingerprintOkay(c *gc.C) { expected, _ := newFingerprint(c, "spamspamspam") s.hash.ReturnSum = expected s.hash.Writer, _ = filetesting.NewStubWriter(s.stub) reader := filetesting.NewStubReader(s.stub, "spamspamspam") fp, err := hash.GenerateFingerprint(reader, s.newHash) c.Assert(err, jc.ErrorIsNil) sum := fp.Bytes() s.stub.CheckCallNames(c, "newHash", "Read", "Write", "Read", "Sum") c.Check(sum, jc.DeepEquals, expected) } func (s *FingerprintSuite) TestGenerateFingerprintNil(c *gc.C) { _, err := hash.GenerateFingerprint(nil, s.newHash) s.stub.CheckNoCalls(c) c.Check(err, gc.ErrorMatches, `missing reader`) } func (s *FingerprintSuite) TestParseHexFingerprint(c *gc.C) { expected, hexSum := newFingerprint(c, "spamspamspam") fp, err := hash.ParseHexFingerprint(hexSum, s.validate) c.Assert(err, jc.ErrorIsNil) sum := fp.Bytes() s.stub.CheckCallNames(c, "validate") c.Check(sum, jc.DeepEquals, expected) } func (s *FingerprintSuite) TestString(c *gc.C) { sum, expected := newFingerprint(c, "spamspamspam") fp, err := hash.NewFingerprint(sum, s.validate) c.Assert(err, jc.ErrorIsNil) hex := fp.String() c.Check(hex, gc.Equals, expected) } func (s *FingerprintSuite) TestHex(c *gc.C) { sum, expected := newFingerprint(c, "spamspamspam") fp, err := hash.NewFingerprint(sum, s.validate) c.Assert(err, jc.ErrorIsNil) hex := fp.String() c.Check(hex, gc.Equals, expected) } func (s *FingerprintSuite) TestBytes(c *gc.C) { expected, _ := newFingerprint(c, "spamspamspam") fp, err := hash.NewFingerprint(expected, s.validate) c.Assert(err, jc.ErrorIsNil) sum := fp.Bytes() c.Check(sum, jc.DeepEquals, expected) } func (s *FingerprintSuite) TestValidateOkay(c *gc.C) { sum, _ := newFingerprint(c, "spamspamspam") fp, err := hash.NewFingerprint(sum, s.validate) c.Assert(err, jc.ErrorIsNil) err = fp.Validate() c.Check(err, jc.ErrorIsNil) } func (s *FingerprintSuite) TestValidateZero(c *gc.C) { var fp hash.Fingerprint err := fp.Validate() c.Check(err, jc.Satisfies, errors.IsNotValid) c.Check(err, gc.ErrorMatches, `zero-value fingerprint not valid`) } func newFingerprint(c *gc.C, data string) ([]byte, string) { hash := sha512.New384() _, err := hash.Write([]byte(data)) c.Assert(err, jc.ErrorIsNil) sum := hash.Sum(nil) hexStr := hex.EncodeToString(sum) return sum, hexStr } ================================================ FILE: hash/hash.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // The hash package provides utilities that support use of the stdlib // hash.Hash. Most notably is the Fingerprint type that wraps the // checksum of a hash. // // Conversion between checksums and strings are facailitated through // Fingerprint. // // Here are some hash-related recipes that bring it all together: // // - Extract the SHA384 hash while writing to elsewhere, then get the // raw checksum: // // newHash, _ := hash.SHA384() // h := newHash() // hashingWriter := io.MultiWriter(writer, h) // if err := writeAll(hashingWriter); err != nil { ... } // fp := hash.NewValidFingerprint(h) // checksum := fp.Bytes() // // - Extract the SHA384 hash while reading from elsewhere, then get the // hex-encoded checksum to send over the wire: // // newHash, _ := hash.SHA384() // h := newHash() // hashingReader := io.TeeReader(reader, h) // if err := processStream(hashingReader); err != nil { ... } // fp := hash.NewValidFingerprint(h) // hexSum := fp.Hex() // req.Header.Set("Content-Sha384", hexSum) // // * Turn a checksum sent over the wire back into a fingerprint: // // _, validate := hash.SHA384() // hexSum := req.Header.Get("Content-Sha384") // var fp hash.Fingerprint // if len(hexSum) != 0 { // fp, err = hash.ParseHexFingerprint(hexSum, validate) // ... // } // if fp.IsZero() { // ... // } package hash import ( "crypto/sha512" "hash" "github.com/juju/errors" ) // SHA384 returns the newHash and validate functions for use // with SHA384 hashes. SHA384 is used in several key places in Juju. func SHA384() (newHash func() hash.Hash, validate func([]byte) error) { const digestLenBytes = 384 / 8 validate = newSizeChecker(digestLenBytes) return sha512.New384, validate } func newSizeChecker(size int) func([]byte) error { return func(sum []byte) error { if len(sum) < size { return errors.NewNotValid(nil, "invalid fingerprint (too small)") } if len(sum) > size { return errors.NewNotValid(nil, "invalid fingerprint (too big)") } return nil } } ================================================ FILE: hash/hash_test.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package hash_test import ( "bytes" "io" "io/ioutil" "strings" "github.com/juju/testing" jc "github.com/juju/testing/checkers" "github.com/juju/testing/filetesting" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/hash" ) var _ = gc.Suite(&HashSuite{}) type HashSuite struct { testing.IsolationSuite } func (s *HashSuite) TestHashingWriter(c *gc.C) { data := "some data" newHash, _ := hash.SHA384() expected, err := hash.GenerateFingerprint(strings.NewReader(data), newHash) c.Assert(err, jc.ErrorIsNil) var writer bytes.Buffer h := newHash() hashingWriter := io.MultiWriter(&writer, h) _, err = hashingWriter.Write([]byte(data)) c.Assert(err, jc.ErrorIsNil) fp := hash.NewValidFingerprint(h) c.Check(fp, jc.DeepEquals, expected) c.Check(writer.String(), gc.Equals, data) } func (s *HashSuite) TestHashingReader(c *gc.C) { expected := "some data" stub := &testing.Stub{} reader := &filetesting.StubReader{ Stub: stub, ReturnRead: &fakeStream{ data: expected, }, } newHash, validate := hash.SHA384() h := newHash() hashingReader := io.TeeReader(reader, h) data, err := ioutil.ReadAll(hashingReader) c.Assert(err, jc.ErrorIsNil) fp := hash.NewValidFingerprint(h) hexSum := fp.Hex() fpAgain, err := hash.ParseHexFingerprint(hexSum, validate) c.Assert(err, jc.ErrorIsNil) stub.CheckCallNames(c, "Read") // The EOF was mixed with the data. c.Check(string(data), gc.Equals, expected) c.Check(fpAgain, jc.DeepEquals, fp) } type fakeStream struct { data string pos uint64 } func (f *fakeStream) Read(data []byte) (int, error) { n := copy(data, f.data[f.pos:]) f.pos += uint64(n) if f.pos >= uint64(len(f.data)) { return n, io.EOF } return n, nil } ================================================ FILE: hash/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package hash_test import ( stdtesting "testing" gc "gopkg.in/check.v1" ) func Test(t *stdtesting.T) { gc.TestingT(t) } ================================================ FILE: hash/writer.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package hash import ( "encoding/base64" "hash" "io" ) // TODO(ericsnow) Remove HashingWriter and NewHashingWriter(). // HashingWriter wraps an io.Writer, providing the checksum of all data // written to it. A HashingWriter may be used in place of the writer it // wraps. // // Note: HashingWriter is deprecated. Please do not use it. We will // remove it ASAP. type HashingWriter struct { hash hash.Hash wrapped io.Writer } // NewHashingWriter returns a new HashingWriter that wraps the provided // writer and the hasher. // // Example: // hw := NewHashingWriter(w, sha1.New()) // io.Copy(hw, reader) // hash := hw.Base64Sum() // // Note: NewHashingWriter is deprecated. Please do not use it. We will // remove it ASAP. func NewHashingWriter(writer io.Writer, hasher hash.Hash) *HashingWriter { return &HashingWriter{ hash: hasher, wrapped: io.MultiWriter(writer, hasher), } } // Base64Sum returns the base64 encoded hash. func (hw HashingWriter) Base64Sum() string { sumBytes := hw.hash.Sum(nil) return base64.StdEncoding.EncodeToString(sumBytes) } // Write writes to both the wrapped file and the hash. func (hw *HashingWriter) Write(data []byte) (int, error) { // No trace because some callers, like ioutil.ReadAll(), won't work. return hw.wrapped.Write(data) } ================================================ FILE: hash/writer_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package hash_test import ( "bytes" "github.com/juju/errors" "github.com/juju/testing" jc "github.com/juju/testing/checkers" "github.com/juju/testing/filetesting" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/hash" ) var _ = gc.Suite(&WriterSuite{}) type WriterSuite struct { testing.IsolationSuite stub *testing.Stub wBuffer *bytes.Buffer writer *filetesting.StubWriter hBuffer *bytes.Buffer hash *filetesting.StubHash } func (s *WriterSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) s.stub = &testing.Stub{} s.wBuffer = new(bytes.Buffer) s.writer = &filetesting.StubWriter{ Stub: s.stub, ReturnWrite: s.wBuffer, } s.hBuffer = new(bytes.Buffer) s.hash = filetesting.NewStubHash(s.stub, s.hBuffer) } func (s *WriterSuite) TestHashingWriterWriteEmpty(c *gc.C) { w := hash.NewHashingWriter(s.writer, s.hash) n, err := w.Write(nil) c.Assert(err, jc.ErrorIsNil) s.stub.CheckCallNames(c, "Write", "Write") c.Check(n, gc.Equals, 0) c.Check(s.wBuffer.String(), gc.Equals, "") c.Check(s.hBuffer.String(), gc.Equals, "") } func (s *WriterSuite) TestHashingWriterWriteSmall(c *gc.C) { w := hash.NewHashingWriter(s.writer, s.hash) n, err := w.Write([]byte("spam")) c.Assert(err, jc.ErrorIsNil) s.stub.CheckCallNames(c, "Write", "Write") c.Check(n, gc.Equals, 4) c.Check(s.wBuffer.String(), gc.Equals, "spam") c.Check(s.hBuffer.String(), gc.Equals, "spam") } func (s *WriterSuite) TestHashingWriterWriteFileError(c *gc.C) { w := hash.NewHashingWriter(s.writer, s.hash) failure := errors.New("") s.stub.SetErrors(failure) _, err := w.Write([]byte("spam")) s.stub.CheckCallNames(c, "Write") c.Check(errors.Cause(err), gc.Equals, failure) } func (s *WriterSuite) TestHashingWriterBase64Sum(c *gc.C) { s.hash.ReturnSum = []byte("spam") w := hash.NewHashingWriter(s.writer, s.hash) b64sum := w.Base64Sum() s.stub.CheckCallNames(c, "Sum") c.Check(b64sum, gc.Equals, "c3BhbQ==") } ================================================ FILE: home_unix.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. //go:build !windows // +build !windows package utils import ( "os" ) // Home returns the os-specific home path. // Always returns the "real" home, not the // confined home that is used when running // inside a strictly confined snap. func Home() string { // Used when running inside a confined snap. realHome, exists := os.LookupEnv("SNAP_REAL_HOME") if exists { return realHome } return os.Getenv("HOME") } // SetHome sets the os-specific home path in the environment. func SetHome(s string) error { if _, exists := os.LookupEnv("SNAP_REAL_HOME"); exists { return os.Setenv("SNAP_REAL_HOME", s) } return os.Setenv("HOME", s) } ================================================ FILE: home_unix_test.go ================================================ // Copyright 2011, 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. //go:build !windows // +build !windows package utils_test import ( "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type homeSuite struct { testing.IsolationSuite } var _ = gc.Suite(&homeSuite{}) func (s *homeSuite) TestHomeLinux(c *gc.C) { h := "/home/foo/bar" s.PatchEnvironment("HOME", h) c.Check(utils.Home(), gc.Equals, h) } func (s *homeSuite) TestHomeConfined(c *gc.C) { h := "/home/foo/bar" s.PatchEnvironment("HOME", "/home/user/snap/foo/1") s.PatchEnvironment("SNAP_REAL_HOME", h) c.Check(utils.Home(), gc.Equals, h) } ================================================ FILE: home_windows.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "os" "path/filepath" ) // Home returns the os-specific home path as specified in the environment. func Home() string { return filepath.Join(os.Getenv("HOMEDRIVE"), os.Getenv("HOMEPATH")) } // SetHome sets the os-specific home path in the environment. func SetHome(s string) error { v := filepath.VolumeName(s) if v != "" { if err := os.Setenv("HOMEDRIVE", v); err != nil { return err } } return os.Setenv("HOMEPATH", s[len(v):]) } ================================================ FILE: home_windows_test.go ================================================ // Copyright 2011, 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "os" "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type homeSuite struct { testing.IsolationSuite } var _ = gc.Suite(&homeSuite{}) func (s *homeSuite) TestHome(c *gc.C) { s.PatchEnvironment("HOMEPATH", "") s.PatchEnvironment("HOMEDRIVE", "") drive := "P:" path := `\home\foo\bar` h := drive + path utils.SetHome(h) c.Check(os.Getenv("HOMEPATH"), gc.Equals, path) c.Check(os.Getenv("HOMEDRIVE"), gc.Equals, drive) c.Check(utils.Home(), gc.Equals, h) // now test that if we only set the path, we don't mess with the drive path2 := `\home\someotherfoo\bar` utils.SetHome(path2) c.Check(os.Getenv("HOMEPATH"), gc.Equals, path2) c.Check(os.Getenv("HOMEDRIVE"), gc.Equals, drive) c.Check(utils.Home(), gc.Equals, drive+path2) } ================================================ FILE: isubuntu.go ================================================ // Copyright 2011, 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "strings" ) // IsUbuntu executes lxb_release to see if the host OS is Ubuntu. func IsUbuntu() bool { out, err := RunCommand("lsb_release", "-i", "-s") if err != nil { return false } return strings.TrimSpace(out) == "Ubuntu" } ================================================ FILE: isubuntu_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "fmt" "runtime" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type IsUbuntuSuite struct { testing.IsolationSuite } var _ = gc.Suite(&IsUbuntuSuite{}) func (s *IsUbuntuSuite) patchLsbRelease(c *gc.C, name string) { var content string var execName string if runtime.GOOS != "windows" { content = fmt.Sprintf("#!/bin/bash --norc\n%s", name) execName = "lsb_release" } else { execName = "lsb_release.bat" content = fmt.Sprintf("@echo off\r\n%s", name) } patchExecutable(s, c.MkDir(), execName, content) } func (s *IsUbuntuSuite) TestIsUbuntu(c *gc.C) { s.patchLsbRelease(c, "echo Ubuntu") c.Assert(utils.IsUbuntu(), jc.IsTrue) } func (s *IsUbuntuSuite) TestIsNotUbuntu(c *gc.C) { s.patchLsbRelease(c, "echo Windows NT") c.Assert(utils.IsUbuntu(), jc.IsFalse) } func (s *IsUbuntuSuite) TestIsNotUbuntuLsbReleaseNotFound(c *gc.C) { if runtime.GOOS != "windows" { s.patchLsbRelease(c, "exit 127") } c.Assert(utils.IsUbuntu(), jc.IsFalse) } ================================================ FILE: jsonhttp/jsonhttp.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // Package jsonhttp provides general functions for returning // JSON responses to HTTP requests. It is agnostic about // the specific form of any returned errors. package jsonhttp import ( "encoding/json" "net/http" "github.com/juju/errors" ) // ErrorToResponse represents a function that can convert a Go error // into a form that can be returned as a JSON body from an HTTP request. // The httpStatus value reports the desired HTTP status. type ErrorToResponse func(err error) (httpStatus int, errorBody any) // ErrorHandler is like http.Handler except it returns an error // which may be returned as the error body of the response. // An ErrorHandler function should not itself write to the ResponseWriter // if it returns an error. type ErrorHandler func(http.ResponseWriter, *http.Request) error // HandleErrors returns a function that can be used to convert an ErrorHandler // into an http.Handler. The given errToResp parameter is used to convert // any non-nil error returned by handle to the response in the HTTP body. func HandleErrors(errToResp ErrorToResponse) func(handle ErrorHandler) http.Handler { writeError := WriteError(errToResp) return func(handle ErrorHandler) http.Handler { f := func(w http.ResponseWriter, req *http.Request) { w1 := responseWriter{ ResponseWriter: w, } if err := handle(&w1, req); err != nil { // We write the error only if the header hasn't // already been written, because if it has, then // we will not be able to set the appropriate error // response code, and there's a danger that we // may be corrupting output by appending // a JSON error message to it. if !w1.headerWritten { writeError(w, err) } // TODO log the error? } } return http.HandlerFunc(f) } } // responseWriter wraps http.ResponseWriter but allows us // to find out whether any body has already been written. type responseWriter struct { headerWritten bool http.ResponseWriter } func (w *responseWriter) Write(data []byte) (int, error) { w.headerWritten = true return w.ResponseWriter.Write(data) } func (w *responseWriter) WriteHeader(code int) { w.headerWritten = true w.ResponseWriter.WriteHeader(code) } // Flush implements http.Flusher.Flush. func (w *responseWriter) Flush() { w.headerWritten = true if f, ok := w.ResponseWriter.(http.Flusher); ok { f.Flush() } } // Ensure statically that responseWriter does implement http.Flusher. var _ http.Flusher = (*responseWriter)(nil) // WriteError returns a function that can be used to write an error to a ResponseWriter // and set the HTTP status code. The errToResp parameter is used to determine // the actual error value and status to write. func WriteError(errToResp ErrorToResponse) func(w http.ResponseWriter, err error) { return func(w http.ResponseWriter, err error) { status, resp := errToResp(err) _ = WriteJSON(w, status, resp) } } // WriteJSON writes the given value to the ResponseWriter // and sets the HTTP status to the given code. func WriteJSON(w http.ResponseWriter, code int, val any) error { // TODO consider marshalling directly to w using json.NewEncoder. // pro: this will not require a full buffer allocation. // con: if there's an error after the first write, it will be lost. data, err := json.Marshal(val) if err != nil { // TODO(rog) log an error if this fails and lose the // error return, because most callers will need // to do that anyway. return errors.Mask(err) } w.Header().Set("content-type", "application/json") w.WriteHeader(code) _, _ = w.Write(data) return nil } // JSONHandler is like http.Handler except that it returns a // body (to be converted to JSON) and an error. // The Header parameter can be used to set // custom header on the response. type JSONHandler func(http.Header, *http.Request) (any, error) // HandleJSON returns a function that can be used to convert an JSONHandler // into an http.Handler. The given errToResp parameter is used to convert // any non-nil error returned by handle to the response in the HTTP body // If it returns a nil value, the original error is returned as a JSON string. func HandleJSON(errToResp ErrorToResponse) func(handle JSONHandler) http.Handler { handleErrors := HandleErrors(errToResp) return func(handle JSONHandler) http.Handler { f := func(w http.ResponseWriter, req *http.Request) error { val, err := handle(w.Header(), req) if err != nil { return errors.Trace(err) } return WriteJSON(w, http.StatusOK, val) } return handleErrors(f) } } ================================================ FILE: jsonhttp/jsonhttp_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package jsonhttp_test import ( "encoding/json" "net/http" "net/http/httptest" "github.com/juju/errors" "github.com/juju/utils/v4/jsonhttp" gc "gopkg.in/check.v1" ) type suite struct{} var _ = gc.Suite(&suite{}) func (*suite) TestWriteJSON(c *gc.C) { rec := httptest.NewRecorder() type Number struct { N int } err := jsonhttp.WriteJSON(rec, http.StatusTeapot, Number{1234}) c.Assert(err, gc.IsNil) c.Assert(rec.Code, gc.Equals, http.StatusTeapot) c.Assert(rec.Body.String(), gc.Equals, `{"N":1234}`) c.Assert(rec.Header().Get("content-type"), gc.Equals, "application/json") } var ( errUnauth = errors.New("unauth") errBadReq = errors.New("bad request") errOther = errors.New("other") errNil = errors.New("nil result") ) type errorResponse struct { Message string } func errorToResponse(err error) (int, any) { resp := &errorResponse{ Message: err.Error(), } status := http.StatusInternalServerError switch errors.Cause(err) { case errUnauth: status = http.StatusUnauthorized case errBadReq: status = http.StatusBadRequest case errNil: return status, nil } return status, &resp } var writeErrorTests = []struct { err error expectStatus int expectResp *errorResponse }{{ err: errUnauth, expectStatus: http.StatusUnauthorized, expectResp: &errorResponse{ Message: errUnauth.Error(), }, }, { err: errBadReq, expectStatus: http.StatusBadRequest, expectResp: &errorResponse{ Message: errBadReq.Error(), }, }, { err: errOther, expectStatus: http.StatusInternalServerError, expectResp: &errorResponse{ Message: errOther.Error(), }, }, { err: errNil, expectStatus: http.StatusInternalServerError, }} func (s *suite) TestWriteError(c *gc.C) { writeError := jsonhttp.WriteError(errorToResponse) for i, test := range writeErrorTests { c.Logf("%d: %s", i, test.err) rec := httptest.NewRecorder() writeError(rec, test.err) resp := parseErrorResponse(c, rec.Body.Bytes()) c.Assert(resp, gc.DeepEquals, test.expectResp) c.Assert(rec.Code, gc.Equals, test.expectStatus) } } func parseErrorResponse(c *gc.C, body []byte) *errorResponse { var errResp *errorResponse err := json.Unmarshal(body, &errResp) c.Assert(err, gc.IsNil) return errResp } func (s *suite) TestHandleErrors(c *gc.C) { handleErrors := jsonhttp.HandleErrors(errorToResponse) // Test when handler returns an error. handler := handleErrors(func(http.ResponseWriter, *http.Request) error { return errUnauth }) rec := httptest.NewRecorder() handler.ServeHTTP(rec, new(http.Request)) c.Assert(rec.Code, gc.Equals, http.StatusUnauthorized) resp := parseErrorResponse(c, rec.Body.Bytes()) c.Assert(resp, gc.DeepEquals, &errorResponse{ Message: errUnauth.Error(), }) // Test when handler returns nil. handler = handleErrors(func(w http.ResponseWriter, _ *http.Request) error { w.WriteHeader(http.StatusCreated) w.Write([]byte("something")) return nil }) rec = httptest.NewRecorder() handler.ServeHTTP(rec, new(http.Request)) c.Assert(rec.Code, gc.Equals, http.StatusCreated) c.Assert(rec.Body.String(), gc.Equals, "something") } var handleErrorsWithErrorAfterWriteHeaderTests = []struct { about string causeWriteHeader func(w http.ResponseWriter) }{{ about: "write", causeWriteHeader: func(w http.ResponseWriter) { w.Write([]byte("")) }, }, { about: "write header", causeWriteHeader: func(w http.ResponseWriter) { w.WriteHeader(http.StatusOK) }, }, { about: "flush", causeWriteHeader: func(w http.ResponseWriter) { w.(http.Flusher).Flush() }, }} func (s *suite) TestHandleErrorsWithErrorAfterWriteHeader(c *gc.C) { handleErrors := jsonhttp.HandleErrors(errorToResponse) for i, test := range handleErrorsWithErrorAfterWriteHeaderTests { c.Logf("test %d: %s", i, test.about) handler := handleErrors(func(w http.ResponseWriter, _ *http.Request) error { test.causeWriteHeader(w) return errors.New("unexpected") }) rec := httptest.NewRecorder() handler.ServeHTTP(rec, new(http.Request)) c.Assert(rec.Code, gc.Equals, http.StatusOK) c.Assert(rec.Body.String(), gc.Equals, "") } } func (s *suite) TestHandleJSON(c *gc.C) { handleJSON := jsonhttp.HandleJSON(errorToResponse) // Test when handler returns an error. handler := handleJSON(func(http.Header, *http.Request) (any, error) { return nil, errUnauth }) rec := httptest.NewRecorder() handler.ServeHTTP(rec, new(http.Request)) resp := parseErrorResponse(c, rec.Body.Bytes()) c.Assert(resp, gc.DeepEquals, &errorResponse{ Message: errUnauth.Error(), }) c.Assert(rec.Code, gc.Equals, http.StatusUnauthorized) // Test when handler returns a body. handler = handleJSON(func(h http.Header, _ *http.Request) (any, error) { h.Set("Some-Header", "value") return "something", nil }) rec = httptest.NewRecorder() handler.ServeHTTP(rec, new(http.Request)) c.Assert(rec.Code, gc.Equals, http.StatusOK) c.Assert(rec.Body.String(), gc.Equals, `"something"`) c.Assert(rec.Header().Get("Some-Header"), gc.Equals, "value") } ================================================ FILE: jsonhttp/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package jsonhttp_test import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: keyvalues/keyvalues.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // The keyvalues package implements a set of functions for parsing key=value data, // usually passed in as command-line parameters to juju subcommands, e.g. // juju-set mongodb logging=true package keyvalues import ( "fmt" "strings" ) // DuplicateError signals that a duplicate key was encountered while parsing // the input into a map. type DuplicateError string func (e DuplicateError) Error() string { return string(e) } // Parse parses the supplied string slice into a map mapping // keys to values. Duplicate keys cause an error to be returned. func Parse(src []string, allowEmptyValues bool) (map[string]string, error) { results := map[string]string{} for _, kv := range src { if kv == "" { continue } parts := strings.SplitN(kv, "=", 2) if len(parts) != 2 { return nil, fmt.Errorf(`expected "key=value", got %q`, kv) } key, value := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) if len(key) == 0 || (!allowEmptyValues && len(value) == 0) { return nil, fmt.Errorf(`expected "key=value", got "%s=%s"`, key, value) } if _, exists := results[key]; exists { return nil, DuplicateError(fmt.Sprintf("key %q specified more than once", key)) } results[key] = value } return results, nil } ================================================ FILE: keyvalues/keyvalues_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package keyvalues_test import ( gc "gopkg.in/check.v1" "github.com/juju/utils/v4/keyvalues" ) type keyValuesSuite struct{} var _ = gc.Suite(&keyValuesSuite{}) var testCases = []struct { about string input []string allowEmptyVal bool output map[string]string error string }{{ about: "simple test case", input: []string{"key=value"}, allowEmptyVal: false, output: map[string]string{"key": "value"}, error: "", }, { about: "empty list", input: []string{}, allowEmptyVal: false, output: map[string]string{}, error: "", }, { about: "nil list", input: nil, allowEmptyVal: false, output: map[string]string{}, error: "", }, { about: "invalid format - missing value", input: []string{"key"}, allowEmptyVal: false, output: nil, error: `expected "key=value", got "key"`, }, { about: "invalid format - missing value", input: []string{"key="}, allowEmptyVal: false, output: nil, error: `expected "key=value", got "key="`, }, { about: "invalid format - missing key", input: []string{"=value"}, allowEmptyVal: false, output: nil, error: `expected "key=value", got "=value"`, }, { about: "invalid format", input: []string{"="}, allowEmptyVal: false, output: nil, error: `expected "key=value", got "="`, }, { about: "invalid format, allowing empty", input: []string{"="}, allowEmptyVal: true, output: nil, error: `expected "key=value", got "="`, }, { about: "duplicate keys", input: []string{"key=value", "key=value"}, allowEmptyVal: true, output: nil, error: `key "key" specified more than once`, }, { about: "multiple keys", input: []string{"key=value", "key2=value", "key3=value"}, allowEmptyVal: true, output: map[string]string{"key": "value", "key2": "value", "key3": "value"}, error: "", }, { about: "empty value", input: []string{"key="}, allowEmptyVal: true, output: map[string]string{"key": ""}, error: "", }, { about: "whitespace trimmed", input: []string{"key=value\n", "key2\t=\tvalue2"}, allowEmptyVal: true, output: map[string]string{"key": "value", "key2": "value2"}, error: "", }, { about: "whitespace trimming and duplicate keys", input: []string{"key =value", "key\t=\tvalue2"}, allowEmptyVal: true, output: nil, error: `key "key" specified more than once`, }, { about: "whitespace trimming and empty value not allowed", input: []string{"key= "}, allowEmptyVal: false, output: nil, error: `expected "key=value", got "key="`, }, { about: "whitespace trimming and empty value", input: []string{"key= "}, allowEmptyVal: true, output: map[string]string{"key": ""}, error: "", }, { about: "whitespace trimming and missing key", input: []string{" =value"}, allowEmptyVal: true, output: nil, error: `expected "key=value", got "=value"`, }, { about: "empty inputs are skipped", input: []string{"key=value", "", "foo=bar"}, allowEmptyVal: true, output: map[string]string{"key": "value", "foo": "bar"}, error: "", }} func (keyValuesSuite) TestMapParsing(c *gc.C) { for i, t := range testCases { c.Logf("test %d: %s", i, t.about) result, err := keyvalues.Parse(t.input, t.allowEmptyVal) c.Check(result, gc.DeepEquals, t.output) if t.error == "" { c.Check(err, gc.IsNil) } else { c.Check(err, gc.ErrorMatches, t.error) } } } ================================================ FILE: keyvalues/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package keyvalues_test import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: limiter.go ================================================ // Copyright 2011, 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "fmt" "math/rand" "time" "github.com/juju/clock" ) type empty struct{} type limiter struct { wait chan empty minPause time.Duration maxPause time.Duration clock clock.Clock } // Limiter represents a limited resource (eg a semaphore). type Limiter interface { // Acquire another unit of the resource. // Acquire returns false to indicate there is no more availability, // until another entity calls Release. Acquire() bool // AcquireWait requests a unit of resource, but blocks until one is // available. AcquireWait() // Release returns a unit of the resource. Calling Release when there // are no units Acquired is an error. Release() error } // NewLimiter creates a limiter. func NewLimiter(maxAllowed int) Limiter { return NewLimiterWithPause(maxAllowed, 0, 0, nil) } // NewLimiterWithPause creates a limiter. If minpause and maxPause is > 0, // there will be a random delay in that duration range before attempting an Acquire. func NewLimiterWithPause(maxAllowed int, minPause, maxPause time.Duration, clk clock.Clock) Limiter { rand.Seed(time.Now().UTC().UnixNano()) if clk == nil { clk = clock.WallClock } return limiter{ wait: make(chan empty, maxAllowed), minPause: minPause, maxPause: maxPause, clock: clk, } } // Acquire requests some resources that you can return later // It returns 'true' if there are resources available, but false if they are // not. Callers are responsible for calling Release if this returns true, but // should not release if this returns false. func (l limiter) Acquire() bool { // Pause before attempting to grab a slot. // This is optional depending on what was used to // construct this limiter, and is used to throttle // incoming connections. l.pause() e := empty{} select { case l.wait <- e: return true default: return false } } // AcquireWait waits for the resource to become available before returning. func (l limiter) AcquireWait() { e := empty{} l.wait <- e } // Release returns the resource to the available pool. func (l limiter) Release() error { select { case <-l.wait: return nil default: return fmt.Errorf("Release without an associated Acquire") } } func (l limiter) pause() { if l.minPause <= 0 || l.maxPause <= 0 { return } pauseRange := int((l.maxPause - l.minPause) / time.Millisecond) pauseTime := time.Duration(rand.Intn(pauseRange)) * time.Millisecond pauseTime += l.minPause select { case <-l.clock.After(pauseTime): } } ================================================ FILE: limiter_test.go ================================================ // Copyright 2011, 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "fmt" "time" "github.com/juju/clock/testclock" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) const longWait = 10 * time.Second type limiterSuite struct { testing.IsolationSuite } var _ = gc.Suite(&limiterSuite{}) func (*limiterSuite) TestAcquireUntilFull(c *gc.C) { l := utils.NewLimiter(2) c.Check(l.Acquire(), jc.IsTrue) c.Check(l.Acquire(), jc.IsTrue) c.Check(l.Acquire(), jc.IsFalse) } func (*limiterSuite) TestBadRelease(c *gc.C) { l := utils.NewLimiter(2) c.Check(l.Release(), gc.ErrorMatches, "Release without an associated Acquire") } func (*limiterSuite) TestAcquireAndRelease(c *gc.C) { l := utils.NewLimiter(2) c.Check(l.Acquire(), jc.IsTrue) c.Check(l.Acquire(), jc.IsTrue) c.Check(l.Acquire(), jc.IsFalse) c.Check(l.Release(), gc.IsNil) c.Check(l.Acquire(), jc.IsTrue) c.Check(l.Release(), gc.IsNil) c.Check(l.Release(), gc.IsNil) c.Check(l.Release(), gc.ErrorMatches, "Release without an associated Acquire") } func (*limiterSuite) TestAcquireWaitBlocksUntilRelease(c *gc.C) { l := utils.NewLimiter(2) calls := make([]string, 0, 10) start := make(chan bool, 0) waiting := make(chan bool, 0) done := make(chan bool, 0) go func() { <-start calls = append(calls, fmt.Sprintf("%v", l.Acquire())) calls = append(calls, fmt.Sprintf("%v", l.Acquire())) calls = append(calls, fmt.Sprintf("%v", l.Acquire())) waiting <- true l.AcquireWait() calls = append(calls, "waited") calls = append(calls, fmt.Sprintf("%v", l.Acquire())) done <- true }() // Start the routine, and wait for it to get to the first checkpoint start <- true select { case <-waiting: case <-time.After(longWait): c.Fatalf("timed out waiting for 'waiting' to trigger") } c.Check(l.Acquire(), jc.IsFalse) l.Release() select { case <-done: case <-time.After(longWait): c.Fatalf("timed out waiting for 'done' to trigger") } c.Check(calls, gc.DeepEquals, []string{"true", "true", "false", "waited", "false"}) } func (*limiterSuite) TestAcquirePauses(c *gc.C) { clk := testclock.NewClock(time.Now()) l := utils.NewLimiterWithPause(2, 10*time.Millisecond, 20*time.Millisecond, clk) acquired := make(chan bool, 1) start := make(chan bool, 0) go func() { <-start defer l.Release() acquired <- l.Acquire() }() start <- true // Minimum pause time not exceeded, acquire should not happen. clk.Advance(9 * time.Millisecond) select { case <-acquired: c.Fail() case <-time.After(50 * time.Millisecond): } clk.Advance(11 * time.Millisecond) select { case <-acquired: case <-time.After(50 * time.Millisecond): c.Fatal("acquire failed") } } ================================================ FILE: multireader.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "io" "sort" "github.com/juju/errors" ) // SizeReaderAt combines io.ReaderAt with a Size method. type SizeReaderAt interface { // Size returns the size of the data readable // from the reader. Size() int64 io.ReaderAt } // NewMultiReaderAt is like io.MultiReader but produces a ReaderAt // (and Size), instead of just a reader. // // Note: this implementation was taken from a talk given // by Brad Fitzpatrick as OSCON 2013. // // http://talks.golang.org/2013/oscon-dl.slide#49 // https://github.com/golang/talks/blob/master/2013/oscon-dl/server-compose.go func NewMultiReaderAt(parts ...SizeReaderAt) SizeReaderAt { m := &multiReaderAt{ parts: make([]offsetAndSource, 0, len(parts)), } var off int64 for _, p := range parts { m.parts = append(m.parts, offsetAndSource{off, p}) off += p.Size() } m.size = off return m } type offsetAndSource struct { off int64 SizeReaderAt } type multiReaderAt struct { parts []offsetAndSource size int64 } func (m *multiReaderAt) Size() int64 { return m.size } func (m *multiReaderAt) ReadAt(p []byte, off int64) (n int, err error) { wantN := len(p) // Skip past the requested offset. skipParts := sort.Search(len(m.parts), func(i int) bool { // This function returns whether parts[i] will // contribute any bytes to our output. part := m.parts[i] return part.off+part.Size() > off }) parts := m.parts[skipParts:] // How far to skip in the first part. needSkip := off if len(parts) > 0 { needSkip -= parts[0].off } for len(parts) > 0 && len(p) > 0 { readP := p partSize := parts[0].Size() if int64(len(readP)) > partSize-needSkip { readP = readP[:partSize-needSkip] } pn, err0 := parts[0].ReadAt(readP, needSkip) if err0 != nil { return n, err0 } n += pn p = p[pn:] if int64(pn)+needSkip == partSize { parts = parts[1:] } needSkip = 0 } if n != wantN { err = io.ErrUnexpectedEOF } return } // NewMultiReaderSeeker returns an io.ReadSeeker that combines // all the given readers into a single one. It assumes that // all the seekers are initially positioned at the start. func NewMultiReaderSeeker(readers ...io.ReadSeeker) io.ReadSeeker { sreaders := make([]SizeReaderAt, len(readers)) for i, r := range readers { r1, err := newSizeReaderAt(r) if err != nil { panic(err) } sreaders[i] = r1 } return &readSeeker{ r: NewMultiReaderAt(sreaders...), } } // newSizeReaderAt adapts an io.ReadSeeker to a SizeReaderAt. // Note that it doesn't strictly adhere to the ReaderAt // contract because it's not safe to call ReadAt concurrently. // This doesn't matter because io.ReadSeeker doesn't // need to be thread-safe and this is only used in that // context. func newSizeReaderAt(r io.ReadSeeker) (SizeReaderAt, error) { size, err := r.Seek(0, 2) if err != nil { return nil, err } return &sizeReaderAt{ r: r, size: size, off: size, }, nil } // sizeReaderAt adapts an io.ReadSeeker to a SizeReaderAt. type sizeReaderAt struct { r io.ReadSeeker size int64 off int64 } // ReadAt implemnts SizeReaderAt.ReadAt. func (r *sizeReaderAt) ReadAt(buf []byte, off int64) (n int, err error) { if off != r.off { _, err = r.r.Seek(off, 0) if err != nil { return 0, err } r.off = off } n, err = io.ReadFull(r.r, buf) r.off += int64(n) return n, err } // Size implemnts SizeReaderAt.Size. func (r *sizeReaderAt) Size() int64 { return r.size } // readSeeker adapts a SizeReaderAt to an io.ReadSeeker. type readSeeker struct { r SizeReaderAt off int64 } // Seek implements io.Seeker.Seek. func (r *readSeeker) Seek(off int64, whence int) (int64, error) { switch whence { case 0: case 1: off += r.off case 2: off = r.r.Size() + off } if off < 0 { return 0, errors.New("negative position") } r.off = off return off, nil } // Read implements io.Reader.Read. func (r *readSeeker) Read(buf []byte) (int, error) { n, err := r.r.ReadAt(buf, r.off) r.off += int64(n) if err == io.ErrUnexpectedEOF { err = io.EOF } return n, err } ================================================ FILE: multireader_test.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "io" "io/ioutil" "strings" "testing/iotest" jc "github.com/juju/testing/checkers" "github.com/juju/utils/v4" gc "gopkg.in/check.v1" ) type multiReaderSeekerSuite struct{} var _ = gc.Suite(&multiReaderSeekerSuite{}) func (*multiReaderSeekerSuite) TestSequentialRead(c *gc.C) { parts := []string{ "one", "two", "three", "four", } r := newMultiStringReader(parts) data, err := ioutil.ReadAll(r) c.Assert(err, gc.IsNil) c.Assert(string(data), gc.Equals, strings.Join(parts, "")) } func (*multiReaderSeekerSuite) TestSeekStart(c *gc.C) { parts := []string{ "one", "two", "three", "four", } all := strings.Join(parts, "") for off := int64(0); off <= int64(len(all)); off++ { c.Logf("-- offset %d", off) r := newMultiStringReader(parts) gotOff, err := r.Seek(off, 0) c.Assert(err, gc.IsNil) c.Assert(gotOff, gc.Equals, off) data, err := ioutil.ReadAll(r) c.Assert(err, gc.IsNil) c.Assert(string(data), gc.Equals, all[off:]) } } func (*multiReaderSeekerSuite) TestSeekEnd(c *gc.C) { parts := []string{ "one", "two", "three", "four", } all := strings.Join(parts, "") for off := int64(0); off <= int64(len(all)); off++ { r := newMultiStringReader(parts) expectOff := int64(len(all)) - off gotOff, err := r.Seek(-off, 2) c.Assert(err, gc.IsNil) c.Assert(gotOff, gc.Equals, expectOff) data, err := ioutil.ReadAll(r) c.Assert(err, gc.IsNil) c.Assert(string(data), gc.Equals, all[expectOff:]) } } func (*multiReaderSeekerSuite) TestSeekCur(c *gc.C) { parts := []string{ "one", "two", "three", "four", } all := strings.Join(parts, "") for off := int64(0); off <= int64(len(all)); off++ { for newOff := int64(0); newOff <= int64(len(all)); newOff++ { readers := make([]io.ReadSeeker, len(parts)) for i, part := range parts { readers[i] = strings.NewReader(part) } r := utils.NewMultiReaderSeeker(readers...) gotOff, err := r.Seek(off, 0) c.Assert(gotOff, gc.Equals, off) c.Assert(err, jc.ErrorIsNil) diff := newOff - off gotNewOff, err := r.Seek(diff, 1) c.Assert(err, gc.IsNil) c.Assert(gotNewOff, gc.Equals, newOff) data, err := ioutil.ReadAll(r) c.Assert(err, gc.IsNil) c.Assert(string(data), gc.Equals, all[newOff:]) } } } func (*multiReaderSeekerSuite) TestSeekAfterRead(c *gc.C) { parts := []string{ "one", "two", "three", "four", } all := strings.Join(parts, "") r := newMultiStringReader(parts) data, err := ioutil.ReadAll(iotest.OneByteReader(r)) c.Assert(err, gc.IsNil) c.Assert(string(data), gc.Equals, all) off, err := r.Seek(-8, 2) c.Assert(err, gc.IsNil) c.Assert(off, gc.Equals, int64(len(all)-8)) data, err = ioutil.ReadAll(r) c.Assert(err, gc.IsNil) c.Assert(string(data), gc.Equals, "hreefour") } func (*multiReaderSeekerSuite) TestSeekNegative(c *gc.C) { r := newMultiStringReader([]string{"one", "two"}) _, err := r.Seek(-1, 0) c.Assert(err, gc.ErrorMatches, "negative position") n, err := r.Seek(0, 0) c.Assert(err, gc.IsNil) c.Assert(n, gc.Equals, int64(0)) _, err = r.Seek(-7, 2) c.Assert(err, gc.ErrorMatches, "negative position") n, err = r.Seek(0, 0) c.Assert(err, gc.IsNil) c.Assert(n, gc.Equals, int64(0)) _, err = r.Seek(-1, 1) c.Assert(err, gc.ErrorMatches, "negative position") n, err = r.Seek(0, 0) c.Assert(err, gc.IsNil) c.Assert(n, gc.Equals, int64(0)) } func (*multiReaderSeekerSuite) TestSeekPastEnd(c *gc.C) { r := newMultiStringReader([]string{"one", "two"}) n, err := r.Seek(100, 0) c.Assert(err, jc.ErrorIsNil) c.Assert(n, gc.Equals, int64(100)) nr, err := r.Read(make([]byte, 1)) c.Assert(nr, gc.Equals, 0) c.Assert(err, gc.Equals, io.EOF) n, err = r.Seek(-5, 1) c.Assert(err, jc.ErrorIsNil) c.Assert(n, gc.Equals, int64(95)) nr, err = r.Read(make([]byte, 1)) c.Assert(nr, gc.Equals, 0) c.Assert(err, gc.Equals, io.EOF) n, err = r.Seek(-94, 1) c.Assert(err, jc.ErrorIsNil) c.Assert(n, gc.Equals, int64(1)) data, err := ioutil.ReadAll(r) c.Assert(err, gc.IsNil) c.Assert(string(data), gc.Equals, "netwo") } type multiReaderAtSuite struct{} var _ = gc.Suite(&multiReaderAtSuite{}) func (*multiReaderAtSuite) TestReadComplete(c *gc.C) { parts := []string{ "one", "two", "three", "four", } all := strings.Join(parts, "") r := newMultistringReaderAt(parts) buf := make([]byte, len(all)) n, err := r.ReadAt(buf, 0) c.Assert(err, jc.ErrorIsNil) c.Assert(n, gc.Equals, len(buf)) c.Assert(string(buf), gc.Equals, all) } func (*multiReaderAtSuite) TestReadPartial(c *gc.C) { parts := []string{ "one", "two", "three", "four", } all := strings.Join(parts, "") r := newMultistringReaderAt(parts) buf := make([]byte, len(all)-4) n, err := r.ReadAt(buf, 2) c.Assert(err, jc.ErrorIsNil) c.Assert(n, gc.Equals, len(buf)) c.Assert(string(buf), gc.Equals, "etwothreefo") } func newMultiStringReader(parts []string) io.ReadSeeker { readers := make([]io.ReadSeeker, len(parts)) for i, part := range parts { readers[i] = strings.NewReader(part) } return utils.NewMultiReaderSeeker(readers...) } type stringReader struct { *strings.Reader } // This method is implemented in later versions // of Go's StringReader but not prior to Go 1.5. func (r stringReader) Size() int64 { return int64(r.Len()) } func newMultistringReaderAt(parts []string) io.ReaderAt { readers := make([]utils.SizeReaderAt, len(parts)) for i, part := range parts { readers[i] = stringReader{strings.NewReader(part)} } return utils.NewMultiReaderAt(readers...) } ================================================ FILE: naturalsort.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "fmt" "sort" "strconv" "unicode" ) // SortStringsNaturally sorts strings according to their natural sort order. func SortStringsNaturally(s []string) []string { sort.Sort(naturally(s)) return s } type naturally []string func (n naturally) Len() int { return len(n) } func (n naturally) Swap(a, b int) { n[a], n[b] = n[b], n[a] } // Less sorts by non-numeric prefix and numeric suffix // when one exists. func (n naturally) Less(a, b int) bool { aVal := n[a] bVal := n[b] for { // If bVal is empty, then aVal can't be less than it. if bVal == "" { return false } // If aVal is empty here, then is must be less than bVal. if aVal == "" { return true } aPrefix, aNumber, aRemainder := splitAtNumber(aVal) bPrefix, bNumber, bRemainder := splitAtNumber(bVal) if aPrefix != bPrefix { return aPrefix < bPrefix } if aNumber != bNumber { return aNumber < bNumber } // Everything is the same so far, try again with the remainer. aVal = aRemainder bVal = bRemainder } } // splitAtNumber splits given string at the first digit, returning the // prefix before the number, the integer represented by the first // series of digits, and the remainder of the string after the first // series of digits. If no digits are present, the number is returned // as -1 and the remainder is empty. func splitAtNumber(str string) (string, int, string) { i := indexOfDigit(str) if i == -1 { // no numbers return str, -1, "" } j := i + indexOfNonDigit(str[i:]) n, err := strconv.Atoi(str[i:j]) if err != nil { panic(fmt.Sprintf("parsing number %v: %v", str[i:j], err)) // should never happen } return str[:i], n, str[j:] } func indexOfDigit(str string) int { for i, r := range str { if unicode.IsDigit(r) { return i } } return -1 } func indexOfNonDigit(str string) int { for i, r := range str { if !unicode.IsDigit(r) { return i } } return len(str) } ================================================ FILE: naturalsort_test.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "math/rand" gc "gopkg.in/check.v1" "github.com/juju/testing" "github.com/juju/utils/v4" ) type naturalSortSuite struct { testing.IsolationSuite } var _ = gc.Suite(&naturalSortSuite{}) func (s *naturalSortSuite) TestEmpty(c *gc.C) { checkCorrectSort(c, []string{}) } func (s *naturalSortSuite) TestAlpha(c *gc.C) { checkCorrectSort(c, []string{"abc", "bac", "cba"}) } func (s *naturalSortSuite) TestNumVsString(c *gc.C) { checkCorrectSort(c, []string{"1", "a"}) } func (s *naturalSortSuite) TestStringVsStringNum(c *gc.C) { checkCorrectSort(c, []string{"a", "a1"}) } func (s *naturalSortSuite) TestCommonPrefix(c *gc.C) { checkCorrectSort(c, []string{"a1", "a1a", "a1b", "a2b", "a2c"}) } func (s *naturalSortSuite) TestDifferentNumberLengths(c *gc.C) { checkCorrectSort(c, []string{"a1a", "a2", "a22a", "a333", "a333a", "a333b"}) } func (s *naturalSortSuite) TestZeroPadding(c *gc.C) { checkCorrectSort(c, []string{"a1", "a002", "a3"}) } func (s *naturalSortSuite) TestMixed(c *gc.C) { checkCorrectSort(c, []string{"1a", "a1", "a1/1", "a10", "a100"}) } func (s *naturalSortSuite) TestSeveralNumericParts(c *gc.C) { checkCorrectSort(c, []string{ "x", "x1", "x1-g0", "x1-g1", "x1-g2", "x1-g10", "x2", "x2-g0", "x2-g2", "x11-g0", "x11-g0-0", "x11-g0-1", "x11-g0-10", "x11-g0-11", "x11-g0-20", "x11-g0-100", "x11-g10-1", "x11-g10-10", "xx1", "xx10", }) } func (s *naturalSortSuite) TestUnitNameLike(c *gc.C) { checkCorrectSort(c, []string{"a1/1", "a1/2", "a1/7", "a1/11", "a1/100"}) } func (s *naturalSortSuite) TestMachineIdLike(c *gc.C) { checkCorrectSort(c, []string{ "1", "1/lxc/0", "1/lxc/1", "1/lxc/2", "1/lxc/10", "1/lxd/0", "1/lxd/1", "1/lxd/10", "2", "11", "11/lxc/6", "11/lxc/60", "20", "21", }) } func (s *naturalSortSuite) TestIPs(c *gc.C) { checkCorrectSort(c, []string{ "1.1.10.122", "001.001.010.123", "001.002.010.123", "100.001.010.123", "100.1.10.124", "100.2.10.124", }) } func checkCorrectSort(c *gc.C, expected []string) { checkSort(c, expected, reverse) for i := 0; i < 5; i++ { checkSort(c, expected, shuffle) } } func checkSort(c *gc.C, expected []string, xform func([]string)) { input := copyStrSlice(expected) xform(input) origInput := copyStrSlice(input) utils.SortStringsNaturally(input) c.Check(input, gc.DeepEquals, expected, gc.Commentf("input was: %#v", origInput)) } func copyStrSlice(in []string) []string { out := make([]string, len(in)) copy(out, in) return out } func shuffle(a []string) { // See https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#Modern_method for i := len(a) - 1; i >= 1; i-- { j := rand.Intn(i + 1) a[i], a[j] = a[j], a[i] } } func reverse(a []string) { size := len(a) for i := 0; i < size/2; i++ { j := size - i - 1 a[i], a[j] = a[j], a[i] } } ================================================ FILE: network.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "fmt" "net" "github.com/juju/loggo/v2" ) var logger = loggo.GetLogger("juju.utils") // GetIPv4Address iterates through the addresses expecting the format from // func (ifi *net.Interface) Addrs() ([]net.Addr, error) func GetIPv4Address(addresses []net.Addr) (string, error) { for _, addr := range addresses { ip, _, err := net.ParseCIDR(addr.String()) if err != nil { return "", err } ipv4 := ip.To4() if ipv4 == nil { continue } return ipv4.String(), nil } return "", fmt.Errorf("no addresses match") } // GetIPv6Address iterates through the addresses expecting the format from // func (ifi *net.Interface) Addrs() ([]net.Addr, error) and returns the first // non-link local address. func GetIPv6Address(addresses []net.Addr) (string, error) { _, llNet, _ := net.ParseCIDR("fe80::/10") for _, addr := range addresses { ip, _, err := net.ParseCIDR(addr.String()) if err != nil { return "", err } if ip.To4() == nil && !llNet.Contains(ip) { return ip.String(), nil } } return "", fmt.Errorf("no addresses match") } // GetAddressForInterface looks for the network interface // and returns the IPv4 address from the possible addresses. func GetAddressForInterface(interfaceName string) (string, error) { iface, err := net.InterfaceByName(interfaceName) if err != nil { logger.Errorf("cannot find network interface %q: %v", interfaceName, err) return "", err } addrs, err := iface.Addrs() if err != nil { logger.Errorf("cannot get addresses for network interface %q: %v", interfaceName, err) return "", err } return GetIPv4Address(addrs) } // GetV4OrV6AddressForInterface looks for the network interface // and returns preferably the IPv4 address, and if it doesn't // exists then IPv6 address. func GetV4OrV6AddressForInterface(interfaceName string) (string, error) { iface, err := net.InterfaceByName(interfaceName) if err != nil { logger.Errorf("cannot find network interface %q: %v", interfaceName, err) return "", err } addrs, err := iface.Addrs() if err != nil { logger.Errorf("cannot get addresses for network interface %q: %v", interfaceName, err) return "", err } if ip, err := GetIPv4Address(addrs); err == nil { return ip, nil } return GetIPv6Address(addrs) } ================================================ FILE: network_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "net" "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type networkSuite struct { testing.IsolationSuite } var _ = gc.Suite(&networkSuite{}) type fakeAddress struct { address string } func (fake fakeAddress) Network() string { return "ignored" } func (fake fakeAddress) String() string { return fake.address } func makeAddresses(values ...string) (result []net.Addr) { for _, v := range values { result = append(result, &fakeAddress{v}) } return } func (*networkSuite) TestGetIPv4Address(c *gc.C) { for _, test := range []struct { addresses []net.Addr expected string errorString string }{{ addresses: makeAddresses( "complete", "nonsense"), errorString: "invalid CIDR address: complete", }, { addresses: makeAddresses( "fe80::90cf:9dff:fe6e:ece/64", ), errorString: "no addresses match", }, { addresses: makeAddresses( "fe80::90cf:9dff:fe6e:ece/64", "10.0.3.1/24", ), expected: "10.0.3.1", }, { addresses: makeAddresses( "10.0.3.1/24", "fe80::90cf:9dff:fe6e:ece/64", ), expected: "10.0.3.1", }} { ip, err := utils.GetIPv4Address(test.addresses) if test.errorString == "" { c.Check(err, gc.IsNil) c.Check(ip, gc.Equals, test.expected) } else { c.Check(err, gc.ErrorMatches, test.errorString) c.Check(ip, gc.Equals, "") } } } func (*networkSuite) TestGetIPv6Address(c *gc.C) { for _, test := range []struct { addresses []net.Addr expected string errorString string }{{ addresses: makeAddresses( "complete", "nonsense"), errorString: "invalid CIDR address: complete", }, { addresses: makeAddresses( "fe80::90cf:9dff:fe6e:ece/64", ), errorString: "no addresses match", }, { addresses: makeAddresses( "fe80::90cf:9dff:fe6e:ece/64", "10.0.3.1/24", ), errorString: "no addresses match", }, { addresses: makeAddresses( "10.0.3.1/24", ), errorString: "no addresses match", }, { addresses: makeAddresses( "10.0.3.1/24", "2001:db8::90cf:9dff:fe6e:ece/64", ), expected: "2001:db8::90cf:9dff:fe6e:ece", }, { addresses: makeAddresses( "2001:db8::90cf:9dff:fe6e:ece/64", "10.0.3.1/24", ), expected: "2001:db8::90cf:9dff:fe6e:ece", }} { ip, err := utils.GetIPv6Address(test.addresses) if test.errorString == "" { c.Check(err, gc.IsNil) c.Check(ip, gc.Equals, test.expected) } else { c.Check(err, gc.ErrorMatches, test.errorString) c.Check(ip, gc.Equals, "") } } } ================================================ FILE: os.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils // These are the names of the operating systems recognized by Go. const ( OSWindows = "windows" OSDarwin = "darwin" OSDragonfly = "dragonfly" OSFreebsd = "freebsd" OSLinux = "linux" OSNacl = "nacl" OSNetbsd = "netbsd" OSOpenbsd = "openbsd" OSSolaris = "solaris" ) // OSUnix is the list of unix-like operating systems recognized by Go. // See http://golang.org/src/path/filepath/path_unix.go. var OSUnix = []string{ OSDarwin, OSDragonfly, OSFreebsd, OSLinux, OSNacl, OSNetbsd, OSOpenbsd, OSSolaris, } // OSIsUnix determines whether or not the given OS name is one of the // unix-like operating systems recognized by Go. func OSIsUnix(os string) bool { for _, goos := range OSUnix { if os == goos { return true } } return false } ================================================ FILE: os_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) var _ = gc.Suite(&osSuite{}) type osSuite struct { testing.IsolationSuite } func (osSuite) TestOSIsUnixKnown(c *gc.C) { for _, os := range utils.OSUnix { c.Logf("checking %q", os) isUnix := utils.OSIsUnix(os) c.Check(isUnix, jc.IsTrue) } } func (osSuite) TestOSIsUnixWindows(c *gc.C) { isUnix := utils.OSIsUnix("windows") c.Check(isUnix, jc.IsFalse) } func (osSuite) TestOSIsUnixUnknown(c *gc.C) { isUnix := utils.OSIsUnix("") c.Check(isUnix, jc.IsFalse) } ================================================ FILE: package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: parallel/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package parallel_test import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: parallel/parallel.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // The parallel package provides utilities for running tasks // concurrently. package parallel import ( "fmt" "sync" ) // Run represents a number of functions running concurrently. type Run struct { mu sync.Mutex results chan Errors max int running int work chan func() error } // Errors holds any errors encountered during the parallel run. type Errors []error func (errs Errors) Error() string { switch len(errs) { case 0: return "no error" case 1: return errs[0].Error() } return fmt.Sprintf("%s (and %d more)", errs[0].Error(), len(errs)-1) } // NewRun returns a new parallel instance. It provides a way of running // functions concurrently while limiting the maximum number running at // once to max. func NewRun(max int) *Run { if max < 1 { panic("parameter max must be >= 1") } return &Run{ max: max, results: make(chan Errors), work: make(chan func() error), } } // Do requests that r run f concurrently. If there are already the maximum // number of functions running concurrently, it will block until one of them // has completed. Do may itself be called concurrently, but may not be called // concurrently with Wait. func (r *Run) Do(f func() error) { select { case r.work <- f: return default: } r.mu.Lock() if r.running < r.max { r.running++ go r.runner() } r.mu.Unlock() r.work <- f } // Wait marks the parallel instance as complete and waits for all the functions // to complete. If any errors were encountered, it returns an Errors value // describing all the errors in arbitrary order. func (r *Run) Wait() error { close(r.work) var errs Errors for i := 0; i < r.running; i++ { errs = append(errs, <-r.results...) } if len(errs) == 0 { return nil } // TODO(rog) sort errors by original order of Do request? return errs } func (r *Run) runner() { var errs Errors for f := range r.work { if err := f(); err != nil { errs = append(errs, err) } } r.results <- errs } ================================================ FILE: parallel/parallel_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package parallel_test import ( "sort" "sync" "sync/atomic" stdtesting "testing" "time" "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/parallel" ) type parallelSuite struct { testing.IsolationSuite } var _ = gc.Suite(¶llelSuite{}) func (*parallelSuite) TestParallelMaxPar(c *gc.C) { const ( totalDo = 10 maxConcurrentRunnersPar = 3 ) var mu sync.Mutex maxConcurrentRunners := 0 nbRunners := 0 nbRuns := 0 parallelRunner := parallel.NewRun(maxConcurrentRunnersPar) for i := 0; i < totalDo; i++ { parallelRunner.Do(func() error { mu.Lock() nbRuns++ nbRunners++ if nbRunners > maxConcurrentRunners { maxConcurrentRunners = nbRunners } mu.Unlock() time.Sleep(time.Second / 10) mu.Lock() nbRunners-- mu.Unlock() return nil }) } err := parallelRunner.Wait() if nbRunners != 0 { c.Errorf("%d functions still running", nbRunners) } if nbRuns != totalDo { c.Errorf("all functions not executed; want %d got %d", totalDo, nbRuns) } c.Check(err, gc.IsNil) if maxConcurrentRunners != maxConcurrentRunnersPar { c.Errorf("wrong number of do's ran at once; want %d got %d", maxConcurrentRunnersPar, maxConcurrentRunners) } } func nothing() error { return nil } func BenchmarkRunSingle(b *stdtesting.B) { for i := 0; i < b.N; i++ { r := parallel.NewRun(1) r.Do(nothing) r.Wait() } } func BenchmarkRun1000p100(b *stdtesting.B) { for i := 0; i < b.N; i++ { r := parallel.NewRun(100) for j := 0; j < 1000; j++ { r.Do(nothing) } r.Wait() } } func (*parallelSuite) TestConcurrentDo(c *gc.C) { r := parallel.NewRun(3) var count int32 var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func() { r.Do(func() error { atomic.AddInt32(&count, 1) return nil }) wg.Done() }() } wg.Wait() err := r.Wait() c.Assert(err, gc.IsNil) c.Assert(count, gc.Equals, int32(100)) } type intError int func (intError) Error() string { return "error" } func (*parallelSuite) TestParallelError(c *gc.C) { const ( totalDo = 10 errDo = 5 ) parallelRun := parallel.NewRun(6) for i := 0; i < totalDo; i++ { i := i if i >= errDo { parallelRun.Do(func() error { return intError(i) }) } else { parallelRun.Do(func() error { return nil }) } } err := parallelRun.Wait() c.Check(err, gc.NotNil) errs := err.(parallel.Errors) c.Check(len(errs), gc.Equals, totalDo-errDo) ints := make([]int, len(errs)) for i, err := range errs { ints[i] = int(err.(intError)) } sort.Ints(ints) for i, n := range ints { c.Check(n, gc.Equals, i+errDo) } } func (*parallelSuite) TestZeroWorkerPanics(c *gc.C) { defer func() { r := recover() c.Check(r, gc.Matches, "parameter max must be >= 1") }() parallel.NewRun(0) } ================================================ FILE: parallel/try.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package parallel import ( "errors" "io" "sync" "gopkg.in/tomb.v1" ) var ( ErrStopped = errors.New("try was stopped") ErrClosed = errors.New("try was closed") ) // Try represents an attempt made concurrently // by a number of goroutines. type Try struct { tomb tomb.Tomb closeMutex sync.Mutex close chan struct{} limiter chan struct{} start chan func() result chan result combineErrors func(err0, err1 error) error maxParallel int endResult io.Closer } // NewTry returns an object that runs functions concurrently until one // succeeds. The result of the first function that returns without an // error is available from the Result method. If maxParallel is // positive, it limits the number of concurrently running functions. // // The function combineErrors(oldErr, newErr) is called to determine // the error return (see the Result method). The first time it is called, // oldErr will be nil; subsequently oldErr will be the error previously // returned by combineErrors. If combineErrors is nil, the last // encountered error is chosen. func NewTry(maxParallel int, combineErrors func(err0, err1 error) error) *Try { if combineErrors == nil { combineErrors = chooseLastError } t := &Try{ combineErrors: combineErrors, maxParallel: maxParallel, close: make(chan struct{}, 1), result: make(chan result), start: make(chan func()), } if t.maxParallel > 0 { t.limiter = make(chan struct{}, t.maxParallel) for i := 0; i < t.maxParallel; i++ { t.limiter <- struct{}{} } } go func() { defer t.tomb.Done() val, err := t.loop() t.endResult = val t.tomb.Kill(err) }() return t } func chooseLastError(err0, err1 error) error { return err1 } type result struct { val io.Closer err error } func (t *Try) loop() (io.Closer, error) { var err error close := t.close nrunning := 0 for { select { case f := <-t.start: nrunning++ go f() case r := <-t.result: if r.err == nil { return r.val, r.err } err = t.combineErrors(err, r.err) nrunning-- if close == nil && nrunning == 0 { return nil, err } case <-t.tomb.Dying(): if err == nil { return nil, ErrStopped } return nil, err case <-close: close = nil if nrunning == 0 { return nil, err } } } } // Start requests the given function to be started, waiting until there // are less than maxParallel functions running if necessary. It returns // an error if the function has not been started (ErrClosed if the Try // has been closed, and ErrStopped if the try is finishing). // // The function should listen on the stop channel and return if it // receives a value, though this is advisory only - the Try does not // wait for all started functions to return before completing. // // If the function returns a nil error but some earlier try was // successful (that is, the returned value is being discarded), // its returned value will be closed by calling its Close method. func (t *Try) Start(try func(stop <-chan struct{}) (io.Closer, error)) error { if t.limiter != nil { // Wait for availability slot. select { case <-t.limiter: case <-t.tomb.Dying(): return ErrStopped case <-t.close: return ErrClosed } } dying := t.tomb.Dying() f := func() { val, err := try(dying) if t.limiter != nil { // Signal availability slot is now free. t.limiter <- struct{}{} } // Deliver result. select { case t.result <- result{val, err}: case <-dying: if err == nil { val.Close() } } } select { case t.start <- f: return nil case <-dying: return ErrStopped case <-t.close: return ErrClosed } } // Close closes the Try. No more functions will be started // if Start is called, and the Try will terminate when all // outstanding functions have completed (or earlier // if one succeeds) func (t *Try) Close() { t.closeMutex.Lock() defer t.closeMutex.Unlock() select { case <-t.close: default: close(t.close) } } // Dead returns a channel that is closed when the // Try completes. func (t *Try) Dead() <-chan struct{} { return t.tomb.Dead() } // Wait waits for the Try to complete and returns the same // error returned by Result. func (t *Try) Wait() error { return t.tomb.Wait() } // Result waits for the Try to complete and returns the result of the // first successful function started by Start. // // If no function succeeded, the last error returned by // combineErrors is returned. If there were no errors or // combineErrors returned nil, ErrStopped is returned. func (t *Try) Result() (io.Closer, error) { err := t.tomb.Wait() return t.endResult, err } // Kill stops the try and all its currently executing functions. func (t *Try) Kill() { t.tomb.Kill(nil) } ================================================ FILE: parallel/try_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package parallel_test import ( "errors" "fmt" "io" "sort" "sync" "time" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/parallel" ) const ( shortWait = 50 * time.Millisecond longWait = 10 * time.Second ) type result string func (r result) Close() error { return nil } type trySuite struct { testing.IsolationSuite } var _ = gc.Suite(&trySuite{}) func tryFunc(delay time.Duration, val io.Closer, err error) func(<-chan struct{}) (io.Closer, error) { return func(<-chan struct{}) (io.Closer, error) { time.Sleep(delay) return val, err } } func (*trySuite) TestOneSuccess(c *gc.C) { try := parallel.NewTry(0, nil) try.Start(tryFunc(0, result("hello"), nil)) val, err := try.Result() c.Assert(err, gc.IsNil) c.Assert(val, gc.Equals, result("hello")) } func (*trySuite) TestOneFailure(c *gc.C) { try := parallel.NewTry(0, nil) expectErr := errors.New("foo") err := try.Start(tryFunc(0, nil, expectErr)) c.Assert(err, gc.IsNil) select { case <-try.Dead(): c.Fatalf("try died before it should") case <-time.After(shortWait): } try.Close() select { case <-try.Dead(): case <-time.After(longWait): c.Fatalf("timed out waiting for Try to complete") } val, err := try.Result() c.Assert(val, gc.IsNil) c.Assert(err, gc.Equals, expectErr) } func (*trySuite) TestStartReturnsErrorAfterClose(c *gc.C) { try := parallel.NewTry(0, nil) expectErr := errors.New("foo") err := try.Start(tryFunc(0, nil, expectErr)) c.Assert(err, gc.IsNil) try.Close() err = try.Start(tryFunc(0, result("goodbye"), nil)) c.Assert(err, gc.Equals, parallel.ErrClosed) // Wait for the first try to deliver its result time.Sleep(shortWait) try.Kill() err = try.Wait() c.Assert(err, gc.Equals, expectErr) } func (*trySuite) TestOutOfOrderResults(c *gc.C) { try := parallel.NewTry(0, nil) try.Start(tryFunc(50*time.Millisecond, result("first"), nil)) try.Start(tryFunc(10*time.Millisecond, result("second"), nil)) r, err := try.Result() c.Assert(err, gc.IsNil) c.Assert(r, gc.Equals, result("second")) } func (*trySuite) TestMaxParallel(c *gc.C) { try := parallel.NewTry(3, nil) var ( mu sync.Mutex count int max int ) for i := 0; i < 10; i++ { try.Start(func(<-chan struct{}) (io.Closer, error) { mu.Lock() if count++; count > max { max = count } c.Check(count, gc.Not(jc.GreaterThan), 3) mu.Unlock() time.Sleep(20 * time.Millisecond) mu.Lock() count-- mu.Unlock() return result("hello"), nil }) } r, err := try.Result() c.Assert(err, gc.IsNil) c.Assert(r, gc.Equals, result("hello")) mu.Lock() defer mu.Unlock() c.Assert(max, gc.Equals, 3) } func (*trySuite) TestStartBlocksForMaxParallel(c *gc.C) { try := parallel.NewTry(3, nil) started := make(chan struct{}) begin := make(chan struct{}) go func() { for i := 0; i < 6; i++ { err := try.Start(func(<-chan struct{}) (io.Closer, error) { <-begin return nil, fmt.Errorf("an error") }) started <- struct{}{} if i < 5 { c.Check(err, gc.IsNil) } else { c.Check(err, gc.Equals, parallel.ErrClosed) } } close(started) }() // Check we can start the first three. timeout := time.After(longWait) for i := 0; i < 3; i++ { select { case <-started: case <-timeout: c.Fatalf("timed out") } } // Check we block when going above maxParallel. timeout = time.After(shortWait) select { case <-started: c.Fatalf("Start did not block") case <-timeout: } // Unblock two attempts. begin <- struct{}{} begin <- struct{}{} // Check we can start another two. timeout = time.After(longWait) for i := 0; i < 2; i++ { select { case <-started: case <-timeout: c.Fatalf("timed out") } } // Check we block again when going above maxParallel. timeout = time.After(shortWait) select { case <-started: c.Fatalf("Start did not block") case <-timeout: } // Close the Try - the last request should be discarded, // unblocking last remaining Start request. try.Close() timeout = time.After(longWait) select { case <-started: case <-timeout: c.Fatalf("Start did not unblock after Close") } // Ensure all checks are completed select { case _, ok := <-started: c.Assert(ok, gc.Equals, false) case <-timeout: c.Fatalf("Start goroutine did not finish") } } func (*trySuite) TestAllConcurrent(c *gc.C) { try := parallel.NewTry(0, nil) started := make(chan chan struct{}) for i := 0; i < 10; i++ { try.Start(func(<-chan struct{}) (io.Closer, error) { reply := make(chan struct{}) started <- reply <-reply return result("hello"), nil }) } timeout := time.After(longWait) for i := 0; i < 10; i++ { select { case reply := <-started: reply <- struct{}{} case <-timeout: c.Fatalf("timed out") } } } type gradedError int func (e gradedError) Error() string { return fmt.Sprintf("error with importance %d", e) } func gradedErrorCombine(err0, err1 error) error { if err0 == nil || err0.(gradedError) < err1.(gradedError) { return err1 } return err0 } type multiError struct { errs []int } func (e *multiError) Error() string { return fmt.Sprintf("%v", e.errs) } func (*trySuite) TestErrorCombine(c *gc.C) { // Use maxParallel=1 to guarantee that all errors are processed sequentially. try := parallel.NewTry(1, func(err0, err1 error) error { if err0 == nil { err0 = &multiError{} } err0.(*multiError).errs = append(err0.(*multiError).errs, int(err1.(gradedError))) return err0 }) errors := []gradedError{3, 2, 4, 0, 5, 5, 3} for _, err := range errors { err := err try.Start(func(<-chan struct{}) (io.Closer, error) { return nil, err }) } try.Close() val, err := try.Result() c.Assert(val, gc.IsNil) grades := err.(*multiError).errs sort.Ints(grades) c.Assert(grades, gc.DeepEquals, []int{0, 2, 3, 3, 4, 5, 5}) } func (*trySuite) TestTriesAreStopped(c *gc.C) { try := parallel.NewTry(0, nil) stopped := make(chan struct{}) try.Start(func(stop <-chan struct{}) (io.Closer, error) { <-stop stopped <- struct{}{} return nil, parallel.ErrStopped }) try.Start(tryFunc(0, result("hello"), nil)) val, err := try.Result() c.Assert(err, gc.IsNil) c.Assert(val, gc.Equals, result("hello")) select { case <-stopped: case <-time.After(longWait): c.Fatalf("timed out waiting for stop") } } func (*trySuite) TestCloseTwice(c *gc.C) { try := parallel.NewTry(0, nil) try.Close() try.Close() val, err := try.Result() c.Assert(val, gc.IsNil) c.Assert(err, gc.IsNil) } type closeResult struct { closed chan struct{} } func (r *closeResult) Close() error { close(r.closed) return nil } func (*trySuite) TestExtraResultsAreClosed(c *gc.C) { try := parallel.NewTry(0, nil) begin := make([]chan struct{}, 4) results := make([]*closeResult, len(begin)) for i := range begin { begin[i] = make(chan struct{}) results[i] = &closeResult{make(chan struct{})} i := i try.Start(func(<-chan struct{}) (io.Closer, error) { <-begin[i] return results[i], nil }) } begin[0] <- struct{}{} val, err := try.Result() c.Assert(err, gc.IsNil) c.Assert(val, gc.Equals, results[0]) timeout := time.After(shortWait) for i, r := range results[1:] { begin[i+1] <- struct{}{} select { case <-r.closed: case <-timeout: c.Fatalf("timed out waiting for close") } } select { case <-results[0].closed: c.Fatalf("result was inappropriately closed") case <-time.After(shortWait): } } func (*trySuite) TestEverything(c *gc.C) { try := parallel.NewTry(5, gradedErrorCombine) tries := []struct { startAt time.Duration wait time.Duration val result err error }{{ wait: 30 * time.Millisecond, err: gradedError(3), }, { startAt: 10 * time.Millisecond, wait: 20 * time.Millisecond, val: result("result 1"), }, { startAt: 20 * time.Millisecond, wait: 10 * time.Millisecond, val: result("result 2"), }, { startAt: 20 * time.Millisecond, wait: 5 * time.Second, val: "delayed result", }, { startAt: 5 * time.Millisecond, err: gradedError(4), }} for _, t := range tries { t := t go func() { time.Sleep(t.startAt) try.Start(tryFunc(t.wait, t.val, t.err)) }() } val, err := try.Result() if val != result("result 1") && val != result("result 2") { c.Errorf(`expected "result 1" or "result 2" got %#v`, val) } c.Assert(err, gc.IsNil) } ================================================ FILE: password.go ================================================ // Copyright 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "crypto/rand" "crypto/sha512" "encoding/base64" "fmt" "io" "golang.org/x/crypto/pbkdf2" ) // CompatSalt is because Juju 1.16 and older used a hard-coded salt to compute // the password hash for all users and agents var CompatSalt = string([]byte{0x75, 0x82, 0x81, 0xca}) const randomPasswordBytes = 18 // MinAgentPasswordLength describes how long agent passwords should be. We // require this length because we assume enough entropy in the Agent password // that it is safe to not do extra rounds of iterated hashing. var MinAgentPasswordLength = base64.StdEncoding.EncodedLen(randomPasswordBytes) // RandomBytes returns n random bytes. func RandomBytes(n int) ([]byte, error) { buf := make([]byte, n) _, err := io.ReadFull(rand.Reader, buf) if err != nil { return nil, fmt.Errorf("cannot read random bytes: %v", err) } return buf, nil } // RandomPassword generates a random base64-encoded password. func RandomPassword() (string, error) { b, err := RandomBytes(randomPasswordBytes) if err != nil { return "", err } return base64.StdEncoding.EncodeToString(b), nil } // RandomSalt generates a random base64 data suitable for using as a password // salt The pbkdf2 guideline is to use 8 bytes of salt, so we do 12 raw bytes // into 16 base64 bytes. (The alternative is 6 raw into 8 base64). func RandomSalt() (string, error) { b, err := RandomBytes(12) if err != nil { return "", err } return base64.StdEncoding.EncodeToString(b), nil } // FastInsecureHash specifies whether a fast, insecure version of the hash // algorithm will be used. Changing this will cause PasswordHash to // produce incompatible passwords. It should only be changed for // testing purposes - to make tests run faster. var FastInsecureHash = false // UserPasswordHash returns base64-encoded one-way hash password that is // computationally hard to crack by iterating through possible passwords. func UserPasswordHash(password string, salt string) string { if salt == "" { panic("salt is not allowed to be empty") } iter := 8192 if FastInsecureHash { iter = 1 } // Generate 18 byte passwords because we know that MongoDB // uses the MD5 sum of the password anyway, so there's // no point in using more bytes. (18 so we don't get base 64 // padding characters). h := pbkdf2.Key([]byte(password), []byte(salt), iter, 18, sha512.New) return base64.StdEncoding.EncodeToString(h) } // AgentPasswordHash returns base64-encoded one-way hash of password. This is // not suitable for User passwords because those will have limited entropy (see // UserPasswordHash). However, since we generate long random passwords for // agents, we can trust that there is sufficient entropy to prevent brute force // search. And using a faster hash allows us to restart the state machines and // have 1000s of agents log in in a reasonable amount of time. func AgentPasswordHash(password string) string { sum := sha512.New() sum.Write([]byte(password)) h := sum.Sum(nil) return base64.StdEncoding.EncodeToString(h[:18]) } ================================================ FILE: password_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type passwordSuite struct { testing.IsolationSuite } var _ = gc.Suite(&passwordSuite{}) // Base64 *can* include a tail of '=' characters, but all the tests here // explicitly *don't* want those because it is wasteful. var base64Chars = "^[A-Za-z0-9+/]+$" func (*passwordSuite) TestRandomBytes(c *gc.C) { b, err := utils.RandomBytes(16) c.Assert(err, gc.IsNil) c.Assert(b, gc.HasLen, 16) x0 := b[0] for _, x := range b { if x != x0 { return } } c.Errorf("all same bytes in result of RandomBytes") } func (*passwordSuite) TestRandomPassword(c *gc.C) { p, err := utils.RandomPassword() c.Assert(err, gc.IsNil) if len(p) < 18 { c.Errorf("password too short: %q", p) } c.Assert(p, gc.Matches, base64Chars) } func (*passwordSuite) TestRandomSalt(c *gc.C) { salt, err := utils.RandomSalt() c.Assert(err, gc.IsNil) if len(salt) < 12 { c.Errorf("salt too short: %q", salt) } // check we're not adding base64 padding. c.Assert(salt, gc.Matches, base64Chars) } var testPasswords = []string{"", "a", "a longer password than i would usually bother with"} var testSalts = []string{"abcd", "abcdefgh", "abcdefghijklmnop", utils.CompatSalt} func (*passwordSuite) TestUserPasswordHash(c *gc.C) { seenHashes := make(map[string]bool) for i, password := range testPasswords { for j, salt := range testSalts { c.Logf("test %d, %d %s %s", i, j, password, salt) hashed := utils.UserPasswordHash(password, salt) c.Logf("hash %q", hashed) c.Assert(len(hashed), gc.Equals, 24) c.Assert(seenHashes[hashed], gc.Equals, false) // check we're not adding base64 padding. c.Assert(hashed, gc.Matches, base64Chars) seenHashes[hashed] = true // check it's deterministic altHashed := utils.UserPasswordHash(password, salt) c.Assert(altHashed, gc.Equals, hashed) } } } func (*passwordSuite) TestAgentPasswordHash(c *gc.C) { seenValues := make(map[string]bool) for i := 0; i < 1000; i++ { password, err := utils.RandomPassword() c.Assert(err, gc.IsNil) c.Assert(seenValues[password], jc.IsFalse) seenValues[password] = true hashed := utils.AgentPasswordHash(password) c.Assert(hashed, gc.Not(gc.Equals), password) c.Assert(seenValues[hashed], jc.IsFalse) seenValues[hashed] = true c.Assert(len(hashed), gc.Equals, 24) // check we're not adding base64 padding. c.Assert(hashed, gc.Matches, base64Chars) } } ================================================ FILE: proxy/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package proxy_test import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: proxy/proxy.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package proxy import ( "fmt" "os" "strings" "github.com/juju/collections/set" ) const ( // Remove the likelihood of errors by mistyping string values. http_proxy = "http_proxy" https_proxy = "https_proxy" ftp_proxy = "ftp_proxy" no_proxy = "no_proxy" ) // Settings holds the values for the HTTP, HTTPS and FTP proxies as well as the // no_proxy value found by Detect Proxies. // AutoNoProxy is filled with addresses of controllers, we never want to proxy those type Settings struct { Http string Https string Ftp string NoProxy string AutoNoProxy string } func getSetting(key string) string { value := os.Getenv(key) if value == "" { value = os.Getenv(strings.ToUpper(key)) } return value } // DetectProxies returns the proxy settings found the environment. func DetectProxies() Settings { return Settings{ Http: getSetting(http_proxy), Https: getSetting(https_proxy), Ftp: getSetting(ftp_proxy), NoProxy: getSetting(no_proxy), } } // AsScriptEnvironment returns a potentially multi-line string in a format // that specifies exported key=value lines. There are two lines for each non- // empty proxy value, one lower-case and one upper-case. func (s *Settings) AsScriptEnvironment() string { var lines []string addLine := func(proxy, value string) { if value != "" { lines = append( lines, fmt.Sprintf("export %s=%s", proxy, value), fmt.Sprintf("export %s=%s", strings.ToUpper(proxy), value)) } } addLine(http_proxy, s.Http) addLine(https_proxy, s.Https) addLine(ftp_proxy, s.Ftp) addLine(no_proxy, s.FullNoProxy()) return strings.Join(lines, "\n") } // AsEnvironmentValues returns a slice of strings of the format "key=value" // suitable to be used in a command environment. There are two values for each // non-empty proxy value, one lower-case and one upper-case. func (s *Settings) AsEnvironmentValues() []string { lines := []string{} addLine := func(proxy, value string) { if value != "" { lines = append( lines, fmt.Sprintf("%s=%s", proxy, value), fmt.Sprintf("%s=%s", strings.ToUpper(proxy), value)) } } addLine(http_proxy, s.Http) addLine(https_proxy, s.Https) addLine(ftp_proxy, s.Ftp) addLine(no_proxy, s.FullNoProxy()) return lines } // AsSystemdDefaultEnv returns a string in the format understood by systemd: // DefaultEnvironment="http_proxy=...." "HTTP_PROXY=..." ... func (s *Settings) AsSystemdDefaultEnv() string { lines := s.AsEnvironmentValues() rv := `# To allow juju to control the global systemd proxy settings, # create symbolic links to this file from within /etc/systemd/system.conf.d/ # and /etc/systemd/users.conf.d/. [Manager] DefaultEnvironment=` for _, line := range lines { rv += fmt.Sprintf(`"%s" `, line) } return rv + "\n" } // SetEnvironmentValues updates the process environment with the // proxy values stored in the settings object. Both the lower-case // and upper-case variants are set. // // http_proxy, HTTP_PROXY // https_proxy, HTTPS_PROXY // ftp_proxy, FTP_PROXY func (s *Settings) SetEnvironmentValues() { setenv := func(proxy, value string) { os.Setenv(proxy, value) os.Setenv(strings.ToUpper(proxy), value) } setenv(http_proxy, s.Http) setenv(https_proxy, s.Https) setenv(ftp_proxy, s.Ftp) setenv(no_proxy, s.FullNoProxy()) } // FullNoProxy merges NoProxy and AutoNoProxyList func (s *Settings) FullNoProxy() string { var allNoProxy []string if s.NoProxy != "" { allNoProxy = strings.Split(s.NoProxy, ",") } if s.AutoNoProxy != "" { allNoProxy = append(allNoProxy, strings.Split(s.AutoNoProxy, ",")...) } noProxySet := set.NewStrings(allNoProxy...) return strings.Join(noProxySet.SortedValues(), ",") } ================================================ FILE: proxy/proxy_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package proxy_test import ( "os" "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/proxy" ) type proxySuite struct { testing.IsolationSuite } var _ = gc.Suite(&proxySuite{}) func (s *proxySuite) TestDetectNoSettings(c *gc.C) { // Patch all of the environment variables we check out just in case the // user has one set. s.PatchEnvironment("http_proxy", "") s.PatchEnvironment("HTTP_PROXY", "") s.PatchEnvironment("https_proxy", "") s.PatchEnvironment("HTTPS_PROXY", "") s.PatchEnvironment("ftp_proxy", "") s.PatchEnvironment("FTP_PROXY", "") s.PatchEnvironment("no_proxy", "") s.PatchEnvironment("NO_PROXY", "") proxies := proxy.DetectProxies() c.Assert(proxies, gc.DeepEquals, proxy.Settings{}) } func (s *proxySuite) TestDetectPrimary(c *gc.C) { // Patch all of the environment variables we check out just in case the // user has one set. s.PatchEnvironment("http_proxy", "http://user@10.0.0.1") s.PatchEnvironment("HTTP_PROXY", "") s.PatchEnvironment("https_proxy", "https://user@10.0.0.1") s.PatchEnvironment("HTTPS_PROXY", "") s.PatchEnvironment("ftp_proxy", "ftp://user@10.0.0.1") s.PatchEnvironment("FTP_PROXY", "") s.PatchEnvironment("no_proxy", "10.0.3.1,localhost") s.PatchEnvironment("NO_PROXY", "") proxies := proxy.DetectProxies() c.Assert(proxies, gc.DeepEquals, proxy.Settings{ Http: "http://user@10.0.0.1", Https: "https://user@10.0.0.1", Ftp: "ftp://user@10.0.0.1", NoProxy: "10.0.3.1,localhost", }) } func (s *proxySuite) TestDetectFallback(c *gc.C) { // Patch all of the environment variables we check out just in case the // user has one set. s.PatchEnvironment("http_proxy", "") s.PatchEnvironment("HTTP_PROXY", "http://user@10.0.0.2") s.PatchEnvironment("https_proxy", "") s.PatchEnvironment("HTTPS_PROXY", "https://user@10.0.0.2") s.PatchEnvironment("ftp_proxy", "") s.PatchEnvironment("FTP_PROXY", "ftp://user@10.0.0.2") s.PatchEnvironment("no_proxy", "") s.PatchEnvironment("NO_PROXY", "10.0.3.1,localhost") proxies := proxy.DetectProxies() c.Assert(proxies, gc.DeepEquals, proxy.Settings{ Http: "http://user@10.0.0.2", Https: "https://user@10.0.0.2", Ftp: "ftp://user@10.0.0.2", NoProxy: "10.0.3.1,localhost", }) } func (s *proxySuite) TestDetectPrimaryPreference(c *gc.C) { // Patch all of the environment variables we check out just in case the // user has one set. s.PatchEnvironment("http_proxy", "http://user@10.0.0.1") s.PatchEnvironment("https_proxy", "https://user@10.0.0.1") s.PatchEnvironment("ftp_proxy", "ftp://user@10.0.0.1") s.PatchEnvironment("no_proxy", "10.0.3.1,localhost") s.PatchEnvironment("HTTP_PROXY", "http://user@10.0.0.2") s.PatchEnvironment("HTTPS_PROXY", "https://user@10.0.0.2") s.PatchEnvironment("FTP_PROXY", "ftp://user@10.0.0.2") s.PatchEnvironment("NO_PROXY", "localhost") proxies := proxy.DetectProxies() c.Assert(proxies, gc.DeepEquals, proxy.Settings{ Http: "http://user@10.0.0.1", Https: "https://user@10.0.0.1", Ftp: "ftp://user@10.0.0.1", NoProxy: "10.0.3.1,localhost", }) } func (s *proxySuite) TestAsScriptEnvironmentEmpty(c *gc.C) { proxies := proxy.Settings{} c.Assert(proxies.AsScriptEnvironment(), gc.Equals, "") } func (s *proxySuite) TestAsScriptEnvironmentOneValue(c *gc.C) { proxies := proxy.Settings{ Http: "some-value", } expected := ` export http_proxy=some-value export HTTP_PROXY=some-value`[1:] c.Assert(proxies.AsScriptEnvironment(), gc.Equals, expected) } func (s *proxySuite) TestAsScriptEnvironmentAllValue(c *gc.C) { proxies := proxy.Settings{ Http: "some-value", Https: "special", Ftp: "who uses this?", NoProxy: "10.0.3.1,localhost", } expected := ` export http_proxy=some-value export HTTP_PROXY=some-value export https_proxy=special export HTTPS_PROXY=special export ftp_proxy=who uses this? export FTP_PROXY=who uses this? export no_proxy=10.0.3.1,localhost export NO_PROXY=10.0.3.1,localhost`[1:] c.Assert(proxies.AsScriptEnvironment(), gc.Equals, expected) } func (s *proxySuite) TestAsEnvironmentValuesEmpty(c *gc.C) { proxies := proxy.Settings{} c.Assert(proxies.AsEnvironmentValues(), gc.HasLen, 0) } func (s *proxySuite) TestAsEnvironmentValuesOneValue(c *gc.C) { proxies := proxy.Settings{ Http: "some-value", } expected := []string{ "http_proxy=some-value", "HTTP_PROXY=some-value", } c.Assert(proxies.AsEnvironmentValues(), gc.DeepEquals, expected) } func (s *proxySuite) TestAsEnvironmentValuesAllValue(c *gc.C) { proxies := proxy.Settings{ Http: "some-value", Https: "special", Ftp: "who uses this?", NoProxy: "10.0.3.1,localhost", } expected := []string{ "http_proxy=some-value", "HTTP_PROXY=some-value", "https_proxy=special", "HTTPS_PROXY=special", "ftp_proxy=who uses this?", "FTP_PROXY=who uses this?", "no_proxy=10.0.3.1,localhost", "NO_PROXY=10.0.3.1,localhost", } c.Assert(proxies.AsEnvironmentValues(), gc.DeepEquals, expected) } func (s *proxySuite) TestAsSystemdDefaultEnv(c *gc.C) { proxies := proxy.Settings{ Http: "some-value", Https: "special", Ftp: "who uses this?", NoProxy: "10.0.3.1,localhost", } expected := ` # To allow juju to control the global systemd proxy settings, # create symbolic links to this file from within /etc/systemd/system.conf.d/ # and /etc/systemd/users.conf.d/. [Manager] DefaultEnvironment="http_proxy=some-value" "HTTP_PROXY=some-value" "https_proxy=special" "HTTPS_PROXY=special" "ftp_proxy=who uses this?" "FTP_PROXY=who uses this?" "no_proxy=10.0.3.1,localhost" "NO_PROXY=10.0.3.1,localhost" `[1:] c.Assert(proxies.AsSystemdDefaultEnv(), gc.DeepEquals, expected) } func (s *proxySuite) TestSetEnvironmentValues(c *gc.C) { s.PatchEnvironment("http_proxy", "initial") s.PatchEnvironment("HTTP_PROXY", "initial") s.PatchEnvironment("https_proxy", "initial") s.PatchEnvironment("HTTPS_PROXY", "initial") s.PatchEnvironment("ftp_proxy", "initial") s.PatchEnvironment("FTP_PROXY", "initial") s.PatchEnvironment("no_proxy", "initial") s.PatchEnvironment("NO_PROXY", "initial") proxySettings := proxy.Settings{ Http: "http proxy", Https: "https proxy", // Ftp left blank to show clearing env. NoProxy: "10.0.3.1,localhost", } proxySettings.SetEnvironmentValues() obtained := proxy.DetectProxies() c.Assert(obtained, gc.DeepEquals, proxySettings) c.Assert(os.Getenv("http_proxy"), gc.Equals, "http proxy") c.Assert(os.Getenv("HTTP_PROXY"), gc.Equals, "http proxy") c.Assert(os.Getenv("https_proxy"), gc.Equals, "https proxy") c.Assert(os.Getenv("HTTPS_PROXY"), gc.Equals, "https proxy") c.Assert(os.Getenv("ftp_proxy"), gc.Equals, "") c.Assert(os.Getenv("FTP_PROXY"), gc.Equals, "") c.Assert(os.Getenv("no_proxy"), gc.Equals, "10.0.3.1,localhost") c.Assert(os.Getenv("NO_PROXY"), gc.Equals, "10.0.3.1,localhost") } func (s *proxySuite) TestAutoNoProxy(c *gc.C) { proxies := proxy.Settings{ NoProxy: "10.0.3.1,localhost", } expectedFirst := []string{ "no_proxy=10.0.3.1,localhost", "NO_PROXY=10.0.3.1,localhost", } expectedSecond := []string{ "no_proxy=10.0.3.1,10.0.3.2,localhost", "NO_PROXY=10.0.3.1,10.0.3.2,localhost", } c.Assert(proxies.AsEnvironmentValues(), gc.DeepEquals, expectedFirst) proxies.AutoNoProxy = "10.0.3.1,10.0.3.2" c.Assert(proxies.AsEnvironmentValues(), gc.DeepEquals, expectedSecond) } ================================================ FILE: randomstring.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "math/rand" "sync" "time" ) // Can be used as a sane default argument for RandomString var ( LowerAlpha = []rune("abcdefghijklmnopqrstuvwxyz") UpperAlpha = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ") Digits = []rune("0123456789") ) var ( randomStringMu sync.Mutex randomStringRand *rand.Rand ) func init() { randomStringRand = rand.New( rand.NewSource(time.Now().UnixNano()), ) } // RandomString will return a string of length n that will only // contain runes inside validRunes func RandomString(n int, validRunes []rune) string { randomStringMu.Lock() defer randomStringMu.Unlock() runes := make([]rune, n) for i := range runes { runes[i] = validRunes[randomStringRand.Intn(len(validRunes))] } return string(runes) } ================================================ FILE: randomstring_test.go ================================================ // Copyright 2015 Canonical Ltd. // Copyright 2015 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "github.com/juju/testing" jc "github.com/juju/testing/checkers" "github.com/juju/utils/v4" gc "gopkg.in/check.v1" ) type randomStringSuite struct { testing.IsolationSuite } var _ = gc.Suite(&randomStringSuite{}) var ( validChars = []rune("thisissorandom") length = 7 ) func (randomStringSuite) TestLength(c *gc.C) { s := utils.RandomString(length, validChars) c.Assert(s, gc.HasLen, length) } func (randomStringSuite) TestContentInValidRunes(c *gc.C) { s := utils.RandomString(length, validChars) for _, char := range s { c.Assert(string(validChars), jc.Contains, string(char)) } } ================================================ FILE: registry/export_test.go ================================================ // Copyright 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package registry var ( DescriptionFromVersions = descriptionFromVersions ) ================================================ FILE: registry/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package registry_test import ( "testing" gc "gopkg.in/check.v1" ) func TestAll(t *testing.T) { gc.TestingT(t) } ================================================ FILE: registry/registry.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package registry import ( "fmt" "reflect" "sort" "github.com/juju/errors" ) // TypedNameVersion is a registry that will allow you to register objects based // on a name and version pair. The objects must be convertible to the Type // defined when the registry was created. It will be cast during Register so // you can be sure all objects returned from Get() are safe to TypeAssert to // that type. type TypedNameVersion struct { requiredType reflect.Type versions map[string]Versions } // NewTypedNameVersion creates a place to register your objects func NewTypedNameVersion(requiredType reflect.Type) *TypedNameVersion { return &TypedNameVersion{ requiredType: requiredType, versions: make(map[string]Versions), } } // Description gives the name and available versions in a registry. type Description struct { Name string Versions []int } // Versions maps concrete versions of the objects. type Versions map[int]any // Register records the factory that can be used to produce an instance of the // facade at the supplied version. // If the object being registered doesn't Implement the required Type, then an // error is returned. // An error is also returned if an object is already registered with the given // name and version. func (r *TypedNameVersion) Register(name string, version int, obj any) error { if !reflect.TypeOf(obj).ConvertibleTo(r.requiredType) { return fmt.Errorf("object of type %T cannot be converted to type %s.%s", obj, r.requiredType.PkgPath(), r.requiredType.Name()) } obj = reflect.ValueOf(obj).Convert(r.requiredType).Interface() if r.versions == nil { r.versions = make(map[string]Versions, 1) } if versions, ok := r.versions[name]; ok { if _, ok := versions[version]; ok { fullname := fmt.Sprintf("%s(%d)", name, version) return fmt.Errorf("object %q already registered", fullname) } versions[version] = obj } else { r.versions[name] = Versions{version: obj} } return nil } // descriptionFromVersions aggregates the information in a Versions map into a // more friendly form for List() func descriptionFromVersions(name string, versions Versions) Description { intVersions := make([]int, 0, len(versions)) for version := range versions { intVersions = append(intVersions, version) } sort.Ints(intVersions) return Description{ Name: name, Versions: intVersions, } } // List returns a slice describing each of the registered Facades. func (r *TypedNameVersion) List() []Description { names := make([]string, 0, len(r.versions)) for name := range r.versions { names = append(names, name) } sort.Strings(names) descriptions := make([]Description, len(r.versions)) for i, name := range names { versions := r.versions[name] descriptions[i] = descriptionFromVersions(name, versions) } return descriptions } // Get returns the object for a single name and version. If the requested // facade is not found, it returns error.NotFound func (r *TypedNameVersion) Get(name string, version int) (any, error) { if versions, ok := r.versions[name]; ok { if factory, ok := versions[version]; ok { return factory, nil } } return nil, errors.NotFoundf("%s(%d)", name, version) } ================================================ FILE: registry/registry_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package registry_test import ( "reflect" "github.com/juju/errors" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/registry" ) type registrySuite struct { testing.IsolationSuite } var _ = gc.Suite(®istrySuite{}) type Factory func() (any, error) func nilFactory() (any, error) { return nil, nil } var factoryType = reflect.TypeOf((*Factory)(nil)).Elem() type testFacade struct { version string called bool } type stringVal struct { value string } func (t *testFacade) TestMethod() stringVal { t.called = true return stringVal{"called " + t.version} } func (s *registrySuite) TestDescriptionFromVersions(c *gc.C) { versions := registry.Versions{0: nilFactory} c.Check(registry.DescriptionFromVersions("name", versions), gc.DeepEquals, registry.Description{ Name: "name", Versions: []int{0}, }) versions[2] = nilFactory c.Check(registry.DescriptionFromVersions("name", versions), gc.DeepEquals, registry.Description{ Name: "name", Versions: []int{0, 2}, }) } func (s *registrySuite) TestDescriptionFromVersionsAreSorted(c *gc.C) { versions := registry.Versions{ 10: nilFactory, 5: nilFactory, 0: nilFactory, 18: nilFactory, 6: nilFactory, 4: nilFactory, } c.Check(registry.DescriptionFromVersions("name", versions), gc.DeepEquals, registry.Description{ Name: "name", Versions: []int{0, 4, 5, 6, 10, 18}, }) } func (s *registrySuite) TestRegisterAndList(c *gc.C) { r := registry.NewTypedNameVersion(factoryType) c.Assert(r.Register("name", 0, nilFactory), gc.IsNil) c.Check(r.List(), gc.DeepEquals, []registry.Description{ {Name: "name", Versions: []int{0}}, }) } func (s *registrySuite) TestRegisterAndListMultiple(c *gc.C) { r := registry.NewTypedNameVersion(factoryType) c.Assert(r.Register("other", 0, nilFactory), gc.IsNil) c.Assert(r.Register("name", 0, nilFactory), gc.IsNil) c.Assert(r.Register("third", 2, nilFactory), gc.IsNil) c.Check(r.List(), gc.DeepEquals, []registry.Description{ {Name: "name", Versions: []int{0}}, {Name: "other", Versions: []int{0}}, {Name: "third", Versions: []int{2}}, }) } func (s *registrySuite) TestRegisterWrongType(c *gc.C) { r := registry.NewTypedNameVersion(factoryType) err := r.Register("other", 0, "notAFactory") c.Check(err, gc.ErrorMatches, `object of type string cannot be converted to type .*registry_test.Factory`) } func (s *registrySuite) TestRegisterAlreadyPresent(c *gc.C) { r := registry.NewTypedNameVersion(factoryType) err := r.Register("name", 0, func() (any, error) { return "orig", nil }) c.Assert(err, gc.IsNil) err = r.Register("name", 0, func() (any, error) { return "broken", nil }) c.Check(err, gc.ErrorMatches, `object "name\(0\)" already registered`) f, err := r.Get("name", 0) c.Assert(err, gc.IsNil) val, err := f.(Factory)() c.Assert(err, gc.IsNil) c.Check(val, gc.Equals, "orig") } func (s *registrySuite) TestGet(c *gc.C) { r := registry.NewTypedNameVersion(factoryType) customFactory := func() (any, error) { return 10, nil } c.Assert(r.Register("name", 0, customFactory), gc.IsNil) f, err := r.Get("name", 0) c.Assert(err, gc.IsNil) c.Assert(f, gc.NotNil) res, err := f.(Factory)() c.Assert(err, gc.IsNil) c.Check(res, gc.Equals, 10) } func (s *registrySuite) TestGetUnknown(c *gc.C) { r := registry.NewTypedNameVersion(factoryType) f, err := r.Get("name", 0) c.Check(err, jc.Satisfies, errors.IsNotFound) c.Check(err, gc.ErrorMatches, `name\(0\) not found`) c.Check(f, gc.IsNil) } func (s *registrySuite) TestGetUnknownVersion(c *gc.C) { r := registry.NewTypedNameVersion(factoryType) c.Assert(r.Register("name", 0, nilFactory), gc.IsNil) f, err := r.Get("name", 1) c.Check(err, jc.Satisfies, errors.IsNotFound) c.Check(err, gc.ErrorMatches, `name\(1\) not found`) c.Check(f, gc.IsNil) } ================================================ FILE: relativeurl.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "strings" "github.com/juju/errors" ) // RelativeURLPath returns a relative URL path that is lexically // equivalent to targpath when interpreted by url.URL.ResolveReference. // On success, the returned path will always be non-empty and relative // to basePath, even if basePath and targPath share no elements. // // It is assumed that both basePath and targPath are normalized // (have no . or .. elements). // // An error is returned if basePath or targPath are not absolute paths. func RelativeURLPath(basePath, targPath string) (string, error) { if !strings.HasPrefix(basePath, "/") { return "", errors.New("non-absolute base URL") } if !strings.HasPrefix(targPath, "/") { return "", errors.New("non-absolute target URL") } baseParts := strings.Split(basePath, "/") targParts := strings.Split(targPath, "/") // For the purposes of dotdot, the last element of // the paths are irrelevant. We save the last part // of the target path for later. lastElem := targParts[len(targParts)-1] baseParts = baseParts[0 : len(baseParts)-1] targParts = targParts[0 : len(targParts)-1] // Find the common prefix between the two paths: var i int for ; i < len(baseParts); i++ { if i >= len(targParts) || baseParts[i] != targParts[i] { break } } dotdotCount := len(baseParts) - i targOnly := targParts[i:] result := make([]string, 0, dotdotCount+len(targOnly)+1) for i := 0; i < dotdotCount; i++ { result = append(result, "..") } result = append(result, targOnly...) result = append(result, lastElem) final := strings.Join(result, "/") if final == "" { // If the final result is empty, the last element must // have been empty, so the target was slash terminated // and there were no previous elements, so "." // is appropriate. final = "." } return final, nil } ================================================ FILE: relativeurl_test.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "net/url" jujutesting "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type relativeURLSuite struct { jujutesting.LoggingSuite } var _ = gc.Suite(&relativeURLSuite{}) var relativeURLTests = []struct { base string target string expect string expectError string }{{ expectError: "non-absolute base URL", }, { base: "/foo", expectError: "non-absolute target URL", }, { base: "foo", expectError: "non-absolute base URL", }, { base: "/foo", target: "foo", expectError: "non-absolute target URL", }, { base: "/foo", target: "/bar", expect: "bar", }, { base: "/foo/", target: "/bar", expect: "../bar", }, { base: "/bar", target: "/foo/", expect: "foo/", }, { base: "/foo/", target: "/bar/", expect: "../bar/", }, { base: "/foo/bar", target: "/bar/", expect: "../bar/", }, { base: "/foo/bar/", target: "/bar/", expect: "../../bar/", }, { base: "/foo/bar/baz", target: "/foo/targ", expect: "../targ", }, { base: "/foo/bar/baz/frob", target: "/foo/bar/one/two/", expect: "../one/two/", }, { base: "/foo/bar/baz/", target: "/foo/targ", expect: "../../targ", }, { base: "/foo/bar/baz/frob/", target: "/foo/bar/one/two/", expect: "../../one/two/", }, { base: "/foo/bar", target: "/foot/bar", expect: "../foot/bar", }, { base: "/foo/bar/baz/frob", target: "/foo/bar", expect: "../../bar", }, { base: "/foo/bar/baz/frob/", target: "/foo/bar", expect: "../../../bar", }, { base: "/foo/bar/baz/frob/", target: "/foo/bar/", expect: "../../", }, { base: "/foo/bar/baz", target: "/foo/bar/other", expect: "other", }, { base: "/foo/bar/", target: "/foo/bar/", expect: ".", }, { base: "/foo/bar", target: "/foo/bar", expect: "bar", }, { base: "/foo/bar/", target: "/foo/bar/", expect: ".", }, { base: "/foo/bar", target: "/foo/", expect: ".", }, { base: "/foo", target: "/", expect: ".", }, { base: "/foo/", target: "/", expect: "../", }, { base: "/foo/bar", target: "/", expect: "../", }, { base: "/foo/bar/", target: "/", expect: "../../", }} func (*relativeURLSuite) TestRelativeURL(c *gc.C) { for i, test := range relativeURLTests { c.Logf("test %d: %q %q", i, test.base, test.target) // Sanity check the test itself. if test.expectError == "" { baseURL := &url.URL{Path: test.base} expectURL := &url.URL{Path: test.expect} targetURL := baseURL.ResolveReference(expectURL) c.Check(targetURL.Path, gc.Equals, test.target, gc.Commentf("resolve reference failure (%q + %q != %q)", test.base, test.expect, test.target)) } result, err := utils.RelativeURLPath(test.base, test.target) if test.expectError != "" { c.Assert(err, gc.ErrorMatches, test.expectError) c.Assert(result, gc.Equals, "") } else { c.Assert(err, gc.IsNil) c.Check(result, gc.Equals, test.expect) } } } ================================================ FILE: setenv.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "strings" ) // Setenv sets an environment variable entry in the given env slice (as // returned by os.Environ or passed in exec.Cmd.Environ) to the given // value. The entry should be in the form "x=y" where x is the name of the // environment variable and y is its value; if not, env will be // returned unchanged. // // If a value isn't already present in the slice, the entry is appended. // // The new environ slice is returned. func Setenv(env []string, entry string) []string { i := strings.Index(entry, "=") if i == -1 { return env } prefix := entry[0 : i+1] for i, e := range env { if strings.HasPrefix(e, prefix) { env[i] = entry return env } } return append(env, entry) } ================================================ FILE: setenv_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type SetenvSuite struct{} var _ = gc.Suite(&SetenvSuite{}) var setenvTests = []struct { set string expect []string }{ {"foo=1", []string{"foo=1", "arble="}}, {"foo=", []string{"foo=", "arble="}}, {"arble=23", []string{"foo=bar", "arble=23"}}, {"zaphod=42", []string{"foo=bar", "arble=", "zaphod=42"}}, {"bar", []string{"foo=bar", "arble="}}, } func (*SetenvSuite) TestSetenv(c *gc.C) { env0 := []string{"foo=bar", "arble="} for i, t := range setenvTests { c.Logf("test %d", i) env := make([]string, len(env0)) copy(env, env0) env = utils.Setenv(env, t.set) c.Check(env, gc.DeepEquals, t.expect) } } ================================================ FILE: shell/bash.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell import ( "strings" ) // BashRenderer is the shell renderer for bash. type BashRenderer struct { unixRenderer } // Render implements ScriptWriter. func (*BashRenderer) RenderScript(commands []string) []byte { commands = append([]string{"#!/usr/bin/env bash", ""}, commands...) return []byte(strings.Join(commands, "\n")) } ================================================ FILE: shell/bash_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell_test import ( "os" "time" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/shell" ) type bashSuite struct { testing.IsolationSuite dirname string filename string renderer *shell.BashRenderer } var _ = gc.Suite(&bashSuite{}) func (s *bashSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) s.dirname = `/some/dir` s.filename = s.dirname + `/file` s.renderer = &shell.BashRenderer{} } func (s bashSuite) TestExeSuffix(c *gc.C) { suffix := s.renderer.ExeSuffix() c.Check(suffix, gc.Equals, "") } func (s bashSuite) TestShQuote(c *gc.C) { quoted := s.renderer.Quote("abc") c.Check(quoted, gc.Equals, `'abc'`) } func (s bashSuite) TestChmod(c *gc.C) { commands := s.renderer.Chmod(s.filename, 0644) c.Check(commands, jc.DeepEquals, []string{ "chmod 0644 '/some/dir/file'", }) } func (s bashSuite) TestWriteFile(c *gc.C) { data := []byte("something\nhere\n") commands := s.renderer.WriteFile(s.filename, data) expected := `cat > '/some/dir/file' << 'EOF' something here EOF` c.Check(commands, jc.DeepEquals, []string{ expected, }) } func (s bashSuite) TestMkdir(c *gc.C) { commands := s.renderer.Mkdir(s.dirname) c.Check(commands, jc.DeepEquals, []string{ `mkdir '/some/dir'`, }) } func (s bashSuite) TestMkdirAll(c *gc.C) { commands := s.renderer.MkdirAll(s.dirname) c.Check(commands, jc.DeepEquals, []string{ `mkdir -p '/some/dir'`, }) } func (s bashSuite) TestChown(c *gc.C) { commands := s.renderer.Chown("/a/b/c", "x", "y") c.Check(commands, jc.DeepEquals, []string{ "chown x:y '/a/b/c'", }) } func (s bashSuite) TestTouchDefault(c *gc.C) { commands := s.renderer.Touch("/a/b/c", nil) c.Check(commands, jc.DeepEquals, []string{ "touch '/a/b/c'", }) } func (s bashSuite) TestTouchTimestamp(c *gc.C) { now := time.Date(2015, time.Month(3), 14, 12, 26, 38, 0, time.UTC) commands := s.renderer.Touch("/a/b/c", &now) c.Check(commands, jc.DeepEquals, []string{ "touch -t 201503141226.38 '/a/b/c'", }) } func (s bashSuite) TestRedirectFD(c *gc.C) { commands := s.renderer.RedirectFD("stdout", "stderr") c.Check(commands, jc.DeepEquals, []string{ "exec 2>&1", }) } func (s bashSuite) TestRedirectOutput(c *gc.C) { commands := s.renderer.RedirectOutput("/a/b/c") c.Check(commands, jc.DeepEquals, []string{ "exec >> '/a/b/c'", }) } func (s bashSuite) TestRedirectOutputReset(c *gc.C) { commands := s.renderer.RedirectOutputReset("/a/b/c") c.Check(commands, jc.DeepEquals, []string{ "exec > '/a/b/c'", }) } func (s bashSuite) TestScriptFilename(c *gc.C) { filename := s.renderer.ScriptFilename("spam", "/ham/eggs") c.Check(filename, gc.Equals, "/ham/eggs/spam.sh") } func (s bashSuite) TestScriptPermissions(c *gc.C) { perm := s.renderer.ScriptPermissions() c.Check(perm, gc.Equals, os.FileMode(0755)) } ================================================ FILE: shell/command.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell import ( "os" "time" ) // CommandRenderer provides methods that may be used to generate shell // commands for a variety of shell and filesystem operations. type CommandRenderer interface { // Chown returns a shell command for changing the ownership of // a file or directory. The copies the behavior of os.Chown, // though it also supports names in addition to ints. Chown(name, user, group string) []string // Chmod returns a shell command that sets the given file's // permissions. The result is equivalent to os.Chmod. Chmod(path string, perm os.FileMode) []string // WriteFile returns a shell command that writes the provided // content to a file. The command is functionally equivalent to // ioutil.WriteFile with permissions from the current umask. WriteFile(filename string, data []byte) []string // Mkdir returns a shell command for creating a directory. The // command is functionally equivalent to os.MkDir using permissions // appropriate for a directory. Mkdir(dirname string) []string // MkdirAll returns a shell command for creating a directory and // all missing parent directories. The command is functionally // equivalent to os.MkDirAll using permissions appropriate for // a directory. MkdirAll(dirname string) []string // Touch returns a shell command that updates the atime and ctime // of the named file. If the provided timestamp is nil then the // current time is used. If the file does not exist then it is // created. If UTC is desired then Time.UTC() should be called // before calling Touch. Touch(filename string, timestamp *time.Time) []string } ================================================ FILE: shell/interface_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell var _ Renderer = (*BashRenderer)(nil) var _ Renderer = (*PowershellRenderer)(nil) var _ Renderer = (*WinCmdRenderer)(nil) ================================================ FILE: shell/output.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell import ( "strconv" "strings" ) // OutputRenderer exposes the Renderer methods that relate to shell output. // // The methods all accept strings to identify their file descriptor // arguments. While the interpretation of these values is up to the // renderer, it will likely conform to the result of calling ResolveFD. // If an FD arg is not recognized then no commands will be returned. // Unless otherwise specified, the default file descriptor is stdout // (FD 1). This applies to the empty string. type OutputRenderer interface { // RedirectFD returns a shell command that redirects the src // file descriptor to the dst one. RedirectFD(dst, src string) []string // TODO(ericsnow) Add CopyFD and CreateFD? // TODO(ericsnow) Support passing the src FD as an arg? // RedirectOutput will cause all subsequent output from the shell // (or script) to go be appended to the given file. Only stdout is // redirected (use RedirectFD to redirect stderr or other FDs). // // The file should already exist (so a call to Touch may be // necessary before calling RedirectOutput). If the file should have // specific permissions or a specific owner then Chmod and Chown // should be called before calling RedirectOutput. RedirectOutput(filename string) []string // RedirectOutputReset will cause all subsequent output from the // shell (or script) to go be written to the given file. The file // will be reset (truncated to 0) before anything is written. Only // stdout is redirected (use RedirectFD to redirect stderr or // other FDs). // // The file should already exist (so a call to Touch may be // necessary before calling RedirectOutputReset). If the file should // have specific permissions or a specific owner then Chmod and // Chown should be called before calling RedirectOutputReset. RedirectOutputReset(filename string) []string } // ResolveFD converts the file descriptor name to the corresponding int. // "stdout" and "out" match stdout (FD 1). "stderr" and "err" match // stderr (FD 2), "stdin" and "in" match stdin (FD 0). All positive // integers match. If there should be an upper bound then the caller // should check it on the result. If the provided name is empty then // the result defaults to stdout. If the name is not recognized then // false is returned. func ResolveFD(name string) (int, bool) { switch strings.ToLower(name) { case "stdout", "out", "": return 1, true case "stderr", "err": return 2, true case "stdin", "in": return 0, true default: fd, err := strconv.ParseUint(name, 10, 64) if err != nil { return -1, false } return int(fd), true } } ================================================ FILE: shell/package_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell_test import ( "testing" gc "gopkg.in/check.v1" ) func Test(t *testing.T) { gc.TestingT(t) } ================================================ FILE: shell/powershell.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell import ( "encoding/base64" "fmt" "os" "golang.org/x/text/encoding/unicode" "github.com/juju/errors" "github.com/juju/utils/v4" ) // PowershellRenderer is a shell renderer for Windows Powershell. type PowershellRenderer struct { windowsRenderer } // Quote implements Renderer. func (pr *PowershellRenderer) Quote(str string) string { return utils.WinPSQuote(str) } // Chmod implements Renderer. func (pr *PowershellRenderer) Chmod(path string, perm os.FileMode) []string { // TODO(ericsnow) Is this necessary? Should we use Set-Acl? return nil } // WriteFile implements Renderer. func (pr *PowershellRenderer) WriteFile(filename string, data []byte) []string { filename = pr.Quote(filename) return []string{ fmt.Sprintf("Set-Content %s @\"\n%s\n\"@", filename, data), } } // MkDir implements Renderer. func (pr *PowershellRenderer) Mkdir(dirname string) []string { dirname = pr.FromSlash(dirname) return []string{ fmt.Sprintf(`mkdir %s`, pr.Quote(dirname)), } } // MkdirAll implements Renderer. func (pr *PowershellRenderer) MkdirAll(dirname string) []string { return pr.Mkdir(dirname) } // ScriptFilename implements ScriptWriter. func (pr *PowershellRenderer) ScriptFilename(name, dirname string) string { return pr.Join(dirname, name+".ps1") } // By default, winrm executes command usind cmd. Prefix the command we send over WinRM with powershell.exe. // the powershell.exe it's a program that will execute the "%s" encoded command. // A breakdown of the parameters: // // -NonInteractive - prevent any prompts from stopping the execution of the scrips // -ExecutionPolicy - sets the execution policy for the current command, regardless of the default ExecutionPolicy on the system. // -EncodedCommand - allows us to run a base64 encoded script. This spares us from having to quote/escape shell special characters. const psRemoteWrapper = "powershell.exe -Sta -NonInteractive -ExecutionPolicy RemoteSigned -EncodedCommand %s" // newEncodedPSScript returns a UTF16-LE, base64 encoded script. // The -EncodedCommand parameter expects this encoding for any base64 script we send over. func newEncodedPSScript(script string) (string, error) { uni := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) encoded, err := uni.NewEncoder().String(script) if err != nil { return "", err } return base64.StdEncoding.EncodeToString([]byte(encoded)), nil } // NewPSEncodedCommand converts the given string to a UTF16-LE, base64 encoded string, // suitable for execution using powershell.exe -EncodedCommand. This can be used on // local systems, as well as remote systems via WinRM. func NewPSEncodedCommand(script string) (string, error) { var err error script, err = newEncodedPSScript(script) if err != nil { return "", errors.Annotatef(err, "Cannot construct powershell command for remote execution") } return fmt.Sprintf(psRemoteWrapper, script), nil } ================================================ FILE: shell/powershell_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell_test import ( "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/shell" ) var _ = gc.Suite(&powershellSuite{}) type powershellSuite struct { testing.IsolationSuite dirname string filename string renderer *shell.PowershellRenderer } func (s *powershellSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) s.dirname = `C:\some\dir` s.filename = s.dirname + `\file` s.renderer = &shell.PowershellRenderer{} } func (s powershellSuite) TestExeSuffix(c *gc.C) { suffix := s.renderer.ExeSuffix() c.Check(suffix, gc.Equals, ".exe") } func (s powershellSuite) TestShQuote(c *gc.C) { quoted := s.renderer.Quote("abc") c.Check(quoted, gc.Equals, `'abc'`) } func (s powershellSuite) TestChmod(c *gc.C) { commands := s.renderer.Chmod(s.filename, 0644) c.Check(commands, gc.HasLen, 0) } func (s powershellSuite) TestWriteFile(c *gc.C) { data := []byte("something\nhere\n") commands := s.renderer.WriteFile(s.filename, data) expected := ` Set-Content 'C:\some\dir\file' @" something here "@`[1:] c.Check(commands, jc.DeepEquals, []string{ expected, }) } func (s powershellSuite) TestMkdir(c *gc.C) { commands := s.renderer.Mkdir(s.dirname) c.Check(commands, jc.DeepEquals, []string{ `mkdir 'C:\some\dir'`, }) } func (s powershellSuite) TestMkdirAll(c *gc.C) { commands := s.renderer.MkdirAll(s.dirname) c.Check(commands, jc.DeepEquals, []string{ `mkdir 'C:\some\dir'`, }) } func (s powershellSuite) TestNewPSEncodedCommand(c *gc.C) { script := ` Get-WmiObject win32_processor ` expected := "powershell.exe -Sta -NonInteractive -ExecutionPolicy RemoteSigned -EncodedCommand CgAJAEcAZQB0AC0AVwBtAGkATwBiAGoAZQBjAHQAIAB3AGkAbgAzADIAXwBwAHIAbwBjAGUAcwBzAG8AcgAKAA==" out, err := shell.NewPSEncodedCommand(script) c.Assert(err, gc.IsNil) c.Assert((len(out) > 0), gc.Equals, true) c.Assert(out, jc.DeepEquals, expected) } ================================================ FILE: shell/renderer.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell import ( "runtime" "strings" "github.com/juju/errors" "github.com/juju/utils/v4" "github.com/juju/utils/v4/filepath" ) // A PathRenderer generates paths that are appropriate for a given // shell environment. type PathRenderer interface { filepath.Renderer // Quote generates a new string with quotation marks and relevant // escape/control characters properly escaped. The resulting string // is wrapped in quotation marks such that it will be treated as a // single string by the shell. Quote(str string) string // ExeSuffix returns the filename suffix for executable files. ExeSuffix() string } // Renderer provides all the functionality needed to generate shell- // compatible paths and commands. type Renderer interface { PathRenderer CommandRenderer OutputRenderer } // NewRenderer returns a Renderer for the given shell, OS, or distro name. func NewRenderer(name string) (Renderer, error) { if name == "" { name = runtime.GOOS } else { name = strings.ToLower(name) } // Try known shell names first. switch name { case "bash": return &BashRenderer{}, nil case "ps", "powershell": return &PowershellRenderer{}, nil case "cmd", "batch", "bat": return &WinCmdRenderer{}, nil } // Fall back to operating systems. switch { case name == "windows": return &PowershellRenderer{}, nil case utils.OSIsUnix(name): return &BashRenderer{}, nil } // Finally try distros. switch name { case "ubuntu": return &BashRenderer{}, nil } return nil, errors.NotFoundf("renderer for %q", name) } ================================================ FILE: shell/renderer_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell_test import ( "runtime" "github.com/juju/errors" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" "github.com/juju/utils/v4/shell" ) type rendererSuite struct { testing.IsolationSuite unix *shell.BashRenderer windows *shell.PowershellRenderer } var _ = gc.Suite(&rendererSuite{}) func (s *rendererSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) s.unix = &shell.BashRenderer{} s.windows = &shell.PowershellRenderer{} } func (s rendererSuite) checkRenderer(c *gc.C, renderer shell.Renderer, expected string) { switch expected { case "powershell": c.Check(renderer, gc.FitsTypeOf, s.windows) case "bash": c.Check(renderer, gc.FitsTypeOf, s.unix) default: c.Errorf("unknown kind %q", expected) } } func (s rendererSuite) TestNewRendererDefault(c *gc.C) { // All possible values of runtime.GOOS should be supported. renderer, err := shell.NewRenderer("") c.Assert(err, jc.ErrorIsNil) switch runtime.GOOS { case "windows": s.checkRenderer(c, renderer, "powershell") default: s.checkRenderer(c, renderer, "bash") } } func (s rendererSuite) TestNewRendererGOOS(c *gc.C) { // All possible values of runtime.GOOS should be supported. renderer, err := shell.NewRenderer(runtime.GOOS) c.Assert(err, jc.ErrorIsNil) switch runtime.GOOS { case "windows": s.checkRenderer(c, renderer, "powershell") default: s.checkRenderer(c, renderer, "bash") } } func (s rendererSuite) TestNewRendererWindows(c *gc.C) { renderer, err := shell.NewRenderer("windows") c.Assert(err, jc.ErrorIsNil) s.checkRenderer(c, renderer, "powershell") } func (s rendererSuite) TestNewRendererUnix(c *gc.C) { for _, os := range utils.OSUnix { c.Logf("trying %q", os) renderer, err := shell.NewRenderer(os) c.Assert(err, jc.ErrorIsNil) s.checkRenderer(c, renderer, "bash") } } func (s rendererSuite) TestNewRendererDistros(c *gc.C) { distros := []string{"ubuntu"} for _, distro := range distros { c.Logf("trying %q", distro) renderer, err := shell.NewRenderer(distro) c.Assert(err, jc.ErrorIsNil) s.checkRenderer(c, renderer, "bash") } } func (s rendererSuite) TestNewRendererUnknown(c *gc.C) { _, err := shell.NewRenderer("") c.Check(err, jc.Satisfies, errors.IsNotFound) } ================================================ FILE: shell/script.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell import ( "fmt" "os" "github.com/juju/utils/v4" ) // DumpFileOnErrorScript returns a bash script that // may be used to dump the contents of the specified // file to stderr when the shell exits with an error. func DumpFileOnErrorScript(filename string) string { script := ` dump_file() { code=$? if [ $code -ne 0 -a -e %s ]; then cat %s >&2 fi exit $code } trap dump_file EXIT `[1:] filename = utils.ShQuote(filename) return fmt.Sprintf(script, filename, filename) } // A ScriptRenderer provides the functionality necessary to render a // sequence of shell commands into the content of a shell script. type ScriptRenderer interface { // RenderScript generates the content of a shell script for the // provided shell commands. RenderScript(commands []string) []byte } // A ScriptWriter provides the functionality necessarily to render and // write a sequence of shell commands to a shell script that is ready // to be run. type ScriptWriter interface { ScriptRenderer // Chmod returns a shell command that sets the given file's // permissions. The result is equivalent to os.Chmod. Chmod(path string, perm os.FileMode) []string // WriteFile returns a shell command that writes the provided // content to a file. The command is functionally equivalent to // ioutil.WriteFile with permissions from the current umask. WriteFile(filename string, data []byte) []string // ScriptFilename generates a filename appropriate for a script // from the provided file and directory names. ScriptFilename(name, dirname string) string // ScriptPermissions returns the permissions appropriate for a script. ScriptPermissions() os.FileMode } // WriteScript returns a sequence of shell commands that write the // provided shell commands to a file. The filename is composed from the // given directory name and name, and the appropriate suffix for a shell // script is applied. The script content is prefixed with any necessary // content related to shell scripts (e.g. a shbang line). The file's // permissions are set to those appropriate for a script (e.g. 0755). func WriteScript(renderer ScriptWriter, name, dirname string, script []string) []string { filename := renderer.ScriptFilename(name, dirname) perm := renderer.ScriptPermissions() var commands []string data := renderer.RenderScript(script) commands = append(commands, renderer.WriteFile(filename, data)...) commands = append(commands, renderer.Chmod(filename, perm)...) return commands } ================================================ FILE: shell/script_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell_test import ( "bytes" "io/ioutil" "os" "os/exec" "path/filepath" "strings" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/shell" ) type scriptSuite struct { testing.IsolationSuite } var _ = gc.Suite(&scriptSuite{}) func (*scriptSuite) TestDumpFileOnErrorScriptOutput(c *gc.C) { script := shell.DumpFileOnErrorScript("a b c") c.Assert(script, gc.Equals, ` dump_file() { code=$? if [ $code -ne 0 -a -e 'a b c' ]; then cat 'a b c' >&2 fi exit $code } trap dump_file EXIT `[1:]) } func (*scriptSuite) TestDumpFileOnErrorScript(c *gc.C) { tempdir := c.MkDir() filename := filepath.Join(tempdir, "log.txt") err := ioutil.WriteFile(filename, []byte("abc"), 0644) c.Assert(err, gc.IsNil) dumpScript := shell.DumpFileOnErrorScript(filename) c.Logf("%s", dumpScript) run := func(command string) (stdout, stderr string) { var stdoutBuf, stderrBuf bytes.Buffer cmd := exec.Command("/bin/bash", "-s") cmd.Stdin = strings.NewReader(dumpScript + command) cmd.Stdout = &stdoutBuf cmd.Stderr = &stderrBuf cmd.Run() return stdoutBuf.String(), stderrBuf.String() } stdout, stderr := run("exit 0") c.Assert(stdout, gc.Equals, "") c.Assert(stderr, gc.Equals, "") stdout, stderr = run("exit 1") c.Assert(stdout, gc.Equals, "") c.Assert(stderr, gc.Equals, "abc") err = os.Remove(filename) c.Assert(err, gc.IsNil) stdout, stderr = run("exit 1") c.Assert(stdout, gc.Equals, "") c.Assert(stderr, gc.Equals, "") } func (*scriptSuite) TestWriteScriptUnix(c *gc.C) { renderer := &shell.BashRenderer{} script := ` exec a-command exec another-command ` commands := shell.WriteScript(renderer, "spam", "/ham/eggs", strings.Split(script, "\n")) cmd := ` cat > '/ham/eggs/spam.sh' << 'EOF' #!/usr/bin/env bash exec a-command exec another-command EOF`[1:] c.Check(commands, jc.DeepEquals, []string{ cmd, "chmod 0755 '/ham/eggs/spam.sh'", }) } func (*scriptSuite) TestWriteScriptWindows(c *gc.C) { renderer := &shell.PowershellRenderer{} script := ` exec a-command exec another-command ` commands := shell.WriteScript(renderer, "spam", `C:\ham\eggs`, strings.Split(script, "\n")) c.Check(commands, jc.DeepEquals, []string{ `Set-Content 'C:\ham\eggs\spam.ps1' @" exec a-command exec another-command "@`, }) } ================================================ FILE: shell/unix.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell import ( "fmt" "os" "time" "github.com/juju/utils/v4" "github.com/juju/utils/v4/filepath" ) // unixRenderer is the base shell renderer for "unix" shells. type unixRenderer struct { filepath.UnixRenderer } // Quote implements Renderer. func (unixRenderer) Quote(str string) string { // This *may* not be correct for *all* unix shells... return utils.ShQuote(str) } // ExeSuffix implements Renderer. func (unixRenderer) ExeSuffix() string { return "" } // Mkdir implements Renderer. func (ur unixRenderer) Mkdir(dirname string) []string { dirname = ur.Quote(dirname) return []string{ fmt.Sprintf("mkdir %s", dirname), } } // MkdirAll implements Renderer. func (ur unixRenderer) MkdirAll(dirname string) []string { dirname = ur.Quote(dirname) return []string{ fmt.Sprintf("mkdir -p %s", dirname), } } // Chmod implements Renderer. func (ur unixRenderer) Chmod(path string, perm os.FileMode) []string { path = ur.Quote(path) return []string{ fmt.Sprintf("chmod %04o %s", perm, path), } } // Chown implements Renderer. func (ur unixRenderer) Chown(path, owner, group string) []string { path = ur.Quote(path) return []string{ fmt.Sprintf("chown %s:%s %s", owner, group, path), } } // Touch implements Renderer. func (ur unixRenderer) Touch(path string, timestamp *time.Time) []string { path = ur.Quote(path) var opt string if timestamp != nil { opt = timestamp.Format("-t 200601021504.05 ") } return []string{ fmt.Sprintf("touch %s%s", opt, path), } } // WriteFile implements Renderer. func (ur unixRenderer) WriteFile(filename string, data []byte) []string { filename = ur.Quote(filename) return []string{ // An alternate approach would be to use printf. fmt.Sprintf("cat > %s << 'EOF'\n%s\nEOF", filename, data), } } func (unixRenderer) outFD(name string) (int, bool) { fd, ok := ResolveFD(name) if !ok || fd <= 0 { return -1, false } return fd, true } // RedirectFD implements OutputRenderer. func (ur unixRenderer) RedirectFD(dst, src string) []string { dstFD, ok := ur.outFD(dst) if !ok { return nil } srcFD, ok := ur.outFD(src) if !ok { return nil } return []string{ fmt.Sprintf("exec %d>&%d", srcFD, dstFD), } } // RedirectOutput implements OutputRenderer. func (ur unixRenderer) RedirectOutput(filename string) []string { filename = ur.Quote(filename) return []string{ "exec >> " + filename, } } // RedirectOutputReset implements OutputRenderer. func (ur unixRenderer) RedirectOutputReset(filename string) []string { filename = ur.Quote(filename) return []string{ "exec > " + filename, } } // ScriptFilename implements ScriptWriter. func (ur *unixRenderer) ScriptFilename(name, dirname string) string { return ur.Join(dirname, name+".sh") } // ScriptPermissions implements ScriptWriter. func (ur *unixRenderer) ScriptPermissions() os.FileMode { return 0755 } ================================================ FILE: shell/win.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell import ( "os" "strings" "time" "github.com/juju/utils/v4/filepath" ) // windowsRenderer is the base implementation for Windows shells. type windowsRenderer struct { filepath.WindowsRenderer } // ExeSuffix implements Renderer. func (w *windowsRenderer) ExeSuffix() string { return ".exe" } // ScriptPermissions implements ScriptWriter. func (w *windowsRenderer) ScriptPermissions() os.FileMode { return 0755 } // Render implements ScriptWriter. func (w *windowsRenderer) RenderScript(commands []string) []byte { return []byte(strings.Join(commands, "\n")) } // Chown implements Renderer. func (w windowsRenderer) Chown(path, owner, group string) []string { // TODO(ericsnow) Use ??? panic("not supported") } // Touch implements Renderer. func (w windowsRenderer) Touch(path string, timestamp *time.Time) []string { // TODO(ericsnow) Use ??? panic("not supported") } // RedirectFD implements OutputRenderer. func (w windowsRenderer) RedirectFD(dst, src string) []string { // TODO(ericsnow) Use ??? panic("not supported") } // RedirectOutput implements OutputRenderer. func (w windowsRenderer) RedirectOutput(filename string) []string { // TODO(ericsnow) Use ??? panic("not supported") } // RedirectOutputReset implements OutputRenderer. func (w windowsRenderer) RedirectOutputReset(filename string) []string { // TODO(ericsnow) Use ??? panic("not supported") } ================================================ FILE: shell/wincmd.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell import ( "bytes" "fmt" "os" "github.com/juju/utils/v4" ) // WinCmdRenderer is a shell renderer for Windows cmd.exe. type WinCmdRenderer struct { windowsRenderer } // Quote implements Renderer. func (wcr *WinCmdRenderer) Quote(str string) string { return utils.WinCmdQuote(str) } // Chmod implements Renderer. func (wcr *WinCmdRenderer) Chmod(path string, perm os.FileMode) []string { // TODO(ericsnow) Is this necessary? Should we use icacls? return nil } // WriteFile implements Renderer. func (wcr *WinCmdRenderer) WriteFile(filename string, data []byte) []string { filename = wcr.Quote(filename) var commands []string for _, line := range bytes.Split(data, []byte{'\n'}) { cmd := fmt.Sprintf(">>%s @echo %s", filename, line) commands = append(commands, cmd) } return commands } // MkDir implements Renderer. func (wcr *WinCmdRenderer) Mkdir(dirname string) []string { dirname = wcr.Quote(dirname) return []string{ fmt.Sprintf(`mkdir %s`, wcr.FromSlash(dirname)), } } // MkDirAll implements Renderer. func (wcr *WinCmdRenderer) MkdirAll(dirname string) []string { dirname = wcr.Quote(dirname) // TODO(ericsnow) Wrap in "setlocal enableextensions...endlocal"? return []string{ fmt.Sprintf(`mkdir %s`, wcr.FromSlash(dirname)), } } // ScriptFilename implements ScriptWriter. func (wcr *WinCmdRenderer) ScriptFilename(name, dirname string) string { return wcr.Join(dirname, name+".bat") } ================================================ FILE: shell/wincmd_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package shell_test import ( "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/shell" ) var _ = gc.Suite(&winCmdSuite{}) type winCmdSuite struct { testing.IsolationSuite dirname string filename string renderer *shell.WinCmdRenderer } func (s *winCmdSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) s.dirname = `C:\some\dir` s.filename = s.dirname + `\file` s.renderer = &shell.WinCmdRenderer{} } func (s winCmdSuite) TestExeSuffix(c *gc.C) { suffix := s.renderer.ExeSuffix() c.Check(suffix, gc.Equals, ".exe") } func (s winCmdSuite) TestShQuote(c *gc.C) { quoted := s.renderer.Quote("abc") c.Check(quoted, gc.Equals, `^"abc^"`) } func (s winCmdSuite) TestChmod(c *gc.C) { commands := s.renderer.Chmod(s.filename, 0644) c.Check(commands, gc.HasLen, 0) } func (s winCmdSuite) TestWriteFile(c *gc.C) { data := []byte("something\nhere\n") commands := s.renderer.WriteFile(s.filename, data) c.Check(commands, jc.DeepEquals, []string{ `>>^"C:\\some\\dir\\file^" @echo something`, `>>^"C:\\some\\dir\\file^" @echo here`, `>>^"C:\\some\\dir\\file^" @echo `, }) } func (s winCmdSuite) TestMkdir(c *gc.C) { commands := s.renderer.Mkdir(s.dirname) c.Check(commands, jc.DeepEquals, []string{ `mkdir ^"C:\\some\\dir^"`, }) } func (s winCmdSuite) TestMkdirAll(c *gc.C) { commands := s.renderer.MkdirAll(s.dirname) c.Check(commands, jc.DeepEquals, []string{ `mkdir ^"C:\\some\\dir^"`, }) } ================================================ FILE: size.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "math" "strconv" "strings" "unicode" "github.com/juju/errors" ) // ParseSize parses the string as a size, in mebibytes. // // The string must be a is a non-negative number with // an optional multiplier suffix (M, G, T, P, E, Z, or Y). // If the suffix is not specified, "M" is implied. func ParseSize(str string) (MB uint64, err error) { // Find the first non-digit/period: i := strings.IndexFunc(str, func(r rune) bool { return r != '.' && !unicode.IsDigit(r) }) var multiplier float64 = 1 if i > 0 { suffix := str[i:] multiplier = 0 for j := 0; j < len(sizeSuffixes); j++ { base := string(sizeSuffixes[j]) // M, MB, or MiB are all valid. switch suffix { case base, base + "B", base + "iB": multiplier = float64(sizeSuffixMultiplier(j)) break } } if multiplier == 0 { return 0, errors.Errorf("invalid multiplier suffix %q, expected one of %s", suffix, []byte(sizeSuffixes)) } str = str[:i] } val, err := strconv.ParseFloat(str, 64) if err != nil || val < 0 { return 0, errors.Errorf("expected a non-negative number, got %q", str) } val *= multiplier return uint64(math.Ceil(val)), nil } var sizeSuffixes = "MGTPEZY" func sizeSuffixMultiplier(i int) int { return 1 << uint(i*10) } // SizeTracker tracks the number of bytes passing through // its Write method (which is otherwise a no-op). // // Use SizeTracker with io.MultiWriter() to track number of bytes // written. Use with io.TeeReader() to track number of bytes read. type SizeTracker struct { // size is the number of bytes written so far. size int64 } // Size returns the number of bytes written so far. func (st SizeTracker) Size() int64 { return st.size } // Write implements io.Writer. func (st *SizeTracker) Write(data []byte) (n int, err error) { n = len(data) st.size += int64(n) return n, nil } ================================================ FILE: size_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "io" "io/ioutil" "github.com/juju/testing" jc "github.com/juju/testing/checkers" "github.com/juju/testing/filetesting" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) var _ = gc.Suite(&sizeSuite{}) type sizeSuite struct { testing.IsolationSuite } func (*sizeSuite) TestParseSize(c *gc.C) { type test struct { in string out uint64 err string } tests := []test{{ in: "", err: `expected a non-negative number, got ""`, }, { in: "-1", err: `expected a non-negative number, got "-1"`, }, { in: "1MZ", err: `invalid multiplier suffix "MZ", expected one of MGTPEZY`, }, { in: "0", out: 0, }, { in: "123", out: 123, }, { in: "1M", out: 1, }, { in: "0.5G", out: 512, }, { in: "0.5GB", out: 512, }, { in: "0.5GiB", out: 512, }, { in: "0.5T", out: 524288, }, { in: "0.5P", out: 536870912, }, { in: "0.0009765625E", out: 1073741824, }, { in: "1Z", out: 1125899906842624, }, { in: "1Y", out: 1152921504606846976, }} for i, test := range tests { c.Logf("test %d: %+v", i, test) size, err := utils.ParseSize(test.in) if test.err != "" { c.Assert(err, gc.NotNil) c.Assert(err, gc.ErrorMatches, test.err) } else { c.Assert(err, gc.IsNil) c.Assert(size, gc.Equals, test.out) } } } func (*sizeSuite) TestSizingReaderOkay(c *gc.C) { expected := "some data" stub := &testing.Stub{} reader := filetesting.NewStubReader(stub, expected) var st utils.SizeTracker sizingReader := io.TeeReader(reader, &st) data, err := ioutil.ReadAll(sizingReader) c.Assert(err, jc.ErrorIsNil) stub.CheckCallNames(c, "Read", "Read") c.Check(string(data), gc.Equals, expected) c.Check(st.Size(), gc.Equals, int64(len(expected))) } func (*sizeSuite) TestSizingReaderMixedEOF(c *gc.C) { expected := "some data" stub := &testing.Stub{} reader := &filetesting.StubReader{ Stub: stub, ReturnRead: &fakeStream{ data: expected, }, } var st utils.SizeTracker sizingReader := io.TeeReader(reader, &st) data, err := ioutil.ReadAll(sizingReader) c.Assert(err, jc.ErrorIsNil) stub.CheckCallNames(c, "Read") // The EOF was mixed with the data. c.Check(string(data), gc.Equals, expected) c.Check(st.Size(), gc.Equals, int64(len(expected))) } func (*sizeSuite) TestSizingWriter(c *gc.C) { expected := "some data" stub := &testing.Stub{} writer, buffer := filetesting.NewStubWriter(stub) var st utils.SizeTracker sizingWriter := io.MultiWriter(writer, &st) n, err := sizingWriter.Write([]byte(expected)) c.Assert(err, jc.ErrorIsNil) stub.CheckCallNames(c, "Write") c.Check(n, gc.Equals, len(expected)) c.Check(buffer.String(), gc.Equals, expected) c.Check(st.Size(), gc.Equals, int64(len(expected))) } type fakeStream struct { data string pos uint64 } func (f *fakeStream) Read(data []byte) (int, error) { n := copy(data, f.data[f.pos:]) f.pos += uint64(n) if f.pos >= uint64(len(f.data)) { return n, io.EOF } return n, nil } ================================================ FILE: ssh/authorisedkeys.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh import ( "fmt" "io/ioutil" "os" "os/user" "path/filepath" "runtime" "strconv" "strings" "sync" "github.com/juju/errors" "github.com/juju/loggo/v2" "golang.org/x/crypto/ssh" "github.com/juju/utils/v4" ) var logger = loggo.GetLogger("juju.utils.ssh") type ListMode bool var ( FullKeys ListMode = true Fingerprints ListMode = false ) const ( defaultAuthKeysFile = "authorized_keys" ) type AuthorisedKey struct { Type string Key []byte Comment string } func authKeysDir(username string) (string, error) { homeDir, err := utils.UserHomeDir(username) if err != nil { return "", err } homeDir, err = utils.NormalizePath(homeDir) if err != nil { return "", err } return filepath.Join(homeDir, ".ssh"), nil } // ParseAuthorisedKey parses a non-comment line from an // authorized_keys file and returns the constituent parts. // Based on description in "man sshd". func ParseAuthorisedKey(line string) (*AuthorisedKey, error) { if strings.Contains(line, "\n") { return nil, errors.NotValidf("newline in authorized_key %q", line) } key, comment, _, _, err := ssh.ParseAuthorizedKey([]byte(line)) if err != nil { return nil, errors.Errorf("invalid authorized_key %q", line) } return &AuthorisedKey{ Type: key.Type(), Key: key.Marshal(), Comment: comment, }, nil } // ConcatAuthorisedKeys will joing two or more authorised keys together to form // a string based list of authorised keys that can be read by ssh programs. Keys // joined with a newline as the separator. func ConcatAuthorisedKeys(a, b string) string { if a == "" { return b } if b == "" { return a } if a[len(a)-1] != '\n' { return a + "\n" + b } return a + b } // SplitAuthorisedKeys extracts a key slice from the specified key data, // by splitting the key data into lines and ignoring comments and blank lines. func SplitAuthorisedKeys(keyData string) []string { var keys []string for _, key := range strings.Split(string(keyData), "\n") { key = strings.Trim(key, " \r") if len(key) == 0 { continue } if key[0] == '#' { continue } keys = append(keys, key) } return keys } func readAuthorisedKeys(username, filename string) ([]string, error) { keyDir, err := authKeysDir(username) if err != nil { return nil, err } sshKeyFile := filepath.Join(keyDir, filename) logger.Debugf("reading authorised keys file %s", sshKeyFile) keyData, err := ioutil.ReadFile(sshKeyFile) if os.IsNotExist(err) { return []string{}, nil } if err != nil { return nil, errors.Annotate(err, "reading ssh authorised keys file") } var keys []string for _, key := range strings.Split(string(keyData), "\n") { if len(strings.Trim(key, " \r")) == 0 { continue } keys = append(keys, key) } return keys, nil } func writeAuthorisedKeys(username, filename string, keys []string) error { keyDir, err := authKeysDir(username) if err != nil { return err } err = os.MkdirAll(keyDir, os.FileMode(0755)) if err != nil { return errors.Annotate(err, "cannot create ssh key directory") } keyData := strings.Join(keys, "\n") + "\n" // Get perms to use on auth keys file sshKeyFile := filepath.Join(keyDir, filename) perms := os.FileMode(0644) info, err := os.Stat(sshKeyFile) if err == nil { perms = info.Mode().Perm() } logger.Debugf("writing authorised keys file %s", sshKeyFile) err = utils.AtomicWriteFile(sshKeyFile, []byte(keyData), perms) if err != nil { return err } // TODO (wallyworld) - what to do on windows (if anything) // TODO(dimitern) - no need to use user.Current() if username // is "" - it will use the current user anyway. if runtime.GOOS != "windows" { // Ensure the resulting authorised keys file has its ownership // set to the specified username. var u *user.User if username == "" { u, err = user.Current() } else { u, err = user.Lookup(username) } if err != nil { return err } // chown requires ints but user.User has strings for windows. uid, err := strconv.Atoi(u.Uid) if err != nil { return err } gid, err := strconv.Atoi(u.Gid) if err != nil { return err } err = os.Chown(sshKeyFile, uid, gid) if err != nil { return err } } return nil } // We need a mutex because updates to the authorised keys file are done by // reading the contents, updating, and writing back out. So only one caller // at a time can use either Add, Delete, List. var keysMutex sync.Mutex // AddKeys adds the specified ssh keys to the authorized_keys file for user. // Returns an error if there is an issue with *any* of the supplied keys. func AddKeys(user string, newKeys ...string) error { keysMutex.Lock() defer keysMutex.Unlock() existingKeys, err := readAuthorisedKeys(user, defaultAuthKeysFile) if err != nil { return err } return addKeys(user, defaultAuthKeysFile, newKeys, existingKeys) } // DeleteKeys removes the specified ssh keys from the authorized ssh keys file for user. // keyIds may be either key comments or fingerprints. // Returns an error if there is an issue with *any* of the keys to delete. func DeleteKeys(user string, keyIds ...string) error { keysMutex.Lock() defer keysMutex.Unlock() existingKeys, err := readAuthorisedKeys(user, defaultAuthKeysFile) if err != nil { return err } return deleteKeys(user, defaultAuthKeysFile, existingKeys, keyIds, false) } // ReplaceKeys writes the specified ssh keys to the authorized_keys file for user, // replacing any that are already there. // Returns an error if there is an issue with *any* of the supplied keys. func ReplaceKeys(user string, newKeys ...string) error { keysMutex.Lock() defer keysMutex.Unlock() existingKeyData, err := readAuthorisedKeys(user, defaultAuthKeysFile) if err != nil { return err } var existingNonKeyLines []string for _, line := range existingKeyData { _, _, err := KeyFingerprint(line) if err != nil { existingNonKeyLines = append(existingNonKeyLines, line) } } return writeAuthorisedKeys(user, defaultAuthKeysFile, append(existingNonKeyLines, newKeys...)) } // ListKeys returns either the full keys or key comments from the authorized ssh keys file for user. func ListKeys(user string, mode ListMode) ([]string, error) { keysMutex.Lock() defer keysMutex.Unlock() keyData, err := readAuthorisedKeys(user, defaultAuthKeysFile) if err != nil { return nil, err } return listKeys(keyData, mode) } // Any ssh key added to the authorised keys list by Juju will have this prefix. // This allows Juju to know which keys have been added externally and any such keys // will always be retained by Juju when updating the authorised keys file. const JujuCommentPrefix = "Juju:" func EnsureJujuComment(key string) string { ak, err := ParseAuthorisedKey(key) // Just return an invalid key as is. if err != nil { logger.Warningf("invalid Juju ssh key %s: %v", key, err) return key } if ak.Comment == "" { return key + " " + JujuCommentPrefix + "sshkey" } else { // Add the Juju prefix to the comment if necessary. if !strings.HasPrefix(ak.Comment, JujuCommentPrefix) { commentIndex := strings.LastIndex(key, ak.Comment) return key[:commentIndex] + JujuCommentPrefix + ak.Comment } } return key } // AddKeysToFile adds the specified ssh keys to the specified file for user. // Returns an error if there is an issue with *any* of the supplied keys. func AddKeysToFile(user, file string, newKeys []string) error { keysMutex.Lock() defer keysMutex.Unlock() existingKeys, err := readAuthorisedKeys(user, file) if err != nil { return err } return addKeys(user, file, newKeys, existingKeys) } // DeleteKeysFromFile removes the specified ssh keys from the authorized ssh keys file for user. // keyIds may be either key comments or fingerprints. // Returns an error if there is an issue with *any* of the keys to delete. // // Unlike DeleteKeys, this version can delete ALL keys from the target file. func DeleteKeysFromFile(user, file string, keyIds []string) error { keysMutex.Lock() defer keysMutex.Unlock() existingKeys, err := readAuthorisedKeys(user, file) if err != nil { return err } return deleteKeys(user, file, existingKeys, keyIds, true) } // ListKeys returns either the full keys or key comments from the authorized ssh keys file for user. func ListKeysFromFile(user, file string, mode ListMode) ([]string, error) { keysMutex.Lock() defer keysMutex.Unlock() keyData, err := readAuthorisedKeys(user, file) if err != nil { return nil, err } return listKeys(keyData, mode) } func addKeys(user, file string, newKeys, existingKeys []string) error { for _, newKey := range newKeys { fingerprint, comment, err := KeyFingerprint(newKey) if err != nil { return err } if comment == "" { return errors.Errorf("cannot add ssh key without comment") } for _, key := range existingKeys { existingFingerprint, existingComment, err := KeyFingerprint(key) if err != nil { // Only log a warning if the unrecognised key line is not a comment. if key[0] != '#' { logger.Warningf("invalid existing ssh key %q: %v", key, err) } continue } if existingFingerprint == fingerprint { return errors.Errorf("cannot add duplicate ssh key: %v", fingerprint) } if existingComment == comment { return errors.Errorf("cannot add ssh key with duplicate comment: %v", comment) } } } sshKeys := append(existingKeys, newKeys...) return writeAuthorisedKeys(user, file, sshKeys) } func deleteKeys(user, file string, existingKeys, keyIdsToDelete []string, deleteAll bool) error { // Build up a map of keys indexed by fingerprint, and fingerprints indexed by comment // so we can easily get the key represented by each keyId, which may be either a fingerprint // or comment. var keysToWrite []string var sshKeys = make(map[string]string) var keyComments = make(map[string]string) for _, key := range existingKeys { fingerprint, comment, err := KeyFingerprint(key) if err != nil { logger.Debugf("keeping unrecognised existing ssh key %q: %v", key, err) keysToWrite = append(keysToWrite, key) continue } sshKeys[fingerprint] = key if comment != "" { keyComments[comment] = fingerprint } } for _, keyId := range keyIdsToDelete { // assume keyId may be a fingerprint fingerprint := keyId _, ok := sshKeys[keyId] if !ok { // keyId is a comment fingerprint, ok = keyComments[keyId] } if !ok { return errors.Errorf("cannot delete non existent key: %v", keyId) } delete(sshKeys, fingerprint) } for _, key := range sshKeys { keysToWrite = append(keysToWrite, key) } if len(keysToWrite) == 0 && !deleteAll { return errors.Errorf("cannot delete all keys") } return writeAuthorisedKeys(user, file, keysToWrite) } func listKeys(existingKeys []string, mode ListMode) ([]string, error) { var keys []string for _, key := range existingKeys { fingerprint, comment, err := KeyFingerprint(key) if err != nil { // Only log a warning if the unrecognised key line is not a comment. if key[0] != '#' { logger.Warningf("ignoring invalid ssh key %q: %v", key, err) } continue } if mode == FullKeys { keys = append(keys, key) } else { shortKey := fingerprint if comment != "" { shortKey += fmt.Sprintf(" (%s)", comment) } keys = append(keys, shortKey) } } return keys, nil } ================================================ FILE: ssh/authorisedkeys_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh_test import ( "encoding/base64" "strings" gitjujutesting "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/ssh" sshtesting "github.com/juju/utils/v4/ssh/testing" ) type AuthorisedKeysKeysSuite struct { gitjujutesting.FakeHomeSuite } const ( // We'll use the current user for ssh tests. testSSHUser = "" authKeysFile = "authorized_keys" alternativeKeysFile2 = "authorized_keys2" alternativeKeysFile3 = "authorized_keys3" ) var _ = gc.Suite(&AuthorisedKeysKeysSuite{}) func writeAuthKeysFile(c *gc.C, keys []string, file string) { err := ssh.WriteAuthorisedKeys(testSSHUser, file, keys) c.Assert(err, jc.ErrorIsNil) } func (s *AuthorisedKeysKeysSuite) TestListKeys(c *gc.C) { keys := []string{ sshtesting.ValidKeyOne.Key + " user@host", sshtesting.ValidKeyTwo.Key, } writeAuthKeysFile(c, keys, authKeysFile) keys, err := ssh.ListKeys(testSSHUser, ssh.Fingerprints) c.Assert(err, jc.ErrorIsNil) c.Assert( keys, gc.DeepEquals, []string{sshtesting.ValidKeyOne.Fingerprint + " (user@host)", sshtesting.ValidKeyTwo.Fingerprint}) } func (s *AuthorisedKeysKeysSuite) TestListKeysFull(c *gc.C) { keys := []string{ sshtesting.ValidKeyOne.Key + " user@host", sshtesting.ValidKeyTwo.Key + " anotheruser@host", } writeAuthKeysFile(c, keys, authKeysFile) actual, err := ssh.ListKeys(testSSHUser, ssh.FullKeys) c.Assert(err, jc.ErrorIsNil) c.Assert(actual, gc.DeepEquals, keys) } func (s *AuthorisedKeysKeysSuite) TestAddNewKey(c *gc.C) { key := sshtesting.ValidKeyOne.Key + " user@host" err := ssh.AddKeys(testSSHUser, key) c.Assert(err, jc.ErrorIsNil) actual, err := ssh.ListKeys(testSSHUser, ssh.FullKeys) c.Assert(err, jc.ErrorIsNil) c.Assert(actual, gc.DeepEquals, []string{key}) } func (s *AuthorisedKeysKeysSuite) TestAddMoreKeys(c *gc.C) { firstKey := sshtesting.ValidKeyOne.Key + " user@host" writeAuthKeysFile(c, []string{firstKey}, authKeysFile) moreKeys := []string{ sshtesting.ValidKeyTwo.Key + " anotheruser@host", sshtesting.ValidKeyThree.Key + " yetanotheruser@host", } err := ssh.AddKeys(testSSHUser, moreKeys...) c.Assert(err, jc.ErrorIsNil) actual, err := ssh.ListKeys(testSSHUser, ssh.FullKeys) c.Assert(err, jc.ErrorIsNil) c.Assert(actual, gc.DeepEquals, append([]string{firstKey}, moreKeys...)) } func (s *AuthorisedKeysKeysSuite) TestAddDuplicateKey(c *gc.C) { key := sshtesting.ValidKeyOne.Key + " user@host" err := ssh.AddKeys(testSSHUser, key) c.Assert(err, jc.ErrorIsNil) moreKeys := []string{ sshtesting.ValidKeyOne.Key + " user@host", sshtesting.ValidKeyTwo.Key + " yetanotheruser@host", } err = ssh.AddKeys(testSSHUser, moreKeys...) c.Assert(err, gc.ErrorMatches, "cannot add duplicate ssh key: "+sshtesting.ValidKeyOne.Fingerprint) } func (s *AuthorisedKeysKeysSuite) TestAddDuplicateComment(c *gc.C) { key := sshtesting.ValidKeyOne.Key + " user@host" err := ssh.AddKeys(testSSHUser, key) c.Assert(err, jc.ErrorIsNil) moreKeys := []string{ sshtesting.ValidKeyTwo.Key + " user@host", sshtesting.ValidKeyThree.Key + " yetanotheruser@host", } err = ssh.AddKeys(testSSHUser, moreKeys...) c.Assert(err, gc.ErrorMatches, "cannot add ssh key with duplicate comment: user@host") } func (s *AuthorisedKeysKeysSuite) TestAddKeyWithoutComment(c *gc.C) { keys := []string{ sshtesting.ValidKeyOne.Key + " user@host", sshtesting.ValidKeyTwo.Key, } err := ssh.AddKeys(testSSHUser, keys...) c.Assert(err, gc.ErrorMatches, "cannot add ssh key without comment") } func (s *AuthorisedKeysKeysSuite) TestAddKeepsUnrecognised(c *gc.C) { writeAuthKeysFile(c, []string{sshtesting.ValidKeyOne.Key, "invalid-key"}, authKeysFile) anotherKey := sshtesting.ValidKeyTwo.Key + " anotheruser@host" err := ssh.AddKeys(testSSHUser, anotherKey) c.Assert(err, jc.ErrorIsNil) actual, err := ssh.ReadAuthorisedKeys(testSSHUser, authKeysFile) c.Assert(err, jc.ErrorIsNil) c.Assert(actual, gc.DeepEquals, []string{sshtesting.ValidKeyOne.Key, "invalid-key", anotherKey}) } func (s *AuthorisedKeysKeysSuite) TestDeleteKeys(c *gc.C) { firstKey := sshtesting.ValidKeyOne.Key + " user@host" anotherKey := sshtesting.ValidKeyTwo.Key thirdKey := sshtesting.ValidKeyThree.Key + " anotheruser@host" writeAuthKeysFile(c, []string{firstKey, anotherKey, thirdKey}, authKeysFile) err := ssh.DeleteKeys(testSSHUser, "user@host", sshtesting.ValidKeyTwo.Fingerprint) c.Assert(err, jc.ErrorIsNil) actual, err := ssh.ListKeys(testSSHUser, ssh.FullKeys) c.Assert(err, jc.ErrorIsNil) c.Assert(actual, gc.DeepEquals, []string{thirdKey}) } func (s *AuthorisedKeysKeysSuite) TestDeleteKeysKeepsUnrecognised(c *gc.C) { firstKey := sshtesting.ValidKeyOne.Key + " user@host" writeAuthKeysFile(c, []string{firstKey, sshtesting.ValidKeyTwo.Key, "invalid-key"}, authKeysFile) err := ssh.DeleteKeys(testSSHUser, "user@host") c.Assert(err, jc.ErrorIsNil) actual, err := ssh.ReadAuthorisedKeys(testSSHUser, authKeysFile) c.Assert(err, jc.ErrorIsNil) c.Assert(actual, gc.DeepEquals, []string{"invalid-key", sshtesting.ValidKeyTwo.Key}) } func (s *AuthorisedKeysKeysSuite) TestDeleteNonExistentComment(c *gc.C) { firstKey := sshtesting.ValidKeyOne.Key + " user@host" writeAuthKeysFile(c, []string{firstKey}, authKeysFile) err := ssh.DeleteKeys(testSSHUser, "someone@host") c.Assert(err, gc.ErrorMatches, "cannot delete non existent key: someone@host") } func (s *AuthorisedKeysKeysSuite) TestDeleteNonExistentFingerprint(c *gc.C) { firstKey := sshtesting.ValidKeyOne.Key + " user@host" writeAuthKeysFile(c, []string{firstKey}, authKeysFile) err := ssh.DeleteKeys(testSSHUser, sshtesting.ValidKeyTwo.Fingerprint) c.Assert(err, gc.ErrorMatches, "cannot delete non existent key: "+sshtesting.ValidKeyTwo.Fingerprint) } func (s *AuthorisedKeysKeysSuite) TestDeleteLastKeyForbidden(c *gc.C) { keys := []string{ sshtesting.ValidKeyOne.Key + " user@host", sshtesting.ValidKeyTwo.Key + " yetanotheruser@host", } writeAuthKeysFile(c, keys, authKeysFile) err := ssh.DeleteKeys(testSSHUser, "user@host", sshtesting.ValidKeyTwo.Fingerprint) c.Assert(err, gc.ErrorMatches, "cannot delete all keys") } func (s *AuthorisedKeysKeysSuite) TestReplaceKeys(c *gc.C) { firstKey := sshtesting.ValidKeyOne.Key + " user@host" anotherKey := sshtesting.ValidKeyTwo.Key writeAuthKeysFile(c, []string{firstKey, anotherKey}, authKeysFile) // replaceKey is created without a comment so test that // ReplaceKeys handles keys without comments. This is // because existing keys may not have a comment and // ReplaceKeys is used to rewrite the entire authorized_keys // file when adding new keys. replaceKey := sshtesting.ValidKeyThree.Key err := ssh.ReplaceKeys(testSSHUser, replaceKey) c.Assert(err, jc.ErrorIsNil) actual, err := ssh.ListKeys(testSSHUser, ssh.FullKeys) c.Assert(err, jc.ErrorIsNil) c.Assert(actual, gc.DeepEquals, []string{replaceKey}) } func (s *AuthorisedKeysKeysSuite) TestReplaceKeepsUnrecognised(c *gc.C) { writeAuthKeysFile(c, []string{sshtesting.ValidKeyOne.Key, "invalid-key"}, authKeysFile) anotherKey := sshtesting.ValidKeyTwo.Key + " anotheruser@host" err := ssh.ReplaceKeys(testSSHUser, anotherKey) c.Assert(err, jc.ErrorIsNil) actual, err := ssh.ReadAuthorisedKeys(testSSHUser, authKeysFile) c.Assert(err, jc.ErrorIsNil) c.Assert(actual, gc.DeepEquals, []string{"invalid-key", anotherKey}) } func (s *AuthorisedKeysKeysSuite) TestEnsureJujuComment(c *gc.C) { sshKey := sshtesting.ValidKeyOne.Key for _, test := range []struct { key string expected string }{ {"invalid-key", "invalid-key"}, {sshKey, sshKey + " Juju:sshkey"}, {sshKey + " user@host", sshKey + " Juju:user@host"}, {sshKey + " Juju:user@host", sshKey + " Juju:user@host"}, {sshKey + " " + sshKey[3:5], sshKey + " Juju:" + sshKey[3:5]}, } { actual := ssh.EnsureJujuComment(test.key) c.Assert(actual, gc.Equals, test.expected) } } func (s *AuthorisedKeysKeysSuite) TestSplitAuthorisedKeys(c *gc.C) { sshKey := sshtesting.ValidKeyOne.Key for _, test := range []struct { keyData string expected []string }{ {"", nil}, {sshKey, []string{sshKey}}, {sshKey + "\n", []string{sshKey}}, {sshKey + "\n\n", []string{sshKey}}, {sshKey + "\n#comment\n", []string{sshKey}}, {sshKey + "\n #comment\n", []string{sshKey}}, {sshKey + "\ninvalid\n", []string{sshKey, "invalid"}}, } { actual := ssh.SplitAuthorisedKeys(test.keyData) c.Assert(actual, gc.DeepEquals, test.expected) } } func b64decode(c *gc.C, s string) []byte { b, err := base64.StdEncoding.DecodeString(s) c.Assert(err, jc.ErrorIsNil) return b } func (s *AuthorisedKeysKeysSuite) TestParseAuthorisedKey(c *gc.C) { for i, test := range []struct { line string key []byte comment string err string }{{ line: sshtesting.ValidKeyOne.Key, key: b64decode(c, strings.Fields(sshtesting.ValidKeyOne.Key)[1]), }, { line: sshtesting.ValidKeyOne.Key + " a b c", key: b64decode(c, strings.Fields(sshtesting.ValidKeyOne.Key)[1]), comment: "a b c", }, { line: "ssh-xsa blah", err: "invalid authorized_key \"ssh-xsa blah\"", }, { // options should be skipped line: `no-pty,principals="\"",command="\!" ` + sshtesting.ValidKeyOne.Key, key: b64decode(c, strings.Fields(sshtesting.ValidKeyOne.Key)[1]), }, { line: "ssh-rsa", err: "invalid authorized_key \"ssh-rsa\"", }, { line: sshtesting.ValidKeyOne.Key + " line1\nline2", err: "newline in authorized_key \".*", }} { c.Logf("test %d: %s", i, test.line) ak, err := ssh.ParseAuthorisedKey(test.line) if test.err != "" { c.Assert(err, gc.ErrorMatches, test.err) } else { c.Assert(err, jc.ErrorIsNil) c.Assert(ak, gc.Not(gc.IsNil)) c.Assert(ak.Key, gc.DeepEquals, test.key) c.Assert(ak.Comment, gc.Equals, test.comment) } } } func (s *AuthorisedKeysKeysSuite) TestConcatAuthorisedKeys(c *gc.C) { for _, test := range []struct{ a, b, result string }{ {"a", "", "a"}, {"", "b", "b"}, {"a", "b", "a\nb"}, {"a\n", "b", "a\nb"}, } { c.Check(ssh.ConcatAuthorisedKeys(test.a, test.b), gc.Equals, test.result) } } func (s *AuthorisedKeysKeysSuite) TestAddKeysToFileToDifferentFiles(c *gc.C) { key1 := sshtesting.ValidKeyOne.Key + " user@host" err := ssh.AddKeysToFile(testSSHUser, alternativeKeysFile2, []string{key1}) c.Assert(err, jc.ErrorIsNil) list1, err := ssh.ListKeysFromFile(testSSHUser, alternativeKeysFile2, ssh.FullKeys) c.Assert(err, jc.ErrorIsNil) c.Assert(list1, gc.DeepEquals, []string{key1}) key2 := sshtesting.ValidKeyTwo.Key + " user@host" err = ssh.AddKeysToFile(testSSHUser, alternativeKeysFile3, []string{key2}) c.Assert(err, jc.ErrorIsNil) list2, err := ssh.ListKeysFromFile(testSSHUser, alternativeKeysFile3, ssh.FullKeys) c.Assert(err, jc.ErrorIsNil) c.Assert(list2, gc.DeepEquals, []string{key2}) } func (s *AuthorisedKeysKeysSuite) TestAddKeysToFileMultipleKeys(c *gc.C) { key1 := sshtesting.ValidKeyOne.Key + " user@host" key2 := sshtesting.ValidKeyTwo.Key + " alice@host" err := ssh.AddKeysToFile(testSSHUser, alternativeKeysFile2, []string{key1, key2}) c.Assert(err, jc.ErrorIsNil) list, err := ssh.ListKeysFromFile(testSSHUser, alternativeKeysFile2, ssh.FullKeys) c.Assert(err, jc.ErrorIsNil) c.Assert(list, jc.DeepEquals, []string{key1, key2}) } func (s *AuthorisedKeysKeysSuite) TestDeleteAllKeysFromFile(c *gc.C) { key1 := sshtesting.ValidKeyOne.Key + " user@host" writeAuthKeysFile(c, []string{key1}, alternativeKeysFile2) err := ssh.DeleteKeysFromFile(testSSHUser, alternativeKeysFile2, []string{sshtesting.ValidKeyOne.Fingerprint}) c.Assert(err, jc.ErrorIsNil) emptyList, err := ssh.ListKeysFromFile(testSSHUser, alternativeKeysFile2, ssh.FullKeys) c.Assert(err, jc.ErrorIsNil) c.Assert(emptyList, gc.HasLen, 0) } func (s *AuthorisedKeysKeysSuite) TestDeleteSomeKeysFromFile(c *gc.C) { key1 := sshtesting.ValidKeyOne.Key + " user@host" key2 := sshtesting.ValidKeyTwo.Key + " alice@host" key3 := sshtesting.ValidKeyThree.Key + " bob@host" writeAuthKeysFile(c, []string{key1, key2, key3}, alternativeKeysFile2) err := ssh.DeleteKeysFromFile(testSSHUser, alternativeKeysFile2, []string{sshtesting.ValidKeyTwo.Fingerprint}) c.Assert(err, jc.ErrorIsNil) keys, err := ssh.ListKeysFromFile(testSSHUser, alternativeKeysFile2, ssh.FullKeys) c.Assert(err, jc.ErrorIsNil) c.Assert(keys, gc.HasLen, 2) c.Assert(keys, jc.SameContents, []string{key1, key3}) } ================================================ FILE: ssh/clientkeys.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh import ( "fmt" "io/ioutil" "os" "path/filepath" "strings" "sync" "github.com/juju/collections/set" "golang.org/x/crypto/ssh" "github.com/juju/utils/v4" ) const clientKeyName = "juju_id_ed25519" // PublicKeySuffix is the file extension for public key files. const PublicKeySuffix = ".pub" var ( clientKeysMutex sync.Mutex // clientKeys is a cached map of private key filenames // to ssh.Signers. The private keys are those loaded // from the client key directory, passed to LoadClientKeys. clientKeys map[string]ssh.Signer ) // LoadClientKeys loads the client SSH keys from the // specified directory, and caches them as a process-wide // global. If the directory does not exist, it is created; // if the directory did not exist, or contains no keys, it // is populated with a new key pair. // // If the directory exists, then all pairs of files where one // has the same name as the other + ".pub" will be loaded as // private/public key pairs. // // Calls to LoadClientKeys will clear the previously loaded // keys, and recompute the keys. func LoadClientKeys(dir string) error { clientKeysMutex.Lock() defer clientKeysMutex.Unlock() dir, err := utils.NormalizePath(dir) if err != nil { return err } if _, err := os.Stat(dir); err == nil { keys, err := loadClientKeys(dir) if err != nil { return err } else if len(keys) > 0 { clientKeys = keys return nil } // Directory exists but contains no keys; // fall through and create one. } if err := os.MkdirAll(dir, 0700); err != nil { return err } keyfile, key, err := generateClientKey(dir) if err != nil { os.RemoveAll(dir) return err } clientKeys = map[string]ssh.Signer{keyfile: key} return nil } // ClearClientKeys clears the client keys cached in memory. func ClearClientKeys() { clientKeysMutex.Lock() defer clientKeysMutex.Unlock() clientKeys = nil } func generateClientKey(dir string) (keyfile string, key ssh.Signer, err error) { private, public, err := GenerateKey("juju-client-key") if err != nil { return "", nil, err } clientPrivateKey, err := ssh.ParsePrivateKey([]byte(private)) if err != nil { return "", nil, err } privkeyFilename := filepath.Join(dir, clientKeyName) if err = ioutil.WriteFile(privkeyFilename, []byte(private), 0600); err != nil { return "", nil, err } if err := ioutil.WriteFile(privkeyFilename+PublicKeySuffix, []byte(public), 0600); err != nil { os.Remove(privkeyFilename) return "", nil, err } return privkeyFilename, clientPrivateKey, nil } func loadClientKeys(dir string) (map[string]ssh.Signer, error) { publicKeyFiles, err := publicKeyFiles(dir) if err != nil { return nil, err } keys := make(map[string]ssh.Signer, len(publicKeyFiles)) for _, filename := range publicKeyFiles { filename = filename[:len(filename)-len(PublicKeySuffix)] data, err := ioutil.ReadFile(filename) if err != nil { return nil, err } keys[filename], err = ssh.ParsePrivateKey(data) if err != nil { return nil, fmt.Errorf("parsing key file %q: %v", filename, err) } } return keys, nil } // privateKeys returns the private keys loaded by LoadClientKeys. func privateKeys() (signers []ssh.Signer) { clientKeysMutex.Lock() defer clientKeysMutex.Unlock() for _, key := range clientKeys { signers = append(signers, key) } return signers } // PrivateKeyFiles returns the filenames of private SSH keys loaded by // LoadClientKeys. func PrivateKeyFiles() []string { clientKeysMutex.Lock() defer clientKeysMutex.Unlock() keyfiles := make([]string, 0, len(clientKeys)) for f := range clientKeys { keyfiles = append(keyfiles, f) } return keyfiles } // PublicKeyFiles returns the filenames of public SSH keys loaded by // LoadClientKeys. func PublicKeyFiles() []string { privkeys := PrivateKeyFiles() pubkeys := make([]string, len(privkeys)) for i, priv := range privkeys { pubkeys[i] = priv + PublicKeySuffix } return pubkeys } // publicKeyFiles returns the filenames of public SSH keys // in the specified directory (all the files ending with .pub). func publicKeyFiles(clientKeysDir string) ([]string, error) { if clientKeysDir == "" { return nil, nil } var keys []string dir, err := os.Open(clientKeysDir) if err != nil { return nil, err } names, err := dir.Readdirnames(-1) dir.Close() if err != nil { return nil, err } candidates := set.NewStrings(names...) for _, name := range names { if !strings.HasSuffix(name, PublicKeySuffix) { continue } // If the private key filename also exists, add the file. priv := name[:len(name)-len(PublicKeySuffix)] if candidates.Contains(priv) { keys = append(keys, filepath.Join(dir.Name(), name)) } } return keys, nil } ================================================ FILE: ssh/clientkeys_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh_test import ( "io/ioutil" "os" gitjujutesting "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" "github.com/juju/utils/v4/ssh" ) type ClientKeysSuite struct { gitjujutesting.FakeHomeSuite } var _ = gc.Suite(&ClientKeysSuite{}) func (s *ClientKeysSuite) SetUpTest(c *gc.C) { s.FakeHomeSuite.SetUpTest(c) s.AddCleanup(func(*gc.C) { ssh.ClearClientKeys() }) generateKeyRestorer := overrideGenerateKey() s.AddCleanup(func(*gc.C) { generateKeyRestorer.Restore() }) } func checkFiles(c *gc.C, obtained, expected []string) { var err error for i, e := range expected { expected[i], err = utils.NormalizePath(e) c.Assert(err, jc.ErrorIsNil) } c.Assert(obtained, jc.SameContents, expected) } func checkPublicKeyFiles(c *gc.C, expected ...string) { keys := ssh.PublicKeyFiles() checkFiles(c, keys, expected) } func checkPrivateKeyFiles(c *gc.C, expected ...string) { keys := ssh.PrivateKeyFiles() checkFiles(c, keys, expected) } func (s *ClientKeysSuite) TestPublicKeyFiles(c *gc.C) { // LoadClientKeys will create the specified directory // and populate it with a key pair. err := ssh.LoadClientKeys("~/.juju/ssh") c.Assert(err, jc.ErrorIsNil) checkPublicKeyFiles(c, "~/.juju/ssh/juju_id_ed25519.pub") // All files ending with .pub in the client key dir get picked up. priv, pub, err := ssh.GenerateKey("whatever") c.Assert(err, jc.ErrorIsNil) err = ioutil.WriteFile(gitjujutesting.HomePath(".juju", "ssh", "whatever.pub"), []byte(pub), 0600) c.Assert(err, jc.ErrorIsNil) err = ssh.LoadClientKeys("~/.juju/ssh") c.Assert(err, jc.ErrorIsNil) // The new public key won't be observed until the // corresponding private key exists. checkPublicKeyFiles(c, "~/.juju/ssh/juju_id_ed25519.pub") err = ioutil.WriteFile(gitjujutesting.HomePath(".juju", "ssh", "whatever"), []byte(priv), 0600) c.Assert(err, jc.ErrorIsNil) err = ssh.LoadClientKeys("~/.juju/ssh") c.Assert(err, jc.ErrorIsNil) checkPublicKeyFiles(c, "~/.juju/ssh/juju_id_ed25519.pub", "~/.juju/ssh/whatever.pub") } func (s *ClientKeysSuite) TestPrivateKeyFiles(c *gc.C) { // Create/load client keys. They will be cached in memory: // any files added to the directory will not be considered // unless LoadClientKeys is called again. err := ssh.LoadClientKeys("~/.juju/ssh") c.Assert(err, jc.ErrorIsNil) checkPrivateKeyFiles(c, "~/.juju/ssh/juju_id_ed25519") priv, pub, err := ssh.GenerateKey("whatever") c.Assert(err, jc.ErrorIsNil) err = ioutil.WriteFile(gitjujutesting.HomePath(".juju", "ssh", "whatever"), []byte(priv), 0600) c.Assert(err, jc.ErrorIsNil) err = ssh.LoadClientKeys("~/.juju/ssh") c.Assert(err, jc.ErrorIsNil) // The new private key won't be observed until the // corresponding public key exists. checkPrivateKeyFiles(c, "~/.juju/ssh/juju_id_ed25519") err = ioutil.WriteFile(gitjujutesting.HomePath(".juju", "ssh", "whatever.pub"), []byte(pub), 0600) c.Assert(err, jc.ErrorIsNil) // new keys won't be reported until we call LoadClientKeys again checkPublicKeyFiles(c, "~/.juju/ssh/juju_id_ed25519.pub") checkPrivateKeyFiles(c, "~/.juju/ssh/juju_id_ed25519") err = ssh.LoadClientKeys("~/.juju/ssh") c.Assert(err, jc.ErrorIsNil) checkPublicKeyFiles(c, "~/.juju/ssh/juju_id_ed25519.pub", "~/.juju/ssh/whatever.pub") checkPrivateKeyFiles(c, "~/.juju/ssh/juju_id_ed25519", "~/.juju/ssh/whatever") } func (s *ClientKeysSuite) TestLoadClientKeysDirExists(c *gc.C) { err := os.MkdirAll(gitjujutesting.HomePath(".juju", "ssh"), 0755) c.Assert(err, jc.ErrorIsNil) err = ssh.LoadClientKeys("~/.juju/ssh") c.Assert(err, jc.ErrorIsNil) checkPrivateKeyFiles(c, "~/.juju/ssh/juju_id_ed25519") } ================================================ FILE: ssh/export_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh import ( "sync/atomic" gc "gopkg.in/check.v1" "github.com/juju/testing" ) var ( ReadAuthorisedKeys = readAuthorisedKeys WriteAuthorisedKeys = writeAuthorisedKeys InitDefaultClient = initDefaultClient DefaultIdentities = &defaultIdentities SSHDial = &sshDial ED25519GenerateKey = &ed25519GenerateKey TestCopyReader = copyReader TestNewCmd = newCmd ) type ReadLineWriter readLineWriter func PatchTerminal(s *testing.CleanupSuite, rlw ReadLineWriter) { var balance int64 s.PatchValue(&getTerminal, func() (readLineWriter, func(), error) { atomic.AddInt64(&balance, 1) cleanup := func() { atomic.AddInt64(&balance, -1) } return rlw, cleanup, nil }) s.AddCleanup(func(c *gc.C) { c.Assert(atomic.LoadInt64(&balance), gc.Equals, int64(0)) }) } func PatchNilTerminal(s *testing.CleanupSuite) { s.PatchValue(&getTerminal, func() (readLineWriter, func(), error) { return nil, func() {}, nil }) } ================================================ FILE: ssh/fakes_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh_test import ( "bytes" "io" "io/ioutil" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/ssh" ) type fakeClient struct { calls []string hostArg string commandArg []string optionsArg *ssh.Options copyArgs []string err error cmd *ssh.Cmd impl fakeCommandImpl } func (cl *fakeClient) checkCalls(c *gc.C, host string, command []string, options *ssh.Options, copyArgs []string, calls ...string) { c.Check(cl.hostArg, gc.Equals, host) c.Check(cl.commandArg, jc.DeepEquals, command) c.Check(cl.optionsArg, gc.Equals, options) c.Check(cl.copyArgs, jc.DeepEquals, copyArgs) c.Check(cl.calls, jc.DeepEquals, calls) } func (cl *fakeClient) Command(host string, command []string, options *ssh.Options) *ssh.Cmd { cl.calls = append(cl.calls, "Command") cl.hostArg = host cl.commandArg = command cl.optionsArg = options cmd := cl.cmd if cmd == nil { cmd = ssh.TestNewCmd(&cl.impl) } return cmd } func (cl *fakeClient) Copy(args []string, options *ssh.Options) error { cl.calls = append(cl.calls, "Copy") cl.copyArgs = args cl.optionsArg = options return cl.err } type bufferWriter struct { bytes.Buffer } func (*bufferWriter) Close() error { return nil } type fakeCommandImpl struct { calls []string stdinArg io.Reader stdoutArg io.Writer stderrArg io.Writer stdinData bufferWriter err error stdinRaw io.Reader stdoutRaw io.Writer stderrRaw io.Writer stdoutData bytes.Buffer stderrData bytes.Buffer } func (ci *fakeCommandImpl) checkCalls(c *gc.C, stdin io.Reader, stdout, stderr io.Writer, calls ...string) { c.Check(ci.stdinArg, gc.Equals, stdin) c.Check(ci.stdoutArg, gc.Equals, stdout) c.Check(ci.stderrArg, gc.Equals, stderr) c.Check(ci.calls, jc.DeepEquals, calls) } func (ci *fakeCommandImpl) checkStdin(c *gc.C, data string) { c.Check(ci.stdinData.String(), gc.Equals, data) } func (ci *fakeCommandImpl) Start() error { ci.calls = append(ci.calls, "Start") return ci.err } func (ci *fakeCommandImpl) Wait() error { ci.calls = append(ci.calls, "Wait") return ci.err } func (ci *fakeCommandImpl) Kill() error { ci.calls = append(ci.calls, "Kill") return ci.err } func (ci *fakeCommandImpl) SetStdio(stdin io.Reader, stdout, stderr io.Writer) { ci.calls = append(ci.calls, "SetStdio") ci.stdinArg = stdin ci.stdoutArg = stdout ci.stderrArg = stderr } func (ci *fakeCommandImpl) StdinPipe() (io.WriteCloser, io.Reader, error) { ci.calls = append(ci.calls, "StdinPipe") return &ci.stdinData, ci.stdinRaw, ci.err } func (ci *fakeCommandImpl) StdoutPipe() (io.ReadCloser, io.Writer, error) { ci.calls = append(ci.calls, "StdoutPipe") return ioutil.NopCloser(&ci.stdoutData), ci.stdoutRaw, ci.err } func (ci *fakeCommandImpl) StderrPipe() (io.ReadCloser, io.Writer, error) { ci.calls = append(ci.calls, "StderrPipe") return ioutil.NopCloser(&ci.stderrData), ci.stderrRaw, ci.err } ================================================ FILE: ssh/fingerprint.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh import ( "bytes" "crypto/md5" "fmt" "github.com/juju/errors" ) // KeyFingerprint returns the fingerprint and comment for the specified key // in authorized_key format. Fingerprints are generated according to RFC4716. // See ttp://www.ietf.org/rfc/rfc4716.txt, section 4. func KeyFingerprint(key string) (fingerprint, comment string, err error) { ak, err := ParseAuthorisedKey(key) if err != nil { return "", "", errors.Errorf("generating key fingerprint: %v", err) } hash := md5.New() hash.Write(ak.Key) sum := hash.Sum(nil) var buf bytes.Buffer for i := 0; i < hash.Size(); i++ { if i > 0 { buf.WriteByte(':') } buf.WriteString(fmt.Sprintf("%02x", sum[i])) } return buf.String(), ak.Comment, nil } ================================================ FILE: ssh/fingerprint_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh_test import ( "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/ssh" sshtesting "github.com/juju/utils/v4/ssh/testing" ) type FingerprintSuite struct { testing.IsolationSuite } var _ = gc.Suite(&FingerprintSuite{}) func (s *FingerprintSuite) TestKeyFingerprint(c *gc.C) { keys := []sshtesting.SSHKey{ sshtesting.ValidKeyOne, sshtesting.ValidKeyTwo, sshtesting.ValidKeyThree, } for _, k := range keys { fingerprint, _, err := ssh.KeyFingerprint(k.Key) c.Assert(err, jc.ErrorIsNil) c.Assert(fingerprint, gc.Equals, k.Fingerprint) } } func (s *FingerprintSuite) TestKeyFingerprintError(c *gc.C) { _, _, err := ssh.KeyFingerprint("invalid key") c.Assert(err, gc.ErrorMatches, `generating key fingerprint: invalid authorized_key "invalid key"`) } ================================================ FILE: ssh/generate.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh import ( "crypto/ed25519" "crypto/rand" "encoding/pem" "fmt" "strings" "github.com/juju/errors" "golang.org/x/crypto/ssh" ) // ed25519GenerateKey allows for tests to patch out ed25519 key generation var ed25519GenerateKey = ed25519.GenerateKey // GenerateKey makes a ED25519 no-passphrase SSH capable key. // The private key returned is encoded to ASCII using the PKCS1 encoding. // The public key is suitable to be added into an authorized_keys file, // and has the comment passed in as the comment part of the key. func GenerateKey(comment string) (private, public string, err error) { _, privateKey, err := ed25519GenerateKey(rand.Reader) if err != nil { return "", "", errors.Trace(err) } pemBlock, err := ssh.MarshalPrivateKey(privateKey, comment) if err != nil { return "", "", errors.Trace(err) } identity := pem.EncodeToMemory(pemBlock) public, err = PublicKey(identity, comment) if err != nil { return "", "", errors.Trace(err) } return string(identity), public, nil } // PublicKey returns the public key for any private key. The public key is // suitable to be added into an authorized_keys file, and has the comment // passed in as the comment part of the key. func PublicKey(privateKey []byte, comment string) (string, error) { signer, err := ssh.ParsePrivateKey(privateKey) if err != nil { return "", errors.Annotate(err, "failed to load key") } auth_key := string(ssh.MarshalAuthorizedKey(signer.PublicKey())) // Strip off the trailing new line so we can add a comment. auth_key = strings.TrimSpace(auth_key) public := fmt.Sprintf("%s %s\n", auth_key, comment) return public, nil } ================================================ FILE: ssh/generate_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh_test import ( "crypto/dsa" "crypto/ed25519" "io" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/ssh" ) type GenerateSuite struct { testing.IsolationSuite } var _ = gc.Suite(&GenerateSuite{}) var ( pregeneratedKey ed25519.PrivateKey ) // overrideGenerateKey patches out rsa.GenerateKey to create a single testing // key which is saved and used between tests to save computation time. func overrideGenerateKey() testing.Restorer { restorer := testing.PatchValue(ssh.ED25519GenerateKey, func(random io.Reader) (ed25519.PublicKey, ed25519.PrivateKey, error) { if pregeneratedKey != nil { return ed25519.PublicKey{}, pregeneratedKey, nil } public, private, err := generateED25519Key(random) if err != nil { return nil, nil, err } pregeneratedKey = private return public, private, nil }) return restorer } func generateED25519Key(random io.Reader) (ed25519.PublicKey, ed25519.PrivateKey, error) { // Ignore requested bits and just use 512 bits for speed public, private, err := ed25519.GenerateKey(random) if err != nil { return nil, nil, err } return public, private, nil } func generateDSAKey(random io.Reader) (*dsa.PrivateKey, error) { var privKey dsa.PrivateKey if err := dsa.GenerateParameters(&privKey.Parameters, random, dsa.L1024N160); err != nil { return nil, err } if err := dsa.GenerateKey(&privKey, random); err != nil { return nil, err } return &privKey, nil } func (s *GenerateSuite) TestGenerate(c *gc.C) { defer overrideGenerateKey().Restore() private, public, err := ssh.GenerateKey("some-comment") c.Check(err, jc.ErrorIsNil) c.Check(private, jc.HasPrefix, "-----BEGIN OPENSSH PRIVATE KEY-----\n") c.Check(private, jc.HasSuffix, "-----END OPENSSH PRIVATE KEY-----\n") c.Check(public, jc.HasPrefix, "ssh-ed25519 ") c.Check(public, jc.HasSuffix, " some-comment\n") } ================================================ FILE: ssh/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh_test import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: ssh/run.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh import ( "bytes" "os/exec" "strings" "syscall" "time" "github.com/juju/clock" "github.com/juju/errors" utilexec "github.com/juju/utils/v4/exec" ) // ExecParams are used for the parameters for ExecuteCommandOnMachine. type ExecParams struct { IdentityFile string Host string Command string Timeout time.Duration } // StartCommandOnMachine executes the command on the given host. The // command is run in a Bash shell over an SSH connection. All output // is captured. A RunningCmd is returned that may be used to wait // for the command to finish running. func StartCommandOnMachine(params ExecParams) (*RunningCmd, error) { // execute bash accepting commands on stdin if params.Host == "" { return nil, errors.Errorf("missing host address") } logger.Debugf("execute on %s", params.Host) var options Options if params.IdentityFile != "" { options.SetIdentities(params.IdentityFile) } command := Command(params.Host, []string{"/bin/bash", "-s"}, &options) // Run the command. running := &RunningCmd{ SSHCmd: command, } command.Stdout = &running.Stdout command.Stderr = &running.Stderr command.Stdin = strings.NewReader(params.Command + "\n") if err := command.Start(); err != nil { return nil, errors.Trace(err) } return running, nil } // RunningCmd represents a command that has been started. type RunningCmd struct { // SSHCmd is the command the was started. SSHCmd *Cmd // Stdout and Stderr are the output streams the command is using. Stdout bytes.Buffer Stderr bytes.Buffer } // Wait waits for the command to complete and returns the result. func (cmd *RunningCmd) Wait() (result utilexec.ExecResponse, _ error) { defer func() { // Gather as much as we have from stdout and stderr. result.Stdout = cmd.Stdout.Bytes() result.Stderr = cmd.Stderr.Bytes() }() err := cmd.SSHCmd.Wait() logger.Debugf("command.Wait finished (err: %v)", err) code, err := getExitCode(err) if err != nil { return result, errors.Trace(err) } result.Code = code return result, nil } // TODO(ericsnow) Add RunningCmd.WaitAbortable(abortChan <-chan error) ... // based on WaitWithTimeout and update WaitWithTimeout to use it. We // could make it WaitAbortable(abortChans ...<-chan error), which would // require using reflect.Select(). Then that could simply replace Wait(). // It may make more sense, however, to have a helper function: // Wait(cmd T, abortChans ...<-chan error) ... // Cancelled is an error indicating that a command timed out. var Cancelled = errors.New("command timed out") // WaitWithCancel waits for the command to complete and returns the result. If // cancel is closed before the result was returned, then it takes longer than // the provided timeout then Cancelled is returned. func (cmd *RunningCmd) WaitWithCancel(cancel <-chan struct{}) (utilexec.ExecResponse, error) { var result utilexec.ExecResponse done := make(chan error, 1) go func() { defer close(done) waitResult, err := cmd.Wait() result = waitResult done <- err }() select { case err := <-done: return result, errors.Trace(err) case <-cancel: logger.Infof("killing the command due to cancellation") cmd.SSHCmd.Kill() <-done // Ensure that the original cmd.Wait() call completed. cmd.SSHCmd.Wait() // Finalize cmd.SSHCmd, if necessary. return result, Cancelled } } func getExitCode(err error) (int, error) { if err == nil { return 0, nil } err = errors.Cause(err) if ee, ok := err.(*exec.ExitError); ok { raw := ee.ProcessState.Sys() status, ok := raw.(syscall.WaitStatus) if !ok { logger.Errorf("unexpected type %T from ProcessState.Sys()", raw) } else if status.Exited() { // A non-zero return code isn't considered an error here. return status.ExitStatus(), nil } } return -1, err } // ExecuteCommandOnMachine will execute the command passed through on // the host specified. This is done using ssh, and passing the commands // through /bin/bash. If the command is not finished within the timeout // specified, an error is returned. Any output captured during that time // is also returned in the remote response. func ExecuteCommandOnMachine(args ExecParams) (utilexec.ExecResponse, error) { var result utilexec.ExecResponse cmd, err := StartCommandOnMachine(args) if err != nil { return result, errors.Trace(err) } cancel := make(chan struct{}) go func() { <-clock.WallClock.After(args.Timeout) close(cancel) }() result, err = cmd.WaitWithCancel(cancel) if err != nil { return result, errors.Trace(err) } return result, nil } ================================================ FILE: ssh/run_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh_test import ( "io/ioutil" "os" "path/filepath" "runtime" "time" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/ssh" ) const ( shortWait = 50 * time.Millisecond longWait = 10 * time.Second ) type ExecuteSSHCommandSuite struct { testing.IsolationSuite originalPath string testbin string fakessh string } var _ = gc.Suite(&ExecuteSSHCommandSuite{}) func (s *ExecuteSSHCommandSuite) SetUpSuite(c *gc.C) { s.originalPath = os.Getenv("PATH") s.IsolationSuite.SetUpSuite(c) } func (s *ExecuteSSHCommandSuite) SetUpTest(c *gc.C) { if runtime.GOOS == "windows" { c.Skip("issue 1403084: Tests use OpenSSH only") } s.IsolationSuite.SetUpTest(c) err := os.Setenv("PATH", s.originalPath) c.Assert(err, jc.ErrorIsNil) s.testbin = c.MkDir() s.fakessh = filepath.Join(s.testbin, "ssh") s.PatchEnvPathPrepend(s.testbin) } func (s *ExecuteSSHCommandSuite) fakeSSH(c *gc.C, cmd string) { err := ioutil.WriteFile(s.fakessh, []byte(cmd), 0755) c.Assert(err, jc.ErrorIsNil) } func (s *ExecuteSSHCommandSuite) TestCaptureOutput(c *gc.C) { s.fakeSSH(c, echoSSH) response, err := ssh.ExecuteCommandOnMachine(ssh.ExecParams{ Host: "hostname", Command: "sudo apt-get update\nsudo apt-get upgrade", Timeout: longWait, }) c.Assert(err, jc.ErrorIsNil) c.Assert(response.Code, gc.Equals, 0) c.Assert(string(response.Stdout), gc.Equals, "sudo apt-get update\nsudo apt-get upgrade\n") c.Assert(string(response.Stderr), gc.Equals, "-o PasswordAuthentication no -o ServerAliveInterval 30 hostname /bin/bash -s\n") } func (s *ExecuteSSHCommandSuite) TestIdentityFile(c *gc.C) { s.fakeSSH(c, echoSSH) response, err := ssh.ExecuteCommandOnMachine(ssh.ExecParams{ IdentityFile: "identity-file", Host: "hostname", Timeout: longWait, }) c.Assert(err, jc.ErrorIsNil) c.Assert(string(response.Stderr), jc.Contains, " -i identity-file ") } func (s *ExecuteSSHCommandSuite) TestTimoutCaptureOutput(c *gc.C) { s.fakeSSH(c, slowSSH) response, err := ssh.ExecuteCommandOnMachine(ssh.ExecParams{ IdentityFile: "identity-file", Host: "hostname", Command: "ignored", Timeout: shortWait, }) c.Check(err, gc.ErrorMatches, "command timed out") c.Assert(response.Code, gc.Equals, 0) c.Assert(string(response.Stdout), gc.Equals, "stdout\n") c.Assert(string(response.Stderr), gc.Equals, "stderr\n") } func (s *ExecuteSSHCommandSuite) TestCapturesReturnCode(c *gc.C) { s.fakeSSH(c, passthroughSSH) response, err := ssh.ExecuteCommandOnMachine(ssh.ExecParams{ IdentityFile: "identity-file", Host: "hostname", Command: "echo stdout; exit 42", Timeout: longWait, }) c.Check(err, jc.ErrorIsNil) c.Assert(response.Code, gc.Equals, 42) c.Assert(string(response.Stdout), gc.Equals, "stdout\n") c.Assert(string(response.Stderr), gc.Equals, "") } // echoSSH outputs the command args to stderr, and copies stdin to stdout var echoSSH = `#!/bin/bash # Write the args to stderr echo "$*" >&2 cat /dev/stdin ` // slowSSH sleeps for a while after outputting some text to stdout and stderr var slowSSH = `#!/bin/bash echo "stderr" >&2 echo "stdout" sleep 5s ` // passthroughSSH creates an ssh that executes stdin. var passthroughSSH = `#!/bin/bash -s` ================================================ FILE: ssh/ssh.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // Package ssh contains utilities for dealing with SSH connections, // key management, and so on. All SSH-based command executions in // Juju should use the Command/ScpCommand functions in this package. package ssh import ( "bytes" "io" "os/exec" "syscall" "github.com/juju/errors" "github.com/juju/utils/v4" ) // StrictHostChecksOption defines the possible values taken by // Option.SetStrictHostKeyChecking(). type StrictHostChecksOption int const ( // StrictHostChecksDefault configures the default, // implementation-specific, behaviour. // // For the OpenSSH implementation, this elides the // StrictHostKeyChecking option, which means the // user's personal configuration will be used. // // For the go.crypto implementation, the default is // the equivalent of "ask". StrictHostChecksDefault StrictHostChecksOption = iota // StrictHostChecksNo disables strict host key checking. StrictHostChecksNo // StrictHostChecksYes enabled strict host key checking // enabled. Target hosts must appear in known_hosts file or // connections will fail. StrictHostChecksYes // StrictHostChecksAsk will cause openssh to ask the user about // hosts that don't appear in known_hosts file. StrictHostChecksAsk ) // Options is a client-implementation independent SSH options set. type Options struct { // proxyCommand specifies the command to // execute to proxy SSH traffic through. proxyCommand []string // ssh server port; zero means use the default (22) port int // no PTY forced by default allocatePTY bool // password authentication is disallowed by default passwordAuthAllowed bool // identities is a sequence of paths to private key/identity files // to use when attempting to login. A client implementaton may attempt // with additional identities, but must give preference to these identities []string // knownHostsFile is a path to a file in which to save the host's // fingerprint. knownHostsFile string // strictHostKeyChecking sets that the host being connected to must // exist in the known_hosts file, and with a matching public key. strictHostKeyChecking StrictHostChecksOption // hostKeyAlgorithms sets the host key types that the client will // accept from the server, in order of preference. By default the // client implementation will specify a set of reasonable types. hostKeyAlgorithms []string } // SetProxyCommand sets a command to execute to proxy traffic through. func (o *Options) SetProxyCommand(command ...string) { o.proxyCommand = append([]string{}, command...) } // SetPort sets the SSH server port to connect to. func (o *Options) SetPort(port int) { o.port = port } // EnablePTY forces the allocation of a pseudo-TTY. // // Forcing a pseudo-TTY is required, for example, for sudo // prompts on the target host. func (o *Options) EnablePTY() { o.allocatePTY = true } // SetKnownHostsFile sets the host's fingerprint to be saved in the given file. // // Host fingerprints are saved in ~/.ssh/known_hosts by default. func (o *Options) SetKnownHostsFile(file string) { o.knownHostsFile = file } // SetStrictHostKeyChecking sets the desired host key checking // behaviour. It takes one of the StrictHostChecksOption constants. // See also EnableStrictHostKeyChecking. func (o *Options) SetStrictHostKeyChecking(value StrictHostChecksOption) { o.strictHostKeyChecking = value } // AllowPasswordAuthentication allows the SSH // client to prompt the user for a password. // // Password authentication is disallowed by default. func (o *Options) AllowPasswordAuthentication() { o.passwordAuthAllowed = true } // SetIdentities sets a sequence of paths to private key/identity files // to use when attempting login. Client implementations may attempt to // use additional identities, but must give preference to the ones // specified here. func (o *Options) SetIdentities(identityFiles ...string) { o.identities = append([]string{}, identityFiles...) } // SetHostKeyAlgorithms sets the host key types that the client will // accept from the server, in order of preference. If not specified, // the client implementation may choose its own defaults. func (o *Options) SetHostKeyAlgorithms(algos ...string) { o.hostKeyAlgorithms = algos } // Client is an interface for SSH clients to implement type Client interface { // Command returns a Command for executing a command // on the specified host. Each Command is executed // within its own SSH session. // // Host is specified in the format [user@]host. Command(host string, command []string, options *Options) *Cmd // Copy copies file(s) between local and remote host(s). // Paths are specified in the scp format, [[user@]host:]path. If // any extra arguments are specified in extraArgs, they are passed // verbatim. Copy(args []string, options *Options) error } // Cmd represents a command to be (or being) executed // on a remote host. type Cmd struct { Stdin io.Reader Stdout io.Writer Stderr io.Writer impl command } func newCmd(impl command) *Cmd { return &Cmd{impl: impl} } // CombinedOutput runs the command, and returns the // combined stdout/stderr output and result of // executing the command. func (c *Cmd) CombinedOutput() ([]byte, error) { if c.Stdout != nil { return nil, errors.New("ssh: Stdout already set") } if c.Stderr != nil { return nil, errors.New("ssh: Stderr already set") } var b bytes.Buffer c.Stdout = &b c.Stderr = &b err := c.Run() return b.Bytes(), err } // Output runs the command, and returns the stdout // output and result of executing the command. func (c *Cmd) Output() ([]byte, error) { if c.Stdout != nil { return nil, errors.New("ssh: Stdout already set") } var b bytes.Buffer c.Stdout = &b err := c.Run() return b.Bytes(), err } // Run runs the command, and returns the result as an error. func (c *Cmd) Run() error { if err := c.Start(); err != nil { return err } err := c.Wait() if exitError, ok := err.(*exec.ExitError); ok && exitError != nil { status := exitError.ProcessState.Sys().(syscall.WaitStatus) if status.Exited() { return utils.NewRcPassthroughError(status.ExitStatus()) } } return err } // Start starts the command running, but does not wait for // it to complete. If the command could not be started, an // error is returned. func (c *Cmd) Start() error { c.impl.SetStdio(c.Stdin, c.Stdout, c.Stderr) return c.impl.Start() } // Wait waits for the started command to complete, // and returns the result as an error. func (c *Cmd) Wait() error { return c.impl.Wait() } // Kill kills the started command. func (c *Cmd) Kill() error { return c.impl.Kill() } // StdinPipe creates a pipe and connects it to // the command's stdin. The read end of the pipe // is assigned to c.Stdin. func (c *Cmd) StdinPipe() (io.WriteCloser, error) { wc, r, err := c.impl.StdinPipe() if err != nil { return nil, err } c.Stdin = r return wc, nil } // StdoutPipe creates a pipe and connects it to // the command's stdout. The write end of the pipe // is assigned to c.Stdout. func (c *Cmd) StdoutPipe() (io.ReadCloser, error) { rc, w, err := c.impl.StdoutPipe() if err != nil { return nil, err } c.Stdout = w return rc, nil } // StderrPipe creates a pipe and connects it to // the command's stderr. The write end of the pipe // is assigned to c.Stderr. func (c *Cmd) StderrPipe() (io.ReadCloser, error) { rc, w, err := c.impl.StderrPipe() if err != nil { return nil, err } c.Stderr = w return rc, nil } // command is an implementation-specific representation of a // command prepared to execute against a specific host. type command interface { Start() error Wait() error Kill() error SetStdio(stdin io.Reader, stdout, stderr io.Writer) StdinPipe() (io.WriteCloser, io.Reader, error) StdoutPipe() (io.ReadCloser, io.Writer, error) StderrPipe() (io.ReadCloser, io.Writer, error) } // DefaultClient is the default SSH client for the process. // // If the OpenSSH client is found in $PATH, then it will be // used for DefaultClient; otherwise, DefaultClient will use // an embedded client based on go.crypto/ssh. var DefaultClient Client // chosenClient holds the type of SSH client created for // DefaultClient, so that we can log it in Command or Copy. var chosenClient string func init() { initDefaultClient() } func initDefaultClient() { if client, err := NewOpenSSHClient(); err == nil { DefaultClient = client chosenClient = "OpenSSH" } else if client, err := NewGoCryptoClient(); err == nil { DefaultClient = client chosenClient = "go.crypto (embedded)" } } // Command is a short-cut for DefaultClient.Command. func Command(host string, command []string, options *Options) *Cmd { logger.Debugf("using %s ssh client", chosenClient) return DefaultClient.Command(host, command, options) } // Copy is a short-cut for DefaultClient.Copy. func Copy(args []string, options *Options) error { logger.Debugf("using %s ssh client", chosenClient) return DefaultClient.Copy(args, options) } // CopyReader sends the reader's data to a file on the remote host over SSH. func CopyReader(host, filename string, r io.Reader, options *Options) error { logger.Debugf("using %s ssh client", chosenClient) return copyReader(DefaultClient, host, filename, r, options) } func copyReader(client Client, host, filename string, r io.Reader, options *Options) error { cmd := client.Command(host, []string{"cat - > " + filename}, options) cmd.Stdin = r return errors.Trace(cmd.Run()) } ================================================ FILE: ssh/ssh_gocrypto.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh import ( "bytes" "fmt" "io" "io/ioutil" "net" "os" "os/exec" "os/user" "strconv" "strings" "sync" "time" "github.com/juju/clock" "github.com/juju/errors" "github.com/juju/mutex/v2" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/knownhosts" "golang.org/x/crypto/ssh/terminal" "github.com/juju/utils/v4" ) const sshDefaultPort = 22 // GoCryptoClient is an implementation of Client that // uses the embedded go.crypto/ssh SSH client. // // GoCryptoClient is intentionally limited in the // functionality that it enables, as it is currently // intended to be used only for non-interactive command // execution. type GoCryptoClient struct { signers []ssh.Signer } // NewGoCryptoClient creates a new GoCryptoClient. // // If no signers are specified, NewGoCryptoClient will // use the private key generated by LoadClientKeys. func NewGoCryptoClient(signers ...ssh.Signer) (*GoCryptoClient, error) { return &GoCryptoClient{signers: signers}, nil } // Command implements Client.Command. func (c *GoCryptoClient) Command(host string, command []string, options *Options) *Cmd { shellCommand := utils.CommandString(command...) signers := c.signers if len(signers) == 0 { signers = privateKeys() } user, host := splitUserHost(host) port := sshDefaultPort var proxyCommand []string var knownHostsFile string var strictHostKeyChecking StrictHostChecksOption var hostKeyAlgorithms []string if options != nil { if options.port != 0 { port = options.port } proxyCommand = options.proxyCommand knownHostsFile = options.knownHostsFile strictHostKeyChecking = options.strictHostKeyChecking hostKeyAlgorithms = options.hostKeyAlgorithms } logger.Tracef(`running (equivalent of): ssh "%s@%s" -p %d '%s'`, user, host, port, shellCommand) return &Cmd{impl: &goCryptoCommand{ signers: signers, user: user, addr: net.JoinHostPort(host, strconv.Itoa(port)), command: shellCommand, proxyCommand: proxyCommand, knownHostsFile: knownHostsFile, strictHostKeyChecking: strictHostKeyChecking, hostKeyAlgorithms: hostKeyAlgorithms, }} } // Copy implements Client.Copy. // // Copy is currently unimplemented, and will always return an error. func (c *GoCryptoClient) Copy(args []string, options *Options) error { return errors.Errorf("scp command is not implemented (OpenSSH scp not available in PATH)") } type goCryptoCommand struct { signers []ssh.Signer user string addr string command string proxyCommand []string knownHostsFile string strictHostKeyChecking StrictHostChecksOption hostKeyAlgorithms []string stdin io.Reader stdout io.Writer stderr io.Writer client *ssh.Client sess *ssh.Session } var sshDial = ssh.Dial var sshDialWithProxy = func(addr string, proxyCommand []string, config *ssh.ClientConfig) (*ssh.Client, error) { if len(proxyCommand) == 0 { return sshDial("tcp", addr, config) } // User has specified a proxy. Create a pipe and // redirect the proxy command's stdin/stdout to it. host, port, err := net.SplitHostPort(addr) if err != nil { host = addr } for i, arg := range proxyCommand { arg = strings.Replace(arg, "%h", host, -1) if port != "" { arg = strings.Replace(arg, "%p", port, -1) } arg = strings.Replace(arg, "%r", config.User, -1) proxyCommand[i] = arg } client, server := net.Pipe() logger.Tracef(`executing proxy command %q`, proxyCommand) cmd := exec.Command(proxyCommand[0], proxyCommand[1:]...) cmd.Stdin = server cmd.Stdout = server cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { return nil, err } conn, chans, reqs, err := ssh.NewClientConn(client, addr, config) if err != nil { return nil, err } return ssh.NewClient(conn, chans, reqs), nil } func (c *goCryptoCommand) ensureSession() (*ssh.Session, error) { if c.sess != nil { return c.sess, nil } if len(c.signers) == 0 { return nil, errors.Errorf("no private keys available") } if c.user == "" { currentUser, err := user.Current() if err != nil { return nil, errors.Errorf("getting current user: %v", err) } c.user = currentUser.Username } config := &ssh.ClientConfig{ User: c.user, HostKeyCallback: c.hostKeyCallback, HostKeyAlgorithms: c.hostKeyAlgorithms, Auth: []ssh.AuthMethod{ ssh.PublicKeysCallback(func() ([]ssh.Signer, error) { return c.signers, nil }), }, } client, err := sshDialWithProxy(c.addr, c.proxyCommand, config) if err != nil { return nil, err } sess, err := client.NewSession() if err != nil { client.Close() return nil, err } c.client = client c.sess = sess c.sess.Stdin = WrapStdin(c.stdin) c.sess.Stdout = c.stdout c.sess.Stderr = c.stderr return sess, nil } func (c *goCryptoCommand) Start() error { sess, err := c.ensureSession() if err != nil { return err } if c.command == "" { return sess.Shell() } return sess.Start(c.command) } func (c *goCryptoCommand) Close() error { if c.sess == nil { return nil } err0 := c.sess.Close() err1 := c.client.Close() if err0 == nil { err0 = err1 } c.sess = nil c.client = nil return err0 } func (c *goCryptoCommand) Wait() error { if c.sess == nil { return errors.Errorf("command has not been started") } err := c.sess.Wait() c.Close() return err } func (c *goCryptoCommand) Kill() error { if c.sess == nil { return errors.Errorf("command has not been started") } return c.sess.Signal(ssh.SIGKILL) } func (c *goCryptoCommand) SetStdio(stdin io.Reader, stdout, stderr io.Writer) { c.stdin = stdin c.stdout = stdout c.stderr = stderr } func (c *goCryptoCommand) StdinPipe() (io.WriteCloser, io.Reader, error) { sess, err := c.ensureSession() if err != nil { return nil, nil, err } wc, err := sess.StdinPipe() return wc, sess.Stdin, err } func (c *goCryptoCommand) StdoutPipe() (io.ReadCloser, io.Writer, error) { sess, err := c.ensureSession() if err != nil { return nil, nil, err } wc, err := sess.StdoutPipe() return ioutil.NopCloser(wc), sess.Stdout, err } func (c *goCryptoCommand) StderrPipe() (io.ReadCloser, io.Writer, error) { sess, err := c.ensureSession() if err != nil { return nil, nil, err } wc, err := sess.StderrPipe() return ioutil.NopCloser(wc), sess.Stderr, err } func (c *goCryptoCommand) hostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error { knownHostsFile := c.knownHostsFile if knownHostsFile == "" { knownHostsFile = GoCryptoKnownHostsFile() if knownHostsFile == "" { return errors.New("known_hosts file not configured") } } var printError func(string) error term, cleanupTerm, err := getTerminal() if err != nil { return errors.Trace(err) } else if term != nil { defer cleanupTerm() printError = func(message string) error { _, err := fmt.Fprintln(term, message) return err } } else { printError = func(message string) error { logger.Errorf("%s", message) return nil } } matched, err := checkHostKey(hostname, remote, key, knownHostsFile, printError) if err != nil || matched { return errors.Trace(err) } // We did not find a matching key, so what we do next depends on the // strict host key checking configuration. var warnAdd bool switch c.strictHostKeyChecking { case StrictHostChecksNo: // Don't ask, just add. warnAdd = true case StrictHostChecksDefault, StrictHostChecksAsk: message := fmt.Sprintf(`The authenticity of host '%s (%s)' can't be established. %s key fingerprint is %s. `, hostname, remote, key.Type(), ssh.FingerprintSHA256(key), ) if term == nil { // If we're not running in a terminal, // we can't ask the user if they want // to accept. logger.Errorf("%s", message) return errors.New("not running in a terminal, cannot prompt for verification") } // Prompt user, asking if they trust the key. fmt.Fprint(term, message+"Are you sure you want to continue connecting (yes/no)? ") for { line, err := term.ReadLine() if err != nil { return errors.Trace(err) } var yes bool switch strings.ToLower(line) { case "yes": yes = true case "no": return errors.New("Host key verification failed.") default: fmt.Fprint(term, "Please type 'yes' or 'no': ") } if yes { break } } default: return errors.Errorf( `no %s host key is known for %s and you have requested strict checking`, key.Type(), hostname, ) } if knownHostsFile != os.DevNull { // Make sure no other process modifies the file. releaser, err := mutex.Acquire(mutex.Spec{ Name: "juju-ssh-client", Clock: clock.WallClock, Delay: time.Second, }) if err != nil { return errors.Trace(err) } defer releaser.Release() // Write the file atomically, so the initial ReadAll above // doesn't have to hold the mutex. knownHostsData, err := ioutil.ReadFile(knownHostsFile) if err != nil && !os.IsNotExist(err) { return errors.Trace(err) } buf := bytes.NewBuffer(knownHostsData) if len(knownHostsData) > 0 && !bytes.HasSuffix(knownHostsData, []byte("\n")) { buf.WriteRune('\n') } buf.WriteString(knownhosts.Line([]string{hostname}, key)) buf.WriteRune('\n') if err := utils.AtomicWriteFile(knownHostsFile, buf.Bytes(), 0600); err != nil { return errors.Trace(err) } } if warnAdd { printError(fmt.Sprintf( "Warning: permanently added '%s' (%s) to the list of known hosts.", hostname, key.Type(), )) } return nil } type readLineWriter interface { io.Writer ReadLine() (string, error) } var getTerminal = func() (readLineWriter, func(), error) { if fd := int(os.Stdin.Fd()); terminal.IsTerminal(fd) { oldState, err := terminal.MakeRaw(fd) if err != nil { return nil, nil, errors.Trace(err) } cleanup := func() { terminal.Restore(fd, oldState) } return terminal.NewTerminal(os.Stdin, ""), cleanup, nil } return nil, nil, nil } // checkHostKey checks the given (hostname, address, public key) tuple // against the local known-hosts database, if it exists, and returns a // boolean indicating whether a match was found, and any errors encountered. func checkHostKey( hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string, printError func(string) error, ) (bool, error) { // NOTE(axw) the knownhosts code is incomplete, but enough for // our limited use cases. We do not support parsing a known_hosts // file managed by OpenSSH (due to hashed hosts, etc.), but that // is OK since this client exists only to support systems that // do not have access to OpenSSH. callback, err := knownhosts.New(knownHostsFile) if err != nil { if os.IsNotExist(err) { // The known_hosts file does not exist. return false, nil } return false, errors.Trace(err) } err = callback(hostname, remote, key) switch err := err.(type) { case nil: // Known host with matching key. return true, nil case *knownhosts.KeyError: if len(err.Want) == 0 { // Unknown host. return false, nil } head := fmt.Sprintf(` @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @ @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY! Someone could be eavesdropping on you right now (man-in-the-middle attack)! It is also possible that a host key has just been changed. The fingerprint for the %s key sent by the remote host is %s. Please contact your system administrator. Add correct host key in %s to get rid of this message. `[1:], key.Type(), ssh.FingerprintSHA256(key), knownHostsFile) var typeKey *knownhosts.KnownKey for i, knownKey := range err.Want { if knownKey.Key.Type() == key.Type() { typeKey = &err.Want[i] } } var tail string if typeKey != nil { tail = fmt.Sprintf( "Offending %s key in %s:%d", typeKey.Key.Type(), typeKey.Filename, typeKey.Line, ) } else { tail = "Host was previously using different host key algorithms:" for _, knownKey := range err.Want { tail += fmt.Sprintf( "\n - %s key in %s:%d", knownKey.Key.Type(), knownKey.Filename, knownKey.Line, ) } } if err := printError(head + tail); err != nil { // Not being able to display the warning // should be considered fatal. return false, errors.Annotate( err, "failed to print host key mismatch warning", ) } } return false, errors.Trace(err) } func splitUserHost(s string) (user, host string) { userHost := strings.SplitN(s, "@", 2) if len(userHost) == 2 { return userHost[0], userHost[1] } return "", userHost[0] } var ( goCryptoKnownHostsMutex sync.Mutex goCryptoKnownHostsFile string ) // GoCryptoKnownHostsFile returns the known_hosts file used // by the golang.org/x/crypto/ssh-based client by default. func GoCryptoKnownHostsFile() string { goCryptoKnownHostsMutex.Lock() defer goCryptoKnownHostsMutex.Unlock() return goCryptoKnownHostsFile } // SetGoCryptoKnownHostsFile returns the known_hosts file used // by the golang.org/x/crypto/ssh-based client. func SetGoCryptoKnownHostsFile(file string) { goCryptoKnownHostsMutex.Lock() defer goCryptoKnownHostsMutex.Unlock() goCryptoKnownHostsFile = file } ================================================ FILE: ssh/ssh_gocrypto_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh_test import ( "bytes" "crypto/rand" "encoding/binary" "errors" "fmt" "io" "io/ioutil" "net" "os" "os/exec" "path/filepath" "regexp" "sync" "time" "github.com/juju/testing" jc "github.com/juju/testing/checkers" cryptossh "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/testdata" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/ssh" ) var ( testCommand = []string{"echo", "$abc"} testCommandFlat = `echo "\$abc"` ) type sshServer struct { cfg *cryptossh.ServerConfig listener net.Listener client *cryptossh.Client } func (s *sshServer) run(errorCh chan error, done chan bool) { netconn, err := s.listener.Accept() if err != nil { errorCh <- fmt.Errorf("accepting connection: %w", err) return } defer netconn.Close() conn, chans, reqs, err := cryptossh.NewServerConn(netconn, s.cfg) if err != nil { errorCh <- fmt.Errorf("getting ssh server connection: %w", err) return } s.client = cryptossh.NewClient(conn, chans, reqs) var wg sync.WaitGroup defer func() { wg.Wait() close(errorCh) }() sessionChannels := s.client.HandleChannelOpen("session") select { case <-done: return case newChannel := <-sessionChannels: if sCh := newChannel.ChannelType(); sCh != "session" { errorCh <- fmt.Errorf("unexpected session channel %q", sCh) return } channel, reqs, err := newChannel.Accept() if err != nil { errorCh <- fmt.Errorf("accepting session connection: %w", err) return } wg.Add(1) go func() { defer wg.Done() defer channel.Close() for req := range reqs { switch req.Type { case "exec": if !req.WantReply { errorCh <- fmt.Errorf("no reply wanted for request %+v", req) return } n := binary.BigEndian.Uint32(req.Payload[:4]) command := string(req.Payload[4 : n+4]) if command != testCommandFlat { errorCh <- fmt.Errorf("unexpected request command: %q", command) return } err = req.Reply(true, nil) if err != nil { errorCh <- fmt.Errorf("error sending reply: %w", err) return } channel.Write([]byte("abc value\n")) _, err := channel.SendRequest("exit-status", false, cryptossh.Marshal(&struct{ n uint32 }{0})) if err != nil { errorCh <- fmt.Errorf("error sending request: %w", err) } return default: errorCh <- fmt.Errorf("unexpected request type: %q", req.Type) return } } }() } } func newClient(c *gc.C) (*ssh.GoCryptoClient, cryptossh.PublicKey) { private, _, err := ssh.GenerateKey("test-client") c.Assert(err, jc.ErrorIsNil) key, err := cryptossh.ParsePrivateKey([]byte(private)) c.Assert(err, jc.ErrorIsNil) client, err := ssh.NewGoCryptoClient(key) c.Assert(err, jc.ErrorIsNil) return client, key.PublicKey() } type SSHGoCryptoCommandSuite struct { testing.IsolationSuite client ssh.Client knownHostsFile string testPrivateKeys map[string]any testSigners map[string]cryptossh.Signer testPublicKeys map[string]cryptossh.PublicKey } var _ = gc.Suite(&SSHGoCryptoCommandSuite{}) func (s *SSHGoCryptoCommandSuite) SetUpSuite(c *gc.C) { s.IsolationSuite.SetUpSuite(c) var err error n := len(testdata.PEMBytes) s.testPrivateKeys = make(map[string]any, n) s.testSigners = make(map[string]cryptossh.Signer, n) s.testPublicKeys = make(map[string]cryptossh.PublicKey, n) for t, k := range testdata.PEMBytes { s.testPrivateKeys[t], err = cryptossh.ParseRawPrivateKey(k) c.Assert(err, jc.ErrorIsNil) s.testSigners[t], err = cryptossh.NewSignerFromKey(s.testPrivateKeys[t]) c.Assert(err, jc.ErrorIsNil) s.testPublicKeys[t] = s.testSigners[t].PublicKey() } // Create a cert and sign it for use in tests. testCert := &cryptossh.Certificate{ Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage ValidAfter: 0, // unix epoch ValidBefore: cryptossh.CertTimeInfinity, // The end of currently representable time. Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil Key: s.testPublicKeys["ecdsa"], SignatureKey: s.testPublicKeys["ed25519"], Permissions: cryptossh.Permissions{ CriticalOptions: map[string]string{}, Extensions: map[string]string{}, }, } err = testCert.SignCert(rand.Reader, s.testSigners["ed25519"]) c.Assert(err, jc.ErrorIsNil) s.testPrivateKeys["cert"] = s.testPrivateKeys["ecdsa"] s.testSigners["cert"], err = cryptossh.NewCertSigner(testCert, s.testSigners["ecdsa"]) c.Assert(err, jc.ErrorIsNil) } func (s *SSHGoCryptoCommandSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) generateKeyRestorer := overrideGenerateKey() s.AddCleanup(func(*gc.C) { generateKeyRestorer.Restore() }) s.knownHostsFile = filepath.Join(c.MkDir(), "known_hosts") ssh.SetGoCryptoKnownHostsFile(s.knownHostsFile) ssh.PatchNilTerminal(&s.CleanupSuite) } func (s *SSHGoCryptoCommandSuite) newServer(c *gc.C, serverConfig cryptossh.ServerConfig) (*sshServer, cryptossh.PublicKey) { server := &sshServer{cfg: &serverConfig} server.cfg.AddHostKey(s.testSigners["ed25519"]) var err error server.listener, err = net.Listen("tcp", "127.0.0.1:0") c.Assert(err, jc.ErrorIsNil) c.Logf("Server listening on %s", server.listener.Addr().String()) return server, s.testPublicKeys["ed25519"] } func (s *SSHGoCryptoCommandSuite) TestNewGoCryptoClient(c *gc.C) { _, err := ssh.NewGoCryptoClient() c.Assert(err, jc.ErrorIsNil) private, _, err := ssh.GenerateKey("test-client") c.Assert(err, jc.ErrorIsNil) key, err := cryptossh.ParsePrivateKey([]byte(private)) c.Assert(err, jc.ErrorIsNil) _, err = ssh.NewGoCryptoClient(key) c.Assert(err, jc.ErrorIsNil) } func (s *SSHGoCryptoCommandSuite) TestClientNoKeys(c *gc.C) { client, err := ssh.NewGoCryptoClient() c.Assert(err, jc.ErrorIsNil) cmd := client.Command("0.1.2.3", []string{"echo", "123"}, nil) _, err = cmd.Output() c.Assert(err, gc.ErrorMatches, "no private keys available") defer ssh.ClearClientKeys() err = ssh.LoadClientKeys(c.MkDir()) c.Assert(err, jc.ErrorIsNil) s.PatchValue(ssh.SSHDial, func(network, address string, cfg *cryptossh.ClientConfig) (*cryptossh.Client, error) { return nil, errors.New("ssh.Dial failed") }) cmd = client.Command("0.1.2.3", []string{"echo", "123"}, nil) _, err = cmd.Output() // error message differs based on whether using cgo or not c.Assert(err, gc.ErrorMatches, "ssh.Dial failed") } func waitForServer(c *gc.C, errorCh chan error) error { select { case err, _ := <-errorCh: return err case <-time.After(testing.LongWait): c.Fatal("timed out waiting for ssh server") return nil } } func (s *SSHGoCryptoCommandSuite) TestCommand(c *gc.C) { client, clientKey := newClient(c) server, serverKey := s.newServer(c, cryptossh.ServerConfig{}) serverPort := server.listener.Addr().(*net.TCPAddr).Port var opts ssh.Options opts.SetPort(serverPort) opts.SetStrictHostKeyChecking(ssh.StrictHostChecksNo) cmd := client.Command("127.0.0.1", testCommand, &opts) checkedKey := false server.cfg.PublicKeyCallback = func(conn cryptossh.ConnMetadata, pubkey cryptossh.PublicKey) (*cryptossh.Permissions, error) { c.Check(pubkey, gc.DeepEquals, clientKey) checkedKey = true return nil, nil } errorCh := make(chan error, 1) done := make(chan bool) defer close(done) go server.run(errorCh, done) out, err := cmd.Output() c.Assert(err, jc.ErrorIsNil) c.Assert(string(out), gc.Equals, "abc value\n") c.Assert(checkedKey, jc.IsTrue) knownHosts, err := ioutil.ReadFile(s.knownHostsFile) c.Assert(err, jc.ErrorIsNil) c.Assert(string(knownHosts), gc.Equals, fmt.Sprintf( "[127.0.0.1]:%d %s", serverPort, cryptossh.MarshalAuthorizedKey(serverKey)), ) c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil) } func (s *SSHGoCryptoCommandSuite) TestCopy(c *gc.C) { client, err := ssh.NewGoCryptoClient() c.Assert(err, jc.ErrorIsNil) err = client.Copy([]string{"0.1.2.3:b", c.MkDir()}, nil) c.Assert(err, gc.ErrorMatches, `scp command is not implemented \(OpenSSH scp not available in PATH\)`) } func (s *SSHGoCryptoCommandSuite) TestProxyCommand(c *gc.C) { realNetcat, err := exec.LookPath("nc") if err != nil { c.Skip("skipping test, couldn't find netcat: %v") return } netcat := filepath.Join(c.MkDir(), "nc") err = ioutil.WriteFile(netcat, []byte("#!/bin/sh\necho $0 \"$@\" > $0.args && exec "+realNetcat+" \"$@\""), 0755) c.Assert(err, jc.ErrorIsNil) client, _ := newClient(c) server, _ := s.newServer(c, cryptossh.ServerConfig{}) var opts ssh.Options port := server.listener.Addr().(*net.TCPAddr).Port opts.SetProxyCommand(netcat, "-q0", "%h", "%p") opts.SetPort(port) cmd := client.Command("127.0.0.1", testCommand, &opts) server.cfg.PublicKeyCallback = func(_ cryptossh.ConnMetadata, pubkey cryptossh.PublicKey) (*cryptossh.Permissions, error) { return nil, nil } errorCh := make(chan error, 1) done := make(chan bool) defer close(done) go server.run(errorCh, done) out, err := cmd.Output() c.Assert(err, jc.ErrorIsNil) c.Assert(string(out), gc.Equals, "abc value\n") // Ensure the proxy command was executed with the appropriate arguments. data, err := ioutil.ReadFile(netcat + ".args") c.Assert(err, jc.ErrorIsNil) c.Assert(string(data), gc.Equals, fmt.Sprintf("%s -q0 127.0.0.1 %v\n", netcat, port)) c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksYes(c *gc.C) { server, _ := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port errorCh := make(chan error, 1) done := make(chan bool) defer close(done) go server.run(errorCh, done) var opts ssh.Options opts.SetPort(serverPort) opts.SetStrictHostKeyChecking(ssh.StrictHostChecksYes) client, _ := newClient(c) cmd := client.Command("127.0.0.1", testCommand, &opts) _, err := cmd.Output() c.Assert(err, gc.ErrorMatches, fmt.Sprintf( "ssh: handshake failed: no ssh-ed25519 host key is known for 127.0.0.1:%d and you have requested strict checking", serverPort, )) _, err = os.Stat(s.knownHostsFile) c.Assert(err, jc.Satisfies, os.IsNotExist) _ = waitForServer(c, errorCh) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskNonTerminal(c *gc.C) { server, _ := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port errorCh := make(chan error, 1) done := make(chan bool) defer close(done) go server.run(errorCh, done) var opts ssh.Options opts.SetPort(serverPort) opts.SetStrictHostKeyChecking(ssh.StrictHostChecksAsk) client, _ := newClient(c) cmd := client.Command("127.0.0.1", testCommand, &opts) _, err := cmd.Output() c.Assert(err, gc.ErrorMatches, "ssh: handshake failed: not running in a terminal, cannot prompt for verification") _, err = os.Stat(s.knownHostsFile) c.Assert(err, jc.Satisfies, os.IsNotExist) _ = waitForServer(c, errorCh) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalYes(c *gc.C) { var readLineWriter mockReadLineWriter ssh.PatchTerminal(&s.CleanupSuite, &readLineWriter) readLineWriter.addLine("") readLineWriter.addLine("yes") server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port errorCh := make(chan error, 1) done := make(chan bool) defer close(done) go server.run(errorCh, done) var opts ssh.Options opts.SetPort(serverPort) opts.SetStrictHostKeyChecking(ssh.StrictHostChecksAsk) client, _ := newClient(c) cmd := client.Command("127.0.0.1", testCommand, &opts) _, err := cmd.Output() c.Assert(err, jc.ErrorIsNil) knownHosts, err := ioutil.ReadFile(s.knownHostsFile) c.Assert(err, jc.ErrorIsNil) c.Assert(string(knownHosts), gc.Equals, fmt.Sprintf( "[127.0.0.1]:%d %s", serverPort, cryptossh.MarshalAuthorizedKey(serverKey), )) c.Assert(readLineWriter.written.String(), gc.Equals, fmt.Sprintf(` The authenticity of host '127.0.0.1:%[1]d (127.0.0.1:%[1]d)' can't be established. ssh-ed25519 key fingerprint is %[2]s. Are you sure you want to continue connecting (yes/no)? Please type 'yes' or 'no': `[1:], serverPort, cryptossh.FingerprintSHA256(serverKey))) c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalNo(c *gc.C) { var readLineWriter mockReadLineWriter ssh.PatchTerminal(&s.CleanupSuite, &readLineWriter) readLineWriter.addLine("no") server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port errorCh := make(chan error, 1) done := make(chan bool) defer close(done) go server.run(errorCh, done) var opts ssh.Options opts.SetPort(serverPort) opts.SetStrictHostKeyChecking(ssh.StrictHostChecksAsk) client, _ := newClient(c) cmd := client.Command("127.0.0.1", testCommand, &opts) _, err := cmd.Output() c.Assert(err, gc.ErrorMatches, "ssh: handshake failed: Host key verification failed.") _, err = os.Stat(s.knownHostsFile) c.Assert(err, jc.Satisfies, os.IsNotExist) c.Assert(readLineWriter.written.String(), gc.Equals, fmt.Sprintf(` The authenticity of host '127.0.0.1:%[1]d (127.0.0.1:%[1]d)' can't be established. ssh-ed25519 key fingerprint is %[2]s. Are you sure you want to continue connecting (yes/no)? `[1:], serverPort, cryptossh.FingerprintSHA256(serverKey))) _ = waitForServer(c, errorCh) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksNoMismatch(c *gc.C) { var readLineWriter mockReadLineWriter ssh.PatchTerminal(&s.CleanupSuite, &readLineWriter) server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port errorCh := make(chan error, 1) done := make(chan bool) defer close(done) go server.run(errorCh, done) // Write a mismatching key to the known_hosts file. Even with // StrictHostChecksNo, we should be verifying against an existing // host key. _, alternativeKey, err := generateED25519Key(rand.Reader) c.Assert(err, jc.ErrorIsNil) alternativePublicKey, err := cryptossh.NewPublicKey(alternativeKey.Public()) c.Assert(err, jc.ErrorIsNil) err = ioutil.WriteFile(s.knownHostsFile, []byte(fmt.Sprintf( "[127.0.0.1]:%d %s", serverPort, cryptossh.MarshalAuthorizedKey(alternativePublicKey), )), 0600) c.Assert(err, jc.ErrorIsNil) var opts ssh.Options opts.SetPort(serverPort) opts.SetStrictHostKeyChecking(ssh.StrictHostChecksNo) client, _ := newClient(c) cmd := client.Command("127.0.0.1", testCommand, &opts) _, err = cmd.Output() c.Assert(err, gc.ErrorMatches, "ssh: handshake failed: knownhosts: key mismatch") c.Assert(readLineWriter.written.String(), gc.Matches, fmt.Sprintf(` @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @ @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY! Someone could be eavesdropping on you right now \(man-in-the-middle attack\)! It is also possible that a host key has just been changed. The fingerprint for the ssh-ed25519 key sent by the remote host is %s. Please contact your system administrator. Add correct host key in .*/known_hosts to get rid of this message. Offending ssh-ed25519 key in .*/known_hosts:1 `[1:], regexp.QuoteMeta(cryptossh.FingerprintSHA256(serverKey)))) _ = waitForServer(c, errorCh) } func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksDifferentKeyTypes(c *gc.C) { var readLineWriter mockReadLineWriter ssh.PatchTerminal(&s.CleanupSuite, &readLineWriter) server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true}) serverPort := server.listener.Addr().(*net.TCPAddr).Port errorCh := make(chan error, 1) done := make(chan bool) defer close(done) go server.run(errorCh, done) // Write a mismatching key to the known_hosts file with a different // key type. Even with StrictHostChecksNo, we should be verifying // against an existing host key. dsaKey, err := generateDSAKey(rand.Reader) c.Assert(err, jc.ErrorIsNil) alternativePublicKey, err := cryptossh.NewPublicKey(&dsaKey.PublicKey) c.Assert(err, jc.ErrorIsNil) err = ioutil.WriteFile(s.knownHostsFile, []byte(fmt.Sprintf( "[127.0.0.1]:%d %s", serverPort, cryptossh.MarshalAuthorizedKey(alternativePublicKey), )), 0600) c.Assert(err, jc.ErrorIsNil) var opts ssh.Options opts.SetPort(serverPort) opts.SetStrictHostKeyChecking(ssh.StrictHostChecksNo) client, _ := newClient(c) cmd := client.Command("127.0.0.1", testCommand, &opts) _, err = cmd.Output() c.Assert(err, gc.ErrorMatches, "ssh: handshake failed: knownhosts: key mismatch") c.Assert(readLineWriter.written.String(), gc.Matches, fmt.Sprintf(` @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @ @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY! Someone could be eavesdropping on you right now \(man-in-the-middle attack\)! It is also possible that a host key has just been changed. The fingerprint for the ssh-ed25519 key sent by the remote host is %s. Please contact your system administrator. Add correct host key in .*/known_hosts to get rid of this message. Host was previously using different host key algorithms: - ssh-dss key in .*/known_hosts:1 `[1:], regexp.QuoteMeta(cryptossh.FingerprintSHA256(serverKey)))) _ = waitForServer(c, errorCh) } type mockReadLineWriter struct { testing.Stub lines []string written bytes.Buffer } func (m *mockReadLineWriter) addLine(line string) { m.lines = append(m.lines, line) } func (m *mockReadLineWriter) ReadLine() (string, error) { m.MethodCall(m, "ReadLine") if len(m.lines) == 0 { return "", io.EOF } line := m.lines[0] m.lines = m.lines[1:] return line, nil } func (m *mockReadLineWriter) Write(data []byte) (int, error) { m.MethodCall(m, "Write", data) return m.written.Write(data) } ================================================ FILE: ssh/ssh_openssh.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh import ( "bytes" "fmt" "io" "os" "os/exec" "strings" "github.com/juju/errors" "github.com/juju/utils/v4" ) // default identities will not be attempted if // -i is specified and they are not explcitly // included. var defaultIdentities = []string{ "~/.ssh/identity", "~/.ssh/id_rsa", "~/.ssh/id_dsa", "~/.ssh/id_ecdsa", "~/.ssh/id_ed25519", } type opensshCommandKind int const ( sshKind opensshCommandKind = iota scpKind ) // sshpassWrap wraps the command/args with sshpass if it is found in $PATH // and the SSHPASS environment variable is set. Otherwise, the original // command/args are returned. func sshpassWrap(cmd string, args []string) (string, []string) { if os.Getenv("SSHPASS") != "" { if path, err := exec.LookPath("sshpass"); err == nil { return path, append([]string{"-e", cmd}, args...) } } return cmd, args } // OpenSSHClient is an implementation of Client that // uses the ssh and scp executables found in $PATH. type OpenSSHClient struct{} // NewOpenSSHClient creates a new OpenSSHClient. // If the ssh and scp programs cannot be found // in $PATH, then an error is returned. func NewOpenSSHClient() (*OpenSSHClient, error) { var c OpenSSHClient if _, err := exec.LookPath("ssh"); err != nil { return nil, err } if _, err := exec.LookPath("scp"); err != nil { return nil, err } return &c, nil } func opensshOptions(options *Options, commandKind opensshCommandKind) []string { if options == nil { options = &Options{} } var args []string var hostChecks string switch options.strictHostKeyChecking { case StrictHostChecksYes: hostChecks = "yes" case StrictHostChecksNo: hostChecks = "no" case StrictHostChecksAsk: hostChecks = "ask" default: // StrictHostChecksUnset and invalid values are handled the // same way (the option doesn't get included). } if hostChecks != "" { args = append(args, "-o", "StrictHostKeyChecking "+hostChecks) } if len(options.proxyCommand) > 0 { args = append(args, "-o", "ProxyCommand "+utils.CommandString(options.proxyCommand...)) } if !options.passwordAuthAllowed { args = append(args, "-o", "PasswordAuthentication no") } // We must set ServerAliveInterval or the server may // think we've become unresponsive on long running // command executions such as "apt-get upgrade". args = append(args, "-o", "ServerAliveInterval 30") if options.allocatePTY { args = append(args, "-t", "-t") // twice to force } if options.knownHostsFile != "" { args = append(args, "-o", "UserKnownHostsFile "+utils.CommandString(options.knownHostsFile)) } if len(options.hostKeyAlgorithms) > 0 { args = append(args, "-o", "HostKeyAlgorithms "+utils.CommandString(strings.Join(options.hostKeyAlgorithms, ","))) } identities := append([]string{}, options.identities...) if pk := PrivateKeyFiles(); len(pk) > 0 { // Add client keys as implicit identities identities = append(identities, pk...) } // If any identities are specified, the // default ones must be explicitly specified. if len(identities) > 0 { for _, identity := range defaultIdentities { path, err := utils.NormalizePath(identity) if err != nil { logger.Warningf("failed to normalize path %q: %v", identity, err) continue } if _, err := os.Stat(path); err == nil { identities = append(identities, path) } } } for _, identity := range identities { args = append(args, "-i", identity) } if options.port != 0 { port := fmt.Sprint(options.port) if commandKind == scpKind { // scp uses -P instead of -p (-p means preserve). args = append(args, "-P", port) } else { args = append(args, "-p", port) } } return args } // Command implements Client.Command. func (c *OpenSSHClient) Command(host string, command []string, options *Options) *Cmd { args := opensshOptions(options, sshKind) args = append(args, host) if len(command) > 0 { args = append(args, command...) } bin, args := sshpassWrap("ssh", args) logger.Tracef("running: %s %s", bin, utils.CommandString(args...)) return &Cmd{impl: &opensshCmd{exec.Command(bin, args...)}} } // Copy implements Client.Copy. func (c *OpenSSHClient) Copy(args []string, userOptions *Options) error { var options Options if userOptions != nil { options = *userOptions options.allocatePTY = false // doesn't make sense for scp } allArgs := opensshOptions(&options, scpKind) allArgs = append(allArgs, args...) bin, allArgs := sshpassWrap("scp", allArgs) cmd := exec.Command(bin, allArgs...) var stderr bytes.Buffer cmd.Stderr = &stderr logger.Tracef("running: %s %s", bin, utils.CommandString(args...)) if err := cmd.Run(); err != nil { stderr := strings.TrimSpace(stderr.String()) if len(stderr) > 0 { err = errors.Errorf("%v (%v)", err, stderr) } return err } return nil } type opensshCmd struct { *exec.Cmd } func (c *opensshCmd) SetStdio(stdin io.Reader, stdout, stderr io.Writer) { c.Stdin, c.Stdout, c.Stderr = stdin, stdout, stderr } func (c *opensshCmd) StdinPipe() (io.WriteCloser, io.Reader, error) { wc, err := c.Cmd.StdinPipe() if err != nil { return nil, nil, err } return wc, c.Stdin, nil } func (c *opensshCmd) StdoutPipe() (io.ReadCloser, io.Writer, error) { rc, err := c.Cmd.StdoutPipe() if err != nil { return nil, nil, err } return rc, c.Stdout, nil } func (c *opensshCmd) StderrPipe() (io.ReadCloser, io.Writer, error) { rc, err := c.Cmd.StderrPipe() if err != nil { return nil, nil, err } return rc, c.Stderr, nil } func (c *opensshCmd) Kill() error { if c.Process == nil { return errors.Errorf("process has not been started") } return c.Process.Kill() } ================================================ FILE: ssh/ssh_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. //go:build !windows // +build !windows package ssh_test import ( "bytes" "fmt" "io/ioutil" "os" "path/filepath" "strings" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" "github.com/juju/utils/v4/ssh" ) const ( echoCommand = "/bin/echo" echoScript = "#!/bin/sh\n" + echoCommand + " $0 \"$@\" | /usr/bin/tee $0.args" ) type SSHCommandSuite struct { testing.IsolationSuite originalPath string testbin string fakessh string fakescp string client ssh.Client } var _ = gc.Suite(&SSHCommandSuite{}) func (s *SSHCommandSuite) SetUpTest(c *gc.C) { s.IsolationSuite.SetUpTest(c) s.testbin = c.MkDir() s.fakessh = filepath.Join(s.testbin, "ssh") s.fakescp = filepath.Join(s.testbin, "scp") err := ioutil.WriteFile(s.fakessh, []byte(echoScript), 0755) c.Assert(err, jc.ErrorIsNil) err = ioutil.WriteFile(s.fakescp, []byte(echoScript), 0755) c.Assert(err, jc.ErrorIsNil) s.PatchEnvPathPrepend(s.testbin) s.client, err = ssh.NewOpenSSHClient() c.Assert(err, jc.ErrorIsNil) s.PatchValue(ssh.DefaultIdentities, nil) } func (s *SSHCommandSuite) command(args ...string) *ssh.Cmd { return s.commandOptions(args, nil) } func (s *SSHCommandSuite) commandOptions(args []string, opts *ssh.Options) *ssh.Cmd { return s.client.Command("localhost", args, opts) } func (s *SSHCommandSuite) assertCommandArgs(c *gc.C, cmd *ssh.Cmd, expected string) { out, err := cmd.Output() c.Assert(err, jc.ErrorIsNil) c.Assert(strings.TrimSpace(string(out)), gc.Equals, expected) } func (s *SSHCommandSuite) TestDefaultClient(c *gc.C) { ssh.InitDefaultClient() c.Assert(ssh.DefaultClient, gc.FitsTypeOf, &ssh.OpenSSHClient{}) s.PatchEnvironment("PATH", "") ssh.InitDefaultClient() c.Assert(ssh.DefaultClient, gc.FitsTypeOf, &ssh.GoCryptoClient{}) } func (s *SSHCommandSuite) TestCommandSSHPass(c *gc.C) { // First create a fake sshpass, but don't set $SSHPASS fakesshpass := filepath.Join(s.testbin, "sshpass") err := ioutil.WriteFile(fakesshpass, []byte(echoScript), 0755) s.assertCommandArgs(c, s.command(echoCommand, "123"), fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123", s.fakessh, echoCommand), ) // Now set $SSHPASS. s.PatchEnvironment("SSHPASS", "anyoldthing") s.assertCommandArgs(c, s.command(echoCommand, "123"), fmt.Sprintf("%s -e ssh -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123", fakesshpass, echoCommand), ) // Finally, remove sshpass from $PATH. err = os.Remove(fakesshpass) c.Assert(err, jc.ErrorIsNil) s.assertCommandArgs(c, s.command(echoCommand, "123"), fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123", s.fakessh, echoCommand), ) } func (s *SSHCommandSuite) TestCommand(c *gc.C) { s.assertCommandArgs(c, s.command(echoCommand, "123"), fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123", s.fakessh, echoCommand), ) } func (s *SSHCommandSuite) TestCommandEnablePTY(c *gc.C) { var opts ssh.Options opts.EnablePTY() s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts), fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -t -t localhost %s 123", s.fakessh, echoCommand), ) } func (s *SSHCommandSuite) TestCommandSetKnownHostsFile(c *gc.C) { var opts ssh.Options opts.SetKnownHostsFile("/tmp/known hosts") s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts), fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -o UserKnownHostsFile \"/tmp/known hosts\" localhost %s 123", s.fakessh, echoCommand), ) } func (s *SSHCommandSuite) TestSetStrictHostKeyChecking(c *gc.C) { commandPattern := fmt.Sprintf("%s%%s -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123", s.fakessh, echoCommand) tests := []struct { input ssh.StrictHostChecksOption expected string }{ {ssh.StrictHostChecksNo, "no"}, {ssh.StrictHostChecksYes, "yes"}, {ssh.StrictHostChecksAsk, "ask"}, {ssh.StrictHostChecksDefault, ""}, {ssh.StrictHostChecksOption(999), ""}, } for _, t := range tests { var opts ssh.Options opts.SetStrictHostKeyChecking(t.input) expectedOpt := "" if t.expected != "" { expectedOpt = " -o StrictHostKeyChecking " + t.expected } s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts), fmt.Sprintf(commandPattern, expectedOpt)) } } func (s *SSHCommandSuite) TestCommandAllowPasswordAuthentication(c *gc.C) { var opts ssh.Options opts.AllowPasswordAuthentication() s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts), fmt.Sprintf("%s -o ServerAliveInterval 30 localhost %s 123", s.fakessh, echoCommand), ) } func (s *SSHCommandSuite) TestCommandIdentities(c *gc.C) { var opts ssh.Options opts.SetIdentities("x", "y") s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts), fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -i x -i y localhost %s 123", s.fakessh, echoCommand), ) } func (s *SSHCommandSuite) TestCommandPort(c *gc.C) { var opts ssh.Options opts.SetPort(2022) s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts), fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -p 2022 localhost %s 123", s.fakessh, echoCommand), ) } func (s *SSHCommandSuite) TestCopy(c *gc.C) { var opts ssh.Options opts.EnablePTY() opts.AllowPasswordAuthentication() opts.SetIdentities("x", "y") opts.SetPort(2022) err := s.client.Copy([]string{"/tmp/blah", "foo@bar.com:baz"}, &opts) c.Assert(err, jc.ErrorIsNil) out, err := ioutil.ReadFile(s.fakescp + ".args") c.Assert(err, jc.ErrorIsNil) // EnablePTY has no effect for Copy c.Assert(string(out), gc.Equals, s.fakescp+" -o ServerAliveInterval 30 -i x -i y -P 2022 /tmp/blah foo@bar.com:baz\n") // Try passing extra args err = s.client.Copy([]string{"/tmp/blah", "foo@bar.com:baz", "-r", "-v"}, &opts) c.Assert(err, jc.ErrorIsNil) out, err = ioutil.ReadFile(s.fakescp + ".args") c.Assert(err, jc.ErrorIsNil) c.Assert(string(out), gc.Equals, s.fakescp+" -o ServerAliveInterval 30 -i x -i y -P 2022 /tmp/blah foo@bar.com:baz -r -v\n") // Try interspersing extra args err = s.client.Copy([]string{"-r", "/tmp/blah", "-v", "foo@bar.com:baz"}, &opts) c.Assert(err, jc.ErrorIsNil) out, err = ioutil.ReadFile(s.fakescp + ".args") c.Assert(err, jc.ErrorIsNil) c.Assert(string(out), gc.Equals, s.fakescp+" -o ServerAliveInterval 30 -i x -i y -P 2022 -r /tmp/blah -v foo@bar.com:baz\n") } func (s *SSHCommandSuite) TestCommandClientKeys(c *gc.C) { defer overrideGenerateKey().Restore() clientKeysDir := c.MkDir() defer ssh.ClearClientKeys() err := ssh.LoadClientKeys(clientKeysDir) c.Assert(err, jc.ErrorIsNil) ck := filepath.Join(clientKeysDir, "juju_id_ed25519") var opts ssh.Options opts.SetIdentities("x", "y") s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts), fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -i x -i y -i %s localhost %s 123", s.fakessh, ck, echoCommand), ) } func (s *SSHCommandSuite) TestCommandError(c *gc.C) { var opts ssh.Options err := ioutil.WriteFile(s.fakessh, []byte("#!/bin/sh\nexit 42"), 0755) c.Assert(err, jc.ErrorIsNil) command := s.client.Command("ignored", []string{echoCommand, "foo"}, &opts) err = command.Run() c.Assert(utils.IsRcPassthroughError(err), jc.IsTrue) } func (s *SSHCommandSuite) TestCommandDefaultIdentities(c *gc.C) { var opts ssh.Options tempdir := c.MkDir() def1 := filepath.Join(tempdir, "def1") def2 := filepath.Join(tempdir, "def2") s.PatchValue(ssh.DefaultIdentities, []string{def1, def2}) // If no identities are specified, then the defaults aren't added. s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts), fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123", s.fakessh, echoCommand), ) // If identities are specified, then the defaults are must added. // Only the defaults that exist on disk will be added. err := ioutil.WriteFile(def2, nil, 0644) c.Assert(err, jc.ErrorIsNil) opts.SetIdentities("x", "y") s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts), fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -i x -i y -i %s localhost %s 123", s.fakessh, def2, echoCommand), ) } func (s *SSHCommandSuite) TestCopyReader(c *gc.C) { client := &fakeClient{} r := bytes.NewBufferString("") err := ssh.TestCopyReader(client, "foo@bar.com:baz", "/tmp/blah", r, nil) c.Assert(err, jc.ErrorIsNil) client.checkCalls(c, "foo@bar.com:baz", []string{"cat - > /tmp/blah"}, nil, nil, "Command") client.impl.checkCalls(c, r, nil, nil, "SetStdio", "Start", "Wait") } ================================================ FILE: ssh/stream.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh import ( "bytes" "io" ) // stripCR implements an io.Reader wrapper that removes carriage return bytes. type stripCR struct { reader io.Reader } // StripCRReader returns a new io.Reader wrapper that strips carriage returns. func StripCRReader(reader io.Reader) io.Reader { if reader == nil { return nil } return &stripCR{reader: reader} } var byteEmpty = []byte{} var byteCR = []byte{'\r'} // Read implements io.Reader interface. // This copies data around much more than needed so should be optimized if // used on a performance critical path. func (s *stripCR) Read(bufOut []byte) (int, error) { bufTemp := make([]byte, len(bufOut)) n, err := s.reader.Read(bufTemp) bufReplaced := bytes.Replace(bufTemp[:n], byteCR, byteEmpty, -1) copy(bufOut, bufReplaced) return len(bufReplaced), err } ================================================ FILE: ssh/stream_test.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh_test import ( "io/ioutil" "strings" "testing/iotest" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/ssh" ) type SSHStreamSuite struct { testing.IsolationSuite } var _ = gc.Suite(&SSHStreamSuite{}) func (s *SSHStreamSuite) TestNewStripCRNil(c *gc.C) { reader := ssh.StripCRReader(nil) c.Assert(reader, gc.IsNil) } func (s *SSHStreamSuite) TestStripCR(c *gc.C) { reader := ssh.StripCRReader(strings.NewReader("One\r\nTwo")) output, err := ioutil.ReadAll(reader) c.Assert(err, jc.ErrorIsNil) c.Check(string(output), gc.Equals, "One\nTwo") } func (s *SSHStreamSuite) TestStripCROneByte(c *gc.C) { reader := ssh.StripCRReader(strings.NewReader("One\r\r\rTwo")) output, err := ioutil.ReadAll(iotest.OneByteReader(reader)) c.Assert(err, jc.ErrorIsNil) c.Check(string(output), gc.Equals, "OneTwo") } func (s *SSHStreamSuite) TestStripCRError(c *gc.C) { reader := ssh.StripCRReader(strings.NewReader("One\r\r\rTwo")) _, err := ioutil.ReadAll(iotest.TimeoutReader(reader)) c.Assert(err.Error(), gc.Equals, "timeout") } ================================================ FILE: ssh/stream_wrapper_unix.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. //go:build !windows // +build !windows package ssh import ( "io" ) // WrapStdin returns the original stdin stream on nix platforms. func WrapStdin(reader io.Reader) io.Reader { return reader } ================================================ FILE: ssh/stream_wrapper_windows.go ================================================ // Copyright 2016 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package ssh import ( "io" ) // WrapStdin returns stdin with carriage returns stripped on windows. func WrapStdin(reader io.Reader) io.Reader { return StripCRReader(reader) } ================================================ FILE: ssh/testing/keys.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package testing type SSHKey struct { Key string Fingerprint string } var ( ValidKeyOne = SSHKey{ `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDEX/dPu4PmtvgK3La9zioCEDrJ` + `yUr6xEIK7Pr+rLgydcqWTU/kt7w7gKjOw4vvzgHfjKl09CWyvgb+y5dCiTk` + `9MxI+erGNhs3pwaoS+EavAbawB7iEqYyTep3YaJK+4RJ4OX7ZlXMAIMrTL+` + `UVrK89t56hCkFYaAgo3VY+z6rb/b3bDBYtE1Y2tS7C3au73aDgeb9psIrSV` + `86ucKBTl5X62FnYiyGd++xCnLB6uLximM5OKXfLzJQNS/QyZyk12g3D8y69` + `Xw1GzCSKX1u1+MQboyf0HJcG2ryUCLHdcDVppApyHx2OLq53hlkQ/yxdflD` + `qCqAE4j+doagSsIfC1T2T`, "86:ed:1b:cd:26:a0:a3:4c:27:35:49:60:95:b7:0f:68", } ValidKeyTwo = SSHKey{ `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDNC6zK8UMazlVgp8en8N7m7H/Y6` + `DoMWbmPFjXYRXu6iQJJ18hCtsfMe63E5/PBaOjDT8am0Sx3Eqn4ZzpWMj+z` + `knTcSd8xnMHYYxH2HStRWC1akTe4tTno2u2mqzjKd8f62URPtIocYCNRBls` + `9yjnq9SogI5EXgcx6taQcrIFcIK0SlthxxcMVSlLpnbReujW65JHtiMqoYA` + `OIALyO+Rkmtvb/ObmViDnwCKCN1up/xWt6J10MrAUtpI5b4prqG7FOqVMM/` + `zdgrVg6rUghnzdYeQ8QMyEv4mVSLzX0XIPcxorkl9q06s5mZmAzysEbKZCO` + `aXcLeNlXx/nkmuWslYCJ`, "2f:fb:b0:65:68:c8:4e:a6:1b:a6:4b:8d:14:0b:40:79", } ValidKeyThree = SSHKey{ `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCpGj1JMjGjAFt5wjARbIORyjQ/c` + `ZAiDyDHe/w8qmLKUG2KTs6586QqqM6DKPZiYesrzXqvZsWYV4B6OjLM1sxq` + `WjeDIl56PSnJ0+KP8pUV9KTkkKtRXxAoNg/II4l69e05qGffj9AcQ/7JPxx` + `eL14Ulvh/a69r3uVkw1UGVk9Bwm4eCOSCqKalYLA1k5da6crEAXn9hiXLGs` + `S9dOn3Lsqj5tK31aaUncue+a3iKb7R5LRFflDizzNS+h8tPuANQflOjOhR0` + `Vas0BsurgISseZZ0NIMISyWhZpr0eOBWA/YruN9r++kYPOnDy0eMaOVGLO7` + `SQwJ/6QHvf73yksJTncz`, "1d:cf:ab:66:8a:f6:77:fb:4c:b2:59:6f:12:cf:cb:2f", } ValidKeyFour = SSHKey{ `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCSEDMH5RyjGtEMIqM2RiPYYQgUK` + `9wdHCo1/AXkuQ7m1iVjHhACp8Oawf2Grn7hO4e0JUn5FaEZOnDj/9HB2VPw` + `EDGBwSN1caVC3yrTVkqQcsxBY9nTV+spQQMsePOdUZALcoEilvAcLRETbyn` + `rybaS2bfzpqbA9MEEaKQKLKGdgqiMdNXAj5I/ik/BPp0ziOMlMl1A1zilnS` + `UXubs1U49WWV0A70vAASvZVTXr3zrPAmstH+9Ik6FdpeE99um08FXxKYWqZ` + `6rZF1M6L1/SqC7ediYdVgRCoti85kKhi7fZBzwrGcCnxer+D0GFz++KDSNS` + `iAnVZxyXhmBrwnR6Q/v7`, "37:99:ab:96:c4:e8:f8:0b:0d:04:3e:1e:ee:66:e8:9e", } ValidKeyMulti = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDW+8zWO6qqXrHlcMK7obliuYp7D` + `vZBsK6rHlnbeV5Hh38Qn0GUX4Ahm6XeQ/NSx53wqkBQDGOJFY3s4w1a/hbd` + `PyLM2/yFXCYsj5FRf01JmUjAzWhuJMH9ViqzD//l4v8cR/pHC2B8PD6abKd` + `mIH+yLI9Cl3C4ICMKteG54egsUyboBOVKCDIKmWRLAak6sE5DPpqKF53NvD` + `cuDufWtaCfVAOrq6NW8wSQ7PAvfDh8gsG5uvZjY3gcWl9yI3EJVGFHcdxcv` + `4LtQI8mKdeg3JoufnEmeBJTZMoo83Gru5Z7tjv8J4JTUeQpd9uCCED1JAMe` + `cJSKgQ2gZMTbTshobpHr` + "\n" + `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDSgfrzyGpE5eLiXusvLcxEmoE6e` + `SMUDvTW1dd2BZgfvUVwq+toQdZ6C0C1JmbC3X563n8fmKVUAQGo5JavzABG` + `Kpy90L3cwoGCFtb+A28YsT+bfuP+LdnCbFXm9c3DPJQx6Dch8prnDtzRjRV` + `CorbPvm35NY73liUXVF6g58Owlx5rWtb8OnoTh5KQps9JTSfyNckdV9bFxP` + `7bZvMyRYW5X33KaA+CQGpTNAKDHruSuKdAdaS6rBIZRvzzzSCF28BWwFL7Z` + `ghQo0ADlUMnqIeQ58nwRImZHpmvadsZi47aMKFeykk4JQUQlwjbM0xGi0uj` + `+hlaqGYbNo0Evcjn23cj` PartValidKeyMulti = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDZRvG2miYVkbWOr2I+9xHWXqALb` + `eBcyxAlYtbjxBRwrq8oFOw9vtIIZSO0r1FM6+JHzKhLSiPCMR/PK78ZqPgZ` + `fia8Y7cEZKaUWLtZUAl0RF9w8EtsA/2gpuLZErjcoIx6fzfEYFCJcLgcQSc` + `RlKG8VZT6tWIjvoLj9ki6unkG5YGmapkT60afhf3/vd7pCJO/uyszkQ9qU8` + `odUDTTlwftpJtUb8xGmzpEZJTgk1lbZKlZm5pVXwjNEodH7Je88RBzR7PBB` + `Jct+vf8wVJ/UEFXCnamvHLanJTcJIi/I5qRlKns65Bwb8M0HszPYmvTfFRD` + `ZLi3sPUmw6PJCJ0SgATd` + "\n" + `ssh-rsa bad key` MultiInvalid = `ssh-rsa bad key` + "\n" + `ssh-rsa also bad` EmptyKeyMulti = "" ) ================================================ FILE: symlink/export_test.go ================================================ // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. package symlink var ( GetLongPathAsString = getLongPathAsString ) ================================================ FILE: symlink/symlink.go ================================================ // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. package symlink import ( "fmt" "os" "path/filepath" "github.com/juju/utils/v4" ) // Replace will do an atomic replacement of a symlink to a new path func Replace(link, newpath string) error { dstDir := filepath.Dir(link) uuid, err := utils.NewUUID() if err != nil { return err } randStr := uuid.String() tmpFile := filepath.Join(dstDir, "tmpfile"+randStr) // Create the new symlink before removing the old one. This way, if New() // fails, we still have a link to the old tools. err = New(newpath, tmpFile) if err != nil { return fmt.Errorf("cannot create symlink: %s", err) } // On Windows, symlinks may not be overwritten. We remove it first, // and then rename tmpFile if _, err := os.Stat(link); err == nil { err = os.RemoveAll(link) if err != nil { return err } } err = os.Rename(tmpFile, link) if err != nil { return fmt.Errorf("cannot update tools symlink: %v", err) } return nil } ================================================ FILE: symlink/symlink_posix.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. //go:build linux || darwin // +build linux darwin package symlink import ( "os" "github.com/juju/errors" ) // New is a wrapper function for os.Symlink() on Linux func New(oldname, newname string) error { return os.Symlink(oldname, newname) } // Read is a wrapper for os.Readlink() on Linux func Read(link string) (string, error) { return os.Readlink(link) } func IsSymlink(path string) (bool, error) { st, err := os.Lstat(path) if err != nil { return false, errors.Trace(err) } return st.Mode()&os.ModeSymlink != 0, nil } // getLongPathAsString does nothing on linux. Its here for compatibillity // with the windows implementation func getLongPathAsString(path string) (string, error) { return path, nil } ================================================ FILE: symlink/symlink_test.go ================================================ // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. package symlink_test import ( "io/ioutil" "os" "path/filepath" "testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" "github.com/juju/utils/v4/symlink" ) type SymlinkSuite struct{} var _ = gc.Suite(&SymlinkSuite{}) func Test(t *testing.T) { gc.TestingT(t) } func (*SymlinkSuite) TestReplace(c *gc.C) { target, err := symlink.GetLongPathAsString(c.MkDir()) c.Assert(err, gc.IsNil) target_second, err := symlink.GetLongPathAsString(c.MkDir()) c.Assert(err, gc.IsNil) link := filepath.Join(target, "link") _, err = os.Stat(target) c.Assert(err, gc.IsNil) _, err = os.Stat(target_second) c.Assert(err, gc.IsNil) err = symlink.New(target, link) c.Assert(err, gc.IsNil) link_target, err := symlink.Read(link) c.Assert(err, gc.IsNil) c.Assert(link_target, gc.Equals, filepath.FromSlash(target)) err = symlink.Replace(link, target_second) c.Assert(err, gc.IsNil) link_target, err = symlink.Read(link) c.Assert(err, gc.IsNil) c.Assert(link_target, gc.Equals, filepath.FromSlash(target_second)) } func (*SymlinkSuite) TestIsSymlinkFile(c *gc.C) { dir, err := symlink.GetLongPathAsString(c.MkDir()) c.Assert(err, gc.IsNil) target := filepath.Join(dir, "file") err = ioutil.WriteFile(target, []byte("TOP SECRET"), 0644) c.Assert(err, gc.IsNil) link := filepath.Join(dir, "link") _, err = os.Stat(target) c.Assert(err, gc.IsNil) err = symlink.New(target, link) c.Assert(err, gc.IsNil) isSymlink, err := symlink.IsSymlink(link) c.Assert(err, gc.IsNil) c.Assert(isSymlink, jc.IsTrue) } func (*SymlinkSuite) TestIsSymlinkFolder(c *gc.C) { target, err := symlink.GetLongPathAsString(c.MkDir()) c.Assert(err, gc.IsNil) link := filepath.Join(target, "link") _, err = os.Stat(target) c.Assert(err, gc.IsNil) err = symlink.New(target, link) c.Assert(err, gc.IsNil) isSymlink, err := symlink.IsSymlink(link) c.Assert(err, gc.IsNil) c.Assert(isSymlink, jc.IsTrue) } func (*SymlinkSuite) TestIsSymlinkFalseFile(c *gc.C) { dir := c.MkDir() target := filepath.Join(dir, "file") err := ioutil.WriteFile(target, []byte("TOP SECRET"), 0644) c.Assert(err, gc.IsNil) _, err = os.Stat(target) c.Assert(err, gc.IsNil) isSymlink, err := symlink.IsSymlink(target) c.Assert(err, gc.IsNil) c.Assert(isSymlink, jc.IsFalse) } func (*SymlinkSuite) TestIsSymlinkFalseFolder(c *gc.C) { target, err := symlink.GetLongPathAsString(c.MkDir()) c.Assert(err, gc.IsNil) _, err = os.Stat(target) c.Assert(err, gc.IsNil) isSymlink, err := symlink.IsSymlink(target) c.Assert(err, gc.IsNil) c.Assert(isSymlink, jc.IsFalse) } func (*SymlinkSuite) TestIsSymlinkFileDoesNotExist(c *gc.C) { dir := c.MkDir() target := filepath.Join(dir, "file") isSymlink, err := symlink.IsSymlink(target) c.Assert(err, gc.ErrorMatches, ".*"+utils.NoSuchFileErrRegexp) c.Assert(isSymlink, jc.IsFalse) } ================================================ FILE: symlink/symlink_windows.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. // Author: Robert Tingirica package symlink import ( "os" "strings" "syscall" "unicode/utf16" "unsafe" "github.com/juju/errors" ) const ( SYMBOLIC_LINK_FLAG_DIRECTORY = 1 // This is the equivalent of syscall.GENERIC_EXECUTION // Using syscall.GENERIC_EXECUTION results in an "Access denied" error GENERIC_EXECUTION = 33554432 // (TODO): bogdanteleaga or anybody else: // Remove this once we upgrade to a go version that has it in the syscall // package FILE_ATTRIBUTE_REPARSE_POINT = 0x00000400 ) //sys createSymbolicLink(symlinkname *uint16, targetname *uint16, flags uint32) (err error) = CreateSymbolicLinkW //sys getFinalPathNameByHandle(handle Handle, buf *uint16, buflen uint32, flags uint32) (n uint32, err error) = GetFinalPathNameByHandleW // New creates newname as a symbolic link to oldname. // If there is an error, it will be of type *LinkError. func New(oldname, newname string) error { fi, err := os.Stat(oldname) if err != nil { return &os.LinkError{"symlink", oldname, newname, err} } var flag uint32 if fi.IsDir() { flag = SYMBOLIC_LINK_FLAG_DIRECTORY } targetp, err := getLongPath(oldname) if err != nil { return &os.LinkError{"symlink", oldname, newname, err} } linkp, err := syscall.UTF16PtrFromString(newname) if err != nil { return &os.LinkError{"symlink", oldname, newname, err} } err = createSymbolicLink(linkp, &targetp[0], flag) if err != nil { return &os.LinkError{"symlink", oldname, newname, err} } return nil } // Read returns the destination of the named symbolic link. // If there is an error, it will be of type *PathError. func Read(link string) (string, error) { linkp, err := getLongPath(link) if err != nil { return "", err } h, err := syscall.CreateFile( &linkp[0], syscall.GENERIC_READ, syscall.FILE_SHARE_READ, nil, syscall.OPEN_EXISTING, GENERIC_EXECUTION, 0) if err != nil { return "", &os.PathError{"readlink", link, err} } defer syscall.CloseHandle(h) pathw := make([]uint16, syscall.MAX_PATH) n, err := getFinalPathNameByHandle(h, &pathw[0], uint32(len(pathw)), 0) if err != nil { return "", &os.PathError{"readlink", link, err} } if n > uint32(len(pathw)) { pathw = make([]uint16, n) n, err = getFinalPathNameByHandle(h, &pathw[0], uint32(len(pathw)), 0) if err != nil { return "", &os.PathError{"readlink", link, err} } if n > uint32(len(pathw)) { return "", &os.PathError{"readlink", link, errors.New("link length too long")} } } ret := string(utf16.Decode(pathw[0:n])) if strings.HasPrefix(ret, `\\?\`) { return ret[4:], nil } retp, err := getLongPath(ret) if err != nil { return "", &os.PathError{"readlink", link, err} } return syscall.UTF16ToString(retp), nil } func IsSymlink(path string) (bool, error) { var fa syscall.Win32FileAttributeData namep, err := syscall.UTF16PtrFromString(path) if err != nil { return false, errors.Trace(err) } err = syscall.GetFileAttributesEx(namep, syscall.GetFileExInfoStandard, (*byte)(unsafe.Pointer(&fa))) if err != nil { return false, errors.Trace(err) } return fa.FileAttributes&FILE_ATTRIBUTE_REPARSE_POINT != 0, nil } // getLongPath converts windows 8.1 short style paths (c:\Progra~1\foo) to full // long paths. func getLongPath(path string) ([]uint16, error) { pathp, err := syscall.UTF16FromString(path) if err != nil { return nil, err } longp := pathp n, err := syscall.GetLongPathName(&pathp[0], &longp[0], uint32(len(longp))) if err != nil { return nil, err } if n > uint32(len(longp)) { longp = make([]uint16, n) n, err = syscall.GetLongPathName(&pathp[0], &longp[0], uint32(len(longp))) if err != nil { return nil, err } } longp = longp[:n] return longp, nil } func getLongPathAsString(path string) (string, error) { longp, err := getLongPath(path) if err != nil { return "", err } return syscall.UTF16ToString(longp), nil } ================================================ FILE: symlink/symlink_windows_test.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. package symlink_test import ( "io/ioutil" "os" "path/filepath" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/symlink" ) func (*SymlinkSuite) TestLongPath(c *gc.C) { programFiles := `C:\PROGRA~1` longProg := `C:\Program Files` target, err := symlink.GetLongPathAsString(programFiles) c.Assert(err, gc.IsNil) c.Assert(target, gc.Equals, longProg) } func (*SymlinkSuite) TestCreateSymLink(c *gc.C) { target, err := symlink.GetLongPathAsString(c.MkDir()) c.Assert(err, gc.IsNil) link := filepath.Join(target, "link") _, err = os.Stat(target) c.Assert(err, gc.IsNil) err = symlink.New(target, link) c.Assert(err, gc.IsNil) link, err = symlink.Read(link) c.Assert(err, gc.IsNil) c.Assert(link, gc.Equals, filepath.FromSlash(target)) } func (*SymlinkSuite) TestReadData(c *gc.C) { dir := c.MkDir() sub := filepath.Join(dir, "sub") err := os.Mkdir(sub, 0700) c.Assert(err, gc.IsNil) oldname := filepath.Join(sub, "foo") data := []byte("data") err = ioutil.WriteFile(oldname, data, 0644) c.Assert(err, gc.IsNil) newname := filepath.Join(dir, "bar") err = symlink.New(oldname, newname) c.Assert(err, gc.IsNil) b, err := ioutil.ReadFile(newname) c.Assert(err, gc.IsNil) c.Assert(string(b), gc.Equals, string(data)) } ================================================ FILE: symlink/zsymlink_windows_386.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. // mksyscall_windows.pl -l32 symlink/symlink_windows.go // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT package symlink import "unsafe" import "syscall" var ( modkernel32 = syscall.NewLazyDLL("kernel32.dll") procCreateSymbolicLinkW = modkernel32.NewProc("CreateSymbolicLinkW") procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW") ) func createSymbolicLink(symlinkname *uint16, targetname *uint16, flags uint32) (err error) { r1, _, e1 := syscall.Syscall(procCreateSymbolicLinkW.Addr(), 3, uintptr(unsafe.Pointer(symlinkname)), uintptr(unsafe.Pointer(targetname)), uintptr(flags)) if r1 == 0 { if e1 != 0 { err = error(e1) } else { err = syscall.EINVAL } } return } func getFinalPathNameByHandle(handle syscall.Handle, buf *uint16, buflen uint32, flags uint32) (n uint32, err error) { r0, _, e1 := syscall.Syscall6(procGetFinalPathNameByHandleW.Addr(), 4, uintptr(handle), uintptr(unsafe.Pointer(buf)), uintptr(buflen), uintptr(flags), 0, 0) n = uint32(r0) if n == 0 { if e1 != 0 { err = error(e1) } else { err = syscall.EINVAL } } return } ================================================ FILE: symlink/zsymlink_windows_amd64.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. // mksyscall_windows.pl symlink/symlink_windows.go // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT package symlink import "unsafe" import "syscall" var ( modkernel32 = syscall.NewLazyDLL("kernel32.dll") procCreateSymbolicLinkW = modkernel32.NewProc("CreateSymbolicLinkW") procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW") ) func createSymbolicLink(symlinkname *uint16, targetname *uint16, flags uint32) (err error) { r1, _, e1 := syscall.Syscall(procCreateSymbolicLinkW.Addr(), 3, uintptr(unsafe.Pointer(symlinkname)), uintptr(unsafe.Pointer(targetname)), uintptr(flags)) if r1 == 0 { if e1 != 0 { err = error(e1) } else { err = syscall.EINVAL } } return } func getFinalPathNameByHandle(handle syscall.Handle, buf *uint16, buflen uint32, flags uint32) (n uint32, err error) { r0, _, e1 := syscall.Syscall6(procGetFinalPathNameByHandleW.Addr(), 4, uintptr(handle), uintptr(unsafe.Pointer(buf)), uintptr(buflen), uintptr(flags), 0, 0) n = uint32(r0) if n == 0 { if e1 != 0 { err = error(e1) } else { err = syscall.EINVAL } } return } ================================================ FILE: systemerrmessages_unix.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. //go:build !windows // +build !windows package utils // The following are strings/regex-es which match common Unix error messages // that may be returned in case of failed calls to the system. // Any extra leading/trailing regex-es are left to be added by the developer. const ( NoSuchUserErrRegexp = `user: unknown user [a-z0-9_-]*` NoSuchFileErrRegexp = `no such file or directory` MkdirFailErrRegexp = `.* not a directory` ) ================================================ FILE: systemerrmessages_windows.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. package utils // The following are strings/regex-es which match common Windows error messages // that may be returned in case of failed calls to the system. // Any extra leading/trailing regex-es are left to be added by the developer. const ( NoSuchUserErrRegexp = `No mapping between account names and security IDs was done\.` NoSuchFileErrRegexp = `The system cannot find the (file|path) specified\.` MkdirFailErrRegexp = `mkdir .*` + NoSuchFileErrRegexp ) ================================================ FILE: tailer/export_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package tailer var ( BufferSize = &bufferSize NewTestTailer = newTailer ) ================================================ FILE: tailer/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package tailer_test import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: tailer/tailer.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package tailer import ( "bufio" "bytes" "io" "os" "time" "gopkg.in/tomb.v1" ) const ( defaultBufferSize = 4096 polltime = time.Second delimiter = '\n' ) var ( bufferSize = defaultBufferSize delimiters = []byte{delimiter} ) // TailerFilterFunc decides if a line shall be tailed (func is nil or // returns true) of shall be omitted (func returns false). type TailerFilterFunc func(line []byte) bool // Tailer reads an input line by line an tails them into the passed Writer. // The lines have to be terminated with a newline. type Tailer struct { tomb tomb.Tomb readSeeker io.ReadSeeker reader *bufio.Reader writeCloser io.WriteCloser writer *bufio.Writer filter TailerFilterFunc polltime time.Duration } // NewTailer starts a Tailer which reads strings from the passed // ReadSeeker line by line. If a filter function is specified the read // lines are filtered. The matching lines are written to the passed // Writer. func NewTailer(readSeeker io.ReadSeeker, writer io.Writer, filter TailerFilterFunc) *Tailer { return newTailer(readSeeker, writer, filter, polltime) } // newTailer starts a Tailer like NewTailer but allows the setting of // the read buffer size and the time between pollings for testing. func newTailer(readSeeker io.ReadSeeker, writer io.Writer, filter TailerFilterFunc, polltime time.Duration) *Tailer { t := &Tailer{ readSeeker: readSeeker, reader: bufio.NewReaderSize(readSeeker, bufferSize), writer: bufio.NewWriter(writer), filter: filter, polltime: polltime, } go func() { defer t.tomb.Done() t.tomb.Kill(t.loop()) }() return t } // Stop tells the tailer to stop working. func (t *Tailer) Stop() error { t.tomb.Kill(nil) return t.tomb.Wait() } // Wait waits until the tailer is stopped due to command // or an error. In case of an error it returns the reason. func (t *Tailer) Wait() error { return t.tomb.Wait() } // Dead returns the channel that can be used to wait until // the tailer is stopped. func (t *Tailer) Dead() <-chan struct{} { return t.tomb.Dead() } // Err returns a possible error. func (t *Tailer) Err() error { return t.tomb.Err() } // loop writes the last lines based on the buffer size to the // writer and then polls for more data to write it to the // writer too. func (t *Tailer) loop() error { // Start polling. // TODO(mue) 2013-12-06 // Handling of read-seeker/files being truncated during // tailing is currently missing! timer := time.NewTimer(0) for { select { case <-t.tomb.Dying(): return nil case <-timer.C: for { line, readErr := t.readLine() _, writeErr := t.writer.Write(line) if writeErr != nil { return writeErr } if readErr != nil { if readErr != io.EOF { return readErr } break } } if writeErr := t.writer.Flush(); writeErr != nil { return writeErr } timer.Reset(t.polltime) } } } // SeekLastLines sets the read position of the ReadSeeker to the // wanted number of filtered lines before the end. func SeekLastLines(readSeeker io.ReadSeeker, lines uint, filter TailerFilterFunc) error { offset, err := readSeeker.Seek(0, os.SEEK_END) if err != nil { return err } if lines == 0 { // We are done, just seeking to the end is sufficient. return nil } seekPos := int64(0) found := uint(0) buffer := make([]byte, bufferSize) SeekLoop: for offset > 0 { // buffer contains the data left over from the // previous iteration. space := cap(buffer) - len(buffer) if space < bufferSize { // Grow buffer. newBuffer := make([]byte, len(buffer), cap(buffer)*2) copy(newBuffer, buffer) buffer = newBuffer space = cap(buffer) - len(buffer) } if int64(space) > offset { // Use exactly the right amount of space if there's // only a small amount remaining. space = int(offset) } // Copy data remaining from last time to the end of the buffer, // so we can read into the right place. copy(buffer[space:cap(buffer)], buffer) buffer = buffer[0 : len(buffer)+space] offset -= int64(space) _, err := readSeeker.Seek(offset, os.SEEK_SET) if err != nil { return err } _, err = io.ReadFull(readSeeker, buffer[0:space]) if err != nil { return err } // Find the end of the last line in the buffer. // This will discard any unterminated line at the end // of the file. end := bytes.LastIndex(buffer, delimiters) if end == -1 { // No end of line found - discard incomplete // line and continue looking. If this happens // at the beginning of the file, we don't care // because we're going to stop anyway. buffer = buffer[:0] continue } end++ for { start := bytes.LastIndex(buffer[0:end-1], delimiters) if start == -1 && offset >= 0 { break } start++ if filter == nil || filter(buffer[start:end]) { found++ if found >= lines { seekPos = offset + int64(start) break SeekLoop } } end = start } // Leave the last line in buffer, as we don't know whether // it's complete or not. buffer = buffer[0:end] } // Final positioning. readSeeker.Seek(seekPos, os.SEEK_SET) return nil } // readLine reads the next valid line from the reader, even if it is // larger than the reader buffer. func (t *Tailer) readLine() ([]byte, error) { for { slice, err := t.reader.ReadSlice(delimiter) if err == nil { if t.isValid(slice) { return slice, nil } continue } line := append([]byte(nil), slice...) for err == bufio.ErrBufferFull { slice, err = t.reader.ReadSlice(delimiter) line = append(line, slice...) } switch err { case nil: if t.isValid(line) { return line, nil } case io.EOF: // EOF without delimiter, step back. t.readSeeker.Seek(-int64(len(line)), os.SEEK_CUR) return nil, err default: return nil, err } } } // isValid checks if the passed line is valid by checking if the // line has content, the filter function is nil or it returns true. func (t *Tailer) isValid(line []byte) bool { if t.filter == nil { return true } return t.filter(line) } ================================================ FILE: tailer/tailer_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package tailer_test import ( "bufio" "bytes" "fmt" "io" "sync" "time" "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/tailer" ) type tailerSuite struct { testing.IsolationSuite } var _ = gc.Suite(&tailerSuite{}) var alphabetData = []string{ "alpha alpha\n", "bravo bravo\n", "charlie charlie\n", "delta delta\n", "echo echo\n", "foxtrott foxtrott\n", "golf golf\n", "hotel hotel\n", "india india\n", "juliet juliet\n", "kilo kilo\n", "lima lima\n", "mike mike\n", "november november\n", "oscar oscar\n", "papa papa\n", "quebec quebec\n", "romeo romeo\n", "sierra sierra\n", "tango tango\n", "uniform uniform\n", "victor victor\n", "whiskey whiskey\n", "x-ray x-ray\n", "yankee yankee\n", "zulu zulu\n", } var tests = []struct { description string data []string initialLinesWritten int initialLinesRequested uint bufferSize int filter tailer.TailerFilterFunc injector func(*tailer.Tailer, *readSeeker) func([]string) initialCollectedData []string appendedCollectedData []string fromStart bool err string }{{ description: "lines are longer than buffer size", data: []string{ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n", "0123456789012345678901234567890123456789012345678901\n", }, initialLinesWritten: 1, initialLinesRequested: 1, bufferSize: 5, initialCollectedData: []string{ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n", }, appendedCollectedData: []string{ "0123456789012345678901234567890123456789012345678901\n", }, }, { description: "lines are longer than buffer size, missing termination of last line", data: []string{ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n", "0123456789012345678901234567890123456789012345678901\n", "the quick brown fox ", }, initialLinesWritten: 1, initialLinesRequested: 1, bufferSize: 5, initialCollectedData: []string{ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n", }, appendedCollectedData: []string{ "0123456789012345678901234567890123456789012345678901\n", }, }, { description: "lines are longer than buffer size, last line is terminated later", data: []string{ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n", "0123456789012345678901234567890123456789012345678901\n", "the quick brown fox ", "jumps over the lazy dog\n", }, initialLinesWritten: 1, initialLinesRequested: 1, bufferSize: 5, initialCollectedData: []string{ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n", }, appendedCollectedData: []string{ "0123456789012345678901234567890123456789012345678901\n", "the quick brown fox jumps over the lazy dog\n", }, }, { description: "missing termination of last line", data: []string{ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n", "0123456789012345678901234567890123456789012345678901\n", "the quick brown fox ", }, initialLinesWritten: 1, initialLinesRequested: 1, initialCollectedData: []string{ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n", }, appendedCollectedData: []string{ "0123456789012345678901234567890123456789012345678901\n", }, }, { description: "last line is terminated later", data: []string{ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n", "0123456789012345678901234567890123456789012345678901\n", "the quick brown fox ", "jumps over the lazy dog\n", }, initialLinesWritten: 1, initialLinesRequested: 1, initialCollectedData: []string{ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n", }, appendedCollectedData: []string{ "0123456789012345678901234567890123456789012345678901\n", "the quick brown fox jumps over the lazy dog\n", }, }, { description: "more lines already written than initially requested", data: alphabetData, initialLinesWritten: 5, initialLinesRequested: 3, initialCollectedData: []string{ "charlie charlie\n", "delta delta\n", "echo echo\n", }, appendedCollectedData: alphabetData[5:], }, { description: "less lines already written than initially requested", data: alphabetData, initialLinesWritten: 3, initialLinesRequested: 5, initialCollectedData: []string{ "alpha alpha\n", "bravo bravo\n", "charlie charlie\n", }, appendedCollectedData: alphabetData[3:], }, { description: "lines are longer than buffer size, more lines already written than initially requested", data: alphabetData, initialLinesWritten: 5, initialLinesRequested: 3, bufferSize: 5, initialCollectedData: []string{ "charlie charlie\n", "delta delta\n", "echo echo\n", }, appendedCollectedData: alphabetData[5:], }, { description: "ignore current lines", data: alphabetData, initialLinesWritten: 5, bufferSize: 5, appendedCollectedData: alphabetData[5:], }, { description: "start from the start", data: alphabetData, initialLinesWritten: 5, bufferSize: 5, appendedCollectedData: alphabetData, fromStart: true, }, { description: "lines are longer than buffer size, less lines already written than initially requested", data: alphabetData, initialLinesWritten: 3, initialLinesRequested: 5, bufferSize: 5, initialCollectedData: []string{ "alpha alpha\n", "bravo bravo\n", "charlie charlie\n", }, appendedCollectedData: alphabetData[3:], }, { description: "filter lines which contain the char 'e'", data: alphabetData, initialLinesWritten: 10, initialLinesRequested: 3, filter: func(line []byte) bool { return bytes.Contains(line, []byte{'e'}) }, initialCollectedData: []string{ "echo echo\n", "hotel hotel\n", "juliet juliet\n", }, appendedCollectedData: []string{ "mike mike\n", "november november\n", "quebec quebec\n", "romeo romeo\n", "sierra sierra\n", "whiskey whiskey\n", "yankee yankee\n", }, }, { description: "stop tailing after 10 collected lines", data: alphabetData, initialLinesWritten: 5, initialLinesRequested: 3, injector: func(t *tailer.Tailer, rs *readSeeker) func([]string) { return func(lines []string) { if len(lines) == 10 { t.Stop() } } }, initialCollectedData: []string{ "charlie charlie\n", "delta delta\n", "echo echo\n", }, appendedCollectedData: alphabetData[5:], }, { description: "generate an error after 10 collected lines", data: alphabetData, initialLinesWritten: 5, initialLinesRequested: 3, injector: func(t *tailer.Tailer, rs *readSeeker) func([]string) { return func(lines []string) { if len(lines) == 10 { rs.setError(fmt.Errorf("ouch after 10 lines")) } } }, initialCollectedData: []string{ "charlie charlie\n", "delta delta\n", "echo echo\n", }, appendedCollectedData: alphabetData[5:], err: "ouch after 10 lines", }, { description: "more lines already written than initially requested, some empty, unfiltered", data: []string{ "one one\n", "two two\n", "\n", "\n", "three three\n", "four four\n", "\n", "\n", "five five\n", "six six\n", }, initialLinesWritten: 3, initialLinesRequested: 2, initialCollectedData: []string{ "two two\n", "\n", }, appendedCollectedData: []string{ "\n", "three three\n", "four four\n", "\n", "\n", "five five\n", "six six\n", }, }, { description: "more lines already written than initially requested, some empty, those filtered", data: []string{ "one one\n", "two two\n", "\n", "\n", "three three\n", "four four\n", "\n", "\n", "five five\n", "six six\n", }, initialLinesWritten: 3, initialLinesRequested: 2, filter: func(line []byte) bool { return len(bytes.TrimSpace(line)) > 0 }, initialCollectedData: []string{ "one one\n", "two two\n", }, appendedCollectedData: []string{ "three three\n", "four four\n", "five five\n", "six six\n", }, }} func (s *tailerSuite) TestTailer(c *gc.C) { for i, test := range tests { c.Logf("Test #%d) %s", i, test.description) bufferSize := test.bufferSize if bufferSize == 0 { // Default value. bufferSize = 4096 } s.PatchValue(tailer.BufferSize, bufferSize) reader, writer := io.Pipe() sigc := make(chan struct{}, 1) rs := startReadSeeker(c, test.data, test.initialLinesWritten, sigc) if !test.fromStart { err := tailer.SeekLastLines(rs, test.initialLinesRequested, test.filter) c.Assert(err, gc.IsNil) } tailer := tailer.NewTestTailer(rs, writer, test.filter, 2*time.Millisecond) linec := startReading(c, tailer, reader, writer) // Collect initial data. assertCollected(c, linec, test.initialCollectedData, nil) sigc <- struct{}{} // Collect remaining data, possibly with injection to stop // earlier or generate an error. var injection func([]string) if test.injector != nil { injection = test.injector(tailer, rs) } assertCollected(c, linec, test.appendedCollectedData, injection) if test.err == "" { c.Assert(tailer.Stop(), gc.IsNil) } else { c.Assert(tailer.Err(), gc.ErrorMatches, test.err) } } } // startReading starts a goroutine receiving the lines out of the reader // in the background and passing them to a created string channel. This // will used in the assertions. func startReading(c *gc.C, tailer *tailer.Tailer, reader *io.PipeReader, writer *io.PipeWriter) chan string { linec := make(chan string) // Start goroutine for reading. go func() { defer close(linec) reader := bufio.NewReader(reader) for { line, err := reader.ReadString('\n') switch err { case nil: linec <- line case io.EOF: return default: c.Fail() } } }() // Close writer when tailer is stopped or has an error. Tailer using // components can do it the same way. go func() { tailer.Wait() writer.Close() }() return linec } // assertCollected reads lines from the string channel linec. It compares if // those are the one passed with compare until a timeout. If the timeout is // reached earlier than all lines are collected the assertion fails. The // injection function allows to interrupt the processing with a function // generating an error or a regular stopping during the tailing. In case the // linec is closed due to stopping or an error only the values so far care // compared. Checking the reason for termination is done in the test. func assertCollected(c *gc.C, linec chan string, compare []string, injection func([]string)) { if len(compare) == 0 { return } timeout := time.After(10 * time.Second) lines := []string{} for { select { case line, ok := <-linec: if ok { lines = append(lines, line) if injection != nil { injection(lines) } if len(lines) == len(compare) { // All data received. c.Assert(lines, gc.DeepEquals, compare) return } } else { // linec closed after stopping or error. c.Assert(lines, gc.DeepEquals, compare[:len(lines)]) return } case <-timeout: if injection == nil { c.Fatalf("timeout during tailer collection") } return } } } // startReadSeeker returns a ReadSeeker for the Tailer. It simulates // reading and seeking inside a file and also simulating an error. // The goroutine waits for a signal that it can start writing the // appended lines. func startReadSeeker(c *gc.C, data []string, initialLeg int, sigc chan struct{}) *readSeeker { // Write initial lines into the buffer. var rs readSeeker var i int for i = 0; i < initialLeg; i++ { rs.write(data[i]) } go func() { <-sigc for ; i < len(data); i++ { time.Sleep(5 * time.Millisecond) rs.write(data[i]) } }() return &rs } type readSeeker struct { mux sync.Mutex buffer []byte pos int err error } func (r *readSeeker) write(s string) { r.mux.Lock() defer r.mux.Unlock() r.buffer = append(r.buffer, []byte(s)...) } func (r *readSeeker) setError(err error) { r.mux.Lock() defer r.mux.Unlock() r.err = err } func (r *readSeeker) Read(p []byte) (n int, err error) { r.mux.Lock() defer r.mux.Unlock() if r.err != nil { return 0, r.err } if r.pos >= len(r.buffer) { return 0, io.EOF } n = copy(p, r.buffer[r.pos:]) r.pos += n return n, nil } func (r *readSeeker) Seek(offset int64, whence int) (ret int64, err error) { r.mux.Lock() defer r.mux.Unlock() var newPos int64 switch whence { case 0: newPos = offset case 1: newPos = int64(r.pos) + offset case 2: newPos = int64(len(r.buffer)) + offset default: return 0, fmt.Errorf("invalid whence: %d", whence) } if newPos < 0 { return 0, fmt.Errorf("negative position: %d", newPos) } if newPos >= 1<<31 { return 0, fmt.Errorf("position out of range: %d", newPos) } r.pos = int(newPos) return newPos, nil } ================================================ FILE: tar/tar.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // This package provides convenience helpers on top of archive/tar // to be able to tar/untar files with a functionality closer // to gnu tar command. package tar import ( "archive/tar" "crypto/sha1" "encoding/base64" "fmt" "io" "os" "path/filepath" "strings" "github.com/juju/errors" "github.com/juju/collections/set" "github.com/juju/utils/v4/symlink" ) // FindFile returns the header and ReadCloser for the entry in the // tarfile that matches the filename. If nothing matches, an // errors.NotFound error is returned. func FindFile(tarFile io.Reader, filename string) (*tar.Header, io.Reader, error) { reader := tar.NewReader(tarFile) for { header, err := reader.Next() if err == io.EOF { break } if err != nil { return nil, nil, errors.Trace(err) } if header.Name == filename { return header, reader, nil } } return nil, nil, errors.NotFoundf(filename) } // TarFiles writes a tar stream into target holding the files listed // in fileList. strip will be removed from the beginning of all the paths // when stored (much like gnu tar -C option) // Returns a Sha sum of the tar and nil if everything went well // or empty sting and error in case of error. // We use a base64 encoded sha1 hash, because this is the hash // used by RFC 3230 Digest headers in http responses // It is not safe to mutate files passed during this function, // however at least the bytes up to the inital size are written // successfully if no error is returned. func TarFiles(fileList []string, target io.Writer, strip string) (shaSum string, err error) { shahash := sha1.New() if err := tarAndHashFiles(fileList, target, strip, shahash); err != nil { return "", err } encodedHash := base64.StdEncoding.EncodeToString(shahash.Sum(nil)) return encodedHash, nil } func tarAndHashFiles(fileList []string, target io.Writer, strip string, hashw io.Writer) (err error) { checkClose := func(w io.Closer) { if closeErr := w.Close(); closeErr != nil && err == nil { err = fmt.Errorf("error closing tar writer: %v", closeErr) } } w := io.MultiWriter(target, hashw) tarw := tar.NewWriter(w) defer checkClose(tarw) for _, ent := range fileList { if err := writeContents(ent, strip, tarw); err != nil { return fmt.Errorf("write to tar file failed: %v", err) } } return nil } // writeContents creates an entry for the given file // or directory in the given tar archive. func writeContents(fileName, strip string, tarw *tar.Writer) error { f, err := os.Open(fileName) if err != nil { return err } defer f.Close() fInfo, err := os.Lstat(fileName) if err != nil { return err } link := "" if fInfo.Mode()&os.ModeSymlink == os.ModeSymlink { link, err = filepath.EvalSymlinks(fileName) if err != nil { return fmt.Errorf("cannnot dereference symlink: %v", err) } } h, err := tar.FileInfoHeader(fInfo, link) if err != nil { return fmt.Errorf("cannot create tar header for %q: %v", fileName, err) } h.Name = filepath.ToSlash(strings.TrimPrefix(fileName, strip)) if err := tarw.WriteHeader(h); err != nil { return fmt.Errorf("cannot write header for %q: %v", fileName, err) } if fInfo.Mode()&os.ModeSymlink == os.ModeSymlink { return nil } if !fInfo.IsDir() { // Limit data copied to inital stat size included in tar header // or ErrWriteTooLong is raised by archive/tar Writer. if _, err := io.CopyN(tarw, f, fInfo.Size()); err != nil { return fmt.Errorf("failed to write %q: %v", fileName, err) } return nil } for { names, err := f.Readdirnames(100) // will return at most 100 names and if less than 100 remaining // next call will return io.EOF and no names if err == io.EOF { return nil } if err != nil { return fmt.Errorf("error reading directory %q: %v", fileName, err) } for _, name := range names { if err := writeContents(filepath.Join(fileName, name), strip, tarw); err != nil { return err } } } } func createAndFill(filePath string, mode int64, content io.Reader) error { fh, err := os.Create(filePath) defer fh.Close() if err != nil { return fmt.Errorf("some of the tar contents cannot be written to disk: %v", err) } _, err = io.Copy(fh, content) if err != nil { return fmt.Errorf("failed while reading tar contents: %v", err) } err = os.Chmod(fh.Name(), os.FileMode(mode)) if err != nil { return fmt.Errorf("cannot set proper mode on file %q: %v", filePath, err) } if err := fh.Sync(); err != nil { return fmt.Errorf("failed to sync contents of file %v: %v", filePath, err) } if err := fh.Close(); err != nil { return fmt.Errorf("failed to close file %v: %v", filePath, err) } return nil } // UntarFiles will extract the contents of tarFile using // outputFolder as root func UntarFiles(tarFile io.Reader, outputFolder string) error { tr := tar.NewReader(tarFile) // Ensure we still make directories for any files where we haven't // already seen the directory (for example, juju backup generates // files like this). seenDirs := set.NewStrings() maybeMkParentDir := func(path string) error { dirName := filepath.Dir(path) if seenDirs.Contains(dirName) { return nil } err := os.MkdirAll(dirName, os.FileMode(0755)) if err != nil { return fmt.Errorf("cannot create parent directory for %q: %v", path, err) } seenDirs.Add(dirName) return nil } for { hdr, err := tr.Next() if err == io.EOF { // end of tar archive return nil } if err != nil { return fmt.Errorf("failed while reading tar header: %v", err) } fullPath := filepath.Join(outputFolder, hdr.Name) switch hdr.Typeflag { case tar.TypeDir: if err = os.MkdirAll(fullPath, os.FileMode(hdr.Mode)); err != nil { return fmt.Errorf("cannot extract directory %q: %v", fullPath, err) } seenDirs.Add(fullPath) case tar.TypeSymlink: if err = maybeMkParentDir(fullPath); err != nil { return err } if err = symlink.New(hdr.Linkname, fullPath); err != nil { return fmt.Errorf("cannot extract symlink %q to %q: %v", hdr.Linkname, fullPath, err) } continue case tar.TypeReg, tar.TypeRegA: if err = maybeMkParentDir(fullPath); err != nil { return err } if err = createAndFill(fullPath, hdr.Mode, tr); err != nil { return fmt.Errorf("cannot extract file %q: %v", fullPath, err) } } } } ================================================ FILE: tar/tar_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package tar import ( "archive/tar" "bytes" "crypto/sha1" "encoding/base64" "fmt" "io" "io/ioutil" "os" "path/filepath" "strings" stdtesting "testing" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" ) func TestPackage(t *stdtesting.T) { gc.TestingT(t) } var _ = gc.Suite(&TarSuite{}) type TarSuite struct { testing.IsolationSuite cwd string testFiles []string } func (t *TarSuite) SetUpTest(c *gc.C) { t.cwd = c.MkDir() t.IsolationSuite.SetUpTest(c) } func (t *TarSuite) createTestFiles(c *gc.C) { tarDirE := filepath.Join(t.cwd, "TarDirectoryEmpty") err := os.Mkdir(tarDirE, os.FileMode(0755)) c.Check(err, gc.IsNil) tarDirP := filepath.Join(t.cwd, "TarDirectoryPopulated") err = os.Mkdir(tarDirP, os.FileMode(0755)) c.Check(err, gc.IsNil) tarlink1 := filepath.Join(t.cwd, "TarLink") err = os.Symlink(tarDirP, tarlink1) c.Check(err, gc.IsNil) tarSubFile1 := filepath.Join(tarDirP, "TarSubFile1") tarSubFile1Handle, err := os.Create(tarSubFile1) c.Check(err, gc.IsNil) tarSubFile1Handle.WriteString("TarSubFile1") tarSubFile1Handle.Close() tarSublink1 := filepath.Join(tarDirP, "TarSubLink") err = os.Symlink(tarSubFile1, tarSublink1) c.Check(err, gc.IsNil) tarSubDir := filepath.Join(tarDirP, "TarDirectoryPopulatedSubDirectory") err = os.Mkdir(tarSubDir, os.FileMode(0755)) c.Check(err, gc.IsNil) tarFile1 := filepath.Join(t.cwd, "TarFile1") tarFile1Handle, err := os.Create(tarFile1) c.Check(err, gc.IsNil) tarFile1Handle.WriteString("TarFile1") tarFile1Handle.Close() tarFile2 := filepath.Join(t.cwd, "TarFile2") tarFile2Handle, err := os.Create(tarFile2) c.Check(err, gc.IsNil) tarFile2Handle.WriteString("TarFile2") tarFile2Handle.Close() t.testFiles = []string{tarDirE, tarDirP, tarlink1, tarFile1, tarFile2} } func (t *TarSuite) removeTestFiles(c *gc.C) { for _, removable := range t.testFiles { err := os.RemoveAll(removable) c.Assert(err, gc.IsNil) } } type expectedTarContents struct { Name string Body string } var testExpectedTarContents = []expectedTarContents{ {"TarDirectoryEmpty", ""}, {"TarDirectoryPopulated", ""}, {"TarLink", ""}, {"TarDirectoryPopulated/TarSubFile1", "TarSubFile1"}, {"TarDirectoryPopulated/TarSubLink", ""}, {"TarDirectoryPopulated/TarDirectoryPopulatedSubDirectory", ""}, {"TarFile1", "TarFile1"}, {"TarFile2", "TarFile2"}, } // Assert thar contents checks that the tar reader provided contains the // Expected files // expectedContents: is a slice of the filenames with relative paths that are // expected to be on the tar file // tarFile: is the path of the file to be checked func (t *TarSuite) assertTarContents(c *gc.C, expectedContents []expectedTarContents, tarFile io.Reader) { tr := tar.NewReader(tarFile) tarContents := make(map[string]string) // Iterate through the files in the archive. for { hdr, err := tr.Next() if err == io.EOF { // end of tar archive break } c.Assert(err, gc.IsNil) buf, err := ioutil.ReadAll(tr) c.Assert(err, gc.IsNil) tarContents[hdr.Name] = string(buf) } for _, expectedContent := range expectedContents { fullExpectedContent := strings.TrimPrefix(expectedContent.Name, string(os.PathSeparator)) body, ok := tarContents[fullExpectedContent] c.Log(tarContents) c.Log(expectedContents) c.Log(fmt.Sprintf("checking for presence of %q on tar file", fullExpectedContent)) c.Assert(ok, gc.Equals, true) if expectedContent.Body != "" { c.Log("Also checking the file contents") c.Assert(body, gc.Equals, expectedContent.Body) } } } func (t *TarSuite) assertFilesWhereUntared(c *gc.C, expectedContents []expectedTarContents, tarOutputFolder string) { tarContents := make(map[string]string) var walkFn filepath.WalkFunc walkFn = func(path string, finfo os.FileInfo, err error) error { if err != nil { return err } fileName := strings.TrimPrefix(path, tarOutputFolder) fileName = strings.TrimPrefix(fileName, string(os.PathSeparator)) c.Log(fileName) if fileName == "" { return nil } if finfo.IsDir() || finfo.Mode()&os.ModeSymlink == os.ModeSymlink { tarContents[fileName] = "" } else { readable, err := os.Open(path) if err != nil { return err } defer readable.Close() buf, err := ioutil.ReadAll(readable) c.Assert(err, gc.IsNil) tarContents[fileName] = string(buf) } return nil } filepath.Walk(tarOutputFolder, walkFn) for _, expectedContent := range expectedContents { fullExpectedContent := strings.TrimPrefix(expectedContent.Name, string(os.PathSeparator)) expectedPath := filepath.Join(tarOutputFolder, fullExpectedContent) _, err := os.Lstat(expectedPath) c.Assert(err, gc.Equals, nil) body, ok := tarContents[fullExpectedContent] c.Log(fmt.Sprintf("checking for presence of %q on untar files", fullExpectedContent)) c.Assert(ok, gc.Equals, true) if expectedContent.Body != "" { c.Log("Also checking the file contents") c.Assert(body, gc.Equals, expectedContent.Body) } } } func shaSumFile(c *gc.C, fileToSum io.Reader) string { shahash := sha1.New() _, err := io.Copy(shahash, fileToSum) c.Assert(err, gc.IsNil) return base64.StdEncoding.EncodeToString(shahash.Sum(nil)) } // Tar func (t *TarSuite) TestTarFiles(c *gc.C) { t.createTestFiles(c) var outputTar bytes.Buffer trimPath := fmt.Sprintf("%s/", t.cwd) shaSum, err := TarFiles(t.testFiles, &outputTar, trimPath) c.Check(err, gc.IsNil) outputBytes := outputTar.Bytes() fileShaSum := shaSumFile(c, bytes.NewBuffer(outputBytes)) c.Assert(shaSum, gc.Equals, fileShaSum) t.removeTestFiles(c) t.assertTarContents(c, testExpectedTarContents, bytes.NewBuffer(outputBytes)) } func (t *TarSuite) TestSymlinksTar(c *gc.C) { tarDirP := filepath.Join(t.cwd, "TarDirectory") err := os.Mkdir(tarDirP, os.FileMode(0755)) c.Check(err, gc.IsNil) tarlink1 := filepath.Join(t.cwd, "TarLink") err = os.Symlink(tarDirP, tarlink1) c.Check(err, gc.IsNil) testFiles := []string{tarDirP, tarlink1} var outputTar bytes.Buffer trimPath := fmt.Sprintf("%s/", t.cwd) _, err = TarFiles(testFiles, &outputTar, trimPath) c.Check(err, gc.IsNil) outputBytes := outputTar.Bytes() tr := tar.NewReader(bytes.NewBuffer(outputBytes)) symlinks := 0 for { hdr, err := tr.Next() if err == io.EOF { // end of tar archive break } c.Assert(err, gc.IsNil) if hdr.Typeflag == tar.TypeSymlink { symlinks += 1 c.Assert(hdr.Linkname, gc.Equals, tarDirP) } } c.Assert(symlinks, gc.Equals, 1) } // UnTar func (t *TarSuite) TestUnTarFilesUncompressed(c *gc.C) { t.createTestFiles(c) var outputTar bytes.Buffer trimPath := fmt.Sprintf("%s/", t.cwd) _, err := TarFiles(t.testFiles, &outputTar, trimPath) c.Check(err, gc.IsNil) t.removeTestFiles(c) outputDir := filepath.Join(t.cwd, "TarOuputFolder") err = os.Mkdir(outputDir, os.FileMode(0755)) c.Check(err, gc.IsNil) UntarFiles(&outputTar, outputDir) t.assertFilesWhereUntared(c, testExpectedTarContents, outputDir) } func (t *TarSuite) TestFindFileFound(c *gc.C) { t.createTestFiles(c) var outputTar bytes.Buffer trimPath := fmt.Sprintf("%s/", t.cwd) _, err := TarFiles(t.testFiles, &outputTar, trimPath) c.Assert(err, gc.IsNil) t.removeTestFiles(c) _, file, err := FindFile(&outputTar, "TarDirectoryPopulated/TarSubFile1") c.Assert(err, gc.IsNil) data, err := ioutil.ReadAll(file) c.Assert(err, gc.IsNil) c.Check(string(data), gc.Equals, "TarSubFile1") } func (t *TarSuite) TestFindFileNotFound(c *gc.C) { t.createTestFiles(c) var outputTar bytes.Buffer trimPath := fmt.Sprintf("%s/", t.cwd) _, err := TarFiles(t.testFiles, &outputTar, trimPath) c.Assert(err, gc.IsNil) t.removeTestFiles(c) _, _, err = FindFile(&outputTar, "does_not_exist") c.Check(err, gc.ErrorMatches, "does_not_exist not found") } func (t *TarSuite) TestUntarFilesHeadersIgnored(c *gc.C) { var buf bytes.Buffer w := tar.NewWriter(&buf) err := w.WriteHeader(&tar.Header{ Name: "pax_global_header", Typeflag: tar.TypeXGlobalHeader, }) c.Assert(err, gc.IsNil) err = w.Flush() c.Assert(err, gc.IsNil) err = UntarFiles(&buf, t.cwd) c.Assert(err, jc.ErrorIsNil) err = filepath.Walk(t.cwd, func(path string, finfo os.FileInfo, err error) error { if path != t.cwd { return fmt.Errorf("unexpected file: %v", path) } return err }) c.Assert(err, gc.IsNil) } func (t *TarSuite) TestUntarFilesWithMissingDirectories(c *gc.C) { var buf bytes.Buffer w := tar.NewWriter(&buf) contents := []byte("file contents") err := w.WriteHeader(&tar.Header{ Name: "missingdir/otherdir/file", Typeflag: tar.TypeReg, Mode: 0700, Size: int64(len(contents)), }) c.Assert(err, jc.ErrorIsNil) _, err = w.Write(contents) c.Assert(err, jc.ErrorIsNil) err = w.WriteHeader(&tar.Header{ Name: "missingdir/otherdir/link", Typeflag: tar.TypeSymlink, Linkname: "viginti", }) c.Assert(err, jc.ErrorIsNil) err = w.Flush() c.Assert(err, jc.ErrorIsNil) err = UntarFiles(&buf, t.cwd) c.Assert(err, jc.ErrorIsNil) var names []string err = filepath.Walk(t.cwd, func(path string, finfo os.FileInfo, err error) error { names = append(names, path[len(t.cwd):]) return nil }) c.Assert(err, jc.ErrorIsNil) expected := []string{ "", "/missingdir", "/missingdir/otherdir", "/missingdir/otherdir/file", "/missingdir/otherdir/link", } c.Assert(names, gc.DeepEquals, expected) } ================================================ FILE: timer.go ================================================ // Copyright 2015 Canonical Ltd. // Copyright 2015 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "math/rand" "time" "github.com/juju/clock" ) // Countdown implements a timer that will call a provided function. // after a internally stored duration. The steps as well as min and max // durations are declared upon initialization and depend on // the particular implementation. // // TODO(katco): 2016-08-09: This type is deprecated: lp:1611427 type Countdown interface { // Reset stops the timer and resets its duration to the minimum one. // Start must be called to start the timer again. Reset() // Start starts the internal timer. // At the end of the timer, if Reset hasn't been called in the mean time // Func will be called and the duration is increased for the next call. Start() } // NewBackoffTimer creates and initializes a new BackoffTimer // A backoff timer starts at min and gets multiplied by factor // until it reaches max. Jitter determines whether a small // randomization is added to the duration. // // TODO(katco): 2016-08-09: This type is deprecated: lp:1611427 func NewBackoffTimer(config BackoffTimerConfig) *BackoffTimer { return &BackoffTimer{ config: config, currentDuration: config.Min, } } // BackoffTimer implements Countdown. // A backoff timer starts at min and gets multiplied by factor // until it reaches max. Jitter determines whether a small // randomization is added to the duration. // // TODO(katco): 2016-08-09: This type is deprecated: lp:1611427 type BackoffTimer struct { config BackoffTimerConfig timer clock.Timer currentDuration time.Duration } // BackoffTimerConfig is a helper struct for backoff timer // that encapsulates config information. // // TODO(katco): 2016-08-09: This type is deprecated: lp:1611427 type BackoffTimerConfig struct { // The minimum duration after which Func is called. Min time.Duration // The maximum duration after which Func is called. Max time.Duration // Determines whether a small randomization is applied to // the duration. Jitter bool // The factor by which you want the duration to increase // every time. Factor int64 // Func is the function that will be called when the countdown reaches 0. Func func() // Clock provides the AfterFunc function used to call func. // It is exposed here so it's easier to mock it in tests. Clock clock.Clock } // Start implements the Timer interface. // Any existing timer execution is stopped before // a new one is created. func (t *BackoffTimer) Start() { if t.timer != nil { t.timer.Stop() } t.timer = t.config.Clock.AfterFunc(t.currentDuration, t.config.Func) // Since it's a backoff timer we will increase // the duration after each signal. t.increaseDuration() } // Reset implements the Timer interface. func (t *BackoffTimer) Reset() { if t.timer != nil { t.timer.Stop() } if t.currentDuration > t.config.Min { t.currentDuration = t.config.Min } } // increaseDuration will increase the duration based on // the current value and the factor. If jitter is true // it will add a 0.3% jitter to the final value. func (t *BackoffTimer) increaseDuration() { current := int64(t.currentDuration) nextDuration := time.Duration(current * t.config.Factor) if t.config.Jitter { // Get a factor in [-1; 1]. randFactor := (rand.Float64() * 2) - 1 jitter := float64(nextDuration) * randFactor * 0.03 nextDuration = nextDuration + time.Duration(jitter) } if nextDuration > t.config.Max { nextDuration = t.config.Max } t.currentDuration = nextDuration } ================================================ FILE: timer_test.go ================================================ // Copyright 2015 Canonical Ltd. // Copyright 2015 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "math" "time" gc "gopkg.in/check.v1" "github.com/juju/clock" "github.com/juju/testing" jc "github.com/juju/testing/checkers" "github.com/juju/utils/v4" ) type TestStdTimer struct { stdStub *testing.Stub } func (t *TestStdTimer) Stop() bool { t.stdStub.AddCall("Stop") return true } func (t *TestStdTimer) Reset(d time.Duration) bool { t.stdStub.AddCall("Reset", d) return true } func (t *TestStdTimer) Chan() <-chan time.Time { panic("should not be called") } type timerSuite struct { baseSuite testing.CleanupSuite timer *utils.BackoffTimer afterFuncCalls int64 properFuncCalled bool stub *testing.Stub min time.Duration max time.Duration factor int64 } var _ = gc.Suite(&timerSuite{}) type mockClock struct { stub *testing.Stub c *gc.C afterFuncCalls *int64 properFuncCalled *bool } // These 2 methods are not used here but are needed to satisfy the intergface func (c *mockClock) Now() time.Time { return time.Now() } func (c *mockClock) After(d time.Duration) <-chan time.Time { return time.After(d) } func (c *mockClock) NewTimer(d time.Duration) clock.Timer { panic("should not be called") } func (c *mockClock) AfterFunc(d time.Duration, f func()) clock.Timer { *c.afterFuncCalls++ f() c.c.Assert(*c.properFuncCalled, jc.IsTrue) *c.properFuncCalled = false return &TestStdTimer{c.stub} } func (s *timerSuite) SetUpTest(c *gc.C) { s.baseSuite.SetUpTest(c) s.stub = nil s.timer = nil } func (s *timerSuite) setup(c *gc.C) { s.afterFuncCalls = 0 s.stub = &testing.Stub{} // This along with the checks in afterFuncMock below assert // that mockFunc is indeed passed as the argument to afterFuncMock // to be executed. mockFunc := func() { s.properFuncCalled = true } mockClock := &mockClock{ stub: s.stub, c: c, afterFuncCalls: &s.afterFuncCalls, properFuncCalled: &s.properFuncCalled, } s.min = 2 * time.Second s.max = 16 * time.Second s.factor = 2 s.timer = utils.NewBackoffTimer( utils.BackoffTimerConfig{ Min: s.min, Max: s.max, Jitter: false, Factor: s.factor, Func: mockFunc, Clock: mockClock, }, ) } func (s *timerSuite) TestStart(c *gc.C) { s.setup(c) s.timer.Start() s.testStart(c, 1, 1) } func (s *timerSuite) TestMultipleStarts(c *gc.C) { s.setup(c) s.timer.Start() s.testStart(c, 1, 1) s.timer.Start() s.checkStopCalls(c, 1) s.testStart(c, 2, 2) s.timer.Start() s.checkStopCalls(c, 2) s.testStart(c, 3, 3) } func (s *timerSuite) TestResetNoStart(c *gc.C) { s.setup(c) s.timer.Reset() currentDuration := utils.ExposeBackoffTimerDuration(s.timer) c.Assert(currentDuration, gc.Equals, s.min) } func (s *timerSuite) TestResetAndStart(c *gc.C) { s.setup(c) s.timer.Reset() currentDuration := utils.ExposeBackoffTimerDuration(s.timer) c.Assert(currentDuration, gc.Equals, s.min) // These variables are used to track the number // of afterFuncCalls(signalCallsNo) and the number // of Stop calls(resetStopCallsNo + signalCallsNo) resetStopCallsNo := 0 signalCallsNo := 0 signalCallsNo++ s.timer.Start() s.testStart(c, 1, 1) resetStopCallsNo++ s.timer.Reset() s.checkStopCalls(c, resetStopCallsNo+signalCallsNo-1) currentDuration = utils.ExposeBackoffTimerDuration(s.timer) c.Assert(currentDuration, gc.Equals, s.min) for i := 1; i < 200; i++ { signalCallsNo++ s.timer.Start() s.testStart(c, int64(signalCallsNo), int64(i)) s.checkStopCalls(c, resetStopCallsNo+signalCallsNo-1) } resetStopCallsNo++ s.timer.Reset() s.checkStopCalls(c, signalCallsNo+resetStopCallsNo-1) for i := 1; i < 100; i++ { signalCallsNo++ s.timer.Start() s.testStart(c, int64(signalCallsNo), int64(i)) s.checkStopCalls(c, resetStopCallsNo+signalCallsNo-1) } resetStopCallsNo++ s.timer.Reset() s.checkStopCalls(c, signalCallsNo+resetStopCallsNo-1) } func (s *timerSuite) testStart(c *gc.C, afterFuncCalls int64, durationFactor int64) { c.Assert(s.afterFuncCalls, gc.Equals, afterFuncCalls) c.Logf("iteration %d", afterFuncCalls) expectedDuration := time.Duration(math.Pow(float64(s.factor), float64(durationFactor))) * s.min if expectedDuration > s.max || expectedDuration <= 0 { expectedDuration = s.max } currentDuration := utils.ExposeBackoffTimerDuration(s.timer) c.Assert(currentDuration, gc.Equals, expectedDuration) } func (s *timerSuite) checkStopCalls(c *gc.C, number int) { calls := make([]testing.StubCall, number) for i := 0; i < number; i++ { calls[i] = testing.StubCall{FuncName: "Stop"} } s.stub.CheckCalls(c, calls) } ================================================ FILE: trivial.go ================================================ // Copyright 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "bytes" "compress/gzip" "crypto/sha256" "encoding/hex" "io" "os" "strings" "unicode" ) // TODO(ericsnow) Move the quoting helpers into the shell package? // ShQuote quotes s so that when read by bash, no metacharacters // within s will be interpreted as such. func ShQuote(s string) string { // single-quote becomes single-quote, double-quote, single-quote, double-quote, single-quote return `'` + strings.Replace(s, `'`, `'"'"'`, -1) + `'` } // WinPSQuote quotes s so that when read by powershell, no metacharacters // within s will be interpreted as such. func WinPSQuote(s string) string { // See http://ss64.com/ps/syntax-esc.html#quotes. // Double quotes inside single quotes are fine, double single quotes inside // single quotes, not so much so. Having double quoted strings inside single // quoted strings, ensure no expansion happens. return `'` + strings.Replace(s, `'`, `"`, -1) + `'` } // WinCmdQuote quotes s so that when read by cmd.exe, no metacharacters // within s will be interpreted as such. func WinCmdQuote(s string) string { // See http://blogs.msdn.com/b/twistylittlepassagesallalike/archive/2011/04/23/everyone-quotes-arguments-the-wrong-way.aspx. quoted := winCmdQuote(s) return winCmdEscapeMeta(quoted) } func winCmdQuote(s string) string { var escaped string for _, c := range s { switch c { case '\\', '"': escaped += `\` } escaped += string(c) } return `"` + escaped + `"` } func winCmdEscapeMeta(str string) string { const meta = `()%!^"<>&|` var newStr string for _, c := range str { if strings.Contains(meta, string(c)) { newStr += "^" } newStr += string(c) } return newStr } // CommandString flattens a sequence of command arguments into a // string suitable for executing in a shell, escaping slashes, // variables and quotes as necessary; each argument is double-quoted // if and only if necessary. func CommandString(args ...string) string { var buf bytes.Buffer for i, arg := range args { needsQuotes := false var argBuf bytes.Buffer for _, r := range arg { if unicode.IsSpace(r) { needsQuotes = true } else if r == '"' || r == '$' || r == '\\' { needsQuotes = true argBuf.WriteByte('\\') } argBuf.WriteRune(r) } if i > 0 { buf.WriteByte(' ') } if needsQuotes { buf.WriteByte('"') _, _ = argBuf.WriteTo(&buf) buf.WriteByte('"') } else { _, _ = argBuf.WriteTo(&buf) } } return buf.String() } // Gzip compresses the given data. func Gzip(data []byte) []byte { var buf bytes.Buffer w := gzip.NewWriter(&buf) if _, err := w.Write(data); err != nil { // Compression should never fail unless it fails // to write to the underlying writer, which is a bytes.Buffer // that never fails. panic(err) } if err := w.Close(); err != nil { panic(err) } return buf.Bytes() } // Gunzip uncompresses the given data. func Gunzip(data []byte) ([]byte, error) { r, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { return nil, err } return io.ReadAll(r) } // ReadSHA256 returns the SHA256 hash of the contents read from source // (hex encoded) and the size of the source in bytes. func ReadSHA256(source io.Reader) (string, int64, error) { hash := sha256.New() size, err := io.Copy(hash, source) if err != nil { return "", 0, err } digest := hex.EncodeToString(hash.Sum(nil)) return digest, size, nil } // ReadFileSHA256 is like ReadSHA256 but reads the contents of the // given file. func ReadFileSHA256(filename string) (string, int64, error) { f, err := os.Open(filename) if err != nil { return "", 0, err } defer func() { _ = f.Close() }() return ReadSHA256(f) } ================================================ FILE: trivial_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "bytes" "fmt" "io/ioutil" "path/filepath" "strings" "github.com/juju/testing" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type utilsSuite struct { testing.IsolationSuite } var _ = gc.Suite(&utilsSuite{}) func (*utilsSuite) TestCompression(c *gc.C) { data := []byte(strings.Repeat("some data to be compressed\n", 100)) compressedData := []byte{ 0x1f, 0x8b, 0x08, 0x00, 0x33, 0xb5, 0xf6, 0x50, 0x00, 0x03, 0xed, 0xc9, 0xb1, 0x0d, 0x00, 0x20, 0x08, 0x45, 0xc1, 0xde, 0x29, 0x58, 0x0d, 0xe5, 0x97, 0x04, 0x23, 0xee, 0x1f, 0xa7, 0xb0, 0x7b, 0xd7, 0x5e, 0x57, 0xca, 0xc2, 0xaf, 0xdb, 0x2d, 0x9b, 0xb2, 0x55, 0xb9, 0x8f, 0xba, 0x15, 0xa3, 0x29, 0x8a, 0xa2, 0x28, 0x8a, 0xa2, 0x28, 0xea, 0x67, 0x3d, 0x71, 0x71, 0x6e, 0xbf, 0x8c, 0x0a, 0x00, 0x00, } cdata := utils.Gzip(data) c.Assert(len(cdata) < len(data), gc.Equals, true) data1, err := utils.Gunzip(cdata) c.Assert(err, gc.IsNil) c.Assert(data1, gc.DeepEquals, data) data1, err = utils.Gunzip(compressedData) c.Assert(err, gc.IsNil) c.Assert(data1, gc.DeepEquals, data) } func checkQuoting(c *gc.C, shQuote func(string) string, tests map[string]string) { for str, expected := range tests { c.Logf("- checking %q -", str) quoted := shQuote(str) c.Check(quoted, gc.Equals, expected) } } func (*utilsSuite) TestWinCmdQuote(c *gc.C) { args := map[string]string{ "": `^"^"`, "a": `^"a^"`, "'a'": `^"'a'^"`, `"a`: `^"\^"a^"`, `a"`: `^"a\^"^"`, `"a"`: `^"\^"a\^"^"`, "abc > xyz 2>&1 &": `^"abc ^> xyz 2^>^&1 ^&^"`, } checkQuoting(c, utils.WinCmdQuote, args) } func (*utilsSuite) TestWinPSQuote(c *gc.C) { args := map[string]string{ "": "''", "a": `'a'`, `"a"`: `'"a"'`, "'a": `'"a'`, "a'": `'a"'`, "'a'": `'"a"'`, "abc > xyz 2>&1 &": "'abc > xyz 2>&1 &'", } checkQuoting(c, utils.WinPSQuote, args) } func (*utilsSuite) TestCommandString(c *gc.C) { type test struct { args []string expected string } tests := []test{ {nil, ""}, {[]string{"a"}, "a"}, {[]string{"a$"}, `"a\$"`}, {[]string{""}, ""}, {[]string{"\\"}, `"\\"`}, {[]string{"a", "'b'"}, "a 'b'"}, {[]string{"a b"}, `"a b"`}, {[]string{"a", `"b"`}, `a "\"b\""`}, {[]string{"a", `"b\"`}, `a "\"b\\\""`}, {[]string{"a\n"}, "\"a\n\""}, } for i, test := range tests { c.Logf("test %d: %q", i, test.args) result := utils.CommandString(test.args...) c.Assert(result, gc.Equals, test.expected) } } func (*utilsSuite) TestReadSHA256AndReadFileSHA256(c *gc.C) { sha256Tests := []struct { content string sha256 string }{{ content: "", sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", }, { content: "some content", sha256: "290f493c44f5d63d06b374d0a5abd292fae38b92cab2fae5efefe1b0e9347f56", }, { content: "foo", sha256: "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae", }, { content: "Foo", sha256: "1cbec737f863e4922cee63cc2ebbfaafcd1cff8b790d8cfd2e6a5d550b648afa", }, { content: "multi\nline\ntext\nhere", sha256: "c384f11c0294280792a44d9d6abb81f9fd991904cb7eb851a88311b04114231e", }} tempDir := c.MkDir() for i, test := range sha256Tests { c.Logf("test %d: %q -> %q", i, test.content, test.sha256) buf := bytes.NewBufferString(test.content) hash, size, err := utils.ReadSHA256(buf) c.Check(err, gc.IsNil) c.Check(hash, gc.Equals, test.sha256) c.Check(int(size), gc.Equals, len(test.content)) tempFileName := filepath.Join(tempDir, fmt.Sprintf("sha256-%d", i)) err = ioutil.WriteFile(tempFileName, []byte(test.content), 0644) c.Check(err, gc.IsNil) fileHash, fileSize, err := utils.ReadFileSHA256(tempFileName) c.Check(err, gc.IsNil) c.Check(fileHash, gc.Equals, hash) c.Check(fileSize, gc.Equals, size) } } ================================================ FILE: uptime/uptime_nix.go ================================================ // Copyright 2014 Cloudbase Solutions SRL // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // //go:build !windows // +build !windows package uptime import ( "syscall" ) // Uptime returns the number of seconds since the system has booted func Uptime() (int64, error) { info := &syscall.Sysinfo_t{} err := syscall.Sysinfo(info) if err != nil { return 0, err } return int64(info.Uptime), nil } ================================================ FILE: uptime/uptime_windows.go ================================================ // Copyright 2014 Cloudbase Solutions SRL // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package uptime import ( "fmt" ) //sys getTickCount64() (uptime uint64, err error) = GetTickCount64 // Uptime returns the number of seconds since the system has booted func Uptime() (int64, error) { uptime, err := getTickCount64() if err != nil { return 0, fmt.Errorf("Failed to get uptime. Error number: %v", err) } return int64(uptime) / 1000, nil } ================================================ FILE: uptime/zuptime_windows_386.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. // mksyscall_windows.pl -l32 uptime_windows.go // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT package uptime import "syscall" var ( modkernel32 = syscall.NewLazyDLL("kernel32.dll") procGetTickCount64 = modkernel32.NewProc("GetTickCount64") ) func getTickCount64() (uptime uint64, err error) { r0, _, e1 := syscall.Syscall(procGetTickCount64.Addr(), 0, 0, 0, 0) uptime = uint64(r0) if uptime == 0 { if e1 != 0 { err = error(e1) } else { err = syscall.EINVAL } } return } ================================================ FILE: uptime/zuptime_windows_amd64.go ================================================ // Copyright 2014 Canonical Ltd. // Copyright 2014 Cloudbase Solutions SRL // Licensed under the LGPLv3, see LICENCE file for details. // mksyscall_windows.pl uptime_windows.go // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT package uptime import "syscall" var ( modkernel32 = syscall.NewLazyDLL("kernel32.dll") procGetTickCount64 = modkernel32.NewProc("GetTickCount64") ) func getTickCount64() (uptime uint64, err error) { r0, _, e1 := syscall.Syscall(procGetTickCount64.Addr(), 0, 0, 0, 0) uptime = uint64(r0) if uptime == 0 { if e1 != 0 { err = error(e1) } else { err = syscall.EINVAL } } return } ================================================ FILE: username.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "os" "os/user" "github.com/juju/errors" ) // ResolveSudo returns the original username if sudo was used. The // original username is extracted from the OS environment. func ResolveSudo(username string) string { return resolveSudo(username, os.Getenv) } func resolveSudo(username string, getenvFunc func(string) string) string { if username != "root" { return username } // sudo was probably called, get the original user. if username := getenvFunc("SUDO_USER"); username != "" { return username } return username } // EnvUsername returns the username from the OS environment. func EnvUsername() (string, error) { return os.Getenv("USER"), nil } // OSUsername returns the username of the current OS user (based on UID). func OSUsername() (string, error) { u, err := user.Current() if err != nil { return "", errors.Trace(err) } return u.Username, nil } // ResolveUsername returns the username determined by the provided // functions. The functions are tried in the same order in which they // were passed in. An error returned from any of them is immediately // returned. If an empty string is returned then that signals that the // function did not find the username and the next function is tried. // Once a username is found, the provided resolveSudo func (if any) is // called with that username and the result is returned. If no username // is found then errors.NotFound is returned. func ResolveUsername(resolveSudo func(string) string, usernameFuncs ...func() (string, error)) (string, error) { for _, usernameFunc := range usernameFuncs { username, err := usernameFunc() if err != nil { return "", errors.Trace(err) } if username != "" { if resolveSudo != nil { if original := resolveSudo(username); original != "" { username = original } } return username, nil } } return "", errors.NotFoundf("username") } // LocalUsername determines the current username on the local host. func LocalUsername() (string, error) { username, err := ResolveUsername(ResolveSudo, EnvUsername, OSUsername) if err != nil { return "", errors.Annotatef(err, "cannot get current user from the environment: %v", os.Environ()) } return username, nil } ================================================ FILE: username_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "github.com/juju/errors" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) var _ = gc.Suite(&usernameSuite{}) type usernameSuite struct { testing.IsolationSuite } func (s *usernameSuite) TestResolveUsername(c *gc.C) { type test struct { userEnv string sudoEnv string userOS string expected string err string } tests := []test{{ userEnv: "someone", sudoEnv: "notroot", userOS: "other", expected: "someone", }, { userOS: "other", expected: "other", }, { userEnv: "root", expected: "root", }, { userEnv: "root", sudoEnv: "other", expected: "other", }, { err: "failed to determine username for namespace: oh noes", }} resolveUsername := func(t test) (string, error) { if t.err != "" { return "", errors.New(t.err) } var funcs []func() (string, error) if t.userEnv != "" { funcs = append(funcs, func() (string, error) { return t.userEnv, nil }) } if t.userOS != "" { funcs = append(funcs, func() (string, error) { return t.userOS, nil }) } resolveSudo := func(username string) string { return utils.ResolveSudoByFunc(username, func(string) string { return t.sudoEnv }) } return utils.ResolveUsername(resolveSudo, funcs...) } for i, test := range tests { c.Logf("test %d: %v", i, test) username, err := resolveUsername(test) if test.err == "" { if c.Check(err, jc.ErrorIsNil) { c.Check(username, gc.Equals, test.expected) } } else { c.Check(err, gc.ErrorMatches, test.err) } } } ================================================ FILE: uuid.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "crypto/rand" "encoding/hex" "fmt" "io" "regexp" "strings" ) // UUID represent a universal identifier with 16 octets. type UUID [16]byte // regex for validating that the UUID matches RFC 4122. // This package generates version 4 UUIDs but // accepts any UUID version. // http://www.ietf.org/rfc/rfc4122.txt var ( block1 = "[0-9a-f]{8}" block2 = "[0-9a-f]{4}" block3 = "[0-9a-f]{4}" block4 = "[0-9a-f]{4}" block5 = "[0-9a-f]{12}" UUIDSnippet = block1 + "-" + block2 + "-" + block3 + "-" + block4 + "-" + block5 validUUID = regexp.MustCompile("^" + UUIDSnippet + "$") ) func UUIDFromString(s string) (UUID, error) { if !IsValidUUIDString(s) { return UUID{}, fmt.Errorf("invalid UUID: %q", s) } s = strings.Replace(s, "-", "", 4) raw, err := hex.DecodeString(s) if err != nil { return UUID{}, err } var uuid UUID copy(uuid[:], raw) return uuid, nil } // IsValidUUIDString returns true, if the given string matches a valid UUID (version 4, variant 2). func IsValidUUIDString(s string) bool { return validUUID.MatchString(s) } // MustNewUUID returns a new uuid, if an error occurs it panics. func MustNewUUID() UUID { uuid, err := NewUUID() if err != nil { panic(err) } return uuid } // NewUUID generates a new version 4 UUID relying only on random numbers. func NewUUID() (UUID, error) { uuid := UUID{} if _, err := io.ReadFull(rand.Reader, []byte(uuid[0:16])); err != nil { return UUID{}, err } // Set version (4) and variant (2) according to RfC 4122. var version byte = 4 << 4 var variant byte = 8 << 4 uuid[6] = version | (uuid[6] & 15) uuid[8] = variant | (uuid[8] & 15) return uuid, nil } // Copy returns a copy of the UUID. func (uuid UUID) Copy() UUID { uuidCopy := uuid return uuidCopy } // Raw returns a copy of the UUID bytes. func (uuid UUID) Raw() [16]byte { return [16]byte(uuid) } // String returns a hexadecimal string representation with // standardized separators. func (uuid UUID) String() string { return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16]) } ================================================ FILE: uuid_test.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils_test import ( "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" "github.com/juju/utils/v4" ) type uuidSuite struct { testing.IsolationSuite } var _ = gc.Suite(&uuidSuite{}) func (*uuidSuite) TestUUID(c *gc.C) { uuid, err := utils.NewUUID() c.Assert(err, gc.IsNil) uuidCopy := uuid.Copy() uuidRaw := uuid.Raw() uuidStr := uuid.String() c.Assert(uuidRaw, gc.HasLen, 16) c.Assert(uuidStr, jc.Satisfies, utils.IsValidUUIDString) uuid[0] = 0x00 uuidCopy[0] = 0xFF c.Assert(uuid, gc.Not(gc.DeepEquals), uuidCopy) uuidRaw[0] = 0xFF c.Assert(uuid, gc.Not(gc.DeepEquals), uuidRaw) nextUUID, err := utils.NewUUID() c.Assert(err, gc.IsNil) c.Assert(uuid, gc.Not(gc.DeepEquals), nextUUID) } func (*uuidSuite) TestIsValidUUIDFailsWhenNotValid(c *gc.C) { tests := []struct { input string expected bool }{ { utils.UUID{}.String(), true, }, { "", false, }, { "blah", false, }, { "blah-9f484882-2f18-4fd2-967d-db9663db7bea", false, }, { "9f484882-2f18-4fd2-967d-db9663db7bea-blah", false, }, { "9f484882-2f18-4fd2-967d-db9663db7bea", true, }, } for i, t := range tests { c.Logf("Running test %d", i) c.Check(utils.IsValidUUIDString(t.input), gc.Equals, t.expected) } } func (*uuidSuite) TestUUIDFromString(c *gc.C) { _, err := utils.UUIDFromString("blah") c.Assert(err, gc.ErrorMatches, `invalid UUID: "blah"`) validUUID := "9f484882-2f18-4fd2-967d-db9663db7bea" uuid, err := utils.UUIDFromString(validUUID) c.Assert(err, gc.IsNil) c.Assert(uuid.String(), gc.Equals, validUUID) } ================================================ FILE: voyeur/package_test.go ================================================ // Copyright 2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package voyeur import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: voyeur/value.go ================================================ // Copyright 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // Package voyeur implements a concurrency-safe value that can be watched for // changes. package voyeur import ( "sync" ) // Value represents a shared value that can be watched for changes. Methods on // a Value may be called concurrently. The zero Value is // ok to use, and is equivalent to a NewValue result // with a nil initial value. type Value struct { val any version int mu sync.RWMutex wait sync.Cond closed bool } // NewValue creates a new Value holding the given initial value. If initial is // nil, any watchers will wait until a value is set. func NewValue(initial any) *Value { v := new(Value) v.init() if initial != nil { v.val = initial v.version++ } return v } func (v *Value) needsInit() bool { return v.wait.L == nil } func (v *Value) init() { if v.needsInit() { v.wait.L = v.mu.RLocker() } } // Set sets the shared value to val. func (v *Value) Set(val any) { v.mu.Lock() v.init() v.val = val v.version++ v.mu.Unlock() v.wait.Broadcast() } // Close closes the Value, unblocking any outstanding watchers. Close always // returns nil. func (v *Value) Close() error { v.mu.Lock() v.init() v.closed = true v.mu.Unlock() v.wait.Broadcast() return nil } // Closed reports whether the value has been closed. func (v *Value) Closed() bool { v.mu.RLock() defer v.mu.RUnlock() return v.closed } // Get returns the current value. func (v *Value) Get() any { v.mu.RLock() defer v.mu.RUnlock() return v.val } // Watch returns a Watcher that can be used to watch for changes to the value. func (v *Value) Watch() *Watcher { return &Watcher{value: v} } // Watcher represents a single watcher of a shared value. type Watcher struct { value *Value version int current any closed bool } // Next blocks until there is a new value to be retrieved from the value that is // being watched. It also unblocks when the value or the Watcher itself is // closed. Next returns false if the value or the Watcher itself have been // closed. func (w *Watcher) Next() bool { val := w.value val.mu.RLock() defer val.mu.RUnlock() if val.needsInit() { val.mu.RUnlock() val.mu.Lock() val.init() val.mu.Unlock() val.mu.RLock() } // We can go around this loop a maximum of two times, // because the only thing that can cause a Wait to // return is for the condition to be triggered, // which can only happen if the value is set (causing // the version to increment) or it is closed // causing the closed flag to be set. // Both these cases will cause Next to return. for { if w.version != val.version { w.version = val.version w.current = val.val return true } if val.closed || w.closed { return false } // Wait releases the lock until triggered and then reacquires the lock, // thus avoiding a deadlock. val.wait.Wait() } } // Close closes the Watcher without closing the underlying // value. It may be called concurrently with Next. func (w *Watcher) Close() { w.value.mu.Lock() w.value.init() w.closed = true w.value.mu.Unlock() w.value.wait.Broadcast() } // Value returns the last value that was retrieved from the watched Value by // Next. func (w *Watcher) Value() any { return w.current } ================================================ FILE: voyeur/value_test.go ================================================ // Copyright 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package voyeur import ( "fmt" "github.com/juju/testing" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" ) type suite struct { testing.IsolationSuite } var _ = gc.Suite(&suite{}) func ExampleWatcher_Next() { v := NewValue(nil) // The channel is not necessary for normal use of the watcher. // It just makes the test output predictable. ch := make(chan bool) go func() { for x := 0; x < 3; x++ { v.Set(fmt.Sprintf("value%d", x)) ch <- true } v.Close() }() w := v.Watch() for w.Next() { fmt.Println(w.Value()) <-ch } // output: // value0 // value1 // value2 } func (s *suite) TestValueGetSet(c *gc.C) { v := NewValue(nil) expected := "12345" v.Set(expected) got := v.Get() c.Assert(got, gc.Equals, expected) c.Assert(v.Closed(), jc.IsFalse) } func (s *suite) TestValueInitial(c *gc.C) { expected := "12345" v := NewValue(expected) got := v.Get() c.Assert(got, gc.Equals, expected) c.Assert(v.Closed(), jc.IsFalse) } func (s *suite) TestValueClose(c *gc.C) { expected := "12345" v := NewValue(expected) c.Assert(v.Close(), gc.IsNil) isClosed := v.Closed() c.Assert(isClosed, jc.IsTrue) got := v.Get() c.Assert(got, gc.Equals, expected) // test that we can close multiple times without a problem c.Assert(v.Close(), gc.IsNil) } func (s *suite) TestWatcher(c *gc.C) { vals := []string{"one", "two", "three"} // blocking on the channel forces the scheduler to let the other goroutine // run for a bit, so we get predictable results. This is not necessary for // normal use of the watcher. ch := make(chan bool) v := NewValue(nil) go func() { for _, s := range vals { v.Set(s) ch <- true } v.Close() }() w := v.Watch() c.Assert(w.Next(), jc.IsTrue) c.Assert(w.Value(), gc.Equals, vals[0]) // test that we can get the same value multiple times c.Assert(w.Value(), gc.Equals, vals[0]) <-ch // now try skipping a value by calling next without getting the value c.Assert(w.Next(), jc.IsTrue) <-ch c.Assert(w.Next(), jc.IsTrue) c.Assert(w.Value(), gc.Equals, vals[2]) <-ch c.Assert(w.Next(), jc.IsFalse) } func (s *suite) TestDoubleSet(c *gc.C) { vals := []string{"one", "two", "three"} // blocking on the channel forces the scheduler to let the other goroutine // run for a bit, so we get predictable results. This is not necessary for // normal use of the watcher. ch := make(chan bool) v := NewValue(nil) go func() { v.Set(vals[0]) ch <- true v.Set(vals[1]) v.Set(vals[2]) ch <- true v.Close() ch <- true }() w := v.Watch() c.Assert(w.Next(), jc.IsTrue) c.Assert(w.Value(), gc.Equals, vals[0]) <-ch // since we did two sets before sending on the channel, // we should just get vals[2] here and not get vals[1] c.Assert(w.Next(), jc.IsTrue) c.Assert(w.Value(), gc.Equals, vals[2]) } func (s *suite) TestTwoReceivers(c *gc.C) { vals := []string{"one", "two", "three"} // blocking on the channel forces the scheduler to let the other goroutine // run for a bit, so we get predictable results. This is not necessary for // normal use of the watcher. ch := make(chan bool) v := NewValue(nil) watcher := func() { w := v.Watch() x := 0 for w.Next() { c.Assert(w.Value(), gc.Equals, vals[x]) x++ <-ch } c.Assert(x, gc.Equals, len(vals)) <-ch } go watcher() go watcher() for _, val := range vals { v.Set(val) ch <- true ch <- true } v.Close() ch <- true ch <- true } func (s *suite) TestCloseWatcher(c *gc.C) { vals := []string{"one", "two", "three"} // blocking on the channel forces the scheduler to let the other goroutine // run for a bit, so we get predictable results. This is not necessary for // normal use of the watcher. ch := make(chan bool) v := NewValue(nil) w := v.Watch() go func() { x := 0 for w.Next() { c.Assert(w.Value(), gc.Equals, vals[x]) x++ <-ch } // the value will only get set once before the watcher is closed c.Assert(x, gc.Equals, 1) <-ch }() v.Set(vals[0]) ch <- true w.Close() ch <- true // prove the value is not closed, even though the watcher is c.Assert(v.Closed(), jc.IsFalse) } func (s *suite) TestWatchZeroValue(c *gc.C) { var v Value ch := make(chan bool) go func() { w := v.Watch() ch <- true ch <- w.Next() }() <-ch v.Set(struct{}{}) c.Assert(<-ch, jc.IsTrue) } ================================================ FILE: yaml.go ================================================ // Copyright 2012, 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "io/ioutil" "os" "path/filepath" "github.com/juju/errors" "gopkg.in/yaml.v2" ) // WriteYaml marshals obj as yaml to a temporary file in the same directory // as path, than atomically replaces path with the temporary file. func WriteYaml(path string, obj any) error { data, err := yaml.Marshal(obj) if err != nil { return errors.Trace(err) } dir := filepath.Dir(path) f, err := ioutil.TempFile(dir, "juju") if err != nil { return errors.Trace(err) } tmp := f.Name() if _, err := f.Write(data); err != nil { _ = f.Close() // don't leak file handle _ = os.Remove(tmp) // don't leak half written files on disk return errors.Trace(err) } if err := f.Sync(); err != nil { _ = f.Close() // don't leak file handle _ = os.Remove(tmp) // don't leak half written files on disk return errors.Trace(err) } // Explicitly close the file before moving it. This is needed on Windows // where the OS will not allow us to move a file that still has an open // file handle. Must check the error on close because filesystems can delay // reporting errors until the file is closed. if err := f.Close(); err != nil { _ = os.Remove(tmp) // don't leak half written files on disk return errors.Trace(err) } // ioutils.TempFile creates files 0600, but this function has a contract // that files will be world readable, 0644 after replacement. if err := os.Chmod(tmp, 0644); err != nil { _ = os.Remove(tmp) // remove file with incorrect permissions. return errors.Trace(err) } return ReplaceFile(tmp, path) } // ReadYaml unmarshals the yaml contained in the file at path into obj. See // goyaml.Unmarshal. If path is not found, the error returned will be compatible // with os.IsNotExist. func ReadYaml(path string, obj any) error { data, err := ioutil.ReadFile(path) if err != nil { return err // cannot wrap here because callers check for NotFound. } return yaml.Unmarshal(data, obj) } // ConformYAML ensures all keys of any nested maps are strings. This is // necessary because YAML unmarshals map[any]any in nested // maps, which cannot be serialized by json or bson. Also, handle // []any. cf. gopkg.in/juju/charm.v4/actions.go cleanse func ConformYAML(input any) (any, error) { switch typedInput := input.(type) { case map[string]any: newMap := make(map[string]any) for key, value := range typedInput { newValue, err := ConformYAML(value) if err != nil { return nil, err } newMap[key] = newValue } return newMap, nil case map[any]any: newMap := make(map[string]any) for key, value := range typedInput { typedKey, ok := key.(string) if !ok { return nil, errors.New("map keyed with non-string value") } newMap[typedKey] = value } return ConformYAML(newMap) case []any: newSlice := make([]any, len(typedInput)) for i, sliceValue := range typedInput { newSliceValue, err := ConformYAML(sliceValue) if err != nil { return nil, errors.New("map keyed with non-string value") } newSlice[i] = newSliceValue } return newSlice, nil default: return input, nil } } ================================================ FILE: yaml_test.go ================================================ // Copyright 2015 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package utils import ( "io/ioutil" "os" "path/filepath" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" ) type yamlSuite struct { } var _ = gc.Suite(&yamlSuite{}) func (*yamlSuite) TestYamlRoundTrip(c *gc.C) { // test happy path of round tripping an object via yaml type T struct { A int `yaml:"a"` B bool `yaml:"deleted"` C string `yaml:"omitempty"` D string } v := T{A: 1, B: true, C: "", D: ""} f, err := ioutil.TempFile(c.MkDir(), "yaml") c.Assert(err, gc.IsNil) path := f.Name() f.Close() err = WriteYaml(path, v) c.Assert(err, gc.IsNil) var v2 T err = ReadYaml(path, &v2) c.Assert(err, gc.IsNil) c.Assert(v, gc.Equals, v2) } func (*yamlSuite) TestReadYamlReturnsNotFound(c *gc.C) { // The contract for ReadYaml requires it returns an error // that can be inspected by os.IsNotExist. Notably, we cannot // use juju/errors gift wrapping. f, err := ioutil.TempFile(c.MkDir(), "yaml") c.Assert(err, gc.IsNil) path := f.Name() err = os.Remove(path) c.Assert(err, gc.IsNil) err = ReadYaml(path, nil) // assert that the error is reported as NotExist c.Assert(os.IsNotExist(err), gc.Equals, true) } func (*yamlSuite) TestWriteYamlMissingDirectory(c *gc.C) { // WriteYaml tries to create a temporary file in the same // directory as the target. Test what happens if the path's // directory is missing root := c.MkDir() missing := filepath.Join(root, "missing", "filename") v := struct{ A, B int }{1, 2} err := WriteYaml(missing, v) c.Assert(err, gc.NotNil) } func (*yamlSuite) TestWriteYamlWriteGarbage(c *gc.C) { c.Skip("https://github.com/go-yaml/yaml/issues/144") // some things cannot be marshalled into yaml, check that // WriteYaml detects this. root := c.MkDir() path := filepath.Join(root, "f") v := struct{ A, B [10]bool }{} err := WriteYaml(path, v) c.Assert(err, gc.NotNil) } type ConformSuite struct{} var _ = gc.Suite(&ConformSuite{}) func (s *ConformSuite) TestConformYAML(c *gc.C) { var goodInterfaceTests = []struct { description string inputInterface any expectedInterface map[string]any expectedError string }{{ description: "An interface requiring no changes.", inputInterface: map[string]any{ "key1": "value1", "key2": "value2", "key3": map[string]any{ "foo1": "val1", "foo2": "val2"}}, expectedInterface: map[string]any{ "key1": "value1", "key2": "value2", "key3": map[string]any{ "foo1": "val1", "foo2": "val2"}}, }, { description: "Substitute a single inner map[i]i.", inputInterface: map[string]any{ "key1": "value1", "key2": "value2", "key3": map[any]any{ "foo1": "val1", "foo2": "val2"}}, expectedInterface: map[string]any{ "key1": "value1", "key2": "value2", "key3": map[string]any{ "foo1": "val1", "foo2": "val2"}}, }, { description: "Substitute nested inner map[i]i.", inputInterface: map[string]any{ "key1a": "val1a", "key2a": "val2a", "key3a": map[any]any{ "key1b": "val1b", "key2b": map[any]any{ "key1c": "val1c"}}}, expectedInterface: map[string]any{ "key1a": "val1a", "key2a": "val2a", "key3a": map[string]any{ "key1b": "val1b", "key2b": map[string]any{ "key1c": "val1c"}}}, }, { description: "Substitute nested map[i]i within []i.", inputInterface: map[string]any{ "key1a": "val1a", "key2a": []any{5, "foo", map[string]any{ "key1b": "val1b", "key2b": map[any]any{ "key1c": "val1c"}}}}, expectedInterface: map[string]any{ "key1a": "val1a", "key2a": []any{5, "foo", map[string]any{ "key1b": "val1b", "key2b": map[string]any{ "key1c": "val1c"}}}}, }, { description: "An inner map[any]any with an int key.", inputInterface: map[string]any{ "key1": "value1", "key2": "value2", "key3": map[any]any{ "foo1": "val1", 5: "val2"}}, expectedError: "map keyed with non-string value", }, { description: "An inner []any containing a map[i]i with an int key.", inputInterface: map[string]any{ "key1a": "val1b", "key2a": "val2b", "key3a": []any{"foo1", 5, map[any]any{ "key1b": "val1b", "key2b": map[any]any{ "key1c": "val1c", 5: "val2c"}}}}, expectedError: "map keyed with non-string value", }} for i, test := range goodInterfaceTests { c.Logf("test %d: %s", i, test.description) input := test.inputInterface cleansedInterfaceMap, err := ConformYAML(input) if test.expectedError == "" { if !c.Check(err, jc.ErrorIsNil) { continue } c.Check(cleansedInterfaceMap, jc.DeepEquals, test.expectedInterface) } else { c.Check(err, gc.ErrorMatches, test.expectedError) } } } ================================================ FILE: zfile_windows.go ================================================ // Copyright 2013 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. // mksyscall_windows.pl -l32 file_windows.go // MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT package utils import "unsafe" import "syscall" var ( modkernel32 = syscall.NewLazyDLL("kernel32.dll") procMoveFileExW = modkernel32.NewProc("MoveFileExW") ) func moveFileEx(lpExistingFileName *uint16, lpNewFileName *uint16, dwFlags uint32) (err error) { r1, _, e1 := syscall.Syscall(procMoveFileExW.Addr(), 3, uintptr(unsafe.Pointer(lpExistingFileName)), uintptr(unsafe.Pointer(lpNewFileName)), uintptr(dwFlags)) if r1 == 0 { if e1 != 0 { err = error(e1) } else { err = syscall.EINVAL } } return } ================================================ FILE: zip/package_test.go ================================================ // Copyright 2011-2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package zip_test import ( "testing" gc "gopkg.in/check.v1" ) func TestPackage(t *testing.T) { gc.TestingT(t) } ================================================ FILE: zip/zip.go ================================================ // Copyright 2011-2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package zip import ( "archive/zip" "bytes" "fmt" "io" "os" "path" "path/filepath" "strings" ) // FindAll returns the cleaned path of every file in the supplied zip reader. func FindAll(reader *zip.Reader) ([]string, error) { return Find(reader, "*") } // Find returns the cleaned path of every file in the supplied zip reader whose // base name matches the supplied pattern, which is interpreted as in path.Match. func Find(reader *zip.Reader, pattern string) ([]string, error) { // path.Match will only return an error if the pattern is not // valid (*and* the supplied name is not empty, hence "check"). if _, err := path.Match(pattern, "check"); err != nil { return nil, err } var matches []string for _, zipFile := range reader.File { cleanPath := path.Clean(zipFile.Name) baseName := path.Base(cleanPath) if match, _ := path.Match(pattern, baseName); match { matches = append(matches, cleanPath) } } return matches, nil } // ExtractAll extracts the supplied zip reader to the target path, overwriting // existing files and directories only where necessary. func ExtractAll(reader *zip.Reader, targetRoot string) error { return Extract(reader, targetRoot, "") } // Extract extracts files from the supplied zip reader, from the (internal, slash- // separated) source path into the (external, OS-specific) target path. If the // source path does not reference a directory, the referenced file will be written // directly to the target path. func Extract(reader *zip.Reader, targetRoot, sourceRoot string) error { sourceRoot = path.Clean(sourceRoot) if sourceRoot == "." { sourceRoot = "" } if !isSanePath(sourceRoot) { return fmt.Errorf("cannot extract files rooted at %q", sourceRoot) } extractor := extractor{targetRoot, sourceRoot} for _, zipFile := range reader.File { if err := extractor.extract(zipFile); err != nil { cleanName := path.Clean(zipFile.Name) return fmt.Errorf("cannot extract %q: %v", cleanName, err) } } return nil } type extractor struct { targetRoot string sourceRoot string } // targetPath returns the target path for a given zip file and whether // it should be extracted. func (x extractor) targetPath(zipFile *zip.File) (string, bool) { cleanPath := path.Clean(zipFile.Name) if cleanPath == x.sourceRoot { return x.targetRoot, true } cleanPath = strings.TrimPrefix(cleanPath, "/") for strings.HasPrefix(cleanPath, "../") { cleanPath = cleanPath[len("../"):] } if x.sourceRoot != "" { mustPrefix := x.sourceRoot + "/" if !strings.HasPrefix(cleanPath, mustPrefix) { return "", false } cleanPath = cleanPath[len(mustPrefix):] } return filepath.Join(x.targetRoot, filepath.FromSlash(cleanPath)), true } func (x extractor) extract(zipFile *zip.File) error { targetPath, ok := x.targetPath(zipFile) if !ok { return nil } parentPath := filepath.Dir(targetPath) if err := os.MkdirAll(parentPath, 0777); err != nil { return err } mode := zipFile.Mode() modePerm := mode & os.ModePerm modeType := mode & os.ModeType switch modeType { case os.ModeDir: return x.writeDir(targetPath, modePerm) case os.ModeSymlink: return x.writeSymlink(targetPath, zipFile) case 0: return x.writeFile(targetPath, zipFile, modePerm) } return fmt.Errorf("unknown file type %d", modeType) } func (x extractor) writeDir(targetPath string, modePerm os.FileMode) error { fileInfo, err := os.Lstat(targetPath) switch { case err == nil: mode := fileInfo.Mode() if mode.IsDir() { if mode&os.ModePerm != modePerm { return os.Chmod(targetPath, modePerm) } return nil } fallthrough case !os.IsNotExist(err): if err := os.RemoveAll(targetPath); err != nil { return err } } return os.MkdirAll(targetPath, modePerm) } func (x extractor) writeFile(targetPath string, zipFile *zip.File, modePerm os.FileMode) error { if _, err := os.Lstat(targetPath); !os.IsNotExist(err) { if err := os.RemoveAll(targetPath); err != nil { return err } } writer, err := os.OpenFile(targetPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, modePerm) if err != nil { return err } defer writer.Close() if err := copyTo(writer, zipFile); err != nil { return err } if err := writer.Sync(); err != nil { return err } if err := writer.Close(); err != nil { return err } return nil } func (x extractor) writeSymlink(targetPath string, zipFile *zip.File) error { symlinkTarget, err := x.checkSymlink(targetPath, zipFile) if err != nil { return err } if _, err := os.Lstat(targetPath); !os.IsNotExist(err) { if err := os.RemoveAll(targetPath); err != nil { return err } } return os.Symlink(symlinkTarget, targetPath) } func (x extractor) checkSymlink(targetPath string, zipFile *zip.File) (string, error) { var buffer bytes.Buffer if err := copyTo(&buffer, zipFile); err != nil { return "", err } symlinkTarget := buffer.String() if filepath.IsAbs(symlinkTarget) { return "", fmt.Errorf("symlink %q is absolute", symlinkTarget) } finalPath := filepath.Join(filepath.Dir(targetPath), symlinkTarget) relativePath, err := filepath.Rel(x.targetRoot, finalPath) if err != nil { // Not tested, because I don't know how to trigger this condition. return "", fmt.Errorf("symlink %q not comprehensible", symlinkTarget) } if !isSanePath(relativePath) { return "", fmt.Errorf("symlink %q leads out of scope", symlinkTarget) } return symlinkTarget, nil } func copyTo(writer io.Writer, zipFile *zip.File) error { reader, err := zipFile.Open() if err != nil { return err } _, err = io.Copy(writer, reader) reader.Close() return err } func isSanePath(path string) bool { if path == ".." || strings.HasPrefix(path, "../") { return false } return true } ================================================ FILE: zip/zip_test.go ================================================ // Copyright 2011-2014 Canonical Ltd. // Licensed under the LGPLv3, see LICENCE file for details. package zip_test import ( stdzip "archive/zip" "bytes" "fmt" "io/ioutil" "os" "os/exec" "path/filepath" "sort" "github.com/juju/testing" jc "github.com/juju/testing/checkers" ft "github.com/juju/testing/filetesting" gc "gopkg.in/check.v1" "github.com/juju/utils/v4/zip" ) type ZipSuite struct { testing.IsolationSuite } var _ = gc.Suite(&ZipSuite{}) func (s *ZipSuite) makeZip(c *gc.C, entries ...ft.Entry) *stdzip.Reader { basePath := c.MkDir() for _, entry := range entries { entry.Create(c, basePath) } defer os.RemoveAll(basePath) outPath := filepath.Join(c.MkDir(), "test.zip") cmd := exec.Command("/bin/sh", "-c", fmt.Sprintf("cd %q; zip --fifo --symlinks -r %q .", basePath, outPath)) output, err := cmd.CombinedOutput() c.Assert(err, gc.IsNil, gc.Commentf("Command output: %s", output)) file, err := os.Open(outPath) c.Assert(err, gc.IsNil) s.AddCleanup(func(c *gc.C) { err := file.Close() c.Assert(err, gc.IsNil) }) fileInfo, err := file.Stat() c.Assert(err, gc.IsNil) reader, err := stdzip.NewReader(file, fileInfo.Size()) c.Assert(err, gc.IsNil) return reader } func (s *ZipSuite) TestFind(c *gc.C) { reader := s.makeZip(c, ft.File{"some-file", "", 0644}, ft.File{"another-file", "", 0644}, ft.Symlink{"some-symlink", "some-file"}, ft.Dir{"some-dir", 0755}, ft.Dir{"some-dir/another-dir", 0755}, ft.File{"some-dir/another-file", "", 0644}, ) for i, test := range []struct { pattern string expect []string }{{ "", nil, }, { "no-matches", nil, }, { "some-file", []string{ "some-file"}, }, { "another-file", []string{ "another-file", "some-dir/another-file"}, }, { "some-*", []string{ "some-file", "some-symlink", "some-dir"}, }, { "another-*", []string{ "another-file", "some-dir/another-dir", "some-dir/another-file"}, }, { "*", []string{ "some-file", "another-file", "some-symlink", "some-dir", "some-dir/another-dir", "some-dir/another-file"}, }} { c.Logf("test %d: %q", i, test.pattern) actual, err := zip.Find(reader, test.pattern) c.Assert(err, gc.IsNil) sort.Strings(test.expect) sort.Strings(actual) c.Check(actual, jc.DeepEquals, test.expect) } c.Logf("test $spanish-inquisition: FindAll") expect, err := zip.Find(reader, "*") c.Assert(err, gc.IsNil) actual, err := zip.FindAll(reader) c.Assert(err, gc.IsNil) sort.Strings(expect) sort.Strings(actual) c.Check(actual, jc.DeepEquals, expect) } func (s *ZipSuite) TestFindError(c *gc.C) { reader := s.makeZip(c, ft.File{"some-file", "", 0644}) _, err := zip.Find(reader, "[]") c.Assert(err, gc.ErrorMatches, "syntax error in pattern") } func (s *ZipSuite) TestExtractAll(c *gc.C) { entries := []ft.Entry{ ft.File{"some-file", "content 1", 0644}, ft.File{"another-file", "content 2", 0640}, ft.Symlink{"some-symlink", "some-file"}, ft.Dir{"some-dir", 0750}, ft.File{"some-dir/another-file", "content 3", 0644}, ft.Dir{"some-dir/another-dir", 0755}, ft.Symlink{"some-dir/another-dir/another-symlink", "../../another-file"}, } reader := s.makeZip(c, entries...) targetPath := c.MkDir() err := zip.ExtractAll(reader, targetPath) c.Assert(err, gc.IsNil) for i, entry := range entries { c.Logf("test %d: %#v", i, entry) entry.Check(c, targetPath) } } func (s *ZipSuite) TestExtractAllOverwriteFiles(c *gc.C) { name := "some-file" for i, test := range []ft.Entry{ ft.File{name, "content", 0644}, ft.Dir{name, 0751}, ft.Symlink{name, "wherever"}, } { c.Logf("test %d: %#v", i, test) targetPath := c.MkDir() ft.File{name, "original", 0}.Create(c, targetPath) reader := s.makeZip(c, test) err := zip.ExtractAll(reader, targetPath) c.Check(err, gc.IsNil) test.Check(c, targetPath) } } func (s *ZipSuite) TestExtractAllOverwriteSymlinks(c *gc.C) { name := "some-symlink" for i, test := range []ft.Entry{ ft.File{name, "content", 0644}, ft.Dir{name, 0751}, ft.Symlink{name, "wherever"}, } { c.Logf("test %d: %#v", i, test) targetPath := c.MkDir() original := ft.File{"original", "content", 0644} original.Create(c, targetPath) ft.Symlink{name, "original"}.Create(c, targetPath) reader := s.makeZip(c, test) err := zip.ExtractAll(reader, targetPath) c.Check(err, gc.IsNil) test.Check(c, targetPath) original.Check(c, targetPath) } } func (s *ZipSuite) TestExtractAllOverwriteDirs(c *gc.C) { name := "some-dir" for i, test := range []ft.Entry{ ft.File{name, "content", 0644}, ft.Dir{name, 0751}, ft.Symlink{name, "wherever"}, } { c.Logf("test %d: %#v", i, test) targetPath := c.MkDir() ft.Dir{name, 0}.Create(c, targetPath) reader := s.makeZip(c, test) err := zip.ExtractAll(reader, targetPath) c.Check(err, gc.IsNil) test.Check(c, targetPath) } } func (s *ZipSuite) TestExtractAllMergeDirs(c *gc.C) { targetPath := c.MkDir() ft.Dir{"dir", 0755}.Create(c, targetPath) originals := []ft.Entry{ ft.Dir{"dir/original-dir", 0751}, ft.File{"dir/original-file", "content 1", 0600}, ft.Symlink{"dir/original-symlink", "original-file"}, } for _, entry := range originals { entry.Create(c, targetPath) } merges := []ft.Entry{ ft.Dir{"dir", 0751}, ft.Dir{"dir/merge-dir", 0750}, ft.File{"dir/merge-file", "content 2", 0640}, ft.Symlink{"dir/merge-symlink", "merge-file"}, } reader := s.makeZip(c, merges...) err := zip.ExtractAll(reader, targetPath) c.Assert(err, gc.IsNil) for i, test := range append(originals, merges...) { c.Logf("test %d: %#v", i, test) test.Check(c, targetPath) } } func (s *ZipSuite) TestExtractAllSymlinkErrors(c *gc.C) { for i, test := range []struct { content []ft.Entry error string }{{ content: []ft.Entry{ ft.Symlink{"symlink", "/blah"}, }, error: `cannot extract "symlink": symlink "/blah" is absolute`, }, { content: []ft.Entry{ ft.Symlink{"symlink", "../blah"}, }, error: `cannot extract "symlink": symlink "../blah" leads out of scope`, }, { content: []ft.Entry{ ft.Dir{"dir", 0755}, ft.Symlink{"dir/symlink", "../../blah"}, }, error: `cannot extract "dir/symlink": symlink "../../blah" leads out of scope`, }} { c.Logf("test %d: %s", i, test.error) targetPath := c.MkDir() reader := s.makeZip(c, test.content...) err := zip.ExtractAll(reader, targetPath) c.Check(err, gc.ErrorMatches, test.error) } } func (s *ZipSuite) TestExtractDir(c *gc.C) { reader := s.makeZip(c, ft.File{"bad-file", "xxx", 0644}, ft.Dir{"bad-dir", 0755}, ft.Symlink{"bad-symlink", "bad-file"}, ft.Dir{"some-dir", 0751}, ft.File{"some-dir-bad-lol", "xxx", 0644}, ft.File{"some-dir/some-file", "content 1", 0644}, ft.File{"some-dir/another-file", "content 2", 0600}, ft.Dir{"some-dir/another-dir", 0750}, ft.Symlink{"some-dir/another-dir/some-symlink", "../some-file"}, ) targetParent := c.MkDir() targetPath := filepath.Join(targetParent, "random-dir") err := zip.Extract(reader, targetPath, "some-dir") c.Assert(err, gc.IsNil) for i, test := range []ft.Entry{ ft.Dir{"random-dir", 0751}, ft.File{"random-dir/some-file", "content 1", 0644}, ft.File{"random-dir/another-file", "content 2", 0600}, ft.Dir{"random-dir/another-dir", 0750}, ft.Symlink{"random-dir/another-dir/some-symlink", "../some-file"}, } { c.Logf("test %d: %#v", i, test) test.Check(c, targetParent) } fileInfos, err := ioutil.ReadDir(targetParent) c.Check(err, gc.IsNil) c.Check(fileInfos, gc.HasLen, 1) fileInfos, err = ioutil.ReadDir(targetPath) c.Check(err, gc.IsNil) c.Check(fileInfos, gc.HasLen, 3) } func (s *ZipSuite) TestExtractSingleFile(c *gc.C) { reader := s.makeZip(c, ft.Dir{"dir", 0755}, ft.Dir{"dir/dir", 0755}, ft.File{"dir/dir/some-file", "content 1", 0644}, ft.File{"dir/dir/some-file-wtf", "content 2", 0644}, ) targetParent := c.MkDir() targetPath := filepath.Join(targetParent, "just-the-one-file") err := zip.Extract(reader, targetPath, "dir/dir/some-file") c.Assert(err, gc.IsNil) fileInfos, err := ioutil.ReadDir(targetParent) c.Check(err, gc.IsNil) c.Check(fileInfos, gc.HasLen, 1) ft.File{"just-the-one-file", "content 1", 0644}.Check(c, targetParent) } func (s *ZipSuite) TestClosesFile(c *gc.C) { reader := s.makeZip(c, ft.File{"f", "echo hullo!", 0755}) targetPath := c.MkDir() err := zip.ExtractAll(reader, targetPath) c.Assert(err, gc.IsNil) cmd := exec.Command("/bin/sh", "-c", filepath.Join(targetPath, "f")) var buffer bytes.Buffer cmd.Stdout = &buffer err = cmd.Run() c.Assert(err, gc.IsNil) c.Assert(buffer.String(), gc.Equals, "hullo!\n") } func (s *ZipSuite) TestExtractSymlinkErrors(c *gc.C) { for i, test := range []struct { content []ft.Entry source string error string }{{ content: []ft.Entry{ ft.Dir{"dir", 0755}, ft.Symlink{"dir/symlink", "/blah"}, }, source: "dir", error: `cannot extract "dir/symlink": symlink "/blah" is absolute`, }, { content: []ft.Entry{ ft.Dir{"dir", 0755}, ft.Symlink{"dir/symlink", "../blah"}, }, source: "dir", error: `cannot extract "dir/symlink": symlink "../blah" leads out of scope`, }, { content: []ft.Entry{ ft.Symlink{"symlink", "blah"}, }, source: "symlink", error: `cannot extract "symlink": symlink "blah" leads out of scope`, }} { c.Logf("test %d: %s", i, test.error) targetPath := c.MkDir() reader := s.makeZip(c, test.content...) err := zip.Extract(reader, targetPath, test.source) c.Check(err, gc.ErrorMatches, test.error) } } func (s *ZipSuite) TestExtractSourceError(c *gc.C) { reader := s.makeZip(c, ft.Dir{"dir", 0755}) err := zip.Extract(reader, c.MkDir(), "../lol") c.Assert(err, gc.ErrorMatches, `cannot extract files rooted at "../lol"`) }