Repository: juju/utils
Branch: master
Commit: b92083fa0866
Files: 204
Total size: 543.2 KB
Directory structure:
gitextract_0jyt02x8/
├── .gitignore
├── ISSUE_TEMPLATE.md
├── LICENSE
├── LICENSE.golang
├── Makefile
├── README.md
├── SECURITY.md
├── arch/
│ ├── arch.go
│ ├── arch_test.go
│ └── package_test.go
├── attempt.go
├── attempt_test.go
├── bzr/
│ ├── bzr.go
│ ├── bzr_test.go
│ ├── bzr_unix_test.go
│ └── bzr_windows_test.go
├── cache/
│ ├── cache.go
│ ├── cache_test.go
│ ├── export_test.go
│ └── package_test.go
├── cert/
│ ├── cert.go
│ ├── cert_test.go
│ └── exports_test.go
├── command.go
├── command_test.go
├── context.go
├── context_test.go
├── du/
│ ├── LICENSE.ricochet2200
│ ├── diskusage.go
│ └── diskusage_windows.go
├── errors.go
├── exec/
│ ├── exec.go
│ ├── exec_internal_test.go
│ ├── exec_linux_test.go
│ ├── exec_test.go
│ ├── exec_unix.go
│ ├── exec_windows.go
│ ├── exec_windows_test.go
│ └── package_test.go
├── export_test.go
├── file.go
├── file_test.go
├── file_unix.go
├── file_unix_test.go
├── file_windows.go
├── file_windows_test.go
├── filepath/
│ ├── common.go
│ ├── common_test.go
│ ├── export_test.go
│ ├── filepath.go
│ ├── filepath_test.go
│ ├── interface_test.go
│ ├── package_test.go
│ ├── stdlib.go
│ ├── stdlib_test.go
│ ├── stdlibmatch.go
│ ├── unix.go
│ ├── unix_test.go
│ ├── win.go
│ └── win_test.go
├── filestorage/
│ ├── doc.go
│ ├── export_test.go
│ ├── fakes_test.go
│ ├── interfaces.go
│ ├── metadata.go
│ ├── metadata_store.go
│ ├── metadata_test.go
│ ├── package_test.go
│ ├── wrapper.go
│ └── wrapper_test.go
├── fs/
│ ├── copy.go
│ └── copy_test.go
├── go.mod
├── go.sum
├── gomaxprocs.go
├── gomaxprocs_test.go
├── hash/
│ ├── fingerprint.go
│ ├── fingerprint_test.go
│ ├── hash.go
│ ├── hash_test.go
│ ├── package_test.go
│ ├── writer.go
│ └── writer_test.go
├── home_unix.go
├── home_unix_test.go
├── home_windows.go
├── home_windows_test.go
├── isubuntu.go
├── isubuntu_test.go
├── jsonhttp/
│ ├── jsonhttp.go
│ ├── jsonhttp_test.go
│ └── package_test.go
├── keyvalues/
│ ├── keyvalues.go
│ ├── keyvalues_test.go
│ └── package_test.go
├── limiter.go
├── limiter_test.go
├── multireader.go
├── multireader_test.go
├── naturalsort.go
├── naturalsort_test.go
├── network.go
├── network_test.go
├── os.go
├── os_test.go
├── package_test.go
├── parallel/
│ ├── package_test.go
│ ├── parallel.go
│ ├── parallel_test.go
│ ├── try.go
│ └── try_test.go
├── password.go
├── password_test.go
├── proxy/
│ ├── package_test.go
│ ├── proxy.go
│ └── proxy_test.go
├── randomstring.go
├── randomstring_test.go
├── registry/
│ ├── export_test.go
│ ├── package_test.go
│ ├── registry.go
│ └── registry_test.go
├── relativeurl.go
├── relativeurl_test.go
├── setenv.go
├── setenv_test.go
├── shell/
│ ├── bash.go
│ ├── bash_test.go
│ ├── command.go
│ ├── interface_test.go
│ ├── output.go
│ ├── package_test.go
│ ├── powershell.go
│ ├── powershell_test.go
│ ├── renderer.go
│ ├── renderer_test.go
│ ├── script.go
│ ├── script_test.go
│ ├── unix.go
│ ├── win.go
│ ├── wincmd.go
│ └── wincmd_test.go
├── size.go
├── size_test.go
├── ssh/
│ ├── authorisedkeys.go
│ ├── authorisedkeys_test.go
│ ├── clientkeys.go
│ ├── clientkeys_test.go
│ ├── export_test.go
│ ├── fakes_test.go
│ ├── fingerprint.go
│ ├── fingerprint_test.go
│ ├── generate.go
│ ├── generate_test.go
│ ├── package_test.go
│ ├── run.go
│ ├── run_test.go
│ ├── ssh.go
│ ├── ssh_gocrypto.go
│ ├── ssh_gocrypto_test.go
│ ├── ssh_openssh.go
│ ├── ssh_test.go
│ ├── stream.go
│ ├── stream_test.go
│ ├── stream_wrapper_unix.go
│ ├── stream_wrapper_windows.go
│ └── testing/
│ └── keys.go
├── symlink/
│ ├── export_test.go
│ ├── symlink.go
│ ├── symlink_posix.go
│ ├── symlink_test.go
│ ├── symlink_windows.go
│ ├── symlink_windows_test.go
│ ├── zsymlink_windows_386.go
│ └── zsymlink_windows_amd64.go
├── systemerrmessages_unix.go
├── systemerrmessages_windows.go
├── tailer/
│ ├── export_test.go
│ ├── package_test.go
│ ├── tailer.go
│ └── tailer_test.go
├── tar/
│ ├── tar.go
│ └── tar_test.go
├── timer.go
├── timer_test.go
├── trivial.go
├── trivial_test.go
├── uptime/
│ ├── uptime_nix.go
│ ├── uptime_windows.go
│ ├── zuptime_windows_386.go
│ └── zuptime_windows_amd64.go
├── username.go
├── username_test.go
├── uuid.go
├── uuid_test.go
├── voyeur/
│ ├── package_test.go
│ ├── value.go
│ └── value_test.go
├── yaml.go
├── yaml_test.go
├── zfile_windows.go
└── zip/
├── package_test.go
├── zip.go
└── zip_test.go
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# GoLand
.idea/
# Dependency directories (remove the comment below to include it)
# vendor/
================================================
FILE: ISSUE_TEMPLATE.md
================================================
## Issues tracked in Launchpad
Please file an issue against https://bugs.launchpad.net/juju/+filebug
================================================
FILE: LICENSE
================================================
All files in this repository are licensed as follows. If you contribute
to this repository, it is assumed that you license your contribution
under the same license unless you state otherwise.
All files Copyright (C) 2015 Canonical Ltd. unless otherwise specified in the file.
This software is licensed under the LGPLv3, included below.
As a special exception to the GNU Lesser General Public License version 3
("LGPL3"), the copyright holders of this Library give you permission to
convey to a third party a Combined Work that links statically or dynamically
to this Library without providing any Minimal Corresponding Source or
Minimal Application Code as set out in 4d or providing the installation
information set out in section 4e, provided that you comply with the other
provisions of LGPL3 and provided that you meet, for the Application the
terms and conditions of the license(s) which apply to the Application.
Except as stated in this special exception, the provisions of LGPL3 will
continue to comply in full to this Library. If you modify this Library, you
may apply this exception to your version of this Library, but you are not
obliged to do so. If you do not wish to do so, delete this exception
statement from your version. This exception does not (and cannot) modify any
license terms which apply to the Application, with which you must still
comply.
GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc.
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.
0. Additional Definitions.
As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.
"The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.
An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.
A "Combined Work" is a work produced by combining or linking an
Application with the Library. The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".
The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.
The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.
1. Exception to Section 3 of the GNU GPL.
You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.
2. Conveying Modified Versions.
If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:
a) under this License, provided that you make a good faith effort to
ensure that, in the event an Application does not supply the
function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
b) under the GNU GPL, with none of the additional permissions of
this License applicable to that copy.
3. Object Code Incorporating Material from Library Header Files.
The object code form of an Application may incorporate material from
a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:
a) Give prominent notice with each copy of the object code that the
Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the object code with a copy of the GNU GPL and this license
document.
4. Combined Works.
You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:
a) Give prominent notice with each copy of the Combined Work that
the Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the Combined Work with a copy of the GNU GPL and this license
document.
c) For a Combined Work that displays copyright notices during
execution, include the copyright notice for the Library among
these notices, as well as a reference directing the user to the
copies of the GNU GPL and this license document.
d) Do one of the following:
0) Convey the Minimal Corresponding Source under the terms of this
License, and the Corresponding Application Code in a form
suitable for, and under terms that permit, the user to
recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
manner specified by section 6 of the GNU GPL for conveying
Corresponding Source.
1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
a copy of the Library already present on the user's computer
system, and (b) will operate properly with a modified version
of the Library that is interface-compatible with the Linked
Version.
e) Provide Installation Information, but only if you would otherwise
be required to provide such information under section 6 of the
GNU GPL, and only to the extent that such information is
necessary to install and execute a modified version of the
Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
you use option 4d0, the Installation Information must accompany
the Minimal Corresponding Source and Corresponding Application
Code. If you use option 4d1, you must provide the Installation
Information in the manner specified by section 6 of the GNU GPL
for conveying Corresponding Source.)
5. Combined Libraries.
You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:
a) Accompany the combined library with a copy of the same work based
on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
accompanying uncombined form of the same work.
6. Revised Versions of the GNU Lesser General Public License.
The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.
================================================
FILE: LICENSE.golang
================================================
This licence applies to the following files:
* filepath/stdlib.go
* filepath/stdlibmatch.go
Copyright (c) 2010 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: Makefile
================================================
PROJECT := github.com/juju/utils/v4
.PHONY: check-licence check-go check
check: check-licence check-go
go test -v $(PROJECT)/...
check-licence:
@(grep -rFl "Licensed under the LGPLv3" .;\
grep -rFl "MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT" .;\
grep -rFl "license that can be found in the LICENSE.ricochet2200 file" .; \
find . -name "*.go") | sed -e 's,\./,,' | sort | uniq -u | \
xargs -I {} echo FAIL: licence missed: {}
check-go:
$(eval GOFMT := $(strip $(shell gofmt -l .| sed -e "s/^/ /g")))
@(if [ x$(GOFMT) != x"" ]; then \
echo go fmt is sad: $(GOFMT); \
exit 1; \
fi )
@(go vet -all -composites=false -copylocks=false .)
# Install packages required to develop in utils and run tests.
install-dependencies: install-snap-dependencies install-mongo-dependencies
@echo Installing dependencies
@echo Installing bzr
@sudo apt install bzr --yes
@echo Installing zip
@sudo apt install zip --yes
install-snap-dependencies:
## install-snap-dependencies: Install the supported snap dependencies
@echo Installing go-1.17 snap
@sudo snap install go --channel=1.17/stable --classic
install-mongo-dependencies:
## install-mongo-dependencies: Install Mongo and its dependencies
@echo Adding juju PPA for mongodb
@sudo apt-add-repository --yes ppa:juju/stable
@sudo apt-get update
@echo Installing mongodb
@sudo apt-get --yes install \
$(strip $(DEPENDENCIES)) \
$(shell apt-cache madison mongodb-server-core juju-mongodb3.2 juju-mongodb mongodb-server | head -1 | cut -d '|' -f1)
================================================
FILE: README.md
================================================
juju/utils
============
This package provides general utility packages and functions.
================================================
FILE: SECURITY.md
================================================
# Security policy
## Reporting a vulnerability
Please provide a description of the issue, the steps you took to
create the issue, affected versions, and, if known, mitigations for
the issue.
The preferred way to report a security issue is through
[GitHub's security advisory for this project](https://github.com/juju/utils/security/advisories/new). See
[Privately reporting a security
vulnerability](https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing/privately-reporting-a-security-vulnerability)
for instructions on reporting using GitHub's security advisory feature.
The [Ubuntu Security disclosure and embargo
policy](https://ubuntu.com/security/disclosure-policy) contains more
information about how can contact us, what you can expect when you contact us,
and what we expect from you.
================================================
FILE: arch/arch.go
================================================
// Copyright 2014-2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package arch
import (
"regexp"
"runtime"
"strings"
)
// The following constants define the machine architectures supported by Juju.
const (
AMD64 = "amd64"
I386 = "i386"
ARM = "armhf"
ARM64 = "arm64"
PPC64EL = "ppc64el"
S390X = "s390x"
RISCV64 = "riscv64"
// Older versions of Juju used "ppc64" instead of ppc64el
LEGACY_PPC64 = "ppc64"
)
// AllSupportedArches records the machine architectures recognised by Juju.
var AllSupportedArches = []string{
AMD64,
I386,
ARM,
ARM64,
PPC64EL,
S390X,
RISCV64,
}
// Info records the information regarding each architecture recognised by Juju.
var Info = map[string]ArchInfo{
AMD64: {64},
I386: {32},
ARM: {32},
ARM64: {64},
PPC64EL: {64},
S390X: {64},
RISCV64: {64},
}
// ArchInfo is a struct containing information about a supported architecture.
type ArchInfo struct {
// WordSize is the architecture's word size, in bits.
WordSize int
}
// archREs maps regular expressions for matching
// `uname -m` to architectures recognised by Juju.
var archREs = []struct {
*regexp.Regexp
arch string
}{
{regexp.MustCompile("amd64|x86_64"), AMD64},
{regexp.MustCompile("i?[3-9]86"), I386},
{regexp.MustCompile("(arm$)|(armv.*)"), ARM},
{regexp.MustCompile("aarch64"), ARM64},
{regexp.MustCompile("ppc64|ppc64el|ppc64le"), PPC64EL},
{regexp.MustCompile("s390x"), S390X},
{regexp.MustCompile("riscv64|risc$|risc-[vV]64"), RISCV64},
}
// Override for testing.
var HostArch = hostArch
// hostArch returns the Juju architecture of the machine on which it is run.
func hostArch() string {
return NormaliseArch(runtime.GOARCH)
}
// NormaliseArch returns the Juju architecture corresponding to a machine's
// reported architecture. The Juju architecture is used to filter simple
// streams lookup of tools and images.
func NormaliseArch(rawArch string) string {
rawArch = strings.TrimSpace(rawArch)
for _, re := range archREs {
if re.Match([]byte(rawArch)) {
return re.arch
}
}
return rawArch
}
// IsSupportedArch returns true if arch is one supported by Juju.
func IsSupportedArch(arch string) bool {
for _, a := range AllSupportedArches {
if a == arch {
return true
}
}
return false
}
================================================
FILE: arch/arch_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package arch_test
import (
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/arch"
)
type archSuite struct {
}
var _ = gc.Suite(&archSuite{})
func (s *archSuite) TestHostArch(c *gc.C) {
a := arch.HostArch()
c.Assert(arch.IsSupportedArch(a), jc.IsTrue)
}
func (s *archSuite) TestNormaliseArch(c *gc.C) {
for _, test := range []struct {
raw string
arch string
}{
{"windows", "windows"},
{"amd64", "amd64"},
{"x86_64", "amd64"},
{"386", "i386"},
{"i386", "i386"},
{"i486", "i386"},
{"arm", "armhf"},
{"armv", "armhf"},
{"armv7", "armhf"},
{"aarch64", "arm64"},
{"arm64", "arm64"},
{"ppc64el", "ppc64el"},
{"ppc64le", "ppc64el"},
{"ppc64", "ppc64el"},
{"s390x", "s390x"},
{"riscv64", "riscv64"},
{"risc", "riscv64"},
{"risc-v64", "riscv64"},
{"risc-V64", "riscv64"},
} {
arch := arch.NormaliseArch(test.raw)
c.Check(arch, gc.Equals, test.arch)
}
}
func (s *archSuite) TestIsSupportedArch(c *gc.C) {
for _, a := range arch.AllSupportedArches {
c.Assert(arch.IsSupportedArch(a), jc.IsTrue)
}
c.Assert(arch.IsSupportedArch("invalid"), jc.IsFalse)
}
func (s *archSuite) TestArchInfo(c *gc.C) {
for _, a := range arch.AllSupportedArches {
_, ok := arch.Info[a]
c.Assert(ok, jc.IsTrue)
}
}
================================================
FILE: arch/package_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package arch_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func Test(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: attempt.go
================================================
// Copyright 2011, 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"time"
)
// The Attempt and AttemptStrategy types are copied from those in launchpad.net/goamz/aws.
// AttemptStrategy represents a strategy for waiting for an action
// to complete successfully.
type AttemptStrategy struct {
Total time.Duration // total duration of attempt.
Delay time.Duration // interval between each try in the burst.
Min int // minimum number of retries; overrides Total
}
type Attempt struct {
strategy AttemptStrategy
last time.Time
end time.Time
force bool
count int
}
// Start begins a new sequence of attempts for the given strategy.
func (s AttemptStrategy) Start() *Attempt {
now := time.Now()
return &Attempt{
strategy: s,
last: now,
end: now.Add(s.Total),
force: true,
}
}
// Next waits until it is time to perform the next attempt or returns
// false if it is time to stop trying.
// It always returns true the first time it is called - we are guaranteed to
// make at least one attempt.
func (a *Attempt) Next() bool {
now := time.Now()
sleep := a.nextSleep(now)
if !a.force && !now.Add(sleep).Before(a.end) && a.strategy.Min <= a.count {
return false
}
a.force = false
if sleep > 0 && a.count > 0 {
time.Sleep(sleep)
now = time.Now()
}
a.count++
a.last = now
return true
}
func (a *Attempt) nextSleep(now time.Time) time.Duration {
sleep := a.strategy.Delay - now.Sub(a.last)
if sleep < 0 {
return 0
}
return sleep
}
// HasNext returns whether another attempt will be made if the current
// one fails. If it returns true, the following call to Next is
// guaranteed to return true.
func (a *Attempt) HasNext() bool {
if a.force || a.strategy.Min > a.count {
return true
}
now := time.Now()
if now.Add(a.nextSleep(now)).Before(a.end) {
a.force = true
return true
}
return false
}
================================================
FILE: attempt_test.go
================================================
// Copyright 2011, 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"time"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
func doSomething() (int, error) { return 0, nil }
func shouldRetry(error) bool { return false }
func doSomethingWith(int) {}
func ExampleAttempt_HasNext() {
// This example shows how Attempt.HasNext can be used to help
// structure an attempt loop. If the godoc example code allowed
// us to make the example return an error, we would uncomment
// the commented return statements.
attempts := utils.AttemptStrategy{
Total: 1 * time.Second,
Delay: 250 * time.Millisecond,
}
for attempt := attempts.Start(); attempt.Next(); {
x, err := doSomething()
if shouldRetry(err) && attempt.HasNext() {
continue
}
if err != nil {
// return err
return
}
doSomethingWith(x)
}
// return ErrTimedOut
return
}
func (*utilsSuite) TestAttemptTiming(c *gc.C) {
testAttempt := utils.AttemptStrategy{
Total: 0.25e9,
Delay: 0.1e9,
}
want := []time.Duration{0, 0.1e9, 0.2e9, 0.2e9}
got := make([]time.Duration, 0, len(want)) // avoid allocation when testing timing
t0 := time.Now()
for a := testAttempt.Start(); a.Next(); {
got = append(got, time.Now().Sub(t0))
}
got = append(got, time.Now().Sub(t0))
c.Assert(got, gc.HasLen, len(want))
const margin = 0.01e9
for i, got := range want {
lo := want[i] - margin
hi := want[i] + margin
if got < lo || got > hi {
c.Errorf("attempt %d want %g got %g", i, want[i].Seconds(), got.Seconds())
}
}
}
func (*utilsSuite) TestAttemptNextHasNext(c *gc.C) {
a := utils.AttemptStrategy{}.Start()
c.Assert(a.Next(), gc.Equals, true)
c.Assert(a.Next(), gc.Equals, false)
a = utils.AttemptStrategy{}.Start()
c.Assert(a.Next(), gc.Equals, true)
c.Assert(a.HasNext(), gc.Equals, false)
c.Assert(a.Next(), gc.Equals, false)
a = utils.AttemptStrategy{Total: 2e8}.Start()
c.Assert(a.Next(), gc.Equals, true)
c.Assert(a.HasNext(), gc.Equals, true)
time.Sleep(2e8)
c.Assert(a.HasNext(), gc.Equals, true)
c.Assert(a.Next(), gc.Equals, true)
c.Assert(a.Next(), gc.Equals, false)
a = utils.AttemptStrategy{Total: 1e8, Min: 2}.Start()
time.Sleep(1e8)
c.Assert(a.Next(), gc.Equals, true)
c.Assert(a.HasNext(), gc.Equals, true)
c.Assert(a.Next(), gc.Equals, true)
c.Assert(a.HasNext(), gc.Equals, false)
c.Assert(a.Next(), gc.Equals, false)
}
================================================
FILE: bzr/bzr.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// Package bzr offers an interface to manage branches of the Bazaar VCS.
package bzr
import (
"bytes"
"fmt"
"os"
"os/exec"
"path"
"strings"
)
// Branch represents a Bazaar branch.
type Branch struct {
location string
env []string
}
// New returns a new Branch for the Bazaar branch at location.
func New(location string) *Branch {
b := &Branch{location, cenv()}
if _, err := os.Stat(location); err == nil {
stdout, _, err := b.bzr("root")
if err == nil {
// Need to trim \r as well as \n for Windows compatibility
b.location = strings.TrimRight(string(stdout), "\r\n")
}
}
return b
}
// cenv returns a copy of the current process environment with LC_ALL=C.
func cenv() []string {
env := os.Environ()
for i, pair := range env {
if strings.HasPrefix(pair, "LC_ALL=") {
env[i] = "LC_ALL=C"
return env
}
}
return append(env, "LC_ALL=C")
}
// Location returns the location of branch b.
func (b *Branch) Location() string {
return b.location
}
// Join returns b's location with parts appended as path components.
// In other words, if b's location is "lp:foo", and parts is {"bar, baz"},
// Join returns "lp:foo/bar/baz".
func (b *Branch) Join(parts ...string) string {
return path.Join(append([]string{b.location}, parts...)...)
}
func (b *Branch) bzr(subcommand string, args ...string) (stdout, stderr []byte, err error) {
cmd := exec.Command("bzr", append([]string{subcommand}, args...)...)
if _, err := os.Stat(b.location); err == nil {
cmd.Dir = b.location
}
errbuf := &bytes.Buffer{}
cmd.Stderr = errbuf
cmd.Env = b.env
stdout, err = cmd.Output()
// Some commands fail with exit status 0 (e.g. bzr root). :-(
if err != nil || bytes.Contains(errbuf.Bytes(), []byte("ERROR")) {
var errmsg string
if err != nil {
errmsg = err.Error()
}
return nil, nil, fmt.Errorf(`error running "bzr %s": %s%s%s`, subcommand, stdout, errbuf.Bytes(), errmsg)
}
return stdout, errbuf.Bytes(), err
}
// Init intializes a new branch at b's location.
func (b *Branch) Init() error {
_, _, err := b.bzr("init", b.location)
return err
}
// Add adds to b the path resultant from calling b.Join(parts...).
func (b *Branch) Add(parts ...string) error {
_, _, err := b.bzr("add", b.Join(parts...))
return err
}
// Commit commits pending changes into b.
func (b *Branch) Commit(message string) error {
_, _, err := b.bzr("commit", "-q", "-m", message)
return err
}
// RevisionId returns the Bazaar revision id for the tip of b.
func (b *Branch) RevisionId() (string, error) {
stdout, stderr, err := b.bzr("revision-info", "-d", b.location)
if err != nil {
return "", err
}
pair := bytes.Fields(stdout)
if len(pair) != 2 {
return "", fmt.Errorf(`invalid output from "bzr revision-info": %s%s`, stdout, stderr)
}
id := string(pair[1])
if id == "null:" {
return "", fmt.Errorf("branch has no content")
}
return id, nil
}
// PushLocation returns the default push location for b.
func (b *Branch) PushLocation() (string, error) {
stdout, _, err := b.bzr("info", b.location)
if err != nil {
return "", err
}
if i := bytes.Index(stdout, []byte("push branch:")); i >= 0 {
return string(stdout[i+13 : i+bytes.IndexAny(stdout[i:], "\r\n")]), nil
}
return "", fmt.Errorf("no push branch location defined")
}
// PushAttr holds options for the Branch.Push method.
type PushAttr struct {
Location string // Location to push to. Use the default push location if empty.
Remember bool // Whether to remember the location being pushed to as the default.
}
// Push pushes any new revisions in b to attr.Location if that's
// provided, or to the default push location otherwise.
// See PushAttr for other options.
func (b *Branch) Push(attr *PushAttr) error {
var args []string
if attr != nil {
if attr.Remember {
args = append(args, "--remember")
}
if attr.Location != "" {
args = append(args, attr.Location)
}
}
_, _, err := b.bzr("push", args...)
return err
}
// CheckClean returns an error if 'bzr status' is not clean.
func (b *Branch) CheckClean() error {
stdout, _, err := b.bzr("status", b.location)
if err != nil {
return err
}
if bytes.Count(stdout, []byte{'\n'}) == 1 && bytes.Contains(stdout, []byte(`See "bzr shelve --list" for details.`)) {
return nil // Shelves are fine.
}
if len(stdout) > 0 {
return fmt.Errorf("branch is not clean (bzr status)")
}
return nil
}
================================================
FILE: bzr/bzr_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
package bzr_test
import (
"io/ioutil"
"os"
"os/exec"
"path/filepath"
stdtesting "testing"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/bzr"
)
func Test(t *stdtesting.T) {
gc.TestingT(t)
}
var _ = gc.Suite(&BzrSuite{})
type BzrSuite struct {
testing.CleanupSuite
b *bzr.Branch
}
const bzr_config = `[DEFAULT]
email = testing
`
func (s *BzrSuite) SetUpTest(c *gc.C) {
s.CleanupSuite.SetUpTest(c)
bzrdir := c.MkDir()
s.PatchEnvironment("BZR_HOME", bzrdir)
err := os.MkdirAll(filepath.Join(bzrdir, bzrHome), 0755)
c.Assert(err, jc.ErrorIsNil)
err = ioutil.WriteFile(
filepath.Join(bzrdir, bzrHome, "bazaar.conf"),
[]byte(bzr_config), 0644)
c.Assert(err, jc.ErrorIsNil)
s.b = bzr.New(c.MkDir())
c.Assert(s.b.Init(), gc.IsNil)
}
func (s *BzrSuite) TestNewFindsRoot(c *gc.C) {
err := os.Mkdir(s.b.Join("dir"), 0755)
c.Assert(err, jc.ErrorIsNil)
b := bzr.New(s.b.Join("dir"))
// When bzr has to search for the root, it will expand any symlinks it
// found along the way.
path, err := filepath.EvalSymlinks(s.b.Location())
c.Assert(err, jc.ErrorIsNil)
c.Assert(b.Location(), jc.SamePath, path)
}
func (s *BzrSuite) TestJoin(c *gc.C) {
path := bzr.New("lp:foo").Join("baz", "bar")
c.Assert(path, gc.Equals, "lp:foo/baz/bar")
}
func (s *BzrSuite) TestErrorHandling(c *gc.C) {
err := bzr.New("/non/existent/path").Init()
c.Assert(err, gc.ErrorMatches, `(?s)error running "bzr init":.*does not exist.*`)
}
func (s *BzrSuite) TestInit(c *gc.C) {
_, err := os.Stat(s.b.Join(".bzr"))
c.Assert(err, jc.ErrorIsNil)
}
func (s *BzrSuite) TestRevisionIdOnEmpty(c *gc.C) {
revid, err := s.b.RevisionId()
c.Assert(err, gc.ErrorMatches, "branch has no content")
c.Assert(revid, gc.Equals, "")
}
func (s *BzrSuite) TestCommit(c *gc.C) {
f, err := os.Create(s.b.Join("myfile"))
c.Assert(err, jc.ErrorIsNil)
f.Close()
err = s.b.Add("myfile")
c.Assert(err, jc.ErrorIsNil)
err = s.b.Commit("my log message")
c.Assert(err, jc.ErrorIsNil)
revid, err := s.b.RevisionId()
c.Assert(err, jc.ErrorIsNil)
cmd := exec.Command("bzr", "log", "--long", "--show-ids", "-v", s.b.Location())
output, err := cmd.CombinedOutput()
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(output), gc.Matches, "(?s).*revision-id: "+revid+"\n.*message:\n.*my log message\n.*added:\n.*myfile .*")
}
func (s *BzrSuite) TestPush(c *gc.C) {
b1 := bzr.New(c.MkDir())
b2 := bzr.New(c.MkDir())
b3 := bzr.New(c.MkDir())
c.Assert(b1.Init(), gc.IsNil)
c.Assert(b2.Init(), gc.IsNil)
c.Assert(b3.Init(), gc.IsNil)
// Create and add b1/file to the branch.
f, err := os.Create(b1.Join("file"))
c.Assert(err, jc.ErrorIsNil)
f.Close()
err = b1.Add("file")
c.Assert(err, jc.ErrorIsNil)
err = b1.Commit("added file")
c.Assert(err, jc.ErrorIsNil)
// Push file to b2.
err = b1.Push(&bzr.PushAttr{Location: b2.Location()})
c.Assert(err, jc.ErrorIsNil)
// Push location should be set to b2.
location, err := b1.PushLocation()
c.Assert(err, jc.ErrorIsNil)
c.Assert(location, jc.SamePath, b2.Location())
// Now push it to b3.
err = b1.Push(&bzr.PushAttr{Location: b3.Location()})
c.Assert(err, jc.ErrorIsNil)
// Push location is still set to b2.
location, err = b1.PushLocation()
c.Assert(err, jc.ErrorIsNil)
c.Assert(location, jc.SamePath, b2.Location())
// Push it again, this time with the remember flag set.
err = b1.Push(&bzr.PushAttr{Location: b3.Location(), Remember: true})
c.Assert(err, jc.ErrorIsNil)
// Now the push location has shifted to b3.
location, err = b1.PushLocation()
c.Assert(err, jc.ErrorIsNil)
c.Assert(location, jc.SamePath, b3.Location())
// Both b2 and b3 should have the file.
_, err = os.Stat(b2.Join("file"))
c.Assert(err, jc.ErrorIsNil)
_, err = os.Stat(b3.Join("file"))
c.Assert(err, jc.ErrorIsNil)
}
func (s *BzrSuite) TestCheckClean(c *gc.C) {
err := s.b.CheckClean()
c.Assert(err, jc.ErrorIsNil)
// Create and add b1/file to the branch.
f, err := os.Create(s.b.Join("file"))
c.Assert(err, jc.ErrorIsNil)
f.Close()
err = s.b.CheckClean()
c.Assert(err, gc.ErrorMatches, `branch is not clean \(bzr status\)`)
}
================================================
FILE: bzr/bzr_unix_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build !windows
// +build !windows
package bzr_test
const bzrHome = ".bazaar"
================================================
FILE: bzr/bzr_windows_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build windows
// +build windows
package bzr_test
const bzrHome = "Bazaar/2.0"
================================================
FILE: cache/cache.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// Package cache provides a simple caching mechanism
// that limits the age of cache entries and tries to avoid large
// repopulation events by staggering refresh times.
package cache
import (
"math/rand"
"sync"
"time"
"github.com/juju/errors"
)
// entry holds a cache entry. The expire field
// holds the time after which the entry will be
// considered invalid.
type entry struct {
value any
expire time.Time
}
// Key represents a cache key. It must be a comparable type.
type Key any
// Cache holds a time-limited set of values for arbitrary keys.
type Cache struct {
maxAge time.Duration
// mu guards the fields below it.
mu sync.Mutex
// expire holds when the cache is due to expire.
expire time.Time
// We hold two maps so that can avoid scanning through all the
// items in the cache when the cache needs to be refreshed.
// Instead, we move items from old to new when they're accessed
// and throw away the old map at refresh time.
old, new map[Key]entry
inFlight map[Key]*fetchCall
}
// fetch represents an in-progress fetch call. If a cache Get request
// is made for an item that is currently being fetched, this will
// be used to avoid an extra call to the fetch function.
type fetchCall struct {
wg sync.WaitGroup
val any
err error
}
// New returns a new Cache that will cache items for
// at most maxAge. If maxAge is zero, items will
// never be cached.
func New(maxAge time.Duration) *Cache {
// The returned cache will have a zero-valued expire
// time, so will expire immediately, causing the new
// map to be created.
return &Cache{
maxAge: maxAge,
inFlight: make(map[Key]*fetchCall),
}
}
// Len returns the total number of cached entries.
func (c *Cache) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.old) + len(c.new)
}
// Evict removes the entry with the given key from the cache if present.
func (c *Cache) Evict(key Key) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.new, key)
delete(c.old, key)
}
// EvictAll removes all entries from the cache.
func (c *Cache) EvictAll() {
c.mu.Lock()
defer c.mu.Unlock()
c.new = make(map[Key]entry)
c.old = nil
}
// Get returns the value for the given key, using fetch to fetch
// the value if it is not found in the cache.
// If fetch returns an error, the returned error from Get will have
// the same cause.
func (c *Cache) Get(key Key, fetch func() (any, error)) (any, error) {
return c.getAtTime(key, fetch, time.Now())
}
// getAtTime is the internal version of Get, useful for testing; now represents the current
// time.
func (c *Cache) getAtTime(key Key, fetch func() (any, error), now time.Time) (any, error) {
if val, ok := c.cachedValue(key, now); ok {
return val, nil
}
c.mu.Lock()
if f, ok := c.inFlight[key]; ok {
// There's already an in-flight request for the key, so wait
// for that to complete and use its results.
c.mu.Unlock()
f.wg.Wait()
// The value will have been added to the cache by the first fetch,
// so no need to add it here.
if f.err == nil {
return f.val, nil
}
return nil, errors.Trace(f.err)
}
var f fetchCall
f.wg.Add(1)
c.inFlight[key] = &f
// Mark the request as done when we return, and after
// the value has been added to the cache.
defer f.wg.Done()
// Fetch the data without the mutex held
// so that one slow fetch doesn't hold up
// all the other cache accesses.
c.mu.Unlock()
val, err := fetch()
c.mu.Lock()
defer c.mu.Unlock()
// Set the result in the fetchCall so that other calls can see it.
f.val, f.err = val, err
if err == nil && c.maxAge >= 2*time.Nanosecond {
// If maxAge is < 2ns then the expiry code will panic because the
// actual expiry time will be maxAge - a random value in the
// interval [0, maxAge/2). If maxAge is < 2ns then this requires
// a random interval in [0, 0) which causes a panic.
//
// This value is so small that there's no need to cache anyway,
// which makes tests more obviously deterministic when using
// a zero expiry time.
c.new[key] = entry{
value: val,
expire: now.Add(c.maxAge - time.Duration(rand.Int63n(int64(c.maxAge/2)))),
}
}
delete(c.inFlight, key)
if err == nil {
return f.val, nil
}
return nil, errors.Trace(err)
}
// cachedValue returns any cached value for the given key
// and whether it was found.
func (c *Cache) cachedValue(key Key, now time.Time) (any, bool) {
c.mu.Lock()
defer c.mu.Unlock()
if now.After(c.expire) {
c.old = c.new
c.new = make(map[Key]entry)
c.expire = now.Add(c.maxAge)
}
if e, ok := c.entry(c.new, key, now); ok {
return e.value, true
}
if e, ok := c.entry(c.old, key, now); ok {
// An old entry has been accessed; move it to the new
// map so that we only use a single map access for
// subsequent lookups. Note that because we use the same
// duration for cache refresh (c.expire) as for max
// entry age, this is strictly speaking unnecessary
// because any entries in old will have expired by the
// time it is dropped.
c.new[key] = e
delete(c.old, key)
return e.value, true
}
return nil, false
}
// entry returns an entry from the map and whether it
// was found. If the entry has expired, it is deleted from the map.
func (c *Cache) entry(m map[Key]entry, key Key, now time.Time) (entry, bool) {
e, ok := m[key]
if !ok {
return entry{}, false
}
if now.After(e.expire) {
// Delete expired entries.
delete(m, key)
return entry{}, false
}
return e, true
}
================================================
FILE: cache/cache_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package cache_test
import (
"fmt"
"sync"
"time"
"github.com/juju/errors"
"github.com/juju/utils/v4/cache"
gc "gopkg.in/check.v1"
)
type suite struct{}
var _ = gc.Suite(&suite{})
func (*suite) TestSimpleGet(c *gc.C) {
p := cache.New(time.Hour)
v, err := p.Get("a", fetchValue(2))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, 2)
}
func (*suite) TestEvict(c *gc.C) {
p := cache.New(time.Hour)
v, err := p.Get("a", fetchValue(2))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, 2)
v, err = p.Get("a", fetchValue(4))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, 2)
p.Evict("a")
v, err = p.Get("a", fetchValue(3))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, 3)
v, err = p.Get("a", fetchValue(4))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, 3)
}
func (*suite) TestEvictOld(c *gc.C) {
// Test that evict removes entries even when they're
// in the old map.
now := time.Now()
p := cache.New(time.Minute)
// Populate the cache with an initial entry.
v, err := cache.GetAtTime(p, "a", fetchValue("a"), now)
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "a")
c.Assert(p.Len(), gc.Equals, 1)
v, err = cache.GetAtTime(p, "b", fetchValue("b"), now.Add(time.Minute/2))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "b")
c.Assert(p.Len(), gc.Equals, 2)
// Fetch an item after the expiry time,
// causing current entries to be moved to old.
v, err = cache.GetAtTime(p, "a", fetchValue("a1"), now.Add(time.Minute+1))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "a1")
c.Assert(p.Len(), gc.Equals, 2)
c.Assert(cache.OldLen(p), gc.Equals, 1)
p.Evict("b")
v, err = cache.GetAtTime(p, "b", fetchValue("b1"), now.Add(time.Minute+2))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "b1")
}
func (*suite) TestFetchError(c *gc.C) {
p := cache.New(time.Hour)
expectErr := errors.New("hello")
v, err := p.Get("a", fetchError(expectErr))
c.Assert(err, gc.ErrorMatches, "hello")
c.Assert(errors.Cause(err), gc.Equals, expectErr)
c.Assert(v, gc.Equals, nil)
}
func (*suite) TestFetchOnlyOnce(c *gc.C) {
p := cache.New(time.Hour)
v, err := p.Get("a", fetchValue(2))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, 2)
v, err = p.Get("a", fetchError(errUnexpectedFetch))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, 2)
}
func (*suite) TestEntryExpiresAfterMaxEntryAge(c *gc.C) {
now := time.Now()
p := cache.New(time.Minute)
v, err := cache.GetAtTime(p, "a", fetchValue(2), now)
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, 2)
// Entry is definitely not expired before half the entry expiry time.
v, err = cache.GetAtTime(p, "a", fetchError(errUnexpectedFetch), now.Add(time.Minute/2-1))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, 2)
// Entry is definitely expired after the entry expiry time
v, err = cache.GetAtTime(p, "a", fetchValue(3), now.Add(time.Minute+1))
c.Assert(v, gc.Equals, 3)
}
func (*suite) TestEntriesRemovedWhenNotRetrieved(c *gc.C) {
now := time.Now()
p := cache.New(time.Minute)
// Populate the cache with an initial entry.
v, err := cache.GetAtTime(p, "a", fetchValue("a"), now)
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "a")
c.Assert(p.Len(), gc.Equals, 1)
// Fetch another item after the expiry time,
// causing current entries to be moved to old.
v, err = cache.GetAtTime(p, "b", fetchValue("b"), now.Add(time.Minute+1))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "b")
c.Assert(p.Len(), gc.Equals, 2)
c.Assert(cache.OldLen(p), gc.Equals, 1)
// Fetch the other item after another expiry time
// causing the old entries to be discarded because
// nothing has fetched them.
v, err = cache.GetAtTime(p, "b", fetchValue("b"), now.Add(time.Minute*2+2))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "b")
c.Assert(p.Len(), gc.Equals, 1)
}
// TestRefreshedEntry tests the code path where a value is moved
// from the old map to new.
func (*suite) TestRefreshedEntry(c *gc.C) {
now := time.Now()
p := cache.New(time.Minute)
// Populate the cache with an initial entry.
v, err := cache.GetAtTime(p, "a", fetchValue("a"), now)
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "a")
c.Assert(p.Len(), gc.Equals, 1)
// Fetch another item very close to the expiry time.
v, err = cache.GetAtTime(p, "b", fetchValue("b"), now.Add(time.Minute-1))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "b")
c.Assert(p.Len(), gc.Equals, 2)
// Fetch it again just after the expiry time,
// which should move it into the new map.
v, err = cache.GetAtTime(p, "b", fetchError(errUnexpectedFetch), now.Add(time.Minute+1))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "b")
c.Assert(p.Len(), gc.Equals, 2)
// Fetch another item, causing "a" to be removed from the cache
// and keeping "b" in there.
v, err = cache.GetAtTime(p, "c", fetchValue("c"), now.Add(time.Minute*2+2))
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, "c")
c.Assert(p.Len(), gc.Equals, 2)
}
// TestConcurrentFetch checks that the cache is safe
// to use concurrently. It is designed to fail when
// tested with the race detector enabled.
func (*suite) TestConcurrentFetch(c *gc.C) {
p := cache.New(time.Minute)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
v, err := p.Get("a", fetchValue("a"))
c.Check(err, gc.IsNil)
c.Check(v, gc.Equals, "a")
}()
wg.Add(1)
go func() {
defer wg.Done()
v, err := p.Get("b", fetchValue("b"))
c.Check(err, gc.IsNil)
c.Check(v, gc.Equals, "b")
}()
wg.Wait()
}
func (*suite) TestRefreshSpread(c *gc.C) {
now := time.Now()
p := cache.New(time.Minute)
// Get all values to start with.
const N = 100
for i := 0; i < N; i++ {
v, err := cache.GetAtTime(p, fmt.Sprint(i), fetchValue(i), now)
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, i)
}
counts := make([]int, time.Minute/time.Millisecond/10+1)
// Continually get values over the course of the
// expiry time; the fetches should be spread out.
slot := 0
for t := now.Add(0); t.Before(now.Add(time.Minute + 1)); t = t.Add(time.Millisecond * 10) {
for i := 0; i < N; i++ {
cache.GetAtTime(p, fmt.Sprint(i), func() (any, error) {
counts[slot]++
return i, nil
}, t)
}
slot++
}
// There should be no fetches in the first half of the cycle.
for i := 0; i < len(counts)/2; i++ {
c.Assert(counts[i], gc.Equals, 0, gc.Commentf("slot %d", i))
}
max := 0
total := 0
for _, count := range counts {
if count > max {
max = count
}
total += count
}
if max > 10 {
c.Errorf("requests grouped too closely (max %d)", max)
}
c.Assert(total, gc.Equals, N)
}
func (*suite) TestSingleFlight(c *gc.C) {
p := cache.New(time.Minute)
start := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
x, err := p.Get("x", func() (any, error) {
start <- struct{}{}
<-start
return 99, nil
})
c.Check(x, gc.Equals, 99)
c.Check(err, gc.Equals, nil)
}()
// Wait for the fetch to start.
<-start
wg.Add(1)
go func() {
defer wg.Done()
x, err := p.Get("x", func() (any, error) {
c.Errorf("fetch function unexpectedly called with inflight request")
return 55, nil
})
c.Check(x, gc.Equals, 99)
c.Check(err, gc.Equals, nil)
}()
// Check that we can still get other values while the
// other fetches are in progress.
y, err := p.Get("y", func() (any, error) {
return 88, nil
})
c.Check(y, gc.Equals, 88)
c.Check(err, gc.Equals, nil)
// Let the original fetch proceed, which should let the other one
// succeed too, but sleep for a little bit to let the second goroutine
// actually initiate its request.
time.Sleep(time.Millisecond)
start <- struct{}{}
wg.Wait()
}
var errUnexpectedFetch = errors.New("fetch called unexpectedly")
func fetchError(err error) func() (any, error) {
return func() (any, error) {
return nil, err
}
}
func fetchValue(val any) func() (any, error) {
return func() (any, error) {
return val, nil
}
}
================================================
FILE: cache/export_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package cache
var GetAtTime = (*Cache).getAtTime
func OldLen(c *Cache) int {
return len(c.old)
}
================================================
FILE: cache/package_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package cache_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: cert/cert.go
================================================
// Copyright 2012, 2013 Canonical Ltd.
// Copyright 2016 Cloudbase solutions
// Licensed under the LGPLv3, see LICENCE file for details.
package cert
import (
"crypto"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"fmt"
"github.com/juju/errors"
)
// OtherName type for asn1 encoding
type OtherName struct {
A string `asn1:"utf8"`
}
// GeneralName type for asn1 encoding
type GeneralName struct {
OID asn1.ObjectIdentifier
OtherName `asn1:"tag:0"`
}
// GeneralNames type for asn1 encoding
type GeneralNames struct {
GeneralName `asn1:"tag:0"`
}
var (
// https://support.microsoft.com/en-us/kb/287547
// szOID_NT_PRINCIPAL_NAME 1.3.6.1.4.1.311.20.2.3
szOID = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 20, 2, 3}
// http://www.umich.edu/~x509/ssleay/asn1-oids.html
// 2 5 29 17 subjectAltName
subjAltName = asn1.ObjectIdentifier{2, 5, 29, 17}
)
// getUPNExtensionValue returns marsheled asn1 encoded info
func getUPNExtensionValue(subject pkix.Name) ([]byte, error) {
// returns the ASN.1 encoding of val
// in addition to the struct tags recognized
// we used:
// utf8 => causes string to be marsheled as ASN.1, UTF8 strings
// tag:x => specifies the ASN.1 tag number; imples ASN.1 CONTEXT SPECIFIC
return asn1.Marshal(GeneralNames{
GeneralName: GeneralName{
// init our ASN.1 object identifier
OID: szOID,
OtherName: OtherName{
A: subject.CommonName,
},
},
})
}
// ParseCert parses the given PEM-formatted X509 certificate.
func ParseCert(certPEM string) (*x509.Certificate, error) {
certPEMData := []byte(certPEM)
for len(certPEMData) > 0 {
var certBlock *pem.Block
certBlock, certPEMData = pem.Decode(certPEMData)
if certBlock == nil {
break
}
if certBlock.Type == "CERTIFICATE" {
cert, err := x509.ParseCertificate(certBlock.Bytes)
return cert, err
}
}
return nil, errors.New("no certificates found")
}
// ParseCertAndKey parses the given PEM-formatted X509 certificate
// and RSA private key.
func ParseCertAndKey(certPEM, keyPEM string) (*x509.Certificate, crypto.Signer, error) {
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return nil, nil, err
}
cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return nil, nil, err
}
key, ok := tlsCert.PrivateKey.(crypto.Signer)
if !ok {
return nil, nil, fmt.Errorf("private key with unexpected type %T", tlsCert.PrivateKey)
}
return cert, key, nil
}
================================================
FILE: cert/cert_test.go
================================================
// Copyright 2012, 2013 Canonical Ltd.
// Copyright 2016 Cloudbase solutions
// Licensed under the LGPLv3, see LICENCE file for details.
package cert_test
import (
"testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/cert"
)
func TestAll(t *testing.T) {
gc.TestingT(t)
}
type certSuite struct{}
var _ = gc.Suite(certSuite{})
func (certSuite) TestParseCertificate(c *gc.C) {
xcert, err := cert.ParseCert(caCertPEM)
c.Assert(err, jc.ErrorIsNil)
c.Assert(xcert.Subject.CommonName, gc.Equals, `juju-generated CA for model "juju testing"`)
xcert, err = cert.ParseCert(caKeyPEM)
c.Check(xcert, gc.IsNil)
c.Assert(err, gc.ErrorMatches, "no certificates found")
xcert, err = cert.ParseCert("hello")
c.Check(xcert, gc.IsNil)
c.Assert(err, gc.ErrorMatches, "no certificates found")
}
func (certSuite) TestParseCertAndKey(c *gc.C) {
xcert, key, err := cert.ParseCertAndKey(caCertPEM, caKeyPEM)
c.Assert(err, jc.ErrorIsNil)
c.Assert(xcert.Subject.CommonName, gc.Equals, `juju-generated CA for model "juju testing"`)
c.Assert(key, gc.NotNil)
c.Assert(xcert.PublicKey, gc.DeepEquals, key.Public())
}
var (
caCertPEM = `
-----BEGIN CERTIFICATE-----
MIICHDCCAcagAwIBAgIUfzWn5ktGMxD6OiTgfiZyvKdM+ZYwDQYJKoZIhvcNAQEL
BQAwazENMAsGA1UEChMEanVqdTEzMDEGA1UEAwwqanVqdS1nZW5lcmF0ZWQgQ0Eg
Zm9yIG1vZGVsICJqdWp1IHRlc3RpbmciMSUwIwYDVQQFExwxMjM0LUFCQ0QtSVMt
Tk9ULUEtUkVBTC1VVUlEMB4XDTE2MDkyMTEwNDgyN1oXDTI2MDkyODEwNDgyN1ow
azENMAsGA1UEChMEanVqdTEzMDEGA1UEAwwqanVqdS1nZW5lcmF0ZWQgQ0EgZm9y
IG1vZGVsICJqdWp1IHRlc3RpbmciMSUwIwYDVQQFExwxMjM0LUFCQ0QtSVMtTk9U
LUEtUkVBTC1VVUlEMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAL+0X+1zl2vt1wI4
1Q+RnlltJyaJmtwCbHRhREXVGU7t0kTMMNERxqLnuNUyWRz90Rg8s9XvOtCqNYW7
mypGrFECAwEAAaNCMEAwDgYDVR0PAQH/BAQDAgKkMA8GA1UdEwEB/wQFMAMBAf8w
HQYDVR0OBBYEFHueMLZ1QJ/2sKiPIJ28TzjIMRENMA0GCSqGSIb3DQEBCwUAA0EA
ovZN0RbUHrO8q9Eazh0qPO4mwW9jbGTDz126uNrLoz1g3TyWxIas1wRJ8IbCgxLy
XUrBZO5UPZab66lJWXyseA==
-----END CERTIFICATE-----
`
caKeyPEM = `
-----BEGIN RSA PRIVATE KEY-----
MIIBOgIBAAJBAL+0X+1zl2vt1wI41Q+RnlltJyaJmtwCbHRhREXVGU7t0kTMMNER
xqLnuNUyWRz90Rg8s9XvOtCqNYW7mypGrFECAwEAAQJAMPa+JaUHgO6foxam/LIB
0u95N3OgFR+dWeBaEsgKDclpREdJ0rXNI+3C3kwqeEZR4omoPlBeSEewSkwHxpmI
0QIhAOjKiHZ5v6R8haleipbDzkGUnZW07hEwL5Ld4MNx/QQ1AiEA0tEzSSNAdM0C
M/vY0x5mekIYai8/tFSEG9PJ3ZkpEy0CIQCo9B3YxwI1Un777vbs903iQQeiWP+U
EAHnOQvhLgDxpQIgGkpml+9igW5zoOH+h02aQBLwEoXz7tw/YW0HFrCcE70CIGkS
ve4WjiEqnQaHNAPy0hY/1DfIgBOSpOfnkFHOk9vX
-----END RSA PRIVATE KEY-----
`
)
================================================
FILE: cert/exports_test.go
================================================
// Copyright 2016 Canonical ltd.
// Copyright 2016 Cloudbase solutions
// Licensed under the LGPLv3, see LICENCE file for details.
package cert
================================================
FILE: command.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"os/exec"
)
// RunCommand executes the command and return the combined output.
func RunCommand(command string, args ...string) (output string, err error) {
cmd := exec.Command(command, args...)
out, err := cmd.CombinedOutput()
output = string(out)
if err != nil {
return output, err
}
return output, nil
}
================================================
FILE: command_test.go
================================================
// Copyright 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"io/ioutil"
"path/filepath"
"runtime"
"github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type EnvironmentPatcher interface {
PatchEnvironment(name, value string)
}
func patchExecutable(patcher EnvironmentPatcher, dir, execName, script string) {
patcher.PatchEnvironment("PATH", dir)
filename := filepath.Join(dir, execName)
ioutil.WriteFile(filename, []byte(script), 0755)
}
type commandSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&commandSuite{})
func (s *commandSuite) TestRunCommandCombinesOutput(c *gc.C) {
var content string
var cmdName string
var expect string
if runtime.GOOS != "windows" {
content = `#!/bin/bash --norc
echo stdout
echo stderr 1>&2
`
cmdName = "test-output"
expect = "stdout\nstderr\n"
} else {
content = `@echo off
echo stdout
echo stderr 1>&2
`
cmdName = "test-output.bat"
expect = "stdout\r\nstderr \r\n"
}
patchExecutable(s, c.MkDir(), cmdName, content)
output, err := utils.RunCommand("test-output")
c.Assert(err, gc.IsNil)
c.Assert(output, gc.Equals, expect)
}
func (s *commandSuite) TestRunCommandNonZeroExit(c *gc.C) {
var content string
var cmdName string
var expect string
if runtime.GOOS != "windows" {
content = `#!/bin/bash --norc
echo stdout
exit 42
`
cmdName = "test-output"
expect = "stdout\n"
} else {
content = `@echo off
echo stdout
exit 42
`
cmdName = "test-output.bat"
expect = "stdout\r\n"
}
patchExecutable(s, c.MkDir(), cmdName, content)
output, err := utils.RunCommand("test-output")
c.Assert(err, gc.ErrorMatches, `exit status 42`)
c.Assert(output, gc.Equals, expect)
}
================================================
FILE: context.go
================================================
// Copyright 2018 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"fmt"
"sync"
"time"
"golang.org/x/net/context"
"github.com/juju/clock"
)
// timerCtx is an implementation of context.Context that
// is done when a given deadline has passed
// (as measured by the Clock in the clock field)
type timerCtx struct {
clock clock.Clock
timer clock.Timer
deadline time.Time
parent context.Context
done chan struct{}
// mu guards err.
mu sync.Mutex
// err holds context.Canceled or context.DeadlineExceeded
// after the context has been canceled.
// If this is non-nil, then done will have been closed.
err error
}
func (ctx *timerCtx) Deadline() (time.Time, bool) {
return ctx.deadline, true
}
func (ctx *timerCtx) Err() error {
ctx.mu.Lock()
defer ctx.mu.Unlock()
return ctx.err
}
func (ctx *timerCtx) Value(key any) any {
return ctx.parent.Value(key)
}
func (ctx *timerCtx) Done() <-chan struct{} {
return ctx.done
}
func (ctx *timerCtx) cancel(err error) {
ctx.mu.Lock()
defer ctx.mu.Unlock()
if err == nil {
panic("cancel with nil error!")
}
if ctx.err != nil {
// Already canceled - no need to do anything.
return
}
ctx.err = err
if ctx.timer != nil {
ctx.timer.Stop()
}
close(ctx.done)
}
func (ctx *timerCtx) String() string {
return fmt.Sprintf("%v.WithDeadline(%s [%s])", ctx.parent, ctx.deadline, ctx.deadline.Sub(ctx.clock.Now()))
}
// ContextWithTimeout is like context.WithTimeout
// except that it works with a clock.Clock rather than
// wall-clock time.
func ContextWithTimeout(parent context.Context, clk clock.Clock, timeout time.Duration) (context.Context, context.CancelFunc) {
return ContextWithDeadline(parent, clk, clk.Now().Add(timeout))
}
// ContextWithDeadline is like context.WithDeadline
// except that it works with a clock.Clock rather than
// wall-clock time.
func ContextWithDeadline(parent context.Context, clk clock.Clock, deadline time.Time) (context.Context, context.CancelFunc) {
d := deadline.Sub(clk.Now())
ctx := &timerCtx{
clock: clk,
parent: parent,
deadline: deadline,
done: make(chan struct{}),
}
if d <= 0 {
// deadline has already passed
ctx.cancel(context.DeadlineExceeded)
return ctx, func() {}
}
ctx.timer = clk.NewTimer(d)
go func() {
select {
case <-ctx.timer.Chan():
ctx.cancel(context.DeadlineExceeded)
case <-parent.Done():
ctx.cancel(parent.Err())
case <-ctx.done:
}
}()
return ctx, func() {
ctx.cancel(context.Canceled)
}
}
================================================
FILE: context_test.go
================================================
// Copyright 2018 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"fmt"
"time"
"golang.org/x/net/context"
"github.com/juju/clock/testclock"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type contextSuite struct{}
var _ = gc.Suite(&contextSuite{})
// Note: the logic in these tests was copied from the tests
// in the Go standard library.
func (*contextSuite) TestDeadline(c *gc.C) {
clk := testclock.NewClock(time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC))
ctx, cancel := utils.ContextWithDeadline(context.Background(), clk, clk.Now().Add(50*time.Millisecond))
defer cancel()
c.Assert(fmt.Sprint(ctx), gc.Equals, `context.Background.WithDeadline(2000-01-01 00:00:00.05 +0000 UTC [50ms])`)
testContextDeadline(c, ctx, "WithDeadline", clk, 1, 50*time.Millisecond)
ctx, cancel = utils.ContextWithDeadline(context.Background(), clk, clk.Now().Add(50*time.Millisecond))
defer cancel()
o := otherContext{ctx}
testContextDeadline(c, o, "WithDeadline+otherContext", clk, 1, 50*time.Millisecond)
ctx, cancel = utils.ContextWithDeadline(context.Background(), clk, clk.Now().Add(50*time.Millisecond))
defer cancel()
o = otherContext{ctx}
ctx, _ = utils.ContextWithDeadline(o, clk, clk.Now().Add(4*time.Second))
testContextDeadline(c, ctx, "WithDeadline+otherContext+WithDeadline", clk, 2, 50*time.Millisecond)
ctx, cancel = utils.ContextWithDeadline(context.Background(), clk, clk.Now().Add(-time.Millisecond))
defer cancel()
testContextDeadline(c, ctx, "WithDeadline+inthepast", clk, 0, 0)
ctx, cancel = utils.ContextWithDeadline(context.Background(), clk, clk.Now())
testContextDeadline(c, ctx, "WithDeadline+now", clk, 0, 0)
}
func (*contextSuite) TestTimeout(c *gc.C) {
clk := testclock.NewClock(time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC))
ctx, _ := utils.ContextWithTimeout(context.Background(), clk, 50*time.Millisecond)
c.Assert(fmt.Sprint(ctx), gc.Equals, `context.Background.WithDeadline(2000-01-01 00:00:00.05 +0000 UTC [50ms])`)
testContextDeadline(c, ctx, "WithTimeout", clk, 1, 50*time.Millisecond)
ctx, _ = utils.ContextWithTimeout(context.Background(), clk, 50*time.Millisecond)
o := otherContext{ctx}
testContextDeadline(c, o, "WithTimeout+otherContext", clk, 1, 50*time.Millisecond)
ctx, _ = utils.ContextWithTimeout(context.Background(), clk, 50*time.Millisecond)
o = otherContext{ctx}
ctx, _ = utils.ContextWithTimeout(o, clk, 3*time.Second)
testContextDeadline(c, ctx, "WithTimeout+otherContext+WithTimeout", clk, 2, 50*time.Millisecond)
}
func (*contextSuite) TestCanceledTimeout(c *gc.C) {
clk := testclock.NewClock(time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC))
ctx, _ := utils.ContextWithTimeout(context.Background(), clk, time.Second)
o := otherContext{ctx}
ctx, cancel := utils.ContextWithTimeout(o, clk, 2*time.Second)
cancel()
time.Sleep(100 * time.Millisecond) // let cancelation propagate
select {
case <-ctx.Done():
default:
c.Errorf("<-ctx.Done() blocked, but shouldn't have")
}
c.Assert(ctx.Err(), gc.Equals, context.Canceled)
}
func testContextDeadline(c *gc.C, ctx context.Context, name string, clk *testclock.Clock, waiters int, failAfter time.Duration) {
err := clk.WaitAdvance(failAfter, 0, waiters)
c.Assert(err, jc.ErrorIsNil)
select {
case <-time.After(time.Second):
c.Fatalf("%s: context should have timed out", name)
case <-ctx.Done():
}
c.Assert(ctx.Err(), gc.Equals, context.DeadlineExceeded)
}
// otherContext is a Context that's not one of the types defined in context.go.
// This lets us test code paths that differ based on the underlying type of the
// Context.
type otherContext struct {
context.Context
}
================================================
FILE: du/LICENSE.ricochet2200
================================================
This is free and unencumbered software released into the public domain.
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to
================================================
FILE: du/diskusage.go
================================================
// Copied from https://github.com/ricochet2200/go-disk-usage
// Copyright 2011 Rick Smith.
// Use of this source code is governed by a public domain
// license that can be found in the LICENSE.ricochet2200 file.
//
//go:build !windows
// +build !windows
package du
import "syscall"
type DiskUsage struct {
stat *syscall.Statfs_t
}
// Returns an object holding the disk usage of volumePath
// This function assumes volumePath is a valid path
func NewDiskUsage(volumePath string) *DiskUsage {
var stat syscall.Statfs_t
syscall.Statfs(volumePath, &stat)
return &DiskUsage{&stat}
}
// Total free bytes on file system
func (this *DiskUsage) Free() uint64 {
return this.stat.Bfree * uint64(this.stat.Bsize)
}
// Total available bytes on file system to an unpriveleged user
func (this *DiskUsage) Available() uint64 {
return this.stat.Bavail * uint64(this.stat.Bsize)
}
// Total size of the file system
func (this *DiskUsage) Size() uint64 {
return this.stat.Blocks * uint64(this.stat.Bsize)
}
// Total bytes used in file system
func (this *DiskUsage) Used() uint64 {
return this.Size() - this.Free()
}
// Percentage of use on the file system
func (this *DiskUsage) Usage() float32 {
return float32(this.Used()) / float32(this.Size())
}
================================================
FILE: du/diskusage_windows.go
================================================
// Copied from https://github.com/ricochet2200/go-disk-usage
// Copyright 2011 Rick Smith.
// Use of this source code is governed by a public domain
// license that can be found in the LICENSE.ricochet2200 file.
//
package du
import (
"syscall"
"unsafe"
)
type DiskUsage struct {
freeBytes int64
totalBytes int64
availBytes int64
}
// Returns an object holding the disk usage of volumePath
// This function assumes volumePath is a valid path
func NewDiskUsage(volumePath string) *DiskUsage {
h := syscall.MustLoadDLL("kernel32.dll")
c := h.MustFindProc("GetDiskFreeSpaceExW")
du := &DiskUsage{}
c.Call(
uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(volumePath))),
uintptr(unsafe.Pointer(&du.freeBytes)),
uintptr(unsafe.Pointer(&du.totalBytes)),
uintptr(unsafe.Pointer(&du.availBytes)))
return du
}
// Total free bytes on file system
func (this *DiskUsage) Free() uint64 {
return uint64(this.freeBytes)
}
// Total available bytes on file system to an unpriveleged user
func (this *DiskUsage) Available() uint64 {
return uint64(this.availBytes)
}
// Total size of the file system
func (this *DiskUsage) Size() uint64 {
return uint64(this.totalBytes)
}
// Total bytes used in file system
func (this *DiskUsage) Used() uint64 {
return this.Size() - this.Free()
}
// Percentage of use on the file system
func (this *DiskUsage) Usage() float32 {
return float32(this.Used()) / float32(this.Size())
}
================================================
FILE: errors.go
================================================
// Copyright 2024 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"fmt"
)
// RcPassthroughError indicates that a Juju plugin command exited with a
// non-zero exit code. This error is used to exit with the return code.
type RcPassthroughError struct {
Code int
}
// Error implements error.
func (e *RcPassthroughError) Error() string {
return fmt.Sprintf("subprocess encountered error code %v", e.Code)
}
// IsRcPassthroughError returns whether the error is an RcPassthroughError.
func IsRcPassthroughError(err error) bool {
_, ok := err.(*RcPassthroughError)
return ok
}
// NewRcPassthroughError creates an error that will have the code used at the
// return code from the cmd.Main function rather than the default of 1 if
// there is an error.
func NewRcPassthroughError(code int) error {
return &RcPassthroughError{code}
}
================================================
FILE: exec/exec.go
================================================
// Copyright 2016 Canonical Ltd.
// Copyright 2016 Cloudbase Solutions
// Licensed under the LGPLv3, see LICENCE file for details.
package exec
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
"github.com/juju/clock"
"github.com/juju/errors"
"github.com/juju/loggo/v2"
)
var logger = loggo.GetLogger("juju.util.exec")
// Parameters for RunCommands. Commands contains one or more commands to be
// executed using bash or PowerShell. If WorkingDir is set, this is passed
// through. Similarly if the Environment is specified, this is used
// for executing the command.
// TODO: refactor this to use a config struct and a constructor. Remove todo
// and extra code from WaitWithCancel once this is done.
type RunParams struct {
Commands string
WorkingDir string
Environment []string
Clock clock.Clock
KillProcess func(*os.Process) error
User string
tempDir string
stdout *bytes.Buffer
stderr *bytes.Buffer
ps *exec.Cmd
}
// ExecResponse contains the return code and output generated by executing a
// command.
type ExecResponse struct {
Code int
Stdout []byte
Stderr []byte
}
// mergeEnvironment takes in a string array representing the desired environment
// and merges it with the current environment. On Windows, clearing the environment,
// or having missing environment variables, may lead to standard go packages not working
// (os.TempDir relies on $env:TEMP), and powershell erroring out
// Currently this function is only used for windows
func mergeEnvironment(env []string) []string {
if env == nil {
return nil
}
m := make(map[string]string)
var tmpEnv []string
for _, val := range os.Environ() {
varSplit := strings.SplitN(val, "=", 2)
m[varSplit[0]] = varSplit[1]
}
for _, val := range env {
varSplit := strings.SplitN(val, "=", 2)
m[varSplit[0]] = varSplit[1]
}
for key, val := range m {
tmpEnv = append(tmpEnv, key+"="+val)
}
return tmpEnv
}
// shellAndArgs returns the name of the shell command and arguments to run the
// specified script. shellAndArgs may write into the provided temporary
// directory, which will be maintained until the process exits.
func shellAndArgs(tempDir, script, user string) (string, []string, error) {
var scriptFile string
var cmd string
var args []string
switch runtime.GOOS {
case "windows":
scriptFile = filepath.Join(tempDir, "script.ps1")
cmd = "powershell.exe"
args = []string{
"-NoProfile",
"-NonInteractive",
"-ExecutionPolicy", "RemoteSigned",
"-File", scriptFile,
}
// Exceptions don't result in a non-zero exit code by default
// when using -File. The exit code of an explicit "exit" when
// using -Command is ignored and results in an exit code of 1.
// We use -File and trap exceptions to cover both.
script = "trap {Write-Error $_; exit 1}\n" + script
default:
scriptFile = filepath.Join(tempDir, "script.sh")
if user == "" {
cmd = "/bin/bash"
args = []string{scriptFile}
} else {
// Need to make the tempDir readable by all so the user can see it.
err := os.Chmod(tempDir, 0755)
if err != nil {
return "", nil, errors.Annotatef(err, "making tempdir readable by %q", user)
}
cmd = "/bin/su"
args = []string{user, "--login", "--command", fmt.Sprintf("/bin/bash %s", scriptFile)}
}
}
err := ioutil.WriteFile(scriptFile, []byte(script), 0644)
if err != nil {
return "", nil, err
}
return cmd, args, nil
}
// Run sets up the command environment (environment variables, working dir)
// and starts the process. The commands are passed into bash on Linux machines
// and to powershell on Windows machines.
func (r *RunParams) Run() error {
if runtime.GOOS == "windows" {
r.Environment = mergeEnvironment(r.Environment)
}
tempDir, err := ioutil.TempDir("", "juju-exec")
if err != nil {
return err
}
shell, args, err := shellAndArgs(tempDir, r.Commands, r.User)
if err != nil {
if err := os.RemoveAll(tempDir); err != nil {
logger.Warningf("failed to remove temporary directory: %v", err)
}
return err
}
r.ps = exec.Command(shell, args...)
if r.Environment != nil {
r.ps.Env = r.Environment
}
if r.WorkingDir != "" {
r.ps.Dir = r.WorkingDir
}
r.populateSysProcAttr()
// If there is no user provided KillProcess function we
// use the default one.
if r.KillProcess == nil {
r.KillProcess = KillProcess
}
r.tempDir = tempDir
r.stdout = &bytes.Buffer{}
r.stderr = &bytes.Buffer{}
r.ps.Stdout = r.stdout
r.ps.Stderr = r.stderr
return r.ps.Start()
}
// Process returns the *os.Process instance of the current running process
// This will allow us to kill the process if needed, or get more information
// on the process
func (r *RunParams) Process() *os.Process {
if r.ps != nil && r.ps.Process != nil {
return r.ps.Process
}
return nil
}
// Wait blocks until the process exits, and returns an ExecResponse type
// containing stdout, stderr and the return code of the process. If a non-zero
// return code is returned, this is collected as the code for the response and
// this does not classify as an error.
func (r *RunParams) Wait() (*ExecResponse, error) {
var err error
if r.ps == nil {
return nil, errors.New("No process has been started yet")
}
err = r.ps.Wait()
if err := os.RemoveAll(r.tempDir); err != nil {
logger.Warningf("failed to remove temporary directory: %v", err)
}
result := &ExecResponse{
Stdout: r.stdout.Bytes(),
Stderr: r.stderr.Bytes(),
}
if ee, ok := err.(*exec.ExitError); ok && err != nil {
status := ee.ProcessState.Sys().(syscall.WaitStatus)
if status.Exited() {
// A non-zero return code isn't considered an error here.
result.Code = status.ExitStatus()
err = nil
}
logger.Infof("run result: %v", ee)
}
return result, err
}
// ErrCancelled is returned by WaitWithCancel in case it successfully manages to kill
// the running process.
var ErrCancelled = errors.New("command cancelled")
// timeWaitForKill reperesent the time we wait after attempting to kill a
// process before bailing out and returning.
const timeWaitForKill = 30 * time.Second
type resultWithError struct {
execResult *ExecResponse
err error
}
// WaitWithCancel waits until the process exits or until a signal is sent on the
// cancel channel. In case a signal is sent it first tries to kill the process and
// return ErrCancelled. If it fails at killing the process it will return anyway
// and report the problematic PID.
func (r *RunParams) WaitWithCancel(cancel <-chan struct{}) (*ExecResponse, error) {
// TODO: Remove this once we make Clock a required field
_clock := r.Clock
if _clock == nil {
_clock = clock.WallClock
}
done := make(chan resultWithError, 1)
go func() {
defer close(done)
waitResult, err := r.Wait()
done <- resultWithError{waitResult, err}
}()
select {
case resWithError := <-done:
return resWithError.execResult, errors.Trace(resWithError.err)
case <-cancel:
logger.Debugf("attempting to kill process")
err := r.KillProcess(r.ps.Process)
if err != nil {
logger.Debugf("kill returned: %s", err)
}
// After we issue a kill we expect the wait above to return within timeWaitForKill.
// In case it doesn't we just go on and assume the process is stuck, but we don't block
select {
case resWithError := <-done:
return resWithError.execResult, ErrCancelled
case <-_clock.After(timeWaitForKill):
return nil, errors.Errorf("tried to kill process %v, but timed out", r.ps.Process.Pid)
}
}
}
// RunCommands executes the Commands specified in the RunParams using
// powershell on windows, and '/bin/bash -s' on everything else,
// passing the commands through as stdin, and collecting
// stdout and stderr. If a non-zero return code is returned, this is
// collected as the code for the response and this does not classify as an
// error.
func RunCommands(run RunParams) (*ExecResponse, error) {
err := run.Run()
if err != nil {
return nil, err
}
return run.Wait()
}
================================================
FILE: exec/exec_internal_test.go
================================================
// Copyright 2017 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package exec
import (
"os"
"path/filepath"
"runtime"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
)
type execSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&execSuite{})
func (*execSuite) TestShellAndArgsNoUserSpecified(c *gc.C) {
if runtime.GOOS == "windows" {
c.Skip("non-windows only test")
}
dir := c.MkDir()
stat, err := os.Stat(dir)
c.Assert(err, jc.ErrorIsNil)
c.Assert(stat.Mode().Perm(), gc.Equals, os.FileMode(0700))
cmd, args, err := shellAndArgs(dir, "env", "")
c.Assert(err, jc.ErrorIsNil)
scriptFile := filepath.Join(dir, "script.sh")
c.Assert(cmd, gc.Equals, "/bin/bash")
c.Assert(args, jc.DeepEquals, []string{scriptFile})
}
func (*execSuite) TestShellAndArgsAsUser(c *gc.C) {
if runtime.GOOS == "windows" {
c.Skip("non-windows only test")
}
dir := c.MkDir()
stat, err := os.Stat(dir)
c.Assert(err, jc.ErrorIsNil)
c.Assert(stat.Mode().Perm(), gc.Equals, os.FileMode(0700))
cmd, args, err := shellAndArgs(dir, "env", "ubuntu")
c.Assert(err, jc.ErrorIsNil)
scriptFile := filepath.Join(dir, "script.sh")
c.Assert(cmd, gc.Equals, "/bin/su")
command := "/bin/bash " + scriptFile
c.Assert(args, jc.DeepEquals, []string{"ubuntu", "--login", "--command", command})
// The directory is now readable by everyone.
stat, err = os.Stat(dir)
c.Assert(err, jc.ErrorIsNil)
c.Assert(stat.Mode().Perm(), gc.Equals, os.FileMode(0755))
// And the file is world readable
stat, err = os.Stat(scriptFile)
c.Assert(err, jc.ErrorIsNil)
c.Assert(stat.Mode().Perm(), gc.Equals, os.FileMode(0644))
}
================================================
FILE: exec/exec_linux_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package exec_test
import (
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/exec"
)
// 0 is thrown by linux because RunParams.Wait
// only sets the code if the process exits cleanly
const cancelErrCode = 0
func (*execSuite) TestRunCommands(c *gc.C) {
newDir := c.MkDir()
for i, test := range []struct {
message string
commands string
workingDir string
environment []string
stdout string
stderr string
code int
}{
{
message: "test stdout capture",
commands: "echo testing stdout",
stdout: "testing stdout\n",
}, {
message: "test stderr capture",
commands: "echo testing stderr >&2",
stderr: "testing stderr\n",
}, {
message: "test return code",
commands: "exit 42",
code: 42,
}, {
message: "test working dir",
commands: "pwd",
workingDir: newDir,
stdout: newDir + "\n",
}, {
message: "test environment",
commands: "echo $OMG_IT_WORKS",
environment: []string{"OMG_IT_WORKS=like magic"},
stdout: "like magic\n",
}, {
message: "multiple commands",
commands: "cat\necho 123",
stdout: "123\n",
},
} {
c.Logf("%v: %s", i, test.message)
params := exec.RunParams{
Commands: test.commands,
WorkingDir: test.workingDir,
Environment: test.environment,
}
result, err := exec.RunCommands(params)
c.Assert(err, gc.IsNil)
c.Assert(string(result.Stdout), gc.Equals, test.stdout)
c.Assert(string(result.Stderr), gc.Equals, test.stderr)
c.Assert(result.Code, gc.Equals, test.code)
err = params.Run()
c.Assert(err, gc.IsNil)
c.Assert(params.Process(), gc.Not(gc.IsNil))
result, err = params.Wait()
c.Assert(err, gc.IsNil)
c.Assert(string(result.Stdout), gc.Equals, test.stdout)
c.Assert(string(result.Stderr), gc.Equals, test.stderr)
c.Assert(result.Code, gc.Equals, test.code)
err = params.Run()
c.Assert(err, gc.IsNil)
c.Assert(params.Process(), gc.Not(gc.IsNil))
result, err = params.WaitWithCancel(nil)
c.Assert(err, gc.IsNil)
c.Assert(string(result.Stdout), gc.Equals, test.stdout)
c.Assert(string(result.Stderr), gc.Equals, test.stderr)
c.Assert(result.Code, gc.Equals, test.code)
}
}
func (*execSuite) TestExecUnknownCommand(c *gc.C) {
result, err := exec.RunCommands(
exec.RunParams{
Commands: "unknown-command",
},
)
c.Assert(err, gc.IsNil)
c.Assert(result.Stdout, gc.HasLen, 0)
c.Assert(string(result.Stderr), jc.Contains, "unknown-command: command not found")
// 127 is a special bash return code meaning command not found.
c.Assert(result.Code, gc.Equals, 127)
}
================================================
FILE: exec/exec_test.go
================================================
// Copyright 2016 Canonical Ltd.
// Copyright 2016 Cloudbase Solutions
// Licensed under the LGPLv3, see LICENCE file for details.
package exec_test
import (
"fmt"
"os"
"time"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/clock"
"github.com/juju/utils/v4/exec"
)
type execSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&execSuite{})
func (*execSuite) TestWaitWithCancel(c *gc.C) {
params := exec.RunParams{
Commands: "sleep 100",
Clock: &mockClock{C: make(chan time.Time)},
}
err := params.Run()
c.Assert(err, gc.IsNil)
c.Assert(params.Process(), gc.Not(gc.IsNil))
cancelChan := make(chan struct{}, 1)
defer close(cancelChan)
cancelChan <- struct{}{}
result, err := params.WaitWithCancel(cancelChan)
c.Assert(err, gc.Equals, exec.ErrCancelled)
c.Assert(string(result.Stdout), gc.Equals, "")
c.Assert(string(result.Stderr), gc.Equals, "")
c.Assert(result.Code, gc.Equals, cancelErrCode)
}
func (s *execSuite) TestKillAbortedIfUnsuccessfull(c *gc.C) {
killCalled := false
mockChan := make(chan time.Time, 1)
defer close(mockChan)
params := exec.RunParams{
Commands: "sleep 100",
WorkingDir: "",
Environment: []string{},
Clock: &mockClock{C: mockChan},
KillProcess: func(*os.Process) error {
killCalled = true
return nil
},
}
err := params.Run()
c.Assert(err, gc.IsNil)
c.Assert(params.Process(), gc.Not(gc.IsNil))
cancelChan := make(chan struct{}, 1)
defer close(cancelChan)
cancelChan <- struct{}{}
mockChan <- time.Now()
res, err := params.WaitWithCancel(cancelChan)
c.Assert(err, gc.ErrorMatches, fmt.Sprintf("tried to kill process %d, but timed out", params.Process().Pid))
c.Assert(res, gc.IsNil)
c.Assert(killCalled, jc.IsTrue)
}
type mockClock struct {
clock.Clock
C <-chan time.Time
}
func (m *mockClock) After(t time.Duration) <-chan time.Time {
return m.C
}
================================================
FILE: exec/exec_unix.go
================================================
// Copyright 2016 Canonical Ltd.
// Copyright 2016 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build !windows
// +build !windows
package exec
import (
"os"
"syscall"
)
// KillProcess tries to kill the process being ran by RunParams
// We need this convoluted implementation because everything
// ran under the bash script is spawned as a different process
// and doesn't get killed by a regular process.Kill()
// For details see https://groups.google.com/forum/#!topic/golang-nuts/XoQ3RhFBJl8
func KillProcess(proc *os.Process) error {
pgid, err := syscall.Getpgid(proc.Pid)
if err == nil {
return syscall.Kill(-pgid, 15) // note the minus sign
}
return nil
}
// populateSysProcAttr exists so that the method Kill on the same struct
// can work correctly. For more information see Kill's comment.
func (r *RunParams) populateSysProcAttr() {
r.ps.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}
================================================
FILE: exec/exec_windows.go
================================================
// Copyright 2016 Canonical Ltd.
// Copyright 2016 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build windows
// +build windows
package exec
import (
"os"
)
// KillProcess tries to kill the process passed in.
func KillProcess(proc *os.Process) error {
return proc.Kill()
}
// populateSysProcAttr is a noop on windows
func (r *RunParams) populateSysProcAttr() {}
================================================
FILE: exec/exec_windows_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package exec_test
import (
"path/filepath"
"strings"
"syscall"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/exec"
)
// 1 is thrown by powershell after the a command is cancelled
const cancelErrCode = 1
// longPath is copied over from the symlink package. This should be removed
// if we add it to gc or in some other convenience package
func longPath(path string) ([]uint16, error) {
pathp, err := syscall.UTF16FromString(path)
if err != nil {
return nil, err
}
longp := pathp
n, err := syscall.GetLongPathName(&pathp[0], &longp[0], uint32(len(longp)))
if err != nil {
return nil, err
}
if n > uint32(len(longp)) {
longp = make([]uint16, n)
n, err = syscall.GetLongPathName(&pathp[0], &longp[0], uint32(len(longp)))
if err != nil {
return nil, err
}
}
longp = longp[:n]
return longp, nil
}
func longPathAsString(path string) (string, error) {
longp, err := longPath(path)
if err != nil {
return "", err
}
return syscall.UTF16ToString(longp), nil
}
func (*execSuite) TestRunCommands(c *gc.C) {
newDir, err := longPathAsString(c.MkDir())
c.Assert(err, gc.IsNil)
for i, test := range []struct {
message string
commands string
workingDir string
environment []string
stdout string
stderr string
code int
}{
{
message: "test stdout capture",
commands: "echo 'testing stdout'",
stdout: "testing stdout\r\n",
}, {
message: "test stderr capture",
commands: "Write-Error 'testing stderr'",
stderr: "testing stderr\r\n",
}, {
message: "test return code",
commands: "exit 42",
code: 42,
}, {
message: "test working dir",
commands: "(pwd).Path",
workingDir: newDir,
stdout: filepath.FromSlash(newDir) + "\r\n",
}, {
message: "test environment",
commands: "echo $env:OMG_IT_WORKS",
environment: []string{"OMG_IT_WORKS=like magic"},
stdout: "like magic\r\n",
},
} {
c.Logf("%v: %s", i, test.message)
params := exec.RunParams{
Commands: test.commands,
WorkingDir: test.workingDir,
Environment: test.environment,
}
result, err := exec.RunCommands(params)
c.Assert(err, gc.IsNil)
c.Assert(string(result.Stdout), gc.Equals, test.stdout)
c.Assert(string(result.Stderr), jc.Contains, test.stderr)
c.Assert(result.Code, gc.Equals, test.code)
err = params.Run()
c.Assert(err, gc.IsNil)
c.Assert(params.Process(), gc.Not(gc.IsNil))
result, err = params.Wait()
c.Assert(err, gc.IsNil)
c.Assert(string(result.Stdout), gc.Equals, test.stdout)
c.Assert(string(result.Stderr), jc.Contains, test.stderr)
c.Assert(result.Code, gc.Equals, test.code)
err = params.Run()
c.Assert(err, gc.IsNil)
c.Assert(params.Process(), gc.Not(gc.IsNil))
result, err = params.WaitWithCancel(nil)
c.Assert(err, gc.IsNil)
c.Assert(string(result.Stdout), gc.Equals, test.stdout)
c.Assert(string(result.Stderr), jc.Contains, test.stderr)
c.Assert(result.Code, gc.Equals, test.code)
}
}
func (*execSuite) TestExecUnknownCommand(c *gc.C) {
result, err := exec.RunCommands(
exec.RunParams{
Commands: "unknown-command",
},
)
c.Assert(err, gc.IsNil)
c.Assert(result.Stdout, gc.HasLen, 0)
stderr := strings.Replace(string(result.Stderr), "\r\n", "", -1)
c.Assert(stderr, jc.Contains, "is not recognized as the name of a cmdlet")
// 1 is returned by RunCommands when powershell commands throw exceptions
c.Assert(result.Code, gc.Equals, 1)
}
================================================
FILE: exec/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package exec_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func Test(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: export_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"time"
)
var (
GOMAXPROCS = &gomaxprocs
NumCPU = &numCPU
ResolveSudoByFunc = resolveSudo
)
func ExposeBackoffTimerDuration(bot *BackoffTimer) time.Duration {
return bot.currentDuration
}
================================================
FILE: file.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"fmt"
"io"
"io/ioutil"
"os"
"path"
"path/filepath"
"regexp"
"github.com/juju/errors"
)
// UserHomeDir returns the home directory for the specified user, or the
// home directory for the current user if the specified user is empty.
func UserHomeDir(userName string) (hDir string, err error) {
if userName == "" {
// TODO (wallyworld) - fix tests on Windows
// Ordinarily, we'd always use user.Current() to get the current user
// and then get the HomeDir from that. But our tests rely on poking
// a value into $HOME in order to override the normal home dir for the
// current user. So we're forced to use Home() to make the tests pass.
// All of our tests currently construct paths with the default user in
// mind eg "~/foo".
return Home(), nil
}
hDir, err = homeDir(userName)
if err != nil {
return "", err
}
return hDir, nil
}
// Only match paths starting with ~ (~user/test, ~/test). This will prevent
// accidental expansion on Windows when short form paths are present (C:\users\ADMINI~1\test)
var userHomePathRegexp = regexp.MustCompile("(^~(?P[^/]*))(?P.*)")
// NormalizePath expands a path containing ~ to its absolute form,
// and removes any .. or . path elements.
func NormalizePath(dir string) (string, error) {
if userHomePathRegexp.MatchString(dir) {
user := userHomePathRegexp.ReplaceAllString(dir, "$user")
userHomeDir, err := UserHomeDir(user)
if err != nil {
return "", err
}
dir = userHomePathRegexp.ReplaceAllString(dir, fmt.Sprintf("%s$path", userHomeDir))
}
return filepath.Clean(dir), nil
}
// ExpandPath normalises (via Normalize) a path returning an absolute path.
func ExpandPath(path string) (string, error) {
normPath, err := NormalizePath(path)
if err != nil {
return "", errors.Annotate(err, "unable to normalise file path")
}
return filepath.Abs(normPath)
}
// EnsureBaseDir ensures that path is always prefixed by baseDir,
// allowing for the fact that path might have a Window drive letter in
// it.
func EnsureBaseDir(baseDir, path string) string {
if baseDir == "" {
return path
}
volume := filepath.VolumeName(path)
return filepath.Join(baseDir, path[len(volume):])
}
// JoinServerPath joins any number of path elements into a single path, adding
// a path separator (based on the current juju server OS) if necessary. The
// result is Cleaned; in particular, all empty strings are ignored.
func JoinServerPath(elem ...string) string {
return path.Join(elem...)
}
// UniqueDirectory returns "path/name" if that directory doesn't exist. If it
// does, the method starts appending .1, .2, etc until a unique name is found.
func UniqueDirectory(path, name string) (string, error) {
dir := filepath.Join(path, name)
_, err := os.Stat(dir)
if os.IsNotExist(err) {
return dir, nil
}
for i := 1; ; i++ {
dir := filepath.Join(path, fmt.Sprintf("%s.%d", name, i))
_, err := os.Stat(dir)
if os.IsNotExist(err) {
return dir, nil
} else if err != nil {
return "", err
}
}
}
// CopyFile writes the contents of the given source file to dest.
func CopyFile(dest, source string) error {
df, err := os.Create(dest)
if err != nil {
return err
}
f, err := os.Open(source)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(df, f)
return err
}
// AtomicWriteFileAndChange atomically writes the filename with the
// given contents and calls the given function after the contents were
// written, but before the file is renamed.
func AtomicWriteFileAndChange(filename string, contents []byte, change func(string) error) (err error) {
dir, file := filepath.Split(filename)
f, err := ioutil.TempFile(dir, file)
if err != nil {
return fmt.Errorf("cannot create temp file: %v", err)
}
defer func() { _ = f.Close() }()
defer func() {
if err != nil {
// Don't leave the temp file lying around on error.
// Close the file before removing. Trying to remove an open file on
// Windows will fail.
_ = f.Close()
_ = os.Remove(f.Name())
}
}()
if _, err := f.Write(contents); err != nil {
return fmt.Errorf("cannot write %q contents: %v", filename, err)
}
if err := f.Sync(); err != nil {
return err
}
if err := f.Close(); err != nil {
return err
}
if err := change(f.Name()); err != nil {
return err
}
if err := ReplaceFile(f.Name(), filename); err != nil {
return fmt.Errorf("cannot replace %q with %q: %v", f.Name(), filename, err)
}
return nil
}
// AtomicWriteFile atomically writes the filename with the given
// contents and permissions, replacing any existing file at the same
// path.
func AtomicWriteFile(filename string, contents []byte, perms os.FileMode) (err error) {
return AtomicWriteFileAndChange(filename, contents, func(f string) error {
// FileMod.Chmod() is not implemented on Windows, however, os.Chmod() is
if err := os.Chmod(f, perms); err != nil {
return fmt.Errorf("cannot set permissions: %v", err)
}
return nil
})
}
================================================
FILE: file_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"fmt"
"io/ioutil"
"os"
"os/user"
"path/filepath"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type fileSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&fileSuite{})
func (*fileSuite) TestNormalizePath(c *gc.C) {
home := filepath.FromSlash(c.MkDir())
err := utils.SetHome(home)
c.Assert(err, gc.IsNil)
// TODO (frankban) bug 1324841: improve the isolation of this suite.
currentUser, err := user.Current()
c.Assert(err, gc.IsNil)
for i, test := range []struct {
path string
expected string
err string
}{{
path: filepath.FromSlash("/var/lib/juju"),
expected: filepath.FromSlash("/var/lib/juju"),
}, {
path: "~/foo",
expected: filepath.Join(home, "foo"),
}, {
path: "~/foo//../bar",
expected: filepath.Join(home, "bar"),
}, {
path: "~",
expected: home,
}, {
path: "~" + currentUser.Username,
expected: currentUser.HomeDir,
}, {
path: "~" + currentUser.Username + "/foo",
expected: filepath.Join(currentUser.HomeDir, "foo"),
}, {
path: "~" + currentUser.Username + "/foo//../bar",
expected: filepath.Join(currentUser.HomeDir, "bar"),
}, {
path: filepath.FromSlash("foo~bar/baz"),
expected: filepath.FromSlash("foo~bar/baz"),
}, {
path: "~foobar/path",
err: ".*" + utils.NoSuchUserErrRegexp,
}} {
c.Logf("test %d: %s", i, test.path)
actual, err := utils.NormalizePath(test.path)
if test.err != "" {
c.Check(err, gc.ErrorMatches, test.err)
} else {
c.Check(err, gc.IsNil)
c.Check(actual, gc.Equals, test.expected)
}
}
}
func (*fileSuite) TestExpandPath(c *gc.C) {
home := filepath.FromSlash(c.MkDir())
err := utils.SetHome(home)
c.Assert(err, gc.IsNil)
currentUser, err := user.Current()
c.Assert(err, gc.IsNil)
cwd, err := os.Getwd()
c.Assert(err, gc.IsNil)
for i, test := range []struct {
path string
expected string
err string
}{{
path: filepath.FromSlash("/var/lib/juju"),
expected: filepath.FromSlash("/var/lib/juju"),
}, {
path: "~/foo",
expected: filepath.Join(home, "foo"),
}, {
path: "~/foo//../bar",
expected: filepath.Join(home, "bar"),
}, {
path: "~",
expected: home,
}, {
path: "~" + currentUser.Username,
expected: currentUser.HomeDir,
}, {
path: "~" + currentUser.Username + "/foo",
expected: filepath.Join(currentUser.HomeDir, "foo"),
}, {
path: "~" + currentUser.Username + "/foo//../bar",
expected: filepath.Join(currentUser.HomeDir, "bar"),
}, {
path: filepath.FromSlash("foo~bar/baz"),
expected: filepath.Join(cwd, "foo~bar/baz"),
}, {
path: filepath.FromSlash("foo/bar"),
expected: filepath.Join(cwd, "foo", "bar"),
}, {
path: filepath.FromSlash("foo/../bar"),
expected: filepath.Join(cwd, "bar"),
}, {
path: filepath.FromSlash("foo/./bar"),
expected: filepath.Join(cwd, "foo", "bar"),
}, {
path: "~foobar/path",
err: ".*" + utils.NoSuchUserErrRegexp,
}} {
c.Logf("test %d: %s", i, test.path)
actual, err := utils.ExpandPath(test.path)
if test.err != "" {
c.Check(err, gc.ErrorMatches, test.err)
} else {
c.Check(err, gc.IsNil)
c.Check(actual, gc.Equals, test.expected)
c.Check(filepath.IsAbs(actual), jc.IsTrue)
}
}
}
func (*fileSuite) TestCopyFile(c *gc.C) {
dir := c.MkDir()
f, err := ioutil.TempFile(dir, "source")
c.Assert(err, gc.IsNil)
defer f.Close()
_, err = f.Write([]byte("hello world"))
c.Assert(err, gc.IsNil)
dest := filepath.Join(dir, "dest")
err = utils.CopyFile(dest, f.Name())
c.Assert(err, gc.IsNil)
data, err := ioutil.ReadFile(dest)
c.Assert(err, gc.IsNil)
c.Assert(string(data), gc.Equals, "hello world")
}
var atomicWriteFileTests = []struct {
summary string
change func(filename string, contents []byte) error
check func(c *gc.C, fileInfo os.FileInfo)
expectErr string
}{{
summary: "atomic file write and chmod 0644",
change: func(filename string, contents []byte) error {
return utils.AtomicWriteFile(filename, contents, 0765)
},
check: func(c *gc.C, fi os.FileInfo) {
c.Assert(fi.Mode(), gc.Equals, 0765)
},
}, {
summary: "atomic file write and change",
change: func(filename string, contents []byte) error {
chmodChange := func(f string) error {
// FileMod.Chmod() is not implemented on Windows, however, os.Chmod() is
return os.Chmod(f, 0700)
}
return utils.AtomicWriteFileAndChange(filename, contents, chmodChange)
},
check: func(c *gc.C, fi os.FileInfo) {
c.Assert(fi.Mode(), gc.Equals, 0700)
},
}, {
summary: "atomic file write empty contents",
change: func(filename string, contents []byte) error {
nopChange := func(string) error {
return nil
}
return utils.AtomicWriteFileAndChange(filename, contents, nopChange)
},
}, {
summary: "atomic file write and failing change func",
change: func(filename string, contents []byte) error {
errChange := func(string) error {
return fmt.Errorf("pow!")
}
return utils.AtomicWriteFileAndChange(filename, contents, errChange)
},
expectErr: "pow!",
}}
func (*fileSuite) TestAtomicWriteFile(c *gc.C) {
dir := c.MkDir()
name := "test.file"
path := filepath.Join(dir, name)
assertDirContents := func(names ...string) {
fis, err := ioutil.ReadDir(dir)
c.Assert(err, gc.IsNil)
c.Assert(fis, gc.HasLen, len(names))
for i, name := range names {
c.Assert(fis[i].Name(), gc.Equals, name)
}
}
assertNotExist := func(path string) {
_, err := os.Lstat(path)
c.Assert(err, jc.Satisfies, os.IsNotExist)
}
for i, test := range atomicWriteFileTests {
c.Logf("test %d: %s", i, test.summary)
// First - test with file not already there.
assertDirContents()
assertNotExist(path)
contents := []byte("some\ncontents")
err := test.change(path, contents)
if test.expectErr == "" {
c.Assert(err, gc.IsNil)
data, err := ioutil.ReadFile(path)
c.Assert(err, gc.IsNil)
c.Assert(data, jc.DeepEquals, contents)
assertDirContents(name)
} else {
c.Assert(err, gc.ErrorMatches, test.expectErr)
assertDirContents()
continue
}
// Second - test with a file already there.
contents = []byte("new\ncontents")
err = test.change(path, contents)
c.Assert(err, gc.IsNil)
data, err := ioutil.ReadFile(path)
c.Assert(err, gc.IsNil)
c.Assert(data, jc.DeepEquals, contents)
assertDirContents(name)
// Remove the file to reset scenario.
c.Assert(os.Remove(path), gc.IsNil)
}
}
func (*fileSuite) TestMoveFile(c *gc.C) {
d := c.MkDir()
dest := filepath.Join(d, "foo")
f1Name := filepath.Join(d, ".foo1")
f2Name := filepath.Join(d, ".foo2")
err := ioutil.WriteFile(f1Name, []byte("macaroni"), 0644)
c.Assert(err, gc.IsNil)
err = ioutil.WriteFile(f2Name, []byte("cheese"), 0644)
c.Assert(err, gc.IsNil)
ok, err := utils.MoveFile(f1Name, dest)
c.Assert(ok, gc.Equals, true)
c.Assert(err, gc.IsNil)
ok, err = utils.MoveFile(f2Name, dest)
c.Assert(ok, gc.Equals, false)
c.Assert(err, gc.NotNil)
contents, err := ioutil.ReadFile(dest)
c.Assert(err, gc.IsNil)
c.Assert(contents, gc.DeepEquals, []byte("macaroni"))
}
================================================
FILE: file_unix.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build !windows
// +build !windows
package utils
import (
"fmt"
"os"
"os/user"
"strconv"
"strings"
"syscall"
"github.com/juju/errors"
)
func homeDir(userName string) (string, error) {
u, err := user.Lookup(userName)
if err != nil {
return "", errors.NewUserNotFound(err, "no such user")
}
return u.HomeDir, nil
}
// MoveFile atomically moves the source file to the destination, returning
// whether the file was moved successfully. If the destination already exists,
// it returns an error rather than overwrite it.
//
// On unix systems, an error may occur with a successful move, if the source
// file location cannot be unlinked.
func MoveFile(source, destination string) (bool, error) {
err := os.Link(source, destination)
if err != nil {
return false, err
}
err = os.Remove(source)
if err != nil {
return true, err
}
return true, nil
}
// ReplaceFile atomically replaces the destination file or directory
// with the source. The errors that are returned are identical to
// those returned by os.Rename.
func ReplaceFile(source, destination string) error {
return os.Rename(source, destination)
}
// MakeFileURL returns a file URL if a directory is passed in else it does nothing
func MakeFileURL(in string) string {
if strings.HasPrefix(in, "/") {
return "file://" + in
}
return in
}
// ChownPath sets the uid and gid of path to match that of the user
// specified.
func ChownPath(path, username string) error {
u, err := user.Lookup(username)
if err != nil {
return fmt.Errorf("cannot lookup %q user id: %v", username, err)
}
uid, err := strconv.Atoi(u.Uid)
if err != nil {
return fmt.Errorf("invalid user id %q: %v", u.Uid, err)
}
gid, err := strconv.Atoi(u.Gid)
if err != nil {
return fmt.Errorf("invalid group id %q: %v", u.Gid, err)
}
return os.Chown(path, uid, gid)
}
// IsFileOwner checks to see if the ownership of the file corresponds to
// the same username
func IsFileOwner(path, username string) (bool, error) {
u, err := user.Lookup(username)
if err != nil {
return false, errors.Annotatef(err, "cannot lookup %q user id", username)
}
info, err := os.Stat(path)
if err != nil {
return false, errors.Trace(err)
}
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return false, fmt.Errorf("cannot lookup %q file", path)
}
return (strconv.Itoa(int(stat.Uid)) == u.Uid &&
strconv.Itoa(int(stat.Gid)) == u.Gid), nil
}
================================================
FILE: file_unix_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build !windows
// +build !windows
package utils_test
import (
"fmt"
"os"
"path/filepath"
"time"
gc "gopkg.in/check.v1"
"github.com/juju/errors"
"github.com/juju/utils/v4"
)
type unixFileSuite struct {
}
var _ = gc.Suite(&unixFileSuite{})
func (s *unixFileSuite) TestEnsureBaseDir(c *gc.C) {
c.Assert(utils.EnsureBaseDir(`/a`, `/b/c`), gc.Equals, `/a/b/c`)
c.Assert(utils.EnsureBaseDir(`/`, `/b/c`), gc.Equals, `/b/c`)
c.Assert(utils.EnsureBaseDir(``, `/b/c`), gc.Equals, `/b/c`)
}
func (s *unixFileSuite) TestFileOwner(c *gc.C) {
username, err := utils.LocalUsername()
c.Assert(err, gc.IsNil)
path := filepath.Join(os.TempDir(), fmt.Sprintf("file-%d", time.Now().UnixNano()))
_, err = os.Create(path)
c.Assert(err, gc.IsNil)
ok, err := utils.IsFileOwner(path, username)
c.Assert(err, gc.IsNil)
c.Assert(ok, gc.Equals, true)
}
func (s *unixFileSuite) TestFileOwnerUsingRoot(c *gc.C) {
path := filepath.Join(os.TempDir(), fmt.Sprintf("file-%d", time.Now().UnixNano()))
_, err := os.Create(path)
c.Assert(err, gc.IsNil)
ok, err := utils.IsFileOwner(path, "root")
c.Assert(err, gc.IsNil)
c.Assert(ok, gc.Equals, false)
}
func (s *unixFileSuite) TestFileOwnerWithInvalidPath(c *gc.C) {
username, err := utils.LocalUsername()
c.Assert(err, gc.IsNil)
path := filepath.Join(os.TempDir(), "file-bad")
ok, err := utils.IsFileOwner(path, username)
c.Assert(errors.Cause(err), gc.ErrorMatches, "stat .*: no such file or directory")
c.Assert(ok, gc.Equals, false)
}
func (s *unixFileSuite) TestFileOwnerWithInvalidUsername(c *gc.C) {
path := filepath.Join(os.TempDir(), fmt.Sprintf("file-%d", time.Now().UnixNano()))
_, err := os.Create(path)
c.Assert(err, gc.IsNil)
ok, err := utils.IsFileOwner(path, "invalid")
c.Assert(errors.Cause(err), gc.ErrorMatches, "user: unknown user invalid")
c.Assert(ok, gc.Equals, false)
}
================================================
FILE: file_windows.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build windows
// +build windows
package utils
import (
"fmt"
"os"
"path/filepath"
"syscall"
"unsafe"
"github.com/juju/errors"
)
const (
movefile_replace_existing = 0x1
movefile_write_through = 0x8
)
//sys moveFileEx(lpExistingFileName *uint16, lpNewFileName *uint16, dwFlags uint32) (err error) = MoveFileExW
// MoveFile atomically moves the source file to the destination, returning
// whether the file was moved successfully. If the destination already exists,
// it returns an error rather than overwrite it.
func MoveFile(source, destination string) (bool, error) {
src, err := syscall.UTF16PtrFromString(source)
if err != nil {
return false, &os.LinkError{"move", source, destination, err}
}
dest, err := syscall.UTF16PtrFromString(destination)
if err != nil {
return false, &os.LinkError{"move", source, destination, err}
}
// see http://msdn.microsoft.com/en-us/library/windows/desktop/aa365240(v=vs.85).aspx
if err := moveFileEx(src, dest, movefile_write_through); err != nil {
return false, &os.LinkError{"move", source, destination, err}
}
return true, nil
}
// ReplaceFile atomically replaces the destination file or directory with the source.
// The errors that are returned are identical to those returned by os.Rename.
func ReplaceFile(source, destination string) error {
src, err := syscall.UTF16PtrFromString(source)
if err != nil {
return &os.LinkError{"replace", source, destination, err}
}
dest, err := syscall.UTF16PtrFromString(destination)
if err != nil {
return &os.LinkError{"replace", source, destination, err}
}
// see http://msdn.microsoft.com/en-us/library/windows/desktop/aa365240(v=vs.85).aspx
if err := moveFileEx(src, dest, movefile_replace_existing|movefile_write_through); err != nil {
return &os.LinkError{"replace", source, destination, err}
}
return nil
}
// MakeFileURL returns a proper file URL for the given path/directory
func MakeFileURL(in string) string {
in = filepath.ToSlash(in)
// for windows at least should be : to be considered valid
// so we cant do anything with less than that.
if len(in) < 2 {
return in
}
if string(in[1]) != ":" {
return in
}
// since go 1.6 http client will only take this format.
return "file://" + in
}
func getUserSID(username string) (string, error) {
sid, _, _, e := syscall.LookupSID("", username)
if e != nil {
return "", e
}
sidStr, err := sid.String()
return sidStr, err
}
func readRegString(h syscall.Handle, key string) (value string, err error) {
var typ uint32
var buf uint32
// Get size of registry key
err = syscall.RegQueryValueEx(h, syscall.StringToUTF16Ptr(key), nil, &typ, nil, &buf)
if err != nil {
return value, err
}
n := make([]uint16, buf/2+1)
err = syscall.RegQueryValueEx(h, syscall.StringToUTF16Ptr(key), nil, &typ, (*byte)(unsafe.Pointer(&n[0])), &buf)
if err != nil {
return value, err
}
return syscall.UTF16ToString(n[:]), err
}
func homeFromRegistry(sid string) (string, error) {
var h syscall.Handle
// This key will exist on all platforms we support the agent on (windows server 2008 and above)
keyPath := fmt.Sprintf("Software\\Microsoft\\Windows NT\\CurrentVersion\\ProfileList\\%s", sid)
err := syscall.RegOpenKeyEx(syscall.HKEY_LOCAL_MACHINE,
syscall.StringToUTF16Ptr(keyPath),
0, syscall.KEY_READ, &h)
if err != nil {
return "", err
}
defer syscall.RegCloseKey(h)
str, err := readRegString(h, "ProfileImagePath")
if err != nil {
return "", err
}
return str, nil
}
// homeDir returns a local user home dir on Windows
// user.Lookup() does not populate Gid and HomeDir on Windows,
// so we get it from the registry
func homeDir(user string) (string, error) {
u, err := getUserSID(user)
if err != nil {
return "", errors.NewUserNotFound(err, "no such user")
}
return homeFromRegistry(u)
}
// ChownPath is not implemented for Windows.
func ChownPath(path, username string) error {
// This only exists to allow building on Windows. User lookup and
// file ownership needs to be handled in a completely different
// way and hasn't yet been implemented.
return nil
}
// IsFileOwner is not implemented for Windows.
func IsFileOwner(path, username string) (bool, error) {
return true, nil
}
================================================
FILE: file_windows_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build windows
// +build windows
package utils_test
import (
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type windowsFileSuite struct {
}
var _ = gc.Suite(&windowsFileSuite{})
func (s *windowsFileSuite) TestMakeFileURL(c *gc.C) {
var makeFileURLTests = []struct {
in string
expected string
}{{
in: "file://C:\\foo\\baz",
expected: "file://C:/foo/baz",
}, {
in: "C:\\foo\\baz",
expected: "file://C:/foo/baz",
}, {
in: "http://foo/baz",
expected: "http://foo/baz",
}, {
in: "file://C:/foo/baz",
expected: "file://C:/foo/baz",
}}
for i, t := range makeFileURLTests {
c.Logf("Test %d", i)
c.Assert(utils.MakeFileURL(t.in), gc.Equals, t.expected)
}
}
func (s *windowsFileSuite) TestEnsureBaseDir(c *gc.C) {
c.Assert(utils.EnsureBaseDir(`C:\r`, `C:\a\b`), gc.Equals, `C:\r\a\b`)
c.Assert(utils.EnsureBaseDir(`C:\r`, `D:\a\b`), gc.Equals, `C:\r\a\b`)
c.Assert(utils.EnsureBaseDir(`C:`, `D:\a\b`), gc.Equals, `C:\a\b`)
c.Assert(utils.EnsureBaseDir(`C:`, `\a\b`), gc.Equals, `C:\a\b`)
c.Assert(utils.EnsureBaseDir(``, `C:\a\b`), gc.Equals, `C:\a\b`)
}
func (s *windowsFileSuite) TestFileOwner(c *gc.C) {
c.Assert(utils.IsFileOwner("file://C:\\foo\\baz", "timmy"), gc.Equals, true)
}
================================================
FILE: filepath/common.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath
func splitSuffix(path string) (string, string) {
for i := len(path) - 1; i >= 0; i-- {
if path[i] == '.' && i > 0 {
return path[:i], path[i:]
}
}
return path, ""
}
================================================
FILE: filepath/common_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath_test
import (
"github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/filepath"
)
var _ = gc.Suite(&commonSuite{})
type commonSuite struct {
testing.IsolationSuite
}
func (s commonSuite) TestSplitSuffixHasSuffix(c *gc.C) {
path, suffix := filepath.SplitSuffix("spam.ext")
c.Check(path, gc.Equals, "spam")
c.Check(suffix, gc.Equals, ".ext")
}
func (s commonSuite) TestSplitSuffixNoSuffix(c *gc.C) {
path, suffix := filepath.SplitSuffix("spam")
c.Check(path, gc.Equals, "spam")
c.Check(suffix, gc.Equals, "")
}
func (s commonSuite) TestSplitSuffixEmpty(c *gc.C) {
path, suffix := filepath.SplitSuffix("")
c.Check(path, gc.Equals, "")
c.Check(suffix, gc.Equals, "")
}
func (s commonSuite) TestSplitSuffixDotFilePlain(c *gc.C) {
path, suffix := filepath.SplitSuffix(".spam")
c.Check(path, gc.Equals, ".spam")
c.Check(suffix, gc.Equals, "")
}
func (s commonSuite) TestSplitSuffixDofileWithSuffix(c *gc.C) {
path, suffix := filepath.SplitSuffix(".spam.ext")
c.Check(path, gc.Equals, ".spam")
c.Check(suffix, gc.Equals, ".ext")
}
================================================
FILE: filepath/export_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath
var (
SplitSuffix = splitSuffix
)
================================================
FILE: filepath/filepath.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath
import (
"runtime"
"strings"
"github.com/juju/errors"
"github.com/juju/utils/v4"
)
// Renderer provides methods for the different functions in
// the stdlib path/filepath package that don't relate to a concrete
// filesystem. So Abs, EvalSymlinks, Glob, Rel, and Walk are not
// included. Also, while the functions in path/filepath relate to the
// current host, the PathRenderer methods relate to the renderer's
// target platform. So for example, a windows-oriented implementation
// will give windows-specific results even when used on linux.
type Renderer interface {
// Base mimics path/filepath.
Base(path string) string
// Clean mimics path/filepath.
Clean(path string) string
// Dir mimics path/filepath.
Dir(path string) string
// Ext mimics path/filepath.
Ext(path string) string
// FromSlash mimics path/filepath.
FromSlash(path string) string
// IsAbs mimics path/filepath.
IsAbs(path string) bool
// Join mimics path/filepath.
Join(path ...string) string
// Match mimics path/filepath.
Match(pattern, name string) (matched bool, err error)
// NormCase normalizes the case of a pathname. On Unix and Mac OS X,
// this returns the path unchanged; on case-insensitive filesystems,
// it converts the path to lowercase.
NormCase(path string) string
// Split mimics path/filepath.
Split(path string) (dir, file string)
// SplitList mimics path/filepath.
SplitList(path string) []string
// SplitSuffix splits the pathname into a pair (root, suffix) such
// that root + suffix == path, and ext is empty or begins with a
// period and contains at most one period. Leading periods on the
// basename are ignored; SplitSuffix('.cshrc') returns ('.cshrc', '').
SplitSuffix(path string) (string, string)
// ToSlash mimics path/filepath.
ToSlash(path string) string
// VolumeName mimics path/filepath.
VolumeName(path string) string
}
// NewRenderer returns a Renderer for the given os.
func NewRenderer(os string) (Renderer, error) {
if os == "" {
os = runtime.GOOS
}
os = strings.ToLower(os)
switch {
case os == utils.OSWindows:
return &WindowsRenderer{}, nil
case utils.OSIsUnix(os):
return &UnixRenderer{}, nil
case os == "ubuntu":
return &UnixRenderer{}, nil
default:
return nil, errors.NotFoundf("renderer for %q", os)
}
}
================================================
FILE: filepath/filepath_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath_test
import (
"runtime"
"github.com/juju/errors"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
"github.com/juju/utils/v4"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/filepath"
)
type filepathSuite struct {
testing.IsolationSuite
unix *filepath.UnixRenderer
windows *filepath.WindowsRenderer
}
var _ = gc.Suite(&filepathSuite{})
func (s *filepathSuite) SetupTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
s.unix = &filepath.UnixRenderer{}
s.windows = &filepath.WindowsRenderer{}
}
func (s filepathSuite) checkRenderer(c *gc.C, renderer filepath.Renderer, expected string) {
switch expected {
case "windows":
c.Check(renderer, gc.FitsTypeOf, s.windows)
case "unix":
c.Check(renderer, gc.FitsTypeOf, s.unix)
default:
c.Errorf("unknown kind %q", expected)
}
}
func (s filepathSuite) TestNewRendererDefault(c *gc.C) {
// All possible values of runtime.GOOS should be supported.
renderer, err := filepath.NewRenderer("")
c.Assert(err, jc.ErrorIsNil)
switch runtime.GOOS {
case "windows":
s.checkRenderer(c, renderer, "windows")
default:
s.checkRenderer(c, renderer, "unix")
}
}
func (s filepathSuite) TestNewRendererGOOS(c *gc.C) {
// All possible values of runtime.GOOS should be supported.
renderer, err := filepath.NewRenderer(runtime.GOOS)
c.Assert(err, jc.ErrorIsNil)
switch runtime.GOOS {
case "windows":
s.checkRenderer(c, renderer, "windows")
default:
s.checkRenderer(c, renderer, "unix")
}
}
func (s filepathSuite) TestNewRendererWindows(c *gc.C) {
renderer, err := filepath.NewRenderer("windows")
c.Assert(err, jc.ErrorIsNil)
s.checkRenderer(c, renderer, "windows")
}
func (s filepathSuite) TestNewRendererUnix(c *gc.C) {
for _, os := range utils.OSUnix {
c.Logf("trying %q", os)
renderer, err := filepath.NewRenderer(os)
c.Assert(err, jc.ErrorIsNil)
s.checkRenderer(c, renderer, "unix")
}
}
func (s filepathSuite) TestNewRendererDistros(c *gc.C) {
distros := []string{"ubuntu"}
for _, distro := range distros {
c.Logf("trying %q", distro)
renderer, err := filepath.NewRenderer(distro)
c.Assert(err, jc.ErrorIsNil)
s.checkRenderer(c, renderer, "unix")
}
}
func (s filepathSuite) TestNewRendererUnknown(c *gc.C) {
_, err := filepath.NewRenderer("")
c.Check(err, jc.Satisfies, errors.IsNotFound)
}
================================================
FILE: filepath/interface_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath
var _ Renderer = (*UnixRenderer)(nil)
var _ Renderer = (*WindowsRenderer)(nil)
================================================
FILE: filepath/package_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func Test(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: filepath/stdlib.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.golang file.
package filepath
import (
"strings"
)
// The following functions are adapted from the GO stdlib source.
// Base mimics path/filepath for the given path separator.
func Base(sep uint8, volumeName func(string) string, path string) string {
if path == "" {
return "."
}
// Strip trailing slashes.
for len(path) > 0 && path[len(path)-1] == sep {
path = path[0 : len(path)-1]
}
// Throw away volume name
path = path[len(volumeName(path)):]
// Find the last element
i := len(path) - 1
for i >= 0 && path[i] != sep {
i--
}
if i >= 0 {
path = path[i+1:]
}
// If empty now, it had only slashes.
if path == "" {
return string(sep)
}
return path
}
// A lazybuf is a lazily constructed path buffer.
// It supports append, reading previously appended bytes,
// and retrieving the final string. It does not allocate a buffer
// to hold the output until that output diverges from s.
type lazybuf struct {
path string
buf []byte
w int
volAndPath string
volLen int
}
func (b *lazybuf) index(i int) byte {
if b.buf != nil {
return b.buf[i]
}
return b.path[i]
}
func (b *lazybuf) append(c byte) {
if b.buf == nil {
if b.w < len(b.path) && b.path[b.w] == c {
b.w++
return
}
b.buf = make([]byte, len(b.path))
copy(b.buf, b.path[:b.w])
}
b.buf[b.w] = c
b.w++
}
func (b *lazybuf) string() string {
if b.buf == nil {
return b.volAndPath[:b.volLen+b.w]
}
return b.volAndPath[:b.volLen] + string(b.buf[:b.w])
}
// Clean mimics path/filepath for the given path separator.
func Clean(sep uint8, volumeName func(string) string, path string) string {
originalPath := path
volLen := len(volumeName(path))
path = path[volLen:]
if path == "" {
if volLen > 1 && originalPath[1] != ':' {
// should be UNC
return FromSlash(sep, originalPath)
}
return originalPath + "."
}
rooted := (path[0] == sep)
// Invariants:
// reading from path; r is index of next byte to process.
// writing to buf; w is index of next byte to write.
// dotdot is index in buf where .. must stop, either because
// it is the leading slash or it is a leading ../../.. prefix.
n := len(path)
out := lazybuf{path: path, volAndPath: originalPath, volLen: volLen}
r, dotdot := 0, 0
if rooted {
out.append(sep)
r, dotdot = 1, 1
}
for r < n {
switch {
case path[r] == sep:
// empty path element
r++
case path[r] == '.' && (r+1 == n || path[r+1] == sep):
// . element
r++
case path[r] == '.' && path[r+1] == '.' && (r+2 == n || path[r+2] == sep):
// .. element: remove to last separator
r += 2
switch {
case out.w > dotdot:
// can backtrack
out.w--
for out.w > dotdot && out.index(out.w) != sep {
out.w--
}
case !rooted:
// cannot backtrack, but not rooted, so append .. element.
if out.w > 0 {
out.append(sep)
}
out.append('.')
out.append('.')
dotdot = out.w
}
default:
// real path element.
// add slash if needed
if rooted && out.w != 1 || !rooted && out.w != 0 {
out.append(sep)
}
// copy element
for ; r < n && path[r] != sep; r++ {
out.append(path[r])
}
}
}
// Turn empty string into "."
if out.w == 0 {
out.append('.')
}
return FromSlash(sep, out.string())
}
// Dir mimics path/filepath for the given path separator.
func Dir(sep uint8, volumeName func(string) string, path string) string {
vol := volumeName(path)
i := len(path) - 1
for i >= len(vol) && path[i] != sep {
i--
}
dir := Clean(sep, volumeName, path[len(vol):i+1])
return vol + dir
}
// Ext mimics path/filepath for the given path separator.
func Ext(sep uint8, path string) string {
for i := len(path) - 1; i >= 0 && path[i] != sep; i-- {
if path[i] == '.' {
return path[i:]
}
}
return ""
}
// FromSlash mimics path/filepath for the given path separator.
func FromSlash(sep uint8, path string) string {
if sep == '/' {
return path
}
return strings.Replace(path, "/", string(sep), -1)
}
// Join mimics path/filepath for the given path separator.
func Join(sep uint8, volumeName func(string) string, elem ...string) string {
for i, e := range elem {
if e != "" {
return Clean(sep, volumeName, strings.Join(elem[i:], string(sep)))
}
}
return ""
}
// Split mimics path/filepath for the given path separator.
func Split(sep uint8, volumeName func(string) string, path string) (dir, file string) {
vol := volumeName(path)
i := len(path) - 1
for i >= len(vol) && path[i] != sep {
i--
}
return path[:i+1], path[i+1:]
}
// ToSlash mimics path/filepath for the given path separator.
func ToSlash(sep uint8, path string) string {
if sep == '/' {
return path
}
return strings.Replace(path, string(sep), "/", -1)
}
================================================
FILE: filepath/stdlib_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath_test
import (
gofilepath "path/filepath"
"runtime"
"strings"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/filepath"
)
// The tests here are mostly just sanity checks against the behavior
// of the stdlib path/filepath. We are not trying for high coverage levels.
type stdlibSuite struct {
testing.IsolationSuite
path string
volumeName func(string) string
}
var _ = gc.Suite(&stdlibSuite{})
func (s *stdlibSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
switch runtime.GOOS {
case "windows":
s.path = `C:\a\b\c.xyz`
s.volumeName = func(path string) string {
return "C:"
}
default:
s.path = "/a/b/c.xyz"
s.volumeName = func(string) string { return "" }
}
}
func (s stdlibSuite) TestBase(c *gc.C) {
path := filepath.Base(gofilepath.Separator, s.volumeName, s.path)
gopath := gofilepath.Base(s.path)
c.Check(path, gc.Equals, gopath)
c.Check(path, gc.Equals, "c.xyz")
}
func (s stdlibSuite) TestClean(c *gc.C) {
// TODO(ericsnow) Add more cases.
originals := map[string]string{
s.path: s.path,
}
for original, expected := range originals {
c.Logf("checking %q", original)
path := filepath.Clean(gofilepath.Separator, s.volumeName, original)
gopath := gofilepath.Clean(original)
c.Check(path, gc.Equals, gopath)
c.Check(path, gc.Equals, expected)
}
}
func (s stdlibSuite) TestDir(c *gc.C) {
path := filepath.Dir(gofilepath.Separator, s.volumeName, s.path)
gopath := gofilepath.Dir(s.path)
c.Check(path, gc.Equals, gopath)
switch runtime.GOOS {
case "windows":
c.Check(path, gc.Equals, `\a\b`)
default:
c.Check(path, gc.Equals, "/a/b")
}
}
func (s stdlibSuite) TestExt(c *gc.C) {
ext := filepath.Ext(gofilepath.Separator, s.path)
goext := gofilepath.Ext(s.path)
c.Check(ext, gc.Equals, goext)
c.Check(ext, gc.Equals, ".xyz")
}
func (s stdlibSuite) TestFromSlash(c *gc.C) {
original := "/a/b/c.xyz"
path := filepath.FromSlash(gofilepath.Separator, original)
gopath := gofilepath.FromSlash(original)
c.Check(path, gc.Equals, gopath)
c.Check(path, gc.Equals, s.path)
}
func (s stdlibSuite) TestJoin(c *gc.C) {
path := filepath.Join(gofilepath.Separator, s.volumeName, "a", "b", "c.xyz")
gopath := gofilepath.Join("a", "b", "c.xyz")
c.Check(path, gc.Equals, gopath)
expected := s.path[strings.Index(s.path, string(gofilepath.Separator))+1:]
c.Check(path, gc.Equals, expected)
}
func (s stdlibSuite) TestSplit(c *gc.C) {
dir, base := filepath.Split(gofilepath.Separator, s.volumeName, s.path)
godir, gobase := gofilepath.Split(s.path)
c.Check(dir, gc.Equals, godir)
c.Check(base, gc.Equals, gobase)
switch runtime.GOOS {
case "windows":
c.Check(dir, gc.Equals, `\a\b\`)
default:
c.Check(dir, gc.Equals, "/a/b/")
}
c.Check(base, gc.Equals, "c.xyz")
}
func (s stdlibSuite) TestToSlash(c *gc.C) {
path := filepath.ToSlash(gofilepath.Separator, s.path)
gopath := gofilepath.ToSlash(s.path)
c.Check(path, gc.Equals, gopath)
c.Check(path, gc.Equals, "/a/b/c.xyz")
}
func (s stdlibSuite) TestMatchTrue(c *gc.C) {
tests := map[string]string{
"abc": "abc",
"ab[c]": "abc",
"": "",
"*": "abc",
"a*c": "abc",
"?": "a",
"a?c": "abc",
}
for pattern, name := range tests {
c.Logf("- checking pattern %q against %q -", pattern, name)
matched, err := filepath.Match(gofilepath.Separator, pattern, name)
c.Assert(err, jc.ErrorIsNil)
gomatched, err := gofilepath.Match(pattern, name)
c.Assert(err, jc.ErrorIsNil)
c.Check(matched, gc.Equals, gomatched)
c.Check(matched, jc.IsTrue)
}
}
func (s stdlibSuite) TestMatchFalse(c *gc.C) {
tests := map[string]string{
"abc": "xyz",
"": "abc",
"a*c": "a",
"?": "",
"a?c": "ac",
}
for pattern, name := range tests {
c.Logf("- checking pattern %q against %q -", pattern, name)
matched, err := filepath.Match(gofilepath.Separator, pattern, name)
c.Assert(err, jc.ErrorIsNil)
gomatched, err := gofilepath.Match(pattern, name)
c.Assert(err, jc.ErrorIsNil)
c.Check(matched, gc.Equals, gomatched)
c.Check(matched, jc.IsFalse)
}
}
func (s stdlibSuite) TestMatchBadPattern(c *gc.C) {
tests := map[string]string{
"ab[": "abc",
"ab[-c]": "abc",
"ab[]": "abc",
}
for pattern, name := range tests {
c.Logf("- checking pattern %q against %q -", pattern, name)
_, err := filepath.Match(gofilepath.Separator, pattern, name)
_, goerr := gofilepath.Match(pattern, name)
c.Check(err, gc.Equals, goerr)
c.Check(err, gc.Equals, gofilepath.ErrBadPattern)
}
}
================================================
FILE: filepath/stdlibmatch.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.golang file.
package filepath
import (
"path/filepath"
"strings"
"unicode/utf8"
)
// The following functions are adapted from the GO stdlib source.
// Match returns true if name matches the shell file name pattern.
// The pattern syntax is:
//
// pattern:
// { term }
// term:
// '*' matches any sequence of non-Separator characters
// '?' matches any single non-Separator character
// '[' [ '^' ] { character-range } ']'
// character class (must be non-empty)
// c matches character c (c != '*', '?', '\\', '[')
// '\\' c matches character c
//
// character-range:
// c matches character c (c != '\\', '-', ']')
// '\\' c matches character c
// lo '-' hi matches character c for lo <= c <= hi
//
// Match requires pattern to match all of name, not just a substring.
// The only possible returned error is ErrBadPattern, when pattern
// is malformed.
//
// On Windows, escaping is disabled. Instead, '\\' is treated as
// path separator.
//
func Match(sep uint8, pattern, name string) (matched bool, err error) {
Pattern:
for len(pattern) > 0 {
var star bool
var chunk string
star, chunk, pattern = scanChunk(sep, pattern)
if star && chunk == "" {
// Trailing * matches rest of string unless it has a /.
return strings.Index(name, string(sep)) < 0, nil
}
// Look for match at current position.
t, ok, err := matchChunk(sep, chunk, name)
// if we're the last chunk, make sure we've exhausted the name
// otherwise we'll give a false result even if we could still match
// using the star
if ok && (len(t) == 0 || len(pattern) > 0) {
name = t
continue
}
if err != nil {
return false, err
}
if star {
// Look for match skipping i+1 bytes.
// Cannot skip /.
for i := 0; i < len(name) && name[i] != sep; i++ {
t, ok, err := matchChunk(sep, chunk, name[i+1:])
if ok {
// if we're the last chunk, make sure we exhausted the name
if len(pattern) == 0 && len(t) > 0 {
continue
}
name = t
continue Pattern
}
if err != nil {
return false, err
}
}
}
return false, nil
}
return len(name) == 0, nil
}
// scanChunk gets the next segment of pattern, which is a non-star string
// possibly preceded by a star.
func scanChunk(sep uint8, pattern string) (star bool, chunk, rest string) {
for len(pattern) > 0 && pattern[0] == '*' {
pattern = pattern[1:]
star = true
}
inrange := false
var i int
Scan:
for i = 0; i < len(pattern); i++ {
switch pattern[i] {
case '\\':
if sep == '\\' {
// error check handled in matchChunk: bad pattern.
if i+1 < len(pattern) {
i++
}
}
case '[':
inrange = true
case ']':
inrange = false
case '*':
if !inrange {
break Scan
}
}
}
return star, pattern[0:i], pattern[i:]
}
// matchChunk checks whether chunk matches the beginning of s.
// If so, it returns the remainder of s (after the match).
// Chunk is all single-character operators: literals, char classes, and ?.
func matchChunk(sep uint8, chunk, s string) (rest string, ok bool, err error) {
for len(chunk) > 0 {
if len(s) == 0 {
return
}
switch chunk[0] {
case '[':
// character class
r, n := utf8.DecodeRuneInString(s)
s = s[n:]
chunk = chunk[1:]
// We can't end right after '[', we're expecting at least
// a closing bracket and possibly a caret.
if len(chunk) == 0 {
err = filepath.ErrBadPattern
return
}
// possibly negated
negated := chunk[0] == '^'
if negated {
chunk = chunk[1:]
}
// parse all ranges
match := false
nrange := 0
for {
if len(chunk) > 0 && chunk[0] == ']' && nrange > 0 {
chunk = chunk[1:]
break
}
var lo, hi rune
if lo, chunk, err = getEsc(sep, chunk); err != nil {
return
}
hi = lo
if chunk[0] == '-' {
if hi, chunk, err = getEsc(sep, chunk[1:]); err != nil {
return
}
}
if lo <= r && r <= hi {
match = true
}
nrange++
}
if match == negated {
return
}
case '?':
if s[0] == sep {
return
}
_, n := utf8.DecodeRuneInString(s)
s = s[n:]
chunk = chunk[1:]
case '\\':
if sep != '\\' {
chunk = chunk[1:]
if len(chunk) == 0 {
err = filepath.ErrBadPattern
return
}
}
fallthrough
default:
if chunk[0] != s[0] {
return
}
s = s[1:]
chunk = chunk[1:]
}
}
return s, true, nil
}
// getEsc gets a possibly-escaped character from chunk, for a character class.
func getEsc(sep uint8, chunk string) (r rune, nchunk string, err error) {
if len(chunk) == 0 || chunk[0] == '-' || chunk[0] == ']' {
err = filepath.ErrBadPattern
return
}
if chunk[0] == '\\' && sep != '\\' {
chunk = chunk[1:]
if len(chunk) == 0 {
err = filepath.ErrBadPattern
return
}
}
r, n := utf8.DecodeRuneInString(chunk)
if r == utf8.RuneError && n == 1 {
err = filepath.ErrBadPattern
}
nchunk = chunk[n:]
if len(nchunk) == 0 {
err = filepath.ErrBadPattern
}
return
}
================================================
FILE: filepath/unix.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath
import (
"strings"
)
// A substantial portion of this code comes from the Go stdlib code.
const (
UnixSeparator = '/' // OS-specific path separator
UnixListSeparator = ':' // OS-specific path list separator
)
// UnixRenderer is a Renderer implementation for most flavors of Unix.
type UnixRenderer struct{}
// Base implements Renderer.
func (ur UnixRenderer) Base(path string) string {
return Base(UnixSeparator, ur.VolumeName, path)
}
// Clean implements Renderer.
func (ur UnixRenderer) Clean(path string) string {
return Clean(UnixSeparator, ur.VolumeName, path)
}
// Dir implements Renderer.
func (ur UnixRenderer) Dir(path string) string {
return Dir(UnixSeparator, ur.VolumeName, path)
}
// Ext implements Renderer.
func (UnixRenderer) Ext(path string) string {
return Ext(UnixSeparator, path)
}
// FromSlash implements Renderer.
func (UnixRenderer) FromSlash(path string) string {
return FromSlash(UnixSeparator, path)
}
// IsAbs implements Renderer.
func (UnixRenderer) IsAbs(path string) bool {
return strings.HasPrefix(path, string(UnixSeparator))
}
// Join implements Renderer.
func (ur UnixRenderer) Join(path ...string) string {
return Join(UnixSeparator, ur.VolumeName, path...)
}
// Match implements Renderer.
func (UnixRenderer) Match(pattern, name string) (matched bool, err error) {
return Match(UnixSeparator, pattern, name)
}
// Split implements Renderer.
func (ur UnixRenderer) Split(path string) (dir, file string) {
return Split(UnixSeparator, ur.VolumeName, path)
}
// SplitList implements Renderer.
func (UnixRenderer) SplitList(path string) []string {
if path == "" {
return []string{}
}
return strings.Split(path, string(UnixListSeparator))
}
// ToSlash implements Renderer.
func (UnixRenderer) ToSlash(path string) string {
return ToSlash(UnixSeparator, path)
}
// VolumeName implements Renderer.
func (UnixRenderer) VolumeName(path string) string {
return ""
}
// NormCase implements Renderer.
func (UnixRenderer) NormCase(path string) string {
return path
}
// SplitSuffix implements Renderer.
func (UnixRenderer) SplitSuffix(path string) (string, string) {
return splitSuffix(path)
}
================================================
FILE: filepath/unix_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath_test
import (
gofilepath "path/filepath"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/filepath"
)
var _ = gc.Suite(&unixSuite{})
var _ = gc.Suite(&unixThinWrapperSuite{})
type unixBaseSuite struct {
testing.IsolationSuite
path string
renderer *filepath.UnixRenderer
}
func (s *unixBaseSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
s.path = "/a/b/c.xyz"
s.renderer = &filepath.UnixRenderer{}
}
func (s *unixBaseSuite) matchesRuntime() bool {
return gofilepath.Separator == filepath.UnixSeparator
}
type unixSuite struct {
unixBaseSuite
}
func (s unixSuite) TestIsAbs(c *gc.C) {
isAbs := s.renderer.IsAbs(s.path)
c.Check(isAbs, jc.IsTrue)
if s.matchesRuntime() {
c.Check(isAbs, gc.Equals, gofilepath.IsAbs(s.path))
}
}
func (s unixSuite) TestSplitList(c *gc.C) {
list := s.renderer.SplitList("/a:b:/c/d")
c.Check(list, jc.DeepEquals, []string{"/a", "b", "/c/d"})
if s.matchesRuntime() {
golist := gofilepath.SplitList("/a:b:/c/d")
c.Check(list, jc.DeepEquals, golist)
}
}
func (s unixSuite) TestVolumeName(c *gc.C) {
volumeName := s.renderer.VolumeName(s.path)
c.Check(volumeName, gc.Equals, "")
}
func (s unixSuite) TestNormCaseLower(c *gc.C) {
normalized := s.renderer.NormCase("spam")
c.Check(normalized, gc.Equals, "spam")
}
func (s unixSuite) TestNormCaseUpper(c *gc.C) {
normalized := s.renderer.NormCase("SPAM")
c.Check(normalized, gc.Equals, "SPAM")
}
func (s unixSuite) TestNormCaseMixed(c *gc.C) {
normalized := s.renderer.NormCase("sPaM")
c.Check(normalized, gc.Equals, "sPaM")
}
func (s unixSuite) TestNormCaseCapitalized(c *gc.C) {
normalized := s.renderer.NormCase("Spam")
c.Check(normalized, gc.Equals, "Spam")
}
func (s unixSuite) TestNormCasePunctuation(c *gc.C) {
normalized := s.renderer.NormCase("spam-eggs.ext")
c.Check(normalized, gc.Equals, "spam-eggs.ext")
}
func (s unixSuite) TestSplitSuffix(c *gc.C) {
// This is just a sanity check. The splitSuffix tests are more
// comprehensive.
path, suffix := s.renderer.SplitSuffix("spam.ext")
c.Check(path, gc.Equals, "spam")
c.Check(suffix, gc.Equals, ".ext")
}
// unixThinWrapperSuite contains test methods for UnixRenderer methods
// that are just thin wrappers around the corresponding helpers in the
// filepath package. As such the test coverage is minimal (more of a
// sanity check).
type unixThinWrapperSuite struct {
unixBaseSuite
}
func (s unixThinWrapperSuite) TestBase(c *gc.C) {
path := s.renderer.Base(s.path)
c.Check(path, gc.Equals, "c.xyz")
if s.matchesRuntime() {
gopath := gofilepath.Base(s.path)
c.Check(path, gc.Equals, gopath)
}
}
func (s unixThinWrapperSuite) TestClean(c *gc.C) {
// TODO(ericsnow) Add more cases.
originals := map[string]string{
s.path: s.path,
}
for original, expected := range originals {
c.Logf("checking %q", original)
path := s.renderer.Clean(original)
c.Check(path, gc.Equals, expected)
if s.matchesRuntime() {
gopath := gofilepath.Clean(original)
c.Check(path, gc.Equals, gopath)
}
}
}
func (s unixThinWrapperSuite) TestDir(c *gc.C) {
path := s.renderer.Dir(s.path)
c.Check(path, gc.Equals, "/a/b")
if s.matchesRuntime() {
gopath := gofilepath.Dir(s.path)
c.Check(path, gc.Equals, gopath)
}
}
func (s unixThinWrapperSuite) TestExt(c *gc.C) {
ext := s.renderer.Ext(s.path)
c.Check(ext, gc.Equals, ".xyz")
if s.matchesRuntime() {
goext := gofilepath.Ext(s.path)
c.Check(ext, gc.Equals, goext)
}
}
func (s unixThinWrapperSuite) TestFromSlash(c *gc.C) {
original := "/a/b/c.xyz"
path := s.renderer.FromSlash(original)
c.Check(path, gc.Equals, s.path)
if s.matchesRuntime() {
gopath := gofilepath.FromSlash(original)
c.Check(path, gc.Equals, gopath)
}
}
func (s unixThinWrapperSuite) TestJoin(c *gc.C) {
path := s.renderer.Join("a", "b", "c.xyz")
c.Check(path, gc.Equals, s.path[1:])
if s.matchesRuntime() {
gopath := gofilepath.Join("a", "b", "c.xyz")
c.Check(path, gc.Equals, gopath)
}
}
func (s unixThinWrapperSuite) TestSplit(c *gc.C) {
dir, base := s.renderer.Split(s.path)
c.Check(dir, gc.Equals, "/a/b/")
c.Check(base, gc.Equals, "c.xyz")
if s.matchesRuntime() {
godir, gobase := gofilepath.Split(s.path)
c.Check(dir, gc.Equals, godir)
c.Check(base, gc.Equals, gobase)
}
}
func (s unixThinWrapperSuite) TestToSlash(c *gc.C) {
path := s.renderer.ToSlash(s.path)
c.Check(path, gc.Equals, "/a/b/c.xyz")
if s.matchesRuntime() {
gopath := gofilepath.ToSlash(s.path)
c.Check(path, gc.Equals, gopath)
}
}
func (s unixThinWrapperSuite) TestMatchTrue(c *gc.C) {
tests := map[string]string{
"abc": "abc",
"ab[c]": "abc",
"": "",
"*": "abc",
"a*c": "abc",
"?": "a",
"a?c": "abc",
}
for pattern, name := range tests {
c.Logf("- checking pattern %q against %q -", pattern, name)
matched, err := s.renderer.Match(pattern, name)
c.Assert(err, jc.ErrorIsNil)
c.Check(matched, jc.IsTrue)
if s.matchesRuntime() {
gomatched, err := gofilepath.Match(pattern, name)
c.Assert(err, jc.ErrorIsNil)
c.Check(matched, gc.Equals, gomatched)
}
}
}
func (s unixThinWrapperSuite) TestMatchFalse(c *gc.C) {
tests := map[string]string{
"abc": "xyz",
"": "abc",
"a*c": "a",
"?": "",
"a?c": "ac",
}
for pattern, name := range tests {
c.Logf("- checking pattern %q against %q -", pattern, name)
matched, err := s.renderer.Match(pattern, name)
c.Assert(err, jc.ErrorIsNil)
c.Check(matched, jc.IsFalse)
if s.matchesRuntime() {
gomatched, err := gofilepath.Match(pattern, name)
c.Assert(err, jc.ErrorIsNil)
c.Check(matched, gc.Equals, gomatched)
}
}
}
func (s unixThinWrapperSuite) TestMatchBadPattern(c *gc.C) {
tests := map[string]string{
"ab[": "abc",
"ab[-c]": "abc",
"ab[]": "abc",
}
for pattern, name := range tests {
c.Logf("- checking pattern %q against %q -", pattern, name)
_, err := s.renderer.Match(pattern, name)
c.Check(err, gc.Equals, gofilepath.ErrBadPattern)
if s.matchesRuntime() {
_, goerr := gofilepath.Match(pattern, name)
c.Check(err, gc.Equals, goerr)
}
}
}
================================================
FILE: filepath/win.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath
import (
"strings"
)
// A substantial portion of this code comes from the Go stdlib code.
const (
WindowsSeparator = '\\' // OS-specific path separator
WindowsListSeparator = ';' // OS-specific path list separator
)
// WindowsRenderer is a Renderer implementation for Windows.
type WindowsRenderer struct{}
// Base implements Renderer.
func (ur WindowsRenderer) Base(path string) string {
return Base(WindowsSeparator, ur.VolumeName, path)
}
// Clean implements Renderer.
func (ur WindowsRenderer) Clean(path string) string {
return Clean(WindowsSeparator, ur.VolumeName, path)
}
// Dir implements Renderer.
func (ur WindowsRenderer) Dir(path string) string {
return Dir(WindowsSeparator, ur.VolumeName, path)
}
// Ext implements Renderer.
func (WindowsRenderer) Ext(path string) string {
return Ext(WindowsSeparator, path)
}
// FromSlash implements Renderer.
func (WindowsRenderer) FromSlash(path string) string {
return FromSlash(WindowsSeparator, path)
}
// IsAbs implements Renderer.
func (WindowsRenderer) IsAbs(path string) bool {
l := volumeNameLen(path)
if l == 0 {
return false
}
path = path[l:]
if path == "" {
return false
}
return isSlash(path[0])
}
// Join implements Renderer.
func (ur WindowsRenderer) Join(path ...string) string {
return Join(WindowsSeparator, ur.VolumeName, path...)
}
// Match implements Renderer.
func (WindowsRenderer) Match(pattern, name string) (matched bool, err error) {
return Match(WindowsSeparator, pattern, name)
}
// Split implements Renderer.
func (ur WindowsRenderer) Split(path string) (dir, file string) {
return Split(WindowsSeparator, ur.VolumeName, path)
}
// SplitList implements Renderer.
func (WindowsRenderer) SplitList(path string) []string {
if path == "" {
return []string{}
}
// Split path, respecting but preserving quotes.
list := []string{}
start := 0
quo := false
for i := 0; i < len(path); i++ {
switch c := path[i]; {
case c == '"':
quo = !quo
case c == WindowsListSeparator && !quo:
list = append(list, path[start:i])
start = i + 1
}
}
list = append(list, path[start:])
// Remove quotes.
for i, s := range list {
if strings.Contains(s, `"`) {
list[i] = strings.Replace(s, `"`, ``, -1)
}
}
return list
}
// ToSlash implements Renderer.
func (WindowsRenderer) ToSlash(path string) string {
return ToSlash(WindowsSeparator, path)
}
// VolumeName implements Renderer.
func (WindowsRenderer) VolumeName(path string) string {
return path[:volumeNameLen(path)]
}
// NormCase implements Renderer.
func (WindowsRenderer) NormCase(path string) string {
return strings.ToLower(path)
}
// SplitSuffix implements Renderer.
func (WindowsRenderer) SplitSuffix(path string) (string, string) {
return splitSuffix(path)
}
func isSlash(c uint8) bool {
return c == WindowsSeparator || c == '/'
}
// volumeNameLen returns length of the leading volume name on Windows.
// It returns 0 elsewhere.
func volumeNameLen(path string) int {
if len(path) < 2 {
return 0
}
// with drive letter
c := path[0]
if path[1] == ':' && ('a' <= c && c <= 'z' || 'A' <= c && c <= 'Z') {
return 2
}
// is it UNC
if l := len(path); l >= 5 && isSlash(path[0]) && isSlash(path[1]) &&
!isSlash(path[2]) && path[2] != '.' {
// first, leading `\\` and next shouldn't be `\`. its server name.
for n := 3; n < l-1; n++ {
// second, next '\' shouldn't be repeated.
if isSlash(path[n]) {
n++
// third, following something characters. its share name.
if !isSlash(path[n]) {
if path[n] == '.' {
break
}
for ; n < l; n++ {
if isSlash(path[n]) {
break
}
}
return n
}
break
}
}
}
return 0
}
================================================
FILE: filepath/win_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filepath_test
import (
gofilepath "path/filepath"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/filepath"
)
var _ = gc.Suite(&windowsSuite{})
var _ = gc.Suite(&windowsThinWrapperSuite{})
type windowsBaseSuite struct {
testing.IsolationSuite
path string
renderer *filepath.WindowsRenderer
}
func (s *windowsBaseSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
s.path = `c:\a\b\c.xyz`
s.renderer = &filepath.WindowsRenderer{}
}
func (s *windowsBaseSuite) matchesRuntime() bool {
return gofilepath.Separator == filepath.WindowsSeparator
}
type windowsSuite struct {
windowsBaseSuite
}
func (s windowsSuite) TestIsAbs(c *gc.C) {
isAbs := s.renderer.IsAbs(s.path)
c.Check(isAbs, jc.IsTrue)
if s.matchesRuntime() {
c.Check(isAbs, gc.Equals, gofilepath.IsAbs(s.path))
}
}
func (s windowsSuite) TestSplitList(c *gc.C) {
list := s.renderer.SplitList(`\a;b;\c\d`)
c.Check(list, jc.DeepEquals, []string{`\a`, "b", `\c\d`})
if s.matchesRuntime() {
golist := gofilepath.SplitList(`\a;b;\c\d`)
c.Check(list, jc.DeepEquals, golist)
}
}
func (s windowsSuite) TestVolumeName(c *gc.C) {
volumeName := s.renderer.VolumeName(s.path)
c.Check(volumeName, gc.Equals, "c:")
if s.matchesRuntime() {
goresult := gofilepath.VolumeName(s.path)
c.Check(volumeName, gc.Equals, goresult)
}
}
func (s windowsSuite) TestNormCaseLower(c *gc.C) {
normalized := s.renderer.NormCase("spam")
c.Check(normalized, gc.Equals, "spam")
}
func (s windowsSuite) TestNormCaseUpper(c *gc.C) {
normalized := s.renderer.NormCase("SPAM")
c.Check(normalized, gc.Equals, "spam")
}
func (s windowsSuite) TestNormCaseMixed(c *gc.C) {
normalized := s.renderer.NormCase("sPaM")
c.Check(normalized, gc.Equals, "spam")
}
func (s windowsSuite) TestNormCaseCapitalized(c *gc.C) {
normalized := s.renderer.NormCase("Spam")
c.Check(normalized, gc.Equals, "spam")
}
func (s windowsSuite) TestNormCasePunctuation(c *gc.C) {
normalized := s.renderer.NormCase("spam-eggs.ext")
c.Check(normalized, gc.Equals, "spam-eggs.ext")
}
func (s windowsSuite) TestSplitSuffix(c *gc.C) {
// This is just a sanity check. The splitSuffix tests are more
// comprehensive.
path, suffix := s.renderer.SplitSuffix("spam.ext")
c.Check(path, gc.Equals, "spam")
c.Check(suffix, gc.Equals, ".ext")
}
// windowsThinWrapperSuite contains test methods for WindowsRenderer methods
// that are just thin wrappers around the corresponding helpers in the
// filepath package. As such the test coverage is minimal (more of a
// sanity check).
type windowsThinWrapperSuite struct {
windowsBaseSuite
}
func (s windowsThinWrapperSuite) TestBase(c *gc.C) {
path := s.renderer.Base(s.path)
c.Check(path, gc.Equals, "c.xyz")
if s.matchesRuntime() {
gopath := gofilepath.Base(s.path)
c.Check(path, gc.Equals, gopath)
}
}
func (s windowsThinWrapperSuite) TestClean(c *gc.C) {
// TODO(ericsnow) Add more cases.
originals := map[string]string{
s.path: s.path,
}
for original, expected := range originals {
c.Logf("checking %q", original)
path := s.renderer.Clean(original)
c.Check(path, gc.Equals, expected)
if s.matchesRuntime() {
gopath := gofilepath.Clean(original)
c.Check(path, gc.Equals, gopath)
}
}
}
func (s windowsThinWrapperSuite) TestDir(c *gc.C) {
path := s.renderer.Dir(s.path)
c.Check(path, gc.Equals, `c:\a\b`)
if s.matchesRuntime() {
gopath := gofilepath.Dir(s.path)
c.Check(path, gc.Equals, gopath)
}
}
func (s windowsThinWrapperSuite) TestExt(c *gc.C) {
ext := s.renderer.Ext(s.path)
c.Check(ext, gc.Equals, ".xyz")
if s.matchesRuntime() {
goext := gofilepath.Ext(s.path)
c.Check(ext, gc.Equals, goext)
}
}
func (s windowsThinWrapperSuite) TestFromSlash(c *gc.C) {
original := "/a/b/c.xyz"
path := s.renderer.FromSlash(original)
c.Check(path, gc.Equals, s.path[2:])
if s.matchesRuntime() {
gopath := gofilepath.FromSlash(original)
c.Check(path, gc.Equals, gopath)
}
}
func (s windowsThinWrapperSuite) TestJoin(c *gc.C) {
path := s.renderer.Join("a", "b", "c.xyz")
c.Check(path, gc.Equals, s.path[3:])
if s.matchesRuntime() {
gopath := gofilepath.Join("a", "b", "c.xyz")
c.Check(path, gc.Equals, gopath)
}
}
func (s windowsThinWrapperSuite) TestSplit(c *gc.C) {
dir, base := s.renderer.Split(s.path)
c.Check(dir, gc.Equals, `c:\a\b\`)
c.Check(base, gc.Equals, "c.xyz")
if s.matchesRuntime() {
godir, gobase := gofilepath.Split(s.path)
c.Check(dir, gc.Equals, godir)
c.Check(base, gc.Equals, gobase)
}
}
func (s windowsThinWrapperSuite) TestToSlash(c *gc.C) {
path := s.renderer.ToSlash(s.path)
c.Check(path, gc.Equals, "c:/a/b/c.xyz")
if s.matchesRuntime() {
gopath := gofilepath.ToSlash(s.path)
c.Check(path, gc.Equals, gopath)
}
}
func (s windowsThinWrapperSuite) TestMatchTrue(c *gc.C) {
tests := map[string]string{
"abc": "abc",
"ab[c]": "abc",
"": "",
"*": "abc",
"a*c": "abc",
"?": "a",
"a?c": "abc",
}
for pattern, name := range tests {
c.Logf("- checking pattern %q against %q -", pattern, name)
matched, err := s.renderer.Match(pattern, name)
c.Assert(err, jc.ErrorIsNil)
c.Check(matched, jc.IsTrue)
if s.matchesRuntime() {
gomatched, err := gofilepath.Match(pattern, name)
c.Assert(err, jc.ErrorIsNil)
c.Check(matched, gc.Equals, gomatched)
}
}
}
func (s windowsThinWrapperSuite) TestMatchFalse(c *gc.C) {
tests := map[string]string{
"abc": "xyz",
"": "abc",
"a*c": "a",
"?": "",
"a?c": "ac",
}
for pattern, name := range tests {
c.Logf("- checking pattern %q against %q -", pattern, name)
matched, err := s.renderer.Match(pattern, name)
c.Assert(err, jc.ErrorIsNil)
c.Check(matched, jc.IsFalse)
if s.matchesRuntime() {
gomatched, err := gofilepath.Match(pattern, name)
c.Assert(err, jc.ErrorIsNil)
c.Check(matched, gc.Equals, gomatched)
}
}
}
func (s windowsThinWrapperSuite) TestMatchBadPattern(c *gc.C) {
tests := map[string]string{
"ab[": "abc",
"ab[-c]": "abc",
"ab[]": "abc",
}
for pattern, name := range tests {
c.Logf("- checking pattern %q against %q -", pattern, name)
_, err := s.renderer.Match(pattern, name)
c.Check(err, gc.Equals, gofilepath.ErrBadPattern)
if s.matchesRuntime() {
_, goerr := gofilepath.Match(pattern, name)
c.Check(err, gc.Equals, goerr)
}
}
}
================================================
FILE: filestorage/doc.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
/*
utils/filestorage provides types for abstracting and implementing a
system that stores files, including their metadata.
Each file in the system is identified by a unique ID, determined by the
system at the time the file is stored.
File metadata includes such information as the size of the file, its
checksum, and when it was created. Regardless of how it is stored in
the system, at the abstraction level it is represented as a document.
Metadata can exist in the system without an associated file. However,
every file must have a corresponding metadata doc stored in the system.
A file can be added for a metadata doc that does not have one already.
The main type is the FileStorage interface. It exposes the core
functionality of such a system. This includes adding/removing files,
retrieving them or their metadata, and listing all files in the system.
The package also provides a basic implementation of FileStorage,
available through NewFileStorage(). This implementation simply wraps
two more focused systems: doc storage and raw file storage. The wrapper
uses the doc storage to store the metadata and raw file storage to
store the files.
The two subsystems are exposed via corresponding interfaces: DocStorage
(and its specialization MetadataStorage) and RawFileStorage. While a
single type could implement both, in practice they will be separate.
The doc storage is responsible to generating the unique IDs. The raw
file storage defers to the doc storage for any information about the
file, including the ID.
*/
package filestorage
================================================
FILE: filestorage/export_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filestorage
================================================
FILE: filestorage/fakes_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filestorage_test
import (
"io"
"github.com/juju/errors"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/filestorage"
)
// FakeMetadataStorage is used as a DocStorage and MetadataStorage for
// testing purposes.
type FakeMetadataStorage struct {
calls []string
id string
meta filestorage.Metadata
metaList []filestorage.Metadata
err error
idArg string
metaArg filestorage.Metadata
}
// Check verfies the state of the fake.
func (s *FakeMetadataStorage) Check(c *gc.C, id string, meta filestorage.Metadata, calls ...string) {
c.Check(s.calls, jc.DeepEquals, calls)
c.Check(s.idArg, gc.Equals, id)
c.Check(s.metaArg, gc.Equals, meta)
}
func (s *FakeMetadataStorage) Doc(id string) (filestorage.Document, error) {
s.calls = append(s.calls, "Doc")
s.idArg = id
if s.err != nil {
return nil, s.err
}
return s.meta, nil
}
func (s *FakeMetadataStorage) ListDocs() ([]filestorage.Document, error) {
s.calls = append(s.calls, "ListDoc")
if s.err != nil {
return nil, s.err
}
var docs []filestorage.Document
for _, doc := range s.metaList {
docs = append(docs, doc)
}
return docs, nil
}
func (s *FakeMetadataStorage) AddDoc(doc filestorage.Document) (string, error) {
s.calls = append(s.calls, "AddDoc")
meta, err := filestorage.Convert(doc)
if err != nil {
return "", errors.Trace(err)
}
s.metaArg = meta
return s.id, nil
}
func (s *FakeMetadataStorage) RemoveDoc(id string) error {
s.calls = append(s.calls, "RemoveDoc")
s.idArg = id
return s.err
}
func (s *FakeMetadataStorage) Close() error {
s.calls = append(s.calls, "Close")
return s.err
}
func (s *FakeMetadataStorage) Metadata(id string) (filestorage.Metadata, error) {
s.calls = append(s.calls, "Metadata")
s.idArg = id
if s.err != nil {
return nil, s.err
}
return s.meta, nil
}
func (s *FakeMetadataStorage) ListMetadata() ([]filestorage.Metadata, error) {
s.calls = append(s.calls, "ListMetadata")
if s.err != nil {
return nil, s.err
}
return s.metaList, nil
}
func (s *FakeMetadataStorage) AddMetadata(meta filestorage.Metadata) (string, error) {
s.calls = append(s.calls, "AddMetadata")
s.metaArg = meta
if s.err != nil {
return "", s.err
}
return s.id, nil
}
func (s *FakeMetadataStorage) RemoveMetadata(id string) error {
s.calls = append(s.calls, "RemoveMetadata")
s.idArg = id
return s.err
}
func (s *FakeMetadataStorage) SetStored(id string) error {
s.calls = append(s.calls, "SetStored")
s.idArg = id
return s.err
}
// FakeRawFileStorage is used in testing as a RawFileStorage.
type FakeRawFileStorage struct {
calls []string
file io.ReadCloser
err error
idArg string
fileArg io.Reader
sizeArg int64
}
// Check verfies the state of the fake.
func (s *FakeRawFileStorage) Check(c *gc.C, id string, file io.Reader, size int64, calls ...string) {
c.Check(s.calls, jc.DeepEquals, calls)
c.Check(s.idArg, gc.Equals, id)
c.Check(s.fileArg, gc.Equals, file)
c.Check(s.sizeArg, gc.Equals, size)
}
// CheckNotUsed verifies that the fake was not used.
func (s *FakeRawFileStorage) CheckNotUsed(c *gc.C) {
s.Check(c, "", nil, 0)
}
func (s *FakeRawFileStorage) File(id string) (io.ReadCloser, error) {
s.calls = append(s.calls, "File")
s.idArg = id
if s.err != nil {
return nil, s.err
}
return s.file, nil
}
func (s *FakeRawFileStorage) AddFile(id string, file io.Reader, size int64) error {
s.calls = append(s.calls, "AddFile")
s.idArg = id
s.fileArg = file
s.sizeArg = size
return s.err
}
func (s *FakeRawFileStorage) RemoveFile(id string) error {
s.calls = append(s.calls, "RemoveFile")
s.idArg = id
return s.err
}
func (s *FakeRawFileStorage) Close() error {
s.calls = append(s.calls, "Close")
return s.err
}
================================================
FILE: filestorage/interfaces.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filestorage
import (
"io"
"time"
)
// FileStorage is an abstraction that can be used for the storage of files.
type FileStorage interface {
io.Closer
// Metadata returns a file's metadata.
Metadata(id string) (Metadata, error)
// Get returns a file and its metadata.
Get(id string) (Metadata, io.ReadCloser, error)
// List returns the metadata for each stored file.
List() ([]Metadata, error)
// Add stores a file and its metadata.
Add(meta Metadata, archive io.Reader) (string, error)
// SetFile stores a file for an existing metadata entry.
SetFile(id string, file io.Reader) error
// Remove removes a file from storage.
Remove(id string) error
}
// Document represents a document that can be identified uniquely
// by a string.
type Document interface {
// ID returns the unique ID of the document.
ID() string
// SetID sets the ID of the document. If the ID is already set,
// SetID() should return true (false otherwise).
SetID(id string) (alreadySet bool)
}
// Metadata is the meta information for a stored file.
type Metadata interface {
Document
// Size is the size of the file (in bytes).
Size() int64
// Checksum is the checksum for the file.
Checksum() string
// ChecksumFormat is the kind (and encoding) of checksum.
ChecksumFormat() string
// Stored returns when the file was last stored. If it has not been
// stored yet, nil is returned. If it has been stored but the
// timestamp is not available, a zero value is returned
// (see Time.IsZero).
Stored() *time.Time
// SetFileInfo sets the file info on the metadata.
SetFileInfo(size int64, checksum, checksumFormat string) error
// SetStored records when the file was last stored. If the previous
// value matters, be sure to call Stored() first.
SetStored(timestamp *time.Time)
}
// DocStorage is an abstraction for a system that can store docs (structs).
// The system is expected to generate its own unique ID for each doc.
type DocStorage interface {
io.Closer
// Doc returns the doc that matches the ID. If there is no match,
// an error is returned (see errors.IsNotFound). Any other problem
// also results in an error.
Doc(id string) (Document, error)
// ListDocs returns a list of all the docs in the storage.
ListDocs() ([]Document, error)
// AddDoc adds the doc to the storage. If successful, the storage-
// generated ID for the doc is returned. Otherwise an error is
// returned.
AddDoc(doc Document) (string, error)
// RemoveDoc removes the matching doc from the storage. If there
// is no match an error is returned (see errors.IsNotFound). Any
// other problem also results in an error.
RemoveDoc(id string) error
}
// RawFileStorage is an abstraction around a system that can store files.
// The system is expected to rely on the user for unique IDs.
type RawFileStorage interface {
io.Closer
// File returns the matching file. If there is no match an error is
// returned (see errors.IsNotFound). Any other problem also results
// in an error.
File(id string) (io.ReadCloser, error)
// AddFile adds the file to the storage. If it fails to do so,
// it returns an error. If a file is already stored for the ID,
// AddFile() fails (see errors.IsAlreadyExists).
AddFile(id string, file io.Reader, size int64) error
// RemoveFile removes the matching file from the storage. It fails
// if there is no such file (see errors.IsNotFound). Any other problem
// also results in an error.
RemoveFile(id string) error
}
// MetadataStorage is an extension of DocStorage adapted to file metadata.
type MetadataStorage interface {
io.Closer
// Metadata returns the matching Metadata. It fails if there is no
// match (see errors.IsNotFound). Any other problems likewise
// results in an error.
Metadata(id string) (Metadata, error)
// ListMetadata returns a list of all metadata in the storage.
ListMetadata() ([]Metadata, error)
// AddMetadata adds the metadata to the storage. If successful, the
// storage-generated ID for the metadata is returned. Otherwise an
// error is returned.
AddMetadata(meta Metadata) (string, error)
// RemoveMetadata removes the matching metadata from the storage.
// If there is no match an error is returned (see errors.IsNotFound).
// Any other problem also results in an error.
RemoveMetadata(id string) error
// SetStored updates the stored metadata to indicate that the
// associated file has been successfully stored in a RawFileStorage
// system. If it does not find a stored metadata with the matching
// ID, it will return an error (see errors.IsNotFound). It also
// returns an error if it fails to update the stored metadata.
SetStored(id string) error
}
================================================
FILE: filestorage/metadata.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filestorage
import (
"time"
"github.com/juju/errors"
)
// RawDoc is a basic, uniquely identifiable document.
type RawDoc struct {
// ID is the unique identifier for the document.
ID string
}
// Doc wraps a document in the Document interface.
type Doc struct {
Raw RawDoc
}
// ID returns the document's unique identifier.
func (d *Doc) ID() string {
return d.Raw.ID
}
// SetID sets the document's unique identifier. If the ID is already
// set, SetID() returns true (false otherwise).
func (d *Doc) SetID(id string) bool {
if d.Raw.ID != "" {
return true
}
d.Raw.ID = id
return false
}
// RawFileMetadata holds info specific to stored files.
type RawFileMetadata struct {
// Size is the size (in bytes) of the stored file.
Size int64
// Checksum is the checksum of the stored file.
Checksum string
// ChecksumFormat describes the kind of the checksum.
ChecksumFormat string
// Stored records the timestamp of when the file was last stored.
Stored *time.Time
}
// FileMetadata contains the metadata for a single stored file.
type FileMetadata struct {
Doc
Raw RawFileMetadata
}
// NewMetadata returns a new Metadata for a stored file.
func NewMetadata() *FileMetadata {
meta := FileMetadata{}
return &meta
}
func (m *FileMetadata) Size() int64 {
return m.Raw.Size
}
func (m *FileMetadata) Checksum() string {
return m.Raw.Checksum
}
func (m *FileMetadata) ChecksumFormat() string {
return m.Raw.ChecksumFormat
}
func (m *FileMetadata) Stored() *time.Time {
return m.Raw.Stored
}
func (m *FileMetadata) SetFileInfo(size int64, checksum, format string) error {
// Fall back to existing values.
if size == 0 {
size = m.Raw.Size
}
if checksum == "" {
checksum = m.Raw.Checksum
}
if format == "" {
format = m.Raw.ChecksumFormat
}
if checksum != "" {
if format == "" {
return errors.Errorf("missing checksum format")
}
} else if format != "" {
return errors.Errorf("missing checksum")
}
// Only allow setting once.
if m.Raw.Size != 0 && size != m.Raw.Size {
return errors.Errorf("file information (size) already set")
}
if m.Raw.Checksum != "" && checksum != m.Raw.Checksum {
return errors.Errorf("file information (checksum) already set")
}
if m.Raw.ChecksumFormat != "" && format != m.Raw.ChecksumFormat {
return errors.Errorf("file information (checksum format) already set")
}
// Set the values.
m.Raw.Size = size
m.Raw.Checksum = checksum
m.Raw.ChecksumFormat = format
return nil
}
func (m *FileMetadata) SetStored(timestamp *time.Time) {
if timestamp == nil {
now := time.Now().UTC()
m.Raw.Stored = &now
} else {
m.Raw.Stored = timestamp
}
}
================================================
FILE: filestorage/metadata_store.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filestorage
import (
"github.com/juju/errors"
)
// Convert turns a Document into a Metadata if possible.
func Convert(doc Document) (Metadata, error) {
meta, ok := doc.(Metadata)
if !ok {
return nil, errors.Errorf("expected a Metadata doc, got %v", doc)
}
return meta, nil
}
// MetadataDocStorage provides the MetadataStorage methods than can be
// derived from DocStorage methods. To fully implement MetadataStorage,
// this type must be embedded in a type that implements the remaining
// methods.
type MetadataDocStorage struct {
DocStorage
}
// Metadata implements MetadataStorage.Metadata.
func (s *MetadataDocStorage) Metadata(id string) (Metadata, error) {
doc, err := s.Doc(id)
if err != nil {
return nil, errors.Trace(err)
}
meta, err := Convert(doc)
return meta, errors.Trace(err)
}
// ListMetadata implements MetadataStorage.ListMetadata.
func (s *MetadataDocStorage) ListMetadata() ([]Metadata, error) {
docs, err := s.ListDocs()
if err != nil {
return nil, errors.Trace(err)
}
var metaList []Metadata
for _, doc := range docs {
if doc == nil {
continue
}
meta, err := Convert(doc)
if err != nil {
return nil, errors.Trace(err)
}
metaList = append(metaList, meta)
}
return metaList, nil
}
// ListMetadata implements MetadataStorage.ListMetadata.
func (s *MetadataDocStorage) AddMetadata(meta Metadata) (string, error) {
id, err := s.AddDoc(meta)
return id, errors.Trace(err)
}
// ListMetadata implements MetadataStorage.ListMetadata.
func (s *MetadataDocStorage) RemoveMetadata(id string) error {
return errors.Trace(s.RemoveDoc(id))
}
================================================
FILE: filestorage/metadata_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filestorage_test
import (
"time"
"github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/filestorage"
)
var (
_ filestorage.Document = (*filestorage.Doc)(nil)
_ filestorage.Metadata = (*filestorage.FileMetadata)(nil)
)
var _ = gc.Suite(&MetadataSuite{})
type MetadataSuite struct {
testing.IsolationSuite
}
func (s *MetadataSuite) TestFileMetadataNewMetadata(c *gc.C) {
meta := filestorage.NewMetadata()
c.Check(meta.ID(), gc.Equals, "")
c.Check(meta.Size(), gc.Equals, int64(0))
c.Check(meta.Checksum(), gc.Equals, "")
c.Check(meta.ChecksumFormat(), gc.Equals, "")
c.Check(meta.Stored(), gc.IsNil)
}
func (s *MetadataSuite) TestFileMetadataSetIDInitial(c *gc.C) {
meta := filestorage.NewMetadata()
meta.SetFileInfo(10, "some sum", "SHA-1")
c.Assert(meta.ID(), gc.Equals, "")
success := meta.SetID("some id")
c.Check(success, gc.Equals, false)
c.Check(meta.ID(), gc.Equals, "some id")
}
func (s *MetadataSuite) TestFileMetadataSetIDAlreadySetSame(c *gc.C) {
meta := filestorage.NewMetadata()
meta.SetFileInfo(10, "some sum", "SHA-1")
success := meta.SetID("some id")
c.Assert(success, gc.Equals, false)
success = meta.SetID("some id")
c.Check(success, gc.Equals, true)
c.Check(meta.ID(), gc.Equals, "some id")
}
func (s *MetadataSuite) TestFileMetadataSetIDAlreadySetDifferent(c *gc.C) {
meta := filestorage.NewMetadata()
meta.SetFileInfo(10, "some sum", "SHA-1")
success := meta.SetID("some id")
c.Assert(success, gc.Equals, false)
success = meta.SetID("another id")
c.Check(success, gc.Equals, true)
c.Check(meta.ID(), gc.Equals, "some id")
}
func (s *MetadataSuite) TestFileMetadataSetFileInfo(c *gc.C) {
meta := filestorage.NewMetadata()
c.Assert(meta.Size(), gc.Equals, int64(0))
c.Assert(meta.Checksum(), gc.Equals, "")
c.Assert(meta.ChecksumFormat(), gc.Equals, "")
c.Assert(meta.Stored(), gc.IsNil)
meta.SetFileInfo(10, "some sum", "SHA-1")
c.Check(meta.Size(), gc.Equals, int64(10))
c.Check(meta.Checksum(), gc.Equals, "some sum")
c.Check(meta.ChecksumFormat(), gc.Equals, "SHA-1")
c.Check(meta.Stored(), gc.IsNil)
}
func (s *MetadataSuite) TestFileMetadataSetStored(c *gc.C) {
meta := filestorage.NewMetadata()
timestamp := time.Now().UTC()
meta.SetStored(×tamp)
c.Check(meta.Stored(), gc.Equals, ×tamp)
}
func (s *MetadataSuite) TestFileMetadataSetStoredDefault(c *gc.C) {
meta := filestorage.NewMetadata()
c.Assert(meta.Stored(), gc.IsNil)
meta.SetStored(nil)
c.Check(meta.Stored(), gc.NotNil)
}
================================================
FILE: filestorage/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filestorage_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: filestorage/wrapper.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filestorage
import (
"io"
"github.com/juju/errors"
)
// Ensure fileStorage implements FileStorage.
var _ = FileStorage((*fileStorage)(nil))
type fileStorage struct {
metaStorage MetadataStorage
rawStorage RawFileStorage
}
// NewFileStorage returns a new FileStorage value that wraps a
// MetadataStorage and a RawFileStorage. It coordinates the two even
// though they may not be designed to be compatible (or the two may be
// the same value).
//
// A stored file will always have a metadata value stored. However, it
// is not required to have a raw file stored.
func NewFileStorage(meta MetadataStorage, files RawFileStorage) FileStorage {
stor := fileStorage{
metaStorage: meta,
rawStorage: files,
}
return &stor
}
// Metadata returns the matching metadata. Failure to find it (see
// errors.IsNotFound) or any other problem results in an error.
func (s *fileStorage) Metadata(id string) (Metadata, error) {
meta, err := s.metaStorage.Metadata(id)
if err != nil {
return nil, errors.Trace(err)
}
return meta, nil
}
// Get returns the matching file and its associated metadata. If there
// is no match (see errors.IsNotFound) or any other problem, it returns
// an error. Both the metadata and file must have been stored for the
// file to be considered found.
func (s *fileStorage) Get(id string) (Metadata, io.ReadCloser, error) {
meta, err := s.Metadata(id)
if err != nil {
return nil, nil, errors.Trace(err)
}
if meta.Stored() == nil {
return nil, nil, errors.NotFoundf("no file stored for %q", id)
}
file, err := s.rawStorage.File(id)
if err != nil {
return nil, nil, errors.Trace(err)
}
return meta, file, nil
}
// List returns a list of the metadata for all files in the storage.
func (s *fileStorage) List() ([]Metadata, error) {
return s.metaStorage.ListMetadata()
}
func (s *fileStorage) addFile(id string, size int64, file io.Reader) error {
err := s.rawStorage.AddFile(id, file, size)
if err != nil {
return errors.Trace(err)
}
err = s.metaStorage.SetStored(id)
if err != nil {
return errors.Trace(err)
}
return nil
}
// Add adds the file to the storage. It returns the unique ID generated
// by the storage for the file. If no file is provided, only the
// metadata is stored. While the passed-in "meta" is not modified, the
// new ID and "stored" flag will be saved in metadata storage. Feel
// free to explicitly call meta.SetID() and meta.SetStored() afterward.
//
// Any problem (including an existing file, see errors.IsAlreadyExists)
// results in an error. If there is an error while storing either the
// file or metadata, neither will be stored.
func (s *fileStorage) Add(meta Metadata, file io.Reader) (string, error) {
id, err := s.metaStorage.AddMetadata(meta)
if err != nil {
return "", errors.Trace(err)
}
if file != nil {
err = s.addFile(id, meta.Size(), file)
if err != nil {
// Remove the metadata we just added.
context := err
err = s.metaStorage.RemoveMetadata(id)
if err != nil {
err = errors.Annotate(err, "while handling another error")
return "", errors.Wrap(context, err)
}
return "", errors.Trace(context)
}
}
return id, nil
}
// SetFile stores the raw file for an existing metadata. If there is no
// matching stored metadata an error is returned (see errors.IsNotFound).
// If a file has already been stored an error is returned (see
// errors.IsAlreadyExists). Any other failure to add the file also
// results in an error.
func (s *fileStorage) SetFile(id string, file io.Reader) error {
meta, err := s.Metadata(id)
if err != nil {
return errors.Trace(err)
}
err = s.addFile(id, meta.Size(), file)
if err != nil {
return errors.Trace(err)
}
return nil
}
// Remove removes both the metadata and raw file from the storage. If
// there is no match an error is returned (see errors.IsNotFound).
//
// The raw file is removed first. Thus if there is any problem after
// removing the raw file, the metadata will still be stored. However,
// in that case the stored metadata is not guaranteed to accurately
// represent that there is no corresponding raw file in storage.
func (s *fileStorage) Remove(id string) error {
err := s.rawStorage.RemoveFile(id)
if err != nil && !errors.IsNotFound(err) {
return errors.Trace(err)
}
err = s.metaStorage.RemoveMetadata(id)
if err != nil {
return errors.Trace(err)
}
return nil
}
// Close implements io.Closer.Close.
func (s *fileStorage) Close() error {
ferr := s.rawStorage.Close()
merr := s.metaStorage.Close()
if ferr == nil {
return errors.Trace(merr)
} else if merr == nil {
return errors.Trace(ferr)
} else {
msg := "closing both failed: metadata (%v) and files (%v)"
return errors.Errorf(msg, merr, ferr)
}
}
================================================
FILE: filestorage/wrapper_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package filestorage_test
import (
"bytes"
"io"
"io/ioutil"
"github.com/juju/errors"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/filestorage"
)
var _ = gc.Suite(&WrapperSuite{})
type WrapperSuite struct {
testing.IsolationSuite
rawstor *FakeRawFileStorage
metastor *FakeMetadataStorage
stor filestorage.FileStorage
}
func (s *WrapperSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
s.rawstor = &FakeRawFileStorage{}
s.metastor = &FakeMetadataStorage{}
s.stor = filestorage.NewFileStorage(s.metastor, s.rawstor)
}
func (s *WrapperSuite) metadata() filestorage.Metadata {
meta := filestorage.NewMetadata()
meta.SetFileInfo(10, "", "")
return meta
}
func (s *WrapperSuite) setMeta() (string, filestorage.Metadata) {
id := ""
meta := s.metadata()
meta.SetID(id)
s.metastor.meta = meta
s.metastor.metaList = append(s.metastor.metaList, meta)
return id, meta
}
func (s *WrapperSuite) setFile(data string) (string, filestorage.Metadata, io.ReadCloser) {
id, meta := s.setMeta()
file := ioutil.NopCloser(bytes.NewBufferString(data))
s.rawstor.file = file
meta.SetStored(nil)
return id, meta, file
}
func (s *WrapperSuite) TestFileStorageNewFileStorage(c *gc.C) {
stor := filestorage.NewFileStorage(s.metastor, s.rawstor)
c.Check(stor, gc.NotNil)
}
func (s *WrapperSuite) TestFileStorageMetadata(c *gc.C) {
id, original := s.setMeta()
meta, err := s.stor.Metadata(id)
c.Assert(err, gc.IsNil)
c.Check(meta, jc.DeepEquals, original)
s.metastor.Check(c, id, nil, "Metadata")
s.rawstor.CheckNotUsed(c)
}
func (s *WrapperSuite) TestFileStorageGet(c *gc.C) {
id, origmeta, origfile := s.setFile("spam")
meta, file, err := s.stor.Get(id)
c.Assert(err, gc.IsNil)
c.Check(meta, gc.Equals, origmeta)
c.Check(file, gc.Equals, origfile)
}
func (s *WrapperSuite) TestFileStorageListEmpty(c *gc.C) {
list, err := s.stor.List()
c.Assert(err, gc.IsNil)
c.Check(list, gc.HasLen, 0)
}
func (s *WrapperSuite) TestFileStorageListOne(c *gc.C) {
id, _ := s.setMeta()
list, err := s.stor.List()
c.Assert(err, gc.IsNil)
c.Check(list, gc.HasLen, 1)
c.Assert(list[0], gc.NotNil)
c.Check(list[0].ID(), gc.Equals, id)
}
func (s *WrapperSuite) TestFileStorageListTwo(c *gc.C) {
id1, _ := s.setMeta()
id2, _ := s.setMeta()
list, err := s.stor.List()
c.Assert(err, gc.IsNil)
c.Assert(list, gc.HasLen, 2)
c.Assert(list[0], gc.NotNil)
c.Assert(list[1], gc.NotNil)
if list[0].ID() == id1 {
c.Check(list[1].ID(), gc.Equals, id2)
} else {
c.Check(list[1].ID(), gc.Equals, id1)
}
}
func (s *WrapperSuite) TestFileStorageAddMeta(c *gc.C) {
s.metastor.id = ""
meta := s.metadata()
c.Assert(meta.ID(), gc.Equals, "")
id, err := s.stor.Add(meta, nil)
c.Assert(err, gc.IsNil)
c.Check(id, gc.Equals, "")
c.Check(meta.ID(), gc.Equals, "")
s.metastor.Check(c, "", meta, "AddMetadata")
s.rawstor.CheckNotUsed(c)
}
func (s *WrapperSuite) TestFileStorageAddFile(c *gc.C) {
s.metastor.id = ""
var file *bytes.Buffer
meta := s.metadata()
id, err := s.stor.Add(meta, file)
c.Assert(err, gc.IsNil)
c.Check(meta.ID(), gc.Equals, "")
c.Check(meta.Stored(), gc.IsNil)
c.Check(id, gc.Equals, "")
c.Check(meta.ID(), gc.Equals, "")
s.metastor.Check(c, id, meta, "AddMetadata", "SetStored")
s.rawstor.Check(c, id, file, 10, "AddFile")
}
func (s *WrapperSuite) TestFileStorageAddIDNotSet(c *gc.C) {
original := s.metadata()
c.Assert(original.ID(), gc.Equals, "")
_, err := s.stor.Add(original, nil)
c.Check(err, gc.IsNil)
c.Check(original.ID(), gc.Equals, "")
}
func (s *WrapperSuite) TestFileStorageAddMetaOnly(c *gc.C) {
id, original := s.setMeta()
meta, err := s.stor.Metadata(id)
c.Assert(err, gc.IsNil)
c.Check(meta, gc.Equals, original)
c.Check(meta.Stored(), gc.IsNil)
}
func (s *WrapperSuite) TestFileStorageAddIDAlreadySet(c *gc.C) {
original := s.metadata()
original.SetID("eggs")
_, err := s.stor.Add(original, nil)
c.Check(err, gc.IsNil) // This should be handled at the lower level.
}
func (s *WrapperSuite) TestFileStorageAddFileFailureDropsMetadata(c *gc.C) {
original := s.metadata()
failure := errors.New("failed!")
raw := &FakeRawFileStorage{err: failure}
stor := filestorage.NewFileStorage(s.metastor, raw)
_, err := stor.Add(original, &bytes.Buffer{})
c.Assert(errors.Cause(err), gc.Equals, failure)
metalist, metaErr := s.metastor.ListMetadata()
c.Assert(metaErr, gc.IsNil)
c.Check(metalist, gc.HasLen, 0)
c.Check(original.ID(), gc.Equals, "")
}
func (s *WrapperSuite) TestFileStorageSetFile(c *gc.C) {
id, _ := s.setMeta()
_, _, err := s.stor.Get(id)
c.Assert(err, gc.NotNil)
file := bytes.NewBufferString("spam")
err = s.stor.SetFile(id, file)
c.Assert(err, gc.IsNil)
s.metastor.Check(c, id, nil, "Metadata", "Metadata", "SetStored")
s.rawstor.Check(c, id, file, 10, "AddFile")
}
func (s *WrapperSuite) TestFileStorageRemove(c *gc.C) {
id := ""
err := s.stor.Remove(id)
c.Assert(err, gc.IsNil)
s.metastor.Check(c, id, nil, "RemoveMetadata")
s.rawstor.Check(c, id, nil, 0, "RemoveFile")
}
func (s *WrapperSuite) TestClose(c *gc.C) {
metaStor := &FakeMetadataStorage{}
fileStor := &FakeRawFileStorage{}
stor := filestorage.NewFileStorage(metaStor, fileStor)
err := stor.Close()
c.Assert(err, gc.IsNil)
c.Check(metaStor.calls, gc.DeepEquals, []string{"Close"})
c.Check(fileStor.calls, gc.DeepEquals, []string{"Close"})
}
================================================
FILE: fs/copy.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package fs
import (
"fmt"
"io"
"os"
"path/filepath"
)
// Copy recursively copies the file, directory or symbolic link at src
// to dst. The destination must not exist. Symbolic links are not
// followed.
//
// If the copy fails half way through, the destination might be left
// partially written.
func Copy(src, dst string) error {
srcInfo, srcErr := os.Lstat(src)
if srcErr != nil {
return srcErr
}
_, dstErr := os.Lstat(dst)
if dstErr == nil {
// TODO(rog) add a flag to permit overwriting?
return fmt.Errorf("will not overwrite %q", dst)
}
if !os.IsNotExist(dstErr) {
return dstErr
}
switch mode := srcInfo.Mode(); mode & os.ModeType {
case os.ModeSymlink:
return copySymLink(src, dst)
case os.ModeDir:
return copyDir(src, dst, mode)
case 0:
return copyFile(src, dst, mode)
default:
return fmt.Errorf("cannot copy file with mode %v", mode)
}
}
func copySymLink(src, dst string) error {
target, err := os.Readlink(src)
if err != nil {
return err
}
return os.Symlink(target, dst)
}
func copyFile(src, dst string, mode os.FileMode) error {
srcf, err := os.Open(src)
if err != nil {
return err
}
defer srcf.Close()
dstf, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode.Perm())
if err != nil {
return err
}
defer dstf.Close()
// Make the actual permissions match the source permissions
// even in the presence of umask.
if err := os.Chmod(dstf.Name(), mode.Perm()); err != nil {
return err
}
if _, err := io.Copy(dstf, srcf); err != nil {
return fmt.Errorf("cannot copy %q to %q: %v", src, dst, err)
}
return nil
}
func copyDir(src, dst string, mode os.FileMode) error {
srcf, err := os.Open(src)
if err != nil {
return err
}
defer srcf.Close()
if mode&0500 == 0 {
// The source directory doesn't have write permission,
// so give the new directory write permission anyway
// so that we have permission to create its contents.
// We'll make the permissions match at the end.
mode |= 0500
}
if err := os.Mkdir(dst, mode.Perm()); err != nil {
return err
}
for {
names, err := srcf.Readdirnames(100)
for _, name := range names {
if err := Copy(filepath.Join(src, name), filepath.Join(dst, name)); err != nil {
return err
}
}
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error reading directory %q: %v", src, err)
}
}
if err := os.Chmod(dst, mode.Perm()); err != nil {
return err
}
return nil
}
================================================
FILE: fs/copy_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package fs_test
import (
"path/filepath"
"testing"
ft "github.com/juju/testing/filetesting"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/fs"
)
type copySuite struct{}
var _ = gc.Suite(©Suite{})
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
var copyTests = []struct {
about string
src ft.Entries
dst ft.Entries
err string
}{{
about: "one file",
src: []ft.Entry{
ft.File{"file", "data", 0756},
},
}, {
about: "one directory",
src: []ft.Entry{
ft.Dir{"dir", 0777},
},
}, {
about: "one symlink",
src: []ft.Entry{
ft.Symlink{"link", "/foo"},
},
}, {
about: "several entries",
src: []ft.Entry{
ft.Dir{"top", 0755},
ft.File{"top/foo", "foodata", 0644},
ft.File{"top/bar", "bardata", 0633},
ft.Dir{"top/next", 0721},
ft.Symlink{"top/next/link", "../foo"},
ft.File{"top/next/another", "anotherdata", 0644},
},
}, {
about: "destination already exists",
src: []ft.Entry{
ft.Dir{"dir", 0777},
},
dst: []ft.Entry{
ft.Dir{"dir", 0777},
},
err: `will not overwrite ".+dir"`,
}, {
about: "source with unwritable directory",
src: []ft.Entry{
ft.Dir{"dir", 0555},
},
}}
func (*copySuite) TestCopy(c *gc.C) {
for i, test := range copyTests {
c.Logf("test %d: %v", i, test.about)
src, dst := c.MkDir(), c.MkDir()
test.src.Create(c, src)
test.dst.Create(c, dst)
path := test.src[0].GetPath()
err := fs.Copy(
filepath.Join(src, path),
filepath.Join(dst, path),
)
if test.err != "" {
c.Check(err, gc.ErrorMatches, test.err)
} else {
c.Assert(err, gc.IsNil)
test.src.Check(c, dst)
}
}
}
================================================
FILE: go.mod
================================================
module github.com/juju/utils/v4
go 1.24.4
require (
github.com/juju/clock v1.0.3
github.com/juju/collections v1.0.4
github.com/juju/errors v1.0.0
github.com/juju/loggo/v2 v2.0.0
github.com/juju/mutex/v2 v2.0.0
github.com/juju/testing v1.2.0
golang.org/x/crypto v0.39.0
golang.org/x/net v0.41.0
golang.org/x/text v0.26.0
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7
gopkg.in/yaml.v2 v2.4.0
)
require (
github.com/juju/loggo v1.0.0 // indirect
github.com/juju/utils/v3 v3.1.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/rogpeppe/go-internal v1.9.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/term v0.32.0 // indirect
)
================================================
FILE: go.sum
================================================
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g=
github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8=
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU=
github.com/juju/clock v1.0.3 h1:yJHIsWXeU8j3QcBdiess09SzfiXRRrsjKPn2whnMeds=
github.com/juju/clock v1.0.3/go.mod h1:HIBvJ8kiV/n7UHwKuCkdYL4l/MDECztHR2sAvWDxxf0=
github.com/juju/collections v1.0.4 h1:GjL+aN512m2rVDqhPII7P6qB0e+iYFubz8sqBhZaZtk=
github.com/juju/collections v1.0.4/go.mod h1:hVrdB0Zwq9wIU1Fl6ItD2+UETeNeOEs+nGvJufVe+0c=
github.com/juju/errors v1.0.0 h1:yiq7kjCLll1BiaRuNY53MGI0+EQ3rF6GB+wvboZDefM=
github.com/juju/errors v1.0.0/go.mod h1:B5x9thDqx0wIMH3+aLIMP9HjItInYWObRovoCFM5Qe8=
github.com/juju/loggo v1.0.0 h1:Y6ZMQOGR9Aj3BGkiWx7HBbIx6zNwNkxhVNOHU2i1bl0=
github.com/juju/loggo v1.0.0/go.mod h1:NIXFioti1SmKAlKNuUwbMenNdef59IF52+ZzuOmHYkg=
github.com/juju/loggo/v2 v2.0.0 h1:PzyVIn+NgoZ22QUtPgKF/lh+6SnaCOEXhcP+sE4FhOk=
github.com/juju/loggo/v2 v2.0.0/go.mod h1:647d6WvXBLj5lvka2qBvccr7vMIvF2KFkEH+0ZuFOUM=
github.com/juju/mutex/v2 v2.0.0 h1:rVmJdOaXGWF8rjcFHBNd4x57/1tks5CgXHx55O55SB0=
github.com/juju/mutex/v2 v2.0.0/go.mod h1:jwCfBs/smYDaeZLqeaCi8CB8M+tOes4yf827HoOEoqk=
github.com/juju/testing v1.2.0 h1:Q0wxjaxx4XPVEN+SgzxKr3d82pjmSBcuM3WndAU391c=
github.com/juju/testing v1.2.0/go.mod h1:lqZVzNwBKAbylGZidK77ts6kIdoOkmD52+4m0ysetPo=
github.com/juju/utils/v3 v3.1.0 h1:NrNo73oVtfr7kLP17/BDpubXwa7YEW16+Ult6z9kpHI=
github.com/juju/utils/v3 v3.1.0/go.mod h1:nAj3sHtdYfAkvnkqttTy3Xzm2HzkD9Hfgnc+upOW2Z8=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lunixbochs/vtclean v0.0.0-20160125035106-4fbf7632a2c6/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
github.com/mattn/go-colorable v0.0.6/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-isatty v0.0.0-20160806122752-66b8e73f3f5c/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20160105164936-4f90aeace3a2/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
================================================
FILE: gomaxprocs.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"os"
"runtime"
)
var gomaxprocs = runtime.GOMAXPROCS
var numCPU = runtime.NumCPU
// UseMultipleCPUs sets GOMAXPROCS to the number of CPU cores unless it has
// already been overridden by the GOMAXPROCS environment variable.
func UseMultipleCPUs() {
if envGOMAXPROCS := os.Getenv("GOMAXPROCS"); envGOMAXPROCS != "" {
n := gomaxprocs(0)
logger.Debugf("GOMAXPROCS already set in environment to %q, %d internally",
envGOMAXPROCS, n)
return
}
n := numCPU()
logger.Debugf("setting GOMAXPROCS to %d", n)
gomaxprocs(n)
}
================================================
FILE: gomaxprocs_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"os"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type gomaxprocsSuite struct {
testing.IsolationSuite
setmaxprocs chan int
numCPUResponse int
setMaxProcs int
}
var _ = gc.Suite(&gomaxprocsSuite{})
func (s *gomaxprocsSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
// always stub out GOMAXPROCS so we don't actually change anything
s.numCPUResponse = 2
s.setMaxProcs = -1
maxProcsFunc := func(n int) int {
s.setMaxProcs = n
return 1
}
numCPUFunc := func() int { return s.numCPUResponse }
s.PatchValue(utils.GOMAXPROCS, maxProcsFunc)
s.PatchValue(utils.NumCPU, numCPUFunc)
s.PatchEnvironment("GOMAXPROCS", "")
}
func (s *gomaxprocsSuite) TestUseMultipleCPUsDoesNothingWhenGOMAXPROCSSet(c *gc.C) {
err := os.Setenv("GOMAXPROCS", "1")
c.Assert(err, jc.ErrorIsNil)
utils.UseMultipleCPUs()
c.Check(s.setMaxProcs, gc.Equals, 0)
}
func (s *gomaxprocsSuite) TestUseMultipleCPUsWhenEnabled(c *gc.C) {
utils.UseMultipleCPUs()
c.Check(s.setMaxProcs, gc.Equals, 2)
s.numCPUResponse = 4
utils.UseMultipleCPUs()
c.Check(s.setMaxProcs, gc.Equals, 4)
}
================================================
FILE: hash/fingerprint.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package hash
import (
"encoding/base64"
"encoding/hex"
"hash"
"io"
"github.com/juju/errors"
)
// Fingerprint represents the checksum for some data.
type Fingerprint struct {
sum []byte
}
// NewFingerprint returns wraps the provided raw hash sum. This function
// roundtrips with Fingerprint.Bytes().
func NewFingerprint(sum []byte, validate func([]byte) error) (Fingerprint, error) {
if validate == nil {
return Fingerprint{}, errors.New("missing validate func")
}
if err := validate(sum); err != nil {
return Fingerprint{}, errors.Trace(err)
}
return newFingerprint(sum), nil
}
// NewValidFingerprint returns a Fingerprint corresponding
// to the current of the provided hash.
func NewValidFingerprint(hash hash.Hash) Fingerprint {
sum := hash.Sum(nil)
return newFingerprint(sum)
}
func newFingerprint(sum []byte) Fingerprint {
return Fingerprint{
sum: append([]byte{}, sum...), // Use an isolated copy.
}
}
// GenerateFingerprint returns the fingerprint for the provided data.
func GenerateFingerprint(reader io.Reader, newHash func() hash.Hash) (Fingerprint, error) {
var fp Fingerprint
if reader == nil {
return fp, errors.New("missing reader")
}
if newHash == nil {
return fp, errors.New("missing new hash func")
}
hash := newHash()
if _, err := io.Copy(hash, reader); err != nil {
return fp, errors.Trace(err)
}
fp.sum = hash.Sum(nil)
return fp, nil
}
// ParseHexFingerprint returns wraps the provided raw fingerprint string.
// This function roundtrips with Fingerprint.Hex().
func ParseHexFingerprint(hexSum string, validate func([]byte) error) (Fingerprint, error) {
if validate == nil {
return Fingerprint{}, errors.New("missing validate func")
}
sum, err := hex.DecodeString(hexSum)
if err != nil {
return Fingerprint{}, errors.Trace(err)
}
fp, err := NewFingerprint(sum, validate)
if err != nil {
return Fingerprint{}, errors.Trace(err)
}
return fp, nil
}
// ParseBase64Fingerprint returns wraps the provided raw fingerprint string.
// This function roundtrips with Fingerprint.Base64().
func ParseBase64Fingerprint(b64Sum string, validate func([]byte) error) (Fingerprint, error) {
if validate == nil {
return Fingerprint{}, errors.New("missing validate func")
}
sum, err := base64.StdEncoding.DecodeString(b64Sum)
if err != nil {
return Fingerprint{}, errors.Trace(err)
}
fp, err := NewFingerprint(sum, validate)
if err != nil {
return Fingerprint{}, errors.Trace(err)
}
return fp, nil
}
// String implements fmt.Stringer.
func (fp Fingerprint) String() string {
return fp.Hex()
}
// Hex returns the hex string representation of the fingerprint.
func (fp Fingerprint) Hex() string {
return hex.EncodeToString(fp.sum)
}
// Base64 returns the base64 encoded fingerprint.
func (fp Fingerprint) Base64() string {
return base64.StdEncoding.EncodeToString(fp.sum)
}
// Bytes returns the raw (sum) bytes of the fingerprint.
func (fp Fingerprint) Bytes() []byte {
return append([]byte{}, fp.sum...)
}
// IsZero returns whether or not the fingerprint is the zero value.
func (fp Fingerprint) IsZero() bool {
return len(fp.sum) == 0
}
// Validate returns an error if the fingerprint is invalid.
func (fp Fingerprint) Validate() error {
if fp.IsZero() {
return errors.NotValidf("zero-value fingerprint")
}
return nil
}
================================================
FILE: hash/fingerprint_test.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package hash_test
import (
"crypto/sha512"
"encoding/hex"
stdhash "hash"
"github.com/juju/errors"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
"github.com/juju/testing/filetesting"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/hash"
)
var _ = gc.Suite(&FingerprintSuite{})
type FingerprintSuite struct {
stub *testing.Stub
hash *filetesting.StubHash
}
func (s *FingerprintSuite) SetUpTest(c *gc.C) {
s.stub = &testing.Stub{}
s.hash = filetesting.NewStubHash(s.stub, nil)
}
func (s *FingerprintSuite) newHash() stdhash.Hash {
s.stub.AddCall("newHash")
s.stub.NextErr() // Pop one off.
return s.hash
}
func (s *FingerprintSuite) validate(sum []byte) error {
s.stub.AddCall("validate", sum)
if err := s.stub.NextErr(); err != nil {
return errors.Trace(err)
}
return nil
}
func (s *FingerprintSuite) TestNewFingerprintOkay(c *gc.C) {
expected, _ := newFingerprint(c, "spamspamspam")
fp, err := hash.NewFingerprint(expected, s.validate)
c.Assert(err, jc.ErrorIsNil)
sum := fp.Bytes()
s.stub.CheckCallNames(c, "validate")
c.Check(sum, jc.DeepEquals, expected)
}
func (s *FingerprintSuite) TestNewFingerprintInvalid(c *gc.C) {
expected, _ := newFingerprint(c, "spamspamspam")
failure := errors.NewNotValid(nil, "bogus!!!")
s.stub.SetErrors(failure)
_, err := hash.NewFingerprint(expected, s.validate)
s.stub.CheckCallNames(c, "validate")
c.Check(errors.Cause(err), gc.Equals, failure)
}
func (s *FingerprintSuite) TestNewValidFingerprint(c *gc.C) {
expected, _ := newFingerprint(c, "spamspamspam")
s.hash.ReturnSum = expected
fp := hash.NewValidFingerprint(s.hash)
sum := fp.Bytes()
s.stub.CheckCallNames(c, "Sum")
c.Check(sum, jc.DeepEquals, expected)
}
func (s *FingerprintSuite) TestGenerateFingerprintOkay(c *gc.C) {
expected, _ := newFingerprint(c, "spamspamspam")
s.hash.ReturnSum = expected
s.hash.Writer, _ = filetesting.NewStubWriter(s.stub)
reader := filetesting.NewStubReader(s.stub, "spamspamspam")
fp, err := hash.GenerateFingerprint(reader, s.newHash)
c.Assert(err, jc.ErrorIsNil)
sum := fp.Bytes()
s.stub.CheckCallNames(c, "newHash", "Read", "Write", "Read", "Sum")
c.Check(sum, jc.DeepEquals, expected)
}
func (s *FingerprintSuite) TestGenerateFingerprintNil(c *gc.C) {
_, err := hash.GenerateFingerprint(nil, s.newHash)
s.stub.CheckNoCalls(c)
c.Check(err, gc.ErrorMatches, `missing reader`)
}
func (s *FingerprintSuite) TestParseHexFingerprint(c *gc.C) {
expected, hexSum := newFingerprint(c, "spamspamspam")
fp, err := hash.ParseHexFingerprint(hexSum, s.validate)
c.Assert(err, jc.ErrorIsNil)
sum := fp.Bytes()
s.stub.CheckCallNames(c, "validate")
c.Check(sum, jc.DeepEquals, expected)
}
func (s *FingerprintSuite) TestString(c *gc.C) {
sum, expected := newFingerprint(c, "spamspamspam")
fp, err := hash.NewFingerprint(sum, s.validate)
c.Assert(err, jc.ErrorIsNil)
hex := fp.String()
c.Check(hex, gc.Equals, expected)
}
func (s *FingerprintSuite) TestHex(c *gc.C) {
sum, expected := newFingerprint(c, "spamspamspam")
fp, err := hash.NewFingerprint(sum, s.validate)
c.Assert(err, jc.ErrorIsNil)
hex := fp.String()
c.Check(hex, gc.Equals, expected)
}
func (s *FingerprintSuite) TestBytes(c *gc.C) {
expected, _ := newFingerprint(c, "spamspamspam")
fp, err := hash.NewFingerprint(expected, s.validate)
c.Assert(err, jc.ErrorIsNil)
sum := fp.Bytes()
c.Check(sum, jc.DeepEquals, expected)
}
func (s *FingerprintSuite) TestValidateOkay(c *gc.C) {
sum, _ := newFingerprint(c, "spamspamspam")
fp, err := hash.NewFingerprint(sum, s.validate)
c.Assert(err, jc.ErrorIsNil)
err = fp.Validate()
c.Check(err, jc.ErrorIsNil)
}
func (s *FingerprintSuite) TestValidateZero(c *gc.C) {
var fp hash.Fingerprint
err := fp.Validate()
c.Check(err, jc.Satisfies, errors.IsNotValid)
c.Check(err, gc.ErrorMatches, `zero-value fingerprint not valid`)
}
func newFingerprint(c *gc.C, data string) ([]byte, string) {
hash := sha512.New384()
_, err := hash.Write([]byte(data))
c.Assert(err, jc.ErrorIsNil)
sum := hash.Sum(nil)
hexStr := hex.EncodeToString(sum)
return sum, hexStr
}
================================================
FILE: hash/hash.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// The hash package provides utilities that support use of the stdlib
// hash.Hash. Most notably is the Fingerprint type that wraps the
// checksum of a hash.
//
// Conversion between checksums and strings are facailitated through
// Fingerprint.
//
// Here are some hash-related recipes that bring it all together:
//
// - Extract the SHA384 hash while writing to elsewhere, then get the
// raw checksum:
//
// newHash, _ := hash.SHA384()
// h := newHash()
// hashingWriter := io.MultiWriter(writer, h)
// if err := writeAll(hashingWriter); err != nil { ... }
// fp := hash.NewValidFingerprint(h)
// checksum := fp.Bytes()
//
// - Extract the SHA384 hash while reading from elsewhere, then get the
// hex-encoded checksum to send over the wire:
//
// newHash, _ := hash.SHA384()
// h := newHash()
// hashingReader := io.TeeReader(reader, h)
// if err := processStream(hashingReader); err != nil { ... }
// fp := hash.NewValidFingerprint(h)
// hexSum := fp.Hex()
// req.Header.Set("Content-Sha384", hexSum)
//
// * Turn a checksum sent over the wire back into a fingerprint:
//
// _, validate := hash.SHA384()
// hexSum := req.Header.Get("Content-Sha384")
// var fp hash.Fingerprint
// if len(hexSum) != 0 {
// fp, err = hash.ParseHexFingerprint(hexSum, validate)
// ...
// }
// if fp.IsZero() {
// ...
// }
package hash
import (
"crypto/sha512"
"hash"
"github.com/juju/errors"
)
// SHA384 returns the newHash and validate functions for use
// with SHA384 hashes. SHA384 is used in several key places in Juju.
func SHA384() (newHash func() hash.Hash, validate func([]byte) error) {
const digestLenBytes = 384 / 8
validate = newSizeChecker(digestLenBytes)
return sha512.New384, validate
}
func newSizeChecker(size int) func([]byte) error {
return func(sum []byte) error {
if len(sum) < size {
return errors.NewNotValid(nil, "invalid fingerprint (too small)")
}
if len(sum) > size {
return errors.NewNotValid(nil, "invalid fingerprint (too big)")
}
return nil
}
}
================================================
FILE: hash/hash_test.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package hash_test
import (
"bytes"
"io"
"io/ioutil"
"strings"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
"github.com/juju/testing/filetesting"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/hash"
)
var _ = gc.Suite(&HashSuite{})
type HashSuite struct {
testing.IsolationSuite
}
func (s *HashSuite) TestHashingWriter(c *gc.C) {
data := "some data"
newHash, _ := hash.SHA384()
expected, err := hash.GenerateFingerprint(strings.NewReader(data), newHash)
c.Assert(err, jc.ErrorIsNil)
var writer bytes.Buffer
h := newHash()
hashingWriter := io.MultiWriter(&writer, h)
_, err = hashingWriter.Write([]byte(data))
c.Assert(err, jc.ErrorIsNil)
fp := hash.NewValidFingerprint(h)
c.Check(fp, jc.DeepEquals, expected)
c.Check(writer.String(), gc.Equals, data)
}
func (s *HashSuite) TestHashingReader(c *gc.C) {
expected := "some data"
stub := &testing.Stub{}
reader := &filetesting.StubReader{
Stub: stub,
ReturnRead: &fakeStream{
data: expected,
},
}
newHash, validate := hash.SHA384()
h := newHash()
hashingReader := io.TeeReader(reader, h)
data, err := ioutil.ReadAll(hashingReader)
c.Assert(err, jc.ErrorIsNil)
fp := hash.NewValidFingerprint(h)
hexSum := fp.Hex()
fpAgain, err := hash.ParseHexFingerprint(hexSum, validate)
c.Assert(err, jc.ErrorIsNil)
stub.CheckCallNames(c, "Read") // The EOF was mixed with the data.
c.Check(string(data), gc.Equals, expected)
c.Check(fpAgain, jc.DeepEquals, fp)
}
type fakeStream struct {
data string
pos uint64
}
func (f *fakeStream) Read(data []byte) (int, error) {
n := copy(data, f.data[f.pos:])
f.pos += uint64(n)
if f.pos >= uint64(len(f.data)) {
return n, io.EOF
}
return n, nil
}
================================================
FILE: hash/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package hash_test
import (
stdtesting "testing"
gc "gopkg.in/check.v1"
)
func Test(t *stdtesting.T) {
gc.TestingT(t)
}
================================================
FILE: hash/writer.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package hash
import (
"encoding/base64"
"hash"
"io"
)
// TODO(ericsnow) Remove HashingWriter and NewHashingWriter().
// HashingWriter wraps an io.Writer, providing the checksum of all data
// written to it. A HashingWriter may be used in place of the writer it
// wraps.
//
// Note: HashingWriter is deprecated. Please do not use it. We will
// remove it ASAP.
type HashingWriter struct {
hash hash.Hash
wrapped io.Writer
}
// NewHashingWriter returns a new HashingWriter that wraps the provided
// writer and the hasher.
//
// Example:
// hw := NewHashingWriter(w, sha1.New())
// io.Copy(hw, reader)
// hash := hw.Base64Sum()
//
// Note: NewHashingWriter is deprecated. Please do not use it. We will
// remove it ASAP.
func NewHashingWriter(writer io.Writer, hasher hash.Hash) *HashingWriter {
return &HashingWriter{
hash: hasher,
wrapped: io.MultiWriter(writer, hasher),
}
}
// Base64Sum returns the base64 encoded hash.
func (hw HashingWriter) Base64Sum() string {
sumBytes := hw.hash.Sum(nil)
return base64.StdEncoding.EncodeToString(sumBytes)
}
// Write writes to both the wrapped file and the hash.
func (hw *HashingWriter) Write(data []byte) (int, error) {
// No trace because some callers, like ioutil.ReadAll(), won't work.
return hw.wrapped.Write(data)
}
================================================
FILE: hash/writer_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package hash_test
import (
"bytes"
"github.com/juju/errors"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
"github.com/juju/testing/filetesting"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/hash"
)
var _ = gc.Suite(&WriterSuite{})
type WriterSuite struct {
testing.IsolationSuite
stub *testing.Stub
wBuffer *bytes.Buffer
writer *filetesting.StubWriter
hBuffer *bytes.Buffer
hash *filetesting.StubHash
}
func (s *WriterSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
s.stub = &testing.Stub{}
s.wBuffer = new(bytes.Buffer)
s.writer = &filetesting.StubWriter{
Stub: s.stub,
ReturnWrite: s.wBuffer,
}
s.hBuffer = new(bytes.Buffer)
s.hash = filetesting.NewStubHash(s.stub, s.hBuffer)
}
func (s *WriterSuite) TestHashingWriterWriteEmpty(c *gc.C) {
w := hash.NewHashingWriter(s.writer, s.hash)
n, err := w.Write(nil)
c.Assert(err, jc.ErrorIsNil)
s.stub.CheckCallNames(c, "Write", "Write")
c.Check(n, gc.Equals, 0)
c.Check(s.wBuffer.String(), gc.Equals, "")
c.Check(s.hBuffer.String(), gc.Equals, "")
}
func (s *WriterSuite) TestHashingWriterWriteSmall(c *gc.C) {
w := hash.NewHashingWriter(s.writer, s.hash)
n, err := w.Write([]byte("spam"))
c.Assert(err, jc.ErrorIsNil)
s.stub.CheckCallNames(c, "Write", "Write")
c.Check(n, gc.Equals, 4)
c.Check(s.wBuffer.String(), gc.Equals, "spam")
c.Check(s.hBuffer.String(), gc.Equals, "spam")
}
func (s *WriterSuite) TestHashingWriterWriteFileError(c *gc.C) {
w := hash.NewHashingWriter(s.writer, s.hash)
failure := errors.New("")
s.stub.SetErrors(failure)
_, err := w.Write([]byte("spam"))
s.stub.CheckCallNames(c, "Write")
c.Check(errors.Cause(err), gc.Equals, failure)
}
func (s *WriterSuite) TestHashingWriterBase64Sum(c *gc.C) {
s.hash.ReturnSum = []byte("spam")
w := hash.NewHashingWriter(s.writer, s.hash)
b64sum := w.Base64Sum()
s.stub.CheckCallNames(c, "Sum")
c.Check(b64sum, gc.Equals, "c3BhbQ==")
}
================================================
FILE: home_unix.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build !windows
// +build !windows
package utils
import (
"os"
)
// Home returns the os-specific home path.
// Always returns the "real" home, not the
// confined home that is used when running
// inside a strictly confined snap.
func Home() string {
// Used when running inside a confined snap.
realHome, exists := os.LookupEnv("SNAP_REAL_HOME")
if exists {
return realHome
}
return os.Getenv("HOME")
}
// SetHome sets the os-specific home path in the environment.
func SetHome(s string) error {
if _, exists := os.LookupEnv("SNAP_REAL_HOME"); exists {
return os.Setenv("SNAP_REAL_HOME", s)
}
return os.Setenv("HOME", s)
}
================================================
FILE: home_unix_test.go
================================================
// Copyright 2011, 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build !windows
// +build !windows
package utils_test
import (
"github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type homeSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&homeSuite{})
func (s *homeSuite) TestHomeLinux(c *gc.C) {
h := "/home/foo/bar"
s.PatchEnvironment("HOME", h)
c.Check(utils.Home(), gc.Equals, h)
}
func (s *homeSuite) TestHomeConfined(c *gc.C) {
h := "/home/foo/bar"
s.PatchEnvironment("HOME", "/home/user/snap/foo/1")
s.PatchEnvironment("SNAP_REAL_HOME", h)
c.Check(utils.Home(), gc.Equals, h)
}
================================================
FILE: home_windows.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"os"
"path/filepath"
)
// Home returns the os-specific home path as specified in the environment.
func Home() string {
return filepath.Join(os.Getenv("HOMEDRIVE"), os.Getenv("HOMEPATH"))
}
// SetHome sets the os-specific home path in the environment.
func SetHome(s string) error {
v := filepath.VolumeName(s)
if v != "" {
if err := os.Setenv("HOMEDRIVE", v); err != nil {
return err
}
}
return os.Setenv("HOMEPATH", s[len(v):])
}
================================================
FILE: home_windows_test.go
================================================
// Copyright 2011, 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"os"
"github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type homeSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&homeSuite{})
func (s *homeSuite) TestHome(c *gc.C) {
s.PatchEnvironment("HOMEPATH", "")
s.PatchEnvironment("HOMEDRIVE", "")
drive := "P:"
path := `\home\foo\bar`
h := drive + path
utils.SetHome(h)
c.Check(os.Getenv("HOMEPATH"), gc.Equals, path)
c.Check(os.Getenv("HOMEDRIVE"), gc.Equals, drive)
c.Check(utils.Home(), gc.Equals, h)
// now test that if we only set the path, we don't mess with the drive
path2 := `\home\someotherfoo\bar`
utils.SetHome(path2)
c.Check(os.Getenv("HOMEPATH"), gc.Equals, path2)
c.Check(os.Getenv("HOMEDRIVE"), gc.Equals, drive)
c.Check(utils.Home(), gc.Equals, drive+path2)
}
================================================
FILE: isubuntu.go
================================================
// Copyright 2011, 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"strings"
)
// IsUbuntu executes lxb_release to see if the host OS is Ubuntu.
func IsUbuntu() bool {
out, err := RunCommand("lsb_release", "-i", "-s")
if err != nil {
return false
}
return strings.TrimSpace(out) == "Ubuntu"
}
================================================
FILE: isubuntu_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"fmt"
"runtime"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type IsUbuntuSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&IsUbuntuSuite{})
func (s *IsUbuntuSuite) patchLsbRelease(c *gc.C, name string) {
var content string
var execName string
if runtime.GOOS != "windows" {
content = fmt.Sprintf("#!/bin/bash --norc\n%s", name)
execName = "lsb_release"
} else {
execName = "lsb_release.bat"
content = fmt.Sprintf("@echo off\r\n%s", name)
}
patchExecutable(s, c.MkDir(), execName, content)
}
func (s *IsUbuntuSuite) TestIsUbuntu(c *gc.C) {
s.patchLsbRelease(c, "echo Ubuntu")
c.Assert(utils.IsUbuntu(), jc.IsTrue)
}
func (s *IsUbuntuSuite) TestIsNotUbuntu(c *gc.C) {
s.patchLsbRelease(c, "echo Windows NT")
c.Assert(utils.IsUbuntu(), jc.IsFalse)
}
func (s *IsUbuntuSuite) TestIsNotUbuntuLsbReleaseNotFound(c *gc.C) {
if runtime.GOOS != "windows" {
s.patchLsbRelease(c, "exit 127")
}
c.Assert(utils.IsUbuntu(), jc.IsFalse)
}
================================================
FILE: jsonhttp/jsonhttp.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// Package jsonhttp provides general functions for returning
// JSON responses to HTTP requests. It is agnostic about
// the specific form of any returned errors.
package jsonhttp
import (
"encoding/json"
"net/http"
"github.com/juju/errors"
)
// ErrorToResponse represents a function that can convert a Go error
// into a form that can be returned as a JSON body from an HTTP request.
// The httpStatus value reports the desired HTTP status.
type ErrorToResponse func(err error) (httpStatus int, errorBody any)
// ErrorHandler is like http.Handler except it returns an error
// which may be returned as the error body of the response.
// An ErrorHandler function should not itself write to the ResponseWriter
// if it returns an error.
type ErrorHandler func(http.ResponseWriter, *http.Request) error
// HandleErrors returns a function that can be used to convert an ErrorHandler
// into an http.Handler. The given errToResp parameter is used to convert
// any non-nil error returned by handle to the response in the HTTP body.
func HandleErrors(errToResp ErrorToResponse) func(handle ErrorHandler) http.Handler {
writeError := WriteError(errToResp)
return func(handle ErrorHandler) http.Handler {
f := func(w http.ResponseWriter, req *http.Request) {
w1 := responseWriter{
ResponseWriter: w,
}
if err := handle(&w1, req); err != nil {
// We write the error only if the header hasn't
// already been written, because if it has, then
// we will not be able to set the appropriate error
// response code, and there's a danger that we
// may be corrupting output by appending
// a JSON error message to it.
if !w1.headerWritten {
writeError(w, err)
}
// TODO log the error?
}
}
return http.HandlerFunc(f)
}
}
// responseWriter wraps http.ResponseWriter but allows us
// to find out whether any body has already been written.
type responseWriter struct {
headerWritten bool
http.ResponseWriter
}
func (w *responseWriter) Write(data []byte) (int, error) {
w.headerWritten = true
return w.ResponseWriter.Write(data)
}
func (w *responseWriter) WriteHeader(code int) {
w.headerWritten = true
w.ResponseWriter.WriteHeader(code)
}
// Flush implements http.Flusher.Flush.
func (w *responseWriter) Flush() {
w.headerWritten = true
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// Ensure statically that responseWriter does implement http.Flusher.
var _ http.Flusher = (*responseWriter)(nil)
// WriteError returns a function that can be used to write an error to a ResponseWriter
// and set the HTTP status code. The errToResp parameter is used to determine
// the actual error value and status to write.
func WriteError(errToResp ErrorToResponse) func(w http.ResponseWriter, err error) {
return func(w http.ResponseWriter, err error) {
status, resp := errToResp(err)
_ = WriteJSON(w, status, resp)
}
}
// WriteJSON writes the given value to the ResponseWriter
// and sets the HTTP status to the given code.
func WriteJSON(w http.ResponseWriter, code int, val any) error {
// TODO consider marshalling directly to w using json.NewEncoder.
// pro: this will not require a full buffer allocation.
// con: if there's an error after the first write, it will be lost.
data, err := json.Marshal(val)
if err != nil {
// TODO(rog) log an error if this fails and lose the
// error return, because most callers will need
// to do that anyway.
return errors.Mask(err)
}
w.Header().Set("content-type", "application/json")
w.WriteHeader(code)
_, _ = w.Write(data)
return nil
}
// JSONHandler is like http.Handler except that it returns a
// body (to be converted to JSON) and an error.
// The Header parameter can be used to set
// custom header on the response.
type JSONHandler func(http.Header, *http.Request) (any, error)
// HandleJSON returns a function that can be used to convert an JSONHandler
// into an http.Handler. The given errToResp parameter is used to convert
// any non-nil error returned by handle to the response in the HTTP body
// If it returns a nil value, the original error is returned as a JSON string.
func HandleJSON(errToResp ErrorToResponse) func(handle JSONHandler) http.Handler {
handleErrors := HandleErrors(errToResp)
return func(handle JSONHandler) http.Handler {
f := func(w http.ResponseWriter, req *http.Request) error {
val, err := handle(w.Header(), req)
if err != nil {
return errors.Trace(err)
}
return WriteJSON(w, http.StatusOK, val)
}
return handleErrors(f)
}
}
================================================
FILE: jsonhttp/jsonhttp_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package jsonhttp_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"github.com/juju/errors"
"github.com/juju/utils/v4/jsonhttp"
gc "gopkg.in/check.v1"
)
type suite struct{}
var _ = gc.Suite(&suite{})
func (*suite) TestWriteJSON(c *gc.C) {
rec := httptest.NewRecorder()
type Number struct {
N int
}
err := jsonhttp.WriteJSON(rec, http.StatusTeapot, Number{1234})
c.Assert(err, gc.IsNil)
c.Assert(rec.Code, gc.Equals, http.StatusTeapot)
c.Assert(rec.Body.String(), gc.Equals, `{"N":1234}`)
c.Assert(rec.Header().Get("content-type"), gc.Equals, "application/json")
}
var (
errUnauth = errors.New("unauth")
errBadReq = errors.New("bad request")
errOther = errors.New("other")
errNil = errors.New("nil result")
)
type errorResponse struct {
Message string
}
func errorToResponse(err error) (int, any) {
resp := &errorResponse{
Message: err.Error(),
}
status := http.StatusInternalServerError
switch errors.Cause(err) {
case errUnauth:
status = http.StatusUnauthorized
case errBadReq:
status = http.StatusBadRequest
case errNil:
return status, nil
}
return status, &resp
}
var writeErrorTests = []struct {
err error
expectStatus int
expectResp *errorResponse
}{{
err: errUnauth,
expectStatus: http.StatusUnauthorized,
expectResp: &errorResponse{
Message: errUnauth.Error(),
},
}, {
err: errBadReq,
expectStatus: http.StatusBadRequest,
expectResp: &errorResponse{
Message: errBadReq.Error(),
},
}, {
err: errOther,
expectStatus: http.StatusInternalServerError,
expectResp: &errorResponse{
Message: errOther.Error(),
},
}, {
err: errNil,
expectStatus: http.StatusInternalServerError,
}}
func (s *suite) TestWriteError(c *gc.C) {
writeError := jsonhttp.WriteError(errorToResponse)
for i, test := range writeErrorTests {
c.Logf("%d: %s", i, test.err)
rec := httptest.NewRecorder()
writeError(rec, test.err)
resp := parseErrorResponse(c, rec.Body.Bytes())
c.Assert(resp, gc.DeepEquals, test.expectResp)
c.Assert(rec.Code, gc.Equals, test.expectStatus)
}
}
func parseErrorResponse(c *gc.C, body []byte) *errorResponse {
var errResp *errorResponse
err := json.Unmarshal(body, &errResp)
c.Assert(err, gc.IsNil)
return errResp
}
func (s *suite) TestHandleErrors(c *gc.C) {
handleErrors := jsonhttp.HandleErrors(errorToResponse)
// Test when handler returns an error.
handler := handleErrors(func(http.ResponseWriter, *http.Request) error {
return errUnauth
})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, new(http.Request))
c.Assert(rec.Code, gc.Equals, http.StatusUnauthorized)
resp := parseErrorResponse(c, rec.Body.Bytes())
c.Assert(resp, gc.DeepEquals, &errorResponse{
Message: errUnauth.Error(),
})
// Test when handler returns nil.
handler = handleErrors(func(w http.ResponseWriter, _ *http.Request) error {
w.WriteHeader(http.StatusCreated)
w.Write([]byte("something"))
return nil
})
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, new(http.Request))
c.Assert(rec.Code, gc.Equals, http.StatusCreated)
c.Assert(rec.Body.String(), gc.Equals, "something")
}
var handleErrorsWithErrorAfterWriteHeaderTests = []struct {
about string
causeWriteHeader func(w http.ResponseWriter)
}{{
about: "write",
causeWriteHeader: func(w http.ResponseWriter) {
w.Write([]byte(""))
},
}, {
about: "write header",
causeWriteHeader: func(w http.ResponseWriter) {
w.WriteHeader(http.StatusOK)
},
}, {
about: "flush",
causeWriteHeader: func(w http.ResponseWriter) {
w.(http.Flusher).Flush()
},
}}
func (s *suite) TestHandleErrorsWithErrorAfterWriteHeader(c *gc.C) {
handleErrors := jsonhttp.HandleErrors(errorToResponse)
for i, test := range handleErrorsWithErrorAfterWriteHeaderTests {
c.Logf("test %d: %s", i, test.about)
handler := handleErrors(func(w http.ResponseWriter, _ *http.Request) error {
test.causeWriteHeader(w)
return errors.New("unexpected")
})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, new(http.Request))
c.Assert(rec.Code, gc.Equals, http.StatusOK)
c.Assert(rec.Body.String(), gc.Equals, "")
}
}
func (s *suite) TestHandleJSON(c *gc.C) {
handleJSON := jsonhttp.HandleJSON(errorToResponse)
// Test when handler returns an error.
handler := handleJSON(func(http.Header, *http.Request) (any, error) {
return nil, errUnauth
})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, new(http.Request))
resp := parseErrorResponse(c, rec.Body.Bytes())
c.Assert(resp, gc.DeepEquals, &errorResponse{
Message: errUnauth.Error(),
})
c.Assert(rec.Code, gc.Equals, http.StatusUnauthorized)
// Test when handler returns a body.
handler = handleJSON(func(h http.Header, _ *http.Request) (any, error) {
h.Set("Some-Header", "value")
return "something", nil
})
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, new(http.Request))
c.Assert(rec.Code, gc.Equals, http.StatusOK)
c.Assert(rec.Body.String(), gc.Equals, `"something"`)
c.Assert(rec.Header().Get("Some-Header"), gc.Equals, "value")
}
================================================
FILE: jsonhttp/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package jsonhttp_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: keyvalues/keyvalues.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// The keyvalues package implements a set of functions for parsing key=value data,
// usually passed in as command-line parameters to juju subcommands, e.g.
// juju-set mongodb logging=true
package keyvalues
import (
"fmt"
"strings"
)
// DuplicateError signals that a duplicate key was encountered while parsing
// the input into a map.
type DuplicateError string
func (e DuplicateError) Error() string {
return string(e)
}
// Parse parses the supplied string slice into a map mapping
// keys to values. Duplicate keys cause an error to be returned.
func Parse(src []string, allowEmptyValues bool) (map[string]string, error) {
results := map[string]string{}
for _, kv := range src {
if kv == "" {
continue
}
parts := strings.SplitN(kv, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf(`expected "key=value", got %q`, kv)
}
key, value := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
if len(key) == 0 || (!allowEmptyValues && len(value) == 0) {
return nil, fmt.Errorf(`expected "key=value", got "%s=%s"`, key, value)
}
if _, exists := results[key]; exists {
return nil, DuplicateError(fmt.Sprintf("key %q specified more than once", key))
}
results[key] = value
}
return results, nil
}
================================================
FILE: keyvalues/keyvalues_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package keyvalues_test
import (
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/keyvalues"
)
type keyValuesSuite struct{}
var _ = gc.Suite(&keyValuesSuite{})
var testCases = []struct {
about string
input []string
allowEmptyVal bool
output map[string]string
error string
}{{
about: "simple test case",
input: []string{"key=value"},
allowEmptyVal: false,
output: map[string]string{"key": "value"},
error: "",
}, {
about: "empty list",
input: []string{},
allowEmptyVal: false,
output: map[string]string{},
error: "",
}, {
about: "nil list",
input: nil,
allowEmptyVal: false,
output: map[string]string{},
error: "",
}, {
about: "invalid format - missing value",
input: []string{"key"},
allowEmptyVal: false,
output: nil,
error: `expected "key=value", got "key"`,
}, {
about: "invalid format - missing value",
input: []string{"key="},
allowEmptyVal: false,
output: nil,
error: `expected "key=value", got "key="`,
}, {
about: "invalid format - missing key",
input: []string{"=value"},
allowEmptyVal: false,
output: nil,
error: `expected "key=value", got "=value"`,
}, {
about: "invalid format",
input: []string{"="},
allowEmptyVal: false,
output: nil,
error: `expected "key=value", got "="`,
}, {
about: "invalid format, allowing empty",
input: []string{"="},
allowEmptyVal: true,
output: nil,
error: `expected "key=value", got "="`,
}, {
about: "duplicate keys",
input: []string{"key=value", "key=value"},
allowEmptyVal: true,
output: nil,
error: `key "key" specified more than once`,
}, {
about: "multiple keys",
input: []string{"key=value", "key2=value", "key3=value"},
allowEmptyVal: true,
output: map[string]string{"key": "value", "key2": "value", "key3": "value"},
error: "",
}, {
about: "empty value",
input: []string{"key="},
allowEmptyVal: true,
output: map[string]string{"key": ""},
error: "",
}, {
about: "whitespace trimmed",
input: []string{"key=value\n", "key2\t=\tvalue2"},
allowEmptyVal: true,
output: map[string]string{"key": "value", "key2": "value2"},
error: "",
}, {
about: "whitespace trimming and duplicate keys",
input: []string{"key =value", "key\t=\tvalue2"},
allowEmptyVal: true,
output: nil,
error: `key "key" specified more than once`,
}, {
about: "whitespace trimming and empty value not allowed",
input: []string{"key= "},
allowEmptyVal: false,
output: nil,
error: `expected "key=value", got "key="`,
}, {
about: "whitespace trimming and empty value",
input: []string{"key= "},
allowEmptyVal: true,
output: map[string]string{"key": ""},
error: "",
}, {
about: "whitespace trimming and missing key",
input: []string{" =value"},
allowEmptyVal: true,
output: nil,
error: `expected "key=value", got "=value"`,
}, {
about: "empty inputs are skipped",
input: []string{"key=value", "", "foo=bar"},
allowEmptyVal: true,
output: map[string]string{"key": "value", "foo": "bar"},
error: "",
}}
func (keyValuesSuite) TestMapParsing(c *gc.C) {
for i, t := range testCases {
c.Logf("test %d: %s", i, t.about)
result, err := keyvalues.Parse(t.input, t.allowEmptyVal)
c.Check(result, gc.DeepEquals, t.output)
if t.error == "" {
c.Check(err, gc.IsNil)
} else {
c.Check(err, gc.ErrorMatches, t.error)
}
}
}
================================================
FILE: keyvalues/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package keyvalues_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: limiter.go
================================================
// Copyright 2011, 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"fmt"
"math/rand"
"time"
"github.com/juju/clock"
)
type empty struct{}
type limiter struct {
wait chan empty
minPause time.Duration
maxPause time.Duration
clock clock.Clock
}
// Limiter represents a limited resource (eg a semaphore).
type Limiter interface {
// Acquire another unit of the resource.
// Acquire returns false to indicate there is no more availability,
// until another entity calls Release.
Acquire() bool
// AcquireWait requests a unit of resource, but blocks until one is
// available.
AcquireWait()
// Release returns a unit of the resource. Calling Release when there
// are no units Acquired is an error.
Release() error
}
// NewLimiter creates a limiter.
func NewLimiter(maxAllowed int) Limiter {
return NewLimiterWithPause(maxAllowed, 0, 0, nil)
}
// NewLimiterWithPause creates a limiter. If minpause and maxPause is > 0,
// there will be a random delay in that duration range before attempting an Acquire.
func NewLimiterWithPause(maxAllowed int, minPause, maxPause time.Duration, clk clock.Clock) Limiter {
rand.Seed(time.Now().UTC().UnixNano())
if clk == nil {
clk = clock.WallClock
}
return limiter{
wait: make(chan empty, maxAllowed),
minPause: minPause,
maxPause: maxPause,
clock: clk,
}
}
// Acquire requests some resources that you can return later
// It returns 'true' if there are resources available, but false if they are
// not. Callers are responsible for calling Release if this returns true, but
// should not release if this returns false.
func (l limiter) Acquire() bool {
// Pause before attempting to grab a slot.
// This is optional depending on what was used to
// construct this limiter, and is used to throttle
// incoming connections.
l.pause()
e := empty{}
select {
case l.wait <- e:
return true
default:
return false
}
}
// AcquireWait waits for the resource to become available before returning.
func (l limiter) AcquireWait() {
e := empty{}
l.wait <- e
}
// Release returns the resource to the available pool.
func (l limiter) Release() error {
select {
case <-l.wait:
return nil
default:
return fmt.Errorf("Release without an associated Acquire")
}
}
func (l limiter) pause() {
if l.minPause <= 0 || l.maxPause <= 0 {
return
}
pauseRange := int((l.maxPause - l.minPause) / time.Millisecond)
pauseTime := time.Duration(rand.Intn(pauseRange)) * time.Millisecond
pauseTime += l.minPause
select {
case <-l.clock.After(pauseTime):
}
}
================================================
FILE: limiter_test.go
================================================
// Copyright 2011, 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"fmt"
"time"
"github.com/juju/clock/testclock"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
const longWait = 10 * time.Second
type limiterSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&limiterSuite{})
func (*limiterSuite) TestAcquireUntilFull(c *gc.C) {
l := utils.NewLimiter(2)
c.Check(l.Acquire(), jc.IsTrue)
c.Check(l.Acquire(), jc.IsTrue)
c.Check(l.Acquire(), jc.IsFalse)
}
func (*limiterSuite) TestBadRelease(c *gc.C) {
l := utils.NewLimiter(2)
c.Check(l.Release(), gc.ErrorMatches, "Release without an associated Acquire")
}
func (*limiterSuite) TestAcquireAndRelease(c *gc.C) {
l := utils.NewLimiter(2)
c.Check(l.Acquire(), jc.IsTrue)
c.Check(l.Acquire(), jc.IsTrue)
c.Check(l.Acquire(), jc.IsFalse)
c.Check(l.Release(), gc.IsNil)
c.Check(l.Acquire(), jc.IsTrue)
c.Check(l.Release(), gc.IsNil)
c.Check(l.Release(), gc.IsNil)
c.Check(l.Release(), gc.ErrorMatches, "Release without an associated Acquire")
}
func (*limiterSuite) TestAcquireWaitBlocksUntilRelease(c *gc.C) {
l := utils.NewLimiter(2)
calls := make([]string, 0, 10)
start := make(chan bool, 0)
waiting := make(chan bool, 0)
done := make(chan bool, 0)
go func() {
<-start
calls = append(calls, fmt.Sprintf("%v", l.Acquire()))
calls = append(calls, fmt.Sprintf("%v", l.Acquire()))
calls = append(calls, fmt.Sprintf("%v", l.Acquire()))
waiting <- true
l.AcquireWait()
calls = append(calls, "waited")
calls = append(calls, fmt.Sprintf("%v", l.Acquire()))
done <- true
}()
// Start the routine, and wait for it to get to the first checkpoint
start <- true
select {
case <-waiting:
case <-time.After(longWait):
c.Fatalf("timed out waiting for 'waiting' to trigger")
}
c.Check(l.Acquire(), jc.IsFalse)
l.Release()
select {
case <-done:
case <-time.After(longWait):
c.Fatalf("timed out waiting for 'done' to trigger")
}
c.Check(calls, gc.DeepEquals, []string{"true", "true", "false", "waited", "false"})
}
func (*limiterSuite) TestAcquirePauses(c *gc.C) {
clk := testclock.NewClock(time.Now())
l := utils.NewLimiterWithPause(2, 10*time.Millisecond, 20*time.Millisecond, clk)
acquired := make(chan bool, 1)
start := make(chan bool, 0)
go func() {
<-start
defer l.Release()
acquired <- l.Acquire()
}()
start <- true
// Minimum pause time not exceeded, acquire should not happen.
clk.Advance(9 * time.Millisecond)
select {
case <-acquired:
c.Fail()
case <-time.After(50 * time.Millisecond):
}
clk.Advance(11 * time.Millisecond)
select {
case <-acquired:
case <-time.After(50 * time.Millisecond):
c.Fatal("acquire failed")
}
}
================================================
FILE: multireader.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"io"
"sort"
"github.com/juju/errors"
)
// SizeReaderAt combines io.ReaderAt with a Size method.
type SizeReaderAt interface {
// Size returns the size of the data readable
// from the reader.
Size() int64
io.ReaderAt
}
// NewMultiReaderAt is like io.MultiReader but produces a ReaderAt
// (and Size), instead of just a reader.
//
// Note: this implementation was taken from a talk given
// by Brad Fitzpatrick as OSCON 2013.
//
// http://talks.golang.org/2013/oscon-dl.slide#49
// https://github.com/golang/talks/blob/master/2013/oscon-dl/server-compose.go
func NewMultiReaderAt(parts ...SizeReaderAt) SizeReaderAt {
m := &multiReaderAt{
parts: make([]offsetAndSource, 0, len(parts)),
}
var off int64
for _, p := range parts {
m.parts = append(m.parts, offsetAndSource{off, p})
off += p.Size()
}
m.size = off
return m
}
type offsetAndSource struct {
off int64
SizeReaderAt
}
type multiReaderAt struct {
parts []offsetAndSource
size int64
}
func (m *multiReaderAt) Size() int64 {
return m.size
}
func (m *multiReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
wantN := len(p)
// Skip past the requested offset.
skipParts := sort.Search(len(m.parts), func(i int) bool {
// This function returns whether parts[i] will
// contribute any bytes to our output.
part := m.parts[i]
return part.off+part.Size() > off
})
parts := m.parts[skipParts:]
// How far to skip in the first part.
needSkip := off
if len(parts) > 0 {
needSkip -= parts[0].off
}
for len(parts) > 0 && len(p) > 0 {
readP := p
partSize := parts[0].Size()
if int64(len(readP)) > partSize-needSkip {
readP = readP[:partSize-needSkip]
}
pn, err0 := parts[0].ReadAt(readP, needSkip)
if err0 != nil {
return n, err0
}
n += pn
p = p[pn:]
if int64(pn)+needSkip == partSize {
parts = parts[1:]
}
needSkip = 0
}
if n != wantN {
err = io.ErrUnexpectedEOF
}
return
}
// NewMultiReaderSeeker returns an io.ReadSeeker that combines
// all the given readers into a single one. It assumes that
// all the seekers are initially positioned at the start.
func NewMultiReaderSeeker(readers ...io.ReadSeeker) io.ReadSeeker {
sreaders := make([]SizeReaderAt, len(readers))
for i, r := range readers {
r1, err := newSizeReaderAt(r)
if err != nil {
panic(err)
}
sreaders[i] = r1
}
return &readSeeker{
r: NewMultiReaderAt(sreaders...),
}
}
// newSizeReaderAt adapts an io.ReadSeeker to a SizeReaderAt.
// Note that it doesn't strictly adhere to the ReaderAt
// contract because it's not safe to call ReadAt concurrently.
// This doesn't matter because io.ReadSeeker doesn't
// need to be thread-safe and this is only used in that
// context.
func newSizeReaderAt(r io.ReadSeeker) (SizeReaderAt, error) {
size, err := r.Seek(0, 2)
if err != nil {
return nil, err
}
return &sizeReaderAt{
r: r,
size: size,
off: size,
}, nil
}
// sizeReaderAt adapts an io.ReadSeeker to a SizeReaderAt.
type sizeReaderAt struct {
r io.ReadSeeker
size int64
off int64
}
// ReadAt implemnts SizeReaderAt.ReadAt.
func (r *sizeReaderAt) ReadAt(buf []byte, off int64) (n int, err error) {
if off != r.off {
_, err = r.r.Seek(off, 0)
if err != nil {
return 0, err
}
r.off = off
}
n, err = io.ReadFull(r.r, buf)
r.off += int64(n)
return n, err
}
// Size implemnts SizeReaderAt.Size.
func (r *sizeReaderAt) Size() int64 {
return r.size
}
// readSeeker adapts a SizeReaderAt to an io.ReadSeeker.
type readSeeker struct {
r SizeReaderAt
off int64
}
// Seek implements io.Seeker.Seek.
func (r *readSeeker) Seek(off int64, whence int) (int64, error) {
switch whence {
case 0:
case 1:
off += r.off
case 2:
off = r.r.Size() + off
}
if off < 0 {
return 0, errors.New("negative position")
}
r.off = off
return off, nil
}
// Read implements io.Reader.Read.
func (r *readSeeker) Read(buf []byte) (int, error) {
n, err := r.r.ReadAt(buf, r.off)
r.off += int64(n)
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return n, err
}
================================================
FILE: multireader_test.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"io"
"io/ioutil"
"strings"
"testing/iotest"
jc "github.com/juju/testing/checkers"
"github.com/juju/utils/v4"
gc "gopkg.in/check.v1"
)
type multiReaderSeekerSuite struct{}
var _ = gc.Suite(&multiReaderSeekerSuite{})
func (*multiReaderSeekerSuite) TestSequentialRead(c *gc.C) {
parts := []string{
"one",
"two",
"three",
"four",
}
r := newMultiStringReader(parts)
data, err := ioutil.ReadAll(r)
c.Assert(err, gc.IsNil)
c.Assert(string(data), gc.Equals, strings.Join(parts, ""))
}
func (*multiReaderSeekerSuite) TestSeekStart(c *gc.C) {
parts := []string{
"one",
"two",
"three",
"four",
}
all := strings.Join(parts, "")
for off := int64(0); off <= int64(len(all)); off++ {
c.Logf("-- offset %d", off)
r := newMultiStringReader(parts)
gotOff, err := r.Seek(off, 0)
c.Assert(err, gc.IsNil)
c.Assert(gotOff, gc.Equals, off)
data, err := ioutil.ReadAll(r)
c.Assert(err, gc.IsNil)
c.Assert(string(data), gc.Equals, all[off:])
}
}
func (*multiReaderSeekerSuite) TestSeekEnd(c *gc.C) {
parts := []string{
"one",
"two",
"three",
"four",
}
all := strings.Join(parts, "")
for off := int64(0); off <= int64(len(all)); off++ {
r := newMultiStringReader(parts)
expectOff := int64(len(all)) - off
gotOff, err := r.Seek(-off, 2)
c.Assert(err, gc.IsNil)
c.Assert(gotOff, gc.Equals, expectOff)
data, err := ioutil.ReadAll(r)
c.Assert(err, gc.IsNil)
c.Assert(string(data), gc.Equals, all[expectOff:])
}
}
func (*multiReaderSeekerSuite) TestSeekCur(c *gc.C) {
parts := []string{
"one",
"two",
"three",
"four",
}
all := strings.Join(parts, "")
for off := int64(0); off <= int64(len(all)); off++ {
for newOff := int64(0); newOff <= int64(len(all)); newOff++ {
readers := make([]io.ReadSeeker, len(parts))
for i, part := range parts {
readers[i] = strings.NewReader(part)
}
r := utils.NewMultiReaderSeeker(readers...)
gotOff, err := r.Seek(off, 0)
c.Assert(gotOff, gc.Equals, off)
c.Assert(err, jc.ErrorIsNil)
diff := newOff - off
gotNewOff, err := r.Seek(diff, 1)
c.Assert(err, gc.IsNil)
c.Assert(gotNewOff, gc.Equals, newOff)
data, err := ioutil.ReadAll(r)
c.Assert(err, gc.IsNil)
c.Assert(string(data), gc.Equals, all[newOff:])
}
}
}
func (*multiReaderSeekerSuite) TestSeekAfterRead(c *gc.C) {
parts := []string{
"one",
"two",
"three",
"four",
}
all := strings.Join(parts, "")
r := newMultiStringReader(parts)
data, err := ioutil.ReadAll(iotest.OneByteReader(r))
c.Assert(err, gc.IsNil)
c.Assert(string(data), gc.Equals, all)
off, err := r.Seek(-8, 2)
c.Assert(err, gc.IsNil)
c.Assert(off, gc.Equals, int64(len(all)-8))
data, err = ioutil.ReadAll(r)
c.Assert(err, gc.IsNil)
c.Assert(string(data), gc.Equals, "hreefour")
}
func (*multiReaderSeekerSuite) TestSeekNegative(c *gc.C) {
r := newMultiStringReader([]string{"one", "two"})
_, err := r.Seek(-1, 0)
c.Assert(err, gc.ErrorMatches, "negative position")
n, err := r.Seek(0, 0)
c.Assert(err, gc.IsNil)
c.Assert(n, gc.Equals, int64(0))
_, err = r.Seek(-7, 2)
c.Assert(err, gc.ErrorMatches, "negative position")
n, err = r.Seek(0, 0)
c.Assert(err, gc.IsNil)
c.Assert(n, gc.Equals, int64(0))
_, err = r.Seek(-1, 1)
c.Assert(err, gc.ErrorMatches, "negative position")
n, err = r.Seek(0, 0)
c.Assert(err, gc.IsNil)
c.Assert(n, gc.Equals, int64(0))
}
func (*multiReaderSeekerSuite) TestSeekPastEnd(c *gc.C) {
r := newMultiStringReader([]string{"one", "two"})
n, err := r.Seek(100, 0)
c.Assert(err, jc.ErrorIsNil)
c.Assert(n, gc.Equals, int64(100))
nr, err := r.Read(make([]byte, 1))
c.Assert(nr, gc.Equals, 0)
c.Assert(err, gc.Equals, io.EOF)
n, err = r.Seek(-5, 1)
c.Assert(err, jc.ErrorIsNil)
c.Assert(n, gc.Equals, int64(95))
nr, err = r.Read(make([]byte, 1))
c.Assert(nr, gc.Equals, 0)
c.Assert(err, gc.Equals, io.EOF)
n, err = r.Seek(-94, 1)
c.Assert(err, jc.ErrorIsNil)
c.Assert(n, gc.Equals, int64(1))
data, err := ioutil.ReadAll(r)
c.Assert(err, gc.IsNil)
c.Assert(string(data), gc.Equals, "netwo")
}
type multiReaderAtSuite struct{}
var _ = gc.Suite(&multiReaderAtSuite{})
func (*multiReaderAtSuite) TestReadComplete(c *gc.C) {
parts := []string{
"one",
"two",
"three",
"four",
}
all := strings.Join(parts, "")
r := newMultistringReaderAt(parts)
buf := make([]byte, len(all))
n, err := r.ReadAt(buf, 0)
c.Assert(err, jc.ErrorIsNil)
c.Assert(n, gc.Equals, len(buf))
c.Assert(string(buf), gc.Equals, all)
}
func (*multiReaderAtSuite) TestReadPartial(c *gc.C) {
parts := []string{
"one",
"two",
"three",
"four",
}
all := strings.Join(parts, "")
r := newMultistringReaderAt(parts)
buf := make([]byte, len(all)-4)
n, err := r.ReadAt(buf, 2)
c.Assert(err, jc.ErrorIsNil)
c.Assert(n, gc.Equals, len(buf))
c.Assert(string(buf), gc.Equals, "etwothreefo")
}
func newMultiStringReader(parts []string) io.ReadSeeker {
readers := make([]io.ReadSeeker, len(parts))
for i, part := range parts {
readers[i] = strings.NewReader(part)
}
return utils.NewMultiReaderSeeker(readers...)
}
type stringReader struct {
*strings.Reader
}
// This method is implemented in later versions
// of Go's StringReader but not prior to Go 1.5.
func (r stringReader) Size() int64 {
return int64(r.Len())
}
func newMultistringReaderAt(parts []string) io.ReaderAt {
readers := make([]utils.SizeReaderAt, len(parts))
for i, part := range parts {
readers[i] = stringReader{strings.NewReader(part)}
}
return utils.NewMultiReaderAt(readers...)
}
================================================
FILE: naturalsort.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"fmt"
"sort"
"strconv"
"unicode"
)
// SortStringsNaturally sorts strings according to their natural sort order.
func SortStringsNaturally(s []string) []string {
sort.Sort(naturally(s))
return s
}
type naturally []string
func (n naturally) Len() int {
return len(n)
}
func (n naturally) Swap(a, b int) {
n[a], n[b] = n[b], n[a]
}
// Less sorts by non-numeric prefix and numeric suffix
// when one exists.
func (n naturally) Less(a, b int) bool {
aVal := n[a]
bVal := n[b]
for {
// If bVal is empty, then aVal can't be less than it.
if bVal == "" {
return false
}
// If aVal is empty here, then is must be less than bVal.
if aVal == "" {
return true
}
aPrefix, aNumber, aRemainder := splitAtNumber(aVal)
bPrefix, bNumber, bRemainder := splitAtNumber(bVal)
if aPrefix != bPrefix {
return aPrefix < bPrefix
}
if aNumber != bNumber {
return aNumber < bNumber
}
// Everything is the same so far, try again with the remainer.
aVal = aRemainder
bVal = bRemainder
}
}
// splitAtNumber splits given string at the first digit, returning the
// prefix before the number, the integer represented by the first
// series of digits, and the remainder of the string after the first
// series of digits. If no digits are present, the number is returned
// as -1 and the remainder is empty.
func splitAtNumber(str string) (string, int, string) {
i := indexOfDigit(str)
if i == -1 {
// no numbers
return str, -1, ""
}
j := i + indexOfNonDigit(str[i:])
n, err := strconv.Atoi(str[i:j])
if err != nil {
panic(fmt.Sprintf("parsing number %v: %v", str[i:j], err)) // should never happen
}
return str[:i], n, str[j:]
}
func indexOfDigit(str string) int {
for i, r := range str {
if unicode.IsDigit(r) {
return i
}
}
return -1
}
func indexOfNonDigit(str string) int {
for i, r := range str {
if !unicode.IsDigit(r) {
return i
}
}
return len(str)
}
================================================
FILE: naturalsort_test.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"math/rand"
gc "gopkg.in/check.v1"
"github.com/juju/testing"
"github.com/juju/utils/v4"
)
type naturalSortSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&naturalSortSuite{})
func (s *naturalSortSuite) TestEmpty(c *gc.C) {
checkCorrectSort(c, []string{})
}
func (s *naturalSortSuite) TestAlpha(c *gc.C) {
checkCorrectSort(c, []string{"abc", "bac", "cba"})
}
func (s *naturalSortSuite) TestNumVsString(c *gc.C) {
checkCorrectSort(c, []string{"1", "a"})
}
func (s *naturalSortSuite) TestStringVsStringNum(c *gc.C) {
checkCorrectSort(c, []string{"a", "a1"})
}
func (s *naturalSortSuite) TestCommonPrefix(c *gc.C) {
checkCorrectSort(c, []string{"a1", "a1a", "a1b", "a2b", "a2c"})
}
func (s *naturalSortSuite) TestDifferentNumberLengths(c *gc.C) {
checkCorrectSort(c, []string{"a1a", "a2", "a22a", "a333", "a333a", "a333b"})
}
func (s *naturalSortSuite) TestZeroPadding(c *gc.C) {
checkCorrectSort(c, []string{"a1", "a002", "a3"})
}
func (s *naturalSortSuite) TestMixed(c *gc.C) {
checkCorrectSort(c, []string{"1a", "a1", "a1/1", "a10", "a100"})
}
func (s *naturalSortSuite) TestSeveralNumericParts(c *gc.C) {
checkCorrectSort(c, []string{
"x",
"x1",
"x1-g0",
"x1-g1",
"x1-g2",
"x1-g10",
"x2",
"x2-g0",
"x2-g2",
"x11-g0",
"x11-g0-0",
"x11-g0-1",
"x11-g0-10",
"x11-g0-11",
"x11-g0-20",
"x11-g0-100",
"x11-g10-1",
"x11-g10-10",
"xx1",
"xx10",
})
}
func (s *naturalSortSuite) TestUnitNameLike(c *gc.C) {
checkCorrectSort(c, []string{"a1/1", "a1/2", "a1/7", "a1/11", "a1/100"})
}
func (s *naturalSortSuite) TestMachineIdLike(c *gc.C) {
checkCorrectSort(c, []string{
"1",
"1/lxc/0",
"1/lxc/1",
"1/lxc/2",
"1/lxc/10",
"1/lxd/0",
"1/lxd/1",
"1/lxd/10",
"2",
"11",
"11/lxc/6",
"11/lxc/60",
"20",
"21",
})
}
func (s *naturalSortSuite) TestIPs(c *gc.C) {
checkCorrectSort(c, []string{
"1.1.10.122",
"001.001.010.123",
"001.002.010.123",
"100.001.010.123",
"100.1.10.124",
"100.2.10.124",
})
}
func checkCorrectSort(c *gc.C, expected []string) {
checkSort(c, expected, reverse)
for i := 0; i < 5; i++ {
checkSort(c, expected, shuffle)
}
}
func checkSort(c *gc.C, expected []string, xform func([]string)) {
input := copyStrSlice(expected)
xform(input)
origInput := copyStrSlice(input)
utils.SortStringsNaturally(input)
c.Check(input, gc.DeepEquals, expected, gc.Commentf("input was: %#v", origInput))
}
func copyStrSlice(in []string) []string {
out := make([]string, len(in))
copy(out, in)
return out
}
func shuffle(a []string) {
// See https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#Modern_method
for i := len(a) - 1; i >= 1; i-- {
j := rand.Intn(i + 1)
a[i], a[j] = a[j], a[i]
}
}
func reverse(a []string) {
size := len(a)
for i := 0; i < size/2; i++ {
j := size - i - 1
a[i], a[j] = a[j], a[i]
}
}
================================================
FILE: network.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"fmt"
"net"
"github.com/juju/loggo/v2"
)
var logger = loggo.GetLogger("juju.utils")
// GetIPv4Address iterates through the addresses expecting the format from
// func (ifi *net.Interface) Addrs() ([]net.Addr, error)
func GetIPv4Address(addresses []net.Addr) (string, error) {
for _, addr := range addresses {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil {
return "", err
}
ipv4 := ip.To4()
if ipv4 == nil {
continue
}
return ipv4.String(), nil
}
return "", fmt.Errorf("no addresses match")
}
// GetIPv6Address iterates through the addresses expecting the format from
// func (ifi *net.Interface) Addrs() ([]net.Addr, error) and returns the first
// non-link local address.
func GetIPv6Address(addresses []net.Addr) (string, error) {
_, llNet, _ := net.ParseCIDR("fe80::/10")
for _, addr := range addresses {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil {
return "", err
}
if ip.To4() == nil && !llNet.Contains(ip) {
return ip.String(), nil
}
}
return "", fmt.Errorf("no addresses match")
}
// GetAddressForInterface looks for the network interface
// and returns the IPv4 address from the possible addresses.
func GetAddressForInterface(interfaceName string) (string, error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
logger.Errorf("cannot find network interface %q: %v", interfaceName, err)
return "", err
}
addrs, err := iface.Addrs()
if err != nil {
logger.Errorf("cannot get addresses for network interface %q: %v", interfaceName, err)
return "", err
}
return GetIPv4Address(addrs)
}
// GetV4OrV6AddressForInterface looks for the network interface
// and returns preferably the IPv4 address, and if it doesn't
// exists then IPv6 address.
func GetV4OrV6AddressForInterface(interfaceName string) (string, error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
logger.Errorf("cannot find network interface %q: %v", interfaceName, err)
return "", err
}
addrs, err := iface.Addrs()
if err != nil {
logger.Errorf("cannot get addresses for network interface %q: %v", interfaceName, err)
return "", err
}
if ip, err := GetIPv4Address(addrs); err == nil {
return ip, nil
}
return GetIPv6Address(addrs)
}
================================================
FILE: network_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"net"
"github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type networkSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&networkSuite{})
type fakeAddress struct {
address string
}
func (fake fakeAddress) Network() string {
return "ignored"
}
func (fake fakeAddress) String() string {
return fake.address
}
func makeAddresses(values ...string) (result []net.Addr) {
for _, v := range values {
result = append(result, &fakeAddress{v})
}
return
}
func (*networkSuite) TestGetIPv4Address(c *gc.C) {
for _, test := range []struct {
addresses []net.Addr
expected string
errorString string
}{{
addresses: makeAddresses(
"complete",
"nonsense"),
errorString: "invalid CIDR address: complete",
}, {
addresses: makeAddresses(
"fe80::90cf:9dff:fe6e:ece/64",
),
errorString: "no addresses match",
}, {
addresses: makeAddresses(
"fe80::90cf:9dff:fe6e:ece/64",
"10.0.3.1/24",
),
expected: "10.0.3.1",
}, {
addresses: makeAddresses(
"10.0.3.1/24",
"fe80::90cf:9dff:fe6e:ece/64",
),
expected: "10.0.3.1",
}} {
ip, err := utils.GetIPv4Address(test.addresses)
if test.errorString == "" {
c.Check(err, gc.IsNil)
c.Check(ip, gc.Equals, test.expected)
} else {
c.Check(err, gc.ErrorMatches, test.errorString)
c.Check(ip, gc.Equals, "")
}
}
}
func (*networkSuite) TestGetIPv6Address(c *gc.C) {
for _, test := range []struct {
addresses []net.Addr
expected string
errorString string
}{{
addresses: makeAddresses(
"complete",
"nonsense"),
errorString: "invalid CIDR address: complete",
}, {
addresses: makeAddresses(
"fe80::90cf:9dff:fe6e:ece/64",
),
errorString: "no addresses match",
}, {
addresses: makeAddresses(
"fe80::90cf:9dff:fe6e:ece/64",
"10.0.3.1/24",
),
errorString: "no addresses match",
}, {
addresses: makeAddresses(
"10.0.3.1/24",
),
errorString: "no addresses match",
}, {
addresses: makeAddresses(
"10.0.3.1/24",
"2001:db8::90cf:9dff:fe6e:ece/64",
),
expected: "2001:db8::90cf:9dff:fe6e:ece",
}, {
addresses: makeAddresses(
"2001:db8::90cf:9dff:fe6e:ece/64",
"10.0.3.1/24",
),
expected: "2001:db8::90cf:9dff:fe6e:ece",
}} {
ip, err := utils.GetIPv6Address(test.addresses)
if test.errorString == "" {
c.Check(err, gc.IsNil)
c.Check(ip, gc.Equals, test.expected)
} else {
c.Check(err, gc.ErrorMatches, test.errorString)
c.Check(ip, gc.Equals, "")
}
}
}
================================================
FILE: os.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
// These are the names of the operating systems recognized by Go.
const (
OSWindows = "windows"
OSDarwin = "darwin"
OSDragonfly = "dragonfly"
OSFreebsd = "freebsd"
OSLinux = "linux"
OSNacl = "nacl"
OSNetbsd = "netbsd"
OSOpenbsd = "openbsd"
OSSolaris = "solaris"
)
// OSUnix is the list of unix-like operating systems recognized by Go.
// See http://golang.org/src/path/filepath/path_unix.go.
var OSUnix = []string{
OSDarwin,
OSDragonfly,
OSFreebsd,
OSLinux,
OSNacl,
OSNetbsd,
OSOpenbsd,
OSSolaris,
}
// OSIsUnix determines whether or not the given OS name is one of the
// unix-like operating systems recognized by Go.
func OSIsUnix(os string) bool {
for _, goos := range OSUnix {
if os == goos {
return true
}
}
return false
}
================================================
FILE: os_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
var _ = gc.Suite(&osSuite{})
type osSuite struct {
testing.IsolationSuite
}
func (osSuite) TestOSIsUnixKnown(c *gc.C) {
for _, os := range utils.OSUnix {
c.Logf("checking %q", os)
isUnix := utils.OSIsUnix(os)
c.Check(isUnix, jc.IsTrue)
}
}
func (osSuite) TestOSIsUnixWindows(c *gc.C) {
isUnix := utils.OSIsUnix("windows")
c.Check(isUnix, jc.IsFalse)
}
func (osSuite) TestOSIsUnixUnknown(c *gc.C) {
isUnix := utils.OSIsUnix("")
c.Check(isUnix, jc.IsFalse)
}
================================================
FILE: package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: parallel/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package parallel_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: parallel/parallel.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// The parallel package provides utilities for running tasks
// concurrently.
package parallel
import (
"fmt"
"sync"
)
// Run represents a number of functions running concurrently.
type Run struct {
mu sync.Mutex
results chan Errors
max int
running int
work chan func() error
}
// Errors holds any errors encountered during the parallel run.
type Errors []error
func (errs Errors) Error() string {
switch len(errs) {
case 0:
return "no error"
case 1:
return errs[0].Error()
}
return fmt.Sprintf("%s (and %d more)", errs[0].Error(), len(errs)-1)
}
// NewRun returns a new parallel instance. It provides a way of running
// functions concurrently while limiting the maximum number running at
// once to max.
func NewRun(max int) *Run {
if max < 1 {
panic("parameter max must be >= 1")
}
return &Run{
max: max,
results: make(chan Errors),
work: make(chan func() error),
}
}
// Do requests that r run f concurrently. If there are already the maximum
// number of functions running concurrently, it will block until one of them
// has completed. Do may itself be called concurrently, but may not be called
// concurrently with Wait.
func (r *Run) Do(f func() error) {
select {
case r.work <- f:
return
default:
}
r.mu.Lock()
if r.running < r.max {
r.running++
go r.runner()
}
r.mu.Unlock()
r.work <- f
}
// Wait marks the parallel instance as complete and waits for all the functions
// to complete. If any errors were encountered, it returns an Errors value
// describing all the errors in arbitrary order.
func (r *Run) Wait() error {
close(r.work)
var errs Errors
for i := 0; i < r.running; i++ {
errs = append(errs, <-r.results...)
}
if len(errs) == 0 {
return nil
}
// TODO(rog) sort errors by original order of Do request?
return errs
}
func (r *Run) runner() {
var errs Errors
for f := range r.work {
if err := f(); err != nil {
errs = append(errs, err)
}
}
r.results <- errs
}
================================================
FILE: parallel/parallel_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package parallel_test
import (
"sort"
"sync"
"sync/atomic"
stdtesting "testing"
"time"
"github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/parallel"
)
type parallelSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(¶llelSuite{})
func (*parallelSuite) TestParallelMaxPar(c *gc.C) {
const (
totalDo = 10
maxConcurrentRunnersPar = 3
)
var mu sync.Mutex
maxConcurrentRunners := 0
nbRunners := 0
nbRuns := 0
parallelRunner := parallel.NewRun(maxConcurrentRunnersPar)
for i := 0; i < totalDo; i++ {
parallelRunner.Do(func() error {
mu.Lock()
nbRuns++
nbRunners++
if nbRunners > maxConcurrentRunners {
maxConcurrentRunners = nbRunners
}
mu.Unlock()
time.Sleep(time.Second / 10)
mu.Lock()
nbRunners--
mu.Unlock()
return nil
})
}
err := parallelRunner.Wait()
if nbRunners != 0 {
c.Errorf("%d functions still running", nbRunners)
}
if nbRuns != totalDo {
c.Errorf("all functions not executed; want %d got %d", totalDo, nbRuns)
}
c.Check(err, gc.IsNil)
if maxConcurrentRunners != maxConcurrentRunnersPar {
c.Errorf("wrong number of do's ran at once; want %d got %d", maxConcurrentRunnersPar, maxConcurrentRunners)
}
}
func nothing() error {
return nil
}
func BenchmarkRunSingle(b *stdtesting.B) {
for i := 0; i < b.N; i++ {
r := parallel.NewRun(1)
r.Do(nothing)
r.Wait()
}
}
func BenchmarkRun1000p100(b *stdtesting.B) {
for i := 0; i < b.N; i++ {
r := parallel.NewRun(100)
for j := 0; j < 1000; j++ {
r.Do(nothing)
}
r.Wait()
}
}
func (*parallelSuite) TestConcurrentDo(c *gc.C) {
r := parallel.NewRun(3)
var count int32
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
r.Do(func() error {
atomic.AddInt32(&count, 1)
return nil
})
wg.Done()
}()
}
wg.Wait()
err := r.Wait()
c.Assert(err, gc.IsNil)
c.Assert(count, gc.Equals, int32(100))
}
type intError int
func (intError) Error() string {
return "error"
}
func (*parallelSuite) TestParallelError(c *gc.C) {
const (
totalDo = 10
errDo = 5
)
parallelRun := parallel.NewRun(6)
for i := 0; i < totalDo; i++ {
i := i
if i >= errDo {
parallelRun.Do(func() error {
return intError(i)
})
} else {
parallelRun.Do(func() error {
return nil
})
}
}
err := parallelRun.Wait()
c.Check(err, gc.NotNil)
errs := err.(parallel.Errors)
c.Check(len(errs), gc.Equals, totalDo-errDo)
ints := make([]int, len(errs))
for i, err := range errs {
ints[i] = int(err.(intError))
}
sort.Ints(ints)
for i, n := range ints {
c.Check(n, gc.Equals, i+errDo)
}
}
func (*parallelSuite) TestZeroWorkerPanics(c *gc.C) {
defer func() {
r := recover()
c.Check(r, gc.Matches, "parameter max must be >= 1")
}()
parallel.NewRun(0)
}
================================================
FILE: parallel/try.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package parallel
import (
"errors"
"io"
"sync"
"gopkg.in/tomb.v1"
)
var (
ErrStopped = errors.New("try was stopped")
ErrClosed = errors.New("try was closed")
)
// Try represents an attempt made concurrently
// by a number of goroutines.
type Try struct {
tomb tomb.Tomb
closeMutex sync.Mutex
close chan struct{}
limiter chan struct{}
start chan func()
result chan result
combineErrors func(err0, err1 error) error
maxParallel int
endResult io.Closer
}
// NewTry returns an object that runs functions concurrently until one
// succeeds. The result of the first function that returns without an
// error is available from the Result method. If maxParallel is
// positive, it limits the number of concurrently running functions.
//
// The function combineErrors(oldErr, newErr) is called to determine
// the error return (see the Result method). The first time it is called,
// oldErr will be nil; subsequently oldErr will be the error previously
// returned by combineErrors. If combineErrors is nil, the last
// encountered error is chosen.
func NewTry(maxParallel int, combineErrors func(err0, err1 error) error) *Try {
if combineErrors == nil {
combineErrors = chooseLastError
}
t := &Try{
combineErrors: combineErrors,
maxParallel: maxParallel,
close: make(chan struct{}, 1),
result: make(chan result),
start: make(chan func()),
}
if t.maxParallel > 0 {
t.limiter = make(chan struct{}, t.maxParallel)
for i := 0; i < t.maxParallel; i++ {
t.limiter <- struct{}{}
}
}
go func() {
defer t.tomb.Done()
val, err := t.loop()
t.endResult = val
t.tomb.Kill(err)
}()
return t
}
func chooseLastError(err0, err1 error) error {
return err1
}
type result struct {
val io.Closer
err error
}
func (t *Try) loop() (io.Closer, error) {
var err error
close := t.close
nrunning := 0
for {
select {
case f := <-t.start:
nrunning++
go f()
case r := <-t.result:
if r.err == nil {
return r.val, r.err
}
err = t.combineErrors(err, r.err)
nrunning--
if close == nil && nrunning == 0 {
return nil, err
}
case <-t.tomb.Dying():
if err == nil {
return nil, ErrStopped
}
return nil, err
case <-close:
close = nil
if nrunning == 0 {
return nil, err
}
}
}
}
// Start requests the given function to be started, waiting until there
// are less than maxParallel functions running if necessary. It returns
// an error if the function has not been started (ErrClosed if the Try
// has been closed, and ErrStopped if the try is finishing).
//
// The function should listen on the stop channel and return if it
// receives a value, though this is advisory only - the Try does not
// wait for all started functions to return before completing.
//
// If the function returns a nil error but some earlier try was
// successful (that is, the returned value is being discarded),
// its returned value will be closed by calling its Close method.
func (t *Try) Start(try func(stop <-chan struct{}) (io.Closer, error)) error {
if t.limiter != nil {
// Wait for availability slot.
select {
case <-t.limiter:
case <-t.tomb.Dying():
return ErrStopped
case <-t.close:
return ErrClosed
}
}
dying := t.tomb.Dying()
f := func() {
val, err := try(dying)
if t.limiter != nil {
// Signal availability slot is now free.
t.limiter <- struct{}{}
}
// Deliver result.
select {
case t.result <- result{val, err}:
case <-dying:
if err == nil {
val.Close()
}
}
}
select {
case t.start <- f:
return nil
case <-dying:
return ErrStopped
case <-t.close:
return ErrClosed
}
}
// Close closes the Try. No more functions will be started
// if Start is called, and the Try will terminate when all
// outstanding functions have completed (or earlier
// if one succeeds)
func (t *Try) Close() {
t.closeMutex.Lock()
defer t.closeMutex.Unlock()
select {
case <-t.close:
default:
close(t.close)
}
}
// Dead returns a channel that is closed when the
// Try completes.
func (t *Try) Dead() <-chan struct{} {
return t.tomb.Dead()
}
// Wait waits for the Try to complete and returns the same
// error returned by Result.
func (t *Try) Wait() error {
return t.tomb.Wait()
}
// Result waits for the Try to complete and returns the result of the
// first successful function started by Start.
//
// If no function succeeded, the last error returned by
// combineErrors is returned. If there were no errors or
// combineErrors returned nil, ErrStopped is returned.
func (t *Try) Result() (io.Closer, error) {
err := t.tomb.Wait()
return t.endResult, err
}
// Kill stops the try and all its currently executing functions.
func (t *Try) Kill() {
t.tomb.Kill(nil)
}
================================================
FILE: parallel/try_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package parallel_test
import (
"errors"
"fmt"
"io"
"sort"
"sync"
"time"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/parallel"
)
const (
shortWait = 50 * time.Millisecond
longWait = 10 * time.Second
)
type result string
func (r result) Close() error {
return nil
}
type trySuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&trySuite{})
func tryFunc(delay time.Duration, val io.Closer, err error) func(<-chan struct{}) (io.Closer, error) {
return func(<-chan struct{}) (io.Closer, error) {
time.Sleep(delay)
return val, err
}
}
func (*trySuite) TestOneSuccess(c *gc.C) {
try := parallel.NewTry(0, nil)
try.Start(tryFunc(0, result("hello"), nil))
val, err := try.Result()
c.Assert(err, gc.IsNil)
c.Assert(val, gc.Equals, result("hello"))
}
func (*trySuite) TestOneFailure(c *gc.C) {
try := parallel.NewTry(0, nil)
expectErr := errors.New("foo")
err := try.Start(tryFunc(0, nil, expectErr))
c.Assert(err, gc.IsNil)
select {
case <-try.Dead():
c.Fatalf("try died before it should")
case <-time.After(shortWait):
}
try.Close()
select {
case <-try.Dead():
case <-time.After(longWait):
c.Fatalf("timed out waiting for Try to complete")
}
val, err := try.Result()
c.Assert(val, gc.IsNil)
c.Assert(err, gc.Equals, expectErr)
}
func (*trySuite) TestStartReturnsErrorAfterClose(c *gc.C) {
try := parallel.NewTry(0, nil)
expectErr := errors.New("foo")
err := try.Start(tryFunc(0, nil, expectErr))
c.Assert(err, gc.IsNil)
try.Close()
err = try.Start(tryFunc(0, result("goodbye"), nil))
c.Assert(err, gc.Equals, parallel.ErrClosed)
// Wait for the first try to deliver its result
time.Sleep(shortWait)
try.Kill()
err = try.Wait()
c.Assert(err, gc.Equals, expectErr)
}
func (*trySuite) TestOutOfOrderResults(c *gc.C) {
try := parallel.NewTry(0, nil)
try.Start(tryFunc(50*time.Millisecond, result("first"), nil))
try.Start(tryFunc(10*time.Millisecond, result("second"), nil))
r, err := try.Result()
c.Assert(err, gc.IsNil)
c.Assert(r, gc.Equals, result("second"))
}
func (*trySuite) TestMaxParallel(c *gc.C) {
try := parallel.NewTry(3, nil)
var (
mu sync.Mutex
count int
max int
)
for i := 0; i < 10; i++ {
try.Start(func(<-chan struct{}) (io.Closer, error) {
mu.Lock()
if count++; count > max {
max = count
}
c.Check(count, gc.Not(jc.GreaterThan), 3)
mu.Unlock()
time.Sleep(20 * time.Millisecond)
mu.Lock()
count--
mu.Unlock()
return result("hello"), nil
})
}
r, err := try.Result()
c.Assert(err, gc.IsNil)
c.Assert(r, gc.Equals, result("hello"))
mu.Lock()
defer mu.Unlock()
c.Assert(max, gc.Equals, 3)
}
func (*trySuite) TestStartBlocksForMaxParallel(c *gc.C) {
try := parallel.NewTry(3, nil)
started := make(chan struct{})
begin := make(chan struct{})
go func() {
for i := 0; i < 6; i++ {
err := try.Start(func(<-chan struct{}) (io.Closer, error) {
<-begin
return nil, fmt.Errorf("an error")
})
started <- struct{}{}
if i < 5 {
c.Check(err, gc.IsNil)
} else {
c.Check(err, gc.Equals, parallel.ErrClosed)
}
}
close(started)
}()
// Check we can start the first three.
timeout := time.After(longWait)
for i := 0; i < 3; i++ {
select {
case <-started:
case <-timeout:
c.Fatalf("timed out")
}
}
// Check we block when going above maxParallel.
timeout = time.After(shortWait)
select {
case <-started:
c.Fatalf("Start did not block")
case <-timeout:
}
// Unblock two attempts.
begin <- struct{}{}
begin <- struct{}{}
// Check we can start another two.
timeout = time.After(longWait)
for i := 0; i < 2; i++ {
select {
case <-started:
case <-timeout:
c.Fatalf("timed out")
}
}
// Check we block again when going above maxParallel.
timeout = time.After(shortWait)
select {
case <-started:
c.Fatalf("Start did not block")
case <-timeout:
}
// Close the Try - the last request should be discarded,
// unblocking last remaining Start request.
try.Close()
timeout = time.After(longWait)
select {
case <-started:
case <-timeout:
c.Fatalf("Start did not unblock after Close")
}
// Ensure all checks are completed
select {
case _, ok := <-started:
c.Assert(ok, gc.Equals, false)
case <-timeout:
c.Fatalf("Start goroutine did not finish")
}
}
func (*trySuite) TestAllConcurrent(c *gc.C) {
try := parallel.NewTry(0, nil)
started := make(chan chan struct{})
for i := 0; i < 10; i++ {
try.Start(func(<-chan struct{}) (io.Closer, error) {
reply := make(chan struct{})
started <- reply
<-reply
return result("hello"), nil
})
}
timeout := time.After(longWait)
for i := 0; i < 10; i++ {
select {
case reply := <-started:
reply <- struct{}{}
case <-timeout:
c.Fatalf("timed out")
}
}
}
type gradedError int
func (e gradedError) Error() string {
return fmt.Sprintf("error with importance %d", e)
}
func gradedErrorCombine(err0, err1 error) error {
if err0 == nil || err0.(gradedError) < err1.(gradedError) {
return err1
}
return err0
}
type multiError struct {
errs []int
}
func (e *multiError) Error() string {
return fmt.Sprintf("%v", e.errs)
}
func (*trySuite) TestErrorCombine(c *gc.C) {
// Use maxParallel=1 to guarantee that all errors are processed sequentially.
try := parallel.NewTry(1, func(err0, err1 error) error {
if err0 == nil {
err0 = &multiError{}
}
err0.(*multiError).errs = append(err0.(*multiError).errs, int(err1.(gradedError)))
return err0
})
errors := []gradedError{3, 2, 4, 0, 5, 5, 3}
for _, err := range errors {
err := err
try.Start(func(<-chan struct{}) (io.Closer, error) {
return nil, err
})
}
try.Close()
val, err := try.Result()
c.Assert(val, gc.IsNil)
grades := err.(*multiError).errs
sort.Ints(grades)
c.Assert(grades, gc.DeepEquals, []int{0, 2, 3, 3, 4, 5, 5})
}
func (*trySuite) TestTriesAreStopped(c *gc.C) {
try := parallel.NewTry(0, nil)
stopped := make(chan struct{})
try.Start(func(stop <-chan struct{}) (io.Closer, error) {
<-stop
stopped <- struct{}{}
return nil, parallel.ErrStopped
})
try.Start(tryFunc(0, result("hello"), nil))
val, err := try.Result()
c.Assert(err, gc.IsNil)
c.Assert(val, gc.Equals, result("hello"))
select {
case <-stopped:
case <-time.After(longWait):
c.Fatalf("timed out waiting for stop")
}
}
func (*trySuite) TestCloseTwice(c *gc.C) {
try := parallel.NewTry(0, nil)
try.Close()
try.Close()
val, err := try.Result()
c.Assert(val, gc.IsNil)
c.Assert(err, gc.IsNil)
}
type closeResult struct {
closed chan struct{}
}
func (r *closeResult) Close() error {
close(r.closed)
return nil
}
func (*trySuite) TestExtraResultsAreClosed(c *gc.C) {
try := parallel.NewTry(0, nil)
begin := make([]chan struct{}, 4)
results := make([]*closeResult, len(begin))
for i := range begin {
begin[i] = make(chan struct{})
results[i] = &closeResult{make(chan struct{})}
i := i
try.Start(func(<-chan struct{}) (io.Closer, error) {
<-begin[i]
return results[i], nil
})
}
begin[0] <- struct{}{}
val, err := try.Result()
c.Assert(err, gc.IsNil)
c.Assert(val, gc.Equals, results[0])
timeout := time.After(shortWait)
for i, r := range results[1:] {
begin[i+1] <- struct{}{}
select {
case <-r.closed:
case <-timeout:
c.Fatalf("timed out waiting for close")
}
}
select {
case <-results[0].closed:
c.Fatalf("result was inappropriately closed")
case <-time.After(shortWait):
}
}
func (*trySuite) TestEverything(c *gc.C) {
try := parallel.NewTry(5, gradedErrorCombine)
tries := []struct {
startAt time.Duration
wait time.Duration
val result
err error
}{{
wait: 30 * time.Millisecond,
err: gradedError(3),
}, {
startAt: 10 * time.Millisecond,
wait: 20 * time.Millisecond,
val: result("result 1"),
}, {
startAt: 20 * time.Millisecond,
wait: 10 * time.Millisecond,
val: result("result 2"),
}, {
startAt: 20 * time.Millisecond,
wait: 5 * time.Second,
val: "delayed result",
}, {
startAt: 5 * time.Millisecond,
err: gradedError(4),
}}
for _, t := range tries {
t := t
go func() {
time.Sleep(t.startAt)
try.Start(tryFunc(t.wait, t.val, t.err))
}()
}
val, err := try.Result()
if val != result("result 1") && val != result("result 2") {
c.Errorf(`expected "result 1" or "result 2" got %#v`, val)
}
c.Assert(err, gc.IsNil)
}
================================================
FILE: password.go
================================================
// Copyright 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"crypto/rand"
"crypto/sha512"
"encoding/base64"
"fmt"
"io"
"golang.org/x/crypto/pbkdf2"
)
// CompatSalt is because Juju 1.16 and older used a hard-coded salt to compute
// the password hash for all users and agents
var CompatSalt = string([]byte{0x75, 0x82, 0x81, 0xca})
const randomPasswordBytes = 18
// MinAgentPasswordLength describes how long agent passwords should be. We
// require this length because we assume enough entropy in the Agent password
// that it is safe to not do extra rounds of iterated hashing.
var MinAgentPasswordLength = base64.StdEncoding.EncodedLen(randomPasswordBytes)
// RandomBytes returns n random bytes.
func RandomBytes(n int) ([]byte, error) {
buf := make([]byte, n)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return nil, fmt.Errorf("cannot read random bytes: %v", err)
}
return buf, nil
}
// RandomPassword generates a random base64-encoded password.
func RandomPassword() (string, error) {
b, err := RandomBytes(randomPasswordBytes)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(b), nil
}
// RandomSalt generates a random base64 data suitable for using as a password
// salt The pbkdf2 guideline is to use 8 bytes of salt, so we do 12 raw bytes
// into 16 base64 bytes. (The alternative is 6 raw into 8 base64).
func RandomSalt() (string, error) {
b, err := RandomBytes(12)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(b), nil
}
// FastInsecureHash specifies whether a fast, insecure version of the hash
// algorithm will be used. Changing this will cause PasswordHash to
// produce incompatible passwords. It should only be changed for
// testing purposes - to make tests run faster.
var FastInsecureHash = false
// UserPasswordHash returns base64-encoded one-way hash password that is
// computationally hard to crack by iterating through possible passwords.
func UserPasswordHash(password string, salt string) string {
if salt == "" {
panic("salt is not allowed to be empty")
}
iter := 8192
if FastInsecureHash {
iter = 1
}
// Generate 18 byte passwords because we know that MongoDB
// uses the MD5 sum of the password anyway, so there's
// no point in using more bytes. (18 so we don't get base 64
// padding characters).
h := pbkdf2.Key([]byte(password), []byte(salt), iter, 18, sha512.New)
return base64.StdEncoding.EncodeToString(h)
}
// AgentPasswordHash returns base64-encoded one-way hash of password. This is
// not suitable for User passwords because those will have limited entropy (see
// UserPasswordHash). However, since we generate long random passwords for
// agents, we can trust that there is sufficient entropy to prevent brute force
// search. And using a faster hash allows us to restart the state machines and
// have 1000s of agents log in in a reasonable amount of time.
func AgentPasswordHash(password string) string {
sum := sha512.New()
sum.Write([]byte(password))
h := sum.Sum(nil)
return base64.StdEncoding.EncodeToString(h[:18])
}
================================================
FILE: password_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type passwordSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&passwordSuite{})
// Base64 *can* include a tail of '=' characters, but all the tests here
// explicitly *don't* want those because it is wasteful.
var base64Chars = "^[A-Za-z0-9+/]+$"
func (*passwordSuite) TestRandomBytes(c *gc.C) {
b, err := utils.RandomBytes(16)
c.Assert(err, gc.IsNil)
c.Assert(b, gc.HasLen, 16)
x0 := b[0]
for _, x := range b {
if x != x0 {
return
}
}
c.Errorf("all same bytes in result of RandomBytes")
}
func (*passwordSuite) TestRandomPassword(c *gc.C) {
p, err := utils.RandomPassword()
c.Assert(err, gc.IsNil)
if len(p) < 18 {
c.Errorf("password too short: %q", p)
}
c.Assert(p, gc.Matches, base64Chars)
}
func (*passwordSuite) TestRandomSalt(c *gc.C) {
salt, err := utils.RandomSalt()
c.Assert(err, gc.IsNil)
if len(salt) < 12 {
c.Errorf("salt too short: %q", salt)
}
// check we're not adding base64 padding.
c.Assert(salt, gc.Matches, base64Chars)
}
var testPasswords = []string{"", "a", "a longer password than i would usually bother with"}
var testSalts = []string{"abcd", "abcdefgh", "abcdefghijklmnop", utils.CompatSalt}
func (*passwordSuite) TestUserPasswordHash(c *gc.C) {
seenHashes := make(map[string]bool)
for i, password := range testPasswords {
for j, salt := range testSalts {
c.Logf("test %d, %d %s %s", i, j, password, salt)
hashed := utils.UserPasswordHash(password, salt)
c.Logf("hash %q", hashed)
c.Assert(len(hashed), gc.Equals, 24)
c.Assert(seenHashes[hashed], gc.Equals, false)
// check we're not adding base64 padding.
c.Assert(hashed, gc.Matches, base64Chars)
seenHashes[hashed] = true
// check it's deterministic
altHashed := utils.UserPasswordHash(password, salt)
c.Assert(altHashed, gc.Equals, hashed)
}
}
}
func (*passwordSuite) TestAgentPasswordHash(c *gc.C) {
seenValues := make(map[string]bool)
for i := 0; i < 1000; i++ {
password, err := utils.RandomPassword()
c.Assert(err, gc.IsNil)
c.Assert(seenValues[password], jc.IsFalse)
seenValues[password] = true
hashed := utils.AgentPasswordHash(password)
c.Assert(hashed, gc.Not(gc.Equals), password)
c.Assert(seenValues[hashed], jc.IsFalse)
seenValues[hashed] = true
c.Assert(len(hashed), gc.Equals, 24)
// check we're not adding base64 padding.
c.Assert(hashed, gc.Matches, base64Chars)
}
}
================================================
FILE: proxy/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package proxy_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: proxy/proxy.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package proxy
import (
"fmt"
"os"
"strings"
"github.com/juju/collections/set"
)
const (
// Remove the likelihood of errors by mistyping string values.
http_proxy = "http_proxy"
https_proxy = "https_proxy"
ftp_proxy = "ftp_proxy"
no_proxy = "no_proxy"
)
// Settings holds the values for the HTTP, HTTPS and FTP proxies as well as the
// no_proxy value found by Detect Proxies.
// AutoNoProxy is filled with addresses of controllers, we never want to proxy those
type Settings struct {
Http string
Https string
Ftp string
NoProxy string
AutoNoProxy string
}
func getSetting(key string) string {
value := os.Getenv(key)
if value == "" {
value = os.Getenv(strings.ToUpper(key))
}
return value
}
// DetectProxies returns the proxy settings found the environment.
func DetectProxies() Settings {
return Settings{
Http: getSetting(http_proxy),
Https: getSetting(https_proxy),
Ftp: getSetting(ftp_proxy),
NoProxy: getSetting(no_proxy),
}
}
// AsScriptEnvironment returns a potentially multi-line string in a format
// that specifies exported key=value lines. There are two lines for each non-
// empty proxy value, one lower-case and one upper-case.
func (s *Settings) AsScriptEnvironment() string {
var lines []string
addLine := func(proxy, value string) {
if value != "" {
lines = append(
lines,
fmt.Sprintf("export %s=%s", proxy, value),
fmt.Sprintf("export %s=%s", strings.ToUpper(proxy), value))
}
}
addLine(http_proxy, s.Http)
addLine(https_proxy, s.Https)
addLine(ftp_proxy, s.Ftp)
addLine(no_proxy, s.FullNoProxy())
return strings.Join(lines, "\n")
}
// AsEnvironmentValues returns a slice of strings of the format "key=value"
// suitable to be used in a command environment. There are two values for each
// non-empty proxy value, one lower-case and one upper-case.
func (s *Settings) AsEnvironmentValues() []string {
lines := []string{}
addLine := func(proxy, value string) {
if value != "" {
lines = append(
lines,
fmt.Sprintf("%s=%s", proxy, value),
fmt.Sprintf("%s=%s", strings.ToUpper(proxy), value))
}
}
addLine(http_proxy, s.Http)
addLine(https_proxy, s.Https)
addLine(ftp_proxy, s.Ftp)
addLine(no_proxy, s.FullNoProxy())
return lines
}
// AsSystemdDefaultEnv returns a string in the format understood by systemd:
// DefaultEnvironment="http_proxy=...." "HTTP_PROXY=..." ...
func (s *Settings) AsSystemdDefaultEnv() string {
lines := s.AsEnvironmentValues()
rv := `# To allow juju to control the global systemd proxy settings,
# create symbolic links to this file from within /etc/systemd/system.conf.d/
# and /etc/systemd/users.conf.d/.
[Manager]
DefaultEnvironment=`
for _, line := range lines {
rv += fmt.Sprintf(`"%s" `, line)
}
return rv + "\n"
}
// SetEnvironmentValues updates the process environment with the
// proxy values stored in the settings object. Both the lower-case
// and upper-case variants are set.
//
// http_proxy, HTTP_PROXY
// https_proxy, HTTPS_PROXY
// ftp_proxy, FTP_PROXY
func (s *Settings) SetEnvironmentValues() {
setenv := func(proxy, value string) {
os.Setenv(proxy, value)
os.Setenv(strings.ToUpper(proxy), value)
}
setenv(http_proxy, s.Http)
setenv(https_proxy, s.Https)
setenv(ftp_proxy, s.Ftp)
setenv(no_proxy, s.FullNoProxy())
}
// FullNoProxy merges NoProxy and AutoNoProxyList
func (s *Settings) FullNoProxy() string {
var allNoProxy []string
if s.NoProxy != "" {
allNoProxy = strings.Split(s.NoProxy, ",")
}
if s.AutoNoProxy != "" {
allNoProxy = append(allNoProxy, strings.Split(s.AutoNoProxy, ",")...)
}
noProxySet := set.NewStrings(allNoProxy...)
return strings.Join(noProxySet.SortedValues(), ",")
}
================================================
FILE: proxy/proxy_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package proxy_test
import (
"os"
"github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/proxy"
)
type proxySuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&proxySuite{})
func (s *proxySuite) TestDetectNoSettings(c *gc.C) {
// Patch all of the environment variables we check out just in case the
// user has one set.
s.PatchEnvironment("http_proxy", "")
s.PatchEnvironment("HTTP_PROXY", "")
s.PatchEnvironment("https_proxy", "")
s.PatchEnvironment("HTTPS_PROXY", "")
s.PatchEnvironment("ftp_proxy", "")
s.PatchEnvironment("FTP_PROXY", "")
s.PatchEnvironment("no_proxy", "")
s.PatchEnvironment("NO_PROXY", "")
proxies := proxy.DetectProxies()
c.Assert(proxies, gc.DeepEquals, proxy.Settings{})
}
func (s *proxySuite) TestDetectPrimary(c *gc.C) {
// Patch all of the environment variables we check out just in case the
// user has one set.
s.PatchEnvironment("http_proxy", "http://user@10.0.0.1")
s.PatchEnvironment("HTTP_PROXY", "")
s.PatchEnvironment("https_proxy", "https://user@10.0.0.1")
s.PatchEnvironment("HTTPS_PROXY", "")
s.PatchEnvironment("ftp_proxy", "ftp://user@10.0.0.1")
s.PatchEnvironment("FTP_PROXY", "")
s.PatchEnvironment("no_proxy", "10.0.3.1,localhost")
s.PatchEnvironment("NO_PROXY", "")
proxies := proxy.DetectProxies()
c.Assert(proxies, gc.DeepEquals, proxy.Settings{
Http: "http://user@10.0.0.1",
Https: "https://user@10.0.0.1",
Ftp: "ftp://user@10.0.0.1",
NoProxy: "10.0.3.1,localhost",
})
}
func (s *proxySuite) TestDetectFallback(c *gc.C) {
// Patch all of the environment variables we check out just in case the
// user has one set.
s.PatchEnvironment("http_proxy", "")
s.PatchEnvironment("HTTP_PROXY", "http://user@10.0.0.2")
s.PatchEnvironment("https_proxy", "")
s.PatchEnvironment("HTTPS_PROXY", "https://user@10.0.0.2")
s.PatchEnvironment("ftp_proxy", "")
s.PatchEnvironment("FTP_PROXY", "ftp://user@10.0.0.2")
s.PatchEnvironment("no_proxy", "")
s.PatchEnvironment("NO_PROXY", "10.0.3.1,localhost")
proxies := proxy.DetectProxies()
c.Assert(proxies, gc.DeepEquals, proxy.Settings{
Http: "http://user@10.0.0.2",
Https: "https://user@10.0.0.2",
Ftp: "ftp://user@10.0.0.2",
NoProxy: "10.0.3.1,localhost",
})
}
func (s *proxySuite) TestDetectPrimaryPreference(c *gc.C) {
// Patch all of the environment variables we check out just in case the
// user has one set.
s.PatchEnvironment("http_proxy", "http://user@10.0.0.1")
s.PatchEnvironment("https_proxy", "https://user@10.0.0.1")
s.PatchEnvironment("ftp_proxy", "ftp://user@10.0.0.1")
s.PatchEnvironment("no_proxy", "10.0.3.1,localhost")
s.PatchEnvironment("HTTP_PROXY", "http://user@10.0.0.2")
s.PatchEnvironment("HTTPS_PROXY", "https://user@10.0.0.2")
s.PatchEnvironment("FTP_PROXY", "ftp://user@10.0.0.2")
s.PatchEnvironment("NO_PROXY", "localhost")
proxies := proxy.DetectProxies()
c.Assert(proxies, gc.DeepEquals, proxy.Settings{
Http: "http://user@10.0.0.1",
Https: "https://user@10.0.0.1",
Ftp: "ftp://user@10.0.0.1",
NoProxy: "10.0.3.1,localhost",
})
}
func (s *proxySuite) TestAsScriptEnvironmentEmpty(c *gc.C) {
proxies := proxy.Settings{}
c.Assert(proxies.AsScriptEnvironment(), gc.Equals, "")
}
func (s *proxySuite) TestAsScriptEnvironmentOneValue(c *gc.C) {
proxies := proxy.Settings{
Http: "some-value",
}
expected := `
export http_proxy=some-value
export HTTP_PROXY=some-value`[1:]
c.Assert(proxies.AsScriptEnvironment(), gc.Equals, expected)
}
func (s *proxySuite) TestAsScriptEnvironmentAllValue(c *gc.C) {
proxies := proxy.Settings{
Http: "some-value",
Https: "special",
Ftp: "who uses this?",
NoProxy: "10.0.3.1,localhost",
}
expected := `
export http_proxy=some-value
export HTTP_PROXY=some-value
export https_proxy=special
export HTTPS_PROXY=special
export ftp_proxy=who uses this?
export FTP_PROXY=who uses this?
export no_proxy=10.0.3.1,localhost
export NO_PROXY=10.0.3.1,localhost`[1:]
c.Assert(proxies.AsScriptEnvironment(), gc.Equals, expected)
}
func (s *proxySuite) TestAsEnvironmentValuesEmpty(c *gc.C) {
proxies := proxy.Settings{}
c.Assert(proxies.AsEnvironmentValues(), gc.HasLen, 0)
}
func (s *proxySuite) TestAsEnvironmentValuesOneValue(c *gc.C) {
proxies := proxy.Settings{
Http: "some-value",
}
expected := []string{
"http_proxy=some-value",
"HTTP_PROXY=some-value",
}
c.Assert(proxies.AsEnvironmentValues(), gc.DeepEquals, expected)
}
func (s *proxySuite) TestAsEnvironmentValuesAllValue(c *gc.C) {
proxies := proxy.Settings{
Http: "some-value",
Https: "special",
Ftp: "who uses this?",
NoProxy: "10.0.3.1,localhost",
}
expected := []string{
"http_proxy=some-value",
"HTTP_PROXY=some-value",
"https_proxy=special",
"HTTPS_PROXY=special",
"ftp_proxy=who uses this?",
"FTP_PROXY=who uses this?",
"no_proxy=10.0.3.1,localhost",
"NO_PROXY=10.0.3.1,localhost",
}
c.Assert(proxies.AsEnvironmentValues(), gc.DeepEquals, expected)
}
func (s *proxySuite) TestAsSystemdDefaultEnv(c *gc.C) {
proxies := proxy.Settings{
Http: "some-value",
Https: "special",
Ftp: "who uses this?",
NoProxy: "10.0.3.1,localhost",
}
expected := `
# To allow juju to control the global systemd proxy settings,
# create symbolic links to this file from within /etc/systemd/system.conf.d/
# and /etc/systemd/users.conf.d/.
[Manager]
DefaultEnvironment="http_proxy=some-value" "HTTP_PROXY=some-value" "https_proxy=special" "HTTPS_PROXY=special" "ftp_proxy=who uses this?" "FTP_PROXY=who uses this?" "no_proxy=10.0.3.1,localhost" "NO_PROXY=10.0.3.1,localhost"
`[1:]
c.Assert(proxies.AsSystemdDefaultEnv(), gc.DeepEquals, expected)
}
func (s *proxySuite) TestSetEnvironmentValues(c *gc.C) {
s.PatchEnvironment("http_proxy", "initial")
s.PatchEnvironment("HTTP_PROXY", "initial")
s.PatchEnvironment("https_proxy", "initial")
s.PatchEnvironment("HTTPS_PROXY", "initial")
s.PatchEnvironment("ftp_proxy", "initial")
s.PatchEnvironment("FTP_PROXY", "initial")
s.PatchEnvironment("no_proxy", "initial")
s.PatchEnvironment("NO_PROXY", "initial")
proxySettings := proxy.Settings{
Http: "http proxy",
Https: "https proxy",
// Ftp left blank to show clearing env.
NoProxy: "10.0.3.1,localhost",
}
proxySettings.SetEnvironmentValues()
obtained := proxy.DetectProxies()
c.Assert(obtained, gc.DeepEquals, proxySettings)
c.Assert(os.Getenv("http_proxy"), gc.Equals, "http proxy")
c.Assert(os.Getenv("HTTP_PROXY"), gc.Equals, "http proxy")
c.Assert(os.Getenv("https_proxy"), gc.Equals, "https proxy")
c.Assert(os.Getenv("HTTPS_PROXY"), gc.Equals, "https proxy")
c.Assert(os.Getenv("ftp_proxy"), gc.Equals, "")
c.Assert(os.Getenv("FTP_PROXY"), gc.Equals, "")
c.Assert(os.Getenv("no_proxy"), gc.Equals, "10.0.3.1,localhost")
c.Assert(os.Getenv("NO_PROXY"), gc.Equals, "10.0.3.1,localhost")
}
func (s *proxySuite) TestAutoNoProxy(c *gc.C) {
proxies := proxy.Settings{
NoProxy: "10.0.3.1,localhost",
}
expectedFirst := []string{
"no_proxy=10.0.3.1,localhost",
"NO_PROXY=10.0.3.1,localhost",
}
expectedSecond := []string{
"no_proxy=10.0.3.1,10.0.3.2,localhost",
"NO_PROXY=10.0.3.1,10.0.3.2,localhost",
}
c.Assert(proxies.AsEnvironmentValues(), gc.DeepEquals, expectedFirst)
proxies.AutoNoProxy = "10.0.3.1,10.0.3.2"
c.Assert(proxies.AsEnvironmentValues(), gc.DeepEquals, expectedSecond)
}
================================================
FILE: randomstring.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"math/rand"
"sync"
"time"
)
// Can be used as a sane default argument for RandomString
var (
LowerAlpha = []rune("abcdefghijklmnopqrstuvwxyz")
UpperAlpha = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
Digits = []rune("0123456789")
)
var (
randomStringMu sync.Mutex
randomStringRand *rand.Rand
)
func init() {
randomStringRand = rand.New(
rand.NewSource(time.Now().UnixNano()),
)
}
// RandomString will return a string of length n that will only
// contain runes inside validRunes
func RandomString(n int, validRunes []rune) string {
randomStringMu.Lock()
defer randomStringMu.Unlock()
runes := make([]rune, n)
for i := range runes {
runes[i] = validRunes[randomStringRand.Intn(len(validRunes))]
}
return string(runes)
}
================================================
FILE: randomstring_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Copyright 2015 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
"github.com/juju/utils/v4"
gc "gopkg.in/check.v1"
)
type randomStringSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&randomStringSuite{})
var (
validChars = []rune("thisissorandom")
length = 7
)
func (randomStringSuite) TestLength(c *gc.C) {
s := utils.RandomString(length, validChars)
c.Assert(s, gc.HasLen, length)
}
func (randomStringSuite) TestContentInValidRunes(c *gc.C) {
s := utils.RandomString(length, validChars)
for _, char := range s {
c.Assert(string(validChars), jc.Contains, string(char))
}
}
================================================
FILE: registry/export_test.go
================================================
// Copyright 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package registry
var (
DescriptionFromVersions = descriptionFromVersions
)
================================================
FILE: registry/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package registry_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestAll(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: registry/registry.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package registry
import (
"fmt"
"reflect"
"sort"
"github.com/juju/errors"
)
// TypedNameVersion is a registry that will allow you to register objects based
// on a name and version pair. The objects must be convertible to the Type
// defined when the registry was created. It will be cast during Register so
// you can be sure all objects returned from Get() are safe to TypeAssert to
// that type.
type TypedNameVersion struct {
requiredType reflect.Type
versions map[string]Versions
}
// NewTypedNameVersion creates a place to register your objects
func NewTypedNameVersion(requiredType reflect.Type) *TypedNameVersion {
return &TypedNameVersion{
requiredType: requiredType,
versions: make(map[string]Versions),
}
}
// Description gives the name and available versions in a registry.
type Description struct {
Name string
Versions []int
}
// Versions maps concrete versions of the objects.
type Versions map[int]any
// Register records the factory that can be used to produce an instance of the
// facade at the supplied version.
// If the object being registered doesn't Implement the required Type, then an
// error is returned.
// An error is also returned if an object is already registered with the given
// name and version.
func (r *TypedNameVersion) Register(name string, version int, obj any) error {
if !reflect.TypeOf(obj).ConvertibleTo(r.requiredType) {
return fmt.Errorf("object of type %T cannot be converted to type %s.%s", obj, r.requiredType.PkgPath(), r.requiredType.Name())
}
obj = reflect.ValueOf(obj).Convert(r.requiredType).Interface()
if r.versions == nil {
r.versions = make(map[string]Versions, 1)
}
if versions, ok := r.versions[name]; ok {
if _, ok := versions[version]; ok {
fullname := fmt.Sprintf("%s(%d)", name, version)
return fmt.Errorf("object %q already registered", fullname)
}
versions[version] = obj
} else {
r.versions[name] = Versions{version: obj}
}
return nil
}
// descriptionFromVersions aggregates the information in a Versions map into a
// more friendly form for List()
func descriptionFromVersions(name string, versions Versions) Description {
intVersions := make([]int, 0, len(versions))
for version := range versions {
intVersions = append(intVersions, version)
}
sort.Ints(intVersions)
return Description{
Name: name,
Versions: intVersions,
}
}
// List returns a slice describing each of the registered Facades.
func (r *TypedNameVersion) List() []Description {
names := make([]string, 0, len(r.versions))
for name := range r.versions {
names = append(names, name)
}
sort.Strings(names)
descriptions := make([]Description, len(r.versions))
for i, name := range names {
versions := r.versions[name]
descriptions[i] = descriptionFromVersions(name, versions)
}
return descriptions
}
// Get returns the object for a single name and version. If the requested
// facade is not found, it returns error.NotFound
func (r *TypedNameVersion) Get(name string, version int) (any, error) {
if versions, ok := r.versions[name]; ok {
if factory, ok := versions[version]; ok {
return factory, nil
}
}
return nil, errors.NotFoundf("%s(%d)", name, version)
}
================================================
FILE: registry/registry_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package registry_test
import (
"reflect"
"github.com/juju/errors"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/registry"
)
type registrySuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(®istrySuite{})
type Factory func() (any, error)
func nilFactory() (any, error) {
return nil, nil
}
var factoryType = reflect.TypeOf((*Factory)(nil)).Elem()
type testFacade struct {
version string
called bool
}
type stringVal struct {
value string
}
func (t *testFacade) TestMethod() stringVal {
t.called = true
return stringVal{"called " + t.version}
}
func (s *registrySuite) TestDescriptionFromVersions(c *gc.C) {
versions := registry.Versions{0: nilFactory}
c.Check(registry.DescriptionFromVersions("name", versions),
gc.DeepEquals,
registry.Description{
Name: "name",
Versions: []int{0},
})
versions[2] = nilFactory
c.Check(registry.DescriptionFromVersions("name", versions),
gc.DeepEquals,
registry.Description{
Name: "name",
Versions: []int{0, 2},
})
}
func (s *registrySuite) TestDescriptionFromVersionsAreSorted(c *gc.C) {
versions := registry.Versions{
10: nilFactory,
5: nilFactory,
0: nilFactory,
18: nilFactory,
6: nilFactory,
4: nilFactory,
}
c.Check(registry.DescriptionFromVersions("name", versions),
gc.DeepEquals,
registry.Description{
Name: "name",
Versions: []int{0, 4, 5, 6, 10, 18},
})
}
func (s *registrySuite) TestRegisterAndList(c *gc.C) {
r := registry.NewTypedNameVersion(factoryType)
c.Assert(r.Register("name", 0, nilFactory), gc.IsNil)
c.Check(r.List(), gc.DeepEquals, []registry.Description{
{Name: "name", Versions: []int{0}},
})
}
func (s *registrySuite) TestRegisterAndListMultiple(c *gc.C) {
r := registry.NewTypedNameVersion(factoryType)
c.Assert(r.Register("other", 0, nilFactory), gc.IsNil)
c.Assert(r.Register("name", 0, nilFactory), gc.IsNil)
c.Assert(r.Register("third", 2, nilFactory), gc.IsNil)
c.Check(r.List(), gc.DeepEquals, []registry.Description{
{Name: "name", Versions: []int{0}},
{Name: "other", Versions: []int{0}},
{Name: "third", Versions: []int{2}},
})
}
func (s *registrySuite) TestRegisterWrongType(c *gc.C) {
r := registry.NewTypedNameVersion(factoryType)
err := r.Register("other", 0, "notAFactory")
c.Check(err, gc.ErrorMatches, `object of type string cannot be converted to type .*registry_test.Factory`)
}
func (s *registrySuite) TestRegisterAlreadyPresent(c *gc.C) {
r := registry.NewTypedNameVersion(factoryType)
err := r.Register("name", 0, func() (any, error) {
return "orig", nil
})
c.Assert(err, gc.IsNil)
err = r.Register("name", 0, func() (any, error) {
return "broken", nil
})
c.Check(err, gc.ErrorMatches, `object "name\(0\)" already registered`)
f, err := r.Get("name", 0)
c.Assert(err, gc.IsNil)
val, err := f.(Factory)()
c.Assert(err, gc.IsNil)
c.Check(val, gc.Equals, "orig")
}
func (s *registrySuite) TestGet(c *gc.C) {
r := registry.NewTypedNameVersion(factoryType)
customFactory := func() (any, error) {
return 10, nil
}
c.Assert(r.Register("name", 0, customFactory), gc.IsNil)
f, err := r.Get("name", 0)
c.Assert(err, gc.IsNil)
c.Assert(f, gc.NotNil)
res, err := f.(Factory)()
c.Assert(err, gc.IsNil)
c.Check(res, gc.Equals, 10)
}
func (s *registrySuite) TestGetUnknown(c *gc.C) {
r := registry.NewTypedNameVersion(factoryType)
f, err := r.Get("name", 0)
c.Check(err, jc.Satisfies, errors.IsNotFound)
c.Check(err, gc.ErrorMatches, `name\(0\) not found`)
c.Check(f, gc.IsNil)
}
func (s *registrySuite) TestGetUnknownVersion(c *gc.C) {
r := registry.NewTypedNameVersion(factoryType)
c.Assert(r.Register("name", 0, nilFactory), gc.IsNil)
f, err := r.Get("name", 1)
c.Check(err, jc.Satisfies, errors.IsNotFound)
c.Check(err, gc.ErrorMatches, `name\(1\) not found`)
c.Check(f, gc.IsNil)
}
================================================
FILE: relativeurl.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"strings"
"github.com/juju/errors"
)
// RelativeURLPath returns a relative URL path that is lexically
// equivalent to targpath when interpreted by url.URL.ResolveReference.
// On success, the returned path will always be non-empty and relative
// to basePath, even if basePath and targPath share no elements.
//
// It is assumed that both basePath and targPath are normalized
// (have no . or .. elements).
//
// An error is returned if basePath or targPath are not absolute paths.
func RelativeURLPath(basePath, targPath string) (string, error) {
if !strings.HasPrefix(basePath, "/") {
return "", errors.New("non-absolute base URL")
}
if !strings.HasPrefix(targPath, "/") {
return "", errors.New("non-absolute target URL")
}
baseParts := strings.Split(basePath, "/")
targParts := strings.Split(targPath, "/")
// For the purposes of dotdot, the last element of
// the paths are irrelevant. We save the last part
// of the target path for later.
lastElem := targParts[len(targParts)-1]
baseParts = baseParts[0 : len(baseParts)-1]
targParts = targParts[0 : len(targParts)-1]
// Find the common prefix between the two paths:
var i int
for ; i < len(baseParts); i++ {
if i >= len(targParts) || baseParts[i] != targParts[i] {
break
}
}
dotdotCount := len(baseParts) - i
targOnly := targParts[i:]
result := make([]string, 0, dotdotCount+len(targOnly)+1)
for i := 0; i < dotdotCount; i++ {
result = append(result, "..")
}
result = append(result, targOnly...)
result = append(result, lastElem)
final := strings.Join(result, "/")
if final == "" {
// If the final result is empty, the last element must
// have been empty, so the target was slash terminated
// and there were no previous elements, so "."
// is appropriate.
final = "."
}
return final, nil
}
================================================
FILE: relativeurl_test.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"net/url"
jujutesting "github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type relativeURLSuite struct {
jujutesting.LoggingSuite
}
var _ = gc.Suite(&relativeURLSuite{})
var relativeURLTests = []struct {
base string
target string
expect string
expectError string
}{{
expectError: "non-absolute base URL",
}, {
base: "/foo",
expectError: "non-absolute target URL",
}, {
base: "foo",
expectError: "non-absolute base URL",
}, {
base: "/foo",
target: "foo",
expectError: "non-absolute target URL",
}, {
base: "/foo",
target: "/bar",
expect: "bar",
}, {
base: "/foo/",
target: "/bar",
expect: "../bar",
}, {
base: "/bar",
target: "/foo/",
expect: "foo/",
}, {
base: "/foo/",
target: "/bar/",
expect: "../bar/",
}, {
base: "/foo/bar",
target: "/bar/",
expect: "../bar/",
}, {
base: "/foo/bar/",
target: "/bar/",
expect: "../../bar/",
}, {
base: "/foo/bar/baz",
target: "/foo/targ",
expect: "../targ",
}, {
base: "/foo/bar/baz/frob",
target: "/foo/bar/one/two/",
expect: "../one/two/",
}, {
base: "/foo/bar/baz/",
target: "/foo/targ",
expect: "../../targ",
}, {
base: "/foo/bar/baz/frob/",
target: "/foo/bar/one/two/",
expect: "../../one/two/",
}, {
base: "/foo/bar",
target: "/foot/bar",
expect: "../foot/bar",
}, {
base: "/foo/bar/baz/frob",
target: "/foo/bar",
expect: "../../bar",
}, {
base: "/foo/bar/baz/frob/",
target: "/foo/bar",
expect: "../../../bar",
}, {
base: "/foo/bar/baz/frob/",
target: "/foo/bar/",
expect: "../../",
}, {
base: "/foo/bar/baz",
target: "/foo/bar/other",
expect: "other",
}, {
base: "/foo/bar/",
target: "/foo/bar/",
expect: ".",
}, {
base: "/foo/bar",
target: "/foo/bar",
expect: "bar",
}, {
base: "/foo/bar/",
target: "/foo/bar/",
expect: ".",
}, {
base: "/foo/bar",
target: "/foo/",
expect: ".",
}, {
base: "/foo",
target: "/",
expect: ".",
}, {
base: "/foo/",
target: "/",
expect: "../",
}, {
base: "/foo/bar",
target: "/",
expect: "../",
}, {
base: "/foo/bar/",
target: "/",
expect: "../../",
}}
func (*relativeURLSuite) TestRelativeURL(c *gc.C) {
for i, test := range relativeURLTests {
c.Logf("test %d: %q %q", i, test.base, test.target)
// Sanity check the test itself.
if test.expectError == "" {
baseURL := &url.URL{Path: test.base}
expectURL := &url.URL{Path: test.expect}
targetURL := baseURL.ResolveReference(expectURL)
c.Check(targetURL.Path, gc.Equals, test.target, gc.Commentf("resolve reference failure (%q + %q != %q)", test.base, test.expect, test.target))
}
result, err := utils.RelativeURLPath(test.base, test.target)
if test.expectError != "" {
c.Assert(err, gc.ErrorMatches, test.expectError)
c.Assert(result, gc.Equals, "")
} else {
c.Assert(err, gc.IsNil)
c.Check(result, gc.Equals, test.expect)
}
}
}
================================================
FILE: setenv.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"strings"
)
// Setenv sets an environment variable entry in the given env slice (as
// returned by os.Environ or passed in exec.Cmd.Environ) to the given
// value. The entry should be in the form "x=y" where x is the name of the
// environment variable and y is its value; if not, env will be
// returned unchanged.
//
// If a value isn't already present in the slice, the entry is appended.
//
// The new environ slice is returned.
func Setenv(env []string, entry string) []string {
i := strings.Index(entry, "=")
if i == -1 {
return env
}
prefix := entry[0 : i+1]
for i, e := range env {
if strings.HasPrefix(e, prefix) {
env[i] = entry
return env
}
}
return append(env, entry)
}
================================================
FILE: setenv_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type SetenvSuite struct{}
var _ = gc.Suite(&SetenvSuite{})
var setenvTests = []struct {
set string
expect []string
}{
{"foo=1", []string{"foo=1", "arble="}},
{"foo=", []string{"foo=", "arble="}},
{"arble=23", []string{"foo=bar", "arble=23"}},
{"zaphod=42", []string{"foo=bar", "arble=", "zaphod=42"}},
{"bar", []string{"foo=bar", "arble="}},
}
func (*SetenvSuite) TestSetenv(c *gc.C) {
env0 := []string{"foo=bar", "arble="}
for i, t := range setenvTests {
c.Logf("test %d", i)
env := make([]string, len(env0))
copy(env, env0)
env = utils.Setenv(env, t.set)
c.Check(env, gc.DeepEquals, t.expect)
}
}
================================================
FILE: shell/bash.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell
import (
"strings"
)
// BashRenderer is the shell renderer for bash.
type BashRenderer struct {
unixRenderer
}
// Render implements ScriptWriter.
func (*BashRenderer) RenderScript(commands []string) []byte {
commands = append([]string{"#!/usr/bin/env bash", ""}, commands...)
return []byte(strings.Join(commands, "\n"))
}
================================================
FILE: shell/bash_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell_test
import (
"os"
"time"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/shell"
)
type bashSuite struct {
testing.IsolationSuite
dirname string
filename string
renderer *shell.BashRenderer
}
var _ = gc.Suite(&bashSuite{})
func (s *bashSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
s.dirname = `/some/dir`
s.filename = s.dirname + `/file`
s.renderer = &shell.BashRenderer{}
}
func (s bashSuite) TestExeSuffix(c *gc.C) {
suffix := s.renderer.ExeSuffix()
c.Check(suffix, gc.Equals, "")
}
func (s bashSuite) TestShQuote(c *gc.C) {
quoted := s.renderer.Quote("abc")
c.Check(quoted, gc.Equals, `'abc'`)
}
func (s bashSuite) TestChmod(c *gc.C) {
commands := s.renderer.Chmod(s.filename, 0644)
c.Check(commands, jc.DeepEquals, []string{
"chmod 0644 '/some/dir/file'",
})
}
func (s bashSuite) TestWriteFile(c *gc.C) {
data := []byte("something\nhere\n")
commands := s.renderer.WriteFile(s.filename, data)
expected := `cat > '/some/dir/file' << 'EOF'
something
here
EOF`
c.Check(commands, jc.DeepEquals, []string{
expected,
})
}
func (s bashSuite) TestMkdir(c *gc.C) {
commands := s.renderer.Mkdir(s.dirname)
c.Check(commands, jc.DeepEquals, []string{
`mkdir '/some/dir'`,
})
}
func (s bashSuite) TestMkdirAll(c *gc.C) {
commands := s.renderer.MkdirAll(s.dirname)
c.Check(commands, jc.DeepEquals, []string{
`mkdir -p '/some/dir'`,
})
}
func (s bashSuite) TestChown(c *gc.C) {
commands := s.renderer.Chown("/a/b/c", "x", "y")
c.Check(commands, jc.DeepEquals, []string{
"chown x:y '/a/b/c'",
})
}
func (s bashSuite) TestTouchDefault(c *gc.C) {
commands := s.renderer.Touch("/a/b/c", nil)
c.Check(commands, jc.DeepEquals, []string{
"touch '/a/b/c'",
})
}
func (s bashSuite) TestTouchTimestamp(c *gc.C) {
now := time.Date(2015, time.Month(3), 14, 12, 26, 38, 0, time.UTC)
commands := s.renderer.Touch("/a/b/c", &now)
c.Check(commands, jc.DeepEquals, []string{
"touch -t 201503141226.38 '/a/b/c'",
})
}
func (s bashSuite) TestRedirectFD(c *gc.C) {
commands := s.renderer.RedirectFD("stdout", "stderr")
c.Check(commands, jc.DeepEquals, []string{
"exec 2>&1",
})
}
func (s bashSuite) TestRedirectOutput(c *gc.C) {
commands := s.renderer.RedirectOutput("/a/b/c")
c.Check(commands, jc.DeepEquals, []string{
"exec >> '/a/b/c'",
})
}
func (s bashSuite) TestRedirectOutputReset(c *gc.C) {
commands := s.renderer.RedirectOutputReset("/a/b/c")
c.Check(commands, jc.DeepEquals, []string{
"exec > '/a/b/c'",
})
}
func (s bashSuite) TestScriptFilename(c *gc.C) {
filename := s.renderer.ScriptFilename("spam", "/ham/eggs")
c.Check(filename, gc.Equals, "/ham/eggs/spam.sh")
}
func (s bashSuite) TestScriptPermissions(c *gc.C) {
perm := s.renderer.ScriptPermissions()
c.Check(perm, gc.Equals, os.FileMode(0755))
}
================================================
FILE: shell/command.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell
import (
"os"
"time"
)
// CommandRenderer provides methods that may be used to generate shell
// commands for a variety of shell and filesystem operations.
type CommandRenderer interface {
// Chown returns a shell command for changing the ownership of
// a file or directory. The copies the behavior of os.Chown,
// though it also supports names in addition to ints.
Chown(name, user, group string) []string
// Chmod returns a shell command that sets the given file's
// permissions. The result is equivalent to os.Chmod.
Chmod(path string, perm os.FileMode) []string
// WriteFile returns a shell command that writes the provided
// content to a file. The command is functionally equivalent to
// ioutil.WriteFile with permissions from the current umask.
WriteFile(filename string, data []byte) []string
// Mkdir returns a shell command for creating a directory. The
// command is functionally equivalent to os.MkDir using permissions
// appropriate for a directory.
Mkdir(dirname string) []string
// MkdirAll returns a shell command for creating a directory and
// all missing parent directories. The command is functionally
// equivalent to os.MkDirAll using permissions appropriate for
// a directory.
MkdirAll(dirname string) []string
// Touch returns a shell command that updates the atime and ctime
// of the named file. If the provided timestamp is nil then the
// current time is used. If the file does not exist then it is
// created. If UTC is desired then Time.UTC() should be called
// before calling Touch.
Touch(filename string, timestamp *time.Time) []string
}
================================================
FILE: shell/interface_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell
var _ Renderer = (*BashRenderer)(nil)
var _ Renderer = (*PowershellRenderer)(nil)
var _ Renderer = (*WinCmdRenderer)(nil)
================================================
FILE: shell/output.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell
import (
"strconv"
"strings"
)
// OutputRenderer exposes the Renderer methods that relate to shell output.
//
// The methods all accept strings to identify their file descriptor
// arguments. While the interpretation of these values is up to the
// renderer, it will likely conform to the result of calling ResolveFD.
// If an FD arg is not recognized then no commands will be returned.
// Unless otherwise specified, the default file descriptor is stdout
// (FD 1). This applies to the empty string.
type OutputRenderer interface {
// RedirectFD returns a shell command that redirects the src
// file descriptor to the dst one.
RedirectFD(dst, src string) []string
// TODO(ericsnow) Add CopyFD and CreateFD?
// TODO(ericsnow) Support passing the src FD as an arg?
// RedirectOutput will cause all subsequent output from the shell
// (or script) to go be appended to the given file. Only stdout is
// redirected (use RedirectFD to redirect stderr or other FDs).
//
// The file should already exist (so a call to Touch may be
// necessary before calling RedirectOutput). If the file should have
// specific permissions or a specific owner then Chmod and Chown
// should be called before calling RedirectOutput.
RedirectOutput(filename string) []string
// RedirectOutputReset will cause all subsequent output from the
// shell (or script) to go be written to the given file. The file
// will be reset (truncated to 0) before anything is written. Only
// stdout is redirected (use RedirectFD to redirect stderr or
// other FDs).
//
// The file should already exist (so a call to Touch may be
// necessary before calling RedirectOutputReset). If the file should
// have specific permissions or a specific owner then Chmod and
// Chown should be called before calling RedirectOutputReset.
RedirectOutputReset(filename string) []string
}
// ResolveFD converts the file descriptor name to the corresponding int.
// "stdout" and "out" match stdout (FD 1). "stderr" and "err" match
// stderr (FD 2), "stdin" and "in" match stdin (FD 0). All positive
// integers match. If there should be an upper bound then the caller
// should check it on the result. If the provided name is empty then
// the result defaults to stdout. If the name is not recognized then
// false is returned.
func ResolveFD(name string) (int, bool) {
switch strings.ToLower(name) {
case "stdout", "out", "":
return 1, true
case "stderr", "err":
return 2, true
case "stdin", "in":
return 0, true
default:
fd, err := strconv.ParseUint(name, 10, 64)
if err != nil {
return -1, false
}
return int(fd), true
}
}
================================================
FILE: shell/package_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func Test(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: shell/powershell.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell
import (
"encoding/base64"
"fmt"
"os"
"golang.org/x/text/encoding/unicode"
"github.com/juju/errors"
"github.com/juju/utils/v4"
)
// PowershellRenderer is a shell renderer for Windows Powershell.
type PowershellRenderer struct {
windowsRenderer
}
// Quote implements Renderer.
func (pr *PowershellRenderer) Quote(str string) string {
return utils.WinPSQuote(str)
}
// Chmod implements Renderer.
func (pr *PowershellRenderer) Chmod(path string, perm os.FileMode) []string {
// TODO(ericsnow) Is this necessary? Should we use Set-Acl?
return nil
}
// WriteFile implements Renderer.
func (pr *PowershellRenderer) WriteFile(filename string, data []byte) []string {
filename = pr.Quote(filename)
return []string{
fmt.Sprintf("Set-Content %s @\"\n%s\n\"@", filename, data),
}
}
// MkDir implements Renderer.
func (pr *PowershellRenderer) Mkdir(dirname string) []string {
dirname = pr.FromSlash(dirname)
return []string{
fmt.Sprintf(`mkdir %s`, pr.Quote(dirname)),
}
}
// MkdirAll implements Renderer.
func (pr *PowershellRenderer) MkdirAll(dirname string) []string {
return pr.Mkdir(dirname)
}
// ScriptFilename implements ScriptWriter.
func (pr *PowershellRenderer) ScriptFilename(name, dirname string) string {
return pr.Join(dirname, name+".ps1")
}
// By default, winrm executes command usind cmd. Prefix the command we send over WinRM with powershell.exe.
// the powershell.exe it's a program that will execute the "%s" encoded command.
// A breakdown of the parameters:
//
// -NonInteractive - prevent any prompts from stopping the execution of the scrips
// -ExecutionPolicy - sets the execution policy for the current command, regardless of the default ExecutionPolicy on the system.
// -EncodedCommand - allows us to run a base64 encoded script. This spares us from having to quote/escape shell special characters.
const psRemoteWrapper = "powershell.exe -Sta -NonInteractive -ExecutionPolicy RemoteSigned -EncodedCommand %s"
// newEncodedPSScript returns a UTF16-LE, base64 encoded script.
// The -EncodedCommand parameter expects this encoding for any base64 script we send over.
func newEncodedPSScript(script string) (string, error) {
uni := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM)
encoded, err := uni.NewEncoder().String(script)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString([]byte(encoded)), nil
}
// NewPSEncodedCommand converts the given string to a UTF16-LE, base64 encoded string,
// suitable for execution using powershell.exe -EncodedCommand. This can be used on
// local systems, as well as remote systems via WinRM.
func NewPSEncodedCommand(script string) (string, error) {
var err error
script, err = newEncodedPSScript(script)
if err != nil {
return "", errors.Annotatef(err, "Cannot construct powershell command for remote execution")
}
return fmt.Sprintf(psRemoteWrapper, script), nil
}
================================================
FILE: shell/powershell_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell_test
import (
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/shell"
)
var _ = gc.Suite(&powershellSuite{})
type powershellSuite struct {
testing.IsolationSuite
dirname string
filename string
renderer *shell.PowershellRenderer
}
func (s *powershellSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
s.dirname = `C:\some\dir`
s.filename = s.dirname + `\file`
s.renderer = &shell.PowershellRenderer{}
}
func (s powershellSuite) TestExeSuffix(c *gc.C) {
suffix := s.renderer.ExeSuffix()
c.Check(suffix, gc.Equals, ".exe")
}
func (s powershellSuite) TestShQuote(c *gc.C) {
quoted := s.renderer.Quote("abc")
c.Check(quoted, gc.Equals, `'abc'`)
}
func (s powershellSuite) TestChmod(c *gc.C) {
commands := s.renderer.Chmod(s.filename, 0644)
c.Check(commands, gc.HasLen, 0)
}
func (s powershellSuite) TestWriteFile(c *gc.C) {
data := []byte("something\nhere\n")
commands := s.renderer.WriteFile(s.filename, data)
expected := `
Set-Content 'C:\some\dir\file' @"
something
here
"@`[1:]
c.Check(commands, jc.DeepEquals, []string{
expected,
})
}
func (s powershellSuite) TestMkdir(c *gc.C) {
commands := s.renderer.Mkdir(s.dirname)
c.Check(commands, jc.DeepEquals, []string{
`mkdir 'C:\some\dir'`,
})
}
func (s powershellSuite) TestMkdirAll(c *gc.C) {
commands := s.renderer.MkdirAll(s.dirname)
c.Check(commands, jc.DeepEquals, []string{
`mkdir 'C:\some\dir'`,
})
}
func (s powershellSuite) TestNewPSEncodedCommand(c *gc.C) {
script := `
Get-WmiObject win32_processor
`
expected := "powershell.exe -Sta -NonInteractive -ExecutionPolicy RemoteSigned -EncodedCommand CgAJAEcAZQB0AC0AVwBtAGkATwBiAGoAZQBjAHQAIAB3AGkAbgAzADIAXwBwAHIAbwBjAGUAcwBzAG8AcgAKAA=="
out, err := shell.NewPSEncodedCommand(script)
c.Assert(err, gc.IsNil)
c.Assert((len(out) > 0), gc.Equals, true)
c.Assert(out, jc.DeepEquals, expected)
}
================================================
FILE: shell/renderer.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell
import (
"runtime"
"strings"
"github.com/juju/errors"
"github.com/juju/utils/v4"
"github.com/juju/utils/v4/filepath"
)
// A PathRenderer generates paths that are appropriate for a given
// shell environment.
type PathRenderer interface {
filepath.Renderer
// Quote generates a new string with quotation marks and relevant
// escape/control characters properly escaped. The resulting string
// is wrapped in quotation marks such that it will be treated as a
// single string by the shell.
Quote(str string) string
// ExeSuffix returns the filename suffix for executable files.
ExeSuffix() string
}
// Renderer provides all the functionality needed to generate shell-
// compatible paths and commands.
type Renderer interface {
PathRenderer
CommandRenderer
OutputRenderer
}
// NewRenderer returns a Renderer for the given shell, OS, or distro name.
func NewRenderer(name string) (Renderer, error) {
if name == "" {
name = runtime.GOOS
} else {
name = strings.ToLower(name)
}
// Try known shell names first.
switch name {
case "bash":
return &BashRenderer{}, nil
case "ps", "powershell":
return &PowershellRenderer{}, nil
case "cmd", "batch", "bat":
return &WinCmdRenderer{}, nil
}
// Fall back to operating systems.
switch {
case name == "windows":
return &PowershellRenderer{}, nil
case utils.OSIsUnix(name):
return &BashRenderer{}, nil
}
// Finally try distros.
switch name {
case "ubuntu":
return &BashRenderer{}, nil
}
return nil, errors.NotFoundf("renderer for %q", name)
}
================================================
FILE: shell/renderer_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell_test
import (
"runtime"
"github.com/juju/errors"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
"github.com/juju/utils/v4/shell"
)
type rendererSuite struct {
testing.IsolationSuite
unix *shell.BashRenderer
windows *shell.PowershellRenderer
}
var _ = gc.Suite(&rendererSuite{})
func (s *rendererSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
s.unix = &shell.BashRenderer{}
s.windows = &shell.PowershellRenderer{}
}
func (s rendererSuite) checkRenderer(c *gc.C, renderer shell.Renderer, expected string) {
switch expected {
case "powershell":
c.Check(renderer, gc.FitsTypeOf, s.windows)
case "bash":
c.Check(renderer, gc.FitsTypeOf, s.unix)
default:
c.Errorf("unknown kind %q", expected)
}
}
func (s rendererSuite) TestNewRendererDefault(c *gc.C) {
// All possible values of runtime.GOOS should be supported.
renderer, err := shell.NewRenderer("")
c.Assert(err, jc.ErrorIsNil)
switch runtime.GOOS {
case "windows":
s.checkRenderer(c, renderer, "powershell")
default:
s.checkRenderer(c, renderer, "bash")
}
}
func (s rendererSuite) TestNewRendererGOOS(c *gc.C) {
// All possible values of runtime.GOOS should be supported.
renderer, err := shell.NewRenderer(runtime.GOOS)
c.Assert(err, jc.ErrorIsNil)
switch runtime.GOOS {
case "windows":
s.checkRenderer(c, renderer, "powershell")
default:
s.checkRenderer(c, renderer, "bash")
}
}
func (s rendererSuite) TestNewRendererWindows(c *gc.C) {
renderer, err := shell.NewRenderer("windows")
c.Assert(err, jc.ErrorIsNil)
s.checkRenderer(c, renderer, "powershell")
}
func (s rendererSuite) TestNewRendererUnix(c *gc.C) {
for _, os := range utils.OSUnix {
c.Logf("trying %q", os)
renderer, err := shell.NewRenderer(os)
c.Assert(err, jc.ErrorIsNil)
s.checkRenderer(c, renderer, "bash")
}
}
func (s rendererSuite) TestNewRendererDistros(c *gc.C) {
distros := []string{"ubuntu"}
for _, distro := range distros {
c.Logf("trying %q", distro)
renderer, err := shell.NewRenderer(distro)
c.Assert(err, jc.ErrorIsNil)
s.checkRenderer(c, renderer, "bash")
}
}
func (s rendererSuite) TestNewRendererUnknown(c *gc.C) {
_, err := shell.NewRenderer("")
c.Check(err, jc.Satisfies, errors.IsNotFound)
}
================================================
FILE: shell/script.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell
import (
"fmt"
"os"
"github.com/juju/utils/v4"
)
// DumpFileOnErrorScript returns a bash script that
// may be used to dump the contents of the specified
// file to stderr when the shell exits with an error.
func DumpFileOnErrorScript(filename string) string {
script := `
dump_file() {
code=$?
if [ $code -ne 0 -a -e %s ]; then
cat %s >&2
fi
exit $code
}
trap dump_file EXIT
`[1:]
filename = utils.ShQuote(filename)
return fmt.Sprintf(script, filename, filename)
}
// A ScriptRenderer provides the functionality necessary to render a
// sequence of shell commands into the content of a shell script.
type ScriptRenderer interface {
// RenderScript generates the content of a shell script for the
// provided shell commands.
RenderScript(commands []string) []byte
}
// A ScriptWriter provides the functionality necessarily to render and
// write a sequence of shell commands to a shell script that is ready
// to be run.
type ScriptWriter interface {
ScriptRenderer
// Chmod returns a shell command that sets the given file's
// permissions. The result is equivalent to os.Chmod.
Chmod(path string, perm os.FileMode) []string
// WriteFile returns a shell command that writes the provided
// content to a file. The command is functionally equivalent to
// ioutil.WriteFile with permissions from the current umask.
WriteFile(filename string, data []byte) []string
// ScriptFilename generates a filename appropriate for a script
// from the provided file and directory names.
ScriptFilename(name, dirname string) string
// ScriptPermissions returns the permissions appropriate for a script.
ScriptPermissions() os.FileMode
}
// WriteScript returns a sequence of shell commands that write the
// provided shell commands to a file. The filename is composed from the
// given directory name and name, and the appropriate suffix for a shell
// script is applied. The script content is prefixed with any necessary
// content related to shell scripts (e.g. a shbang line). The file's
// permissions are set to those appropriate for a script (e.g. 0755).
func WriteScript(renderer ScriptWriter, name, dirname string, script []string) []string {
filename := renderer.ScriptFilename(name, dirname)
perm := renderer.ScriptPermissions()
var commands []string
data := renderer.RenderScript(script)
commands = append(commands, renderer.WriteFile(filename, data)...)
commands = append(commands, renderer.Chmod(filename, perm)...)
return commands
}
================================================
FILE: shell/script_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell_test
import (
"bytes"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/shell"
)
type scriptSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&scriptSuite{})
func (*scriptSuite) TestDumpFileOnErrorScriptOutput(c *gc.C) {
script := shell.DumpFileOnErrorScript("a b c")
c.Assert(script, gc.Equals, `
dump_file() {
code=$?
if [ $code -ne 0 -a -e 'a b c' ]; then
cat 'a b c' >&2
fi
exit $code
}
trap dump_file EXIT
`[1:])
}
func (*scriptSuite) TestDumpFileOnErrorScript(c *gc.C) {
tempdir := c.MkDir()
filename := filepath.Join(tempdir, "log.txt")
err := ioutil.WriteFile(filename, []byte("abc"), 0644)
c.Assert(err, gc.IsNil)
dumpScript := shell.DumpFileOnErrorScript(filename)
c.Logf("%s", dumpScript)
run := func(command string) (stdout, stderr string) {
var stdoutBuf, stderrBuf bytes.Buffer
cmd := exec.Command("/bin/bash", "-s")
cmd.Stdin = strings.NewReader(dumpScript + command)
cmd.Stdout = &stdoutBuf
cmd.Stderr = &stderrBuf
cmd.Run()
return stdoutBuf.String(), stderrBuf.String()
}
stdout, stderr := run("exit 0")
c.Assert(stdout, gc.Equals, "")
c.Assert(stderr, gc.Equals, "")
stdout, stderr = run("exit 1")
c.Assert(stdout, gc.Equals, "")
c.Assert(stderr, gc.Equals, "abc")
err = os.Remove(filename)
c.Assert(err, gc.IsNil)
stdout, stderr = run("exit 1")
c.Assert(stdout, gc.Equals, "")
c.Assert(stderr, gc.Equals, "")
}
func (*scriptSuite) TestWriteScriptUnix(c *gc.C) {
renderer := &shell.BashRenderer{}
script := `
exec a-command
exec another-command
`
commands := shell.WriteScript(renderer, "spam", "/ham/eggs", strings.Split(script, "\n"))
cmd := `
cat > '/ham/eggs/spam.sh' << 'EOF'
#!/usr/bin/env bash
exec a-command
exec another-command
EOF`[1:]
c.Check(commands, jc.DeepEquals, []string{
cmd,
"chmod 0755 '/ham/eggs/spam.sh'",
})
}
func (*scriptSuite) TestWriteScriptWindows(c *gc.C) {
renderer := &shell.PowershellRenderer{}
script := `
exec a-command
exec another-command
`
commands := shell.WriteScript(renderer, "spam", `C:\ham\eggs`, strings.Split(script, "\n"))
c.Check(commands, jc.DeepEquals, []string{
`Set-Content 'C:\ham\eggs\spam.ps1' @"
exec a-command
exec another-command
"@`,
})
}
================================================
FILE: shell/unix.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell
import (
"fmt"
"os"
"time"
"github.com/juju/utils/v4"
"github.com/juju/utils/v4/filepath"
)
// unixRenderer is the base shell renderer for "unix" shells.
type unixRenderer struct {
filepath.UnixRenderer
}
// Quote implements Renderer.
func (unixRenderer) Quote(str string) string {
// This *may* not be correct for *all* unix shells...
return utils.ShQuote(str)
}
// ExeSuffix implements Renderer.
func (unixRenderer) ExeSuffix() string {
return ""
}
// Mkdir implements Renderer.
func (ur unixRenderer) Mkdir(dirname string) []string {
dirname = ur.Quote(dirname)
return []string{
fmt.Sprintf("mkdir %s", dirname),
}
}
// MkdirAll implements Renderer.
func (ur unixRenderer) MkdirAll(dirname string) []string {
dirname = ur.Quote(dirname)
return []string{
fmt.Sprintf("mkdir -p %s", dirname),
}
}
// Chmod implements Renderer.
func (ur unixRenderer) Chmod(path string, perm os.FileMode) []string {
path = ur.Quote(path)
return []string{
fmt.Sprintf("chmod %04o %s", perm, path),
}
}
// Chown implements Renderer.
func (ur unixRenderer) Chown(path, owner, group string) []string {
path = ur.Quote(path)
return []string{
fmt.Sprintf("chown %s:%s %s", owner, group, path),
}
}
// Touch implements Renderer.
func (ur unixRenderer) Touch(path string, timestamp *time.Time) []string {
path = ur.Quote(path)
var opt string
if timestamp != nil {
opt = timestamp.Format("-t 200601021504.05 ")
}
return []string{
fmt.Sprintf("touch %s%s", opt, path),
}
}
// WriteFile implements Renderer.
func (ur unixRenderer) WriteFile(filename string, data []byte) []string {
filename = ur.Quote(filename)
return []string{
// An alternate approach would be to use printf.
fmt.Sprintf("cat > %s << 'EOF'\n%s\nEOF", filename, data),
}
}
func (unixRenderer) outFD(name string) (int, bool) {
fd, ok := ResolveFD(name)
if !ok || fd <= 0 {
return -1, false
}
return fd, true
}
// RedirectFD implements OutputRenderer.
func (ur unixRenderer) RedirectFD(dst, src string) []string {
dstFD, ok := ur.outFD(dst)
if !ok {
return nil
}
srcFD, ok := ur.outFD(src)
if !ok {
return nil
}
return []string{
fmt.Sprintf("exec %d>&%d", srcFD, dstFD),
}
}
// RedirectOutput implements OutputRenderer.
func (ur unixRenderer) RedirectOutput(filename string) []string {
filename = ur.Quote(filename)
return []string{
"exec >> " + filename,
}
}
// RedirectOutputReset implements OutputRenderer.
func (ur unixRenderer) RedirectOutputReset(filename string) []string {
filename = ur.Quote(filename)
return []string{
"exec > " + filename,
}
}
// ScriptFilename implements ScriptWriter.
func (ur *unixRenderer) ScriptFilename(name, dirname string) string {
return ur.Join(dirname, name+".sh")
}
// ScriptPermissions implements ScriptWriter.
func (ur *unixRenderer) ScriptPermissions() os.FileMode {
return 0755
}
================================================
FILE: shell/win.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell
import (
"os"
"strings"
"time"
"github.com/juju/utils/v4/filepath"
)
// windowsRenderer is the base implementation for Windows shells.
type windowsRenderer struct {
filepath.WindowsRenderer
}
// ExeSuffix implements Renderer.
func (w *windowsRenderer) ExeSuffix() string {
return ".exe"
}
// ScriptPermissions implements ScriptWriter.
func (w *windowsRenderer) ScriptPermissions() os.FileMode {
return 0755
}
// Render implements ScriptWriter.
func (w *windowsRenderer) RenderScript(commands []string) []byte {
return []byte(strings.Join(commands, "\n"))
}
// Chown implements Renderer.
func (w windowsRenderer) Chown(path, owner, group string) []string {
// TODO(ericsnow) Use ???
panic("not supported")
}
// Touch implements Renderer.
func (w windowsRenderer) Touch(path string, timestamp *time.Time) []string {
// TODO(ericsnow) Use ???
panic("not supported")
}
// RedirectFD implements OutputRenderer.
func (w windowsRenderer) RedirectFD(dst, src string) []string {
// TODO(ericsnow) Use ???
panic("not supported")
}
// RedirectOutput implements OutputRenderer.
func (w windowsRenderer) RedirectOutput(filename string) []string {
// TODO(ericsnow) Use ???
panic("not supported")
}
// RedirectOutputReset implements OutputRenderer.
func (w windowsRenderer) RedirectOutputReset(filename string) []string {
// TODO(ericsnow) Use ???
panic("not supported")
}
================================================
FILE: shell/wincmd.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell
import (
"bytes"
"fmt"
"os"
"github.com/juju/utils/v4"
)
// WinCmdRenderer is a shell renderer for Windows cmd.exe.
type WinCmdRenderer struct {
windowsRenderer
}
// Quote implements Renderer.
func (wcr *WinCmdRenderer) Quote(str string) string {
return utils.WinCmdQuote(str)
}
// Chmod implements Renderer.
func (wcr *WinCmdRenderer) Chmod(path string, perm os.FileMode) []string {
// TODO(ericsnow) Is this necessary? Should we use icacls?
return nil
}
// WriteFile implements Renderer.
func (wcr *WinCmdRenderer) WriteFile(filename string, data []byte) []string {
filename = wcr.Quote(filename)
var commands []string
for _, line := range bytes.Split(data, []byte{'\n'}) {
cmd := fmt.Sprintf(">>%s @echo %s", filename, line)
commands = append(commands, cmd)
}
return commands
}
// MkDir implements Renderer.
func (wcr *WinCmdRenderer) Mkdir(dirname string) []string {
dirname = wcr.Quote(dirname)
return []string{
fmt.Sprintf(`mkdir %s`, wcr.FromSlash(dirname)),
}
}
// MkDirAll implements Renderer.
func (wcr *WinCmdRenderer) MkdirAll(dirname string) []string {
dirname = wcr.Quote(dirname)
// TODO(ericsnow) Wrap in "setlocal enableextensions...endlocal"?
return []string{
fmt.Sprintf(`mkdir %s`, wcr.FromSlash(dirname)),
}
}
// ScriptFilename implements ScriptWriter.
func (wcr *WinCmdRenderer) ScriptFilename(name, dirname string) string {
return wcr.Join(dirname, name+".bat")
}
================================================
FILE: shell/wincmd_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package shell_test
import (
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/shell"
)
var _ = gc.Suite(&winCmdSuite{})
type winCmdSuite struct {
testing.IsolationSuite
dirname string
filename string
renderer *shell.WinCmdRenderer
}
func (s *winCmdSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
s.dirname = `C:\some\dir`
s.filename = s.dirname + `\file`
s.renderer = &shell.WinCmdRenderer{}
}
func (s winCmdSuite) TestExeSuffix(c *gc.C) {
suffix := s.renderer.ExeSuffix()
c.Check(suffix, gc.Equals, ".exe")
}
func (s winCmdSuite) TestShQuote(c *gc.C) {
quoted := s.renderer.Quote("abc")
c.Check(quoted, gc.Equals, `^"abc^"`)
}
func (s winCmdSuite) TestChmod(c *gc.C) {
commands := s.renderer.Chmod(s.filename, 0644)
c.Check(commands, gc.HasLen, 0)
}
func (s winCmdSuite) TestWriteFile(c *gc.C) {
data := []byte("something\nhere\n")
commands := s.renderer.WriteFile(s.filename, data)
c.Check(commands, jc.DeepEquals, []string{
`>>^"C:\\some\\dir\\file^" @echo something`,
`>>^"C:\\some\\dir\\file^" @echo here`,
`>>^"C:\\some\\dir\\file^" @echo `,
})
}
func (s winCmdSuite) TestMkdir(c *gc.C) {
commands := s.renderer.Mkdir(s.dirname)
c.Check(commands, jc.DeepEquals, []string{
`mkdir ^"C:\\some\\dir^"`,
})
}
func (s winCmdSuite) TestMkdirAll(c *gc.C) {
commands := s.renderer.MkdirAll(s.dirname)
c.Check(commands, jc.DeepEquals, []string{
`mkdir ^"C:\\some\\dir^"`,
})
}
================================================
FILE: size.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"math"
"strconv"
"strings"
"unicode"
"github.com/juju/errors"
)
// ParseSize parses the string as a size, in mebibytes.
//
// The string must be a is a non-negative number with
// an optional multiplier suffix (M, G, T, P, E, Z, or Y).
// If the suffix is not specified, "M" is implied.
func ParseSize(str string) (MB uint64, err error) {
// Find the first non-digit/period:
i := strings.IndexFunc(str, func(r rune) bool {
return r != '.' && !unicode.IsDigit(r)
})
var multiplier float64 = 1
if i > 0 {
suffix := str[i:]
multiplier = 0
for j := 0; j < len(sizeSuffixes); j++ {
base := string(sizeSuffixes[j])
// M, MB, or MiB are all valid.
switch suffix {
case base, base + "B", base + "iB":
multiplier = float64(sizeSuffixMultiplier(j))
break
}
}
if multiplier == 0 {
return 0, errors.Errorf("invalid multiplier suffix %q, expected one of %s", suffix, []byte(sizeSuffixes))
}
str = str[:i]
}
val, err := strconv.ParseFloat(str, 64)
if err != nil || val < 0 {
return 0, errors.Errorf("expected a non-negative number, got %q", str)
}
val *= multiplier
return uint64(math.Ceil(val)), nil
}
var sizeSuffixes = "MGTPEZY"
func sizeSuffixMultiplier(i int) int {
return 1 << uint(i*10)
}
// SizeTracker tracks the number of bytes passing through
// its Write method (which is otherwise a no-op).
//
// Use SizeTracker with io.MultiWriter() to track number of bytes
// written. Use with io.TeeReader() to track number of bytes read.
type SizeTracker struct {
// size is the number of bytes written so far.
size int64
}
// Size returns the number of bytes written so far.
func (st SizeTracker) Size() int64 {
return st.size
}
// Write implements io.Writer.
func (st *SizeTracker) Write(data []byte) (n int, err error) {
n = len(data)
st.size += int64(n)
return n, nil
}
================================================
FILE: size_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"io"
"io/ioutil"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
"github.com/juju/testing/filetesting"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
var _ = gc.Suite(&sizeSuite{})
type sizeSuite struct {
testing.IsolationSuite
}
func (*sizeSuite) TestParseSize(c *gc.C) {
type test struct {
in string
out uint64
err string
}
tests := []test{{
in: "",
err: `expected a non-negative number, got ""`,
}, {
in: "-1",
err: `expected a non-negative number, got "-1"`,
}, {
in: "1MZ",
err: `invalid multiplier suffix "MZ", expected one of MGTPEZY`,
}, {
in: "0",
out: 0,
}, {
in: "123",
out: 123,
}, {
in: "1M",
out: 1,
}, {
in: "0.5G",
out: 512,
}, {
in: "0.5GB",
out: 512,
}, {
in: "0.5GiB",
out: 512,
}, {
in: "0.5T",
out: 524288,
}, {
in: "0.5P",
out: 536870912,
}, {
in: "0.0009765625E",
out: 1073741824,
}, {
in: "1Z",
out: 1125899906842624,
}, {
in: "1Y",
out: 1152921504606846976,
}}
for i, test := range tests {
c.Logf("test %d: %+v", i, test)
size, err := utils.ParseSize(test.in)
if test.err != "" {
c.Assert(err, gc.NotNil)
c.Assert(err, gc.ErrorMatches, test.err)
} else {
c.Assert(err, gc.IsNil)
c.Assert(size, gc.Equals, test.out)
}
}
}
func (*sizeSuite) TestSizingReaderOkay(c *gc.C) {
expected := "some data"
stub := &testing.Stub{}
reader := filetesting.NewStubReader(stub, expected)
var st utils.SizeTracker
sizingReader := io.TeeReader(reader, &st)
data, err := ioutil.ReadAll(sizingReader)
c.Assert(err, jc.ErrorIsNil)
stub.CheckCallNames(c, "Read", "Read")
c.Check(string(data), gc.Equals, expected)
c.Check(st.Size(), gc.Equals, int64(len(expected)))
}
func (*sizeSuite) TestSizingReaderMixedEOF(c *gc.C) {
expected := "some data"
stub := &testing.Stub{}
reader := &filetesting.StubReader{
Stub: stub,
ReturnRead: &fakeStream{
data: expected,
},
}
var st utils.SizeTracker
sizingReader := io.TeeReader(reader, &st)
data, err := ioutil.ReadAll(sizingReader)
c.Assert(err, jc.ErrorIsNil)
stub.CheckCallNames(c, "Read") // The EOF was mixed with the data.
c.Check(string(data), gc.Equals, expected)
c.Check(st.Size(), gc.Equals, int64(len(expected)))
}
func (*sizeSuite) TestSizingWriter(c *gc.C) {
expected := "some data"
stub := &testing.Stub{}
writer, buffer := filetesting.NewStubWriter(stub)
var st utils.SizeTracker
sizingWriter := io.MultiWriter(writer, &st)
n, err := sizingWriter.Write([]byte(expected))
c.Assert(err, jc.ErrorIsNil)
stub.CheckCallNames(c, "Write")
c.Check(n, gc.Equals, len(expected))
c.Check(buffer.String(), gc.Equals, expected)
c.Check(st.Size(), gc.Equals, int64(len(expected)))
}
type fakeStream struct {
data string
pos uint64
}
func (f *fakeStream) Read(data []byte) (int, error) {
n := copy(data, f.data[f.pos:])
f.pos += uint64(n)
if f.pos >= uint64(len(f.data)) {
return n, io.EOF
}
return n, nil
}
================================================
FILE: ssh/authorisedkeys.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh
import (
"fmt"
"io/ioutil"
"os"
"os/user"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"github.com/juju/errors"
"github.com/juju/loggo/v2"
"golang.org/x/crypto/ssh"
"github.com/juju/utils/v4"
)
var logger = loggo.GetLogger("juju.utils.ssh")
type ListMode bool
var (
FullKeys ListMode = true
Fingerprints ListMode = false
)
const (
defaultAuthKeysFile = "authorized_keys"
)
type AuthorisedKey struct {
Type string
Key []byte
Comment string
}
func authKeysDir(username string) (string, error) {
homeDir, err := utils.UserHomeDir(username)
if err != nil {
return "", err
}
homeDir, err = utils.NormalizePath(homeDir)
if err != nil {
return "", err
}
return filepath.Join(homeDir, ".ssh"), nil
}
// ParseAuthorisedKey parses a non-comment line from an
// authorized_keys file and returns the constituent parts.
// Based on description in "man sshd".
func ParseAuthorisedKey(line string) (*AuthorisedKey, error) {
if strings.Contains(line, "\n") {
return nil, errors.NotValidf("newline in authorized_key %q", line)
}
key, comment, _, _, err := ssh.ParseAuthorizedKey([]byte(line))
if err != nil {
return nil, errors.Errorf("invalid authorized_key %q", line)
}
return &AuthorisedKey{
Type: key.Type(),
Key: key.Marshal(),
Comment: comment,
}, nil
}
// ConcatAuthorisedKeys will joing two or more authorised keys together to form
// a string based list of authorised keys that can be read by ssh programs. Keys
// joined with a newline as the separator.
func ConcatAuthorisedKeys(a, b string) string {
if a == "" {
return b
}
if b == "" {
return a
}
if a[len(a)-1] != '\n' {
return a + "\n" + b
}
return a + b
}
// SplitAuthorisedKeys extracts a key slice from the specified key data,
// by splitting the key data into lines and ignoring comments and blank lines.
func SplitAuthorisedKeys(keyData string) []string {
var keys []string
for _, key := range strings.Split(string(keyData), "\n") {
key = strings.Trim(key, " \r")
if len(key) == 0 {
continue
}
if key[0] == '#' {
continue
}
keys = append(keys, key)
}
return keys
}
func readAuthorisedKeys(username, filename string) ([]string, error) {
keyDir, err := authKeysDir(username)
if err != nil {
return nil, err
}
sshKeyFile := filepath.Join(keyDir, filename)
logger.Debugf("reading authorised keys file %s", sshKeyFile)
keyData, err := ioutil.ReadFile(sshKeyFile)
if os.IsNotExist(err) {
return []string{}, nil
}
if err != nil {
return nil, errors.Annotate(err, "reading ssh authorised keys file")
}
var keys []string
for _, key := range strings.Split(string(keyData), "\n") {
if len(strings.Trim(key, " \r")) == 0 {
continue
}
keys = append(keys, key)
}
return keys, nil
}
func writeAuthorisedKeys(username, filename string, keys []string) error {
keyDir, err := authKeysDir(username)
if err != nil {
return err
}
err = os.MkdirAll(keyDir, os.FileMode(0755))
if err != nil {
return errors.Annotate(err, "cannot create ssh key directory")
}
keyData := strings.Join(keys, "\n") + "\n"
// Get perms to use on auth keys file
sshKeyFile := filepath.Join(keyDir, filename)
perms := os.FileMode(0644)
info, err := os.Stat(sshKeyFile)
if err == nil {
perms = info.Mode().Perm()
}
logger.Debugf("writing authorised keys file %s", sshKeyFile)
err = utils.AtomicWriteFile(sshKeyFile, []byte(keyData), perms)
if err != nil {
return err
}
// TODO (wallyworld) - what to do on windows (if anything)
// TODO(dimitern) - no need to use user.Current() if username
// is "" - it will use the current user anyway.
if runtime.GOOS != "windows" {
// Ensure the resulting authorised keys file has its ownership
// set to the specified username.
var u *user.User
if username == "" {
u, err = user.Current()
} else {
u, err = user.Lookup(username)
}
if err != nil {
return err
}
// chown requires ints but user.User has strings for windows.
uid, err := strconv.Atoi(u.Uid)
if err != nil {
return err
}
gid, err := strconv.Atoi(u.Gid)
if err != nil {
return err
}
err = os.Chown(sshKeyFile, uid, gid)
if err != nil {
return err
}
}
return nil
}
// We need a mutex because updates to the authorised keys file are done by
// reading the contents, updating, and writing back out. So only one caller
// at a time can use either Add, Delete, List.
var keysMutex sync.Mutex
// AddKeys adds the specified ssh keys to the authorized_keys file for user.
// Returns an error if there is an issue with *any* of the supplied keys.
func AddKeys(user string, newKeys ...string) error {
keysMutex.Lock()
defer keysMutex.Unlock()
existingKeys, err := readAuthorisedKeys(user, defaultAuthKeysFile)
if err != nil {
return err
}
return addKeys(user, defaultAuthKeysFile, newKeys, existingKeys)
}
// DeleteKeys removes the specified ssh keys from the authorized ssh keys file for user.
// keyIds may be either key comments or fingerprints.
// Returns an error if there is an issue with *any* of the keys to delete.
func DeleteKeys(user string, keyIds ...string) error {
keysMutex.Lock()
defer keysMutex.Unlock()
existingKeys, err := readAuthorisedKeys(user, defaultAuthKeysFile)
if err != nil {
return err
}
return deleteKeys(user, defaultAuthKeysFile, existingKeys, keyIds, false)
}
// ReplaceKeys writes the specified ssh keys to the authorized_keys file for user,
// replacing any that are already there.
// Returns an error if there is an issue with *any* of the supplied keys.
func ReplaceKeys(user string, newKeys ...string) error {
keysMutex.Lock()
defer keysMutex.Unlock()
existingKeyData, err := readAuthorisedKeys(user, defaultAuthKeysFile)
if err != nil {
return err
}
var existingNonKeyLines []string
for _, line := range existingKeyData {
_, _, err := KeyFingerprint(line)
if err != nil {
existingNonKeyLines = append(existingNonKeyLines, line)
}
}
return writeAuthorisedKeys(user, defaultAuthKeysFile, append(existingNonKeyLines, newKeys...))
}
// ListKeys returns either the full keys or key comments from the authorized ssh keys file for user.
func ListKeys(user string, mode ListMode) ([]string, error) {
keysMutex.Lock()
defer keysMutex.Unlock()
keyData, err := readAuthorisedKeys(user, defaultAuthKeysFile)
if err != nil {
return nil, err
}
return listKeys(keyData, mode)
}
// Any ssh key added to the authorised keys list by Juju will have this prefix.
// This allows Juju to know which keys have been added externally and any such keys
// will always be retained by Juju when updating the authorised keys file.
const JujuCommentPrefix = "Juju:"
func EnsureJujuComment(key string) string {
ak, err := ParseAuthorisedKey(key)
// Just return an invalid key as is.
if err != nil {
logger.Warningf("invalid Juju ssh key %s: %v", key, err)
return key
}
if ak.Comment == "" {
return key + " " + JujuCommentPrefix + "sshkey"
} else {
// Add the Juju prefix to the comment if necessary.
if !strings.HasPrefix(ak.Comment, JujuCommentPrefix) {
commentIndex := strings.LastIndex(key, ak.Comment)
return key[:commentIndex] + JujuCommentPrefix + ak.Comment
}
}
return key
}
// AddKeysToFile adds the specified ssh keys to the specified file for user.
// Returns an error if there is an issue with *any* of the supplied keys.
func AddKeysToFile(user, file string, newKeys []string) error {
keysMutex.Lock()
defer keysMutex.Unlock()
existingKeys, err := readAuthorisedKeys(user, file)
if err != nil {
return err
}
return addKeys(user, file, newKeys, existingKeys)
}
// DeleteKeysFromFile removes the specified ssh keys from the authorized ssh keys file for user.
// keyIds may be either key comments or fingerprints.
// Returns an error if there is an issue with *any* of the keys to delete.
//
// Unlike DeleteKeys, this version can delete ALL keys from the target file.
func DeleteKeysFromFile(user, file string, keyIds []string) error {
keysMutex.Lock()
defer keysMutex.Unlock()
existingKeys, err := readAuthorisedKeys(user, file)
if err != nil {
return err
}
return deleteKeys(user, file, existingKeys, keyIds, true)
}
// ListKeys returns either the full keys or key comments from the authorized ssh keys file for user.
func ListKeysFromFile(user, file string, mode ListMode) ([]string, error) {
keysMutex.Lock()
defer keysMutex.Unlock()
keyData, err := readAuthorisedKeys(user, file)
if err != nil {
return nil, err
}
return listKeys(keyData, mode)
}
func addKeys(user, file string, newKeys, existingKeys []string) error {
for _, newKey := range newKeys {
fingerprint, comment, err := KeyFingerprint(newKey)
if err != nil {
return err
}
if comment == "" {
return errors.Errorf("cannot add ssh key without comment")
}
for _, key := range existingKeys {
existingFingerprint, existingComment, err := KeyFingerprint(key)
if err != nil {
// Only log a warning if the unrecognised key line is not a comment.
if key[0] != '#' {
logger.Warningf("invalid existing ssh key %q: %v", key, err)
}
continue
}
if existingFingerprint == fingerprint {
return errors.Errorf("cannot add duplicate ssh key: %v", fingerprint)
}
if existingComment == comment {
return errors.Errorf("cannot add ssh key with duplicate comment: %v", comment)
}
}
}
sshKeys := append(existingKeys, newKeys...)
return writeAuthorisedKeys(user, file, sshKeys)
}
func deleteKeys(user, file string, existingKeys, keyIdsToDelete []string, deleteAll bool) error {
// Build up a map of keys indexed by fingerprint, and fingerprints indexed by comment
// so we can easily get the key represented by each keyId, which may be either a fingerprint
// or comment.
var keysToWrite []string
var sshKeys = make(map[string]string)
var keyComments = make(map[string]string)
for _, key := range existingKeys {
fingerprint, comment, err := KeyFingerprint(key)
if err != nil {
logger.Debugf("keeping unrecognised existing ssh key %q: %v", key, err)
keysToWrite = append(keysToWrite, key)
continue
}
sshKeys[fingerprint] = key
if comment != "" {
keyComments[comment] = fingerprint
}
}
for _, keyId := range keyIdsToDelete {
// assume keyId may be a fingerprint
fingerprint := keyId
_, ok := sshKeys[keyId]
if !ok {
// keyId is a comment
fingerprint, ok = keyComments[keyId]
}
if !ok {
return errors.Errorf("cannot delete non existent key: %v", keyId)
}
delete(sshKeys, fingerprint)
}
for _, key := range sshKeys {
keysToWrite = append(keysToWrite, key)
}
if len(keysToWrite) == 0 && !deleteAll {
return errors.Errorf("cannot delete all keys")
}
return writeAuthorisedKeys(user, file, keysToWrite)
}
func listKeys(existingKeys []string, mode ListMode) ([]string, error) {
var keys []string
for _, key := range existingKeys {
fingerprint, comment, err := KeyFingerprint(key)
if err != nil {
// Only log a warning if the unrecognised key line is not a comment.
if key[0] != '#' {
logger.Warningf("ignoring invalid ssh key %q: %v", key, err)
}
continue
}
if mode == FullKeys {
keys = append(keys, key)
} else {
shortKey := fingerprint
if comment != "" {
shortKey += fmt.Sprintf(" (%s)", comment)
}
keys = append(keys, shortKey)
}
}
return keys, nil
}
================================================
FILE: ssh/authorisedkeys_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh_test
import (
"encoding/base64"
"strings"
gitjujutesting "github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/ssh"
sshtesting "github.com/juju/utils/v4/ssh/testing"
)
type AuthorisedKeysKeysSuite struct {
gitjujutesting.FakeHomeSuite
}
const (
// We'll use the current user for ssh tests.
testSSHUser = ""
authKeysFile = "authorized_keys"
alternativeKeysFile2 = "authorized_keys2"
alternativeKeysFile3 = "authorized_keys3"
)
var _ = gc.Suite(&AuthorisedKeysKeysSuite{})
func writeAuthKeysFile(c *gc.C, keys []string, file string) {
err := ssh.WriteAuthorisedKeys(testSSHUser, file, keys)
c.Assert(err, jc.ErrorIsNil)
}
func (s *AuthorisedKeysKeysSuite) TestListKeys(c *gc.C) {
keys := []string{
sshtesting.ValidKeyOne.Key + " user@host",
sshtesting.ValidKeyTwo.Key,
}
writeAuthKeysFile(c, keys, authKeysFile)
keys, err := ssh.ListKeys(testSSHUser, ssh.Fingerprints)
c.Assert(err, jc.ErrorIsNil)
c.Assert(
keys, gc.DeepEquals,
[]string{sshtesting.ValidKeyOne.Fingerprint + " (user@host)", sshtesting.ValidKeyTwo.Fingerprint})
}
func (s *AuthorisedKeysKeysSuite) TestListKeysFull(c *gc.C) {
keys := []string{
sshtesting.ValidKeyOne.Key + " user@host",
sshtesting.ValidKeyTwo.Key + " anotheruser@host",
}
writeAuthKeysFile(c, keys, authKeysFile)
actual, err := ssh.ListKeys(testSSHUser, ssh.FullKeys)
c.Assert(err, jc.ErrorIsNil)
c.Assert(actual, gc.DeepEquals, keys)
}
func (s *AuthorisedKeysKeysSuite) TestAddNewKey(c *gc.C) {
key := sshtesting.ValidKeyOne.Key + " user@host"
err := ssh.AddKeys(testSSHUser, key)
c.Assert(err, jc.ErrorIsNil)
actual, err := ssh.ListKeys(testSSHUser, ssh.FullKeys)
c.Assert(err, jc.ErrorIsNil)
c.Assert(actual, gc.DeepEquals, []string{key})
}
func (s *AuthorisedKeysKeysSuite) TestAddMoreKeys(c *gc.C) {
firstKey := sshtesting.ValidKeyOne.Key + " user@host"
writeAuthKeysFile(c, []string{firstKey}, authKeysFile)
moreKeys := []string{
sshtesting.ValidKeyTwo.Key + " anotheruser@host",
sshtesting.ValidKeyThree.Key + " yetanotheruser@host",
}
err := ssh.AddKeys(testSSHUser, moreKeys...)
c.Assert(err, jc.ErrorIsNil)
actual, err := ssh.ListKeys(testSSHUser, ssh.FullKeys)
c.Assert(err, jc.ErrorIsNil)
c.Assert(actual, gc.DeepEquals, append([]string{firstKey}, moreKeys...))
}
func (s *AuthorisedKeysKeysSuite) TestAddDuplicateKey(c *gc.C) {
key := sshtesting.ValidKeyOne.Key + " user@host"
err := ssh.AddKeys(testSSHUser, key)
c.Assert(err, jc.ErrorIsNil)
moreKeys := []string{
sshtesting.ValidKeyOne.Key + " user@host",
sshtesting.ValidKeyTwo.Key + " yetanotheruser@host",
}
err = ssh.AddKeys(testSSHUser, moreKeys...)
c.Assert(err, gc.ErrorMatches, "cannot add duplicate ssh key: "+sshtesting.ValidKeyOne.Fingerprint)
}
func (s *AuthorisedKeysKeysSuite) TestAddDuplicateComment(c *gc.C) {
key := sshtesting.ValidKeyOne.Key + " user@host"
err := ssh.AddKeys(testSSHUser, key)
c.Assert(err, jc.ErrorIsNil)
moreKeys := []string{
sshtesting.ValidKeyTwo.Key + " user@host",
sshtesting.ValidKeyThree.Key + " yetanotheruser@host",
}
err = ssh.AddKeys(testSSHUser, moreKeys...)
c.Assert(err, gc.ErrorMatches, "cannot add ssh key with duplicate comment: user@host")
}
func (s *AuthorisedKeysKeysSuite) TestAddKeyWithoutComment(c *gc.C) {
keys := []string{
sshtesting.ValidKeyOne.Key + " user@host",
sshtesting.ValidKeyTwo.Key,
}
err := ssh.AddKeys(testSSHUser, keys...)
c.Assert(err, gc.ErrorMatches, "cannot add ssh key without comment")
}
func (s *AuthorisedKeysKeysSuite) TestAddKeepsUnrecognised(c *gc.C) {
writeAuthKeysFile(c, []string{sshtesting.ValidKeyOne.Key, "invalid-key"}, authKeysFile)
anotherKey := sshtesting.ValidKeyTwo.Key + " anotheruser@host"
err := ssh.AddKeys(testSSHUser, anotherKey)
c.Assert(err, jc.ErrorIsNil)
actual, err := ssh.ReadAuthorisedKeys(testSSHUser, authKeysFile)
c.Assert(err, jc.ErrorIsNil)
c.Assert(actual, gc.DeepEquals, []string{sshtesting.ValidKeyOne.Key, "invalid-key", anotherKey})
}
func (s *AuthorisedKeysKeysSuite) TestDeleteKeys(c *gc.C) {
firstKey := sshtesting.ValidKeyOne.Key + " user@host"
anotherKey := sshtesting.ValidKeyTwo.Key
thirdKey := sshtesting.ValidKeyThree.Key + " anotheruser@host"
writeAuthKeysFile(c, []string{firstKey, anotherKey, thirdKey}, authKeysFile)
err := ssh.DeleteKeys(testSSHUser, "user@host", sshtesting.ValidKeyTwo.Fingerprint)
c.Assert(err, jc.ErrorIsNil)
actual, err := ssh.ListKeys(testSSHUser, ssh.FullKeys)
c.Assert(err, jc.ErrorIsNil)
c.Assert(actual, gc.DeepEquals, []string{thirdKey})
}
func (s *AuthorisedKeysKeysSuite) TestDeleteKeysKeepsUnrecognised(c *gc.C) {
firstKey := sshtesting.ValidKeyOne.Key + " user@host"
writeAuthKeysFile(c, []string{firstKey, sshtesting.ValidKeyTwo.Key, "invalid-key"}, authKeysFile)
err := ssh.DeleteKeys(testSSHUser, "user@host")
c.Assert(err, jc.ErrorIsNil)
actual, err := ssh.ReadAuthorisedKeys(testSSHUser, authKeysFile)
c.Assert(err, jc.ErrorIsNil)
c.Assert(actual, gc.DeepEquals, []string{"invalid-key", sshtesting.ValidKeyTwo.Key})
}
func (s *AuthorisedKeysKeysSuite) TestDeleteNonExistentComment(c *gc.C) {
firstKey := sshtesting.ValidKeyOne.Key + " user@host"
writeAuthKeysFile(c, []string{firstKey}, authKeysFile)
err := ssh.DeleteKeys(testSSHUser, "someone@host")
c.Assert(err, gc.ErrorMatches, "cannot delete non existent key: someone@host")
}
func (s *AuthorisedKeysKeysSuite) TestDeleteNonExistentFingerprint(c *gc.C) {
firstKey := sshtesting.ValidKeyOne.Key + " user@host"
writeAuthKeysFile(c, []string{firstKey}, authKeysFile)
err := ssh.DeleteKeys(testSSHUser, sshtesting.ValidKeyTwo.Fingerprint)
c.Assert(err, gc.ErrorMatches, "cannot delete non existent key: "+sshtesting.ValidKeyTwo.Fingerprint)
}
func (s *AuthorisedKeysKeysSuite) TestDeleteLastKeyForbidden(c *gc.C) {
keys := []string{
sshtesting.ValidKeyOne.Key + " user@host",
sshtesting.ValidKeyTwo.Key + " yetanotheruser@host",
}
writeAuthKeysFile(c, keys, authKeysFile)
err := ssh.DeleteKeys(testSSHUser, "user@host", sshtesting.ValidKeyTwo.Fingerprint)
c.Assert(err, gc.ErrorMatches, "cannot delete all keys")
}
func (s *AuthorisedKeysKeysSuite) TestReplaceKeys(c *gc.C) {
firstKey := sshtesting.ValidKeyOne.Key + " user@host"
anotherKey := sshtesting.ValidKeyTwo.Key
writeAuthKeysFile(c, []string{firstKey, anotherKey}, authKeysFile)
// replaceKey is created without a comment so test that
// ReplaceKeys handles keys without comments. This is
// because existing keys may not have a comment and
// ReplaceKeys is used to rewrite the entire authorized_keys
// file when adding new keys.
replaceKey := sshtesting.ValidKeyThree.Key
err := ssh.ReplaceKeys(testSSHUser, replaceKey)
c.Assert(err, jc.ErrorIsNil)
actual, err := ssh.ListKeys(testSSHUser, ssh.FullKeys)
c.Assert(err, jc.ErrorIsNil)
c.Assert(actual, gc.DeepEquals, []string{replaceKey})
}
func (s *AuthorisedKeysKeysSuite) TestReplaceKeepsUnrecognised(c *gc.C) {
writeAuthKeysFile(c, []string{sshtesting.ValidKeyOne.Key, "invalid-key"}, authKeysFile)
anotherKey := sshtesting.ValidKeyTwo.Key + " anotheruser@host"
err := ssh.ReplaceKeys(testSSHUser, anotherKey)
c.Assert(err, jc.ErrorIsNil)
actual, err := ssh.ReadAuthorisedKeys(testSSHUser, authKeysFile)
c.Assert(err, jc.ErrorIsNil)
c.Assert(actual, gc.DeepEquals, []string{"invalid-key", anotherKey})
}
func (s *AuthorisedKeysKeysSuite) TestEnsureJujuComment(c *gc.C) {
sshKey := sshtesting.ValidKeyOne.Key
for _, test := range []struct {
key string
expected string
}{
{"invalid-key", "invalid-key"},
{sshKey, sshKey + " Juju:sshkey"},
{sshKey + " user@host", sshKey + " Juju:user@host"},
{sshKey + " Juju:user@host", sshKey + " Juju:user@host"},
{sshKey + " " + sshKey[3:5], sshKey + " Juju:" + sshKey[3:5]},
} {
actual := ssh.EnsureJujuComment(test.key)
c.Assert(actual, gc.Equals, test.expected)
}
}
func (s *AuthorisedKeysKeysSuite) TestSplitAuthorisedKeys(c *gc.C) {
sshKey := sshtesting.ValidKeyOne.Key
for _, test := range []struct {
keyData string
expected []string
}{
{"", nil},
{sshKey, []string{sshKey}},
{sshKey + "\n", []string{sshKey}},
{sshKey + "\n\n", []string{sshKey}},
{sshKey + "\n#comment\n", []string{sshKey}},
{sshKey + "\n #comment\n", []string{sshKey}},
{sshKey + "\ninvalid\n", []string{sshKey, "invalid"}},
} {
actual := ssh.SplitAuthorisedKeys(test.keyData)
c.Assert(actual, gc.DeepEquals, test.expected)
}
}
func b64decode(c *gc.C, s string) []byte {
b, err := base64.StdEncoding.DecodeString(s)
c.Assert(err, jc.ErrorIsNil)
return b
}
func (s *AuthorisedKeysKeysSuite) TestParseAuthorisedKey(c *gc.C) {
for i, test := range []struct {
line string
key []byte
comment string
err string
}{{
line: sshtesting.ValidKeyOne.Key,
key: b64decode(c, strings.Fields(sshtesting.ValidKeyOne.Key)[1]),
}, {
line: sshtesting.ValidKeyOne.Key + " a b c",
key: b64decode(c, strings.Fields(sshtesting.ValidKeyOne.Key)[1]),
comment: "a b c",
}, {
line: "ssh-xsa blah",
err: "invalid authorized_key \"ssh-xsa blah\"",
}, {
// options should be skipped
line: `no-pty,principals="\"",command="\!" ` + sshtesting.ValidKeyOne.Key,
key: b64decode(c, strings.Fields(sshtesting.ValidKeyOne.Key)[1]),
}, {
line: "ssh-rsa",
err: "invalid authorized_key \"ssh-rsa\"",
}, {
line: sshtesting.ValidKeyOne.Key + " line1\nline2",
err: "newline in authorized_key \".*",
}} {
c.Logf("test %d: %s", i, test.line)
ak, err := ssh.ParseAuthorisedKey(test.line)
if test.err != "" {
c.Assert(err, gc.ErrorMatches, test.err)
} else {
c.Assert(err, jc.ErrorIsNil)
c.Assert(ak, gc.Not(gc.IsNil))
c.Assert(ak.Key, gc.DeepEquals, test.key)
c.Assert(ak.Comment, gc.Equals, test.comment)
}
}
}
func (s *AuthorisedKeysKeysSuite) TestConcatAuthorisedKeys(c *gc.C) {
for _, test := range []struct{ a, b, result string }{
{"a", "", "a"},
{"", "b", "b"},
{"a", "b", "a\nb"},
{"a\n", "b", "a\nb"},
} {
c.Check(ssh.ConcatAuthorisedKeys(test.a, test.b), gc.Equals, test.result)
}
}
func (s *AuthorisedKeysKeysSuite) TestAddKeysToFileToDifferentFiles(c *gc.C) {
key1 := sshtesting.ValidKeyOne.Key + " user@host"
err := ssh.AddKeysToFile(testSSHUser, alternativeKeysFile2, []string{key1})
c.Assert(err, jc.ErrorIsNil)
list1, err := ssh.ListKeysFromFile(testSSHUser, alternativeKeysFile2, ssh.FullKeys)
c.Assert(err, jc.ErrorIsNil)
c.Assert(list1, gc.DeepEquals, []string{key1})
key2 := sshtesting.ValidKeyTwo.Key + " user@host"
err = ssh.AddKeysToFile(testSSHUser, alternativeKeysFile3, []string{key2})
c.Assert(err, jc.ErrorIsNil)
list2, err := ssh.ListKeysFromFile(testSSHUser, alternativeKeysFile3, ssh.FullKeys)
c.Assert(err, jc.ErrorIsNil)
c.Assert(list2, gc.DeepEquals, []string{key2})
}
func (s *AuthorisedKeysKeysSuite) TestAddKeysToFileMultipleKeys(c *gc.C) {
key1 := sshtesting.ValidKeyOne.Key + " user@host"
key2 := sshtesting.ValidKeyTwo.Key + " alice@host"
err := ssh.AddKeysToFile(testSSHUser, alternativeKeysFile2, []string{key1, key2})
c.Assert(err, jc.ErrorIsNil)
list, err := ssh.ListKeysFromFile(testSSHUser, alternativeKeysFile2, ssh.FullKeys)
c.Assert(err, jc.ErrorIsNil)
c.Assert(list, jc.DeepEquals, []string{key1, key2})
}
func (s *AuthorisedKeysKeysSuite) TestDeleteAllKeysFromFile(c *gc.C) {
key1 := sshtesting.ValidKeyOne.Key + " user@host"
writeAuthKeysFile(c, []string{key1}, alternativeKeysFile2)
err := ssh.DeleteKeysFromFile(testSSHUser, alternativeKeysFile2, []string{sshtesting.ValidKeyOne.Fingerprint})
c.Assert(err, jc.ErrorIsNil)
emptyList, err := ssh.ListKeysFromFile(testSSHUser, alternativeKeysFile2, ssh.FullKeys)
c.Assert(err, jc.ErrorIsNil)
c.Assert(emptyList, gc.HasLen, 0)
}
func (s *AuthorisedKeysKeysSuite) TestDeleteSomeKeysFromFile(c *gc.C) {
key1 := sshtesting.ValidKeyOne.Key + " user@host"
key2 := sshtesting.ValidKeyTwo.Key + " alice@host"
key3 := sshtesting.ValidKeyThree.Key + " bob@host"
writeAuthKeysFile(c, []string{key1, key2, key3}, alternativeKeysFile2)
err := ssh.DeleteKeysFromFile(testSSHUser, alternativeKeysFile2, []string{sshtesting.ValidKeyTwo.Fingerprint})
c.Assert(err, jc.ErrorIsNil)
keys, err := ssh.ListKeysFromFile(testSSHUser, alternativeKeysFile2, ssh.FullKeys)
c.Assert(err, jc.ErrorIsNil)
c.Assert(keys, gc.HasLen, 2)
c.Assert(keys, jc.SameContents, []string{key1, key3})
}
================================================
FILE: ssh/clientkeys.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"sync"
"github.com/juju/collections/set"
"golang.org/x/crypto/ssh"
"github.com/juju/utils/v4"
)
const clientKeyName = "juju_id_ed25519"
// PublicKeySuffix is the file extension for public key files.
const PublicKeySuffix = ".pub"
var (
clientKeysMutex sync.Mutex
// clientKeys is a cached map of private key filenames
// to ssh.Signers. The private keys are those loaded
// from the client key directory, passed to LoadClientKeys.
clientKeys map[string]ssh.Signer
)
// LoadClientKeys loads the client SSH keys from the
// specified directory, and caches them as a process-wide
// global. If the directory does not exist, it is created;
// if the directory did not exist, or contains no keys, it
// is populated with a new key pair.
//
// If the directory exists, then all pairs of files where one
// has the same name as the other + ".pub" will be loaded as
// private/public key pairs.
//
// Calls to LoadClientKeys will clear the previously loaded
// keys, and recompute the keys.
func LoadClientKeys(dir string) error {
clientKeysMutex.Lock()
defer clientKeysMutex.Unlock()
dir, err := utils.NormalizePath(dir)
if err != nil {
return err
}
if _, err := os.Stat(dir); err == nil {
keys, err := loadClientKeys(dir)
if err != nil {
return err
} else if len(keys) > 0 {
clientKeys = keys
return nil
}
// Directory exists but contains no keys;
// fall through and create one.
}
if err := os.MkdirAll(dir, 0700); err != nil {
return err
}
keyfile, key, err := generateClientKey(dir)
if err != nil {
os.RemoveAll(dir)
return err
}
clientKeys = map[string]ssh.Signer{keyfile: key}
return nil
}
// ClearClientKeys clears the client keys cached in memory.
func ClearClientKeys() {
clientKeysMutex.Lock()
defer clientKeysMutex.Unlock()
clientKeys = nil
}
func generateClientKey(dir string) (keyfile string, key ssh.Signer, err error) {
private, public, err := GenerateKey("juju-client-key")
if err != nil {
return "", nil, err
}
clientPrivateKey, err := ssh.ParsePrivateKey([]byte(private))
if err != nil {
return "", nil, err
}
privkeyFilename := filepath.Join(dir, clientKeyName)
if err = ioutil.WriteFile(privkeyFilename, []byte(private), 0600); err != nil {
return "", nil, err
}
if err := ioutil.WriteFile(privkeyFilename+PublicKeySuffix, []byte(public), 0600); err != nil {
os.Remove(privkeyFilename)
return "", nil, err
}
return privkeyFilename, clientPrivateKey, nil
}
func loadClientKeys(dir string) (map[string]ssh.Signer, error) {
publicKeyFiles, err := publicKeyFiles(dir)
if err != nil {
return nil, err
}
keys := make(map[string]ssh.Signer, len(publicKeyFiles))
for _, filename := range publicKeyFiles {
filename = filename[:len(filename)-len(PublicKeySuffix)]
data, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
keys[filename], err = ssh.ParsePrivateKey(data)
if err != nil {
return nil, fmt.Errorf("parsing key file %q: %v", filename, err)
}
}
return keys, nil
}
// privateKeys returns the private keys loaded by LoadClientKeys.
func privateKeys() (signers []ssh.Signer) {
clientKeysMutex.Lock()
defer clientKeysMutex.Unlock()
for _, key := range clientKeys {
signers = append(signers, key)
}
return signers
}
// PrivateKeyFiles returns the filenames of private SSH keys loaded by
// LoadClientKeys.
func PrivateKeyFiles() []string {
clientKeysMutex.Lock()
defer clientKeysMutex.Unlock()
keyfiles := make([]string, 0, len(clientKeys))
for f := range clientKeys {
keyfiles = append(keyfiles, f)
}
return keyfiles
}
// PublicKeyFiles returns the filenames of public SSH keys loaded by
// LoadClientKeys.
func PublicKeyFiles() []string {
privkeys := PrivateKeyFiles()
pubkeys := make([]string, len(privkeys))
for i, priv := range privkeys {
pubkeys[i] = priv + PublicKeySuffix
}
return pubkeys
}
// publicKeyFiles returns the filenames of public SSH keys
// in the specified directory (all the files ending with .pub).
func publicKeyFiles(clientKeysDir string) ([]string, error) {
if clientKeysDir == "" {
return nil, nil
}
var keys []string
dir, err := os.Open(clientKeysDir)
if err != nil {
return nil, err
}
names, err := dir.Readdirnames(-1)
dir.Close()
if err != nil {
return nil, err
}
candidates := set.NewStrings(names...)
for _, name := range names {
if !strings.HasSuffix(name, PublicKeySuffix) {
continue
}
// If the private key filename also exists, add the file.
priv := name[:len(name)-len(PublicKeySuffix)]
if candidates.Contains(priv) {
keys = append(keys, filepath.Join(dir.Name(), name))
}
}
return keys, nil
}
================================================
FILE: ssh/clientkeys_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh_test
import (
"io/ioutil"
"os"
gitjujutesting "github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
"github.com/juju/utils/v4/ssh"
)
type ClientKeysSuite struct {
gitjujutesting.FakeHomeSuite
}
var _ = gc.Suite(&ClientKeysSuite{})
func (s *ClientKeysSuite) SetUpTest(c *gc.C) {
s.FakeHomeSuite.SetUpTest(c)
s.AddCleanup(func(*gc.C) { ssh.ClearClientKeys() })
generateKeyRestorer := overrideGenerateKey()
s.AddCleanup(func(*gc.C) { generateKeyRestorer.Restore() })
}
func checkFiles(c *gc.C, obtained, expected []string) {
var err error
for i, e := range expected {
expected[i], err = utils.NormalizePath(e)
c.Assert(err, jc.ErrorIsNil)
}
c.Assert(obtained, jc.SameContents, expected)
}
func checkPublicKeyFiles(c *gc.C, expected ...string) {
keys := ssh.PublicKeyFiles()
checkFiles(c, keys, expected)
}
func checkPrivateKeyFiles(c *gc.C, expected ...string) {
keys := ssh.PrivateKeyFiles()
checkFiles(c, keys, expected)
}
func (s *ClientKeysSuite) TestPublicKeyFiles(c *gc.C) {
// LoadClientKeys will create the specified directory
// and populate it with a key pair.
err := ssh.LoadClientKeys("~/.juju/ssh")
c.Assert(err, jc.ErrorIsNil)
checkPublicKeyFiles(c, "~/.juju/ssh/juju_id_ed25519.pub")
// All files ending with .pub in the client key dir get picked up.
priv, pub, err := ssh.GenerateKey("whatever")
c.Assert(err, jc.ErrorIsNil)
err = ioutil.WriteFile(gitjujutesting.HomePath(".juju", "ssh", "whatever.pub"), []byte(pub), 0600)
c.Assert(err, jc.ErrorIsNil)
err = ssh.LoadClientKeys("~/.juju/ssh")
c.Assert(err, jc.ErrorIsNil)
// The new public key won't be observed until the
// corresponding private key exists.
checkPublicKeyFiles(c, "~/.juju/ssh/juju_id_ed25519.pub")
err = ioutil.WriteFile(gitjujutesting.HomePath(".juju", "ssh", "whatever"), []byte(priv), 0600)
c.Assert(err, jc.ErrorIsNil)
err = ssh.LoadClientKeys("~/.juju/ssh")
c.Assert(err, jc.ErrorIsNil)
checkPublicKeyFiles(c, "~/.juju/ssh/juju_id_ed25519.pub", "~/.juju/ssh/whatever.pub")
}
func (s *ClientKeysSuite) TestPrivateKeyFiles(c *gc.C) {
// Create/load client keys. They will be cached in memory:
// any files added to the directory will not be considered
// unless LoadClientKeys is called again.
err := ssh.LoadClientKeys("~/.juju/ssh")
c.Assert(err, jc.ErrorIsNil)
checkPrivateKeyFiles(c, "~/.juju/ssh/juju_id_ed25519")
priv, pub, err := ssh.GenerateKey("whatever")
c.Assert(err, jc.ErrorIsNil)
err = ioutil.WriteFile(gitjujutesting.HomePath(".juju", "ssh", "whatever"), []byte(priv), 0600)
c.Assert(err, jc.ErrorIsNil)
err = ssh.LoadClientKeys("~/.juju/ssh")
c.Assert(err, jc.ErrorIsNil)
// The new private key won't be observed until the
// corresponding public key exists.
checkPrivateKeyFiles(c, "~/.juju/ssh/juju_id_ed25519")
err = ioutil.WriteFile(gitjujutesting.HomePath(".juju", "ssh", "whatever.pub"), []byte(pub), 0600)
c.Assert(err, jc.ErrorIsNil)
// new keys won't be reported until we call LoadClientKeys again
checkPublicKeyFiles(c, "~/.juju/ssh/juju_id_ed25519.pub")
checkPrivateKeyFiles(c, "~/.juju/ssh/juju_id_ed25519")
err = ssh.LoadClientKeys("~/.juju/ssh")
c.Assert(err, jc.ErrorIsNil)
checkPublicKeyFiles(c, "~/.juju/ssh/juju_id_ed25519.pub", "~/.juju/ssh/whatever.pub")
checkPrivateKeyFiles(c, "~/.juju/ssh/juju_id_ed25519", "~/.juju/ssh/whatever")
}
func (s *ClientKeysSuite) TestLoadClientKeysDirExists(c *gc.C) {
err := os.MkdirAll(gitjujutesting.HomePath(".juju", "ssh"), 0755)
c.Assert(err, jc.ErrorIsNil)
err = ssh.LoadClientKeys("~/.juju/ssh")
c.Assert(err, jc.ErrorIsNil)
checkPrivateKeyFiles(c, "~/.juju/ssh/juju_id_ed25519")
}
================================================
FILE: ssh/export_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh
import (
"sync/atomic"
gc "gopkg.in/check.v1"
"github.com/juju/testing"
)
var (
ReadAuthorisedKeys = readAuthorisedKeys
WriteAuthorisedKeys = writeAuthorisedKeys
InitDefaultClient = initDefaultClient
DefaultIdentities = &defaultIdentities
SSHDial = &sshDial
ED25519GenerateKey = &ed25519GenerateKey
TestCopyReader = copyReader
TestNewCmd = newCmd
)
type ReadLineWriter readLineWriter
func PatchTerminal(s *testing.CleanupSuite, rlw ReadLineWriter) {
var balance int64
s.PatchValue(&getTerminal, func() (readLineWriter, func(), error) {
atomic.AddInt64(&balance, 1)
cleanup := func() {
atomic.AddInt64(&balance, -1)
}
return rlw, cleanup, nil
})
s.AddCleanup(func(c *gc.C) {
c.Assert(atomic.LoadInt64(&balance), gc.Equals, int64(0))
})
}
func PatchNilTerminal(s *testing.CleanupSuite) {
s.PatchValue(&getTerminal, func() (readLineWriter, func(), error) {
return nil, func() {}, nil
})
}
================================================
FILE: ssh/fakes_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh_test
import (
"bytes"
"io"
"io/ioutil"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/ssh"
)
type fakeClient struct {
calls []string
hostArg string
commandArg []string
optionsArg *ssh.Options
copyArgs []string
err error
cmd *ssh.Cmd
impl fakeCommandImpl
}
func (cl *fakeClient) checkCalls(c *gc.C, host string, command []string, options *ssh.Options, copyArgs []string, calls ...string) {
c.Check(cl.hostArg, gc.Equals, host)
c.Check(cl.commandArg, jc.DeepEquals, command)
c.Check(cl.optionsArg, gc.Equals, options)
c.Check(cl.copyArgs, jc.DeepEquals, copyArgs)
c.Check(cl.calls, jc.DeepEquals, calls)
}
func (cl *fakeClient) Command(host string, command []string, options *ssh.Options) *ssh.Cmd {
cl.calls = append(cl.calls, "Command")
cl.hostArg = host
cl.commandArg = command
cl.optionsArg = options
cmd := cl.cmd
if cmd == nil {
cmd = ssh.TestNewCmd(&cl.impl)
}
return cmd
}
func (cl *fakeClient) Copy(args []string, options *ssh.Options) error {
cl.calls = append(cl.calls, "Copy")
cl.copyArgs = args
cl.optionsArg = options
return cl.err
}
type bufferWriter struct {
bytes.Buffer
}
func (*bufferWriter) Close() error {
return nil
}
type fakeCommandImpl struct {
calls []string
stdinArg io.Reader
stdoutArg io.Writer
stderrArg io.Writer
stdinData bufferWriter
err error
stdinRaw io.Reader
stdoutRaw io.Writer
stderrRaw io.Writer
stdoutData bytes.Buffer
stderrData bytes.Buffer
}
func (ci *fakeCommandImpl) checkCalls(c *gc.C, stdin io.Reader, stdout, stderr io.Writer, calls ...string) {
c.Check(ci.stdinArg, gc.Equals, stdin)
c.Check(ci.stdoutArg, gc.Equals, stdout)
c.Check(ci.stderrArg, gc.Equals, stderr)
c.Check(ci.calls, jc.DeepEquals, calls)
}
func (ci *fakeCommandImpl) checkStdin(c *gc.C, data string) {
c.Check(ci.stdinData.String(), gc.Equals, data)
}
func (ci *fakeCommandImpl) Start() error {
ci.calls = append(ci.calls, "Start")
return ci.err
}
func (ci *fakeCommandImpl) Wait() error {
ci.calls = append(ci.calls, "Wait")
return ci.err
}
func (ci *fakeCommandImpl) Kill() error {
ci.calls = append(ci.calls, "Kill")
return ci.err
}
func (ci *fakeCommandImpl) SetStdio(stdin io.Reader, stdout, stderr io.Writer) {
ci.calls = append(ci.calls, "SetStdio")
ci.stdinArg = stdin
ci.stdoutArg = stdout
ci.stderrArg = stderr
}
func (ci *fakeCommandImpl) StdinPipe() (io.WriteCloser, io.Reader, error) {
ci.calls = append(ci.calls, "StdinPipe")
return &ci.stdinData, ci.stdinRaw, ci.err
}
func (ci *fakeCommandImpl) StdoutPipe() (io.ReadCloser, io.Writer, error) {
ci.calls = append(ci.calls, "StdoutPipe")
return ioutil.NopCloser(&ci.stdoutData), ci.stdoutRaw, ci.err
}
func (ci *fakeCommandImpl) StderrPipe() (io.ReadCloser, io.Writer, error) {
ci.calls = append(ci.calls, "StderrPipe")
return ioutil.NopCloser(&ci.stderrData), ci.stderrRaw, ci.err
}
================================================
FILE: ssh/fingerprint.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh
import (
"bytes"
"crypto/md5"
"fmt"
"github.com/juju/errors"
)
// KeyFingerprint returns the fingerprint and comment for the specified key
// in authorized_key format. Fingerprints are generated according to RFC4716.
// See ttp://www.ietf.org/rfc/rfc4716.txt, section 4.
func KeyFingerprint(key string) (fingerprint, comment string, err error) {
ak, err := ParseAuthorisedKey(key)
if err != nil {
return "", "", errors.Errorf("generating key fingerprint: %v", err)
}
hash := md5.New()
hash.Write(ak.Key)
sum := hash.Sum(nil)
var buf bytes.Buffer
for i := 0; i < hash.Size(); i++ {
if i > 0 {
buf.WriteByte(':')
}
buf.WriteString(fmt.Sprintf("%02x", sum[i]))
}
return buf.String(), ak.Comment, nil
}
================================================
FILE: ssh/fingerprint_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh_test
import (
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/ssh"
sshtesting "github.com/juju/utils/v4/ssh/testing"
)
type FingerprintSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&FingerprintSuite{})
func (s *FingerprintSuite) TestKeyFingerprint(c *gc.C) {
keys := []sshtesting.SSHKey{
sshtesting.ValidKeyOne,
sshtesting.ValidKeyTwo,
sshtesting.ValidKeyThree,
}
for _, k := range keys {
fingerprint, _, err := ssh.KeyFingerprint(k.Key)
c.Assert(err, jc.ErrorIsNil)
c.Assert(fingerprint, gc.Equals, k.Fingerprint)
}
}
func (s *FingerprintSuite) TestKeyFingerprintError(c *gc.C) {
_, _, err := ssh.KeyFingerprint("invalid key")
c.Assert(err, gc.ErrorMatches, `generating key fingerprint: invalid authorized_key "invalid key"`)
}
================================================
FILE: ssh/generate.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh
import (
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"fmt"
"strings"
"github.com/juju/errors"
"golang.org/x/crypto/ssh"
)
// ed25519GenerateKey allows for tests to patch out ed25519 key generation
var ed25519GenerateKey = ed25519.GenerateKey
// GenerateKey makes a ED25519 no-passphrase SSH capable key.
// The private key returned is encoded to ASCII using the PKCS1 encoding.
// The public key is suitable to be added into an authorized_keys file,
// and has the comment passed in as the comment part of the key.
func GenerateKey(comment string) (private, public string, err error) {
_, privateKey, err := ed25519GenerateKey(rand.Reader)
if err != nil {
return "", "", errors.Trace(err)
}
pemBlock, err := ssh.MarshalPrivateKey(privateKey, comment)
if err != nil {
return "", "", errors.Trace(err)
}
identity := pem.EncodeToMemory(pemBlock)
public, err = PublicKey(identity, comment)
if err != nil {
return "", "", errors.Trace(err)
}
return string(identity), public, nil
}
// PublicKey returns the public key for any private key. The public key is
// suitable to be added into an authorized_keys file, and has the comment
// passed in as the comment part of the key.
func PublicKey(privateKey []byte, comment string) (string, error) {
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
return "", errors.Annotate(err, "failed to load key")
}
auth_key := string(ssh.MarshalAuthorizedKey(signer.PublicKey()))
// Strip off the trailing new line so we can add a comment.
auth_key = strings.TrimSpace(auth_key)
public := fmt.Sprintf("%s %s\n", auth_key, comment)
return public, nil
}
================================================
FILE: ssh/generate_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh_test
import (
"crypto/dsa"
"crypto/ed25519"
"io"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/ssh"
)
type GenerateSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&GenerateSuite{})
var (
pregeneratedKey ed25519.PrivateKey
)
// overrideGenerateKey patches out rsa.GenerateKey to create a single testing
// key which is saved and used between tests to save computation time.
func overrideGenerateKey() testing.Restorer {
restorer := testing.PatchValue(ssh.ED25519GenerateKey, func(random io.Reader) (ed25519.PublicKey, ed25519.PrivateKey, error) {
if pregeneratedKey != nil {
return ed25519.PublicKey{}, pregeneratedKey, nil
}
public, private, err := generateED25519Key(random)
if err != nil {
return nil, nil, err
}
pregeneratedKey = private
return public, private, nil
})
return restorer
}
func generateED25519Key(random io.Reader) (ed25519.PublicKey, ed25519.PrivateKey, error) {
// Ignore requested bits and just use 512 bits for speed
public, private, err := ed25519.GenerateKey(random)
if err != nil {
return nil, nil, err
}
return public, private, nil
}
func generateDSAKey(random io.Reader) (*dsa.PrivateKey, error) {
var privKey dsa.PrivateKey
if err := dsa.GenerateParameters(&privKey.Parameters, random, dsa.L1024N160); err != nil {
return nil, err
}
if err := dsa.GenerateKey(&privKey, random); err != nil {
return nil, err
}
return &privKey, nil
}
func (s *GenerateSuite) TestGenerate(c *gc.C) {
defer overrideGenerateKey().Restore()
private, public, err := ssh.GenerateKey("some-comment")
c.Check(err, jc.ErrorIsNil)
c.Check(private, jc.HasPrefix, "-----BEGIN OPENSSH PRIVATE KEY-----\n")
c.Check(private, jc.HasSuffix, "-----END OPENSSH PRIVATE KEY-----\n")
c.Check(public, jc.HasPrefix, "ssh-ed25519 ")
c.Check(public, jc.HasSuffix, " some-comment\n")
}
================================================
FILE: ssh/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: ssh/run.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh
import (
"bytes"
"os/exec"
"strings"
"syscall"
"time"
"github.com/juju/clock"
"github.com/juju/errors"
utilexec "github.com/juju/utils/v4/exec"
)
// ExecParams are used for the parameters for ExecuteCommandOnMachine.
type ExecParams struct {
IdentityFile string
Host string
Command string
Timeout time.Duration
}
// StartCommandOnMachine executes the command on the given host. The
// command is run in a Bash shell over an SSH connection. All output
// is captured. A RunningCmd is returned that may be used to wait
// for the command to finish running.
func StartCommandOnMachine(params ExecParams) (*RunningCmd, error) {
// execute bash accepting commands on stdin
if params.Host == "" {
return nil, errors.Errorf("missing host address")
}
logger.Debugf("execute on %s", params.Host)
var options Options
if params.IdentityFile != "" {
options.SetIdentities(params.IdentityFile)
}
command := Command(params.Host, []string{"/bin/bash", "-s"}, &options)
// Run the command.
running := &RunningCmd{
SSHCmd: command,
}
command.Stdout = &running.Stdout
command.Stderr = &running.Stderr
command.Stdin = strings.NewReader(params.Command + "\n")
if err := command.Start(); err != nil {
return nil, errors.Trace(err)
}
return running, nil
}
// RunningCmd represents a command that has been started.
type RunningCmd struct {
// SSHCmd is the command the was started.
SSHCmd *Cmd
// Stdout and Stderr are the output streams the command is using.
Stdout bytes.Buffer
Stderr bytes.Buffer
}
// Wait waits for the command to complete and returns the result.
func (cmd *RunningCmd) Wait() (result utilexec.ExecResponse, _ error) {
defer func() {
// Gather as much as we have from stdout and stderr.
result.Stdout = cmd.Stdout.Bytes()
result.Stderr = cmd.Stderr.Bytes()
}()
err := cmd.SSHCmd.Wait()
logger.Debugf("command.Wait finished (err: %v)", err)
code, err := getExitCode(err)
if err != nil {
return result, errors.Trace(err)
}
result.Code = code
return result, nil
}
// TODO(ericsnow) Add RunningCmd.WaitAbortable(abortChan <-chan error) ...
// based on WaitWithTimeout and update WaitWithTimeout to use it. We
// could make it WaitAbortable(abortChans ...<-chan error), which would
// require using reflect.Select(). Then that could simply replace Wait().
// It may make more sense, however, to have a helper function:
// Wait(cmd T, abortChans ...<-chan error) ...
// Cancelled is an error indicating that a command timed out.
var Cancelled = errors.New("command timed out")
// WaitWithCancel waits for the command to complete and returns the result. If
// cancel is closed before the result was returned, then it takes longer than
// the provided timeout then Cancelled is returned.
func (cmd *RunningCmd) WaitWithCancel(cancel <-chan struct{}) (utilexec.ExecResponse, error) {
var result utilexec.ExecResponse
done := make(chan error, 1)
go func() {
defer close(done)
waitResult, err := cmd.Wait()
result = waitResult
done <- err
}()
select {
case err := <-done:
return result, errors.Trace(err)
case <-cancel:
logger.Infof("killing the command due to cancellation")
cmd.SSHCmd.Kill()
<-done // Ensure that the original cmd.Wait() call completed.
cmd.SSHCmd.Wait() // Finalize cmd.SSHCmd, if necessary.
return result, Cancelled
}
}
func getExitCode(err error) (int, error) {
if err == nil {
return 0, nil
}
err = errors.Cause(err)
if ee, ok := err.(*exec.ExitError); ok {
raw := ee.ProcessState.Sys()
status, ok := raw.(syscall.WaitStatus)
if !ok {
logger.Errorf("unexpected type %T from ProcessState.Sys()", raw)
} else if status.Exited() {
// A non-zero return code isn't considered an error here.
return status.ExitStatus(), nil
}
}
return -1, err
}
// ExecuteCommandOnMachine will execute the command passed through on
// the host specified. This is done using ssh, and passing the commands
// through /bin/bash. If the command is not finished within the timeout
// specified, an error is returned. Any output captured during that time
// is also returned in the remote response.
func ExecuteCommandOnMachine(args ExecParams) (utilexec.ExecResponse, error) {
var result utilexec.ExecResponse
cmd, err := StartCommandOnMachine(args)
if err != nil {
return result, errors.Trace(err)
}
cancel := make(chan struct{})
go func() {
<-clock.WallClock.After(args.Timeout)
close(cancel)
}()
result, err = cmd.WaitWithCancel(cancel)
if err != nil {
return result, errors.Trace(err)
}
return result, nil
}
================================================
FILE: ssh/run_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh_test
import (
"io/ioutil"
"os"
"path/filepath"
"runtime"
"time"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/ssh"
)
const (
shortWait = 50 * time.Millisecond
longWait = 10 * time.Second
)
type ExecuteSSHCommandSuite struct {
testing.IsolationSuite
originalPath string
testbin string
fakessh string
}
var _ = gc.Suite(&ExecuteSSHCommandSuite{})
func (s *ExecuteSSHCommandSuite) SetUpSuite(c *gc.C) {
s.originalPath = os.Getenv("PATH")
s.IsolationSuite.SetUpSuite(c)
}
func (s *ExecuteSSHCommandSuite) SetUpTest(c *gc.C) {
if runtime.GOOS == "windows" {
c.Skip("issue 1403084: Tests use OpenSSH only")
}
s.IsolationSuite.SetUpTest(c)
err := os.Setenv("PATH", s.originalPath)
c.Assert(err, jc.ErrorIsNil)
s.testbin = c.MkDir()
s.fakessh = filepath.Join(s.testbin, "ssh")
s.PatchEnvPathPrepend(s.testbin)
}
func (s *ExecuteSSHCommandSuite) fakeSSH(c *gc.C, cmd string) {
err := ioutil.WriteFile(s.fakessh, []byte(cmd), 0755)
c.Assert(err, jc.ErrorIsNil)
}
func (s *ExecuteSSHCommandSuite) TestCaptureOutput(c *gc.C) {
s.fakeSSH(c, echoSSH)
response, err := ssh.ExecuteCommandOnMachine(ssh.ExecParams{
Host: "hostname",
Command: "sudo apt-get update\nsudo apt-get upgrade",
Timeout: longWait,
})
c.Assert(err, jc.ErrorIsNil)
c.Assert(response.Code, gc.Equals, 0)
c.Assert(string(response.Stdout), gc.Equals, "sudo apt-get update\nsudo apt-get upgrade\n")
c.Assert(string(response.Stderr), gc.Equals,
"-o PasswordAuthentication no -o ServerAliveInterval 30 hostname /bin/bash -s\n")
}
func (s *ExecuteSSHCommandSuite) TestIdentityFile(c *gc.C) {
s.fakeSSH(c, echoSSH)
response, err := ssh.ExecuteCommandOnMachine(ssh.ExecParams{
IdentityFile: "identity-file",
Host: "hostname",
Timeout: longWait,
})
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(response.Stderr), jc.Contains, " -i identity-file ")
}
func (s *ExecuteSSHCommandSuite) TestTimoutCaptureOutput(c *gc.C) {
s.fakeSSH(c, slowSSH)
response, err := ssh.ExecuteCommandOnMachine(ssh.ExecParams{
IdentityFile: "identity-file",
Host: "hostname",
Command: "ignored",
Timeout: shortWait,
})
c.Check(err, gc.ErrorMatches, "command timed out")
c.Assert(response.Code, gc.Equals, 0)
c.Assert(string(response.Stdout), gc.Equals, "stdout\n")
c.Assert(string(response.Stderr), gc.Equals, "stderr\n")
}
func (s *ExecuteSSHCommandSuite) TestCapturesReturnCode(c *gc.C) {
s.fakeSSH(c, passthroughSSH)
response, err := ssh.ExecuteCommandOnMachine(ssh.ExecParams{
IdentityFile: "identity-file",
Host: "hostname",
Command: "echo stdout; exit 42",
Timeout: longWait,
})
c.Check(err, jc.ErrorIsNil)
c.Assert(response.Code, gc.Equals, 42)
c.Assert(string(response.Stdout), gc.Equals, "stdout\n")
c.Assert(string(response.Stderr), gc.Equals, "")
}
// echoSSH outputs the command args to stderr, and copies stdin to stdout
var echoSSH = `#!/bin/bash
# Write the args to stderr
echo "$*" >&2
cat /dev/stdin
`
// slowSSH sleeps for a while after outputting some text to stdout and stderr
var slowSSH = `#!/bin/bash
echo "stderr" >&2
echo "stdout"
sleep 5s
`
// passthroughSSH creates an ssh that executes stdin.
var passthroughSSH = `#!/bin/bash -s`
================================================
FILE: ssh/ssh.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// Package ssh contains utilities for dealing with SSH connections,
// key management, and so on. All SSH-based command executions in
// Juju should use the Command/ScpCommand functions in this package.
package ssh
import (
"bytes"
"io"
"os/exec"
"syscall"
"github.com/juju/errors"
"github.com/juju/utils/v4"
)
// StrictHostChecksOption defines the possible values taken by
// Option.SetStrictHostKeyChecking().
type StrictHostChecksOption int
const (
// StrictHostChecksDefault configures the default,
// implementation-specific, behaviour.
//
// For the OpenSSH implementation, this elides the
// StrictHostKeyChecking option, which means the
// user's personal configuration will be used.
//
// For the go.crypto implementation, the default is
// the equivalent of "ask".
StrictHostChecksDefault StrictHostChecksOption = iota
// StrictHostChecksNo disables strict host key checking.
StrictHostChecksNo
// StrictHostChecksYes enabled strict host key checking
// enabled. Target hosts must appear in known_hosts file or
// connections will fail.
StrictHostChecksYes
// StrictHostChecksAsk will cause openssh to ask the user about
// hosts that don't appear in known_hosts file.
StrictHostChecksAsk
)
// Options is a client-implementation independent SSH options set.
type Options struct {
// proxyCommand specifies the command to
// execute to proxy SSH traffic through.
proxyCommand []string
// ssh server port; zero means use the default (22)
port int
// no PTY forced by default
allocatePTY bool
// password authentication is disallowed by default
passwordAuthAllowed bool
// identities is a sequence of paths to private key/identity files
// to use when attempting to login. A client implementaton may attempt
// with additional identities, but must give preference to these
identities []string
// knownHostsFile is a path to a file in which to save the host's
// fingerprint.
knownHostsFile string
// strictHostKeyChecking sets that the host being connected to must
// exist in the known_hosts file, and with a matching public key.
strictHostKeyChecking StrictHostChecksOption
// hostKeyAlgorithms sets the host key types that the client will
// accept from the server, in order of preference. By default the
// client implementation will specify a set of reasonable types.
hostKeyAlgorithms []string
}
// SetProxyCommand sets a command to execute to proxy traffic through.
func (o *Options) SetProxyCommand(command ...string) {
o.proxyCommand = append([]string{}, command...)
}
// SetPort sets the SSH server port to connect to.
func (o *Options) SetPort(port int) {
o.port = port
}
// EnablePTY forces the allocation of a pseudo-TTY.
//
// Forcing a pseudo-TTY is required, for example, for sudo
// prompts on the target host.
func (o *Options) EnablePTY() {
o.allocatePTY = true
}
// SetKnownHostsFile sets the host's fingerprint to be saved in the given file.
//
// Host fingerprints are saved in ~/.ssh/known_hosts by default.
func (o *Options) SetKnownHostsFile(file string) {
o.knownHostsFile = file
}
// SetStrictHostKeyChecking sets the desired host key checking
// behaviour. It takes one of the StrictHostChecksOption constants.
// See also EnableStrictHostKeyChecking.
func (o *Options) SetStrictHostKeyChecking(value StrictHostChecksOption) {
o.strictHostKeyChecking = value
}
// AllowPasswordAuthentication allows the SSH
// client to prompt the user for a password.
//
// Password authentication is disallowed by default.
func (o *Options) AllowPasswordAuthentication() {
o.passwordAuthAllowed = true
}
// SetIdentities sets a sequence of paths to private key/identity files
// to use when attempting login. Client implementations may attempt to
// use additional identities, but must give preference to the ones
// specified here.
func (o *Options) SetIdentities(identityFiles ...string) {
o.identities = append([]string{}, identityFiles...)
}
// SetHostKeyAlgorithms sets the host key types that the client will
// accept from the server, in order of preference. If not specified,
// the client implementation may choose its own defaults.
func (o *Options) SetHostKeyAlgorithms(algos ...string) {
o.hostKeyAlgorithms = algos
}
// Client is an interface for SSH clients to implement
type Client interface {
// Command returns a Command for executing a command
// on the specified host. Each Command is executed
// within its own SSH session.
//
// Host is specified in the format [user@]host.
Command(host string, command []string, options *Options) *Cmd
// Copy copies file(s) between local and remote host(s).
// Paths are specified in the scp format, [[user@]host:]path. If
// any extra arguments are specified in extraArgs, they are passed
// verbatim.
Copy(args []string, options *Options) error
}
// Cmd represents a command to be (or being) executed
// on a remote host.
type Cmd struct {
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
impl command
}
func newCmd(impl command) *Cmd {
return &Cmd{impl: impl}
}
// CombinedOutput runs the command, and returns the
// combined stdout/stderr output and result of
// executing the command.
func (c *Cmd) CombinedOutput() ([]byte, error) {
if c.Stdout != nil {
return nil, errors.New("ssh: Stdout already set")
}
if c.Stderr != nil {
return nil, errors.New("ssh: Stderr already set")
}
var b bytes.Buffer
c.Stdout = &b
c.Stderr = &b
err := c.Run()
return b.Bytes(), err
}
// Output runs the command, and returns the stdout
// output and result of executing the command.
func (c *Cmd) Output() ([]byte, error) {
if c.Stdout != nil {
return nil, errors.New("ssh: Stdout already set")
}
var b bytes.Buffer
c.Stdout = &b
err := c.Run()
return b.Bytes(), err
}
// Run runs the command, and returns the result as an error.
func (c *Cmd) Run() error {
if err := c.Start(); err != nil {
return err
}
err := c.Wait()
if exitError, ok := err.(*exec.ExitError); ok && exitError != nil {
status := exitError.ProcessState.Sys().(syscall.WaitStatus)
if status.Exited() {
return utils.NewRcPassthroughError(status.ExitStatus())
}
}
return err
}
// Start starts the command running, but does not wait for
// it to complete. If the command could not be started, an
// error is returned.
func (c *Cmd) Start() error {
c.impl.SetStdio(c.Stdin, c.Stdout, c.Stderr)
return c.impl.Start()
}
// Wait waits for the started command to complete,
// and returns the result as an error.
func (c *Cmd) Wait() error {
return c.impl.Wait()
}
// Kill kills the started command.
func (c *Cmd) Kill() error {
return c.impl.Kill()
}
// StdinPipe creates a pipe and connects it to
// the command's stdin. The read end of the pipe
// is assigned to c.Stdin.
func (c *Cmd) StdinPipe() (io.WriteCloser, error) {
wc, r, err := c.impl.StdinPipe()
if err != nil {
return nil, err
}
c.Stdin = r
return wc, nil
}
// StdoutPipe creates a pipe and connects it to
// the command's stdout. The write end of the pipe
// is assigned to c.Stdout.
func (c *Cmd) StdoutPipe() (io.ReadCloser, error) {
rc, w, err := c.impl.StdoutPipe()
if err != nil {
return nil, err
}
c.Stdout = w
return rc, nil
}
// StderrPipe creates a pipe and connects it to
// the command's stderr. The write end of the pipe
// is assigned to c.Stderr.
func (c *Cmd) StderrPipe() (io.ReadCloser, error) {
rc, w, err := c.impl.StderrPipe()
if err != nil {
return nil, err
}
c.Stderr = w
return rc, nil
}
// command is an implementation-specific representation of a
// command prepared to execute against a specific host.
type command interface {
Start() error
Wait() error
Kill() error
SetStdio(stdin io.Reader, stdout, stderr io.Writer)
StdinPipe() (io.WriteCloser, io.Reader, error)
StdoutPipe() (io.ReadCloser, io.Writer, error)
StderrPipe() (io.ReadCloser, io.Writer, error)
}
// DefaultClient is the default SSH client for the process.
//
// If the OpenSSH client is found in $PATH, then it will be
// used for DefaultClient; otherwise, DefaultClient will use
// an embedded client based on go.crypto/ssh.
var DefaultClient Client
// chosenClient holds the type of SSH client created for
// DefaultClient, so that we can log it in Command or Copy.
var chosenClient string
func init() {
initDefaultClient()
}
func initDefaultClient() {
if client, err := NewOpenSSHClient(); err == nil {
DefaultClient = client
chosenClient = "OpenSSH"
} else if client, err := NewGoCryptoClient(); err == nil {
DefaultClient = client
chosenClient = "go.crypto (embedded)"
}
}
// Command is a short-cut for DefaultClient.Command.
func Command(host string, command []string, options *Options) *Cmd {
logger.Debugf("using %s ssh client", chosenClient)
return DefaultClient.Command(host, command, options)
}
// Copy is a short-cut for DefaultClient.Copy.
func Copy(args []string, options *Options) error {
logger.Debugf("using %s ssh client", chosenClient)
return DefaultClient.Copy(args, options)
}
// CopyReader sends the reader's data to a file on the remote host over SSH.
func CopyReader(host, filename string, r io.Reader, options *Options) error {
logger.Debugf("using %s ssh client", chosenClient)
return copyReader(DefaultClient, host, filename, r, options)
}
func copyReader(client Client, host, filename string, r io.Reader, options *Options) error {
cmd := client.Command(host, []string{"cat - > " + filename}, options)
cmd.Stdin = r
return errors.Trace(cmd.Run())
}
================================================
FILE: ssh/ssh_gocrypto.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"os/exec"
"os/user"
"strconv"
"strings"
"sync"
"time"
"github.com/juju/clock"
"github.com/juju/errors"
"github.com/juju/mutex/v2"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
"golang.org/x/crypto/ssh/terminal"
"github.com/juju/utils/v4"
)
const sshDefaultPort = 22
// GoCryptoClient is an implementation of Client that
// uses the embedded go.crypto/ssh SSH client.
//
// GoCryptoClient is intentionally limited in the
// functionality that it enables, as it is currently
// intended to be used only for non-interactive command
// execution.
type GoCryptoClient struct {
signers []ssh.Signer
}
// NewGoCryptoClient creates a new GoCryptoClient.
//
// If no signers are specified, NewGoCryptoClient will
// use the private key generated by LoadClientKeys.
func NewGoCryptoClient(signers ...ssh.Signer) (*GoCryptoClient, error) {
return &GoCryptoClient{signers: signers}, nil
}
// Command implements Client.Command.
func (c *GoCryptoClient) Command(host string, command []string, options *Options) *Cmd {
shellCommand := utils.CommandString(command...)
signers := c.signers
if len(signers) == 0 {
signers = privateKeys()
}
user, host := splitUserHost(host)
port := sshDefaultPort
var proxyCommand []string
var knownHostsFile string
var strictHostKeyChecking StrictHostChecksOption
var hostKeyAlgorithms []string
if options != nil {
if options.port != 0 {
port = options.port
}
proxyCommand = options.proxyCommand
knownHostsFile = options.knownHostsFile
strictHostKeyChecking = options.strictHostKeyChecking
hostKeyAlgorithms = options.hostKeyAlgorithms
}
logger.Tracef(`running (equivalent of): ssh "%s@%s" -p %d '%s'`, user, host, port, shellCommand)
return &Cmd{impl: &goCryptoCommand{
signers: signers,
user: user,
addr: net.JoinHostPort(host, strconv.Itoa(port)),
command: shellCommand,
proxyCommand: proxyCommand,
knownHostsFile: knownHostsFile,
strictHostKeyChecking: strictHostKeyChecking,
hostKeyAlgorithms: hostKeyAlgorithms,
}}
}
// Copy implements Client.Copy.
//
// Copy is currently unimplemented, and will always return an error.
func (c *GoCryptoClient) Copy(args []string, options *Options) error {
return errors.Errorf("scp command is not implemented (OpenSSH scp not available in PATH)")
}
type goCryptoCommand struct {
signers []ssh.Signer
user string
addr string
command string
proxyCommand []string
knownHostsFile string
strictHostKeyChecking StrictHostChecksOption
hostKeyAlgorithms []string
stdin io.Reader
stdout io.Writer
stderr io.Writer
client *ssh.Client
sess *ssh.Session
}
var sshDial = ssh.Dial
var sshDialWithProxy = func(addr string, proxyCommand []string, config *ssh.ClientConfig) (*ssh.Client, error) {
if len(proxyCommand) == 0 {
return sshDial("tcp", addr, config)
}
// User has specified a proxy. Create a pipe and
// redirect the proxy command's stdin/stdout to it.
host, port, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
for i, arg := range proxyCommand {
arg = strings.Replace(arg, "%h", host, -1)
if port != "" {
arg = strings.Replace(arg, "%p", port, -1)
}
arg = strings.Replace(arg, "%r", config.User, -1)
proxyCommand[i] = arg
}
client, server := net.Pipe()
logger.Tracef(`executing proxy command %q`, proxyCommand)
cmd := exec.Command(proxyCommand[0], proxyCommand[1:]...)
cmd.Stdin = server
cmd.Stdout = server
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return nil, err
}
conn, chans, reqs, err := ssh.NewClientConn(client, addr, config)
if err != nil {
return nil, err
}
return ssh.NewClient(conn, chans, reqs), nil
}
func (c *goCryptoCommand) ensureSession() (*ssh.Session, error) {
if c.sess != nil {
return c.sess, nil
}
if len(c.signers) == 0 {
return nil, errors.Errorf("no private keys available")
}
if c.user == "" {
currentUser, err := user.Current()
if err != nil {
return nil, errors.Errorf("getting current user: %v", err)
}
c.user = currentUser.Username
}
config := &ssh.ClientConfig{
User: c.user,
HostKeyCallback: c.hostKeyCallback,
HostKeyAlgorithms: c.hostKeyAlgorithms,
Auth: []ssh.AuthMethod{
ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
return c.signers, nil
}),
},
}
client, err := sshDialWithProxy(c.addr, c.proxyCommand, config)
if err != nil {
return nil, err
}
sess, err := client.NewSession()
if err != nil {
client.Close()
return nil, err
}
c.client = client
c.sess = sess
c.sess.Stdin = WrapStdin(c.stdin)
c.sess.Stdout = c.stdout
c.sess.Stderr = c.stderr
return sess, nil
}
func (c *goCryptoCommand) Start() error {
sess, err := c.ensureSession()
if err != nil {
return err
}
if c.command == "" {
return sess.Shell()
}
return sess.Start(c.command)
}
func (c *goCryptoCommand) Close() error {
if c.sess == nil {
return nil
}
err0 := c.sess.Close()
err1 := c.client.Close()
if err0 == nil {
err0 = err1
}
c.sess = nil
c.client = nil
return err0
}
func (c *goCryptoCommand) Wait() error {
if c.sess == nil {
return errors.Errorf("command has not been started")
}
err := c.sess.Wait()
c.Close()
return err
}
func (c *goCryptoCommand) Kill() error {
if c.sess == nil {
return errors.Errorf("command has not been started")
}
return c.sess.Signal(ssh.SIGKILL)
}
func (c *goCryptoCommand) SetStdio(stdin io.Reader, stdout, stderr io.Writer) {
c.stdin = stdin
c.stdout = stdout
c.stderr = stderr
}
func (c *goCryptoCommand) StdinPipe() (io.WriteCloser, io.Reader, error) {
sess, err := c.ensureSession()
if err != nil {
return nil, nil, err
}
wc, err := sess.StdinPipe()
return wc, sess.Stdin, err
}
func (c *goCryptoCommand) StdoutPipe() (io.ReadCloser, io.Writer, error) {
sess, err := c.ensureSession()
if err != nil {
return nil, nil, err
}
wc, err := sess.StdoutPipe()
return ioutil.NopCloser(wc), sess.Stdout, err
}
func (c *goCryptoCommand) StderrPipe() (io.ReadCloser, io.Writer, error) {
sess, err := c.ensureSession()
if err != nil {
return nil, nil, err
}
wc, err := sess.StderrPipe()
return ioutil.NopCloser(wc), sess.Stderr, err
}
func (c *goCryptoCommand) hostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
knownHostsFile := c.knownHostsFile
if knownHostsFile == "" {
knownHostsFile = GoCryptoKnownHostsFile()
if knownHostsFile == "" {
return errors.New("known_hosts file not configured")
}
}
var printError func(string) error
term, cleanupTerm, err := getTerminal()
if err != nil {
return errors.Trace(err)
} else if term != nil {
defer cleanupTerm()
printError = func(message string) error {
_, err := fmt.Fprintln(term, message)
return err
}
} else {
printError = func(message string) error {
logger.Errorf("%s", message)
return nil
}
}
matched, err := checkHostKey(hostname, remote, key, knownHostsFile, printError)
if err != nil || matched {
return errors.Trace(err)
}
// We did not find a matching key, so what we do next depends on the
// strict host key checking configuration.
var warnAdd bool
switch c.strictHostKeyChecking {
case StrictHostChecksNo:
// Don't ask, just add.
warnAdd = true
case StrictHostChecksDefault, StrictHostChecksAsk:
message := fmt.Sprintf(`The authenticity of host '%s (%s)' can't be established.
%s key fingerprint is %s.
`,
hostname,
remote,
key.Type(),
ssh.FingerprintSHA256(key),
)
if term == nil {
// If we're not running in a terminal,
// we can't ask the user if they want
// to accept.
logger.Errorf("%s", message)
return errors.New("not running in a terminal, cannot prompt for verification")
}
// Prompt user, asking if they trust the key.
fmt.Fprint(term, message+"Are you sure you want to continue connecting (yes/no)? ")
for {
line, err := term.ReadLine()
if err != nil {
return errors.Trace(err)
}
var yes bool
switch strings.ToLower(line) {
case "yes":
yes = true
case "no":
return errors.New("Host key verification failed.")
default:
fmt.Fprint(term, "Please type 'yes' or 'no': ")
}
if yes {
break
}
}
default:
return errors.Errorf(
`no %s host key is known for %s and you have requested strict checking`,
key.Type(), hostname,
)
}
if knownHostsFile != os.DevNull {
// Make sure no other process modifies the file.
releaser, err := mutex.Acquire(mutex.Spec{
Name: "juju-ssh-client",
Clock: clock.WallClock,
Delay: time.Second,
})
if err != nil {
return errors.Trace(err)
}
defer releaser.Release()
// Write the file atomically, so the initial ReadAll above
// doesn't have to hold the mutex.
knownHostsData, err := ioutil.ReadFile(knownHostsFile)
if err != nil && !os.IsNotExist(err) {
return errors.Trace(err)
}
buf := bytes.NewBuffer(knownHostsData)
if len(knownHostsData) > 0 && !bytes.HasSuffix(knownHostsData, []byte("\n")) {
buf.WriteRune('\n')
}
buf.WriteString(knownhosts.Line([]string{hostname}, key))
buf.WriteRune('\n')
if err := utils.AtomicWriteFile(knownHostsFile, buf.Bytes(), 0600); err != nil {
return errors.Trace(err)
}
}
if warnAdd {
printError(fmt.Sprintf(
"Warning: permanently added '%s' (%s) to the list of known hosts.",
hostname, key.Type(),
))
}
return nil
}
type readLineWriter interface {
io.Writer
ReadLine() (string, error)
}
var getTerminal = func() (readLineWriter, func(), error) {
if fd := int(os.Stdin.Fd()); terminal.IsTerminal(fd) {
oldState, err := terminal.MakeRaw(fd)
if err != nil {
return nil, nil, errors.Trace(err)
}
cleanup := func() { terminal.Restore(fd, oldState) }
return terminal.NewTerminal(os.Stdin, ""), cleanup, nil
}
return nil, nil, nil
}
// checkHostKey checks the given (hostname, address, public key) tuple
// against the local known-hosts database, if it exists, and returns a
// boolean indicating whether a match was found, and any errors encountered.
func checkHostKey(
hostname string,
remote net.Addr,
key ssh.PublicKey,
knownHostsFile string,
printError func(string) error,
) (bool, error) {
// NOTE(axw) the knownhosts code is incomplete, but enough for
// our limited use cases. We do not support parsing a known_hosts
// file managed by OpenSSH (due to hashed hosts, etc.), but that
// is OK since this client exists only to support systems that
// do not have access to OpenSSH.
callback, err := knownhosts.New(knownHostsFile)
if err != nil {
if os.IsNotExist(err) {
// The known_hosts file does not exist.
return false, nil
}
return false, errors.Trace(err)
}
err = callback(hostname, remote, key)
switch err := err.(type) {
case nil:
// Known host with matching key.
return true, nil
case *knownhosts.KeyError:
if len(err.Want) == 0 {
// Unknown host.
return false, nil
}
head := fmt.Sprintf(`
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!
Someone could be eavesdropping on you right now (man-in-the-middle attack)!
It is also possible that a host key has just been changed.
The fingerprint for the %s key sent by the remote host is
%s.
Please contact your system administrator.
Add correct host key in %s to get rid of this message.
`[1:], key.Type(), ssh.FingerprintSHA256(key), knownHostsFile)
var typeKey *knownhosts.KnownKey
for i, knownKey := range err.Want {
if knownKey.Key.Type() == key.Type() {
typeKey = &err.Want[i]
}
}
var tail string
if typeKey != nil {
tail = fmt.Sprintf(
"Offending %s key in %s:%d",
typeKey.Key.Type(),
typeKey.Filename,
typeKey.Line,
)
} else {
tail = "Host was previously using different host key algorithms:"
for _, knownKey := range err.Want {
tail += fmt.Sprintf(
"\n - %s key in %s:%d",
knownKey.Key.Type(),
knownKey.Filename,
knownKey.Line,
)
}
}
if err := printError(head + tail); err != nil {
// Not being able to display the warning
// should be considered fatal.
return false, errors.Annotate(
err, "failed to print host key mismatch warning",
)
}
}
return false, errors.Trace(err)
}
func splitUserHost(s string) (user, host string) {
userHost := strings.SplitN(s, "@", 2)
if len(userHost) == 2 {
return userHost[0], userHost[1]
}
return "", userHost[0]
}
var (
goCryptoKnownHostsMutex sync.Mutex
goCryptoKnownHostsFile string
)
// GoCryptoKnownHostsFile returns the known_hosts file used
// by the golang.org/x/crypto/ssh-based client by default.
func GoCryptoKnownHostsFile() string {
goCryptoKnownHostsMutex.Lock()
defer goCryptoKnownHostsMutex.Unlock()
return goCryptoKnownHostsFile
}
// SetGoCryptoKnownHostsFile returns the known_hosts file used
// by the golang.org/x/crypto/ssh-based client.
func SetGoCryptoKnownHostsFile(file string) {
goCryptoKnownHostsMutex.Lock()
defer goCryptoKnownHostsMutex.Unlock()
goCryptoKnownHostsFile = file
}
================================================
FILE: ssh/ssh_gocrypto_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh_test
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"os/exec"
"path/filepath"
"regexp"
"sync"
"time"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
cryptossh "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/testdata"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/ssh"
)
var (
testCommand = []string{"echo", "$abc"}
testCommandFlat = `echo "\$abc"`
)
type sshServer struct {
cfg *cryptossh.ServerConfig
listener net.Listener
client *cryptossh.Client
}
func (s *sshServer) run(errorCh chan error, done chan bool) {
netconn, err := s.listener.Accept()
if err != nil {
errorCh <- fmt.Errorf("accepting connection: %w", err)
return
}
defer netconn.Close()
conn, chans, reqs, err := cryptossh.NewServerConn(netconn, s.cfg)
if err != nil {
errorCh <- fmt.Errorf("getting ssh server connection: %w", err)
return
}
s.client = cryptossh.NewClient(conn, chans, reqs)
var wg sync.WaitGroup
defer func() {
wg.Wait()
close(errorCh)
}()
sessionChannels := s.client.HandleChannelOpen("session")
select {
case <-done:
return
case newChannel := <-sessionChannels:
if sCh := newChannel.ChannelType(); sCh != "session" {
errorCh <- fmt.Errorf("unexpected session channel %q", sCh)
return
}
channel, reqs, err := newChannel.Accept()
if err != nil {
errorCh <- fmt.Errorf("accepting session connection: %w", err)
return
}
wg.Add(1)
go func() {
defer wg.Done()
defer channel.Close()
for req := range reqs {
switch req.Type {
case "exec":
if !req.WantReply {
errorCh <- fmt.Errorf("no reply wanted for request %+v", req)
return
}
n := binary.BigEndian.Uint32(req.Payload[:4])
command := string(req.Payload[4 : n+4])
if command != testCommandFlat {
errorCh <- fmt.Errorf("unexpected request command: %q", command)
return
}
err = req.Reply(true, nil)
if err != nil {
errorCh <- fmt.Errorf("error sending reply: %w", err)
return
}
channel.Write([]byte("abc value\n"))
_, err := channel.SendRequest("exit-status", false, cryptossh.Marshal(&struct{ n uint32 }{0}))
if err != nil {
errorCh <- fmt.Errorf("error sending request: %w", err)
}
return
default:
errorCh <- fmt.Errorf("unexpected request type: %q", req.Type)
return
}
}
}()
}
}
func newClient(c *gc.C) (*ssh.GoCryptoClient, cryptossh.PublicKey) {
private, _, err := ssh.GenerateKey("test-client")
c.Assert(err, jc.ErrorIsNil)
key, err := cryptossh.ParsePrivateKey([]byte(private))
c.Assert(err, jc.ErrorIsNil)
client, err := ssh.NewGoCryptoClient(key)
c.Assert(err, jc.ErrorIsNil)
return client, key.PublicKey()
}
type SSHGoCryptoCommandSuite struct {
testing.IsolationSuite
client ssh.Client
knownHostsFile string
testPrivateKeys map[string]any
testSigners map[string]cryptossh.Signer
testPublicKeys map[string]cryptossh.PublicKey
}
var _ = gc.Suite(&SSHGoCryptoCommandSuite{})
func (s *SSHGoCryptoCommandSuite) SetUpSuite(c *gc.C) {
s.IsolationSuite.SetUpSuite(c)
var err error
n := len(testdata.PEMBytes)
s.testPrivateKeys = make(map[string]any, n)
s.testSigners = make(map[string]cryptossh.Signer, n)
s.testPublicKeys = make(map[string]cryptossh.PublicKey, n)
for t, k := range testdata.PEMBytes {
s.testPrivateKeys[t], err = cryptossh.ParseRawPrivateKey(k)
c.Assert(err, jc.ErrorIsNil)
s.testSigners[t], err = cryptossh.NewSignerFromKey(s.testPrivateKeys[t])
c.Assert(err, jc.ErrorIsNil)
s.testPublicKeys[t] = s.testSigners[t].PublicKey()
}
// Create a cert and sign it for use in tests.
testCert := &cryptossh.Certificate{
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
ValidAfter: 0, // unix epoch
ValidBefore: cryptossh.CertTimeInfinity, // The end of currently representable time.
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
Key: s.testPublicKeys["ecdsa"],
SignatureKey: s.testPublicKeys["ed25519"],
Permissions: cryptossh.Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
},
}
err = testCert.SignCert(rand.Reader, s.testSigners["ed25519"])
c.Assert(err, jc.ErrorIsNil)
s.testPrivateKeys["cert"] = s.testPrivateKeys["ecdsa"]
s.testSigners["cert"], err = cryptossh.NewCertSigner(testCert, s.testSigners["ecdsa"])
c.Assert(err, jc.ErrorIsNil)
}
func (s *SSHGoCryptoCommandSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
generateKeyRestorer := overrideGenerateKey()
s.AddCleanup(func(*gc.C) { generateKeyRestorer.Restore() })
s.knownHostsFile = filepath.Join(c.MkDir(), "known_hosts")
ssh.SetGoCryptoKnownHostsFile(s.knownHostsFile)
ssh.PatchNilTerminal(&s.CleanupSuite)
}
func (s *SSHGoCryptoCommandSuite) newServer(c *gc.C, serverConfig cryptossh.ServerConfig) (*sshServer, cryptossh.PublicKey) {
server := &sshServer{cfg: &serverConfig}
server.cfg.AddHostKey(s.testSigners["ed25519"])
var err error
server.listener, err = net.Listen("tcp", "127.0.0.1:0")
c.Assert(err, jc.ErrorIsNil)
c.Logf("Server listening on %s", server.listener.Addr().String())
return server, s.testPublicKeys["ed25519"]
}
func (s *SSHGoCryptoCommandSuite) TestNewGoCryptoClient(c *gc.C) {
_, err := ssh.NewGoCryptoClient()
c.Assert(err, jc.ErrorIsNil)
private, _, err := ssh.GenerateKey("test-client")
c.Assert(err, jc.ErrorIsNil)
key, err := cryptossh.ParsePrivateKey([]byte(private))
c.Assert(err, jc.ErrorIsNil)
_, err = ssh.NewGoCryptoClient(key)
c.Assert(err, jc.ErrorIsNil)
}
func (s *SSHGoCryptoCommandSuite) TestClientNoKeys(c *gc.C) {
client, err := ssh.NewGoCryptoClient()
c.Assert(err, jc.ErrorIsNil)
cmd := client.Command("0.1.2.3", []string{"echo", "123"}, nil)
_, err = cmd.Output()
c.Assert(err, gc.ErrorMatches, "no private keys available")
defer ssh.ClearClientKeys()
err = ssh.LoadClientKeys(c.MkDir())
c.Assert(err, jc.ErrorIsNil)
s.PatchValue(ssh.SSHDial, func(network, address string, cfg *cryptossh.ClientConfig) (*cryptossh.Client, error) {
return nil, errors.New("ssh.Dial failed")
})
cmd = client.Command("0.1.2.3", []string{"echo", "123"}, nil)
_, err = cmd.Output()
// error message differs based on whether using cgo or not
c.Assert(err, gc.ErrorMatches, "ssh.Dial failed")
}
func waitForServer(c *gc.C, errorCh chan error) error {
select {
case err, _ := <-errorCh:
return err
case <-time.After(testing.LongWait):
c.Fatal("timed out waiting for ssh server")
return nil
}
}
func (s *SSHGoCryptoCommandSuite) TestCommand(c *gc.C) {
client, clientKey := newClient(c)
server, serverKey := s.newServer(c, cryptossh.ServerConfig{})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
var opts ssh.Options
opts.SetPort(serverPort)
opts.SetStrictHostKeyChecking(ssh.StrictHostChecksNo)
cmd := client.Command("127.0.0.1", testCommand, &opts)
checkedKey := false
server.cfg.PublicKeyCallback = func(conn cryptossh.ConnMetadata, pubkey cryptossh.PublicKey) (*cryptossh.Permissions, error) {
c.Check(pubkey, gc.DeepEquals, clientKey)
checkedKey = true
return nil, nil
}
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)
out, err := cmd.Output()
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(out), gc.Equals, "abc value\n")
c.Assert(checkedKey, jc.IsTrue)
knownHosts, err := ioutil.ReadFile(s.knownHostsFile)
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(knownHosts), gc.Equals, fmt.Sprintf(
"[127.0.0.1]:%d %s",
serverPort,
cryptossh.MarshalAuthorizedKey(serverKey)),
)
c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil)
}
func (s *SSHGoCryptoCommandSuite) TestCopy(c *gc.C) {
client, err := ssh.NewGoCryptoClient()
c.Assert(err, jc.ErrorIsNil)
err = client.Copy([]string{"0.1.2.3:b", c.MkDir()}, nil)
c.Assert(err, gc.ErrorMatches, `scp command is not implemented \(OpenSSH scp not available in PATH\)`)
}
func (s *SSHGoCryptoCommandSuite) TestProxyCommand(c *gc.C) {
realNetcat, err := exec.LookPath("nc")
if err != nil {
c.Skip("skipping test, couldn't find netcat: %v")
return
}
netcat := filepath.Join(c.MkDir(), "nc")
err = ioutil.WriteFile(netcat, []byte("#!/bin/sh\necho $0 \"$@\" > $0.args && exec "+realNetcat+" \"$@\""), 0755)
c.Assert(err, jc.ErrorIsNil)
client, _ := newClient(c)
server, _ := s.newServer(c, cryptossh.ServerConfig{})
var opts ssh.Options
port := server.listener.Addr().(*net.TCPAddr).Port
opts.SetProxyCommand(netcat, "-q0", "%h", "%p")
opts.SetPort(port)
cmd := client.Command("127.0.0.1", testCommand, &opts)
server.cfg.PublicKeyCallback = func(_ cryptossh.ConnMetadata, pubkey cryptossh.PublicKey) (*cryptossh.Permissions, error) {
return nil, nil
}
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)
out, err := cmd.Output()
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(out), gc.Equals, "abc value\n")
// Ensure the proxy command was executed with the appropriate arguments.
data, err := ioutil.ReadFile(netcat + ".args")
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(data), gc.Equals, fmt.Sprintf("%s -q0 127.0.0.1 %v\n", netcat, port))
c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil)
}
func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksYes(c *gc.C) {
server, _ := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)
var opts ssh.Options
opts.SetPort(serverPort)
opts.SetStrictHostKeyChecking(ssh.StrictHostChecksYes)
client, _ := newClient(c)
cmd := client.Command("127.0.0.1", testCommand, &opts)
_, err := cmd.Output()
c.Assert(err, gc.ErrorMatches, fmt.Sprintf(
"ssh: handshake failed: no ssh-ed25519 host key is known for 127.0.0.1:%d and you have requested strict checking",
serverPort,
))
_, err = os.Stat(s.knownHostsFile)
c.Assert(err, jc.Satisfies, os.IsNotExist)
_ = waitForServer(c, errorCh)
}
func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskNonTerminal(c *gc.C) {
server, _ := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)
var opts ssh.Options
opts.SetPort(serverPort)
opts.SetStrictHostKeyChecking(ssh.StrictHostChecksAsk)
client, _ := newClient(c)
cmd := client.Command("127.0.0.1", testCommand, &opts)
_, err := cmd.Output()
c.Assert(err, gc.ErrorMatches, "ssh: handshake failed: not running in a terminal, cannot prompt for verification")
_, err = os.Stat(s.knownHostsFile)
c.Assert(err, jc.Satisfies, os.IsNotExist)
_ = waitForServer(c, errorCh)
}
func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalYes(c *gc.C) {
var readLineWriter mockReadLineWriter
ssh.PatchTerminal(&s.CleanupSuite, &readLineWriter)
readLineWriter.addLine("")
readLineWriter.addLine("yes")
server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)
var opts ssh.Options
opts.SetPort(serverPort)
opts.SetStrictHostKeyChecking(ssh.StrictHostChecksAsk)
client, _ := newClient(c)
cmd := client.Command("127.0.0.1", testCommand, &opts)
_, err := cmd.Output()
c.Assert(err, jc.ErrorIsNil)
knownHosts, err := ioutil.ReadFile(s.knownHostsFile)
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(knownHosts), gc.Equals, fmt.Sprintf(
"[127.0.0.1]:%d %s",
serverPort,
cryptossh.MarshalAuthorizedKey(serverKey),
))
c.Assert(readLineWriter.written.String(), gc.Equals, fmt.Sprintf(`
The authenticity of host '127.0.0.1:%[1]d (127.0.0.1:%[1]d)' can't be established.
ssh-ed25519 key fingerprint is %[2]s.
Are you sure you want to continue connecting (yes/no)? Please type 'yes' or 'no': `[1:],
serverPort, cryptossh.FingerprintSHA256(serverKey)))
c.Assert(waitForServer(c, errorCh), jc.ErrorIsNil)
}
func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksAskTerminalNo(c *gc.C) {
var readLineWriter mockReadLineWriter
ssh.PatchTerminal(&s.CleanupSuite, &readLineWriter)
readLineWriter.addLine("no")
server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)
var opts ssh.Options
opts.SetPort(serverPort)
opts.SetStrictHostKeyChecking(ssh.StrictHostChecksAsk)
client, _ := newClient(c)
cmd := client.Command("127.0.0.1", testCommand, &opts)
_, err := cmd.Output()
c.Assert(err, gc.ErrorMatches, "ssh: handshake failed: Host key verification failed.")
_, err = os.Stat(s.knownHostsFile)
c.Assert(err, jc.Satisfies, os.IsNotExist)
c.Assert(readLineWriter.written.String(), gc.Equals, fmt.Sprintf(`
The authenticity of host '127.0.0.1:%[1]d (127.0.0.1:%[1]d)' can't be established.
ssh-ed25519 key fingerprint is %[2]s.
Are you sure you want to continue connecting (yes/no)? `[1:],
serverPort, cryptossh.FingerprintSHA256(serverKey)))
_ = waitForServer(c, errorCh)
}
func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksNoMismatch(c *gc.C) {
var readLineWriter mockReadLineWriter
ssh.PatchTerminal(&s.CleanupSuite, &readLineWriter)
server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)
// Write a mismatching key to the known_hosts file. Even with
// StrictHostChecksNo, we should be verifying against an existing
// host key.
_, alternativeKey, err := generateED25519Key(rand.Reader)
c.Assert(err, jc.ErrorIsNil)
alternativePublicKey, err := cryptossh.NewPublicKey(alternativeKey.Public())
c.Assert(err, jc.ErrorIsNil)
err = ioutil.WriteFile(s.knownHostsFile, []byte(fmt.Sprintf(
"[127.0.0.1]:%d %s",
serverPort,
cryptossh.MarshalAuthorizedKey(alternativePublicKey),
)), 0600)
c.Assert(err, jc.ErrorIsNil)
var opts ssh.Options
opts.SetPort(serverPort)
opts.SetStrictHostKeyChecking(ssh.StrictHostChecksNo)
client, _ := newClient(c)
cmd := client.Command("127.0.0.1", testCommand, &opts)
_, err = cmd.Output()
c.Assert(err, gc.ErrorMatches, "ssh: handshake failed: knownhosts: key mismatch")
c.Assert(readLineWriter.written.String(), gc.Matches, fmt.Sprintf(`
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!
Someone could be eavesdropping on you right now \(man-in-the-middle attack\)!
It is also possible that a host key has just been changed.
The fingerprint for the ssh-ed25519 key sent by the remote host is
%s.
Please contact your system administrator.
Add correct host key in .*/known_hosts to get rid of this message.
Offending ssh-ed25519 key in .*/known_hosts:1
`[1:], regexp.QuoteMeta(cryptossh.FingerprintSHA256(serverKey))))
_ = waitForServer(c, errorCh)
}
func (s *SSHGoCryptoCommandSuite) TestStrictHostChecksDifferentKeyTypes(c *gc.C) {
var readLineWriter mockReadLineWriter
ssh.PatchTerminal(&s.CleanupSuite, &readLineWriter)
server, serverKey := s.newServer(c, cryptossh.ServerConfig{NoClientAuth: true})
serverPort := server.listener.Addr().(*net.TCPAddr).Port
errorCh := make(chan error, 1)
done := make(chan bool)
defer close(done)
go server.run(errorCh, done)
// Write a mismatching key to the known_hosts file with a different
// key type. Even with StrictHostChecksNo, we should be verifying
// against an existing host key.
dsaKey, err := generateDSAKey(rand.Reader)
c.Assert(err, jc.ErrorIsNil)
alternativePublicKey, err := cryptossh.NewPublicKey(&dsaKey.PublicKey)
c.Assert(err, jc.ErrorIsNil)
err = ioutil.WriteFile(s.knownHostsFile, []byte(fmt.Sprintf(
"[127.0.0.1]:%d %s",
serverPort,
cryptossh.MarshalAuthorizedKey(alternativePublicKey),
)), 0600)
c.Assert(err, jc.ErrorIsNil)
var opts ssh.Options
opts.SetPort(serverPort)
opts.SetStrictHostKeyChecking(ssh.StrictHostChecksNo)
client, _ := newClient(c)
cmd := client.Command("127.0.0.1", testCommand, &opts)
_, err = cmd.Output()
c.Assert(err, gc.ErrorMatches, "ssh: handshake failed: knownhosts: key mismatch")
c.Assert(readLineWriter.written.String(), gc.Matches, fmt.Sprintf(`
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!
Someone could be eavesdropping on you right now \(man-in-the-middle attack\)!
It is also possible that a host key has just been changed.
The fingerprint for the ssh-ed25519 key sent by the remote host is
%s.
Please contact your system administrator.
Add correct host key in .*/known_hosts to get rid of this message.
Host was previously using different host key algorithms:
- ssh-dss key in .*/known_hosts:1
`[1:], regexp.QuoteMeta(cryptossh.FingerprintSHA256(serverKey))))
_ = waitForServer(c, errorCh)
}
type mockReadLineWriter struct {
testing.Stub
lines []string
written bytes.Buffer
}
func (m *mockReadLineWriter) addLine(line string) {
m.lines = append(m.lines, line)
}
func (m *mockReadLineWriter) ReadLine() (string, error) {
m.MethodCall(m, "ReadLine")
if len(m.lines) == 0 {
return "", io.EOF
}
line := m.lines[0]
m.lines = m.lines[1:]
return line, nil
}
func (m *mockReadLineWriter) Write(data []byte) (int, error) {
m.MethodCall(m, "Write", data)
return m.written.Write(data)
}
================================================
FILE: ssh/ssh_openssh.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh
import (
"bytes"
"fmt"
"io"
"os"
"os/exec"
"strings"
"github.com/juju/errors"
"github.com/juju/utils/v4"
)
// default identities will not be attempted if
// -i is specified and they are not explcitly
// included.
var defaultIdentities = []string{
"~/.ssh/identity",
"~/.ssh/id_rsa",
"~/.ssh/id_dsa",
"~/.ssh/id_ecdsa",
"~/.ssh/id_ed25519",
}
type opensshCommandKind int
const (
sshKind opensshCommandKind = iota
scpKind
)
// sshpassWrap wraps the command/args with sshpass if it is found in $PATH
// and the SSHPASS environment variable is set. Otherwise, the original
// command/args are returned.
func sshpassWrap(cmd string, args []string) (string, []string) {
if os.Getenv("SSHPASS") != "" {
if path, err := exec.LookPath("sshpass"); err == nil {
return path, append([]string{"-e", cmd}, args...)
}
}
return cmd, args
}
// OpenSSHClient is an implementation of Client that
// uses the ssh and scp executables found in $PATH.
type OpenSSHClient struct{}
// NewOpenSSHClient creates a new OpenSSHClient.
// If the ssh and scp programs cannot be found
// in $PATH, then an error is returned.
func NewOpenSSHClient() (*OpenSSHClient, error) {
var c OpenSSHClient
if _, err := exec.LookPath("ssh"); err != nil {
return nil, err
}
if _, err := exec.LookPath("scp"); err != nil {
return nil, err
}
return &c, nil
}
func opensshOptions(options *Options, commandKind opensshCommandKind) []string {
if options == nil {
options = &Options{}
}
var args []string
var hostChecks string
switch options.strictHostKeyChecking {
case StrictHostChecksYes:
hostChecks = "yes"
case StrictHostChecksNo:
hostChecks = "no"
case StrictHostChecksAsk:
hostChecks = "ask"
default:
// StrictHostChecksUnset and invalid values are handled the
// same way (the option doesn't get included).
}
if hostChecks != "" {
args = append(args, "-o", "StrictHostKeyChecking "+hostChecks)
}
if len(options.proxyCommand) > 0 {
args = append(args, "-o", "ProxyCommand "+utils.CommandString(options.proxyCommand...))
}
if !options.passwordAuthAllowed {
args = append(args, "-o", "PasswordAuthentication no")
}
// We must set ServerAliveInterval or the server may
// think we've become unresponsive on long running
// command executions such as "apt-get upgrade".
args = append(args, "-o", "ServerAliveInterval 30")
if options.allocatePTY {
args = append(args, "-t", "-t") // twice to force
}
if options.knownHostsFile != "" {
args = append(args, "-o", "UserKnownHostsFile "+utils.CommandString(options.knownHostsFile))
}
if len(options.hostKeyAlgorithms) > 0 {
args = append(args, "-o", "HostKeyAlgorithms "+utils.CommandString(strings.Join(options.hostKeyAlgorithms, ",")))
}
identities := append([]string{}, options.identities...)
if pk := PrivateKeyFiles(); len(pk) > 0 {
// Add client keys as implicit identities
identities = append(identities, pk...)
}
// If any identities are specified, the
// default ones must be explicitly specified.
if len(identities) > 0 {
for _, identity := range defaultIdentities {
path, err := utils.NormalizePath(identity)
if err != nil {
logger.Warningf("failed to normalize path %q: %v", identity, err)
continue
}
if _, err := os.Stat(path); err == nil {
identities = append(identities, path)
}
}
}
for _, identity := range identities {
args = append(args, "-i", identity)
}
if options.port != 0 {
port := fmt.Sprint(options.port)
if commandKind == scpKind {
// scp uses -P instead of -p (-p means preserve).
args = append(args, "-P", port)
} else {
args = append(args, "-p", port)
}
}
return args
}
// Command implements Client.Command.
func (c *OpenSSHClient) Command(host string, command []string, options *Options) *Cmd {
args := opensshOptions(options, sshKind)
args = append(args, host)
if len(command) > 0 {
args = append(args, command...)
}
bin, args := sshpassWrap("ssh", args)
logger.Tracef("running: %s %s", bin, utils.CommandString(args...))
return &Cmd{impl: &opensshCmd{exec.Command(bin, args...)}}
}
// Copy implements Client.Copy.
func (c *OpenSSHClient) Copy(args []string, userOptions *Options) error {
var options Options
if userOptions != nil {
options = *userOptions
options.allocatePTY = false // doesn't make sense for scp
}
allArgs := opensshOptions(&options, scpKind)
allArgs = append(allArgs, args...)
bin, allArgs := sshpassWrap("scp", allArgs)
cmd := exec.Command(bin, allArgs...)
var stderr bytes.Buffer
cmd.Stderr = &stderr
logger.Tracef("running: %s %s", bin, utils.CommandString(args...))
if err := cmd.Run(); err != nil {
stderr := strings.TrimSpace(stderr.String())
if len(stderr) > 0 {
err = errors.Errorf("%v (%v)", err, stderr)
}
return err
}
return nil
}
type opensshCmd struct {
*exec.Cmd
}
func (c *opensshCmd) SetStdio(stdin io.Reader, stdout, stderr io.Writer) {
c.Stdin, c.Stdout, c.Stderr = stdin, stdout, stderr
}
func (c *opensshCmd) StdinPipe() (io.WriteCloser, io.Reader, error) {
wc, err := c.Cmd.StdinPipe()
if err != nil {
return nil, nil, err
}
return wc, c.Stdin, nil
}
func (c *opensshCmd) StdoutPipe() (io.ReadCloser, io.Writer, error) {
rc, err := c.Cmd.StdoutPipe()
if err != nil {
return nil, nil, err
}
return rc, c.Stdout, nil
}
func (c *opensshCmd) StderrPipe() (io.ReadCloser, io.Writer, error) {
rc, err := c.Cmd.StderrPipe()
if err != nil {
return nil, nil, err
}
return rc, c.Stderr, nil
}
func (c *opensshCmd) Kill() error {
if c.Process == nil {
return errors.Errorf("process has not been started")
}
return c.Process.Kill()
}
================================================
FILE: ssh/ssh_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build !windows
// +build !windows
package ssh_test
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
"github.com/juju/utils/v4/ssh"
)
const (
echoCommand = "/bin/echo"
echoScript = "#!/bin/sh\n" + echoCommand + " $0 \"$@\" | /usr/bin/tee $0.args"
)
type SSHCommandSuite struct {
testing.IsolationSuite
originalPath string
testbin string
fakessh string
fakescp string
client ssh.Client
}
var _ = gc.Suite(&SSHCommandSuite{})
func (s *SSHCommandSuite) SetUpTest(c *gc.C) {
s.IsolationSuite.SetUpTest(c)
s.testbin = c.MkDir()
s.fakessh = filepath.Join(s.testbin, "ssh")
s.fakescp = filepath.Join(s.testbin, "scp")
err := ioutil.WriteFile(s.fakessh, []byte(echoScript), 0755)
c.Assert(err, jc.ErrorIsNil)
err = ioutil.WriteFile(s.fakescp, []byte(echoScript), 0755)
c.Assert(err, jc.ErrorIsNil)
s.PatchEnvPathPrepend(s.testbin)
s.client, err = ssh.NewOpenSSHClient()
c.Assert(err, jc.ErrorIsNil)
s.PatchValue(ssh.DefaultIdentities, nil)
}
func (s *SSHCommandSuite) command(args ...string) *ssh.Cmd {
return s.commandOptions(args, nil)
}
func (s *SSHCommandSuite) commandOptions(args []string, opts *ssh.Options) *ssh.Cmd {
return s.client.Command("localhost", args, opts)
}
func (s *SSHCommandSuite) assertCommandArgs(c *gc.C, cmd *ssh.Cmd, expected string) {
out, err := cmd.Output()
c.Assert(err, jc.ErrorIsNil)
c.Assert(strings.TrimSpace(string(out)), gc.Equals, expected)
}
func (s *SSHCommandSuite) TestDefaultClient(c *gc.C) {
ssh.InitDefaultClient()
c.Assert(ssh.DefaultClient, gc.FitsTypeOf, &ssh.OpenSSHClient{})
s.PatchEnvironment("PATH", "")
ssh.InitDefaultClient()
c.Assert(ssh.DefaultClient, gc.FitsTypeOf, &ssh.GoCryptoClient{})
}
func (s *SSHCommandSuite) TestCommandSSHPass(c *gc.C) {
// First create a fake sshpass, but don't set $SSHPASS
fakesshpass := filepath.Join(s.testbin, "sshpass")
err := ioutil.WriteFile(fakesshpass, []byte(echoScript), 0755)
s.assertCommandArgs(c, s.command(echoCommand, "123"),
fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123",
s.fakessh, echoCommand),
)
// Now set $SSHPASS.
s.PatchEnvironment("SSHPASS", "anyoldthing")
s.assertCommandArgs(c, s.command(echoCommand, "123"),
fmt.Sprintf("%s -e ssh -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123",
fakesshpass, echoCommand),
)
// Finally, remove sshpass from $PATH.
err = os.Remove(fakesshpass)
c.Assert(err, jc.ErrorIsNil)
s.assertCommandArgs(c, s.command(echoCommand, "123"),
fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123",
s.fakessh, echoCommand),
)
}
func (s *SSHCommandSuite) TestCommand(c *gc.C) {
s.assertCommandArgs(c, s.command(echoCommand, "123"),
fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123",
s.fakessh, echoCommand),
)
}
func (s *SSHCommandSuite) TestCommandEnablePTY(c *gc.C) {
var opts ssh.Options
opts.EnablePTY()
s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts),
fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -t -t localhost %s 123",
s.fakessh, echoCommand),
)
}
func (s *SSHCommandSuite) TestCommandSetKnownHostsFile(c *gc.C) {
var opts ssh.Options
opts.SetKnownHostsFile("/tmp/known hosts")
s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts),
fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -o UserKnownHostsFile \"/tmp/known hosts\" localhost %s 123",
s.fakessh, echoCommand),
)
}
func (s *SSHCommandSuite) TestSetStrictHostKeyChecking(c *gc.C) {
commandPattern := fmt.Sprintf("%s%%s -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123",
s.fakessh, echoCommand)
tests := []struct {
input ssh.StrictHostChecksOption
expected string
}{
{ssh.StrictHostChecksNo, "no"},
{ssh.StrictHostChecksYes, "yes"},
{ssh.StrictHostChecksAsk, "ask"},
{ssh.StrictHostChecksDefault, ""},
{ssh.StrictHostChecksOption(999), ""},
}
for _, t := range tests {
var opts ssh.Options
opts.SetStrictHostKeyChecking(t.input)
expectedOpt := ""
if t.expected != "" {
expectedOpt = " -o StrictHostKeyChecking " + t.expected
}
s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts),
fmt.Sprintf(commandPattern, expectedOpt))
}
}
func (s *SSHCommandSuite) TestCommandAllowPasswordAuthentication(c *gc.C) {
var opts ssh.Options
opts.AllowPasswordAuthentication()
s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts),
fmt.Sprintf("%s -o ServerAliveInterval 30 localhost %s 123",
s.fakessh, echoCommand),
)
}
func (s *SSHCommandSuite) TestCommandIdentities(c *gc.C) {
var opts ssh.Options
opts.SetIdentities("x", "y")
s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts),
fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -i x -i y localhost %s 123",
s.fakessh, echoCommand),
)
}
func (s *SSHCommandSuite) TestCommandPort(c *gc.C) {
var opts ssh.Options
opts.SetPort(2022)
s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts),
fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -p 2022 localhost %s 123",
s.fakessh, echoCommand),
)
}
func (s *SSHCommandSuite) TestCopy(c *gc.C) {
var opts ssh.Options
opts.EnablePTY()
opts.AllowPasswordAuthentication()
opts.SetIdentities("x", "y")
opts.SetPort(2022)
err := s.client.Copy([]string{"/tmp/blah", "foo@bar.com:baz"}, &opts)
c.Assert(err, jc.ErrorIsNil)
out, err := ioutil.ReadFile(s.fakescp + ".args")
c.Assert(err, jc.ErrorIsNil)
// EnablePTY has no effect for Copy
c.Assert(string(out), gc.Equals, s.fakescp+" -o ServerAliveInterval 30 -i x -i y -P 2022 /tmp/blah foo@bar.com:baz\n")
// Try passing extra args
err = s.client.Copy([]string{"/tmp/blah", "foo@bar.com:baz", "-r", "-v"}, &opts)
c.Assert(err, jc.ErrorIsNil)
out, err = ioutil.ReadFile(s.fakescp + ".args")
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(out), gc.Equals, s.fakescp+" -o ServerAliveInterval 30 -i x -i y -P 2022 /tmp/blah foo@bar.com:baz -r -v\n")
// Try interspersing extra args
err = s.client.Copy([]string{"-r", "/tmp/blah", "-v", "foo@bar.com:baz"}, &opts)
c.Assert(err, jc.ErrorIsNil)
out, err = ioutil.ReadFile(s.fakescp + ".args")
c.Assert(err, jc.ErrorIsNil)
c.Assert(string(out), gc.Equals, s.fakescp+" -o ServerAliveInterval 30 -i x -i y -P 2022 -r /tmp/blah -v foo@bar.com:baz\n")
}
func (s *SSHCommandSuite) TestCommandClientKeys(c *gc.C) {
defer overrideGenerateKey().Restore()
clientKeysDir := c.MkDir()
defer ssh.ClearClientKeys()
err := ssh.LoadClientKeys(clientKeysDir)
c.Assert(err, jc.ErrorIsNil)
ck := filepath.Join(clientKeysDir, "juju_id_ed25519")
var opts ssh.Options
opts.SetIdentities("x", "y")
s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts),
fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -i x -i y -i %s localhost %s 123",
s.fakessh, ck, echoCommand),
)
}
func (s *SSHCommandSuite) TestCommandError(c *gc.C) {
var opts ssh.Options
err := ioutil.WriteFile(s.fakessh, []byte("#!/bin/sh\nexit 42"), 0755)
c.Assert(err, jc.ErrorIsNil)
command := s.client.Command("ignored", []string{echoCommand, "foo"}, &opts)
err = command.Run()
c.Assert(utils.IsRcPassthroughError(err), jc.IsTrue)
}
func (s *SSHCommandSuite) TestCommandDefaultIdentities(c *gc.C) {
var opts ssh.Options
tempdir := c.MkDir()
def1 := filepath.Join(tempdir, "def1")
def2 := filepath.Join(tempdir, "def2")
s.PatchValue(ssh.DefaultIdentities, []string{def1, def2})
// If no identities are specified, then the defaults aren't added.
s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts),
fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 localhost %s 123",
s.fakessh, echoCommand),
)
// If identities are specified, then the defaults are must added.
// Only the defaults that exist on disk will be added.
err := ioutil.WriteFile(def2, nil, 0644)
c.Assert(err, jc.ErrorIsNil)
opts.SetIdentities("x", "y")
s.assertCommandArgs(c, s.commandOptions([]string{echoCommand, "123"}, &opts),
fmt.Sprintf("%s -o PasswordAuthentication no -o ServerAliveInterval 30 -i x -i y -i %s localhost %s 123",
s.fakessh, def2, echoCommand),
)
}
func (s *SSHCommandSuite) TestCopyReader(c *gc.C) {
client := &fakeClient{}
r := bytes.NewBufferString("")
err := ssh.TestCopyReader(client, "foo@bar.com:baz", "/tmp/blah", r, nil)
c.Assert(err, jc.ErrorIsNil)
client.checkCalls(c, "foo@bar.com:baz", []string{"cat - > /tmp/blah"}, nil, nil, "Command")
client.impl.checkCalls(c, r, nil, nil, "SetStdio", "Start", "Wait")
}
================================================
FILE: ssh/stream.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh
import (
"bytes"
"io"
)
// stripCR implements an io.Reader wrapper that removes carriage return bytes.
type stripCR struct {
reader io.Reader
}
// StripCRReader returns a new io.Reader wrapper that strips carriage returns.
func StripCRReader(reader io.Reader) io.Reader {
if reader == nil {
return nil
}
return &stripCR{reader: reader}
}
var byteEmpty = []byte{}
var byteCR = []byte{'\r'}
// Read implements io.Reader interface.
// This copies data around much more than needed so should be optimized if
// used on a performance critical path.
func (s *stripCR) Read(bufOut []byte) (int, error) {
bufTemp := make([]byte, len(bufOut))
n, err := s.reader.Read(bufTemp)
bufReplaced := bytes.Replace(bufTemp[:n], byteCR, byteEmpty, -1)
copy(bufOut, bufReplaced)
return len(bufReplaced), err
}
================================================
FILE: ssh/stream_test.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh_test
import (
"io/ioutil"
"strings"
"testing/iotest"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/ssh"
)
type SSHStreamSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&SSHStreamSuite{})
func (s *SSHStreamSuite) TestNewStripCRNil(c *gc.C) {
reader := ssh.StripCRReader(nil)
c.Assert(reader, gc.IsNil)
}
func (s *SSHStreamSuite) TestStripCR(c *gc.C) {
reader := ssh.StripCRReader(strings.NewReader("One\r\nTwo"))
output, err := ioutil.ReadAll(reader)
c.Assert(err, jc.ErrorIsNil)
c.Check(string(output), gc.Equals, "One\nTwo")
}
func (s *SSHStreamSuite) TestStripCROneByte(c *gc.C) {
reader := ssh.StripCRReader(strings.NewReader("One\r\r\rTwo"))
output, err := ioutil.ReadAll(iotest.OneByteReader(reader))
c.Assert(err, jc.ErrorIsNil)
c.Check(string(output), gc.Equals, "OneTwo")
}
func (s *SSHStreamSuite) TestStripCRError(c *gc.C) {
reader := ssh.StripCRReader(strings.NewReader("One\r\r\rTwo"))
_, err := ioutil.ReadAll(iotest.TimeoutReader(reader))
c.Assert(err.Error(), gc.Equals, "timeout")
}
================================================
FILE: ssh/stream_wrapper_unix.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build !windows
// +build !windows
package ssh
import (
"io"
)
// WrapStdin returns the original stdin stream on nix platforms.
func WrapStdin(reader io.Reader) io.Reader {
return reader
}
================================================
FILE: ssh/stream_wrapper_windows.go
================================================
// Copyright 2016 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package ssh
import (
"io"
)
// WrapStdin returns stdin with carriage returns stripped on windows.
func WrapStdin(reader io.Reader) io.Reader {
return StripCRReader(reader)
}
================================================
FILE: ssh/testing/keys.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package testing
type SSHKey struct {
Key string
Fingerprint string
}
var (
ValidKeyOne = SSHKey{
`ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDEX/dPu4PmtvgK3La9zioCEDrJ` +
`yUr6xEIK7Pr+rLgydcqWTU/kt7w7gKjOw4vvzgHfjKl09CWyvgb+y5dCiTk` +
`9MxI+erGNhs3pwaoS+EavAbawB7iEqYyTep3YaJK+4RJ4OX7ZlXMAIMrTL+` +
`UVrK89t56hCkFYaAgo3VY+z6rb/b3bDBYtE1Y2tS7C3au73aDgeb9psIrSV` +
`86ucKBTl5X62FnYiyGd++xCnLB6uLximM5OKXfLzJQNS/QyZyk12g3D8y69` +
`Xw1GzCSKX1u1+MQboyf0HJcG2ryUCLHdcDVppApyHx2OLq53hlkQ/yxdflD` +
`qCqAE4j+doagSsIfC1T2T`,
"86:ed:1b:cd:26:a0:a3:4c:27:35:49:60:95:b7:0f:68",
}
ValidKeyTwo = SSHKey{
`ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDNC6zK8UMazlVgp8en8N7m7H/Y6` +
`DoMWbmPFjXYRXu6iQJJ18hCtsfMe63E5/PBaOjDT8am0Sx3Eqn4ZzpWMj+z` +
`knTcSd8xnMHYYxH2HStRWC1akTe4tTno2u2mqzjKd8f62URPtIocYCNRBls` +
`9yjnq9SogI5EXgcx6taQcrIFcIK0SlthxxcMVSlLpnbReujW65JHtiMqoYA` +
`OIALyO+Rkmtvb/ObmViDnwCKCN1up/xWt6J10MrAUtpI5b4prqG7FOqVMM/` +
`zdgrVg6rUghnzdYeQ8QMyEv4mVSLzX0XIPcxorkl9q06s5mZmAzysEbKZCO` +
`aXcLeNlXx/nkmuWslYCJ`,
"2f:fb:b0:65:68:c8:4e:a6:1b:a6:4b:8d:14:0b:40:79",
}
ValidKeyThree = SSHKey{
`ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCpGj1JMjGjAFt5wjARbIORyjQ/c` +
`ZAiDyDHe/w8qmLKUG2KTs6586QqqM6DKPZiYesrzXqvZsWYV4B6OjLM1sxq` +
`WjeDIl56PSnJ0+KP8pUV9KTkkKtRXxAoNg/II4l69e05qGffj9AcQ/7JPxx` +
`eL14Ulvh/a69r3uVkw1UGVk9Bwm4eCOSCqKalYLA1k5da6crEAXn9hiXLGs` +
`S9dOn3Lsqj5tK31aaUncue+a3iKb7R5LRFflDizzNS+h8tPuANQflOjOhR0` +
`Vas0BsurgISseZZ0NIMISyWhZpr0eOBWA/YruN9r++kYPOnDy0eMaOVGLO7` +
`SQwJ/6QHvf73yksJTncz`,
"1d:cf:ab:66:8a:f6:77:fb:4c:b2:59:6f:12:cf:cb:2f",
}
ValidKeyFour = SSHKey{
`ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCSEDMH5RyjGtEMIqM2RiPYYQgUK` +
`9wdHCo1/AXkuQ7m1iVjHhACp8Oawf2Grn7hO4e0JUn5FaEZOnDj/9HB2VPw` +
`EDGBwSN1caVC3yrTVkqQcsxBY9nTV+spQQMsePOdUZALcoEilvAcLRETbyn` +
`rybaS2bfzpqbA9MEEaKQKLKGdgqiMdNXAj5I/ik/BPp0ziOMlMl1A1zilnS` +
`UXubs1U49WWV0A70vAASvZVTXr3zrPAmstH+9Ik6FdpeE99um08FXxKYWqZ` +
`6rZF1M6L1/SqC7ediYdVgRCoti85kKhi7fZBzwrGcCnxer+D0GFz++KDSNS` +
`iAnVZxyXhmBrwnR6Q/v7`,
"37:99:ab:96:c4:e8:f8:0b:0d:04:3e:1e:ee:66:e8:9e",
}
ValidKeyMulti = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDW+8zWO6qqXrHlcMK7obliuYp7D` +
`vZBsK6rHlnbeV5Hh38Qn0GUX4Ahm6XeQ/NSx53wqkBQDGOJFY3s4w1a/hbd` +
`PyLM2/yFXCYsj5FRf01JmUjAzWhuJMH9ViqzD//l4v8cR/pHC2B8PD6abKd` +
`mIH+yLI9Cl3C4ICMKteG54egsUyboBOVKCDIKmWRLAak6sE5DPpqKF53NvD` +
`cuDufWtaCfVAOrq6NW8wSQ7PAvfDh8gsG5uvZjY3gcWl9yI3EJVGFHcdxcv` +
`4LtQI8mKdeg3JoufnEmeBJTZMoo83Gru5Z7tjv8J4JTUeQpd9uCCED1JAMe` +
`cJSKgQ2gZMTbTshobpHr` + "\n" +
`ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDSgfrzyGpE5eLiXusvLcxEmoE6e` +
`SMUDvTW1dd2BZgfvUVwq+toQdZ6C0C1JmbC3X563n8fmKVUAQGo5JavzABG` +
`Kpy90L3cwoGCFtb+A28YsT+bfuP+LdnCbFXm9c3DPJQx6Dch8prnDtzRjRV` +
`CorbPvm35NY73liUXVF6g58Owlx5rWtb8OnoTh5KQps9JTSfyNckdV9bFxP` +
`7bZvMyRYW5X33KaA+CQGpTNAKDHruSuKdAdaS6rBIZRvzzzSCF28BWwFL7Z` +
`ghQo0ADlUMnqIeQ58nwRImZHpmvadsZi47aMKFeykk4JQUQlwjbM0xGi0uj` +
`+hlaqGYbNo0Evcjn23cj`
PartValidKeyMulti = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDZRvG2miYVkbWOr2I+9xHWXqALb` +
`eBcyxAlYtbjxBRwrq8oFOw9vtIIZSO0r1FM6+JHzKhLSiPCMR/PK78ZqPgZ` +
`fia8Y7cEZKaUWLtZUAl0RF9w8EtsA/2gpuLZErjcoIx6fzfEYFCJcLgcQSc` +
`RlKG8VZT6tWIjvoLj9ki6unkG5YGmapkT60afhf3/vd7pCJO/uyszkQ9qU8` +
`odUDTTlwftpJtUb8xGmzpEZJTgk1lbZKlZm5pVXwjNEodH7Je88RBzR7PBB` +
`Jct+vf8wVJ/UEFXCnamvHLanJTcJIi/I5qRlKns65Bwb8M0HszPYmvTfFRD` +
`ZLi3sPUmw6PJCJ0SgATd` + "\n" +
`ssh-rsa bad key`
MultiInvalid = `ssh-rsa bad key` + "\n" +
`ssh-rsa also bad`
EmptyKeyMulti = ""
)
================================================
FILE: symlink/export_test.go
================================================
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
package symlink
var (
GetLongPathAsString = getLongPathAsString
)
================================================
FILE: symlink/symlink.go
================================================
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
package symlink
import (
"fmt"
"os"
"path/filepath"
"github.com/juju/utils/v4"
)
// Replace will do an atomic replacement of a symlink to a new path
func Replace(link, newpath string) error {
dstDir := filepath.Dir(link)
uuid, err := utils.NewUUID()
if err != nil {
return err
}
randStr := uuid.String()
tmpFile := filepath.Join(dstDir, "tmpfile"+randStr)
// Create the new symlink before removing the old one. This way, if New()
// fails, we still have a link to the old tools.
err = New(newpath, tmpFile)
if err != nil {
return fmt.Errorf("cannot create symlink: %s", err)
}
// On Windows, symlinks may not be overwritten. We remove it first,
// and then rename tmpFile
if _, err := os.Stat(link); err == nil {
err = os.RemoveAll(link)
if err != nil {
return err
}
}
err = os.Rename(tmpFile, link)
if err != nil {
return fmt.Errorf("cannot update tools symlink: %v", err)
}
return nil
}
================================================
FILE: symlink/symlink_posix.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build linux || darwin
// +build linux darwin
package symlink
import (
"os"
"github.com/juju/errors"
)
// New is a wrapper function for os.Symlink() on Linux
func New(oldname, newname string) error {
return os.Symlink(oldname, newname)
}
// Read is a wrapper for os.Readlink() on Linux
func Read(link string) (string, error) {
return os.Readlink(link)
}
func IsSymlink(path string) (bool, error) {
st, err := os.Lstat(path)
if err != nil {
return false, errors.Trace(err)
}
return st.Mode()&os.ModeSymlink != 0, nil
}
// getLongPathAsString does nothing on linux. Its here for compatibillity
// with the windows implementation
func getLongPathAsString(path string) (string, error) {
return path, nil
}
================================================
FILE: symlink/symlink_test.go
================================================
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
package symlink_test
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
"github.com/juju/utils/v4/symlink"
)
type SymlinkSuite struct{}
var _ = gc.Suite(&SymlinkSuite{})
func Test(t *testing.T) {
gc.TestingT(t)
}
func (*SymlinkSuite) TestReplace(c *gc.C) {
target, err := symlink.GetLongPathAsString(c.MkDir())
c.Assert(err, gc.IsNil)
target_second, err := symlink.GetLongPathAsString(c.MkDir())
c.Assert(err, gc.IsNil)
link := filepath.Join(target, "link")
_, err = os.Stat(target)
c.Assert(err, gc.IsNil)
_, err = os.Stat(target_second)
c.Assert(err, gc.IsNil)
err = symlink.New(target, link)
c.Assert(err, gc.IsNil)
link_target, err := symlink.Read(link)
c.Assert(err, gc.IsNil)
c.Assert(link_target, gc.Equals, filepath.FromSlash(target))
err = symlink.Replace(link, target_second)
c.Assert(err, gc.IsNil)
link_target, err = symlink.Read(link)
c.Assert(err, gc.IsNil)
c.Assert(link_target, gc.Equals, filepath.FromSlash(target_second))
}
func (*SymlinkSuite) TestIsSymlinkFile(c *gc.C) {
dir, err := symlink.GetLongPathAsString(c.MkDir())
c.Assert(err, gc.IsNil)
target := filepath.Join(dir, "file")
err = ioutil.WriteFile(target, []byte("TOP SECRET"), 0644)
c.Assert(err, gc.IsNil)
link := filepath.Join(dir, "link")
_, err = os.Stat(target)
c.Assert(err, gc.IsNil)
err = symlink.New(target, link)
c.Assert(err, gc.IsNil)
isSymlink, err := symlink.IsSymlink(link)
c.Assert(err, gc.IsNil)
c.Assert(isSymlink, jc.IsTrue)
}
func (*SymlinkSuite) TestIsSymlinkFolder(c *gc.C) {
target, err := symlink.GetLongPathAsString(c.MkDir())
c.Assert(err, gc.IsNil)
link := filepath.Join(target, "link")
_, err = os.Stat(target)
c.Assert(err, gc.IsNil)
err = symlink.New(target, link)
c.Assert(err, gc.IsNil)
isSymlink, err := symlink.IsSymlink(link)
c.Assert(err, gc.IsNil)
c.Assert(isSymlink, jc.IsTrue)
}
func (*SymlinkSuite) TestIsSymlinkFalseFile(c *gc.C) {
dir := c.MkDir()
target := filepath.Join(dir, "file")
err := ioutil.WriteFile(target, []byte("TOP SECRET"), 0644)
c.Assert(err, gc.IsNil)
_, err = os.Stat(target)
c.Assert(err, gc.IsNil)
isSymlink, err := symlink.IsSymlink(target)
c.Assert(err, gc.IsNil)
c.Assert(isSymlink, jc.IsFalse)
}
func (*SymlinkSuite) TestIsSymlinkFalseFolder(c *gc.C) {
target, err := symlink.GetLongPathAsString(c.MkDir())
c.Assert(err, gc.IsNil)
_, err = os.Stat(target)
c.Assert(err, gc.IsNil)
isSymlink, err := symlink.IsSymlink(target)
c.Assert(err, gc.IsNil)
c.Assert(isSymlink, jc.IsFalse)
}
func (*SymlinkSuite) TestIsSymlinkFileDoesNotExist(c *gc.C) {
dir := c.MkDir()
target := filepath.Join(dir, "file")
isSymlink, err := symlink.IsSymlink(target)
c.Assert(err, gc.ErrorMatches, ".*"+utils.NoSuchFileErrRegexp)
c.Assert(isSymlink, jc.IsFalse)
}
================================================
FILE: symlink/symlink_windows.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
// Author: Robert Tingirica
package symlink
import (
"os"
"strings"
"syscall"
"unicode/utf16"
"unsafe"
"github.com/juju/errors"
)
const (
SYMBOLIC_LINK_FLAG_DIRECTORY = 1
// This is the equivalent of syscall.GENERIC_EXECUTION
// Using syscall.GENERIC_EXECUTION results in an "Access denied" error
GENERIC_EXECUTION = 33554432
// (TODO): bogdanteleaga or anybody else:
// Remove this once we upgrade to a go version that has it in the syscall
// package
FILE_ATTRIBUTE_REPARSE_POINT = 0x00000400
)
//sys createSymbolicLink(symlinkname *uint16, targetname *uint16, flags uint32) (err error) = CreateSymbolicLinkW
//sys getFinalPathNameByHandle(handle Handle, buf *uint16, buflen uint32, flags uint32) (n uint32, err error) = GetFinalPathNameByHandleW
// New creates newname as a symbolic link to oldname.
// If there is an error, it will be of type *LinkError.
func New(oldname, newname string) error {
fi, err := os.Stat(oldname)
if err != nil {
return &os.LinkError{"symlink", oldname, newname, err}
}
var flag uint32
if fi.IsDir() {
flag = SYMBOLIC_LINK_FLAG_DIRECTORY
}
targetp, err := getLongPath(oldname)
if err != nil {
return &os.LinkError{"symlink", oldname, newname, err}
}
linkp, err := syscall.UTF16PtrFromString(newname)
if err != nil {
return &os.LinkError{"symlink", oldname, newname, err}
}
err = createSymbolicLink(linkp, &targetp[0], flag)
if err != nil {
return &os.LinkError{"symlink", oldname, newname, err}
}
return nil
}
// Read returns the destination of the named symbolic link.
// If there is an error, it will be of type *PathError.
func Read(link string) (string, error) {
linkp, err := getLongPath(link)
if err != nil {
return "", err
}
h, err := syscall.CreateFile(
&linkp[0],
syscall.GENERIC_READ,
syscall.FILE_SHARE_READ,
nil,
syscall.OPEN_EXISTING,
GENERIC_EXECUTION,
0)
if err != nil {
return "", &os.PathError{"readlink", link, err}
}
defer syscall.CloseHandle(h)
pathw := make([]uint16, syscall.MAX_PATH)
n, err := getFinalPathNameByHandle(h, &pathw[0], uint32(len(pathw)), 0)
if err != nil {
return "", &os.PathError{"readlink", link, err}
}
if n > uint32(len(pathw)) {
pathw = make([]uint16, n)
n, err = getFinalPathNameByHandle(h, &pathw[0], uint32(len(pathw)), 0)
if err != nil {
return "", &os.PathError{"readlink", link, err}
}
if n > uint32(len(pathw)) {
return "", &os.PathError{"readlink", link, errors.New("link length too long")}
}
}
ret := string(utf16.Decode(pathw[0:n]))
if strings.HasPrefix(ret, `\\?\`) {
return ret[4:], nil
}
retp, err := getLongPath(ret)
if err != nil {
return "", &os.PathError{"readlink", link, err}
}
return syscall.UTF16ToString(retp), nil
}
func IsSymlink(path string) (bool, error) {
var fa syscall.Win32FileAttributeData
namep, err := syscall.UTF16PtrFromString(path)
if err != nil {
return false, errors.Trace(err)
}
err = syscall.GetFileAttributesEx(namep, syscall.GetFileExInfoStandard, (*byte)(unsafe.Pointer(&fa)))
if err != nil {
return false, errors.Trace(err)
}
return fa.FileAttributes&FILE_ATTRIBUTE_REPARSE_POINT != 0, nil
}
// getLongPath converts windows 8.1 short style paths (c:\Progra~1\foo) to full
// long paths.
func getLongPath(path string) ([]uint16, error) {
pathp, err := syscall.UTF16FromString(path)
if err != nil {
return nil, err
}
longp := pathp
n, err := syscall.GetLongPathName(&pathp[0], &longp[0], uint32(len(longp)))
if err != nil {
return nil, err
}
if n > uint32(len(longp)) {
longp = make([]uint16, n)
n, err = syscall.GetLongPathName(&pathp[0], &longp[0], uint32(len(longp)))
if err != nil {
return nil, err
}
}
longp = longp[:n]
return longp, nil
}
func getLongPathAsString(path string) (string, error) {
longp, err := getLongPath(path)
if err != nil {
return "", err
}
return syscall.UTF16ToString(longp), nil
}
================================================
FILE: symlink/symlink_windows_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
package symlink_test
import (
"io/ioutil"
"os"
"path/filepath"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/symlink"
)
func (*SymlinkSuite) TestLongPath(c *gc.C) {
programFiles := `C:\PROGRA~1`
longProg := `C:\Program Files`
target, err := symlink.GetLongPathAsString(programFiles)
c.Assert(err, gc.IsNil)
c.Assert(target, gc.Equals, longProg)
}
func (*SymlinkSuite) TestCreateSymLink(c *gc.C) {
target, err := symlink.GetLongPathAsString(c.MkDir())
c.Assert(err, gc.IsNil)
link := filepath.Join(target, "link")
_, err = os.Stat(target)
c.Assert(err, gc.IsNil)
err = symlink.New(target, link)
c.Assert(err, gc.IsNil)
link, err = symlink.Read(link)
c.Assert(err, gc.IsNil)
c.Assert(link, gc.Equals, filepath.FromSlash(target))
}
func (*SymlinkSuite) TestReadData(c *gc.C) {
dir := c.MkDir()
sub := filepath.Join(dir, "sub")
err := os.Mkdir(sub, 0700)
c.Assert(err, gc.IsNil)
oldname := filepath.Join(sub, "foo")
data := []byte("data")
err = ioutil.WriteFile(oldname, data, 0644)
c.Assert(err, gc.IsNil)
newname := filepath.Join(dir, "bar")
err = symlink.New(oldname, newname)
c.Assert(err, gc.IsNil)
b, err := ioutil.ReadFile(newname)
c.Assert(err, gc.IsNil)
c.Assert(string(b), gc.Equals, string(data))
}
================================================
FILE: symlink/zsymlink_windows_386.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
// mksyscall_windows.pl -l32 symlink/symlink_windows.go
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
package symlink
import "unsafe"
import "syscall"
var (
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
procCreateSymbolicLinkW = modkernel32.NewProc("CreateSymbolicLinkW")
procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW")
)
func createSymbolicLink(symlinkname *uint16, targetname *uint16, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procCreateSymbolicLinkW.Addr(), 3, uintptr(unsafe.Pointer(symlinkname)), uintptr(unsafe.Pointer(targetname)), uintptr(flags))
if r1 == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getFinalPathNameByHandle(handle syscall.Handle, buf *uint16, buflen uint32, flags uint32) (n uint32, err error) {
r0, _, e1 := syscall.Syscall6(procGetFinalPathNameByHandleW.Addr(), 4, uintptr(handle), uintptr(unsafe.Pointer(buf)), uintptr(buflen), uintptr(flags), 0, 0)
n = uint32(r0)
if n == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
================================================
FILE: symlink/zsymlink_windows_amd64.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
// mksyscall_windows.pl symlink/symlink_windows.go
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
package symlink
import "unsafe"
import "syscall"
var (
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
procCreateSymbolicLinkW = modkernel32.NewProc("CreateSymbolicLinkW")
procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW")
)
func createSymbolicLink(symlinkname *uint16, targetname *uint16, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procCreateSymbolicLinkW.Addr(), 3, uintptr(unsafe.Pointer(symlinkname)), uintptr(unsafe.Pointer(targetname)), uintptr(flags))
if r1 == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getFinalPathNameByHandle(handle syscall.Handle, buf *uint16, buflen uint32, flags uint32) (n uint32, err error) {
r0, _, e1 := syscall.Syscall6(procGetFinalPathNameByHandleW.Addr(), 4, uintptr(handle), uintptr(unsafe.Pointer(buf)), uintptr(buflen), uintptr(flags), 0, 0)
n = uint32(r0)
if n == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
================================================
FILE: systemerrmessages_unix.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
//go:build !windows
// +build !windows
package utils
// The following are strings/regex-es which match common Unix error messages
// that may be returned in case of failed calls to the system.
// Any extra leading/trailing regex-es are left to be added by the developer.
const (
NoSuchUserErrRegexp = `user: unknown user [a-z0-9_-]*`
NoSuchFileErrRegexp = `no such file or directory`
MkdirFailErrRegexp = `.* not a directory`
)
================================================
FILE: systemerrmessages_windows.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
// The following are strings/regex-es which match common Windows error messages
// that may be returned in case of failed calls to the system.
// Any extra leading/trailing regex-es are left to be added by the developer.
const (
NoSuchUserErrRegexp = `No mapping between account names and security IDs was done\.`
NoSuchFileErrRegexp = `The system cannot find the (file|path) specified\.`
MkdirFailErrRegexp = `mkdir .*` + NoSuchFileErrRegexp
)
================================================
FILE: tailer/export_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package tailer
var (
BufferSize = &bufferSize
NewTestTailer = newTailer
)
================================================
FILE: tailer/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package tailer_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: tailer/tailer.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package tailer
import (
"bufio"
"bytes"
"io"
"os"
"time"
"gopkg.in/tomb.v1"
)
const (
defaultBufferSize = 4096
polltime = time.Second
delimiter = '\n'
)
var (
bufferSize = defaultBufferSize
delimiters = []byte{delimiter}
)
// TailerFilterFunc decides if a line shall be tailed (func is nil or
// returns true) of shall be omitted (func returns false).
type TailerFilterFunc func(line []byte) bool
// Tailer reads an input line by line an tails them into the passed Writer.
// The lines have to be terminated with a newline.
type Tailer struct {
tomb tomb.Tomb
readSeeker io.ReadSeeker
reader *bufio.Reader
writeCloser io.WriteCloser
writer *bufio.Writer
filter TailerFilterFunc
polltime time.Duration
}
// NewTailer starts a Tailer which reads strings from the passed
// ReadSeeker line by line. If a filter function is specified the read
// lines are filtered. The matching lines are written to the passed
// Writer.
func NewTailer(readSeeker io.ReadSeeker, writer io.Writer, filter TailerFilterFunc) *Tailer {
return newTailer(readSeeker, writer, filter, polltime)
}
// newTailer starts a Tailer like NewTailer but allows the setting of
// the read buffer size and the time between pollings for testing.
func newTailer(readSeeker io.ReadSeeker, writer io.Writer,
filter TailerFilterFunc, polltime time.Duration) *Tailer {
t := &Tailer{
readSeeker: readSeeker,
reader: bufio.NewReaderSize(readSeeker, bufferSize),
writer: bufio.NewWriter(writer),
filter: filter,
polltime: polltime,
}
go func() {
defer t.tomb.Done()
t.tomb.Kill(t.loop())
}()
return t
}
// Stop tells the tailer to stop working.
func (t *Tailer) Stop() error {
t.tomb.Kill(nil)
return t.tomb.Wait()
}
// Wait waits until the tailer is stopped due to command
// or an error. In case of an error it returns the reason.
func (t *Tailer) Wait() error {
return t.tomb.Wait()
}
// Dead returns the channel that can be used to wait until
// the tailer is stopped.
func (t *Tailer) Dead() <-chan struct{} {
return t.tomb.Dead()
}
// Err returns a possible error.
func (t *Tailer) Err() error {
return t.tomb.Err()
}
// loop writes the last lines based on the buffer size to the
// writer and then polls for more data to write it to the
// writer too.
func (t *Tailer) loop() error {
// Start polling.
// TODO(mue) 2013-12-06
// Handling of read-seeker/files being truncated during
// tailing is currently missing!
timer := time.NewTimer(0)
for {
select {
case <-t.tomb.Dying():
return nil
case <-timer.C:
for {
line, readErr := t.readLine()
_, writeErr := t.writer.Write(line)
if writeErr != nil {
return writeErr
}
if readErr != nil {
if readErr != io.EOF {
return readErr
}
break
}
}
if writeErr := t.writer.Flush(); writeErr != nil {
return writeErr
}
timer.Reset(t.polltime)
}
}
}
// SeekLastLines sets the read position of the ReadSeeker to the
// wanted number of filtered lines before the end.
func SeekLastLines(readSeeker io.ReadSeeker, lines uint, filter TailerFilterFunc) error {
offset, err := readSeeker.Seek(0, os.SEEK_END)
if err != nil {
return err
}
if lines == 0 {
// We are done, just seeking to the end is sufficient.
return nil
}
seekPos := int64(0)
found := uint(0)
buffer := make([]byte, bufferSize)
SeekLoop:
for offset > 0 {
// buffer contains the data left over from the
// previous iteration.
space := cap(buffer) - len(buffer)
if space < bufferSize {
// Grow buffer.
newBuffer := make([]byte, len(buffer), cap(buffer)*2)
copy(newBuffer, buffer)
buffer = newBuffer
space = cap(buffer) - len(buffer)
}
if int64(space) > offset {
// Use exactly the right amount of space if there's
// only a small amount remaining.
space = int(offset)
}
// Copy data remaining from last time to the end of the buffer,
// so we can read into the right place.
copy(buffer[space:cap(buffer)], buffer)
buffer = buffer[0 : len(buffer)+space]
offset -= int64(space)
_, err := readSeeker.Seek(offset, os.SEEK_SET)
if err != nil {
return err
}
_, err = io.ReadFull(readSeeker, buffer[0:space])
if err != nil {
return err
}
// Find the end of the last line in the buffer.
// This will discard any unterminated line at the end
// of the file.
end := bytes.LastIndex(buffer, delimiters)
if end == -1 {
// No end of line found - discard incomplete
// line and continue looking. If this happens
// at the beginning of the file, we don't care
// because we're going to stop anyway.
buffer = buffer[:0]
continue
}
end++
for {
start := bytes.LastIndex(buffer[0:end-1], delimiters)
if start == -1 && offset >= 0 {
break
}
start++
if filter == nil || filter(buffer[start:end]) {
found++
if found >= lines {
seekPos = offset + int64(start)
break SeekLoop
}
}
end = start
}
// Leave the last line in buffer, as we don't know whether
// it's complete or not.
buffer = buffer[0:end]
}
// Final positioning.
readSeeker.Seek(seekPos, os.SEEK_SET)
return nil
}
// readLine reads the next valid line from the reader, even if it is
// larger than the reader buffer.
func (t *Tailer) readLine() ([]byte, error) {
for {
slice, err := t.reader.ReadSlice(delimiter)
if err == nil {
if t.isValid(slice) {
return slice, nil
}
continue
}
line := append([]byte(nil), slice...)
for err == bufio.ErrBufferFull {
slice, err = t.reader.ReadSlice(delimiter)
line = append(line, slice...)
}
switch err {
case nil:
if t.isValid(line) {
return line, nil
}
case io.EOF:
// EOF without delimiter, step back.
t.readSeeker.Seek(-int64(len(line)), os.SEEK_CUR)
return nil, err
default:
return nil, err
}
}
}
// isValid checks if the passed line is valid by checking if the
// line has content, the filter function is nil or it returns true.
func (t *Tailer) isValid(line []byte) bool {
if t.filter == nil {
return true
}
return t.filter(line)
}
================================================
FILE: tailer/tailer_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package tailer_test
import (
"bufio"
"bytes"
"fmt"
"io"
"sync"
"time"
"github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/tailer"
)
type tailerSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&tailerSuite{})
var alphabetData = []string{
"alpha alpha\n",
"bravo bravo\n",
"charlie charlie\n",
"delta delta\n",
"echo echo\n",
"foxtrott foxtrott\n",
"golf golf\n",
"hotel hotel\n",
"india india\n",
"juliet juliet\n",
"kilo kilo\n",
"lima lima\n",
"mike mike\n",
"november november\n",
"oscar oscar\n",
"papa papa\n",
"quebec quebec\n",
"romeo romeo\n",
"sierra sierra\n",
"tango tango\n",
"uniform uniform\n",
"victor victor\n",
"whiskey whiskey\n",
"x-ray x-ray\n",
"yankee yankee\n",
"zulu zulu\n",
}
var tests = []struct {
description string
data []string
initialLinesWritten int
initialLinesRequested uint
bufferSize int
filter tailer.TailerFilterFunc
injector func(*tailer.Tailer, *readSeeker) func([]string)
initialCollectedData []string
appendedCollectedData []string
fromStart bool
err string
}{{
description: "lines are longer than buffer size",
data: []string{
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n",
"0123456789012345678901234567890123456789012345678901\n",
},
initialLinesWritten: 1,
initialLinesRequested: 1,
bufferSize: 5,
initialCollectedData: []string{
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n",
},
appendedCollectedData: []string{
"0123456789012345678901234567890123456789012345678901\n",
},
}, {
description: "lines are longer than buffer size, missing termination of last line",
data: []string{
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n",
"0123456789012345678901234567890123456789012345678901\n",
"the quick brown fox ",
},
initialLinesWritten: 1,
initialLinesRequested: 1,
bufferSize: 5,
initialCollectedData: []string{
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n",
},
appendedCollectedData: []string{
"0123456789012345678901234567890123456789012345678901\n",
},
}, {
description: "lines are longer than buffer size, last line is terminated later",
data: []string{
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n",
"0123456789012345678901234567890123456789012345678901\n",
"the quick brown fox ",
"jumps over the lazy dog\n",
},
initialLinesWritten: 1,
initialLinesRequested: 1,
bufferSize: 5,
initialCollectedData: []string{
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n",
},
appendedCollectedData: []string{
"0123456789012345678901234567890123456789012345678901\n",
"the quick brown fox jumps over the lazy dog\n",
},
}, {
description: "missing termination of last line",
data: []string{
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n",
"0123456789012345678901234567890123456789012345678901\n",
"the quick brown fox ",
},
initialLinesWritten: 1,
initialLinesRequested: 1,
initialCollectedData: []string{
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n",
},
appendedCollectedData: []string{
"0123456789012345678901234567890123456789012345678901\n",
},
}, {
description: "last line is terminated later",
data: []string{
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n",
"0123456789012345678901234567890123456789012345678901\n",
"the quick brown fox ",
"jumps over the lazy dog\n",
},
initialLinesWritten: 1,
initialLinesRequested: 1,
initialCollectedData: []string{
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\n",
},
appendedCollectedData: []string{
"0123456789012345678901234567890123456789012345678901\n",
"the quick brown fox jumps over the lazy dog\n",
},
}, {
description: "more lines already written than initially requested",
data: alphabetData,
initialLinesWritten: 5,
initialLinesRequested: 3,
initialCollectedData: []string{
"charlie charlie\n",
"delta delta\n",
"echo echo\n",
},
appendedCollectedData: alphabetData[5:],
}, {
description: "less lines already written than initially requested",
data: alphabetData,
initialLinesWritten: 3,
initialLinesRequested: 5,
initialCollectedData: []string{
"alpha alpha\n",
"bravo bravo\n",
"charlie charlie\n",
},
appendedCollectedData: alphabetData[3:],
}, {
description: "lines are longer than buffer size, more lines already written than initially requested",
data: alphabetData,
initialLinesWritten: 5,
initialLinesRequested: 3,
bufferSize: 5,
initialCollectedData: []string{
"charlie charlie\n",
"delta delta\n",
"echo echo\n",
},
appendedCollectedData: alphabetData[5:],
}, {
description: "ignore current lines",
data: alphabetData,
initialLinesWritten: 5,
bufferSize: 5,
appendedCollectedData: alphabetData[5:],
}, {
description: "start from the start",
data: alphabetData,
initialLinesWritten: 5,
bufferSize: 5,
appendedCollectedData: alphabetData,
fromStart: true,
}, {
description: "lines are longer than buffer size, less lines already written than initially requested",
data: alphabetData,
initialLinesWritten: 3,
initialLinesRequested: 5,
bufferSize: 5,
initialCollectedData: []string{
"alpha alpha\n",
"bravo bravo\n",
"charlie charlie\n",
},
appendedCollectedData: alphabetData[3:],
}, {
description: "filter lines which contain the char 'e'",
data: alphabetData,
initialLinesWritten: 10,
initialLinesRequested: 3,
filter: func(line []byte) bool {
return bytes.Contains(line, []byte{'e'})
},
initialCollectedData: []string{
"echo echo\n",
"hotel hotel\n",
"juliet juliet\n",
},
appendedCollectedData: []string{
"mike mike\n",
"november november\n",
"quebec quebec\n",
"romeo romeo\n",
"sierra sierra\n",
"whiskey whiskey\n",
"yankee yankee\n",
},
}, {
description: "stop tailing after 10 collected lines",
data: alphabetData,
initialLinesWritten: 5,
initialLinesRequested: 3,
injector: func(t *tailer.Tailer, rs *readSeeker) func([]string) {
return func(lines []string) {
if len(lines) == 10 {
t.Stop()
}
}
},
initialCollectedData: []string{
"charlie charlie\n",
"delta delta\n",
"echo echo\n",
},
appendedCollectedData: alphabetData[5:],
}, {
description: "generate an error after 10 collected lines",
data: alphabetData,
initialLinesWritten: 5,
initialLinesRequested: 3,
injector: func(t *tailer.Tailer, rs *readSeeker) func([]string) {
return func(lines []string) {
if len(lines) == 10 {
rs.setError(fmt.Errorf("ouch after 10 lines"))
}
}
},
initialCollectedData: []string{
"charlie charlie\n",
"delta delta\n",
"echo echo\n",
},
appendedCollectedData: alphabetData[5:],
err: "ouch after 10 lines",
}, {
description: "more lines already written than initially requested, some empty, unfiltered",
data: []string{
"one one\n",
"two two\n",
"\n",
"\n",
"three three\n",
"four four\n",
"\n",
"\n",
"five five\n",
"six six\n",
},
initialLinesWritten: 3,
initialLinesRequested: 2,
initialCollectedData: []string{
"two two\n",
"\n",
},
appendedCollectedData: []string{
"\n",
"three three\n",
"four four\n",
"\n",
"\n",
"five five\n",
"six six\n",
},
}, {
description: "more lines already written than initially requested, some empty, those filtered",
data: []string{
"one one\n",
"two two\n",
"\n",
"\n",
"three three\n",
"four four\n",
"\n",
"\n",
"five five\n",
"six six\n",
},
initialLinesWritten: 3,
initialLinesRequested: 2,
filter: func(line []byte) bool {
return len(bytes.TrimSpace(line)) > 0
},
initialCollectedData: []string{
"one one\n",
"two two\n",
},
appendedCollectedData: []string{
"three three\n",
"four four\n",
"five five\n",
"six six\n",
},
}}
func (s *tailerSuite) TestTailer(c *gc.C) {
for i, test := range tests {
c.Logf("Test #%d) %s", i, test.description)
bufferSize := test.bufferSize
if bufferSize == 0 {
// Default value.
bufferSize = 4096
}
s.PatchValue(tailer.BufferSize, bufferSize)
reader, writer := io.Pipe()
sigc := make(chan struct{}, 1)
rs := startReadSeeker(c, test.data, test.initialLinesWritten, sigc)
if !test.fromStart {
err := tailer.SeekLastLines(rs, test.initialLinesRequested, test.filter)
c.Assert(err, gc.IsNil)
}
tailer := tailer.NewTestTailer(rs, writer, test.filter, 2*time.Millisecond)
linec := startReading(c, tailer, reader, writer)
// Collect initial data.
assertCollected(c, linec, test.initialCollectedData, nil)
sigc <- struct{}{}
// Collect remaining data, possibly with injection to stop
// earlier or generate an error.
var injection func([]string)
if test.injector != nil {
injection = test.injector(tailer, rs)
}
assertCollected(c, linec, test.appendedCollectedData, injection)
if test.err == "" {
c.Assert(tailer.Stop(), gc.IsNil)
} else {
c.Assert(tailer.Err(), gc.ErrorMatches, test.err)
}
}
}
// startReading starts a goroutine receiving the lines out of the reader
// in the background and passing them to a created string channel. This
// will used in the assertions.
func startReading(c *gc.C, tailer *tailer.Tailer, reader *io.PipeReader, writer *io.PipeWriter) chan string {
linec := make(chan string)
// Start goroutine for reading.
go func() {
defer close(linec)
reader := bufio.NewReader(reader)
for {
line, err := reader.ReadString('\n')
switch err {
case nil:
linec <- line
case io.EOF:
return
default:
c.Fail()
}
}
}()
// Close writer when tailer is stopped or has an error. Tailer using
// components can do it the same way.
go func() {
tailer.Wait()
writer.Close()
}()
return linec
}
// assertCollected reads lines from the string channel linec. It compares if
// those are the one passed with compare until a timeout. If the timeout is
// reached earlier than all lines are collected the assertion fails. The
// injection function allows to interrupt the processing with a function
// generating an error or a regular stopping during the tailing. In case the
// linec is closed due to stopping or an error only the values so far care
// compared. Checking the reason for termination is done in the test.
func assertCollected(c *gc.C, linec chan string, compare []string, injection func([]string)) {
if len(compare) == 0 {
return
}
timeout := time.After(10 * time.Second)
lines := []string{}
for {
select {
case line, ok := <-linec:
if ok {
lines = append(lines, line)
if injection != nil {
injection(lines)
}
if len(lines) == len(compare) {
// All data received.
c.Assert(lines, gc.DeepEquals, compare)
return
}
} else {
// linec closed after stopping or error.
c.Assert(lines, gc.DeepEquals, compare[:len(lines)])
return
}
case <-timeout:
if injection == nil {
c.Fatalf("timeout during tailer collection")
}
return
}
}
}
// startReadSeeker returns a ReadSeeker for the Tailer. It simulates
// reading and seeking inside a file and also simulating an error.
// The goroutine waits for a signal that it can start writing the
// appended lines.
func startReadSeeker(c *gc.C, data []string, initialLeg int, sigc chan struct{}) *readSeeker {
// Write initial lines into the buffer.
var rs readSeeker
var i int
for i = 0; i < initialLeg; i++ {
rs.write(data[i])
}
go func() {
<-sigc
for ; i < len(data); i++ {
time.Sleep(5 * time.Millisecond)
rs.write(data[i])
}
}()
return &rs
}
type readSeeker struct {
mux sync.Mutex
buffer []byte
pos int
err error
}
func (r *readSeeker) write(s string) {
r.mux.Lock()
defer r.mux.Unlock()
r.buffer = append(r.buffer, []byte(s)...)
}
func (r *readSeeker) setError(err error) {
r.mux.Lock()
defer r.mux.Unlock()
r.err = err
}
func (r *readSeeker) Read(p []byte) (n int, err error) {
r.mux.Lock()
defer r.mux.Unlock()
if r.err != nil {
return 0, r.err
}
if r.pos >= len(r.buffer) {
return 0, io.EOF
}
n = copy(p, r.buffer[r.pos:])
r.pos += n
return n, nil
}
func (r *readSeeker) Seek(offset int64, whence int) (ret int64, err error) {
r.mux.Lock()
defer r.mux.Unlock()
var newPos int64
switch whence {
case 0:
newPos = offset
case 1:
newPos = int64(r.pos) + offset
case 2:
newPos = int64(len(r.buffer)) + offset
default:
return 0, fmt.Errorf("invalid whence: %d", whence)
}
if newPos < 0 {
return 0, fmt.Errorf("negative position: %d", newPos)
}
if newPos >= 1<<31 {
return 0, fmt.Errorf("position out of range: %d", newPos)
}
r.pos = int(newPos)
return newPos, nil
}
================================================
FILE: tar/tar.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// This package provides convenience helpers on top of archive/tar
// to be able to tar/untar files with a functionality closer
// to gnu tar command.
package tar
import (
"archive/tar"
"crypto/sha1"
"encoding/base64"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/juju/errors"
"github.com/juju/collections/set"
"github.com/juju/utils/v4/symlink"
)
// FindFile returns the header and ReadCloser for the entry in the
// tarfile that matches the filename. If nothing matches, an
// errors.NotFound error is returned.
func FindFile(tarFile io.Reader, filename string) (*tar.Header, io.Reader, error) {
reader := tar.NewReader(tarFile)
for {
header, err := reader.Next()
if err == io.EOF {
break
}
if err != nil {
return nil, nil, errors.Trace(err)
}
if header.Name == filename {
return header, reader, nil
}
}
return nil, nil, errors.NotFoundf(filename)
}
// TarFiles writes a tar stream into target holding the files listed
// in fileList. strip will be removed from the beginning of all the paths
// when stored (much like gnu tar -C option)
// Returns a Sha sum of the tar and nil if everything went well
// or empty sting and error in case of error.
// We use a base64 encoded sha1 hash, because this is the hash
// used by RFC 3230 Digest headers in http responses
// It is not safe to mutate files passed during this function,
// however at least the bytes up to the inital size are written
// successfully if no error is returned.
func TarFiles(fileList []string, target io.Writer, strip string) (shaSum string, err error) {
shahash := sha1.New()
if err := tarAndHashFiles(fileList, target, strip, shahash); err != nil {
return "", err
}
encodedHash := base64.StdEncoding.EncodeToString(shahash.Sum(nil))
return encodedHash, nil
}
func tarAndHashFiles(fileList []string, target io.Writer, strip string, hashw io.Writer) (err error) {
checkClose := func(w io.Closer) {
if closeErr := w.Close(); closeErr != nil && err == nil {
err = fmt.Errorf("error closing tar writer: %v", closeErr)
}
}
w := io.MultiWriter(target, hashw)
tarw := tar.NewWriter(w)
defer checkClose(tarw)
for _, ent := range fileList {
if err := writeContents(ent, strip, tarw); err != nil {
return fmt.Errorf("write to tar file failed: %v", err)
}
}
return nil
}
// writeContents creates an entry for the given file
// or directory in the given tar archive.
func writeContents(fileName, strip string, tarw *tar.Writer) error {
f, err := os.Open(fileName)
if err != nil {
return err
}
defer f.Close()
fInfo, err := os.Lstat(fileName)
if err != nil {
return err
}
link := ""
if fInfo.Mode()&os.ModeSymlink == os.ModeSymlink {
link, err = filepath.EvalSymlinks(fileName)
if err != nil {
return fmt.Errorf("cannnot dereference symlink: %v", err)
}
}
h, err := tar.FileInfoHeader(fInfo, link)
if err != nil {
return fmt.Errorf("cannot create tar header for %q: %v", fileName, err)
}
h.Name = filepath.ToSlash(strings.TrimPrefix(fileName, strip))
if err := tarw.WriteHeader(h); err != nil {
return fmt.Errorf("cannot write header for %q: %v", fileName, err)
}
if fInfo.Mode()&os.ModeSymlink == os.ModeSymlink {
return nil
}
if !fInfo.IsDir() {
// Limit data copied to inital stat size included in tar header
// or ErrWriteTooLong is raised by archive/tar Writer.
if _, err := io.CopyN(tarw, f, fInfo.Size()); err != nil {
return fmt.Errorf("failed to write %q: %v", fileName, err)
}
return nil
}
for {
names, err := f.Readdirnames(100)
// will return at most 100 names and if less than 100 remaining
// next call will return io.EOF and no names
if err == io.EOF {
return nil
}
if err != nil {
return fmt.Errorf("error reading directory %q: %v", fileName, err)
}
for _, name := range names {
if err := writeContents(filepath.Join(fileName, name), strip, tarw); err != nil {
return err
}
}
}
}
func createAndFill(filePath string, mode int64, content io.Reader) error {
fh, err := os.Create(filePath)
defer fh.Close()
if err != nil {
return fmt.Errorf("some of the tar contents cannot be written to disk: %v", err)
}
_, err = io.Copy(fh, content)
if err != nil {
return fmt.Errorf("failed while reading tar contents: %v", err)
}
err = os.Chmod(fh.Name(), os.FileMode(mode))
if err != nil {
return fmt.Errorf("cannot set proper mode on file %q: %v", filePath, err)
}
if err := fh.Sync(); err != nil {
return fmt.Errorf("failed to sync contents of file %v: %v", filePath, err)
}
if err := fh.Close(); err != nil {
return fmt.Errorf("failed to close file %v: %v", filePath, err)
}
return nil
}
// UntarFiles will extract the contents of tarFile using
// outputFolder as root
func UntarFiles(tarFile io.Reader, outputFolder string) error {
tr := tar.NewReader(tarFile)
// Ensure we still make directories for any files where we haven't
// already seen the directory (for example, juju backup generates
// files like this).
seenDirs := set.NewStrings()
maybeMkParentDir := func(path string) error {
dirName := filepath.Dir(path)
if seenDirs.Contains(dirName) {
return nil
}
err := os.MkdirAll(dirName, os.FileMode(0755))
if err != nil {
return fmt.Errorf("cannot create parent directory for %q: %v", path, err)
}
seenDirs.Add(dirName)
return nil
}
for {
hdr, err := tr.Next()
if err == io.EOF {
// end of tar archive
return nil
}
if err != nil {
return fmt.Errorf("failed while reading tar header: %v", err)
}
fullPath := filepath.Join(outputFolder, hdr.Name)
switch hdr.Typeflag {
case tar.TypeDir:
if err = os.MkdirAll(fullPath, os.FileMode(hdr.Mode)); err != nil {
return fmt.Errorf("cannot extract directory %q: %v", fullPath, err)
}
seenDirs.Add(fullPath)
case tar.TypeSymlink:
if err = maybeMkParentDir(fullPath); err != nil {
return err
}
if err = symlink.New(hdr.Linkname, fullPath); err != nil {
return fmt.Errorf("cannot extract symlink %q to %q: %v", hdr.Linkname, fullPath, err)
}
continue
case tar.TypeReg, tar.TypeRegA:
if err = maybeMkParentDir(fullPath); err != nil {
return err
}
if err = createAndFill(fullPath, hdr.Mode, tr); err != nil {
return fmt.Errorf("cannot extract file %q: %v", fullPath, err)
}
}
}
}
================================================
FILE: tar/tar_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package tar
import (
"archive/tar"
"bytes"
"crypto/sha1"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
stdtesting "testing"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
)
func TestPackage(t *stdtesting.T) {
gc.TestingT(t)
}
var _ = gc.Suite(&TarSuite{})
type TarSuite struct {
testing.IsolationSuite
cwd string
testFiles []string
}
func (t *TarSuite) SetUpTest(c *gc.C) {
t.cwd = c.MkDir()
t.IsolationSuite.SetUpTest(c)
}
func (t *TarSuite) createTestFiles(c *gc.C) {
tarDirE := filepath.Join(t.cwd, "TarDirectoryEmpty")
err := os.Mkdir(tarDirE, os.FileMode(0755))
c.Check(err, gc.IsNil)
tarDirP := filepath.Join(t.cwd, "TarDirectoryPopulated")
err = os.Mkdir(tarDirP, os.FileMode(0755))
c.Check(err, gc.IsNil)
tarlink1 := filepath.Join(t.cwd, "TarLink")
err = os.Symlink(tarDirP, tarlink1)
c.Check(err, gc.IsNil)
tarSubFile1 := filepath.Join(tarDirP, "TarSubFile1")
tarSubFile1Handle, err := os.Create(tarSubFile1)
c.Check(err, gc.IsNil)
tarSubFile1Handle.WriteString("TarSubFile1")
tarSubFile1Handle.Close()
tarSublink1 := filepath.Join(tarDirP, "TarSubLink")
err = os.Symlink(tarSubFile1, tarSublink1)
c.Check(err, gc.IsNil)
tarSubDir := filepath.Join(tarDirP, "TarDirectoryPopulatedSubDirectory")
err = os.Mkdir(tarSubDir, os.FileMode(0755))
c.Check(err, gc.IsNil)
tarFile1 := filepath.Join(t.cwd, "TarFile1")
tarFile1Handle, err := os.Create(tarFile1)
c.Check(err, gc.IsNil)
tarFile1Handle.WriteString("TarFile1")
tarFile1Handle.Close()
tarFile2 := filepath.Join(t.cwd, "TarFile2")
tarFile2Handle, err := os.Create(tarFile2)
c.Check(err, gc.IsNil)
tarFile2Handle.WriteString("TarFile2")
tarFile2Handle.Close()
t.testFiles = []string{tarDirE, tarDirP, tarlink1, tarFile1, tarFile2}
}
func (t *TarSuite) removeTestFiles(c *gc.C) {
for _, removable := range t.testFiles {
err := os.RemoveAll(removable)
c.Assert(err, gc.IsNil)
}
}
type expectedTarContents struct {
Name string
Body string
}
var testExpectedTarContents = []expectedTarContents{
{"TarDirectoryEmpty", ""},
{"TarDirectoryPopulated", ""},
{"TarLink", ""},
{"TarDirectoryPopulated/TarSubFile1", "TarSubFile1"},
{"TarDirectoryPopulated/TarSubLink", ""},
{"TarDirectoryPopulated/TarDirectoryPopulatedSubDirectory", ""},
{"TarFile1", "TarFile1"},
{"TarFile2", "TarFile2"},
}
// Assert thar contents checks that the tar reader provided contains the
// Expected files
// expectedContents: is a slice of the filenames with relative paths that are
// expected to be on the tar file
// tarFile: is the path of the file to be checked
func (t *TarSuite) assertTarContents(c *gc.C, expectedContents []expectedTarContents,
tarFile io.Reader) {
tr := tar.NewReader(tarFile)
tarContents := make(map[string]string)
// Iterate through the files in the archive.
for {
hdr, err := tr.Next()
if err == io.EOF {
// end of tar archive
break
}
c.Assert(err, gc.IsNil)
buf, err := ioutil.ReadAll(tr)
c.Assert(err, gc.IsNil)
tarContents[hdr.Name] = string(buf)
}
for _, expectedContent := range expectedContents {
fullExpectedContent := strings.TrimPrefix(expectedContent.Name, string(os.PathSeparator))
body, ok := tarContents[fullExpectedContent]
c.Log(tarContents)
c.Log(expectedContents)
c.Log(fmt.Sprintf("checking for presence of %q on tar file", fullExpectedContent))
c.Assert(ok, gc.Equals, true)
if expectedContent.Body != "" {
c.Log("Also checking the file contents")
c.Assert(body, gc.Equals, expectedContent.Body)
}
}
}
func (t *TarSuite) assertFilesWhereUntared(c *gc.C,
expectedContents []expectedTarContents,
tarOutputFolder string) {
tarContents := make(map[string]string)
var walkFn filepath.WalkFunc
walkFn = func(path string, finfo os.FileInfo, err error) error {
if err != nil {
return err
}
fileName := strings.TrimPrefix(path, tarOutputFolder)
fileName = strings.TrimPrefix(fileName, string(os.PathSeparator))
c.Log(fileName)
if fileName == "" {
return nil
}
if finfo.IsDir() || finfo.Mode()&os.ModeSymlink == os.ModeSymlink {
tarContents[fileName] = ""
} else {
readable, err := os.Open(path)
if err != nil {
return err
}
defer readable.Close()
buf, err := ioutil.ReadAll(readable)
c.Assert(err, gc.IsNil)
tarContents[fileName] = string(buf)
}
return nil
}
filepath.Walk(tarOutputFolder, walkFn)
for _, expectedContent := range expectedContents {
fullExpectedContent := strings.TrimPrefix(expectedContent.Name, string(os.PathSeparator))
expectedPath := filepath.Join(tarOutputFolder, fullExpectedContent)
_, err := os.Lstat(expectedPath)
c.Assert(err, gc.Equals, nil)
body, ok := tarContents[fullExpectedContent]
c.Log(fmt.Sprintf("checking for presence of %q on untar files", fullExpectedContent))
c.Assert(ok, gc.Equals, true)
if expectedContent.Body != "" {
c.Log("Also checking the file contents")
c.Assert(body, gc.Equals, expectedContent.Body)
}
}
}
func shaSumFile(c *gc.C, fileToSum io.Reader) string {
shahash := sha1.New()
_, err := io.Copy(shahash, fileToSum)
c.Assert(err, gc.IsNil)
return base64.StdEncoding.EncodeToString(shahash.Sum(nil))
}
// Tar
func (t *TarSuite) TestTarFiles(c *gc.C) {
t.createTestFiles(c)
var outputTar bytes.Buffer
trimPath := fmt.Sprintf("%s/", t.cwd)
shaSum, err := TarFiles(t.testFiles, &outputTar, trimPath)
c.Check(err, gc.IsNil)
outputBytes := outputTar.Bytes()
fileShaSum := shaSumFile(c, bytes.NewBuffer(outputBytes))
c.Assert(shaSum, gc.Equals, fileShaSum)
t.removeTestFiles(c)
t.assertTarContents(c, testExpectedTarContents, bytes.NewBuffer(outputBytes))
}
func (t *TarSuite) TestSymlinksTar(c *gc.C) {
tarDirP := filepath.Join(t.cwd, "TarDirectory")
err := os.Mkdir(tarDirP, os.FileMode(0755))
c.Check(err, gc.IsNil)
tarlink1 := filepath.Join(t.cwd, "TarLink")
err = os.Symlink(tarDirP, tarlink1)
c.Check(err, gc.IsNil)
testFiles := []string{tarDirP, tarlink1}
var outputTar bytes.Buffer
trimPath := fmt.Sprintf("%s/", t.cwd)
_, err = TarFiles(testFiles, &outputTar, trimPath)
c.Check(err, gc.IsNil)
outputBytes := outputTar.Bytes()
tr := tar.NewReader(bytes.NewBuffer(outputBytes))
symlinks := 0
for {
hdr, err := tr.Next()
if err == io.EOF {
// end of tar archive
break
}
c.Assert(err, gc.IsNil)
if hdr.Typeflag == tar.TypeSymlink {
symlinks += 1
c.Assert(hdr.Linkname, gc.Equals, tarDirP)
}
}
c.Assert(symlinks, gc.Equals, 1)
}
// UnTar
func (t *TarSuite) TestUnTarFilesUncompressed(c *gc.C) {
t.createTestFiles(c)
var outputTar bytes.Buffer
trimPath := fmt.Sprintf("%s/", t.cwd)
_, err := TarFiles(t.testFiles, &outputTar, trimPath)
c.Check(err, gc.IsNil)
t.removeTestFiles(c)
outputDir := filepath.Join(t.cwd, "TarOuputFolder")
err = os.Mkdir(outputDir, os.FileMode(0755))
c.Check(err, gc.IsNil)
UntarFiles(&outputTar, outputDir)
t.assertFilesWhereUntared(c, testExpectedTarContents, outputDir)
}
func (t *TarSuite) TestFindFileFound(c *gc.C) {
t.createTestFiles(c)
var outputTar bytes.Buffer
trimPath := fmt.Sprintf("%s/", t.cwd)
_, err := TarFiles(t.testFiles, &outputTar, trimPath)
c.Assert(err, gc.IsNil)
t.removeTestFiles(c)
_, file, err := FindFile(&outputTar, "TarDirectoryPopulated/TarSubFile1")
c.Assert(err, gc.IsNil)
data, err := ioutil.ReadAll(file)
c.Assert(err, gc.IsNil)
c.Check(string(data), gc.Equals, "TarSubFile1")
}
func (t *TarSuite) TestFindFileNotFound(c *gc.C) {
t.createTestFiles(c)
var outputTar bytes.Buffer
trimPath := fmt.Sprintf("%s/", t.cwd)
_, err := TarFiles(t.testFiles, &outputTar, trimPath)
c.Assert(err, gc.IsNil)
t.removeTestFiles(c)
_, _, err = FindFile(&outputTar, "does_not_exist")
c.Check(err, gc.ErrorMatches, "does_not_exist not found")
}
func (t *TarSuite) TestUntarFilesHeadersIgnored(c *gc.C) {
var buf bytes.Buffer
w := tar.NewWriter(&buf)
err := w.WriteHeader(&tar.Header{
Name: "pax_global_header",
Typeflag: tar.TypeXGlobalHeader,
})
c.Assert(err, gc.IsNil)
err = w.Flush()
c.Assert(err, gc.IsNil)
err = UntarFiles(&buf, t.cwd)
c.Assert(err, jc.ErrorIsNil)
err = filepath.Walk(t.cwd, func(path string, finfo os.FileInfo, err error) error {
if path != t.cwd {
return fmt.Errorf("unexpected file: %v", path)
}
return err
})
c.Assert(err, gc.IsNil)
}
func (t *TarSuite) TestUntarFilesWithMissingDirectories(c *gc.C) {
var buf bytes.Buffer
w := tar.NewWriter(&buf)
contents := []byte("file contents")
err := w.WriteHeader(&tar.Header{
Name: "missingdir/otherdir/file",
Typeflag: tar.TypeReg,
Mode: 0700,
Size: int64(len(contents)),
})
c.Assert(err, jc.ErrorIsNil)
_, err = w.Write(contents)
c.Assert(err, jc.ErrorIsNil)
err = w.WriteHeader(&tar.Header{
Name: "missingdir/otherdir/link",
Typeflag: tar.TypeSymlink,
Linkname: "viginti",
})
c.Assert(err, jc.ErrorIsNil)
err = w.Flush()
c.Assert(err, jc.ErrorIsNil)
err = UntarFiles(&buf, t.cwd)
c.Assert(err, jc.ErrorIsNil)
var names []string
err = filepath.Walk(t.cwd, func(path string, finfo os.FileInfo, err error) error {
names = append(names, path[len(t.cwd):])
return nil
})
c.Assert(err, jc.ErrorIsNil)
expected := []string{
"",
"/missingdir",
"/missingdir/otherdir",
"/missingdir/otherdir/file",
"/missingdir/otherdir/link",
}
c.Assert(names, gc.DeepEquals, expected)
}
================================================
FILE: timer.go
================================================
// Copyright 2015 Canonical Ltd.
// Copyright 2015 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"math/rand"
"time"
"github.com/juju/clock"
)
// Countdown implements a timer that will call a provided function.
// after a internally stored duration. The steps as well as min and max
// durations are declared upon initialization and depend on
// the particular implementation.
//
// TODO(katco): 2016-08-09: This type is deprecated: lp:1611427
type Countdown interface {
// Reset stops the timer and resets its duration to the minimum one.
// Start must be called to start the timer again.
Reset()
// Start starts the internal timer.
// At the end of the timer, if Reset hasn't been called in the mean time
// Func will be called and the duration is increased for the next call.
Start()
}
// NewBackoffTimer creates and initializes a new BackoffTimer
// A backoff timer starts at min and gets multiplied by factor
// until it reaches max. Jitter determines whether a small
// randomization is added to the duration.
//
// TODO(katco): 2016-08-09: This type is deprecated: lp:1611427
func NewBackoffTimer(config BackoffTimerConfig) *BackoffTimer {
return &BackoffTimer{
config: config,
currentDuration: config.Min,
}
}
// BackoffTimer implements Countdown.
// A backoff timer starts at min and gets multiplied by factor
// until it reaches max. Jitter determines whether a small
// randomization is added to the duration.
//
// TODO(katco): 2016-08-09: This type is deprecated: lp:1611427
type BackoffTimer struct {
config BackoffTimerConfig
timer clock.Timer
currentDuration time.Duration
}
// BackoffTimerConfig is a helper struct for backoff timer
// that encapsulates config information.
//
// TODO(katco): 2016-08-09: This type is deprecated: lp:1611427
type BackoffTimerConfig struct {
// The minimum duration after which Func is called.
Min time.Duration
// The maximum duration after which Func is called.
Max time.Duration
// Determines whether a small randomization is applied to
// the duration.
Jitter bool
// The factor by which you want the duration to increase
// every time.
Factor int64
// Func is the function that will be called when the countdown reaches 0.
Func func()
// Clock provides the AfterFunc function used to call func.
// It is exposed here so it's easier to mock it in tests.
Clock clock.Clock
}
// Start implements the Timer interface.
// Any existing timer execution is stopped before
// a new one is created.
func (t *BackoffTimer) Start() {
if t.timer != nil {
t.timer.Stop()
}
t.timer = t.config.Clock.AfterFunc(t.currentDuration, t.config.Func)
// Since it's a backoff timer we will increase
// the duration after each signal.
t.increaseDuration()
}
// Reset implements the Timer interface.
func (t *BackoffTimer) Reset() {
if t.timer != nil {
t.timer.Stop()
}
if t.currentDuration > t.config.Min {
t.currentDuration = t.config.Min
}
}
// increaseDuration will increase the duration based on
// the current value and the factor. If jitter is true
// it will add a 0.3% jitter to the final value.
func (t *BackoffTimer) increaseDuration() {
current := int64(t.currentDuration)
nextDuration := time.Duration(current * t.config.Factor)
if t.config.Jitter {
// Get a factor in [-1; 1].
randFactor := (rand.Float64() * 2) - 1
jitter := float64(nextDuration) * randFactor * 0.03
nextDuration = nextDuration + time.Duration(jitter)
}
if nextDuration > t.config.Max {
nextDuration = t.config.Max
}
t.currentDuration = nextDuration
}
================================================
FILE: timer_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Copyright 2015 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"math"
"time"
gc "gopkg.in/check.v1"
"github.com/juju/clock"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
"github.com/juju/utils/v4"
)
type TestStdTimer struct {
stdStub *testing.Stub
}
func (t *TestStdTimer) Stop() bool {
t.stdStub.AddCall("Stop")
return true
}
func (t *TestStdTimer) Reset(d time.Duration) bool {
t.stdStub.AddCall("Reset", d)
return true
}
func (t *TestStdTimer) Chan() <-chan time.Time {
panic("should not be called")
}
type timerSuite struct {
baseSuite testing.CleanupSuite
timer *utils.BackoffTimer
afterFuncCalls int64
properFuncCalled bool
stub *testing.Stub
min time.Duration
max time.Duration
factor int64
}
var _ = gc.Suite(&timerSuite{})
type mockClock struct {
stub *testing.Stub
c *gc.C
afterFuncCalls *int64
properFuncCalled *bool
}
// These 2 methods are not used here but are needed to satisfy the intergface
func (c *mockClock) Now() time.Time { return time.Now() }
func (c *mockClock) After(d time.Duration) <-chan time.Time { return time.After(d) }
func (c *mockClock) NewTimer(d time.Duration) clock.Timer {
panic("should not be called")
}
func (c *mockClock) AfterFunc(d time.Duration, f func()) clock.Timer {
*c.afterFuncCalls++
f()
c.c.Assert(*c.properFuncCalled, jc.IsTrue)
*c.properFuncCalled = false
return &TestStdTimer{c.stub}
}
func (s *timerSuite) SetUpTest(c *gc.C) {
s.baseSuite.SetUpTest(c)
s.stub = nil
s.timer = nil
}
func (s *timerSuite) setup(c *gc.C) {
s.afterFuncCalls = 0
s.stub = &testing.Stub{}
// This along with the checks in afterFuncMock below assert
// that mockFunc is indeed passed as the argument to afterFuncMock
// to be executed.
mockFunc := func() { s.properFuncCalled = true }
mockClock := &mockClock{
stub: s.stub,
c: c,
afterFuncCalls: &s.afterFuncCalls,
properFuncCalled: &s.properFuncCalled,
}
s.min = 2 * time.Second
s.max = 16 * time.Second
s.factor = 2
s.timer = utils.NewBackoffTimer(
utils.BackoffTimerConfig{
Min: s.min,
Max: s.max,
Jitter: false,
Factor: s.factor,
Func: mockFunc,
Clock: mockClock,
},
)
}
func (s *timerSuite) TestStart(c *gc.C) {
s.setup(c)
s.timer.Start()
s.testStart(c, 1, 1)
}
func (s *timerSuite) TestMultipleStarts(c *gc.C) {
s.setup(c)
s.timer.Start()
s.testStart(c, 1, 1)
s.timer.Start()
s.checkStopCalls(c, 1)
s.testStart(c, 2, 2)
s.timer.Start()
s.checkStopCalls(c, 2)
s.testStart(c, 3, 3)
}
func (s *timerSuite) TestResetNoStart(c *gc.C) {
s.setup(c)
s.timer.Reset()
currentDuration := utils.ExposeBackoffTimerDuration(s.timer)
c.Assert(currentDuration, gc.Equals, s.min)
}
func (s *timerSuite) TestResetAndStart(c *gc.C) {
s.setup(c)
s.timer.Reset()
currentDuration := utils.ExposeBackoffTimerDuration(s.timer)
c.Assert(currentDuration, gc.Equals, s.min)
// These variables are used to track the number
// of afterFuncCalls(signalCallsNo) and the number
// of Stop calls(resetStopCallsNo + signalCallsNo)
resetStopCallsNo := 0
signalCallsNo := 0
signalCallsNo++
s.timer.Start()
s.testStart(c, 1, 1)
resetStopCallsNo++
s.timer.Reset()
s.checkStopCalls(c, resetStopCallsNo+signalCallsNo-1)
currentDuration = utils.ExposeBackoffTimerDuration(s.timer)
c.Assert(currentDuration, gc.Equals, s.min)
for i := 1; i < 200; i++ {
signalCallsNo++
s.timer.Start()
s.testStart(c, int64(signalCallsNo), int64(i))
s.checkStopCalls(c, resetStopCallsNo+signalCallsNo-1)
}
resetStopCallsNo++
s.timer.Reset()
s.checkStopCalls(c, signalCallsNo+resetStopCallsNo-1)
for i := 1; i < 100; i++ {
signalCallsNo++
s.timer.Start()
s.testStart(c, int64(signalCallsNo), int64(i))
s.checkStopCalls(c, resetStopCallsNo+signalCallsNo-1)
}
resetStopCallsNo++
s.timer.Reset()
s.checkStopCalls(c, signalCallsNo+resetStopCallsNo-1)
}
func (s *timerSuite) testStart(c *gc.C, afterFuncCalls int64, durationFactor int64) {
c.Assert(s.afterFuncCalls, gc.Equals, afterFuncCalls)
c.Logf("iteration %d", afterFuncCalls)
expectedDuration := time.Duration(math.Pow(float64(s.factor), float64(durationFactor))) * s.min
if expectedDuration > s.max || expectedDuration <= 0 {
expectedDuration = s.max
}
currentDuration := utils.ExposeBackoffTimerDuration(s.timer)
c.Assert(currentDuration, gc.Equals, expectedDuration)
}
func (s *timerSuite) checkStopCalls(c *gc.C, number int) {
calls := make([]testing.StubCall, number)
for i := 0; i < number; i++ {
calls[i] = testing.StubCall{FuncName: "Stop"}
}
s.stub.CheckCalls(c, calls)
}
================================================
FILE: trivial.go
================================================
// Copyright 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"bytes"
"compress/gzip"
"crypto/sha256"
"encoding/hex"
"io"
"os"
"strings"
"unicode"
)
// TODO(ericsnow) Move the quoting helpers into the shell package?
// ShQuote quotes s so that when read by bash, no metacharacters
// within s will be interpreted as such.
func ShQuote(s string) string {
// single-quote becomes single-quote, double-quote, single-quote, double-quote, single-quote
return `'` + strings.Replace(s, `'`, `'"'"'`, -1) + `'`
}
// WinPSQuote quotes s so that when read by powershell, no metacharacters
// within s will be interpreted as such.
func WinPSQuote(s string) string {
// See http://ss64.com/ps/syntax-esc.html#quotes.
// Double quotes inside single quotes are fine, double single quotes inside
// single quotes, not so much so. Having double quoted strings inside single
// quoted strings, ensure no expansion happens.
return `'` + strings.Replace(s, `'`, `"`, -1) + `'`
}
// WinCmdQuote quotes s so that when read by cmd.exe, no metacharacters
// within s will be interpreted as such.
func WinCmdQuote(s string) string {
// See http://blogs.msdn.com/b/twistylittlepassagesallalike/archive/2011/04/23/everyone-quotes-arguments-the-wrong-way.aspx.
quoted := winCmdQuote(s)
return winCmdEscapeMeta(quoted)
}
func winCmdQuote(s string) string {
var escaped string
for _, c := range s {
switch c {
case '\\', '"':
escaped += `\`
}
escaped += string(c)
}
return `"` + escaped + `"`
}
func winCmdEscapeMeta(str string) string {
const meta = `()%!^"<>&|`
var newStr string
for _, c := range str {
if strings.Contains(meta, string(c)) {
newStr += "^"
}
newStr += string(c)
}
return newStr
}
// CommandString flattens a sequence of command arguments into a
// string suitable for executing in a shell, escaping slashes,
// variables and quotes as necessary; each argument is double-quoted
// if and only if necessary.
func CommandString(args ...string) string {
var buf bytes.Buffer
for i, arg := range args {
needsQuotes := false
var argBuf bytes.Buffer
for _, r := range arg {
if unicode.IsSpace(r) {
needsQuotes = true
} else if r == '"' || r == '$' || r == '\\' {
needsQuotes = true
argBuf.WriteByte('\\')
}
argBuf.WriteRune(r)
}
if i > 0 {
buf.WriteByte(' ')
}
if needsQuotes {
buf.WriteByte('"')
_, _ = argBuf.WriteTo(&buf)
buf.WriteByte('"')
} else {
_, _ = argBuf.WriteTo(&buf)
}
}
return buf.String()
}
// Gzip compresses the given data.
func Gzip(data []byte) []byte {
var buf bytes.Buffer
w := gzip.NewWriter(&buf)
if _, err := w.Write(data); err != nil {
// Compression should never fail unless it fails
// to write to the underlying writer, which is a bytes.Buffer
// that never fails.
panic(err)
}
if err := w.Close(); err != nil {
panic(err)
}
return buf.Bytes()
}
// Gunzip uncompresses the given data.
func Gunzip(data []byte) ([]byte, error) {
r, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, err
}
return io.ReadAll(r)
}
// ReadSHA256 returns the SHA256 hash of the contents read from source
// (hex encoded) and the size of the source in bytes.
func ReadSHA256(source io.Reader) (string, int64, error) {
hash := sha256.New()
size, err := io.Copy(hash, source)
if err != nil {
return "", 0, err
}
digest := hex.EncodeToString(hash.Sum(nil))
return digest, size, nil
}
// ReadFileSHA256 is like ReadSHA256 but reads the contents of the
// given file.
func ReadFileSHA256(filename string) (string, int64, error) {
f, err := os.Open(filename)
if err != nil {
return "", 0, err
}
defer func() { _ = f.Close() }()
return ReadSHA256(f)
}
================================================
FILE: trivial_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"bytes"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"github.com/juju/testing"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type utilsSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&utilsSuite{})
func (*utilsSuite) TestCompression(c *gc.C) {
data := []byte(strings.Repeat("some data to be compressed\n", 100))
compressedData := []byte{
0x1f, 0x8b, 0x08, 0x00, 0x33, 0xb5, 0xf6, 0x50,
0x00, 0x03, 0xed, 0xc9, 0xb1, 0x0d, 0x00, 0x20,
0x08, 0x45, 0xc1, 0xde, 0x29, 0x58, 0x0d, 0xe5,
0x97, 0x04, 0x23, 0xee, 0x1f, 0xa7, 0xb0, 0x7b,
0xd7, 0x5e, 0x57, 0xca, 0xc2, 0xaf, 0xdb, 0x2d,
0x9b, 0xb2, 0x55, 0xb9, 0x8f, 0xba, 0x15, 0xa3,
0x29, 0x8a, 0xa2, 0x28, 0x8a, 0xa2, 0x28, 0xea,
0x67, 0x3d, 0x71, 0x71, 0x6e, 0xbf, 0x8c, 0x0a,
0x00, 0x00,
}
cdata := utils.Gzip(data)
c.Assert(len(cdata) < len(data), gc.Equals, true)
data1, err := utils.Gunzip(cdata)
c.Assert(err, gc.IsNil)
c.Assert(data1, gc.DeepEquals, data)
data1, err = utils.Gunzip(compressedData)
c.Assert(err, gc.IsNil)
c.Assert(data1, gc.DeepEquals, data)
}
func checkQuoting(c *gc.C, shQuote func(string) string, tests map[string]string) {
for str, expected := range tests {
c.Logf("- checking %q -", str)
quoted := shQuote(str)
c.Check(quoted, gc.Equals, expected)
}
}
func (*utilsSuite) TestWinCmdQuote(c *gc.C) {
args := map[string]string{
"": `^"^"`,
"a": `^"a^"`,
"'a'": `^"'a'^"`,
`"a`: `^"\^"a^"`,
`a"`: `^"a\^"^"`,
`"a"`: `^"\^"a\^"^"`,
"abc > xyz 2>&1 &": `^"abc ^> xyz 2^>^&1 ^&^"`,
}
checkQuoting(c, utils.WinCmdQuote, args)
}
func (*utilsSuite) TestWinPSQuote(c *gc.C) {
args := map[string]string{
"": "''",
"a": `'a'`,
`"a"`: `'"a"'`,
"'a": `'"a'`,
"a'": `'a"'`,
"'a'": `'"a"'`,
"abc > xyz 2>&1 &": "'abc > xyz 2>&1 &'",
}
checkQuoting(c, utils.WinPSQuote, args)
}
func (*utilsSuite) TestCommandString(c *gc.C) {
type test struct {
args []string
expected string
}
tests := []test{
{nil, ""},
{[]string{"a"}, "a"},
{[]string{"a$"}, `"a\$"`},
{[]string{""}, ""},
{[]string{"\\"}, `"\\"`},
{[]string{"a", "'b'"}, "a 'b'"},
{[]string{"a b"}, `"a b"`},
{[]string{"a", `"b"`}, `a "\"b\""`},
{[]string{"a", `"b\"`}, `a "\"b\\\""`},
{[]string{"a\n"}, "\"a\n\""},
}
for i, test := range tests {
c.Logf("test %d: %q", i, test.args)
result := utils.CommandString(test.args...)
c.Assert(result, gc.Equals, test.expected)
}
}
func (*utilsSuite) TestReadSHA256AndReadFileSHA256(c *gc.C) {
sha256Tests := []struct {
content string
sha256 string
}{{
content: "",
sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
}, {
content: "some content",
sha256: "290f493c44f5d63d06b374d0a5abd292fae38b92cab2fae5efefe1b0e9347f56",
}, {
content: "foo",
sha256: "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae",
}, {
content: "Foo",
sha256: "1cbec737f863e4922cee63cc2ebbfaafcd1cff8b790d8cfd2e6a5d550b648afa",
}, {
content: "multi\nline\ntext\nhere",
sha256: "c384f11c0294280792a44d9d6abb81f9fd991904cb7eb851a88311b04114231e",
}}
tempDir := c.MkDir()
for i, test := range sha256Tests {
c.Logf("test %d: %q -> %q", i, test.content, test.sha256)
buf := bytes.NewBufferString(test.content)
hash, size, err := utils.ReadSHA256(buf)
c.Check(err, gc.IsNil)
c.Check(hash, gc.Equals, test.sha256)
c.Check(int(size), gc.Equals, len(test.content))
tempFileName := filepath.Join(tempDir, fmt.Sprintf("sha256-%d", i))
err = ioutil.WriteFile(tempFileName, []byte(test.content), 0644)
c.Check(err, gc.IsNil)
fileHash, fileSize, err := utils.ReadFileSHA256(tempFileName)
c.Check(err, gc.IsNil)
c.Check(fileHash, gc.Equals, hash)
c.Check(fileSize, gc.Equals, size)
}
}
================================================
FILE: uptime/uptime_nix.go
================================================
// Copyright 2014 Cloudbase Solutions SRL
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
//
//go:build !windows
// +build !windows
package uptime
import (
"syscall"
)
// Uptime returns the number of seconds since the system has booted
func Uptime() (int64, error) {
info := &syscall.Sysinfo_t{}
err := syscall.Sysinfo(info)
if err != nil {
return 0, err
}
return int64(info.Uptime), nil
}
================================================
FILE: uptime/uptime_windows.go
================================================
// Copyright 2014 Cloudbase Solutions SRL
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package uptime
import (
"fmt"
)
//sys getTickCount64() (uptime uint64, err error) = GetTickCount64
// Uptime returns the number of seconds since the system has booted
func Uptime() (int64, error) {
uptime, err := getTickCount64()
if err != nil {
return 0, fmt.Errorf("Failed to get uptime. Error number: %v", err)
}
return int64(uptime) / 1000, nil
}
================================================
FILE: uptime/zuptime_windows_386.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
// mksyscall_windows.pl -l32 uptime_windows.go
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
package uptime
import "syscall"
var (
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
procGetTickCount64 = modkernel32.NewProc("GetTickCount64")
)
func getTickCount64() (uptime uint64, err error) {
r0, _, e1 := syscall.Syscall(procGetTickCount64.Addr(), 0, 0, 0, 0)
uptime = uint64(r0)
if uptime == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
================================================
FILE: uptime/zuptime_windows_amd64.go
================================================
// Copyright 2014 Canonical Ltd.
// Copyright 2014 Cloudbase Solutions SRL
// Licensed under the LGPLv3, see LICENCE file for details.
// mksyscall_windows.pl uptime_windows.go
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
package uptime
import "syscall"
var (
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
procGetTickCount64 = modkernel32.NewProc("GetTickCount64")
)
func getTickCount64() (uptime uint64, err error) {
r0, _, e1 := syscall.Syscall(procGetTickCount64.Addr(), 0, 0, 0, 0)
uptime = uint64(r0)
if uptime == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
================================================
FILE: username.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"os"
"os/user"
"github.com/juju/errors"
)
// ResolveSudo returns the original username if sudo was used. The
// original username is extracted from the OS environment.
func ResolveSudo(username string) string {
return resolveSudo(username, os.Getenv)
}
func resolveSudo(username string, getenvFunc func(string) string) string {
if username != "root" {
return username
}
// sudo was probably called, get the original user.
if username := getenvFunc("SUDO_USER"); username != "" {
return username
}
return username
}
// EnvUsername returns the username from the OS environment.
func EnvUsername() (string, error) {
return os.Getenv("USER"), nil
}
// OSUsername returns the username of the current OS user (based on UID).
func OSUsername() (string, error) {
u, err := user.Current()
if err != nil {
return "", errors.Trace(err)
}
return u.Username, nil
}
// ResolveUsername returns the username determined by the provided
// functions. The functions are tried in the same order in which they
// were passed in. An error returned from any of them is immediately
// returned. If an empty string is returned then that signals that the
// function did not find the username and the next function is tried.
// Once a username is found, the provided resolveSudo func (if any) is
// called with that username and the result is returned. If no username
// is found then errors.NotFound is returned.
func ResolveUsername(resolveSudo func(string) string, usernameFuncs ...func() (string, error)) (string, error) {
for _, usernameFunc := range usernameFuncs {
username, err := usernameFunc()
if err != nil {
return "", errors.Trace(err)
}
if username != "" {
if resolveSudo != nil {
if original := resolveSudo(username); original != "" {
username = original
}
}
return username, nil
}
}
return "", errors.NotFoundf("username")
}
// LocalUsername determines the current username on the local host.
func LocalUsername() (string, error) {
username, err := ResolveUsername(ResolveSudo, EnvUsername, OSUsername)
if err != nil {
return "", errors.Annotatef(err, "cannot get current user from the environment: %v", os.Environ())
}
return username, nil
}
================================================
FILE: username_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"github.com/juju/errors"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
var _ = gc.Suite(&usernameSuite{})
type usernameSuite struct {
testing.IsolationSuite
}
func (s *usernameSuite) TestResolveUsername(c *gc.C) {
type test struct {
userEnv string
sudoEnv string
userOS string
expected string
err string
}
tests := []test{{
userEnv: "someone",
sudoEnv: "notroot",
userOS: "other",
expected: "someone",
}, {
userOS: "other",
expected: "other",
}, {
userEnv: "root",
expected: "root",
}, {
userEnv: "root",
sudoEnv: "other",
expected: "other",
}, {
err: "failed to determine username for namespace: oh noes",
}}
resolveUsername := func(t test) (string, error) {
if t.err != "" {
return "", errors.New(t.err)
}
var funcs []func() (string, error)
if t.userEnv != "" {
funcs = append(funcs, func() (string, error) {
return t.userEnv, nil
})
}
if t.userOS != "" {
funcs = append(funcs, func() (string, error) {
return t.userOS, nil
})
}
resolveSudo := func(username string) string {
return utils.ResolveSudoByFunc(username, func(string) string {
return t.sudoEnv
})
}
return utils.ResolveUsername(resolveSudo, funcs...)
}
for i, test := range tests {
c.Logf("test %d: %v", i, test)
username, err := resolveUsername(test)
if test.err == "" {
if c.Check(err, jc.ErrorIsNil) {
c.Check(username, gc.Equals, test.expected)
}
} else {
c.Check(err, gc.ErrorMatches, test.err)
}
}
}
================================================
FILE: uuid.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"regexp"
"strings"
)
// UUID represent a universal identifier with 16 octets.
type UUID [16]byte
// regex for validating that the UUID matches RFC 4122.
// This package generates version 4 UUIDs but
// accepts any UUID version.
// http://www.ietf.org/rfc/rfc4122.txt
var (
block1 = "[0-9a-f]{8}"
block2 = "[0-9a-f]{4}"
block3 = "[0-9a-f]{4}"
block4 = "[0-9a-f]{4}"
block5 = "[0-9a-f]{12}"
UUIDSnippet = block1 + "-" + block2 + "-" + block3 + "-" + block4 + "-" + block5
validUUID = regexp.MustCompile("^" + UUIDSnippet + "$")
)
func UUIDFromString(s string) (UUID, error) {
if !IsValidUUIDString(s) {
return UUID{}, fmt.Errorf("invalid UUID: %q", s)
}
s = strings.Replace(s, "-", "", 4)
raw, err := hex.DecodeString(s)
if err != nil {
return UUID{}, err
}
var uuid UUID
copy(uuid[:], raw)
return uuid, nil
}
// IsValidUUIDString returns true, if the given string matches a valid UUID (version 4, variant 2).
func IsValidUUIDString(s string) bool {
return validUUID.MatchString(s)
}
// MustNewUUID returns a new uuid, if an error occurs it panics.
func MustNewUUID() UUID {
uuid, err := NewUUID()
if err != nil {
panic(err)
}
return uuid
}
// NewUUID generates a new version 4 UUID relying only on random numbers.
func NewUUID() (UUID, error) {
uuid := UUID{}
if _, err := io.ReadFull(rand.Reader, []byte(uuid[0:16])); err != nil {
return UUID{}, err
}
// Set version (4) and variant (2) according to RfC 4122.
var version byte = 4 << 4
var variant byte = 8 << 4
uuid[6] = version | (uuid[6] & 15)
uuid[8] = variant | (uuid[8] & 15)
return uuid, nil
}
// Copy returns a copy of the UUID.
func (uuid UUID) Copy() UUID {
uuidCopy := uuid
return uuidCopy
}
// Raw returns a copy of the UUID bytes.
func (uuid UUID) Raw() [16]byte {
return [16]byte(uuid)
}
// String returns a hexadecimal string representation with
// standardized separators.
func (uuid UUID) String() string {
return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16])
}
================================================
FILE: uuid_test.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils_test
import (
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4"
)
type uuidSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&uuidSuite{})
func (*uuidSuite) TestUUID(c *gc.C) {
uuid, err := utils.NewUUID()
c.Assert(err, gc.IsNil)
uuidCopy := uuid.Copy()
uuidRaw := uuid.Raw()
uuidStr := uuid.String()
c.Assert(uuidRaw, gc.HasLen, 16)
c.Assert(uuidStr, jc.Satisfies, utils.IsValidUUIDString)
uuid[0] = 0x00
uuidCopy[0] = 0xFF
c.Assert(uuid, gc.Not(gc.DeepEquals), uuidCopy)
uuidRaw[0] = 0xFF
c.Assert(uuid, gc.Not(gc.DeepEquals), uuidRaw)
nextUUID, err := utils.NewUUID()
c.Assert(err, gc.IsNil)
c.Assert(uuid, gc.Not(gc.DeepEquals), nextUUID)
}
func (*uuidSuite) TestIsValidUUIDFailsWhenNotValid(c *gc.C) {
tests := []struct {
input string
expected bool
}{
{
utils.UUID{}.String(),
true,
},
{
"",
false,
},
{
"blah",
false,
},
{
"blah-9f484882-2f18-4fd2-967d-db9663db7bea",
false,
},
{
"9f484882-2f18-4fd2-967d-db9663db7bea-blah",
false,
},
{
"9f484882-2f18-4fd2-967d-db9663db7bea",
true,
},
}
for i, t := range tests {
c.Logf("Running test %d", i)
c.Check(utils.IsValidUUIDString(t.input), gc.Equals, t.expected)
}
}
func (*uuidSuite) TestUUIDFromString(c *gc.C) {
_, err := utils.UUIDFromString("blah")
c.Assert(err, gc.ErrorMatches, `invalid UUID: "blah"`)
validUUID := "9f484882-2f18-4fd2-967d-db9663db7bea"
uuid, err := utils.UUIDFromString(validUUID)
c.Assert(err, gc.IsNil)
c.Assert(uuid.String(), gc.Equals, validUUID)
}
================================================
FILE: voyeur/package_test.go
================================================
// Copyright 2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package voyeur
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: voyeur/value.go
================================================
// Copyright 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// Package voyeur implements a concurrency-safe value that can be watched for
// changes.
package voyeur
import (
"sync"
)
// Value represents a shared value that can be watched for changes. Methods on
// a Value may be called concurrently. The zero Value is
// ok to use, and is equivalent to a NewValue result
// with a nil initial value.
type Value struct {
val any
version int
mu sync.RWMutex
wait sync.Cond
closed bool
}
// NewValue creates a new Value holding the given initial value. If initial is
// nil, any watchers will wait until a value is set.
func NewValue(initial any) *Value {
v := new(Value)
v.init()
if initial != nil {
v.val = initial
v.version++
}
return v
}
func (v *Value) needsInit() bool {
return v.wait.L == nil
}
func (v *Value) init() {
if v.needsInit() {
v.wait.L = v.mu.RLocker()
}
}
// Set sets the shared value to val.
func (v *Value) Set(val any) {
v.mu.Lock()
v.init()
v.val = val
v.version++
v.mu.Unlock()
v.wait.Broadcast()
}
// Close closes the Value, unblocking any outstanding watchers. Close always
// returns nil.
func (v *Value) Close() error {
v.mu.Lock()
v.init()
v.closed = true
v.mu.Unlock()
v.wait.Broadcast()
return nil
}
// Closed reports whether the value has been closed.
func (v *Value) Closed() bool {
v.mu.RLock()
defer v.mu.RUnlock()
return v.closed
}
// Get returns the current value.
func (v *Value) Get() any {
v.mu.RLock()
defer v.mu.RUnlock()
return v.val
}
// Watch returns a Watcher that can be used to watch for changes to the value.
func (v *Value) Watch() *Watcher {
return &Watcher{value: v}
}
// Watcher represents a single watcher of a shared value.
type Watcher struct {
value *Value
version int
current any
closed bool
}
// Next blocks until there is a new value to be retrieved from the value that is
// being watched. It also unblocks when the value or the Watcher itself is
// closed. Next returns false if the value or the Watcher itself have been
// closed.
func (w *Watcher) Next() bool {
val := w.value
val.mu.RLock()
defer val.mu.RUnlock()
if val.needsInit() {
val.mu.RUnlock()
val.mu.Lock()
val.init()
val.mu.Unlock()
val.mu.RLock()
}
// We can go around this loop a maximum of two times,
// because the only thing that can cause a Wait to
// return is for the condition to be triggered,
// which can only happen if the value is set (causing
// the version to increment) or it is closed
// causing the closed flag to be set.
// Both these cases will cause Next to return.
for {
if w.version != val.version {
w.version = val.version
w.current = val.val
return true
}
if val.closed || w.closed {
return false
}
// Wait releases the lock until triggered and then reacquires the lock,
// thus avoiding a deadlock.
val.wait.Wait()
}
}
// Close closes the Watcher without closing the underlying
// value. It may be called concurrently with Next.
func (w *Watcher) Close() {
w.value.mu.Lock()
w.value.init()
w.closed = true
w.value.mu.Unlock()
w.value.wait.Broadcast()
}
// Value returns the last value that was retrieved from the watched Value by
// Next.
func (w *Watcher) Value() any {
return w.current
}
================================================
FILE: voyeur/value_test.go
================================================
// Copyright 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package voyeur
import (
"fmt"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
)
type suite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&suite{})
func ExampleWatcher_Next() {
v := NewValue(nil)
// The channel is not necessary for normal use of the watcher.
// It just makes the test output predictable.
ch := make(chan bool)
go func() {
for x := 0; x < 3; x++ {
v.Set(fmt.Sprintf("value%d", x))
ch <- true
}
v.Close()
}()
w := v.Watch()
for w.Next() {
fmt.Println(w.Value())
<-ch
}
// output:
// value0
// value1
// value2
}
func (s *suite) TestValueGetSet(c *gc.C) {
v := NewValue(nil)
expected := "12345"
v.Set(expected)
got := v.Get()
c.Assert(got, gc.Equals, expected)
c.Assert(v.Closed(), jc.IsFalse)
}
func (s *suite) TestValueInitial(c *gc.C) {
expected := "12345"
v := NewValue(expected)
got := v.Get()
c.Assert(got, gc.Equals, expected)
c.Assert(v.Closed(), jc.IsFalse)
}
func (s *suite) TestValueClose(c *gc.C) {
expected := "12345"
v := NewValue(expected)
c.Assert(v.Close(), gc.IsNil)
isClosed := v.Closed()
c.Assert(isClosed, jc.IsTrue)
got := v.Get()
c.Assert(got, gc.Equals, expected)
// test that we can close multiple times without a problem
c.Assert(v.Close(), gc.IsNil)
}
func (s *suite) TestWatcher(c *gc.C) {
vals := []string{"one", "two", "three"}
// blocking on the channel forces the scheduler to let the other goroutine
// run for a bit, so we get predictable results. This is not necessary for
// normal use of the watcher.
ch := make(chan bool)
v := NewValue(nil)
go func() {
for _, s := range vals {
v.Set(s)
ch <- true
}
v.Close()
}()
w := v.Watch()
c.Assert(w.Next(), jc.IsTrue)
c.Assert(w.Value(), gc.Equals, vals[0])
// test that we can get the same value multiple times
c.Assert(w.Value(), gc.Equals, vals[0])
<-ch
// now try skipping a value by calling next without getting the value
c.Assert(w.Next(), jc.IsTrue)
<-ch
c.Assert(w.Next(), jc.IsTrue)
c.Assert(w.Value(), gc.Equals, vals[2])
<-ch
c.Assert(w.Next(), jc.IsFalse)
}
func (s *suite) TestDoubleSet(c *gc.C) {
vals := []string{"one", "two", "three"}
// blocking on the channel forces the scheduler to let the other goroutine
// run for a bit, so we get predictable results. This is not necessary for
// normal use of the watcher.
ch := make(chan bool)
v := NewValue(nil)
go func() {
v.Set(vals[0])
ch <- true
v.Set(vals[1])
v.Set(vals[2])
ch <- true
v.Close()
ch <- true
}()
w := v.Watch()
c.Assert(w.Next(), jc.IsTrue)
c.Assert(w.Value(), gc.Equals, vals[0])
<-ch
// since we did two sets before sending on the channel,
// we should just get vals[2] here and not get vals[1]
c.Assert(w.Next(), jc.IsTrue)
c.Assert(w.Value(), gc.Equals, vals[2])
}
func (s *suite) TestTwoReceivers(c *gc.C) {
vals := []string{"one", "two", "three"}
// blocking on the channel forces the scheduler to let the other goroutine
// run for a bit, so we get predictable results. This is not necessary for
// normal use of the watcher.
ch := make(chan bool)
v := NewValue(nil)
watcher := func() {
w := v.Watch()
x := 0
for w.Next() {
c.Assert(w.Value(), gc.Equals, vals[x])
x++
<-ch
}
c.Assert(x, gc.Equals, len(vals))
<-ch
}
go watcher()
go watcher()
for _, val := range vals {
v.Set(val)
ch <- true
ch <- true
}
v.Close()
ch <- true
ch <- true
}
func (s *suite) TestCloseWatcher(c *gc.C) {
vals := []string{"one", "two", "three"}
// blocking on the channel forces the scheduler to let the other goroutine
// run for a bit, so we get predictable results. This is not necessary for
// normal use of the watcher.
ch := make(chan bool)
v := NewValue(nil)
w := v.Watch()
go func() {
x := 0
for w.Next() {
c.Assert(w.Value(), gc.Equals, vals[x])
x++
<-ch
}
// the value will only get set once before the watcher is closed
c.Assert(x, gc.Equals, 1)
<-ch
}()
v.Set(vals[0])
ch <- true
w.Close()
ch <- true
// prove the value is not closed, even though the watcher is
c.Assert(v.Closed(), jc.IsFalse)
}
func (s *suite) TestWatchZeroValue(c *gc.C) {
var v Value
ch := make(chan bool)
go func() {
w := v.Watch()
ch <- true
ch <- w.Next()
}()
<-ch
v.Set(struct{}{})
c.Assert(<-ch, jc.IsTrue)
}
================================================
FILE: yaml.go
================================================
// Copyright 2012, 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"io/ioutil"
"os"
"path/filepath"
"github.com/juju/errors"
"gopkg.in/yaml.v2"
)
// WriteYaml marshals obj as yaml to a temporary file in the same directory
// as path, than atomically replaces path with the temporary file.
func WriteYaml(path string, obj any) error {
data, err := yaml.Marshal(obj)
if err != nil {
return errors.Trace(err)
}
dir := filepath.Dir(path)
f, err := ioutil.TempFile(dir, "juju")
if err != nil {
return errors.Trace(err)
}
tmp := f.Name()
if _, err := f.Write(data); err != nil {
_ = f.Close() // don't leak file handle
_ = os.Remove(tmp) // don't leak half written files on disk
return errors.Trace(err)
}
if err := f.Sync(); err != nil {
_ = f.Close() // don't leak file handle
_ = os.Remove(tmp) // don't leak half written files on disk
return errors.Trace(err)
}
// Explicitly close the file before moving it. This is needed on Windows
// where the OS will not allow us to move a file that still has an open
// file handle. Must check the error on close because filesystems can delay
// reporting errors until the file is closed.
if err := f.Close(); err != nil {
_ = os.Remove(tmp) // don't leak half written files on disk
return errors.Trace(err)
}
// ioutils.TempFile creates files 0600, but this function has a contract
// that files will be world readable, 0644 after replacement.
if err := os.Chmod(tmp, 0644); err != nil {
_ = os.Remove(tmp) // remove file with incorrect permissions.
return errors.Trace(err)
}
return ReplaceFile(tmp, path)
}
// ReadYaml unmarshals the yaml contained in the file at path into obj. See
// goyaml.Unmarshal. If path is not found, the error returned will be compatible
// with os.IsNotExist.
func ReadYaml(path string, obj any) error {
data, err := ioutil.ReadFile(path)
if err != nil {
return err // cannot wrap here because callers check for NotFound.
}
return yaml.Unmarshal(data, obj)
}
// ConformYAML ensures all keys of any nested maps are strings. This is
// necessary because YAML unmarshals map[any]any in nested
// maps, which cannot be serialized by json or bson. Also, handle
// []any. cf. gopkg.in/juju/charm.v4/actions.go cleanse
func ConformYAML(input any) (any, error) {
switch typedInput := input.(type) {
case map[string]any:
newMap := make(map[string]any)
for key, value := range typedInput {
newValue, err := ConformYAML(value)
if err != nil {
return nil, err
}
newMap[key] = newValue
}
return newMap, nil
case map[any]any:
newMap := make(map[string]any)
for key, value := range typedInput {
typedKey, ok := key.(string)
if !ok {
return nil, errors.New("map keyed with non-string value")
}
newMap[typedKey] = value
}
return ConformYAML(newMap)
case []any:
newSlice := make([]any, len(typedInput))
for i, sliceValue := range typedInput {
newSliceValue, err := ConformYAML(sliceValue)
if err != nil {
return nil, errors.New("map keyed with non-string value")
}
newSlice[i] = newSliceValue
}
return newSlice, nil
default:
return input, nil
}
}
================================================
FILE: yaml_test.go
================================================
// Copyright 2015 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package utils
import (
"io/ioutil"
"os"
"path/filepath"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
)
type yamlSuite struct {
}
var _ = gc.Suite(&yamlSuite{})
func (*yamlSuite) TestYamlRoundTrip(c *gc.C) {
// test happy path of round tripping an object via yaml
type T struct {
A int `yaml:"a"`
B bool `yaml:"deleted"`
C string `yaml:"omitempty"`
D string
}
v := T{A: 1, B: true, C: "", D: ""}
f, err := ioutil.TempFile(c.MkDir(), "yaml")
c.Assert(err, gc.IsNil)
path := f.Name()
f.Close()
err = WriteYaml(path, v)
c.Assert(err, gc.IsNil)
var v2 T
err = ReadYaml(path, &v2)
c.Assert(err, gc.IsNil)
c.Assert(v, gc.Equals, v2)
}
func (*yamlSuite) TestReadYamlReturnsNotFound(c *gc.C) {
// The contract for ReadYaml requires it returns an error
// that can be inspected by os.IsNotExist. Notably, we cannot
// use juju/errors gift wrapping.
f, err := ioutil.TempFile(c.MkDir(), "yaml")
c.Assert(err, gc.IsNil)
path := f.Name()
err = os.Remove(path)
c.Assert(err, gc.IsNil)
err = ReadYaml(path, nil)
// assert that the error is reported as NotExist
c.Assert(os.IsNotExist(err), gc.Equals, true)
}
func (*yamlSuite) TestWriteYamlMissingDirectory(c *gc.C) {
// WriteYaml tries to create a temporary file in the same
// directory as the target. Test what happens if the path's
// directory is missing
root := c.MkDir()
missing := filepath.Join(root, "missing", "filename")
v := struct{ A, B int }{1, 2}
err := WriteYaml(missing, v)
c.Assert(err, gc.NotNil)
}
func (*yamlSuite) TestWriteYamlWriteGarbage(c *gc.C) {
c.Skip("https://github.com/go-yaml/yaml/issues/144")
// some things cannot be marshalled into yaml, check that
// WriteYaml detects this.
root := c.MkDir()
path := filepath.Join(root, "f")
v := struct{ A, B [10]bool }{}
err := WriteYaml(path, v)
c.Assert(err, gc.NotNil)
}
type ConformSuite struct{}
var _ = gc.Suite(&ConformSuite{})
func (s *ConformSuite) TestConformYAML(c *gc.C) {
var goodInterfaceTests = []struct {
description string
inputInterface any
expectedInterface map[string]any
expectedError string
}{{
description: "An interface requiring no changes.",
inputInterface: map[string]any{
"key1": "value1",
"key2": "value2",
"key3": map[string]any{
"foo1": "val1",
"foo2": "val2"}},
expectedInterface: map[string]any{
"key1": "value1",
"key2": "value2",
"key3": map[string]any{
"foo1": "val1",
"foo2": "val2"}},
}, {
description: "Substitute a single inner map[i]i.",
inputInterface: map[string]any{
"key1": "value1",
"key2": "value2",
"key3": map[any]any{
"foo1": "val1",
"foo2": "val2"}},
expectedInterface: map[string]any{
"key1": "value1",
"key2": "value2",
"key3": map[string]any{
"foo1": "val1",
"foo2": "val2"}},
}, {
description: "Substitute nested inner map[i]i.",
inputInterface: map[string]any{
"key1a": "val1a",
"key2a": "val2a",
"key3a": map[any]any{
"key1b": "val1b",
"key2b": map[any]any{
"key1c": "val1c"}}},
expectedInterface: map[string]any{
"key1a": "val1a",
"key2a": "val2a",
"key3a": map[string]any{
"key1b": "val1b",
"key2b": map[string]any{
"key1c": "val1c"}}},
}, {
description: "Substitute nested map[i]i within []i.",
inputInterface: map[string]any{
"key1a": "val1a",
"key2a": []any{5, "foo", map[string]any{
"key1b": "val1b",
"key2b": map[any]any{
"key1c": "val1c"}}}},
expectedInterface: map[string]any{
"key1a": "val1a",
"key2a": []any{5, "foo", map[string]any{
"key1b": "val1b",
"key2b": map[string]any{
"key1c": "val1c"}}}},
}, {
description: "An inner map[any]any with an int key.",
inputInterface: map[string]any{
"key1": "value1",
"key2": "value2",
"key3": map[any]any{
"foo1": "val1",
5: "val2"}},
expectedError: "map keyed with non-string value",
}, {
description: "An inner []any containing a map[i]i with an int key.",
inputInterface: map[string]any{
"key1a": "val1b",
"key2a": "val2b",
"key3a": []any{"foo1", 5, map[any]any{
"key1b": "val1b",
"key2b": map[any]any{
"key1c": "val1c",
5: "val2c"}}}},
expectedError: "map keyed with non-string value",
}}
for i, test := range goodInterfaceTests {
c.Logf("test %d: %s", i, test.description)
input := test.inputInterface
cleansedInterfaceMap, err := ConformYAML(input)
if test.expectedError == "" {
if !c.Check(err, jc.ErrorIsNil) {
continue
}
c.Check(cleansedInterfaceMap, jc.DeepEquals, test.expectedInterface)
} else {
c.Check(err, gc.ErrorMatches, test.expectedError)
}
}
}
================================================
FILE: zfile_windows.go
================================================
// Copyright 2013 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
// mksyscall_windows.pl -l32 file_windows.go
// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT
package utils
import "unsafe"
import "syscall"
var (
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
procMoveFileExW = modkernel32.NewProc("MoveFileExW")
)
func moveFileEx(lpExistingFileName *uint16, lpNewFileName *uint16, dwFlags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procMoveFileExW.Addr(), 3, uintptr(unsafe.Pointer(lpExistingFileName)), uintptr(unsafe.Pointer(lpNewFileName)), uintptr(dwFlags))
if r1 == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = syscall.EINVAL
}
}
return
}
================================================
FILE: zip/package_test.go
================================================
// Copyright 2011-2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package zip_test
import (
"testing"
gc "gopkg.in/check.v1"
)
func TestPackage(t *testing.T) {
gc.TestingT(t)
}
================================================
FILE: zip/zip.go
================================================
// Copyright 2011-2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package zip
import (
"archive/zip"
"bytes"
"fmt"
"io"
"os"
"path"
"path/filepath"
"strings"
)
// FindAll returns the cleaned path of every file in the supplied zip reader.
func FindAll(reader *zip.Reader) ([]string, error) {
return Find(reader, "*")
}
// Find returns the cleaned path of every file in the supplied zip reader whose
// base name matches the supplied pattern, which is interpreted as in path.Match.
func Find(reader *zip.Reader, pattern string) ([]string, error) {
// path.Match will only return an error if the pattern is not
// valid (*and* the supplied name is not empty, hence "check").
if _, err := path.Match(pattern, "check"); err != nil {
return nil, err
}
var matches []string
for _, zipFile := range reader.File {
cleanPath := path.Clean(zipFile.Name)
baseName := path.Base(cleanPath)
if match, _ := path.Match(pattern, baseName); match {
matches = append(matches, cleanPath)
}
}
return matches, nil
}
// ExtractAll extracts the supplied zip reader to the target path, overwriting
// existing files and directories only where necessary.
func ExtractAll(reader *zip.Reader, targetRoot string) error {
return Extract(reader, targetRoot, "")
}
// Extract extracts files from the supplied zip reader, from the (internal, slash-
// separated) source path into the (external, OS-specific) target path. If the
// source path does not reference a directory, the referenced file will be written
// directly to the target path.
func Extract(reader *zip.Reader, targetRoot, sourceRoot string) error {
sourceRoot = path.Clean(sourceRoot)
if sourceRoot == "." {
sourceRoot = ""
}
if !isSanePath(sourceRoot) {
return fmt.Errorf("cannot extract files rooted at %q", sourceRoot)
}
extractor := extractor{targetRoot, sourceRoot}
for _, zipFile := range reader.File {
if err := extractor.extract(zipFile); err != nil {
cleanName := path.Clean(zipFile.Name)
return fmt.Errorf("cannot extract %q: %v", cleanName, err)
}
}
return nil
}
type extractor struct {
targetRoot string
sourceRoot string
}
// targetPath returns the target path for a given zip file and whether
// it should be extracted.
func (x extractor) targetPath(zipFile *zip.File) (string, bool) {
cleanPath := path.Clean(zipFile.Name)
if cleanPath == x.sourceRoot {
return x.targetRoot, true
}
cleanPath = strings.TrimPrefix(cleanPath, "/")
for strings.HasPrefix(cleanPath, "../") {
cleanPath = cleanPath[len("../"):]
}
if x.sourceRoot != "" {
mustPrefix := x.sourceRoot + "/"
if !strings.HasPrefix(cleanPath, mustPrefix) {
return "", false
}
cleanPath = cleanPath[len(mustPrefix):]
}
return filepath.Join(x.targetRoot, filepath.FromSlash(cleanPath)), true
}
func (x extractor) extract(zipFile *zip.File) error {
targetPath, ok := x.targetPath(zipFile)
if !ok {
return nil
}
parentPath := filepath.Dir(targetPath)
if err := os.MkdirAll(parentPath, 0777); err != nil {
return err
}
mode := zipFile.Mode()
modePerm := mode & os.ModePerm
modeType := mode & os.ModeType
switch modeType {
case os.ModeDir:
return x.writeDir(targetPath, modePerm)
case os.ModeSymlink:
return x.writeSymlink(targetPath, zipFile)
case 0:
return x.writeFile(targetPath, zipFile, modePerm)
}
return fmt.Errorf("unknown file type %d", modeType)
}
func (x extractor) writeDir(targetPath string, modePerm os.FileMode) error {
fileInfo, err := os.Lstat(targetPath)
switch {
case err == nil:
mode := fileInfo.Mode()
if mode.IsDir() {
if mode&os.ModePerm != modePerm {
return os.Chmod(targetPath, modePerm)
}
return nil
}
fallthrough
case !os.IsNotExist(err):
if err := os.RemoveAll(targetPath); err != nil {
return err
}
}
return os.MkdirAll(targetPath, modePerm)
}
func (x extractor) writeFile(targetPath string, zipFile *zip.File, modePerm os.FileMode) error {
if _, err := os.Lstat(targetPath); !os.IsNotExist(err) {
if err := os.RemoveAll(targetPath); err != nil {
return err
}
}
writer, err := os.OpenFile(targetPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, modePerm)
if err != nil {
return err
}
defer writer.Close()
if err := copyTo(writer, zipFile); err != nil {
return err
}
if err := writer.Sync(); err != nil {
return err
}
if err := writer.Close(); err != nil {
return err
}
return nil
}
func (x extractor) writeSymlink(targetPath string, zipFile *zip.File) error {
symlinkTarget, err := x.checkSymlink(targetPath, zipFile)
if err != nil {
return err
}
if _, err := os.Lstat(targetPath); !os.IsNotExist(err) {
if err := os.RemoveAll(targetPath); err != nil {
return err
}
}
return os.Symlink(symlinkTarget, targetPath)
}
func (x extractor) checkSymlink(targetPath string, zipFile *zip.File) (string, error) {
var buffer bytes.Buffer
if err := copyTo(&buffer, zipFile); err != nil {
return "", err
}
symlinkTarget := buffer.String()
if filepath.IsAbs(symlinkTarget) {
return "", fmt.Errorf("symlink %q is absolute", symlinkTarget)
}
finalPath := filepath.Join(filepath.Dir(targetPath), symlinkTarget)
relativePath, err := filepath.Rel(x.targetRoot, finalPath)
if err != nil {
// Not tested, because I don't know how to trigger this condition.
return "", fmt.Errorf("symlink %q not comprehensible", symlinkTarget)
}
if !isSanePath(relativePath) {
return "", fmt.Errorf("symlink %q leads out of scope", symlinkTarget)
}
return symlinkTarget, nil
}
func copyTo(writer io.Writer, zipFile *zip.File) error {
reader, err := zipFile.Open()
if err != nil {
return err
}
_, err = io.Copy(writer, reader)
reader.Close()
return err
}
func isSanePath(path string) bool {
if path == ".." || strings.HasPrefix(path, "../") {
return false
}
return true
}
================================================
FILE: zip/zip_test.go
================================================
// Copyright 2011-2014 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.
package zip_test
import (
stdzip "archive/zip"
"bytes"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"sort"
"github.com/juju/testing"
jc "github.com/juju/testing/checkers"
ft "github.com/juju/testing/filetesting"
gc "gopkg.in/check.v1"
"github.com/juju/utils/v4/zip"
)
type ZipSuite struct {
testing.IsolationSuite
}
var _ = gc.Suite(&ZipSuite{})
func (s *ZipSuite) makeZip(c *gc.C, entries ...ft.Entry) *stdzip.Reader {
basePath := c.MkDir()
for _, entry := range entries {
entry.Create(c, basePath)
}
defer os.RemoveAll(basePath)
outPath := filepath.Join(c.MkDir(), "test.zip")
cmd := exec.Command("/bin/sh", "-c", fmt.Sprintf("cd %q; zip --fifo --symlinks -r %q .", basePath, outPath))
output, err := cmd.CombinedOutput()
c.Assert(err, gc.IsNil, gc.Commentf("Command output: %s", output))
file, err := os.Open(outPath)
c.Assert(err, gc.IsNil)
s.AddCleanup(func(c *gc.C) {
err := file.Close()
c.Assert(err, gc.IsNil)
})
fileInfo, err := file.Stat()
c.Assert(err, gc.IsNil)
reader, err := stdzip.NewReader(file, fileInfo.Size())
c.Assert(err, gc.IsNil)
return reader
}
func (s *ZipSuite) TestFind(c *gc.C) {
reader := s.makeZip(c,
ft.File{"some-file", "", 0644},
ft.File{"another-file", "", 0644},
ft.Symlink{"some-symlink", "some-file"},
ft.Dir{"some-dir", 0755},
ft.Dir{"some-dir/another-dir", 0755},
ft.File{"some-dir/another-file", "", 0644},
)
for i, test := range []struct {
pattern string
expect []string
}{{
"", nil,
}, {
"no-matches", nil,
}, {
"some-file", []string{
"some-file"},
}, {
"another-file", []string{
"another-file",
"some-dir/another-file"},
}, {
"some-*", []string{
"some-file",
"some-symlink",
"some-dir"},
}, {
"another-*", []string{
"another-file",
"some-dir/another-dir",
"some-dir/another-file"},
}, {
"*", []string{
"some-file",
"another-file",
"some-symlink",
"some-dir",
"some-dir/another-dir",
"some-dir/another-file"},
}} {
c.Logf("test %d: %q", i, test.pattern)
actual, err := zip.Find(reader, test.pattern)
c.Assert(err, gc.IsNil)
sort.Strings(test.expect)
sort.Strings(actual)
c.Check(actual, jc.DeepEquals, test.expect)
}
c.Logf("test $spanish-inquisition: FindAll")
expect, err := zip.Find(reader, "*")
c.Assert(err, gc.IsNil)
actual, err := zip.FindAll(reader)
c.Assert(err, gc.IsNil)
sort.Strings(expect)
sort.Strings(actual)
c.Check(actual, jc.DeepEquals, expect)
}
func (s *ZipSuite) TestFindError(c *gc.C) {
reader := s.makeZip(c, ft.File{"some-file", "", 0644})
_, err := zip.Find(reader, "[]")
c.Assert(err, gc.ErrorMatches, "syntax error in pattern")
}
func (s *ZipSuite) TestExtractAll(c *gc.C) {
entries := []ft.Entry{
ft.File{"some-file", "content 1", 0644},
ft.File{"another-file", "content 2", 0640},
ft.Symlink{"some-symlink", "some-file"},
ft.Dir{"some-dir", 0750},
ft.File{"some-dir/another-file", "content 3", 0644},
ft.Dir{"some-dir/another-dir", 0755},
ft.Symlink{"some-dir/another-dir/another-symlink", "../../another-file"},
}
reader := s.makeZip(c, entries...)
targetPath := c.MkDir()
err := zip.ExtractAll(reader, targetPath)
c.Assert(err, gc.IsNil)
for i, entry := range entries {
c.Logf("test %d: %#v", i, entry)
entry.Check(c, targetPath)
}
}
func (s *ZipSuite) TestExtractAllOverwriteFiles(c *gc.C) {
name := "some-file"
for i, test := range []ft.Entry{
ft.File{name, "content", 0644},
ft.Dir{name, 0751},
ft.Symlink{name, "wherever"},
} {
c.Logf("test %d: %#v", i, test)
targetPath := c.MkDir()
ft.File{name, "original", 0}.Create(c, targetPath)
reader := s.makeZip(c, test)
err := zip.ExtractAll(reader, targetPath)
c.Check(err, gc.IsNil)
test.Check(c, targetPath)
}
}
func (s *ZipSuite) TestExtractAllOverwriteSymlinks(c *gc.C) {
name := "some-symlink"
for i, test := range []ft.Entry{
ft.File{name, "content", 0644},
ft.Dir{name, 0751},
ft.Symlink{name, "wherever"},
} {
c.Logf("test %d: %#v", i, test)
targetPath := c.MkDir()
original := ft.File{"original", "content", 0644}
original.Create(c, targetPath)
ft.Symlink{name, "original"}.Create(c, targetPath)
reader := s.makeZip(c, test)
err := zip.ExtractAll(reader, targetPath)
c.Check(err, gc.IsNil)
test.Check(c, targetPath)
original.Check(c, targetPath)
}
}
func (s *ZipSuite) TestExtractAllOverwriteDirs(c *gc.C) {
name := "some-dir"
for i, test := range []ft.Entry{
ft.File{name, "content", 0644},
ft.Dir{name, 0751},
ft.Symlink{name, "wherever"},
} {
c.Logf("test %d: %#v", i, test)
targetPath := c.MkDir()
ft.Dir{name, 0}.Create(c, targetPath)
reader := s.makeZip(c, test)
err := zip.ExtractAll(reader, targetPath)
c.Check(err, gc.IsNil)
test.Check(c, targetPath)
}
}
func (s *ZipSuite) TestExtractAllMergeDirs(c *gc.C) {
targetPath := c.MkDir()
ft.Dir{"dir", 0755}.Create(c, targetPath)
originals := []ft.Entry{
ft.Dir{"dir/original-dir", 0751},
ft.File{"dir/original-file", "content 1", 0600},
ft.Symlink{"dir/original-symlink", "original-file"},
}
for _, entry := range originals {
entry.Create(c, targetPath)
}
merges := []ft.Entry{
ft.Dir{"dir", 0751},
ft.Dir{"dir/merge-dir", 0750},
ft.File{"dir/merge-file", "content 2", 0640},
ft.Symlink{"dir/merge-symlink", "merge-file"},
}
reader := s.makeZip(c, merges...)
err := zip.ExtractAll(reader, targetPath)
c.Assert(err, gc.IsNil)
for i, test := range append(originals, merges...) {
c.Logf("test %d: %#v", i, test)
test.Check(c, targetPath)
}
}
func (s *ZipSuite) TestExtractAllSymlinkErrors(c *gc.C) {
for i, test := range []struct {
content []ft.Entry
error string
}{{
content: []ft.Entry{
ft.Symlink{"symlink", "/blah"},
},
error: `cannot extract "symlink": symlink "/blah" is absolute`,
}, {
content: []ft.Entry{
ft.Symlink{"symlink", "../blah"},
},
error: `cannot extract "symlink": symlink "../blah" leads out of scope`,
}, {
content: []ft.Entry{
ft.Dir{"dir", 0755},
ft.Symlink{"dir/symlink", "../../blah"},
},
error: `cannot extract "dir/symlink": symlink "../../blah" leads out of scope`,
}} {
c.Logf("test %d: %s", i, test.error)
targetPath := c.MkDir()
reader := s.makeZip(c, test.content...)
err := zip.ExtractAll(reader, targetPath)
c.Check(err, gc.ErrorMatches, test.error)
}
}
func (s *ZipSuite) TestExtractDir(c *gc.C) {
reader := s.makeZip(c,
ft.File{"bad-file", "xxx", 0644},
ft.Dir{"bad-dir", 0755},
ft.Symlink{"bad-symlink", "bad-file"},
ft.Dir{"some-dir", 0751},
ft.File{"some-dir-bad-lol", "xxx", 0644},
ft.File{"some-dir/some-file", "content 1", 0644},
ft.File{"some-dir/another-file", "content 2", 0600},
ft.Dir{"some-dir/another-dir", 0750},
ft.Symlink{"some-dir/another-dir/some-symlink", "../some-file"},
)
targetParent := c.MkDir()
targetPath := filepath.Join(targetParent, "random-dir")
err := zip.Extract(reader, targetPath, "some-dir")
c.Assert(err, gc.IsNil)
for i, test := range []ft.Entry{
ft.Dir{"random-dir", 0751},
ft.File{"random-dir/some-file", "content 1", 0644},
ft.File{"random-dir/another-file", "content 2", 0600},
ft.Dir{"random-dir/another-dir", 0750},
ft.Symlink{"random-dir/another-dir/some-symlink", "../some-file"},
} {
c.Logf("test %d: %#v", i, test)
test.Check(c, targetParent)
}
fileInfos, err := ioutil.ReadDir(targetParent)
c.Check(err, gc.IsNil)
c.Check(fileInfos, gc.HasLen, 1)
fileInfos, err = ioutil.ReadDir(targetPath)
c.Check(err, gc.IsNil)
c.Check(fileInfos, gc.HasLen, 3)
}
func (s *ZipSuite) TestExtractSingleFile(c *gc.C) {
reader := s.makeZip(c,
ft.Dir{"dir", 0755},
ft.Dir{"dir/dir", 0755},
ft.File{"dir/dir/some-file", "content 1", 0644},
ft.File{"dir/dir/some-file-wtf", "content 2", 0644},
)
targetParent := c.MkDir()
targetPath := filepath.Join(targetParent, "just-the-one-file")
err := zip.Extract(reader, targetPath, "dir/dir/some-file")
c.Assert(err, gc.IsNil)
fileInfos, err := ioutil.ReadDir(targetParent)
c.Check(err, gc.IsNil)
c.Check(fileInfos, gc.HasLen, 1)
ft.File{"just-the-one-file", "content 1", 0644}.Check(c, targetParent)
}
func (s *ZipSuite) TestClosesFile(c *gc.C) {
reader := s.makeZip(c, ft.File{"f", "echo hullo!", 0755})
targetPath := c.MkDir()
err := zip.ExtractAll(reader, targetPath)
c.Assert(err, gc.IsNil)
cmd := exec.Command("/bin/sh", "-c", filepath.Join(targetPath, "f"))
var buffer bytes.Buffer
cmd.Stdout = &buffer
err = cmd.Run()
c.Assert(err, gc.IsNil)
c.Assert(buffer.String(), gc.Equals, "hullo!\n")
}
func (s *ZipSuite) TestExtractSymlinkErrors(c *gc.C) {
for i, test := range []struct {
content []ft.Entry
source string
error string
}{{
content: []ft.Entry{
ft.Dir{"dir", 0755},
ft.Symlink{"dir/symlink", "/blah"},
},
source: "dir",
error: `cannot extract "dir/symlink": symlink "/blah" is absolute`,
}, {
content: []ft.Entry{
ft.Dir{"dir", 0755},
ft.Symlink{"dir/symlink", "../blah"},
},
source: "dir",
error: `cannot extract "dir/symlink": symlink "../blah" leads out of scope`,
}, {
content: []ft.Entry{
ft.Symlink{"symlink", "blah"},
},
source: "symlink",
error: `cannot extract "symlink": symlink "blah" leads out of scope`,
}} {
c.Logf("test %d: %s", i, test.error)
targetPath := c.MkDir()
reader := s.makeZip(c, test.content...)
err := zip.Extract(reader, targetPath, test.source)
c.Check(err, gc.ErrorMatches, test.error)
}
}
func (s *ZipSuite) TestExtractSourceError(c *gc.C) {
reader := s.makeZip(c, ft.Dir{"dir", 0755})
err := zip.Extract(reader, c.MkDir(), "../lol")
c.Assert(err, gc.ErrorMatches, `cannot extract files rooted at "../lol"`)
}