Repository: moriyoshi/s3-sftp-proxy Branch: master Commit: d292467c629f Files: 20 Total size: 64.0 KB Directory structure: gitextract_lb8z7ih5/ ├── CONTRIBUTORS ├── LICENSE ├── README.md ├── bucket.go ├── bucketio.go ├── config.go ├── fakee_unix.go ├── fakee_windows.go ├── io.go ├── logging.go ├── main.go ├── merged_context.go ├── path.go ├── path_test.go ├── phantom_object_map.go ├── phantom_object_map_test.go ├── s3-sftp-proxy.example.toml ├── server.go ├── user.go └── utils.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: CONTRIBUTORS ================================================ Moriyoshi Koizumi Dmitry Chepurovskiy ================================================ FILE: LICENSE ================================================ Copyright (c) 2018 Moriyoshi Koizumi Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # s3-sftp-proxy `s3-sftp-proxy` is a tiny program that exposes the resources on your AWS S3 buckets through SFTP protocol. ## Usage ``` Usage of s3-sftp-proxy: -bind string listen on addr:port -config string configuration file (default "s3-sftp-proxy.toml") -debug turn on debugging output ``` * `-bind` Specifies the local address and port to listen on. This overrides the value of `bind` in the configuration file. If it is not present in the configuration file either, it defaults to `:10022`. * `-config` Specifies the path to the configuration file. It defaults to "./s3-sftp-config.toml" if not given. * `-debug` Turn on debug logging. The output will be more verbose. ## Configuation The configuration file is in [TOML](https://github.com/toml-lang/toml) format. Refer to that page for the detailed explanation of the syntax. ### Top level ```toml host_key_file = "./host_key" bind = "localhost:10022" banner = """ Welcome to my SFTP server """ reader_lookback_buffer_size = 1048576 reader_min_chunk_size = 262144 lister_lookback_buffer_size = 100 # buckets and authantication settings follow... ``` * `host_key_file` (required) Specifies the path to the host key file (private key). The host key can be generated with `ssh-keygen` command: ```sh ssh-keygen -f host_key ``` * `bind` (optional, defaults to `":10022"`) Specifies the local address and port to listen on. * `banner` (optional, defaults to an empty string) A banner is a message text that will be sent to the client when the connection is esablished to the server prior to any authentication steps. * `reader_lookback_buffer_size` (optional, defaults to `1048576`) Specifies the size of the buffer used to keep several amounts of data read from S3 for later access to it. The reason why such buffer is necessary is that SFTP protocol requires the data should be sent or retrieved on a random-access basis (i.e. each request contains an offset) while those coming from S3 is actually fetched in a streaming manner. In that we have to emulate block storage access for S3 objects, but chances are we don't need to hold the entire data with the reasonable SFTP clients. * `reader_min_chunk_size` (optional, defaults to `262144`) Specifies the amount of data fetched from S3 at once. Increase the value when you experience quite a poor performance. * `lister_lookback_buffer_size` (optional, defalts to `100`) Contrary to the people's expectation, SFTP also requires file listings to be retrieved in random-access as well. * `buckets` (required) `buckets` contains records for bucket declarations. See [Bucket Settings](#bucket-settings) for detail. * `auth` `auth` contains records for authenticator configurations. See [Authenticator Settings](#authenticator-settings) for detail. ### Bucket Settings ```toml [buckets.test] endpoint = "http://endpoint" s3_force_path_style = true disable_ssl = false bucket = "BUCKET" key_prefix = "PREFIX" bucket_url = "s3://BUCKET/PREFIX" profile = "profile" region = "ap-northeast-1" max_object_size = 65536 writable = false readable = true listable = true auth = "test" server_side_encryption = "kms" sse_customer_key = "" sse_kms_key_id = "" keyboard_interactive_auth = false [buckets.test.credentials] aws_access_key_id = "aaa" aws_secret_access_key = "bbb" ``` * `endpoint` (optional) Specifies s3 endpoint (server) different from AWS. * `s3_force_path_style` (optional) This option should be set to `true` if ypu use endpount different from AWS. Set this to `true` to force the request to use path-style addressing, i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client will use virtual hosted bucket addressing when possible (`http://BUCKET.s3.amazonaws.com/KEY`). * `disable_ssl` (optional) Set this to `true` to disable SSL when sending requests. * `bucket` (required when `bucket_url` is unspecified) Specifies the bucket name. * `key_prefix` (required when `bucket_url` is unspecified) Specifies the prefix prepended to the file path sent from the client. The key string is derived as follows: `key` = `key_prefix` + `path` * `bucket_url` (required when `bucket` is unspecified) Specifies both the bucket name and prefix in the URL form. The URL's scheme must be `s3`, and the host part corresponds to `bucket` while the path part does to `key_prefix`. You may not specify `bucket_url` and either `bucket` or `key_prefix` at the same time. * `profile` (optional, defaults to the value of `AWS_PROFILE` unless `credentials` is specified) Specifies the credentials profile name. * `region` (optional, defaults to the value of `AWS_REGION` environment variable) Specifies the region of the endpoint. * `credentials` (optional) * `credentials.aws_access_key_id` (required) Specifies the AWS access key. * `credentials.aws_secret_access_key` (required) Specifies the AWS secret access key. * `max_object_size` (optional, defaults to unlimited) Specifies the maximum size of an object put to S3. This actually sets the size of the in-memory buffer used to hold the entire content sent from the client, as we have to calculate a MD5 sum for it before uploading there. * `readable` (optional, defaults to `true`) Specifies whether to allow the client to fetch objects from S3. * `writable` (optional, defaults to `true`) Specifies whether to allow the client to put objects to S3. * `listable` (optional, defaults to `true`) Specifies whether to allow the client to list objects in S3. * `server_side_encryption` (optional, defaults to `"none"`) Specifies which server-side encryption scheme is applied to store the objects. Valid values are: `"aes256"` and `"kms"`. * `sse_customer_key` (required when `server_side_encryption` is set to `"aes256"`) Specifies the base64-encoded encryption key. As the cipher is AES256-CBC, the key must be 256-bits long (32 bytes) * `sse_kms_key_id` (required when `server_side_encryption` is est to `"kms"`) Specifies the CMK ID used for the server-side encryption using KMS. * `keyboard_interactive_auth` (optional, defaults to `false`) Enables keyboard interactive authentication if set to true. * `auth` (required) Specifies the name of the authenticator. ### Authenticator Settings ```toml [auth.test] type = "inplace" # authenticator specific settings follow ``` * `type` (required) Specifies the authenticator implementation type. Currently `"inplace"` is the only valid value. * `users` (required when `type` is `"inplace"`) Contains user records as a dictionary. #### In-place authenticator In-place authenticator reads the credentials directly embedded in the configuration file. The user record looks like the following: ```toml [auth.test] type = "inplace" [auth.test.users.user0] password = "test" public_keys = """ ssh-rsa AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA ssh-rsa AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA """ [auth.test.users.user1] password = "test" public_keys = """ ssh-rsa AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA ssh-rsa AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA """ ``` Or ```toml [auth.test] type = "inplace" [auth.test.users] user0 = { password="test", public_keys="..." } user1 = { password="test", public_keys="..." } ``` * (key) (appears as `user0` or `user1` in the above example) Specifies the name of the user. * `password` (optional) Specifies the password in a clear-text form. * `public_keys` (optional) Specifies the public keys authorized to use in authentication. Multiple keys can be specified by delimiting them by newlines. ================================================ FILE: bucket.go ================================================ package main import ( "crypto" "encoding/base64" "fmt" "strings" aws "github.com/aws/aws-sdk-go/aws" aws_creds "github.com/aws/aws-sdk-go/aws/credentials" aws_ec2_role_creds "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" aws_ec2_meta "github.com/aws/aws-sdk-go/aws/ec2metadata" aws_session "github.com/aws/aws-sdk-go/aws/session" s3 "github.com/aws/aws-sdk-go/service/s3" "github.com/pkg/errors" ) type ServerSideEncryptionType int const ( ServerSideEncryptionTypeNone = iota ServerSideEncryptionTypeAES256 ServerSideEncryptionTypeKMS ) var sseNameToEnumMap = map[string]ServerSideEncryptionType{ "": ServerSideEncryptionTypeNone, "none": ServerSideEncryptionTypeNone, "aes256": ServerSideEncryptionTypeAES256, "kms": ServerSideEncryptionTypeKMS, } func (v *ServerSideEncryptionType) UnmarshalText(text []byte) error { _v, ok := sseNameToEnumMap[strings.ToLower(string(text))] if !ok { return fmt.Errorf("invalid value for ServerSideEncryption: %s", string(text)) } *v = _v return nil } type ServerSideEncryptionConfig struct { Type ServerSideEncryptionType CustomerKey string CustomerKeyMD5 string KMSKeyId string } func (cfg *ServerSideEncryptionConfig) CustomerAlgorithm() string { if cfg.Type == ServerSideEncryptionTypeAES256 { return "AES256" } else { return "" } } type Perms struct { Readable bool Writable bool Listable bool } type S3Bucket struct { Name string AWSConfig *aws.Config Bucket string KeyPrefix Path MaxObjectSize int64 Users UserStore Perms Perms ServerSideEncryption ServerSideEncryptionConfig KeyboardInteractiveAuthEnabled bool } type S3Buckets struct { Buckets map[string]*S3Bucket UserToBucketMap map[string]*S3Bucket } func (s3bs *S3Buckets) Get(name string) *S3Bucket { b, _ := s3bs.Buckets[name] return b } func (s3b *S3Bucket) S3(sess *aws_session.Session) *s3.S3 { awsCfg := s3b.AWSConfig if awsCfg.Credentials == nil { awsCfg = s3b.AWSConfig.WithCredentials(aws_creds.NewChainCredentials( []aws_creds.Provider{ &aws_ec2_role_creds.EC2RoleProvider{ Client: aws_ec2_meta.New(sess), ExpiryWindow: 0, }, &aws_creds.EnvProvider{}, }, )) } return s3.New(sess, awsCfg) } func buildS3Bucket(uStores UserStores, name string, bCfg *S3BucketConfig) (*S3Bucket, error) { awsCfg := aws.NewConfig() if bCfg.Credentials != nil { awsCfg = awsCfg.WithCredentials( aws_creds.NewStaticCredentials( bCfg.Credentials.AWSAccessKeyID, bCfg.Credentials.AWSSecretAccessKey, "", ), ) } else if bCfg.Profile != "" { awsCfg = awsCfg.WithCredentials( aws_creds.NewSharedCredentials( "", // TODO: assumes default bCfg.Profile, ), ) } else { // credentials are retrieved through EC2 metadata on runtime } if bCfg.Endpoint != "" { awsCfg = awsCfg.WithEndpoint(bCfg.Endpoint) } if bCfg.S3ForcePathStyle != nil { awsCfg = awsCfg.WithS3ForcePathStyle(*bCfg.S3ForcePathStyle) } if bCfg.DisableSSL != nil { awsCfg = awsCfg.WithDisableSSL(*bCfg.DisableSSL) } if bCfg.Region != "" { awsCfg = awsCfg.WithRegion(bCfg.Region) } users, ok := uStores[bCfg.Auth] if !ok { return nil, fmt.Errorf("no such auth config: %s", bCfg.Auth) } keyPrefix := SplitIntoPath(bCfg.KeyPrefix) if len(keyPrefix) > 0 && keyPrefix[0] == "" { keyPrefix = keyPrefix[1:] } maxObjectSize := int64(-1) if bCfg.MaxObjectSize != nil { maxObjectSize = *bCfg.MaxObjectSize } var customerKey []byte var customerKeyMD5 string if bCfg.SSECustomerKey != "" { var err error customerKey, err = base64.StdEncoding.DecodeString(bCfg.SSECustomerKey) if err != nil { return nil, errors.Wrapf(err, `invalid base64-encoded string specified for "sse_customer_key"`) } hasher := crypto.MD5.New() hasher.Write(customerKey) customerKeyMD5 = base64.StdEncoding.EncodeToString(hasher.Sum([]byte{})) } else { customerKey = []byte{} } return &S3Bucket{ Name: name, AWSConfig: awsCfg, Bucket: bCfg.Bucket, KeyPrefix: keyPrefix, MaxObjectSize: maxObjectSize, Users: users, Perms: Perms{ Readable: *bCfg.Readable, Writable: *bCfg.Writable, Listable: *bCfg.Listable, }, ServerSideEncryption: ServerSideEncryptionConfig{ Type: bCfg.ServerSideEncryption, CustomerKey: string(customerKey), CustomerKeyMD5: customerKeyMD5, KMSKeyId: bCfg.SSEKMSKeyId, }, KeyboardInteractiveAuthEnabled: bCfg.KeyboardInteractiveAuthEnabled, }, nil } func NewS3BucketFromConfig(uStores UserStores, cfg *S3SFTPProxyConfig) (*S3Buckets, error) { buckets := map[string]*S3Bucket{} userToBucketMap := map[string]*S3Bucket{} for name, bCfg := range cfg.Buckets { bucket, err := buildS3Bucket(uStores, name, bCfg) if err != nil { return nil, errors.Wrapf(err, "bucket config %s", name) } for _, user := range bucket.Users.Users { _bucket, ok := userToBucketMap[user.Name] if ok { return nil, fmt.Errorf(`bucket config %s: user "%s" is already assigned to bucket config "%s"`, name, user.Name, _bucket.Name) } userToBucketMap[user.Name] = bucket } buckets[name] = bucket } return &S3Buckets{ Buckets: buckets, UserToBucketMap: userToBucketMap, }, nil } ================================================ FILE: bucketio.go ================================================ package main import ( "bytes" "context" "fmt" "io" "os" "path" "sync" "time" aws "github.com/aws/aws-sdk-go/aws" aws_session "github.com/aws/aws-sdk-go/aws/session" aws_s3 "github.com/aws/aws-sdk-go/service/s3" "github.com/pkg/sftp" // s3crypto "github.com/aws/aws-sdk-go/service/s3/s3crypto" ) var aclPrivate = "private" type ReadDeadlineSettable interface { SetReadDeadline(t time.Time) error } type WriteDeadlineSettable interface { SetWriteDeadline(t time.Time) error } var sseTypes = map[ServerSideEncryptionType]*string{ ServerSideEncryptionTypeKMS: aws.String("aws:kms"), } func nilIfEmpty(s string) *string { if s == "" { return nil } else { return &s } } type S3GetObjectOutputReader struct { Ctx context.Context Goo *aws_s3.GetObjectOutput Log DebugLogger Lookback int MinChunkSize int mtx sync.Mutex spooled []byte spoolOffset int noMore bool } func (oor *S3GetObjectOutputReader) Close() error { if oor.Goo.Body != nil { oor.Goo.Body.Close() oor.Goo.Body = nil } return nil } func (oor *S3GetObjectOutputReader) ReadAt(buf []byte, off int64) (int, error) { oor.mtx.Lock() defer oor.mtx.Unlock() F(oor.Log.Debug, "len(buf)=%d, off=%d", len(buf), off) _o, err := castInt64ToInt(off) if err != nil { return 0, err } if _o < oor.spoolOffset { return 0, fmt.Errorf("supplied position is out of range") } s := _o - oor.spoolOffset i := 0 r := len(buf) if s < len(oor.spooled) { // n = max(r, len(oor.spooled)-s) n := r if n > len(oor.spooled)-s { n = len(oor.spooled) - s } copy(buf[i:i+n], oor.spooled[s:s+n]) i += n s += n r -= n } if r == 0 { return i, nil } if oor.noMore { if i == 0 { return 0, io.EOF } else { return i, nil } } F(oor.Log.Debug, "s=%d, len(oor.spooled)=%d, oor.Lookback=%d", s, len(oor.spooled), oor.Lookback) if s <= len(oor.spooled) && s >= oor.Lookback { oor.spooled = oor.spooled[s-oor.Lookback:] oor.spoolOffset += s - oor.Lookback s = oor.Lookback } var e int if len(oor.spooled)+oor.MinChunkSize < s+r { e = s + r } else { e = len(oor.spooled) + oor.MinChunkSize } if cap(oor.spooled) < e { spooled := make([]byte, len(oor.spooled), e) copy(spooled, oor.spooled) oor.spooled = spooled } type readResult struct { n int err error } resultChan := make(chan readResult) go func() { n, err := io.ReadFull(oor.Goo.Body, oor.spooled[len(oor.spooled):e]) resultChan <- readResult{n, err} }() select { case <-oor.Ctx.Done(): oor.Goo.Body.(ReadDeadlineSettable).SetReadDeadline(time.Unix(1, 0)) oor.Log.Debug("canceled") return 0, fmt.Errorf("read operation canceled") case res := <-resultChan: if IsEOF(res.err) { oor.noMore = true } e = len(oor.spooled) + res.n oor.spooled = oor.spooled[:e] if s < e { be := e if be > s+r { be = s + r } copy(buf[i:], oor.spooled[s:be]) return be - s, nil } else { return 0, io.EOF } } } type S3PutObjectWriter struct { Ctx context.Context Bucket string Key Path S3 *aws_s3.S3 ServerSideEncryption *ServerSideEncryptionConfig Log interface { DebugLogger ErrorLogger } MaxObjectSize int64 Info *PhantomObjectInfo PhantomObjectMap *PhantomObjectMap mtx sync.Mutex writer *BytesWriter } func (oow *S3PutObjectWriter) Close() error { F(oow.Log.Debug, "S3PutObjectWriter.Close") oow.mtx.Lock() defer oow.mtx.Unlock() phInfo := oow.Info.GetOne() oow.PhantomObjectMap.RemoveByInfoPtr(oow.Info) key := phInfo.Key.String() sse := oow.ServerSideEncryption F(oow.Log.Debug, "PutObject(Bucket=%s, Key=%s, Sse=%v)", oow.Bucket, key, sse) _, err := oow.S3.PutObject( &aws_s3.PutObjectInput{ ACL: &aclPrivate, Body: bytes.NewReader(oow.writer.Bytes()), Bucket: &oow.Bucket, Key: &key, ServerSideEncryption: sseTypes[sse.Type], SSECustomerAlgorithm: nilIfEmpty(sse.CustomerAlgorithm()), SSECustomerKey: nilIfEmpty(sse.CustomerKey), SSECustomerKeyMD5: nilIfEmpty(sse.CustomerKeyMD5), SSEKMSKeyId: nilIfEmpty(sse.KMSKeyId), }, ) if err != nil { oow.Log.Debug("=> ", err) F(oow.Log.Error, "failed to put object: %s", err.Error()) } else { oow.Log.Debug("=> OK") } return nil } func (oow *S3PutObjectWriter) WriteAt(buf []byte, off int64) (int, error) { oow.mtx.Lock() defer oow.mtx.Unlock() if oow.MaxObjectSize >= 0 { if int64(len(buf))+off > oow.MaxObjectSize { return 0, fmt.Errorf("file too large: maximum allowed size is %d bytes", oow.MaxObjectSize) } } F(oow.Log.Debug, "len(buf)=%d, off=%d", len(buf), off) n, err := oow.writer.WriteAt(buf, off) oow.Info.SetSize(oow.writer.Size()) return n, err } type ObjectFileInfo struct { _Name string _LastModified time.Time _Size int64 _Mode os.FileMode } func (ofi *ObjectFileInfo) Name() string { return ofi._Name } func (ofi *ObjectFileInfo) ModTime() time.Time { return ofi._LastModified } func (ofi *ObjectFileInfo) Size() int64 { return ofi._Size } func (ofi *ObjectFileInfo) Mode() os.FileMode { return ofi._Mode } func (ofi *ObjectFileInfo) IsDir() bool { return (ofi._Mode & os.ModeDir) != 0 } func (ofi *ObjectFileInfo) Sys() interface{} { return BuildFakeFileInfoSys() } type S3ObjectLister struct { DebugLogger Ctx context.Context Bucket string Prefix Path S3 *aws_s3.S3 Lookback int PhantomObjectMap *PhantomObjectMap spoolOffset int spooled []os.FileInfo continuation *string noMore bool } func aclToMode(owner *aws_s3.Owner, grants []*aws_s3.Grant) os.FileMode { var v os.FileMode for _, g := range grants { if g.Grantee != nil { if g.Grantee.ID != nil && *g.Grantee.ID == *owner.ID { switch *g.Permission { case "READ": v |= 0400 case "WRITE": v |= 0200 case "FULL_CONTROL": v |= 0600 } } else if g.Grantee.URI != nil { switch *g.Grantee.URI { case "http://acs.amazonaws.com/groups/global/AuthenticatedUsers": switch *g.Permission { case "READ": v |= 0440 case "WRITE": v |= 0220 case "FULL_CONTROL": v |= 0660 } case "http://acs.amazonaws.com/groups/global/AllUsers": switch *g.Permission { case "READ": v |= 0444 case "WRITE": v |= 0222 case "FULL_CONTROL": v |= 0666 } } } } } return v } func (sol *S3ObjectLister) ListAt(result []os.FileInfo, o int64) (int, error) { _o, err := castInt64ToInt(o) if err != nil { return 0, err } if _o < sol.spoolOffset { return 0, fmt.Errorf("supplied position is out of range") } s := _o - sol.spoolOffset i := 0 if s < len(sol.spooled) { n := len(result) if n > len(sol.spooled)-s { n = len(sol.spooled) - s } copy(result[i:i+n], sol.spooled[s:s+n]) i += n s = len(sol.spooled) } if i >= len(result) { return i, nil } if sol.noMore { if i == 0 { return 0, io.EOF } else { return i, nil } } if s <= len(sol.spooled) && s >= sol.Lookback { sol.spooled = sol.spooled[s-sol.Lookback:] sol.spoolOffset += s - sol.Lookback s = sol.Lookback } if sol.continuation == nil { sol.spooled = append(sol.spooled, &ObjectFileInfo{ _Name: ".", _LastModified: time.Unix(1, 0), _Size: 0, _Mode: 0755 | os.ModeDir, }) sol.spooled = append(sol.spooled, &ObjectFileInfo{ _Name: "..", _LastModified: time.Unix(1, 0), _Size: 0, _Mode: 0755 | os.ModeDir, }) phObjs := sol.PhantomObjectMap.List(sol.Prefix) for _, phInfo := range phObjs { _phInfo := phInfo.GetOne() sol.spooled = append(sol.spooled, &ObjectFileInfo{ _Name: _phInfo.Key.Base(), _LastModified: _phInfo.LastModified, _Size: _phInfo.Size, _Mode: 0600, // TODO }) } } prefix := sol.Prefix.String() if prefix != "" { prefix += "/" } F(sol.Debug, "ListObjectsV2WithContext(Bucket=%s, Prefix=%s, Continuation=%v)", sol.Bucket, prefix, sol.continuation) out, err := sol.S3.ListObjectsV2WithContext( sol.Ctx, &aws_s3.ListObjectsV2Input{ Bucket: &sol.Bucket, Prefix: &prefix, MaxKeys: aws.Int64(10000), Delimiter: aws.String("/"), ContinuationToken: sol.continuation, }, ) if err != nil { sol.Debug("=> ", err) return i, err } F(sol.Debug, "=> { CommonPrefixes=len(%d), Contents=len(%d) }", len(out.CommonPrefixes), len(out.Contents)) if sol.continuation == nil { for _, cPfx := range out.CommonPrefixes { sol.spooled = append(sol.spooled, &ObjectFileInfo{ _Name: path.Base(*cPfx.Prefix), _LastModified: time.Unix(1, 0), _Size: 0, _Mode: 0755 | os.ModeDir, }) } } for _, obj := range out.Contents { // if *obj.Key == sol.Prefix { // continue // } sol.spooled = append(sol.spooled, &ObjectFileInfo{ _Name: path.Base(*obj.Key), _LastModified: *obj.LastModified, _Size: *obj.Size, _Mode: 0644, }) } sol.continuation = out.NextContinuationToken if out.NextContinuationToken == nil { sol.noMore = true } var n int if len(sol.spooled)-s > len(result)-i { n = len(result) - i } else { n = len(sol.spooled) - s if sol.noMore { err = io.EOF } } copy(result[i:i+n], sol.spooled[s:s+n]) return i + n, err } type S3ObjectStat struct { DebugLogger Ctx context.Context Bucket string Key Path Root bool S3 *aws_s3.S3 PhantomObjectMap *PhantomObjectMap } func (sos *S3ObjectStat) ListAt(result []os.FileInfo, o int64) (int, error) { F(sos.Debug, "S3ObjectStat.ListAt: len(result)=%d offset=%d", len(result), o) _o, err := castInt64ToInt(o) if err != nil { return 0, err } if len(result) == 0 { return 0, nil } if _o > 0 { return 0, fmt.Errorf("supplied position is out of range") } if sos.Key.IsRoot() { result[0] = &ObjectFileInfo{ _Name: "/", _LastModified: time.Time{}, _Size: 0, _Mode: 0755 | os.ModeDir, } } else { phInfo := sos.PhantomObjectMap.Get(sos.Key) if phInfo != nil { _phInfo := phInfo.GetOne() result[0] = &ObjectFileInfo{ _Name: _phInfo.Key.Base(), _LastModified: _phInfo.LastModified, _Size: _phInfo.Size, _Mode: 0600, // TODO } } else { key := sos.Key.String() F(sos.Debug, "GetObjectAclWithContext(Bucket=%s, Key=%s)", sos.Bucket, key) out, err := sos.S3.GetObjectAclWithContext( sos.Ctx, &aws_s3.GetObjectAclInput{ Bucket: &sos.Bucket, Key: &key, }, ) if err == nil { F(sos.Debug, "=> %v", out) F(sos.Debug, "HeadObjectWithContext(Bucket=%s, Key=%s)", sos.Bucket, key) headOut, err := sos.S3.HeadObjectWithContext( sos.Ctx, &aws_s3.HeadObjectInput{ Bucket: &sos.Bucket, Key: &key, }, ) objInfo := ObjectFileInfo{ _Name: sos.Key.Base(), _Mode: aclToMode(out.Owner, out.Grants), } if err == nil { F(sos.Debug, "=> { ContentLength=%d, LastModified=%v }", *headOut.ContentLength, *headOut.LastModified) objInfo._Size = *headOut.ContentLength objInfo._LastModified = *headOut.LastModified } else { sos.Debug("=> ", err) } result[0] = &objInfo } else { sos.Debug("=> ", err) F(sos.Debug, "ListObjectsV2WithContext(Bucket=%s, Prefix=%s)", sos.Bucket, key) out, err := sos.S3.ListObjectsV2WithContext( sos.Ctx, &aws_s3.ListObjectsV2Input{ Bucket: &sos.Bucket, Prefix: &key, MaxKeys: aws.Int64(10000), Delimiter: aws.String("/"), }, ) if err != nil || (!sos.Root && len(out.CommonPrefixes) == 0) { sos.Debug("=> ", err) return 0, os.ErrNotExist } F(sos.Debug, "=> { CommonPrefixes=len(%d), Contents=len(%d) }", len(out.CommonPrefixes), len(out.Contents)) result[0] = &ObjectFileInfo{ _Name: sos.Key.Base(), _LastModified: time.Time{}, _Size: 0, _Mode: 0755 | os.ModeDir, } } } } return 1, nil } type S3BucketIO struct { Ctx context.Context Bucket *S3Bucket ReaderLookbackBufferSize int ReaderMinChunkSize int ListerLookbackBufferSize int PhantomObjectMap *PhantomObjectMap Perms Perms ServerSideEncryption *ServerSideEncryptionConfig Now func() time.Time Log interface { ErrorLogger DebugLogger } } func buildKey(s3b *S3Bucket, path string) Path { return s3b.KeyPrefix.Join(SplitIntoPath(path)) } func buildPath(s3b *S3Bucket, key string) (string, bool) { _key := SplitIntoPath(key) if !_key.IsPrefixed(s3b.KeyPrefix) { return "", false } return "/" + _key[len(s3b.KeyPrefix):].String(), true } func (s3io *S3BucketIO) Fileread(req *sftp.Request) (io.ReaderAt, error) { if !s3io.Perms.Readable { return nil, fmt.Errorf("read operation not allowed as per configuration") } sess, err := aws_session.NewSession() if err != nil { return nil, err } s3 := s3io.Bucket.S3(sess) key := buildKey(s3io.Bucket, req.Filepath) phInfo := s3io.PhantomObjectMap.Get(key) if phInfo != nil { return bytes.NewReader(phInfo.Opaque.(*S3PutObjectWriter).writer.Bytes()), nil } keyStr := key.String() ctx := combineContext(s3io.Ctx, req.Context()) F(s3io.Log.Debug, "GetObject(Bucket=%s, Key=%s)", s3io.Bucket.Bucket, keyStr) sse := s3io.ServerSideEncryption goo, err := s3.GetObjectWithContext( ctx, &aws_s3.GetObjectInput{ Bucket: &s3io.Bucket.Bucket, Key: &keyStr, SSECustomerAlgorithm: nilIfEmpty(sse.CustomerAlgorithm()), SSECustomerKey: nilIfEmpty(sse.CustomerKey), SSECustomerKeyMD5: nilIfEmpty(sse.CustomerKeyMD5), }, ) if err != nil { return nil, err } return &S3GetObjectOutputReader{ Ctx: ctx, Goo: goo, Log: s3io.Log, Lookback: s3io.ReaderLookbackBufferSize, MinChunkSize: s3io.ReaderMinChunkSize, }, nil } func (s3io *S3BucketIO) Filewrite(req *sftp.Request) (io.WriterAt, error) { if !s3io.Perms.Writable { return nil, fmt.Errorf("write operation not allowed as per configuration") } sess, err := aws_session.NewSession() if err != nil { return nil, err } maxObjectSize := s3io.Bucket.MaxObjectSize if maxObjectSize < 0 { maxObjectSize = int64(^uint(0) >> 1) } key := buildKey(s3io.Bucket, req.Filepath) info := &PhantomObjectInfo{ Key: key, Size: 0, LastModified: s3io.Now(), } F(s3io.Log.Debug, "S3PutObjectWriter.New(key=%s)", key) oow := &S3PutObjectWriter{ Ctx: combineContext(s3io.Ctx, req.Context()), Bucket: s3io.Bucket.Bucket, Key: key, S3: s3io.Bucket.S3(sess), ServerSideEncryption: s3io.ServerSideEncryption, Log: s3io.Log, MaxObjectSize: maxObjectSize, PhantomObjectMap: s3io.PhantomObjectMap, Info: info, writer: NewBytesWriter(), } info.Opaque = oow s3io.PhantomObjectMap.Add(info) return oow, nil } func (s3io *S3BucketIO) Filecmd(req *sftp.Request) error { switch req.Method { case "Rename": if !s3io.Perms.Writable { return fmt.Errorf("write operation not allowed as per configuration") } src := buildKey(s3io.Bucket, req.Filepath) dest := buildKey(s3io.Bucket, req.Target) if s3io.PhantomObjectMap.Rename(src, dest) { return nil } sess, err := aws_session.NewSession() if err != nil { return err } srcStr := src.String() destStr := dest.String() copySource := s3io.Bucket.Bucket + "/" + srcStr sse := s3io.ServerSideEncryption F(s3io.Log.Debug, "CopyObject(Bucket=%s, Key=%s, CopySource=%s, Sse=%v)", s3io.Bucket.Bucket, destStr, copySource, sse.Type) _, err = s3io.Bucket.S3(sess).CopyObjectWithContext( combineContext(s3io.Ctx, req.Context()), &aws_s3.CopyObjectInput{ ACL: &aclPrivate, Bucket: &s3io.Bucket.Bucket, CopySource: ©Source, Key: &destStr, ServerSideEncryption: sseTypes[sse.Type], SSECustomerAlgorithm: nilIfEmpty(sse.CustomerAlgorithm()), SSECustomerKey: nilIfEmpty(sse.CustomerKey), SSECustomerKeyMD5: nilIfEmpty(sse.CustomerKeyMD5), SSEKMSKeyId: nilIfEmpty(sse.KMSKeyId), }, ) if err != nil { s3io.Log.Debug("=> ", err) return err } F(s3io.Log.Debug, "DeleteObject(Bucket=%s, Key=%s)", s3io.Bucket.Bucket, srcStr) _, err = s3io.Bucket.S3(sess).DeleteObjectWithContext( combineContext(s3io.Ctx, req.Context()), &aws_s3.DeleteObjectInput{ Bucket: &s3io.Bucket.Bucket, Key: &srcStr, }, ) if err != nil { s3io.Log.Debug("=> ", err) return err } case "Remove": if !s3io.Perms.Writable { return fmt.Errorf("write operation not allowed as per configuration") } key := buildKey(s3io.Bucket, req.Filepath) if s3io.PhantomObjectMap.Remove(key) != nil { return nil } sess, err := aws_session.NewSession() if err != nil { return err } keyStr := key.String() F(s3io.Log.Debug, "DeleteObject(Bucket=%s, Key=%s)", s3io.Bucket.Bucket, key) _, err = s3io.Bucket.S3(sess).DeleteObjectWithContext( combineContext(s3io.Ctx, req.Context()), &aws_s3.DeleteObjectInput{ Bucket: &s3io.Bucket.Bucket, Key: &keyStr, }, ) if err != nil { s3io.Log.Debug("=> ", err) return err } } return nil } func (s3io *S3BucketIO) Filelist(req *sftp.Request) (sftp.ListerAt, error) { sess, err := aws_session.NewSession() if err != nil { return nil, err } switch req.Method { case "Stat", "ReadLink": if !s3io.Perms.Readable && !s3io.Perms.Listable { return nil, fmt.Errorf("stat operation not allowed as per configuration") } key := buildKey(s3io.Bucket, req.Filepath) return &S3ObjectStat{ DebugLogger: s3io.Log, Ctx: combineContext(s3io.Ctx, req.Context()), Bucket: s3io.Bucket.Bucket, Root: key.Equal(s3io.Bucket.KeyPrefix), Key: key, S3: s3io.Bucket.S3(sess), PhantomObjectMap: s3io.PhantomObjectMap, }, nil case "List": if !s3io.Perms.Listable { return nil, fmt.Errorf("listing operation not allowed as per configuration") } return &S3ObjectLister{ DebugLogger: s3io.Log, Ctx: combineContext(s3io.Ctx, req.Context()), Bucket: s3io.Bucket.Bucket, Prefix: buildKey(s3io.Bucket, req.Filepath), S3: s3io.Bucket.S3(sess), Lookback: s3io.ListerLookbackBufferSize, PhantomObjectMap: s3io.PhantomObjectMap, }, nil default: return nil, fmt.Errorf("unsupported method: %s", req.Method) } } ================================================ FILE: config.go ================================================ package main import ( "fmt" "github.com/BurntSushi/toml" "github.com/pkg/errors" "io/ioutil" "net/url" ) var ( minReaderLookbackBufferSize = 1048576 minReaderMinChunkSize = 262144 minListerLookbackBufferSize = 100 vTrue = true ) type URL struct { *url.URL } func (u *URL) UnmarshalText(text []byte) (err error) { u.URL, err = url.Parse(string(text)) return } type AWSCredentialsConfig struct { AWSAccessKeyID string `toml:"aws_access_key_id"` AWSSecretAccessKey string `toml:"aws_secret_access_key"` } type S3BucketConfig struct { Profile string `toml:"profile"` Credentials *AWSCredentialsConfig `toml:"credentials"` Region string `toml:"region"` Endpoint string `toml:"endpoint"` DisableSSL *bool `toml:"disable_ssl"` S3ForcePathStyle *bool `toml:"s3_force_path_style"` Bucket string `toml:"bucket"` KeyPrefix string `toml:"key_prefix"` BucketUrl *URL `toml:"bucket_url"` Auth string `toml:"auth"` MaxObjectSize *int64 `toml:"max_object_size"` Readable *bool `toml:"readble"` Writable *bool `toml:"writable"` Listable *bool `toml:"listable"` ServerSideEncryption ServerSideEncryptionType `toml:"server_side_encryption"` SSECustomerKey string `toml:"sse_customer_key"` SSEKMSKeyId string `toml:"sse_kms_key_id"` KeyboardInteractiveAuthEnabled bool `toml:"keyboard_interactive_auth"` } type AuthUser struct { Password string `toml:"password"` PublicKeys string `toml:"public_keys"` PublicKeyFile string `toml:"public_key_file"` } type AuthConfig struct { Type string `toml:"type"` UserDBFile string `toml:"user_db_file"` Users map[string]AuthUser `toml:"users"` } type S3SFTPProxyConfig struct { Bind string `toml:"bind"` HostKeyFile string `toml:"host_key_file"` Banner string `toml:"banner"` ReaderLookbackBufferSize *int `toml:"reader_lookback_buffer_size"` ReaderMinChunkSize *int `toml:"reader_min_chunk_size"` ListerLookbackBufferSize *int `toml:"lister_lookback_buffer_size"` Buckets map[string]*S3BucketConfig `toml:"buckets"` AuthConfigs map[string]*AuthConfig `toml:"auth"` } func validateAndFixupBucketConfig(bCfg *S3BucketConfig) error { if bCfg.Profile != "" { if bCfg.Credentials != nil { return fmt.Errorf("no credentials may be specified if profile is given") } } if bCfg.BucketUrl != nil { if bCfg.Bucket != "" { return fmt.Errorf("bucket may not be specified if bucket_url is given") } if bCfg.KeyPrefix != "" { return fmt.Errorf("root path may not be specified if bucket_url is given") } if bCfg.BucketUrl.Host == "" { return fmt.Errorf("bucket name is empty") } if bCfg.BucketUrl.Scheme != "s3" { return fmt.Errorf("bucket URL scheme must be \"s3\"") } bCfg.Bucket = bCfg.BucketUrl.Host bCfg.KeyPrefix = bCfg.BucketUrl.Path } else { if bCfg.Bucket == "" { return fmt.Errorf("bucket name is empty") } } if bCfg.Auth == "" { return fmt.Errorf("auth is not specified") } if bCfg.Readable == nil { bCfg.Readable = &vTrue } if bCfg.Writable == nil { bCfg.Writable = &vTrue } if bCfg.Listable == nil { bCfg.Listable = &vTrue } return nil } func validateAndFixupAuthConfigInplace(aCfg *AuthConfig) error { if aCfg.UserDBFile != "" { return fmt.Errorf(`user_db_file may not be specified when auth type is "inplace"`) } if aCfg.Users == nil || len(aCfg.Users) == 0 { fmt.Printf("%#v\n", aCfg.Users) return fmt.Errorf(`no "users" present`) } return nil } func validateAndFixupAuthConfig(aCfg *AuthConfig) error { switch aCfg.Type { case "inplace": return validateAndFixupAuthConfigInplace(aCfg) default: return fmt.Errorf("unknown auth type: %s", aCfg.Type) } } func ReadConfig(tomlStr string) (*S3SFTPProxyConfig, error) { cfg := &S3SFTPProxyConfig{ Buckets: map[string]*S3BucketConfig{}, AuthConfigs: map[string]*AuthConfig{}, } _, err := toml.Decode(tomlStr, cfg) if err != nil { return nil, err } if len(cfg.Buckets) == 0 { return nil, fmt.Errorf("no bucket configs are present") } if len(cfg.AuthConfigs) == 0 { return nil, fmt.Errorf("no auth configs are present") } if cfg.HostKeyFile == "" { return nil, fmt.Errorf("no host key file is specified") } if len(cfg.Banner) > 0 && cfg.Banner[len(cfg.Banner)-1] != '\n' { cfg.Banner += "\n" } if cfg.ReaderLookbackBufferSize == nil { cfg.ReaderLookbackBufferSize = &minReaderLookbackBufferSize } else if *cfg.ReaderLookbackBufferSize < minReaderLookbackBufferSize { return nil, fmt.Errorf("reader_lookback_buffer_size must be equal to or greater than %d", minReaderMinChunkSize) } if cfg.ReaderMinChunkSize == nil { cfg.ReaderMinChunkSize = &minReaderMinChunkSize } else if *cfg.ReaderMinChunkSize < minReaderMinChunkSize { return nil, fmt.Errorf("reader_min_chunk_size must be equal to or greater than %d", minReaderMinChunkSize) } if cfg.ListerLookbackBufferSize == nil { cfg.ListerLookbackBufferSize = &minListerLookbackBufferSize } else if *cfg.ListerLookbackBufferSize < minListerLookbackBufferSize { return nil, fmt.Errorf("lister_lookback_buffer_size must be equal to or greater than %d", minListerLookbackBufferSize) } for name, bCfg := range cfg.Buckets { err := validateAndFixupBucketConfig(bCfg) if err != nil { return nil, errors.Wrapf(err, `bucket config "%s"`, name) } } for name, aCfg := range cfg.AuthConfigs { err := validateAndFixupAuthConfig(aCfg) if err != nil { return nil, errors.Wrapf(err, `auth config "%s"`, name) } } return cfg, err } func ReadConfigFromFile(tomlFile string) (*S3SFTPProxyConfig, error) { tomlStr, err := ioutil.ReadFile(tomlFile) if err != nil { return nil, errors.Wrapf(err, "failed to open %s", tomlFile) } cfg, err := ReadConfig(string(tomlStr)) if err != nil { return nil, errors.Wrapf(err, "failed to parse %s", tomlFile) } return cfg, nil } ================================================ FILE: fakee_unix.go ================================================ // +build !windows package main import ( "syscall" ) func BuildFakeFileInfoSys() interface{} { return &syscall.Stat_t{Uid: 65534, Gid: 65534} } ================================================ FILE: fakee_windows.go ================================================ // +build windows package main import "syscall" func BuildFakeFileInfoSys() interface{} { return syscall.Win32FileAttributeData{} } ================================================ FILE: io.go ================================================ package main import ( "fmt" "io" "unsafe" ) type BytesWriter struct { buf []byte pos int } func NewBytesWriter() *BytesWriter { return &BytesWriter{ buf: []byte{}, } } func castInt64ToInt(n int64) (int, error) { if unsafe.Sizeof(n) == unsafe.Sizeof(int(0)) { return int(n), nil } else { _n := int(n) if int64(_n) < n { return -1, fmt.Errorf("integer overflow detected when converting %#v to int", n) } return _n, nil } } func (bw *BytesWriter) Close() error { return nil } // Resize the buffer capacity so the new size is at least the value of newCap. func (bw *BytesWriter) grow(newCap int) { if cap(bw.buf) >= newCap { return } i := cap(bw.buf) if i < 2 { i = 2 } for i < newCap { i = i + i/2 if i < cap(bw.buf) { panic("allocation failure") } } newBuf := make([]byte, len(bw.buf), i) copy(newBuf, bw.buf) bw.buf = newBuf } func (bw *BytesWriter) Truncate(n int64) error { _n, err := castInt64ToInt(n) if err != nil { return err } bw.buf = bw.buf[0:_n] if bw.pos > _n { bw.pos = _n } return nil } func (bw *BytesWriter) Seek(offset int64, whence int) (int64, error) { _o, err := castInt64ToInt(offset) if err != nil { return -1, err } var newPos int switch whence { case 0: newPos = _o case 1: newPos = bw.pos + _o case 2: newPos = len(bw.buf) + _o } if newPos < len(bw.buf) { bw.grow(newPos) bw.buf = bw.buf[0:newPos] } bw.pos = newPos return int64(newPos), nil } func (bw *BytesWriter) Write(p []byte) (n int, err error) { bw.grow(bw.pos + len(p)) copy(bw.buf[bw.pos:bw.pos+len(p)], p) return len(p), nil } func (bw *BytesWriter) WriteAt(p []byte, offset int64) (n int, err error) { _o, err := castInt64ToInt(offset) if err != nil { return -1, err } req := _o + len(p) if req > len(bw.buf) { bw.grow(req) bw.buf = bw.buf[0:req] } copy(bw.buf[_o:req], p) return len(p), nil } func (bw *BytesWriter) Size() int64 { return int64(len(bw.buf)) } func (bw *BytesWriter) Bytes() []byte { return bw.buf } func IsEOF(e error) bool { return e == io.EOF || e == io.ErrUnexpectedEOF } func IsTimeout(e error) bool { t, ok := e.(interface{ Timeout() bool }) if ok { return t.Timeout() } return false } ================================================ FILE: logging.go ================================================ package main type DebugLogger interface { Debug(args ...interface{}) } type InfoLogger interface { Info(args ...interface{}) } type ErrorLogger interface { Error(args ...interface{}) } ================================================ FILE: main.go ================================================ package main import ( "bytes" "context" "flag" "fmt" "io/ioutil" "net" "os" "os/signal" "time" "github.com/pkg/errors" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) var ( configFile string bind string debug bool ) func init() { flag.StringVar(&configFile, "config", "s3-sftp-proxy.toml", "configuration file") flag.StringVar(&bind, "bind", "", "listen on addr:port") flag.BoolVar(&debug, "debug", false, "turn on debugging output") } func buildSSHServerConfig(buckets *S3Buckets, cfg *S3SFTPProxyConfig) (*ssh.ServerConfig, error) { pem, err := ioutil.ReadFile(cfg.HostKeyFile) if err != nil { return nil, errors.Wrapf(err, `failed to open "%s"`, cfg.HostKeyFile) } key, err := ssh.ParseRawPrivateKey(pem) if err != nil { return nil, errors.Wrapf(err, `failed to parse host key "%s"`, cfg.HostKeyFile) } c := &ssh.ServerConfig{ PasswordCallback: func(c ssh.ConnMetadata, passwd []byte) (*ssh.Permissions, error) { bucket, ok := buckets.UserToBucketMap[c.User()] if !ok { return nil, fmt.Errorf("unknown user: %s", c.User()) } u := bucket.Users.Lookup(c.User()) if u.Password != "" && u.Password == string(passwd) { return nil, nil } return nil, fmt.Errorf("passwords do not match") }, PublicKeyCallback: func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { bucket, ok := buckets.UserToBucketMap[c.User()] if !ok { return nil, fmt.Errorf("unknown user: %s", c.User()) } u := bucket.Users.Lookup(c.User()) if u.PublicKeys != nil { keyMarshaled := key.Marshal() for _, herKey := range u.PublicKeys { if herKey.Type() == key.Type() && len(herKey.Marshal()) == len(keyMarshaled) && bytes.Compare(herKey.Marshal(), keyMarshaled) == 0 { return &ssh.Permissions{ Extensions: map[string]string{ "pubkey-fp": ssh.FingerprintSHA256(key), }, }, nil } } } return nil, fmt.Errorf("public keys do not match") }, KeyboardInteractiveCallback: func(c ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { bucket, ok := buckets.UserToBucketMap[c.User()] if !ok { return nil, fmt.Errorf("unknown user: %s", c.User()) } if !bucket.KeyboardInteractiveAuthEnabled { return nil, fmt.Errorf("keyboard interactive authentication not enabled") } u := bucket.Users.Lookup(c.User()) if u.Password == "" { return nil, fmt.Errorf("no credentials are present") } answers, err := client(u.Name, "", []string{"Password: "}, []bool{false}) if err != nil { return nil, errors.Wrapf(err, "keyboard interactive conversation failed") } if answers[0] != u.Password { return nil, fmt.Errorf("passwords do not match") } return nil, nil }, BannerCallback: func(c ssh.ConnMetadata) string { return cfg.Banner }, } sgn, err := ssh.NewSignerFromKey(key) if err != nil { return nil, err } c.AddHostKey(sgn) return c, nil } func bail(msg string, status ...interface{}) { os.Stderr.Write([]byte(msg + "\n")) statusCode := 1 if len(status) > 0 { var ok bool statusCode, ok = status[0].(int) if !ok { panic("invalid argument for bail()") } } os.Exit(statusCode) } func main() { flag.Parse() cfg, err := ReadConfigFromFile(configFile) if err != nil { bail(err.Error()) } uStores, err := NewUserStoresFromConfig(cfg) if err != nil { bail(err.Error()) } buckets, err := NewS3BucketFromConfig(uStores, cfg) if err != nil { bail(err.Error()) } sCfg, err := buildSSHServerConfig(buckets, cfg) if err != nil { bail(err.Error()) } _bind := bind if _bind == "" { _bind = cfg.Bind if _bind == "" { _bind = ":10022" } } logger := logrus.New() if debug { logger.SetLevel(logrus.DebugLevel) } lsnr, err := net.Listen("tcp", _bind) if err != nil { bail(err.Error()) } defer lsnr.Close() logger.Info("Listen on ", _bind) ctx, cancel := context.WithCancel(context.Background()) defer cancel() sigChan := make(chan os.Signal) signal.Notify(sigChan, os.Interrupt) errChan := make(chan error) go func() { errChan <- (&Server{ S3Buckets: buckets, ServerConfig: sCfg, Log: logger, ReaderLookbackBufferSize: *cfg.ReaderLookbackBufferSize, ReaderMinChunkSize: *cfg.ReaderMinChunkSize, ListerLookbackBufferSize: *cfg.ListerLookbackBufferSize, PhantomObjectMap: NewPhantomObjectMap(), Now: time.Now, }).RunListenerEventLoop(ctx, lsnr.(*net.TCPListener)) }() outer: for { select { case err = <-errChan: if err != nil { bail(err.Error()) } break outer case <-sigChan: cancel() } } } ================================================ FILE: merged_context.go ================================================ package main import ( "context" "reflect" "time" ) type mergedContext struct { ctxs []context.Context doneChan chan struct{} err error } func (ctxs *mergedContext) Deadline() (time.Time, bool) { retval := time.Time{} retvalAvail := false for _, ctx := range ctxs.ctxs { if dl, ok := ctx.Deadline(); ok { if !retval.IsZero() || retval.After(dl) { retval = dl retvalAvail = true } } } return retval, retvalAvail } func (ctxs *mergedContext) Done() <-chan struct{} { return ctxs.doneChan } func (ctxs *mergedContext) Err() error { return ctxs.err } func (ctxs *mergedContext) Value(key interface{}) interface{} { for _, ctx := range ctxs.ctxs { v := ctx.Value(key) if v != nil { return v } } return nil } func (ctxs *mergedContext) watcher() { if len(ctxs.ctxs) == 2 { go func() { select { case <-ctxs.ctxs[0].Done(): ctxs.err = ctxs.ctxs[0].Err() case <-ctxs.ctxs[1].Done(): ctxs.err = ctxs.ctxs[0].Err() } close(ctxs.doneChan) }() } else { cases := []reflect.SelectCase{} for _, ctx := range ctxs.ctxs { cases = append(cases, reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ctx.Done()), }) } go func() { chosen, _, _ := reflect.Select(cases) ctxs.err = ctxs.ctxs[chosen].Err() close(ctxs.doneChan) }() } } func combineContext(ctxs ...context.Context) context.Context { if len(ctxs) == 1 { return ctxs[0] } ctx := &mergedContext{ ctxs: ctxs, doneChan: make(chan struct{}), err: nil, } ctx.watcher() return ctx } ================================================ FILE: path.go ================================================ package main import "strings" type Path []string func splitIntoPathInner(p Path, path string, state int) Path { s := 0 i := 0 c := 0 for c >= 0 { if i < len(path) { c = int(path[i]) } else { c = -1 } switch state { case 0: if c == '/' { i++ } else { state = 1 s = i } case 1: if c == '/' || c < 0 { p = append(p, path[s:i]) state = 0 } else { i++ } } } return p } func SplitIntoPathAsAbs(path string) Path { if path == "" { return Path{} } return splitIntoPathInner(Path{""}, path, 0) } func SplitIntoPath(path string) Path { if path == "" { return Path{} } return splitIntoPathInner(Path{}, path, 1) } func (p Path) Canonicalize() Path { retval := make(Path, 0, len(p)) for _, c := range p { switch c { case ".": continue case "..": if len(retval) > 0 && retval[len(retval)-1] != "" { retval = retval[:len(retval)-1] } default: retval = append(retval, c) } } return retval } func (p Path) IsEmpty() bool { return len(p) == 0 } func (p Path) IsRoot() bool { return len(p) == 1 && p[0] == "" } func (p Path) IsAbs() bool { return len(p) > 0 && p[0] == "" } func (p Path) Join(another Path) Path { if len(another) > 0 && another[0] == "" { return append(p, another[1:]...) } else { return append(p, another...) } } func (p Path) String() string { return strings.Join(p, "/") } func (p Path) IsPrefixed(another Path) bool { if len(p) < len(another) { return false } for i, c := range another { if p[i] != c { return false } } return true } func (p Path) Prefix() Path { if len(p) == 0 { return p } else if len(p) == 1 { if p[0] == "" { return Path{""} } else { return Path{} } } else { return p[:len(p)-1] } } func (p Path) BasePart() Path { if len(p) == 0 { return p } else if len(p) == 1 { if p[0] == "" { return Path{""} } else { return Path{} } } else { return p[len(p)-1:] } } func (p Path) Base() string { if len(p) == 0 { return "" } else if len(p) == 1 { if p[0] == "" { return "/" } else { return "" } } else { return p[len(p)-1] } } func (p Path) Equal(p2 Path) bool { if len(p) != len(p2) { return false } for i := 0; i < len(p); i++ { if p[i] != p2[i] { return false } } return true } ================================================ FILE: path_test.go ================================================ package main import ( "github.com/stretchr/testify/assert" "testing" ) func TestSplitIntoPath(t *testing.T) { assert.Equal(t, Path{}, SplitIntoPath("")) assert.Equal(t, Path{"abc"}, SplitIntoPath("abc")) assert.Equal(t, Path{"abc", "bcd"}, SplitIntoPath("abc/bcd")) assert.Equal(t, Path{"abc", "bcd"}, SplitIntoPath("abc/bcd/")) assert.Equal(t, Path{"abc", "bcd"}, SplitIntoPath("abc/bcd///")) assert.Equal(t, Path{"abc", "bcd", "cde"}, SplitIntoPath("abc/bcd//cde")) assert.Equal(t, Path{""}, SplitIntoPath("/")) assert.Equal(t, Path{""}, SplitIntoPath("//")) assert.Equal(t, Path{"", "abc", "bcd"}, SplitIntoPath("//abc//bcd")) } func TestSplitIntoPathAbsolute(t *testing.T) { assert.Equal(t, Path{}, SplitIntoPathAsAbs("")) assert.Equal(t, Path{"", "abc"}, SplitIntoPathAsAbs("abc")) assert.Equal(t, Path{"", "abc", "bcd"}, SplitIntoPathAsAbs("abc/bcd")) assert.Equal(t, Path{"", "abc", "bcd"}, SplitIntoPathAsAbs("abc/bcd/")) assert.Equal(t, Path{"", "abc", "bcd"}, SplitIntoPathAsAbs("abc/bcd///")) assert.Equal(t, Path{"", "abc", "bcd", "cde"}, SplitIntoPathAsAbs("abc/bcd//cde")) assert.Equal(t, Path{""}, SplitIntoPathAsAbs("/")) assert.Equal(t, Path{""}, SplitIntoPathAsAbs("//")) assert.Equal(t, Path{"", "abc", "bcd"}, SplitIntoPathAsAbs("//abc//bcd")) } ================================================ FILE: phantom_object_map.go ================================================ package main import ( "sync" "time" ) type PhantomObjectInfo struct { Key Path LastModified time.Time Size int64 Opaque interface{} Mtx sync.Mutex } func (info *PhantomObjectInfo) GetOne() PhantomObjectInfo { info.Mtx.Lock() defer info.Mtx.Unlock() return *info } func (info *PhantomObjectInfo) setKey(v Path) { info.Mtx.Lock() defer info.Mtx.Unlock() info.Key = v } func (info *PhantomObjectInfo) SetLastModified(v time.Time) { info.Mtx.Lock() defer info.Mtx.Unlock() info.LastModified = v } func (info *PhantomObjectInfo) SetSize(v int64) { info.Mtx.Lock() defer info.Mtx.Unlock() info.Size = v } type phantomObjectInfoMap map[string]*PhantomObjectInfo type PhantomObjectMap struct { perPrefixObjects map[string]phantomObjectInfoMap ptrToPOIMMapMap map[*PhantomObjectInfo]phantomObjectInfoMap mtx sync.Mutex } func (pom *PhantomObjectMap) add(info *PhantomObjectInfo) bool { prefix := info.Key.Prefix().String() m := pom.perPrefixObjects[prefix] if m == nil { m = phantomObjectInfoMap{} pom.perPrefixObjects[prefix] = m } prevInfo := m[info.Key.Base()] m[info.Key.Base()] = info pom.ptrToPOIMMapMap[info] = m if prevInfo != nil { delete(pom.ptrToPOIMMapMap, prevInfo) } return prevInfo == nil } func (pom *PhantomObjectMap) Add(info *PhantomObjectInfo) bool { pom.mtx.Lock() defer pom.mtx.Unlock() return pom.add(info) } func (pom *PhantomObjectMap) remove(key Path) *PhantomObjectInfo { prefix := key.Prefix().String() m := pom.perPrefixObjects[prefix] if m == nil { return nil } info := m[key.Base()] if info == nil { return nil } delete(m, key.Base()) if len(m) == 0 { delete(pom.perPrefixObjects, prefix) } delete(pom.ptrToPOIMMapMap, info) return info } func (pom *PhantomObjectMap) Remove(key Path) *PhantomObjectInfo { pom.mtx.Lock() defer pom.mtx.Unlock() return pom.remove(key) } func (pom *PhantomObjectMap) removeByInfoPtr(info *PhantomObjectInfo) bool { m := pom.ptrToPOIMMapMap[info] if m == nil { return false } delete(m, info.Key.Base()) if len(m) == 0 { delete(pom.perPrefixObjects, info.Key.Prefix().String()) } delete(pom.ptrToPOIMMapMap, info) return true } func (pom *PhantomObjectMap) RemoveByInfoPtr(info *PhantomObjectInfo) bool { pom.mtx.Lock() defer pom.mtx.Unlock() return pom.removeByInfoPtr(info) } func (pom *PhantomObjectMap) rename(old, new Path) bool { info := pom.remove(old) if info == nil { return false } info.setKey(new) pom.add(info) return true } func (pom *PhantomObjectMap) Rename(old, new Path) bool { pom.mtx.Lock() defer pom.mtx.Unlock() return pom.rename(old, new) } func (pom *PhantomObjectMap) get(p Path) *PhantomObjectInfo { m := pom.perPrefixObjects[p.Prefix().String()] if m == nil { return nil } return m[p.Base()] } func (pom *PhantomObjectMap) Get(p Path) *PhantomObjectInfo { pom.mtx.Lock() defer pom.mtx.Unlock() return pom.get(p) } func (pom *PhantomObjectMap) List(p Path) []*PhantomObjectInfo { pom.mtx.Lock() defer pom.mtx.Unlock() m := pom.perPrefixObjects[p.String()] retval := make([]*PhantomObjectInfo, 0, len(m)) for _, info := range m { retval = append(retval, info) } return retval } func (pom *PhantomObjectMap) Size() int { pom.mtx.Lock() defer pom.mtx.Unlock() return len(pom.ptrToPOIMMapMap) } func NewPhantomObjectMap() *PhantomObjectMap { return &PhantomObjectMap{ perPrefixObjects: map[string]phantomObjectInfoMap{}, ptrToPOIMMapMap: map[*PhantomObjectInfo]phantomObjectInfoMap{}, } } ================================================ FILE: phantom_object_map_test.go ================================================ package main import ( "github.com/stretchr/testify/assert" "testing" ) func TestPhantomObjectMapAdd(t *testing.T) { pom := NewPhantomObjectMap() assert.Equal(t, true, pom.Add(&PhantomObjectInfo{Key: Path{"", "a", "b"}})) assert.Equal(t, 1, pom.Size()) assert.Equal(t, false, pom.Add(&PhantomObjectInfo{Key: Path{"", "a", "b"}})) assert.Equal(t, 1, pom.Size()) assert.Equal(t, true, pom.Add(&PhantomObjectInfo{Key: Path{"", "a", "c"}})) assert.Equal(t, 2, pom.Size()) assert.Equal(t, true, pom.Add(&PhantomObjectInfo{Key: Path{"", "a", "b", "c"}})) assert.Equal(t, 3, pom.Size()) } func TestPhantomObjectMapRemove(t *testing.T) { pom := NewPhantomObjectMap() o1 := &PhantomObjectInfo{Key: Path{"", "a", "b"}} o2 := &PhantomObjectInfo{Key: Path{"", "a", "b"}} o3 := &PhantomObjectInfo{Key: Path{"", "a", "c"}} o4 := &PhantomObjectInfo{Key: Path{"", "a", "b", "c"}} assert.Equal(t, true, pom.Add(o1)) assert.Equal(t, 1, pom.Size()) assert.Equal(t, false, pom.Add(o2)) assert.Equal(t, 1, pom.Size()) assert.Equal(t, true, pom.Add(o3)) assert.Equal(t, 2, pom.Size()) assert.Equal(t, true, pom.Add(o4)) assert.Equal(t, 3, pom.Size()) assert.Equal(t, o3, pom.Remove(Path{"", "a", "c"})) assert.Nil(t, pom.Get(Path{"", "a", "c"})) assert.Equal(t, 2, pom.Size()) assert.Nil(t, pom.Remove(Path{"", "a", "c"})) assert.Equal(t, 2, pom.Size()) assert.Equal(t, o2, pom.Remove(Path{"", "a", "b"})) assert.Equal(t, 1, pom.Size()) assert.Nil(t, pom.Get(Path{"", "a", "b"})) } ================================================ FILE: s3-sftp-proxy.example.toml ================================================ host_key_file = "./host_key" [buckets.test] # endpoint = "http://endpoint" # s3_force_path_style = false # disable_ssl = false bucket_url = "s3://BUCKET/PREFIX" # bucket = BUCKET # key_prefix = PREFIX profile = "xxx" region = "ap-northeast-1" auth = "test" # [buckets.test.credentials] # aws_access_key_id = "aaa" # aws_secret_access_key = "bbb" [auth.test] type = "inplace" [auth.test.users.user01] password = "test" public_keys = """ ... """ [auth.test.users.user02] password = "test" public_keys = """ ... """ ================================================ FILE: server.go ================================================ package main import ( "context" "fmt" "io" "net" "sync" "time" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" ) type Server struct { *ssh.ServerConfig *S3Buckets *PhantomObjectMap ReaderLookbackBufferSize int ReaderMinChunkSize int ListerLookbackBufferSize int Log interface { DebugLogger InfoLogger ErrorLogger } Now func() time.Time } func asHandlers(handlers interface { sftp.FileReader sftp.FileWriter sftp.FileCmder sftp.FileLister }) sftp.Handlers { return sftp.Handlers{handlers, handlers, handlers, handlers} } func (s *Server) HandleChannel(ctx context.Context, bucket *S3Bucket, sshCh ssh.Channel, reqs <-chan *ssh.Request) { defer s.Log.Debug("HandleChannel ended") server := sftp.NewRequestServer( sshCh, asHandlers( &S3BucketIO{ Ctx: ctx, Bucket: bucket, ReaderLookbackBufferSize: s.ReaderLookbackBufferSize, ReaderMinChunkSize: s.ReaderMinChunkSize, ListerLookbackBufferSize: s.ListerLookbackBufferSize, Log: s.Log, PhantomObjectMap: s.PhantomObjectMap, Perms: bucket.Perms, ServerSideEncryption: &bucket.ServerSideEncryption, Now: s.Now, }, ), ) innerCtx, cancel := context.WithCancel(ctx) defer cancel() wg := sync.WaitGroup{} wg.Add(1) go func() { defer s.Log.Debug("HandleChannel.discardRequest ended") defer wg.Done() defer cancel() outer: for { select { case <-innerCtx.Done(): break outer case req := <-reqs: if req == nil { break outer } ok := false if req.Type == "subsystem" && string(req.Payload[4:]) == "sftp" { ok = true } req.Reply(ok, nil) } } }() wg.Add(1) go func() { defer s.Log.Debug("HandleChannel.serve ended") defer wg.Done() defer cancel() go func() { <-innerCtx.Done() server.Close() }() if err := server.Serve(); err != io.EOF { s.Log.Error(err.Error()) } }() wg.Wait() } func (s *Server) HandleClient(ctx context.Context, conn *net.TCPConn) error { defer s.Log.Debug("HandleClient ended") defer func() { F(s.Log.Info, "connection from client %s closed", conn.RemoteAddr().String()) conn.Close() }() F(s.Log.Info, "connected from client %s", conn.RemoteAddr().String()) innerCtx, cancel := context.WithCancel(ctx) defer cancel() go func() { <-innerCtx.Done() conn.SetDeadline(time.Unix(1, 0)) }() // Before use, a handshake must be performed on the incoming net.Conn. sconn, chans, reqs, err := ssh.NewServerConn(conn, s.ServerConfig) if err != nil { return err } F(s.Log.Info, "user %s logged in", sconn.User()) bucket, ok := s.UserToBucketMap[sconn.User()] if !ok { return fmt.Errorf("unknown error: no bucket designated to user %s found", sconn.User()) } wg := sync.WaitGroup{} wg.Add(1) go func(reqs <-chan *ssh.Request) { defer wg.Done() defer s.Log.Debug("HandleClient.requestHandler ended") for _ = range reqs { } }(reqs) wg.Add(1) go func(chans <-chan ssh.NewChannel) { defer wg.Done() defer cancel() defer s.Log.Debug("HandleClient.channelHandler ended") for newSSHCh := range chans { if newSSHCh.ChannelType() != "session" { newSSHCh.Reject(ssh.UnknownChannelType, "unknown channel type") F(s.Log.Info, "unknown channel type: %s", newSSHCh.ChannelType()) continue } F(s.Log.Info, "channel: %s", newSSHCh.ChannelType()) sshCh, reqs, err := newSSHCh.Accept() if err != nil { F(s.Log.Error, "could not accept channel: %s", err.Error()) break } wg.Add(1) go func() { defer wg.Done() s.HandleChannel(innerCtx, bucket, sshCh, reqs) }() } }(chans) wg.Wait() return nil } func (s *Server) RunListenerEventLoop(ctx context.Context, lsnr *net.TCPListener) error { defer s.Log.Debug("RunListenerEventLoop ended") wg := sync.WaitGroup{} connChan := make(chan *net.TCPConn) var err error wg.Add(1) go func() { defer s.Log.Debug("RunListenerEventLoop.connHandler ended") defer wg.Done() defer close(connChan) outer: for { var conn *net.TCPConn conn, err = lsnr.AcceptTCP() if err != nil { return } select { case <-ctx.Done(): conn.Close() break outer case connChan <- conn: } } }() outer: for { select { case conn := <-connChan: wg.Add(1) go func() { defer wg.Done() err := s.HandleClient(ctx, conn) if err != nil { s.Log.Error(err.Error()) } }() case <-ctx.Done(): lsnr.SetDeadline(time.Unix(1, 0)) break outer } } // drain for _ = range connChan { } wg.Wait() if IsTimeout(err) { err = nil } return err } ================================================ FILE: user.go ================================================ package main import ( "fmt" "github.com/pkg/errors" "golang.org/x/crypto/ssh" "io/ioutil" ) type User struct { Name string Password string PublicKeys []ssh.PublicKey } type UserStore struct { Name string Users []*User usersMap map[string]*User } type UserStores map[string]UserStore func (us *UserStore) Add(u *User) { us.Users = append(us.Users, u) us.usersMap[u.Name] = u } func (us *UserStore) Lookup(name string) *User { u, _ := us.usersMap[name] return u } func parseAuthorizedKeys(pubKeys []ssh.PublicKey, pubKeyFileContent []byte) ([]ssh.PublicKey, error) { for len(pubKeyFileContent) > 0 { var pubKey ssh.PublicKey var err error pubKey, _, _, pubKeyFileContent, err = ssh.ParseAuthorizedKey(pubKeyFileContent) if err != nil { return pubKeys, err } pubKeys = append(pubKeys, pubKey) } return pubKeys, nil } func buildUsersFromAuthConfigInplace(users []*User, aCfg *AuthConfig) ([]*User, error) { for name, params := range aCfg.Users { var pubKeys []ssh.PublicKey if params.PublicKeys != "" { var err error pubKeys, err = parseAuthorizedKeys(pubKeys, []byte(params.PublicKeys)) if err != nil { return users, errors.Wrapf(err, `user "%s"`, name) } } if params.PublicKeyFile != "" { var err error pubKeysFileContent, err := ioutil.ReadFile(params.PublicKeyFile) if err != nil { return users, errors.Wrapf(err, `user "%s"`, name) } pubKeys, err = parseAuthorizedKeys(pubKeys, pubKeysFileContent) if err != nil { return users, errors.Wrapf(err, `user "%s"`, name) } } users = append(users, &User{ Name: name, Password: params.Password, PublicKeys: pubKeys, }) } return users, nil } func buildUsersFromAuthConfig(users []*User, aCfg *AuthConfig) ([]*User, error) { switch aCfg.Type { case "inplace": return buildUsersFromAuthConfigInplace(users, aCfg) default: return users, fmt.Errorf("unknown auth config type: %s", aCfg.Type) } } func NewUserStoresFromConfig(cfg *S3SFTPProxyConfig) (UserStores, error) { uStores := UserStores{} for name, aCfg := range cfg.AuthConfigs { var err error var users []*User users, err = buildUsersFromAuthConfig(users, aCfg) if err != nil { return nil, err } usersMap := map[string]*User{} for _, u := range users { usersMap[u.Name] = u } uStores[name] = UserStore{Name: name, Users: users, usersMap: usersMap} } return uStores, nil } ================================================ FILE: utils.go ================================================ package main import "fmt" type PrintlnLike func(...interface{}) func F(p PrintlnLike, f string, args ...interface{}) { p(fmt.Sprintf(f, args...)) }