Update aws-sdk-go library to latest release

From version 1.0.5 to version 1.15.74.
Simply because the old version was really old.
This commit is contained in:
Bryan Boreham
2018-11-13 15:08:30 +00:00
parent 730d5e3b62
commit e49c2dcb84
2482 changed files with 1295782 additions and 334246 deletions

View File

@@ -1,6 +0,0 @@
source 'https://rubygems.org'
gem 'yard', git: 'git://github.com/lsegal/yard'
gem 'yard-go', git: 'git://github.com/lsegal/yard-go'
gem 'rdiscount'

View File

@@ -1,98 +0,0 @@
LINTIGNOREDOT='awstesting/integration.+should not use dot imports'
LINTIGNOREDOC='service/[^/]+/(api|service|waiters)\.go:.+(comment on exported|should have comment or be unexported)'
LINTIGNORECONST='service/[^/]+/(api|service|waiters)\.go:.+(type|struct field|const|func) ([^ ]+) should be ([^ ]+)'
LINTIGNORESTUTTER='service/[^/]+/(api|service)\.go:.+(and that stutters)'
LINTIGNOREINFLECT='service/[^/]+/(api|service)\.go:.+method .+ should be '
LINTIGNOREDEPS='vendor/.+\.go'
SDK_WITH_VENDOR_PKGS=$(shell go list ./... | grep -v "/vendor/src")
SDK_ONLY_PKGS=$(shell go list ./... | grep -v "/vendor/")
all: generate unit
help:
@echo "Please use \`make <target>' where <target> is one of"
@echo " api_info to print a list of services and versions"
@echo " docs to build SDK documentation"
@echo " build to go build the SDK"
@echo " unit to run unit tests"
@echo " integration to run integration tests"
@echo " verify to verify tests"
@echo " lint to lint the SDK"
@echo " vet to vet the SDK"
@echo " generate to go generate and make services"
@echo " gen-test to generate protocol tests"
@echo " gen-services to generate services"
@echo " get-deps to go get the SDK dependencies"
@echo " get-deps-unit to get the SDK's unit test dependencies"
@echo " get-deps-integ to get the SDK's integration test dependencies"
@echo " get-deps-verify to get the SDK's verification dependencies"
generate: gen-test gen-endpoints gen-services
gen-test: gen-protocol-test
gen-services:
go generate ./service
gen-protocol-test:
go generate ./private/protocol/...
gen-endpoints:
go generate ./private/endpoints
build:
@echo "go build SDK and vendor packages"
@go build $(SDK_WITH_VENDOR_PKGS)
unit: get-deps-unit build verify
@echo "go test SDK and vendor packages"
@go test $(SDK_WITH_VENDOR_PKGS)
integration: get-deps-integ
go test -tags=integration ./awstesting/integration/customizations/...
gucumber ./awstesting/integration/smoke
verify: get-deps-verify lint vet
lint:
@echo "go lint SDK and vendor packages"
@lint=`golint ./...`; \
lint=`echo "$$lint" | grep -E -v -e ${LINTIGNOREDOT} -e ${LINTIGNOREDOC} -e ${LINTIGNORECONST} -e ${LINTIGNORESTUTTER} -e ${LINTIGNOREINFLECT} -e ${LINTIGNOREDEPS}`; \
echo "$$lint"; \
if [ "$$lint" != "" ]; then exit 1; fi
vet:
go tool vet -all -shadow $(shell ls -d */ | grep -v vendor)
get-deps: get-deps-unit get-deps-integ get-deps-verify
@echo "go get SDK dependencies"
@go get -v $(SDK_ONLY_PKGS)
get-deps-unit:
@echo "go get SDK unit testing dependancies"
go get github.com/stretchr/testify
go get github.com/smartystreets/goconvey
get-deps-integ: get-deps-unit
@echo "go get SDK integration testing dependencies"
go get github.com/lsegal/gucumber/cmd/gucumber
get-deps-verify:
@echo "go get SDK verification utilities"
go get github.com/golang/lint/golint
bench:
@echo "go bench SDK packages"
@go test -run NONE -bench . -benchmem -tags 'bench' $(SDK_ONLY_PKGS)
bench-protocol:
@echo "go bench SDK protocol marshallers"
@go test -run NONE -bench . -benchmem -tags 'bench' ./private/protocol/...
docs:
@echo "generate SDK docs"
rm -rf doc && bundle install && bundle exec yard
api_info:
@go run private/model/cli/api-info/api-info.go

View File

@@ -1,3 +0,0 @@
AWS SDK for Go
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2014-2015 Stripe, Inc.

View File

@@ -1,93 +0,0 @@
# AWS SDK for Go
[![API Reference](http://img.shields.io/badge/api-reference-blue.svg)](http://docs.aws.amazon.com/sdk-for-go/api)
[![Join the chat at https://gitter.im/aws/aws-sdk-go](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/aws/aws-sdk-go?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Build Status](https://img.shields.io/travis/aws/aws-sdk-go.svg)](https://travis-ci.org/aws/aws-sdk-go)
[![Apache V2 License](http://img.shields.io/badge/license-Apache%20V2-blue.svg)](https://github.com/aws/aws-sdk-go/blob/master/LICENSE.txt)
aws-sdk-go is the official AWS SDK for the Go programming language.
Checkout our [release notes](https://github.com/aws/aws-sdk-go/releases) for information about the latest bug fixes, updates, and features added to the SDK.
## Installing
If you are using Go 1.5 with the `GO15VENDOREXPERIMENT=1` vendoring flag you can use the following to get the SDK as the SDK's runtime dependancies are vendored in the `vendor` folder.
$ go get -u github.com/aws/aws-sdk-go
Otherwise you'll need to tell Go to get the SDK and all of its dependancies.
$ go get -u github.com/aws/aws-sdk-go/...
## Configuring Credentials
Before using the SDK, ensure that you've configured credentials. The best
way to configure credentials on a development machine is to use the
`~/.aws/credentials` file, which might look like:
```
[default]
aws_access_key_id = AKID1234567890
aws_secret_access_key = MY-SECRET-KEY
```
You can learn more about the credentials file from this
[blog post](http://blogs.aws.amazon.com/security/post/Tx3D6U6WSFGOK2H/A-New-and-Standardized-Way-to-Manage-Credentials-in-the-AWS-SDKs).
Alternatively, you can set the following environment variables:
```
AWS_ACCESS_KEY_ID=AKID1234567890
AWS_SECRET_ACCESS_KEY=MY-SECRET-KEY
```
## Using the Go SDK
To use a service in the SDK, create a service variable by calling the `New()`
function. Once you have a service client, you can call API operations which each
return response data and a possible error.
To list a set of instance IDs from EC2, you could run:
```go
package main
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
)
func main() {
// Create an EC2 service object in the "us-west-2" region
// Note that you can also configure your region globally by
// exporting the AWS_REGION environment variable
svc := ec2.New(session.New(), &aws.Config{Region: aws.String("us-west-2")})
// Call the DescribeInstances Operation
resp, err := svc.DescribeInstances(nil)
if err != nil {
panic(err)
}
// resp has all of the response data, pull out instance IDs:
fmt.Println("> Number of reservation sets: ", len(resp.Reservations))
for idx, res := range resp.Reservations {
fmt.Println(" > Number of instances: ", len(res.Instances))
for _, inst := range resp.Reservations[idx].Instances {
fmt.Println(" - Instance ID: ", *inst.InstanceId)
}
}
}
```
You can find more information and operations in our
[API documentation](http://docs.aws.amazon.com/sdk-for-go/api/).
## License
This SDK is distributed under the
[Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0),
see LICENSE.txt and NOTICE.txt for more information.

86
vendor/github.com/aws/aws-sdk-go/aws/arn/arn.go generated vendored Normal file
View File

@@ -0,0 +1,86 @@
// Package arn provides a parser for interacting with Amazon Resource Names.
package arn
import (
"errors"
"strings"
)
const (
arnDelimiter = ":"
arnSections = 6
arnPrefix = "arn:"
// zero-indexed
sectionPartition = 1
sectionService = 2
sectionRegion = 3
sectionAccountID = 4
sectionResource = 5
// errors
invalidPrefix = "arn: invalid prefix"
invalidSections = "arn: not enough sections"
)
// ARN captures the individual fields of an Amazon Resource Name.
// See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more information.
type ARN struct {
// The partition that the resource is in. For standard AWS regions, the partition is "aws". If you have resources in
// other partitions, the partition is "aws-partitionname". For example, the partition for resources in the China
// (Beijing) region is "aws-cn".
Partition string
// The service namespace that identifies the AWS product (for example, Amazon S3, IAM, or Amazon RDS). For a list of
// namespaces, see
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-aws-service-namespaces.
Service string
// The region the resource resides in. Note that the ARNs for some resources do not require a region, so this
// component might be omitted.
Region string
// The ID of the AWS account that owns the resource, without the hyphens. For example, 123456789012. Note that the
// ARNs for some resources don't require an account number, so this component might be omitted.
AccountID string
// The content of this part of the ARN varies by service. It often includes an indicator of the type of resource —
// for example, an IAM user or Amazon RDS database - followed by a slash (/) or a colon (:), followed by the
// resource name itself. Some services allows paths for resource names, as described in
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#arns-paths.
Resource string
}
// Parse parses an ARN into its constituent parts.
//
// Some example ARNs:
// arn:aws:elasticbeanstalk:us-east-1:123456789012:environment/My App/MyEnvironment
// arn:aws:iam::123456789012:user/David
// arn:aws:rds:eu-west-1:123456789012:db:mysql-db
// arn:aws:s3:::my_corporate_bucket/exampleobject.png
func Parse(arn string) (ARN, error) {
if !strings.HasPrefix(arn, arnPrefix) {
return ARN{}, errors.New(invalidPrefix)
}
sections := strings.SplitN(arn, arnDelimiter, arnSections)
if len(sections) != arnSections {
return ARN{}, errors.New(invalidSections)
}
return ARN{
Partition: sections[sectionPartition],
Service: sections[sectionService],
Region: sections[sectionRegion],
AccountID: sections[sectionAccountID],
Resource: sections[sectionResource],
}, nil
}
// String returns the canonical representation of the ARN
func (arn ARN) String() string {
return arnPrefix +
arn.Partition + arnDelimiter +
arn.Service + arnDelimiter +
arn.Region + arnDelimiter +
arn.AccountID + arnDelimiter +
arn.Resource
}

90
vendor/github.com/aws/aws-sdk-go/aws/arn/arn_test.go generated vendored Normal file
View File

@@ -0,0 +1,90 @@
// +build go1.7
package arn
import (
"errors"
"testing"
)
func TestParseARN(t *testing.T) {
cases := []struct {
input string
arn ARN
err error
}{
{
input: "invalid",
err: errors.New(invalidPrefix),
},
{
input: "arn:nope",
err: errors.New(invalidSections),
},
{
input: "arn:aws:ecr:us-west-2:123456789012:repository/foo/bar",
arn: ARN{
Partition: "aws",
Service: "ecr",
Region: "us-west-2",
AccountID: "123456789012",
Resource: "repository/foo/bar",
},
},
{
input: "arn:aws:elasticbeanstalk:us-east-1:123456789012:environment/My App/MyEnvironment",
arn: ARN{
Partition: "aws",
Service: "elasticbeanstalk",
Region: "us-east-1",
AccountID: "123456789012",
Resource: "environment/My App/MyEnvironment",
},
},
{
input: "arn:aws:iam::123456789012:user/David",
arn: ARN{
Partition: "aws",
Service: "iam",
Region: "",
AccountID: "123456789012",
Resource: "user/David",
},
},
{
input: "arn:aws:rds:eu-west-1:123456789012:db:mysql-db",
arn: ARN{
Partition: "aws",
Service: "rds",
Region: "eu-west-1",
AccountID: "123456789012",
Resource: "db:mysql-db",
},
},
{
input: "arn:aws:s3:::my_corporate_bucket/exampleobject.png",
arn: ARN{
Partition: "aws",
Service: "s3",
Region: "",
AccountID: "",
Resource: "my_corporate_bucket/exampleobject.png",
},
},
}
for _, tc := range cases {
t.Run(tc.input, func(t *testing.T) {
spec, err := Parse(tc.input)
if tc.arn != spec {
t.Errorf("Expected %q to parse as %v, but got %v", tc.input, tc.arn, spec)
}
if err == nil && tc.err != nil {
t.Errorf("Expected err to be %v, but got nil", tc.err)
} else if err != nil && tc.err == nil {
t.Errorf("Expected err to be nil, but got %v", err)
} else if err != nil && tc.err != nil && err.Error() != tc.err.Error() {
t.Errorf("Expected err to be %v, but got %v", tc.err, err)
}
})
}
}

View File

@@ -14,13 +14,13 @@ package awserr
// if err != nil {
// if awsErr, ok := err.(awserr.Error); ok {
// // Get error details
// log.Println("Error:", err.Code(), err.Message())
// log.Println("Error:", awsErr.Code(), awsErr.Message())
//
// // Prints out full error message, including original error if there was one.
// log.Println("Error:", err.Error())
// log.Println("Error:", awsErr.Error())
//
// // Get original error
// if origErr := err.Err(); origErr != nil {
// if origErr := awsErr.OrigErr(); origErr != nil {
// // operate on original error.
// }
// } else {
@@ -42,15 +42,55 @@ type Error interface {
OrigErr() error
}
// BatchError is a batch of errors which also wraps lower level errors with
// code, message, and original errors. Calling Error() will include all errors
// that occurred in the batch.
//
// Deprecated: Replaced with BatchedErrors. Only defined for backwards
// compatibility.
type BatchError interface {
// Satisfy the generic error interface.
error
// Returns the short phrase depicting the classification of the error.
Code() string
// Returns the error details message.
Message() string
// Returns the original error if one was set. Nil is returned if not set.
OrigErrs() []error
}
// BatchedErrors is a batch of errors which also wraps lower level errors with
// code, message, and original errors. Calling Error() will include all errors
// that occurred in the batch.
//
// Replaces BatchError
type BatchedErrors interface {
// Satisfy the base Error interface.
Error
// Returns the original error if one was set. Nil is returned if not set.
OrigErrs() []error
}
// New returns an Error object described by the code, message, and origErr.
//
// If origErr satisfies the Error interface it will not be wrapped within a new
// Error object and will instead be returned.
func New(code, message string, origErr error) Error {
if e, ok := origErr.(Error); ok && e != nil {
return e
var errs []error
if origErr != nil {
errs = append(errs, origErr)
}
return newBaseError(code, message, origErr)
return newBaseError(code, message, errs)
}
// NewBatchError returns an BatchedErrors with a collection of errors as an
// array of errors.
func NewBatchError(code, message string, errs []error) BatchedErrors {
return newBaseError(code, message, errs)
}
// A RequestFailure is an interface to extract request failure information from
@@ -63,9 +103,9 @@ func New(code, message string, origErr error) Error {
// output, err := s3manage.Upload(svc, input, opts)
// if err != nil {
// if reqerr, ok := err.(RequestFailure); ok {
// log.Printf("Request failed", reqerr.Code(), reqerr.Message(), reqerr.RequestID())
// log.Println("Request failed", reqerr.Code(), reqerr.Message(), reqerr.RequestID())
// } else {
// log.Printf("Error:", err.Error()
// log.Println("Error:", err.Error())
// }
// }
//

View File

@@ -31,23 +31,27 @@ type baseError struct {
// Optional original error this error is based off of. Allows building
// chained errors.
origErr error
errs []error
}
// newBaseError returns an error object for the code, message, and err.
// newBaseError returns an error object for the code, message, and errors.
//
// code is a short no whitespace phrase depicting the classification of
// the error that is being created.
//
// message is the free flow string containing detailed information about the error.
// message is the free flow string containing detailed information about the
// error.
//
// origErr is the error object which will be nested under the new error to be returned.
func newBaseError(code, message string, origErr error) *baseError {
return &baseError{
// origErrs is the error objects which will be nested under the new errors to
// be returned.
func newBaseError(code, message string, origErrs []error) *baseError {
b := &baseError{
code: code,
message: message,
origErr: origErr,
errs: origErrs,
}
return b
}
// Error returns the string representation of the error.
@@ -56,7 +60,12 @@ func newBaseError(code, message string, origErr error) *baseError {
//
// Satisfies the error interface.
func (b baseError) Error() string {
return SprintError(b.code, b.message, "", b.origErr)
size := len(b.errs)
if size > 0 {
return SprintError(b.code, b.message, "", errorList(b.errs))
}
return SprintError(b.code, b.message, "", nil)
}
// String returns the string representation of the error.
@@ -75,10 +84,28 @@ func (b baseError) Message() string {
return b.message
}
// OrigErr returns the original error if one was set. Nil is returned if no error
// was set.
// OrigErr returns the original error if one was set. Nil is returned if no
// error was set. This only returns the first element in the list. If the full
// list is needed, use BatchedErrors.
func (b baseError) OrigErr() error {
return b.origErr
switch len(b.errs) {
case 0:
return nil
case 1:
return b.errs[0]
default:
if err, ok := b.errs[0].(Error); ok {
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
}
return NewBatchError("BatchedErrors",
"multiple errors occurred", b.errs)
}
}
// OrigErrs returns the original errors if one was set. An empty slice is
// returned if no error was set.
func (b baseError) OrigErrs() []error {
return b.errs
}
// So that the Error interface type can be included as an anonymous field
@@ -94,8 +121,8 @@ type requestError struct {
requestID string
}
// newRequestError returns a wrapped error with additional information for request
// status code, and service requestID.
// newRequestError returns a wrapped error with additional information for
// request status code, and service requestID.
//
// Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code
@@ -133,3 +160,35 @@ func (r requestError) StatusCode() int {
func (r requestError) RequestID() string {
return r.requestID
}
// OrigErrs returns the original errors if one was set. An empty slice is
// returned if no error was set.
func (r requestError) OrigErrs() []error {
if b, ok := r.awsError.(BatchedErrors); ok {
return b.OrigErrs()
}
return []error{r.OrigErr()}
}
// An error list that satisfies the golang interface
type errorList []error
// Error returns the string representation of the error.
//
// Satisfies the error interface.
func (e errorList) Error() string {
msg := ""
// How do we want to handle the array size being zero
if size := len(e); size > 0 {
for i := 0; i < size; i++ {
msg += fmt.Sprintf("%s", e[i].Error())
// We check the next index to see if it is within the slice.
// If it is, then we append a newline. We do this, because unit tests
// could be broken with the additional '\n'
if i+1 < size {
msg += "\n"
}
}
}
return msg
}

View File

@@ -3,6 +3,7 @@ package awsutil
import (
"io"
"reflect"
"time"
)
// Copy deeply copies a src structure to dst. Useful for copying request and
@@ -49,7 +50,14 @@ func rcopy(dst, src reflect.Value, root bool) {
} else {
e := src.Type().Elem()
if dst.CanSet() && !src.IsNil() {
dst.Set(reflect.New(e))
if _, ok := src.Interface().(*time.Time); !ok {
dst.Set(reflect.New(e))
} else {
tempValue := reflect.New(e)
tempValue.Elem().Set(src.Elem())
// Sets time.Time's unexported values
dst.Set(tempValue)
}
}
if src.Elem().IsValid() {
// Keep the current root state since the depth hasn't changed

View File

@@ -5,10 +5,11 @@ import (
"fmt"
"io"
"io/ioutil"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func ExampleCopy() {
@@ -36,11 +37,19 @@ func ExampleCopy() {
// }
}
func TestCopy(t *testing.T) {
func TestCopy1(t *testing.T) {
type Bar struct {
a *int
B *int
c int
D int
}
type Foo struct {
A int
B []*string
C map[string]*int
D *time.Time
E *Bar
}
// Create the initial value
@@ -48,6 +57,9 @@ func TestCopy(t *testing.T) {
str2 := "bye bye"
int1 := 1
int2 := 2
intPtr1 := 1
intPtr2 := 2
now := time.Now()
f1 := &Foo{
A: 1,
B: []*string{&str1, &str2},
@@ -55,6 +67,13 @@ func TestCopy(t *testing.T) {
"A": &int1,
"B": &int2,
},
D: &now,
E: &Bar{
&intPtr1,
&intPtr2,
2,
3,
},
}
// Do the copy
@@ -62,19 +81,60 @@ func TestCopy(t *testing.T) {
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.C, f1.C; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.D, f1.D; !v1.Equal(*v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.E.B, f1.E.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.E.D, f1.E.D; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
// But pointers are not!
str3 := "nothello"
int3 := 57
f2.A = 100
f2.B[0] = &str3
f2.C["B"] = &int3
assert.NotEqual(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.C, f1.C)
*f2.B[0] = str3
*f2.C["B"] = int3
*f2.D = time.Now()
f2.E.a = &int3
*f2.E.B = int3
f2.E.c = 5
f2.E.D = 5
if v1, v2 := f2.A, f1.A; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B, f1.B; reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.C, f1.C; reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.D, f1.D; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.a, f1.E.a; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.B, f1.E.B; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.c, f1.E.c; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.D, f1.E.D; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
}
func TestCopyNestedWithUnexported(t *testing.T) {
@@ -93,10 +153,18 @@ func TestCopyNestedWithUnexported(t *testing.T) {
awsutil.Copy(&f2, f1)
// Values match
assert.Equal(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.B.a, f1.B.a)
assert.Equal(t, f2.B.B, f2.B.B)
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B.a, f1.B.a; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B.B, f2.B.B; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyIgnoreNilMembers(t *testing.T) {
@@ -107,34 +175,56 @@ func TestCopyIgnoreNilMembers(t *testing.T) {
}
f := &Foo{}
assert.Nil(t, f.A)
assert.Nil(t, f.B)
assert.Nil(t, f.C)
if v1 := f.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
var f2 Foo
awsutil.Copy(&f2, f)
assert.Nil(t, f2.A)
assert.Nil(t, f2.B)
assert.Nil(t, f2.C)
if v1 := f2.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f2.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f2.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
fcopy := awsutil.CopyOf(f)
f3 := fcopy.(*Foo)
assert.Nil(t, f3.A)
assert.Nil(t, f3.B)
assert.Nil(t, f3.C)
if v1 := f3.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f3.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f3.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
}
func TestCopyPrimitive(t *testing.T) {
str := "hello"
var s string
awsutil.Copy(&s, &str)
assert.Equal(t, "hello", s)
if v1, v2 := "hello", s; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyNil(t *testing.T) {
var s string
awsutil.Copy(&s, nil)
assert.Equal(t, "", s)
if v1, v2 := "", s; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyReader(t *testing.T) {
@@ -142,13 +232,21 @@ func TestCopyReader(t *testing.T) {
var r io.Reader
awsutil.Copy(&r, buf)
b, err := ioutil.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, []byte("hello world"), b)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if v1, v2 := []byte("hello world"), b; !bytes.Equal(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
// empty bytes because this is not a deep copy
b, err = ioutil.ReadAll(buf)
assert.NoError(t, err)
assert.Equal(t, []byte(""), b)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if v1, v2 := []byte(""), b; !bytes.Equal(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyDifferentStructs(t *testing.T) {
@@ -194,17 +292,39 @@ func TestCopyDifferentStructs(t *testing.T) {
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
assert.Equal(t, "unique", f1.SrcUnique)
assert.Equal(t, 1, f1.SameNameDiffType)
assert.Equal(t, 0, f2.DstUnique)
assert.Equal(t, "", f2.SameNameDiffType)
assert.Equal(t, int1, *f1.unexportedPtr)
assert.Nil(t, f2.unexportedPtr)
assert.Equal(t, int2, *f1.ExportedPtr)
assert.Equal(t, int2, *f2.ExportedPtr)
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.C, f1.C; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := "unique", f1.SrcUnique; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := 1, f1.SameNameDiffType; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := 0, f2.DstUnique; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := "", f2.SameNameDiffType; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := int1, *f1.unexportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1 := f2.unexportedPtr; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1, v2 := int2, *f1.ExportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := int2, *f2.ExportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func ExampleCopyOf() {

View File

@@ -5,7 +5,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func TestDeepEqual(t *testing.T) {
@@ -24,6 +23,8 @@ func TestDeepEqual(t *testing.T) {
}
for i, c := range cases {
assert.Equal(t, c.equal, awsutil.DeepEqual(c.a, c.b), "%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
if awsutil.DeepEqual(c.a, c.b) != c.equal {
t.Errorf("%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
}
}
}

View File

@@ -106,8 +106,8 @@ func rValuesAtPath(v interface{}, path string, createPath, caseSensitive, nilTer
if indexStar || index != nil {
nextvals = []reflect.Value{}
for _, value := range values {
value := reflect.Indirect(value)
for _, valItem := range values {
value := reflect.Indirect(valItem)
if value.Kind() != reflect.Slice {
continue
}

View File

@@ -1,10 +1,10 @@
package awsutil_test
import (
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
type Struct struct {
@@ -50,8 +50,12 @@ func TestValueAtPathSuccess(t *testing.T) {
}
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
if err != nil {
t.Errorf("case %v, expected no error, %v", i, c.path)
}
if e, a := c.expect, v; !awsutil.DeepEqual(e, a) {
t.Errorf("case %v, %v", i, c.path)
}
}
}
@@ -78,12 +82,18 @@ func TestValueAtPathFailure(t *testing.T) {
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
if c.errContains != "" {
assert.Contains(t, err.Error(), c.errContains, "case %d, expected error, %s", i, c.path)
if !strings.Contains(err.Error(), c.errContains) {
t.Errorf("case %v, expected error, %v", i, c.path)
}
continue
} else {
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
if err != nil {
t.Errorf("case %v, expected no error, %v", i, c.path)
}
}
if e, a := c.expect, v; !awsutil.DeepEqual(e, a) {
t.Errorf("case %v, %v", i, c.path)
}
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
}
}
@@ -92,51 +102,81 @@ func TestSetValueAtPathSuccess(t *testing.T) {
awsutil.SetValueAtPath(&s, "C", "test1")
awsutil.SetValueAtPath(&s, "B.B.C", "test2")
awsutil.SetValueAtPath(&s, "B.D.C", "test3")
assert.Equal(t, "test1", s.C)
assert.Equal(t, "test2", s.B.B.C)
assert.Equal(t, "test3", s.B.D.C)
if e, a := "test1", s.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test2", s.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test3", s.B.D.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
awsutil.SetValueAtPath(&s, "B.*.C", "test0")
assert.Equal(t, "test0", s.B.B.C)
assert.Equal(t, "test0", s.B.D.C)
if e, a := "test0", s.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test0", s.B.D.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
var s2 Struct
awsutil.SetValueAtPath(&s2, "b.b.c", "test0")
assert.Equal(t, "test0", s2.B.B.C)
if e, a := "test0", s2.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
awsutil.SetValueAtPath(&s2, "A", []Struct{{}})
assert.Equal(t, []Struct{{}}, s2.A)
if e, a := []Struct{{}}, s2.A; !awsutil.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
str := "foo"
s3 := Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", str)
assert.Equal(t, "foo", s3.B.B.C)
if e, a := "foo", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{B: &Struct{B: &Struct{C: str}}}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
assert.Equal(t, "", s3.B.B.C)
if e, a := "", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
assert.Equal(t, "", s3.B.B.C)
if e, a := "", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", &str)
assert.Equal(t, "foo", s3.B.B.C)
if e, a := "foo", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
var s4 struct{ Name *string }
awsutil.SetValueAtPath(&s4, "Name", str)
assert.Equal(t, str, *s4.Name)
if e, a := str, *s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", nil)
assert.Equal(t, (*string)(nil), s4.Name)
if e, a := (*string)(nil), s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{Name: &str}
awsutil.SetValueAtPath(&s4, "Name", nil)
assert.Equal(t, (*string)(nil), s4.Name)
if e, a := (*string)(nil), s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", &str)
assert.Equal(t, str, *s4.Name)
if e, a := str, *s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
}

View File

@@ -61,6 +61,12 @@ func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
case reflect.Slice:
strtype := v.Type().String()
if strtype == "[]uint8" {
fmt.Fprintf(buf, "<binary> len %d", v.Len())
break
}
nl, id, id2 := "", "", ""
if v.Len() > 3 {
nl, id, id2 = "\n", strings.Repeat(" ", indent), strings.Repeat(" ", indent+2)
@@ -91,6 +97,10 @@ func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")
default:
if !v.IsValid() {
fmt.Fprint(buf, "<invalid value>")
return
}
format := "%v"
switch v.Interface().(type) {
case string:

View File

@@ -2,8 +2,6 @@ package client
import (
"fmt"
"io/ioutil"
"net/http/httputil"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
@@ -12,9 +10,17 @@ import (
// A Config provides configuration to a service client instance.
type Config struct {
Config *aws.Config
Handlers request.Handlers
Endpoint, SigningRegion string
Config *aws.Config
Handlers request.Handlers
Endpoint string
SigningRegion string
SigningName string
// States that the signing name did not come from a modeled source but
// was derived based on other data. Used by service client constructors
// to determine if the signin name can be overriden based on metadata the
// service has.
SigningNameDerived bool
}
// ConfigProvider provides a generic way for a service client to receive
@@ -23,6 +29,13 @@ type ConfigProvider interface {
ClientConfig(serviceName string, cfgs ...*aws.Config) Config
}
// ConfigNoResolveEndpointProvider same as ConfigProvider except it will not
// resolve the endpoint automatically. The service client's endpoint must be
// provided via the aws.Config.Endpoint field.
type ConfigNoResolveEndpointProvider interface {
ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) Config
}
// A Client implements the base client request and response handling
// used by all service clients.
type Client struct {
@@ -38,7 +51,7 @@ func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, op
svc := &Client{
Config: cfg,
ClientInfo: info,
Handlers: handlers,
Handlers: handlers.Copy(),
}
switch retryer, ok := cfg.Retryer.(request.Retryer); {
@@ -78,43 +91,6 @@ func (c *Client) AddDebugHandlers() {
return
}
c.Handlers.Send.PushFront(logRequest)
c.Handlers.Send.PushBack(logResponse)
}
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if logBody {
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.Body.Seek(r.BodyStart, 0)
r.HTTPRequest.Body = ioutil.NopCloser(r.Body)
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
func logResponse(r *request.Request) {
var msg = "no reponse data"
if r.HTTPResponse != nil {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, _ := httputil.DumpResponse(r.HTTPResponse, logBody)
msg = string(dumpedBody)
} else if r.Error != nil {
msg = r.Error.Error()
}
r.Config.Logger.Log(fmt.Sprintf(logRespMsg, r.ClientInfo.ServiceName, r.Operation.Name, msg))
c.Handlers.Send.PushFrontNamed(LogHTTPRequestHandler)
c.Handlers.Send.PushBackNamed(LogHTTPResponseHandler)
}

View File

@@ -0,0 +1,78 @@
package client
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
)
func pushBackTestHandler(name string, list *request.HandlerList) *bool {
called := false
(*list).PushBackNamed(request.NamedHandler{
Name: name,
Fn: func(r *request.Request) {
called = true
},
})
return &called
}
func pushFrontTestHandler(name string, list *request.HandlerList) *bool {
called := false
(*list).PushFrontNamed(request.NamedHandler{
Name: name,
Fn: func(r *request.Request) {
called = true
},
})
return &called
}
func TestNewClient_CopyHandlers(t *testing.T) {
handlers := request.Handlers{}
firstCalled := pushBackTestHandler("first", &handlers.Send)
secondCalled := pushBackTestHandler("second", &handlers.Send)
var clientHandlerCalled *bool
c := New(aws.Config{}, metadata.ClientInfo{}, handlers,
func(c *Client) {
clientHandlerCalled = pushFrontTestHandler("client handler", &c.Handlers.Send)
},
)
if e, a := 2, handlers.Send.Len(); e != a {
t.Errorf("expect %d original handlers, got %d", e, a)
}
if e, a := 3, c.Handlers.Send.Len(); e != a {
t.Errorf("expect %d client handlers, got %d", e, a)
}
handlers.Send.Run(nil)
if !*firstCalled {
t.Errorf("expect first handler to of been called")
}
*firstCalled = false
if !*secondCalled {
t.Errorf("expect second handler to of been called")
}
*secondCalled = false
if *clientHandlerCalled {
t.Errorf("expect client handler to not of been called, but was")
}
c.Handlers.Send.Run(nil)
if !*firstCalled {
t.Errorf("expect client's first handler to of been called")
}
if !*secondCalled {
t.Errorf("expect client's second handler to of been called")
}
if !*clientHandlerCalled {
t.Errorf("expect client's client handler to of been called")
}
}

View File

@@ -1,11 +1,11 @@
package client
import (
"math"
"math/rand"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkrand"
)
// DefaultRetryer implements basic retry logic using exponential backoff for
@@ -15,11 +15,11 @@ import (
// the MaxRetries method:
//
// type retryer struct {
// service.DefaultRetryer
// client.DefaultRetryer
// }
//
// // This implementation always has 100 max retries
// func (d retryer) MaxRetries() uint { return 100 }
// func (d retryer) MaxRetries() int { return 100 }
type DefaultRetryer struct {
NumMaxRetries int
}
@@ -32,14 +32,85 @@ func (d DefaultRetryer) MaxRetries() int {
// RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
delay := int(math.Pow(2, float64(r.RetryCount))) * (rand.Intn(30) + 30)
// Set the upper limit of delay in retrying at ~five minutes
minTime := 30
throttle := d.shouldThrottle(r)
if throttle {
if delay, ok := getRetryDelay(r); ok {
return delay
}
minTime = 500
}
retryCount := r.RetryCount
if throttle && retryCount > 8 {
retryCount = 8
} else if retryCount > 13 {
retryCount = 13
}
delay := (1 << uint(retryCount)) * (sdkrand.SeededRand.Intn(minTime) + minTime)
return time.Duration(delay) * time.Millisecond
}
// ShouldRetry returns if the request should be retried.
// ShouldRetry returns true if the request should be retried.
func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
if r.HTTPResponse.StatusCode >= 500 {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable != nil {
return *r.Retryable
}
if r.HTTPResponse.StatusCode >= 500 && r.HTTPResponse.StatusCode != 501 {
return true
}
return r.IsErrorRetryable()
return r.IsErrorRetryable() || d.shouldThrottle(r)
}
// ShouldThrottle returns true if the request should be throttled.
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
switch r.HTTPResponse.StatusCode {
case 429:
case 502:
case 503:
case 504:
default:
return r.IsErrorThrottle()
}
return true
}
// This will look in the Retry-After header, RFC 7231, for how long
// it will wait before attempting another request
func getRetryDelay(r *request.Request) (time.Duration, bool) {
if !canUseRetryAfterHeader(r) {
return 0, false
}
delayStr := r.HTTPResponse.Header.Get("Retry-After")
if len(delayStr) == 0 {
return 0, false
}
delay, err := strconv.Atoi(delayStr)
if err != nil {
return 0, false
}
return time.Duration(delay) * time.Second, true
}
// Will look at the status code to see if the retry header pertains to
// the status code.
func canUseRetryAfterHeader(r *request.Request) bool {
switch r.HTTPResponse.StatusCode {
case 429:
case 503:
default:
return false
}
return true
}

View File

@@ -0,0 +1,189 @@
package client
import (
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestRetryThrottleStatusCodes(t *testing.T) {
cases := []struct {
expectThrottle bool
expectRetry bool
r request.Request
}{
{
false,
false,
request.Request{
HTTPResponse: &http.Response{StatusCode: 200},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 429},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 502},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 503},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 504},
},
},
{
false,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 500},
},
},
}
d := DefaultRetryer{NumMaxRetries: 10}
for i, c := range cases {
throttle := d.shouldThrottle(&c.r)
retry := d.ShouldRetry(&c.r)
if e, a := c.expectThrottle, throttle; e != a {
t.Errorf("%d: expected %v, but received %v", i, e, a)
}
if e, a := c.expectRetry, retry; e != a {
t.Errorf("%d: expected %v, but received %v", i, e, a)
}
}
}
func TestCanUseRetryAfter(t *testing.T) {
cases := []struct {
r request.Request
e bool
}{
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 200},
},
false,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 500},
},
false,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 429},
},
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503},
},
true,
},
}
for i, c := range cases {
a := canUseRetryAfterHeader(&c.r)
if c.e != a {
t.Errorf("%d: expected %v, but received %v", i, c.e, a)
}
}
}
func TestGetRetryDelay(t *testing.T) {
cases := []struct {
r request.Request
e time.Duration
equal bool
ok bool
}{
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 429, Header: http.Header{"Retry-After": []string{"3600"}}},
},
3600 * time.Second,
true,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}},
},
120 * time.Second,
true,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}},
},
1 * time.Second,
false,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}},
},
0 * time.Second,
true,
false,
},
}
for i, c := range cases {
a, ok := getRetryDelay(&c.r)
if c.ok != ok {
t.Errorf("%d: expected %v, but received %v", i, c.ok, ok)
}
if (c.e != a) == c.equal {
t.Errorf("%d: expected %v, but received %v", i, c.e, a)
}
}
}
func TestRetryDelay(t *testing.T) {
r := request.Request{}
for i := 0; i < 100; i++ {
rTemp := r
rTemp.HTTPResponse = &http.Response{StatusCode: 500, Header: http.Header{"Retry-After": []string{""}}}
rTemp.RetryCount = i
a, _ := getRetryDelay(&rTemp)
if a > 5*time.Minute {
t.Errorf("retry delay should never be greater than five minutes, received %d", a)
}
}
for i := 0; i < 100; i++ {
rTemp := r
rTemp.RetryCount = i
rTemp.HTTPResponse = &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}}
a, _ := getRetryDelay(&rTemp)
if a > 5*time.Minute {
t.Errorf("retry delay should never be greater than five minutes, received %d", a)
}
}
}

184
vendor/github.com/aws/aws-sdk-go/aws/client/logger.go generated vendored Normal file
View File

@@ -0,0 +1,184 @@
package client
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http/httputil"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
const logReqMsg = `DEBUG: Request %s/%s Details:
---[ REQUEST POST-SIGN ]-----------------------------
%s
-----------------------------------------------------`
const logReqErrMsg = `DEBUG ERROR: Request %s/%s:
---[ REQUEST DUMP ERROR ]-----------------------------
%s
------------------------------------------------------`
type logWriter struct {
// Logger is what we will use to log the payload of a response.
Logger aws.Logger
// buf stores the contents of what has been read
buf *bytes.Buffer
}
func (logger *logWriter) Write(b []byte) (int, error) {
return logger.buf.Write(b)
}
type teeReaderCloser struct {
// io.Reader will be a tee reader that is used during logging.
// This structure will read from a body and write the contents to a logger.
io.Reader
// Source is used just to close when we are done reading.
Source io.ReadCloser
}
func (reader *teeReaderCloser) Close() error {
return reader.Source.Close()
}
// LogHTTPRequestHandler is a SDK request handler to log the HTTP request sent
// to a service. Will include the HTTP request body if the LogLevel of the
// request matches LogDebugWithHTTPBody.
var LogHTTPRequestHandler = request.NamedHandler{
Name: "awssdk.client.LogRequest",
Fn: logRequest,
}
func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
bodySeekable := aws.IsReaderSeekable(r.Body)
b, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
if logBody {
if !bodySeekable {
r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body))
}
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.ResetBody()
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg,
r.ClientInfo.ServiceName, r.Operation.Name, string(b)))
}
// LogHTTPRequestHeaderHandler is a SDK request handler to log the HTTP request sent
// to a service. Will only log the HTTP request's headers. The request payload
// will not be read.
var LogHTTPRequestHeaderHandler = request.NamedHandler{
Name: "awssdk.client.LogRequestHeader",
Fn: logRequestHeader,
}
func logRequestHeader(r *request.Request) {
b, err := httputil.DumpRequestOut(r.HTTPRequest, false)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg,
r.ClientInfo.ServiceName, r.Operation.Name, string(b)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
---[ RESPONSE ]--------------------------------------
%s
-----------------------------------------------------`
const logRespErrMsg = `DEBUG ERROR: Response %s/%s:
---[ RESPONSE DUMP ERROR ]-----------------------------
%s
-----------------------------------------------------`
// LogHTTPResponseHandler is a SDK request handler to log the HTTP response
// received from a service. Will include the HTTP response body if the LogLevel
// of the request matches LogDebugWithHTTPBody.
var LogHTTPResponseHandler = request.NamedHandler{
Name: "awssdk.client.LogResponse",
Fn: logResponse,
}
func logResponse(r *request.Request) {
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
if logBody {
r.HTTPResponse.Body = &teeReaderCloser{
Reader: io.TeeReader(r.HTTPResponse.Body, lw),
Source: r.HTTPResponse.Body,
}
}
handlerFn := func(req *request.Request) {
b, err := httputil.DumpResponse(req.HTTPResponse, false)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
lw.Logger.Log(fmt.Sprintf(logRespMsg,
req.ClientInfo.ServiceName, req.Operation.Name, string(b)))
if logBody {
b, err := ioutil.ReadAll(lw.buf)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
lw.Logger.Log(string(b))
}
}
const handlerName = "awsdk.client.LogResponse.ResponseBody"
r.Handlers.Unmarshal.SetBackNamed(request.NamedHandler{
Name: handlerName, Fn: handlerFn,
})
r.Handlers.UnmarshalError.SetBackNamed(request.NamedHandler{
Name: handlerName, Fn: handlerFn,
})
}
// LogHTTPResponseHeaderHandler is a SDK request handler to log the HTTP
// response received from a service. Will only log the HTTP response's headers.
// The response payload will not be read.
var LogHTTPResponseHeaderHandler = request.NamedHandler{
Name: "awssdk.client.LogResponseHeader",
Fn: logResponseHeader,
}
func logResponseHeader(r *request.Request) {
if r.Config.Logger == nil {
return
}
b, err := httputil.DumpResponse(r.HTTPResponse, false)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logRespErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
r.Config.Logger.Log(fmt.Sprintf(logRespMsg,
r.ClientInfo.ServiceName, r.Operation.Name, string(b)))
}

View File

@@ -0,0 +1,222 @@
package client
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
type mockCloser struct {
closed bool
}
func (closer *mockCloser) Read(b []byte) (int, error) {
return 0, io.EOF
}
func (closer *mockCloser) Close() error {
closer.closed = true
return nil
}
func TestTeeReaderCloser(t *testing.T) {
expected := "FOO"
buf := bytes.NewBuffer([]byte(expected))
lw := bytes.NewBuffer(nil)
c := &mockCloser{}
closer := teeReaderCloser{
io.TeeReader(buf, lw),
c,
}
b := make([]byte, len(expected))
_, err := closer.Read(b)
closer.Close()
if expected != lw.String() {
t.Errorf("Expected %q, but received %q", expected, lw.String())
}
if err != nil {
t.Errorf("Expected 'nil', but received %v", err)
}
if !c.closed {
t.Error("Expected 'true', but received 'false'")
}
}
func TestLogWriter(t *testing.T) {
expected := "FOO"
lw := &logWriter{nil, bytes.NewBuffer(nil)}
lw.Write([]byte(expected))
if expected != lw.buf.String() {
t.Errorf("Expected %q, but received %q", expected, lw.buf.String())
}
}
func TestLogRequest(t *testing.T) {
cases := []struct {
Body io.ReadSeeker
ExpectBody []byte
LogLevel aws.LogLevelType
}{
{
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
ExpectBody: []byte("body content"),
},
{
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
LogLevel: aws.LogDebugWithHTTPBody,
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewReader([]byte("body content")),
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewReader([]byte("body content")),
LogLevel: aws.LogDebugWithHTTPBody,
ExpectBody: []byte("body content"),
},
}
for i, c := range cases {
logW := bytes.NewBuffer(nil)
req := request.New(
aws.Config{
Credentials: credentials.AnonymousCredentials,
Logger: &bufLogger{w: logW},
LogLevel: aws.LogLevel(c.LogLevel),
},
metadata.ClientInfo{
Endpoint: "https://mock-service.mock-region.amazonaws.com",
},
testHandlers(),
nil,
&request.Operation{
Name: "APIName",
HTTPMethod: "POST",
HTTPPath: "/",
},
struct{}{}, nil,
)
req.SetReaderBody(c.Body)
req.Build()
logRequest(req)
b, err := ioutil.ReadAll(req.HTTPRequest.Body)
if err != nil {
t.Fatalf("%d, expect to read SDK request Body", i)
}
if e, a := c.ExpectBody, b; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v body, got %v", i, e, a)
}
}
}
func TestLogResponse(t *testing.T) {
cases := []struct {
Body *bytes.Buffer
ExpectBody []byte
ReadBody bool
LogLevel aws.LogLevelType
}{
{
Body: bytes.NewBuffer([]byte("body content")),
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewBuffer([]byte("body content")),
LogLevel: aws.LogDebug,
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewBuffer([]byte("body content")),
LogLevel: aws.LogDebugWithHTTPBody,
ReadBody: true,
ExpectBody: []byte("body content"),
},
}
for i, c := range cases {
var logW bytes.Buffer
req := request.New(
aws.Config{
Credentials: credentials.AnonymousCredentials,
Logger: &bufLogger{w: &logW},
LogLevel: aws.LogLevel(c.LogLevel),
},
metadata.ClientInfo{
Endpoint: "https://mock-service.mock-region.amazonaws.com",
},
testHandlers(),
nil,
&request.Operation{
Name: "APIName",
HTTPMethod: "POST",
HTTPPath: "/",
},
struct{}{}, nil,
)
req.HTTPResponse = &http.Response{
StatusCode: 200,
Status: "OK",
Header: http.Header{
"ABC": []string{"123"},
},
Body: ioutil.NopCloser(c.Body),
}
logResponse(req)
req.Handlers.Unmarshal.Run(req)
if c.ReadBody {
if e, a := len(c.ExpectBody), c.Body.Len(); e != a {
t.Errorf("%d, expect orginal body not to of been read", i)
}
}
if logW.Len() == 0 {
t.Errorf("%d, expect HTTP Response headers to be logged", i)
}
b, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
t.Fatalf("%d, expect to read SDK request Body", i)
}
if e, a := c.ExpectBody, b; !bytes.Equal(e, a) {
t.Errorf("%d, expect %v body, got %v", i, e, a)
}
}
}
type bufLogger struct {
w *bytes.Buffer
}
func (l *bufLogger) Log(args ...interface{}) {
fmt.Fprintln(l.w, args...)
}
func testHandlers() request.Handlers {
var handlers request.Handlers
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
return handlers
}

View File

@@ -3,6 +3,7 @@ package metadata
// ClientInfo wraps immutable data from the client.Client structure.
type ClientInfo struct {
ServiceName string
ServiceID string
APIVersion string
Endpoint string
SigningName string

View File

@@ -5,21 +5,39 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
)
// UseServiceDefaultRetries instructs the config to use the service's own default
// number of retries. This will be the default action if Config.MaxRetries
// is nil also.
// UseServiceDefaultRetries instructs the config to use the service's own
// default number of retries. This will be the default action if
// Config.MaxRetries is nil also.
const UseServiceDefaultRetries = -1
// RequestRetryer is an alias for a type that implements the request.Retryer interface.
// RequestRetryer is an alias for a type that implements the request.Retryer
// interface.
type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default,
// all clients will use the {defaults.DefaultConfig} structure.
// all clients will use the defaults.DefaultConfig structure.
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3),
// }))
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, &aws.Config{
// Region: aws.String("us-west-2"),
// })
type Config struct {
// The credentials object to use when signing requests. Defaults to
// a chain of credential providers to search for credentials in environment
// Enables verbose error printing of all credential chain errors.
// Should be used when wanting to see all errors while attempting to
// retrieve credentials.
CredentialsChainVerboseErrors *bool
// The credentials object to use when signing requests. Defaults to a
// chain of credential providers to search for credentials in environment
// variables, shared credential file, and EC2 Instance Roles.
Credentials *credentials.Credentials
@@ -27,17 +45,28 @@ type Config struct {
// that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint.
//
// @note You must still provide a `Region` value when specifying an
// endpoint for a client.
// Note: You must still provide a `Region` value when specifying an
// endpoint for a client.
Endpoint *string
// The resolver to use for looking up endpoints for AWS service clients
// to use based on region.
EndpointResolver endpoints.Resolver
// EnforceShouldRetryCheck is used in the AfterRetryHandler to always call
// ShouldRetry regardless of whether or not if request.Retryable is set.
// This will utilize ShouldRetry method of custom retryers. If EnforceShouldRetryCheck
// is not set, then ShouldRetry will only be called if request.Retryable is nil.
// Proper handling of the request.Retryable field is important when setting this field.
EnforceShouldRetryCheck *bool
// The region to send requests to. This parameter is required and must
// be configured globally or on a per-client basis unless otherwise
// noted. A full list of regions is found in the "Regions and Endpoints"
// document.
//
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
// AWS Regions and Endpoints
// See http://docs.aws.amazon.com/general/latest/gr/rande.html for AWS
// Regions and Endpoints.
Region *string
// Set this to `true` to disable SSL when sending requests. Defaults
@@ -58,14 +87,15 @@ type Config struct {
Logger Logger
// The maximum number of times that a request will be retried for failures.
// Defaults to -1, which defers the max retry setting to the service specific
// configuration.
// Defaults to -1, which defers the max retry setting to the service
// specific configuration.
MaxRetries *int
// Retryer guides how HTTP requests should be retried in case of recoverable failures.
// Retryer guides how HTTP requests should be retried in case of
// recoverable failures.
//
// When nil or the value does not implement the request.Retryer interface,
// the request.DefaultRetryer will be used.
// the client.DefaultRetryer will be used.
//
// When both Retryer and MaxRetries are non-nil, the former is used and
// the latter ignored.
@@ -77,8 +107,8 @@ type Config struct {
//
Retryer RequestRetryer
// Disables semantic parameter validation, which validates input for missing
// required fields and/or other semantic request input errors.
// Disables semantic parameter validation, which validates input for
// missing required fields and/or other semantic request input errors.
DisableParamValidation *bool
// Disables the computation of request and response checksums, e.g.,
@@ -86,27 +116,155 @@ type Config struct {
DisableComputeChecksums *bool
// Set this to `true` to force the request to use path-style addressing,
// i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client will
// use virtual hosted bucket addressing when possible
// i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client
// will use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`).
//
// @note This configuration option is specific to the Amazon S3 service.
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
// Note: This configuration option is specific to the Amazon S3 service.
//
// See http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// for Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
// header to PUT requests over 2MB of content. 100-Continue instructs the
// HTTP client not to send the body until the service responds with a
// `continue` status. This is useful to prevent sending the request body
// until after the request is authenticated, and validated.
//
// http://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectPUT.html
//
// 100-Continue is only enabled for Go 1.6 and above. See `http.Transport`'s
// `ExpectContinueTimeout` for information on adjusting the continue wait
// timeout. https://golang.org/pkg/net/http/#Transport
//
// You should use this flag to disble 100-Continue if you experience issues
// with proxies or third party S3 compatible services.
S3Disable100Continue *bool
// Set this to `true` to enable S3 Accelerate feature. For all operations
// compatible with S3 Accelerate will use the accelerate endpoint for
// requests. Requests not compatible will fall back to normal S3 requests.
//
// The bucket must be enable for accelerate to be used with S3 client with
// accelerate enabled. If the bucket is not enabled for accelerate an error
// will be returned. The bucket name must be DNS compatible to also work
// with accelerate.
S3UseAccelerate *bool
// S3DisableContentMD5Validation config option is temporarily disabled,
// For S3 GetObject API calls, #1837.
//
// Set this to `true` to disable the S3 service client from automatically
// adding the ContentMD5 to S3 Object Put and Upload API calls. This option
// will also disable the SDK from performing object ContentMD5 validation
// on GetObject API calls.
S3DisableContentMD5Validation *bool
// Set this to `true` to disable the EC2Metadata client from overriding the
// default http.Client's Timeout. This is helpful if you do not want the
// EC2Metadata client to create a new http.Client. This options is only
// meaningful if you're not already using a custom HTTP client with the
// SDK. Enabled by default.
//
// Must be set and provided to the session.NewSession() in order to disable
// the EC2Metadata overriding the timeout for default credentials chain.
//
// Example:
// sess := session.Must(session.NewSession(aws.NewConfig()
// .WithEC2MetadataDiableTimeoutOverride(true)))
//
// svc := s3.New(sess)
//
EC2MetadataDisableTimeoutOverride *bool
// Instructs the endpoint to be generated for a service client to
// be the dual stack endpoint. The dual stack endpoint will support
// both IPv4 and IPv6 addressing.
//
// Setting this for a service which does not support dual stack will fail
// to make requets. It is not recommended to set this value on the session
// as it will apply to all service clients created with the session. Even
// services which don't support dual stack endpoints.
//
// If the Endpoint config value is also provided the UseDualStack flag
// will be ignored.
//
// Only supported with.
//
// sess := session.Must(session.NewSession())
//
// svc := s3.New(sess, &aws.Config{
// UseDualStack: aws.Bool(true),
// })
UseDualStack *bool
// SleepDelay is an override for the func the SDK will call when sleeping
// during the lifecycle of a request. Specifically this will be used for
// request delays. This value should only be used for testing. To adjust
// the delay of a request see the aws/client.DefaultRetryer and
// aws/request.Retryer.
//
// SleepDelay will prevent any Context from being used for canceling retry
// delay of an API operation. It is recommended to not use SleepDelay at all
// and specify a Retryer instead.
SleepDelay func(time.Duration)
// DisableRestProtocolURICleaning will not clean the URL path when making rest protocol requests.
// Will default to false. This would only be used for empty directory names in s3 requests.
//
// Example:
// sess := session.Must(session.NewSession(&aws.Config{
// DisableRestProtocolURICleaning: aws.Bool(true),
// }))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String("bucketname"),
// Key: aws.String("//foo//bar//moo"),
// })
DisableRestProtocolURICleaning *bool
// EnableEndpointDiscovery will allow for endpoint discovery on operations that
// have the definition in its model. By default, endpoint discovery is off.
//
// Example:
// sess := session.Must(session.NewSession(&aws.Config{
// EnableEndpointDiscovery: aws.Bool(true),
// }))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String("bucketname"),
// Key: aws.String("/foo/bar/moo"),
// })
EnableEndpointDiscovery *bool
}
// NewConfig returns a new Config pointer that can be chained with builder methods to
// set multiple configuration values inline without using pointers.
// NewConfig returns a new Config pointer that can be chained with builder
// methods to set multiple configuration values inline without using pointers.
//
// svc := s3.New(aws.NewConfig().WithRegion("us-west-2").WithMaxRetries(10))
// // Create Session with MaxRetry configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3),
// ))
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, aws.NewConfig().
// WithRegion("us-west-2"),
// )
func NewConfig() *Config {
return &Config{}
}
// WithCredentialsChainVerboseErrors sets a config verbose errors boolean and returning
// a Config pointer.
func (c *Config) WithCredentialsChainVerboseErrors(verboseErrs bool) *Config {
c.CredentialsChainVerboseErrors = &verboseErrs
return c
}
// WithCredentials sets a config Credentials value returning a Config pointer
// for chaining.
func (c *Config) WithCredentials(creds *credentials.Credentials) *Config {
@@ -121,6 +279,13 @@ func (c *Config) WithEndpoint(endpoint string) *Config {
return c
}
// WithEndpointResolver sets a config EndpointResolver value returning a
// Config pointer for chaining.
func (c *Config) WithEndpointResolver(resolver endpoints.Resolver) *Config {
c.EndpointResolver = resolver
return c
}
// WithRegion sets a config Region value returning a Config pointer for
// chaining.
func (c *Config) WithRegion(region string) *Config {
@@ -184,6 +349,43 @@ func (c *Config) WithS3ForcePathStyle(force bool) *Config {
return c
}
// WithS3Disable100Continue sets a config S3Disable100Continue value returning
// a Config pointer for chaining.
func (c *Config) WithS3Disable100Continue(disable bool) *Config {
c.S3Disable100Continue = &disable
return c
}
// WithS3UseAccelerate sets a config S3UseAccelerate value returning a Config
// pointer for chaining.
func (c *Config) WithS3UseAccelerate(enable bool) *Config {
c.S3UseAccelerate = &enable
return c
}
// WithS3DisableContentMD5Validation sets a config
// S3DisableContentMD5Validation value returning a Config pointer for chaining.
func (c *Config) WithS3DisableContentMD5Validation(enable bool) *Config {
c.S3DisableContentMD5Validation = &enable
return c
}
// WithUseDualStack sets a config UseDualStack value returning a Config
// pointer for chaining.
func (c *Config) WithUseDualStack(enable bool) *Config {
c.UseDualStack = &enable
return c
}
// WithEC2MetadataDisableTimeoutOverride sets a config EC2MetadataDisableTimeoutOverride value
// returning a Config pointer for chaining.
func (c *Config) WithEC2MetadataDisableTimeoutOverride(enable bool) *Config {
c.EC2MetadataDisableTimeoutOverride = &enable
return c
}
// WithSleepDelay overrides the function used to sleep while waiting for the
// next retry. Defaults to time.Sleep.
func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
@@ -191,6 +393,12 @@ func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
return c
}
// WithEndpointDiscovery will set whether or not to use endpoint discovery.
func (c *Config) WithEndpointDiscovery(t bool) *Config {
c.EnableEndpointDiscovery = &t
return c
}
// MergeIn merges the passed in configs into the existing config object.
func (c *Config) MergeIn(cfgs ...*Config) {
for _, other := range cfgs {
@@ -203,6 +411,10 @@ func mergeInConfig(dst *Config, other *Config) {
return
}
if other.CredentialsChainVerboseErrors != nil {
dst.CredentialsChainVerboseErrors = other.CredentialsChainVerboseErrors
}
if other.Credentials != nil {
dst.Credentials = other.Credentials
}
@@ -211,6 +423,10 @@ func mergeInConfig(dst *Config, other *Config) {
dst.Endpoint = other.Endpoint
}
if other.EndpointResolver != nil {
dst.EndpointResolver = other.EndpointResolver
}
if other.Region != nil {
dst.Region = other.Region
}
@@ -251,9 +467,41 @@ func mergeInConfig(dst *Config, other *Config) {
dst.S3ForcePathStyle = other.S3ForcePathStyle
}
if other.S3Disable100Continue != nil {
dst.S3Disable100Continue = other.S3Disable100Continue
}
if other.S3UseAccelerate != nil {
dst.S3UseAccelerate = other.S3UseAccelerate
}
if other.S3DisableContentMD5Validation != nil {
dst.S3DisableContentMD5Validation = other.S3DisableContentMD5Validation
}
if other.UseDualStack != nil {
dst.UseDualStack = other.UseDualStack
}
if other.EC2MetadataDisableTimeoutOverride != nil {
dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride
}
if other.SleepDelay != nil {
dst.SleepDelay = other.SleepDelay
}
if other.DisableRestProtocolURICleaning != nil {
dst.DisableRestProtocolURICleaning = other.DisableRestProtocolURICleaning
}
if other.EnforceShouldRetryCheck != nil {
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
}
if other.EnableEndpointDiscovery != nil {
dst.EnableEndpointDiscovery = other.EnableEndpointDiscovery
}
}
// Copy will return a shallow copy of the Config object. If any additional

71
vendor/github.com/aws/aws-sdk-go/aws/context.go generated vendored Normal file
View File

@@ -0,0 +1,71 @@
package aws
import (
"time"
)
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
// It is represented as a SDK interface to enable you to use the "WithContext"
// API methods with Go v1.6 and a Context type such as golang.org/x/net/context.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
Done() <-chan struct{}
// Err returns a non-nil error value after Done is closed. Err returns
// Canceled if the context was canceled or DeadlineExceeded if the
// context's deadline passed. No other values for Err are defined.
// After Done is closed, successive calls to Err return the same value.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
Value(key interface{}) interface{}
}
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

41
vendor/github.com/aws/aws-sdk-go/aws/context_1_6.go generated vendored Normal file
View File

@@ -0,0 +1,41 @@
// +build !go1.7
package aws
import "time"
// An emptyCtx is a copy of the Go 1.7 context.emptyCtx type. This is copied to
// provide a 1.6 and 1.5 safe version of context that is compatible with Go
// 1.7's Context.
//
// An emptyCtx is never canceled, has no values, and has no deadline. It is not
// struct{}, since vars of this type must have distinct addresses.
type emptyCtx int
func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (*emptyCtx) Done() <-chan struct{} {
return nil
}
func (*emptyCtx) Err() error {
return nil
}
func (*emptyCtx) Value(key interface{}) interface{} {
return nil
}
func (e *emptyCtx) String() string {
switch e {
case backgroundCtx:
return "aws.BackgroundContext"
}
return "unknown empty Context"
}
var (
backgroundCtx = new(emptyCtx)
)

9
vendor/github.com/aws/aws-sdk-go/aws/context_1_7.go generated vendored Normal file
View File

@@ -0,0 +1,9 @@
// +build go1.7
package aws
import "context"
var (
backgroundCtx = context.Background()
)

37
vendor/github.com/aws/aws-sdk-go/aws/context_test.go generated vendored Normal file
View File

@@ -0,0 +1,37 @@
package aws_test
import (
"fmt"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestSleepWithContext(t *testing.T) {
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
err := aws.SleepWithContext(ctx, 1*time.Millisecond)
if err != nil {
t.Errorf("expect context to not be canceled, got %v", err)
}
}
func TestSleepWithContext_Canceled(t *testing.T) {
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
expectErr := fmt.Errorf("context canceled")
ctx.Error = expectErr
close(ctx.DoneCh)
err := aws.SleepWithContext(ctx, 1*time.Millisecond)
if err == nil {
t.Fatalf("expect error, did not get one")
}
if e, a := expectErr, err; e != a {
t.Errorf("expect %v error, got %v", e, a)
}
}

View File

@@ -2,7 +2,7 @@ package aws
import "time"
// String returns a pointer to of the string value passed in.
// String returns a pointer to the string value passed in.
func String(v string) *string {
return &v
}
@@ -61,7 +61,7 @@ func StringValueMap(src map[string]*string) map[string]string {
return dst
}
// Bool returns a pointer to of the bool value passed in.
// Bool returns a pointer to the bool value passed in.
func Bool(v bool) *bool {
return &v
}
@@ -120,7 +120,7 @@ func BoolValueMap(src map[string]*bool) map[string]bool {
return dst
}
// Int returns a pointer to of the int value passed in.
// Int returns a pointer to the int value passed in.
func Int(v int) *int {
return &v
}
@@ -179,7 +179,7 @@ func IntValueMap(src map[string]*int) map[string]int {
return dst
}
// Int64 returns a pointer to of the int64 value passed in.
// Int64 returns a pointer to the int64 value passed in.
func Int64(v int64) *int64 {
return &v
}
@@ -238,7 +238,7 @@ func Int64ValueMap(src map[string]*int64) map[string]int64 {
return dst
}
// Float64 returns a pointer to of the float64 value passed in.
// Float64 returns a pointer to the float64 value passed in.
func Float64(v float64) *float64 {
return &v
}
@@ -297,7 +297,7 @@ func Float64ValueMap(src map[string]*float64) map[string]float64 {
return dst
}
// Time returns a pointer to of the time.Time value passed in.
// Time returns a pointer to the time.Time value passed in.
func Time(v time.Time) *time.Time {
return &v
}
@@ -311,6 +311,36 @@ func TimeValue(v *time.Time) time.Time {
return time.Time{}
}
// SecondsTimeValue converts an int64 pointer to a time.Time value
// representing seconds since Epoch or time.Time{} if the pointer is nil.
func SecondsTimeValue(v *int64) time.Time {
if v != nil {
return time.Unix((*v / 1000), 0)
}
return time.Time{}
}
// MillisecondsTimeValue converts an int64 pointer to a time.Time value
// representing milliseconds sinch Epoch or time.Time{} if the pointer is nil.
func MillisecondsTimeValue(v *int64) time.Time {
if v != nil {
return time.Unix(0, (*v * 1000000))
}
return time.Time{}
}
// TimeUnixMilli returns a Unix timestamp in milliseconds from "January 1, 1970 UTC".
// The result is undefined if the Unix time cannot be represented by an int64.
// Which includes calling TimeUnixMilli on a zero Time is undefined.
//
// This utility is useful for service API's such as CloudWatch Logs which require
// their unix time values to be in milliseconds.
//
// See Go stdlib https://golang.org/pkg/time/#Time.UnixNano for more information.
func TimeUnixMilli(t time.Time) int64 {
return t.UnixNano() / int64(time.Millisecond/time.Nanosecond)
}
// TimeSlice converts a slice of time.Time values into a slice of
// time.Time pointers
func TimeSlice(src []time.Time) []*time.Time {

View File

@@ -1,10 +1,9 @@
package aws
import (
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var testCasesStringSlice = [][]string{
@@ -18,14 +17,22 @@ func TestStringSlice(t *testing.T) {
continue
}
out := StringSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := StringValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -39,22 +46,34 @@ func TestStringValueSlice(t *testing.T) {
continue
}
out := StringValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if out[i] != "" {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := StringSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if *(out2[i]) != "" {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := *in[i], *out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -70,14 +89,22 @@ func TestStringMap(t *testing.T) {
continue
}
out := StringMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := StringValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -91,14 +118,22 @@ func TestBoolSlice(t *testing.T) {
continue
}
out := BoolSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := BoolValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -110,22 +145,34 @@ func TestBoolValueSlice(t *testing.T) {
continue
}
out := BoolValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if out[i] {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := BoolSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if *(out2[i]) {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -141,14 +188,22 @@ func TestBoolMap(t *testing.T) {
continue
}
out := BoolMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := BoolValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -162,14 +217,22 @@ func TestIntSlice(t *testing.T) {
continue
}
out := IntSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := IntValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -181,22 +244,34 @@ func TestIntValueSlice(t *testing.T) {
continue
}
out := IntValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if out[i] != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := IntSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if *(out2[i]) != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -212,14 +287,22 @@ func TestIntMap(t *testing.T) {
continue
}
out := IntMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := IntValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -233,14 +316,22 @@ func TestInt64Slice(t *testing.T) {
continue
}
out := Int64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Int64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -252,22 +343,34 @@ func TestInt64ValueSlice(t *testing.T) {
continue
}
out := Int64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if out[i] != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := Int64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if *(out2[i]) != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -283,14 +386,22 @@ func TestInt64Map(t *testing.T) {
continue
}
out := Int64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Int64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -304,14 +415,22 @@ func TestFloat64Slice(t *testing.T) {
continue
}
out := Float64Slice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Float64ValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -323,22 +442,34 @@ func TestFloat64ValueSlice(t *testing.T) {
continue
}
out := Float64ValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if out[i] != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := Float64Slice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if *(out2[i]) != 0 {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -354,14 +485,22 @@ func TestFloat64Map(t *testing.T) {
continue
}
out := Float64Map(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := Float64ValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -375,14 +514,22 @@ func TestTimeSlice(t *testing.T) {
continue
}
out := TimeSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := TimeValueSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
@@ -394,22 +541,34 @@ func TestTimeValueSlice(t *testing.T) {
continue
}
out := TimeValueSlice(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
if in[i] == nil {
assert.Empty(t, out[i], "Unexpected value at idx %d", idx)
if !out[i].IsZero() {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, *(in[i]), out[i], "Unexpected value at idx %d", idx)
if e, a := *(in[i]), out[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
out2 := TimeSlice(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out2 {
if in[i] == nil {
assert.Empty(t, *(out2[i]), "Unexpected value at idx %d", idx)
if !(*(out2[i])).IsZero() {
t.Errorf("Unexpected value at idx %d", idx)
}
} else {
assert.Equal(t, in[i], out2[i], "Unexpected value at idx %d", idx)
if e, a := in[i], out2[i]; e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
}
@@ -425,13 +584,58 @@ func TestTimeMap(t *testing.T) {
continue
}
out := TimeMap(in)
assert.Len(t, out, len(in), "Unexpected len at idx %d", idx)
if e, a := len(out), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
for i := range out {
assert.Equal(t, in[i], *(out[i]), "Unexpected value at idx %d", idx)
if e, a := in[i], *(out[i]); e != a {
t.Errorf("Unexpected value at idx %d", idx)
}
}
out2 := TimeValueMap(out)
assert.Len(t, out2, len(in), "Unexpected len at idx %d", idx)
assert.Equal(t, in, out2, "Unexpected value at idx %d", idx)
if e, a := len(out2), len(in); e != a {
t.Errorf("Unexpected len at idx %d", idx)
}
if e, a := in, out2; !reflect.DeepEqual(e, a) {
t.Errorf("Unexpected value at idx %d", idx)
}
}
}
type TimeValueTestCase struct {
in int64
outSecs time.Time
outMillis time.Time
}
var testCasesTimeValue = []TimeValueTestCase{
{
in: int64(1501558289000),
outSecs: time.Unix(1501558289, 0),
outMillis: time.Unix(1501558289, 0),
},
{
in: int64(1501558289001),
outSecs: time.Unix(1501558289, 0),
outMillis: time.Unix(1501558289, 1*1000000),
},
}
func TestSecondsTimeValue(t *testing.T) {
for idx, testCase := range testCasesTimeValue {
out := SecondsTimeValue(&testCase.in)
if e, a := testCase.outSecs, out; e != a {
t.Errorf("Unexpected value for time value at %d", idx)
}
}
}
func TestMillisecondsTimeValue(t *testing.T) {
for idx, testCase := range testCasesTimeValue {
out := MillisecondsTimeValue(&testCase.in)
if e, a := testCase.outMillis, out; e != a {
t.Errorf("Unexpected value for time value at %d", idx)
}
}
}

View File

@@ -3,16 +3,16 @@ package corehandlers
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"runtime"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
@@ -24,76 +24,156 @@ type lener interface {
// BuildContentLengthHandler builds the content length of a request based on the body,
// or will use the HTTPRequest.Header's "Content-Length" if defined. If unable
// to determine request body length and no "Content-Length" was specified it will panic.
//
// The Content-Length will only be added to the request if the length of the body
// is greater than 0. If the body is empty or the current `Content-Length`
// header is <= 0, the header will also be stripped.
var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLengthHandler", Fn: func(r *request.Request) {
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ := strconv.ParseInt(slength, 10, 64)
r.HTTPRequest.ContentLength = length
return
}
var length int64
switch body := r.Body.(type) {
case nil:
length = 0
case lener:
length = int64(body.Len())
case io.Seeker:
r.BodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
length = end - r.BodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
}
r.HTTPRequest.ContentLength = length
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
}}
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version to the user agent.
var SDKVersionUserAgentHandler = request.NamedHandler{
Name: "core.SDKVersionUserAgentHandler",
Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion,
runtime.Version(), runtime.GOOS, runtime.GOARCH),
}
var reStatusCode = regexp.MustCompile(`^(\d{3})`)
// SendHandler is a request handler to send service request using HTTP client.
var SendHandler = request.NamedHandler{Name: "core.SendHandler", Fn: func(r *request.Request) {
var err error
r.HTTPResponse, err = r.Config.HTTPClient.Do(r.HTTPRequest)
if err != nil {
// Capture the case where url.Error is returned for error processing
// response. e.g. 301 without location header comes back as string
// error and r.HTTPResponse is nil. Other url redirect errors will
// comeback in a similar method.
if e, ok := err.(*url.Error); ok && e.Err != nil {
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
code, _ := strconv.ParseInt(s[1], 10, 64)
r.HTTPResponse = &http.Response{
StatusCode: int(code),
Status: http.StatusText(int(code)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ = strconv.ParseInt(slength, 10, 64)
} else {
if r.Body != nil {
var err error
length, err = aws.SeekerLen(r.Body)
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "failed to get request body's length", err)
return
}
}
if r.HTTPResponse == nil {
// Add a dummy request response object to ensure the HTTPResponse
// value is consistent.
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
}
// Catch all other request errors.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = aws.Bool(true) // network errors are retryable
}
if length > 0 {
r.HTTPRequest.ContentLength = length
r.HTTPRequest.Header.Set("Content-Length", fmt.Sprintf("%d", length))
} else {
r.HTTPRequest.ContentLength = 0
r.HTTPRequest.Header.Del("Content-Length")
}
}}
var reStatusCode = regexp.MustCompile(`^(\d{3})`)
// ValidateReqSigHandler is a request handler to ensure that the request's
// signature doesn't expire before it is sent. This can happen when a request
// is built and signed significantly before it is sent. Or significant delays
// occur when retrying requests that would cause the signature to expire.
var ValidateReqSigHandler = request.NamedHandler{
Name: "core.ValidateReqSigHandler",
Fn: func(r *request.Request) {
// Unsigned requests are not signed
if r.Config.Credentials == credentials.AnonymousCredentials {
return
}
signedTime := r.Time
if !r.LastSignedAt.IsZero() {
signedTime = r.LastSignedAt
}
// 10 minutes to allow for some clock skew/delays in transmission.
// Would be improved with aws/aws-sdk-go#423
if signedTime.Add(10 * time.Minute).After(time.Now()) {
return
}
fmt.Println("request expired, resigning")
r.Sign()
},
}
// SendHandler is a request handler to send service request using HTTP client.
var SendHandler = request.NamedHandler{
Name: "core.SendHandler",
Fn: func(r *request.Request) {
sender := sendFollowRedirects
if r.DisableFollowRedirects {
sender = sendWithoutFollowRedirects
}
if request.NoBody == r.HTTPRequest.Body {
// Strip off the request body if the NoBody reader was used as a
// place holder for a request body. This prevents the SDK from
// making requests with a request body when it would be invalid
// to do so.
//
// Use a shallow copy of the http.Request to ensure the race condition
// of transport on Body will not trigger
reqOrig, reqCopy := r.HTTPRequest, *r.HTTPRequest
reqCopy.Body = nil
r.HTTPRequest = &reqCopy
defer func() {
r.HTTPRequest = reqOrig
}()
}
var err error
r.HTTPResponse, err = sender(r)
if err != nil {
handleSendError(r, err)
}
},
}
func sendFollowRedirects(r *request.Request) (*http.Response, error) {
return r.Config.HTTPClient.Do(r.HTTPRequest)
}
func sendWithoutFollowRedirects(r *request.Request) (*http.Response, error) {
transport := r.Config.HTTPClient.Transport
if transport == nil {
transport = http.DefaultTransport
}
return transport.RoundTrip(r.HTTPRequest)
}
func handleSendError(r *request.Request, err error) {
// Prevent leaking if an HTTPResponse was returned. Clean up
// the body.
if r.HTTPResponse != nil {
r.HTTPResponse.Body.Close()
}
// Capture the case where url.Error is returned for error processing
// response. e.g. 301 without location header comes back as string
// error and r.HTTPResponse is nil. Other URL redirect errors will
// comeback in a similar method.
if e, ok := err.(*url.Error); ok && e.Err != nil {
if s := reStatusCode.FindStringSubmatch(e.Err.Error()); s != nil {
code, _ := strconv.ParseInt(s[1], 10, 64)
r.HTTPResponse = &http.Response{
StatusCode: int(code),
Status: http.StatusText(int(code)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
return
}
}
if r.HTTPResponse == nil {
// Add a dummy request response object to ensure the HTTPResponse
// value is consistent.
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
}
// Catch all other request errors.
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable = aws.Bool(true) // network errors are retryable
// Override the error with a context canceled error, if that was canceled.
ctx := r.Context()
select {
case <-ctx.Done():
r.Error = awserr.New(request.CanceledErrorCode,
"request context canceled", ctx.Err())
r.Retryable = aws.Bool(false)
default:
}
}
// ValidateResponseHandler is a request handler to validate service response.
var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
@@ -107,13 +187,22 @@ var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseH
var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) {
// If one of the other handlers already set the retry state
// we don't want to override it based on the service's state
if r.Retryable == nil {
if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) {
r.Retryable = aws.Bool(r.ShouldRetry(r))
}
if r.WillRetry() {
r.RetryDelay = r.RetryRules(r)
r.Config.SleepDelay(r.RetryDelay)
if sleepFn := r.Config.SleepDelay; sleepFn != nil {
// Support SleepDelay for backwards compatibility and testing
sleepFn(r.RetryDelay)
} else if err := aws.SleepWithContext(r.Context(), r.RetryDelay); err != nil {
r.Error = awserr.New(request.CanceledErrorCode,
"request context canceled", err)
r.Retryable = aws.Bool(false)
return
}
// when the expired token exception occurs the credentials
// need to be expired locally so that the next request to

View File

@@ -0,0 +1,64 @@
// +build go1.10
package corehandlers_test
import (
"crypto/tls"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/service/s3"
"golang.org/x/net/http2"
)
func TestSendHandler_HEADNoBody(t *testing.T) {
TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile, err := awstesting.CreateTLSBundleFiles()
if err != nil {
panic(err)
}
defer awstesting.CleanupTLSBundleFiles(TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile)
endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
transport := http.DefaultTransport.(*http.Transport)
// test server's certificate is self-signed certificate
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
http2.ConfigureTransport(transport)
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{
HTTPClient: &http.Client{},
Endpoint: aws.String(endpoint),
Region: aws.String("mock-region"),
Credentials: credentials.AnonymousCredentials,
S3ForcePathStyle: aws.Bool(true),
},
})
svc := s3.New(sess)
req, _ := svc.HeadObjectRequest(&s3.HeadObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
})
if e, a := request.NoBody, req.HTTPRequest.Body; e != a {
t.Fatalf("expect %T request body, got %T", e, a)
}
err = req.Send()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := http.StatusOK, req.HTTPResponse.StatusCode; e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
}

View File

@@ -1,12 +1,15 @@
package corehandlers_test
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
@@ -14,6 +17,8 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestValidateEndpointHandler(t *testing.T) {
@@ -26,7 +31,9 @@ func TestValidateEndpointHandler(t *testing.T) {
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
}
func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
@@ -39,8 +46,12 @@ func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
err := req.Build()
assert.Error(t, err)
assert.Equal(t, aws.ErrMissingRegion, err)
if err == nil {
t.Errorf("expect error, got none")
}
if e, a := aws.ErrMissingRegion, err; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}
type mockCredsProvider struct {
@@ -50,7 +61,7 @@ type mockCredsProvider struct {
func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
m.retrieveCalled = true
return credentials.Value{}, nil
return credentials.Value{ProviderName: "mockCredsProvider"}, nil
}
func (m *mockCredsProvider) IsExpired() bool {
@@ -69,25 +80,125 @@ func TestAfterRetryRefreshCreds(t *testing.T) {
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *request.Request) {
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400}
r.HTTPResponse = &http.Response{StatusCode: 400, Body: ioutil.NopCloser(bytes.NewBuffer([]byte{}))}
})
svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
assert.True(t, svc.Config.Credentials.IsExpired(), "Expect to start out expired")
assert.False(t, credProvider.retrieveCalled)
if !svc.Config.Credentials.IsExpired() {
t.Errorf("Expect to start out expired")
}
if credProvider.retrieveCalled {
t.Errorf("expect not called")
}
req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
req.Send()
assert.True(t, svc.Config.Credentials.IsExpired())
assert.False(t, credProvider.retrieveCalled)
if !svc.Config.Credentials.IsExpired() {
t.Errorf("Expect to start out expired")
}
if credProvider.retrieveCalled {
t.Errorf("expect not called")
}
_, err := svc.Config.Credentials.Get()
assert.NoError(t, err)
assert.True(t, credProvider.retrieveCalled)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if !credProvider.retrieveCalled {
t.Errorf("expect not called")
}
}
func TestAfterRetryWithContextCanceled(t *testing.T) {
c := awstesting.NewClient()
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{}, 0)}
req.SetContext(ctx)
req.Error = fmt.Errorf("some error")
req.Retryable = aws.Bool(true)
req.HTTPResponse = &http.Response{
StatusCode: 500,
}
close(ctx.DoneCh)
ctx.Error = fmt.Errorf("context canceled")
corehandlers.AfterRetryHandler.Fn(req)
if req.Error == nil {
t.Fatalf("expect error but didn't receive one")
}
aerr := req.Error.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q, error code got %q", e, a)
}
}
func TestAfterRetryWithContext(t *testing.T) {
c := awstesting.NewClient()
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{}, 0)}
req.SetContext(ctx)
req.Error = fmt.Errorf("some error")
req.Retryable = aws.Bool(true)
req.HTTPResponse = &http.Response{
StatusCode: 500,
}
corehandlers.AfterRetryHandler.Fn(req)
if req.Error != nil {
t.Fatalf("expect no error, got %v", req.Error)
}
if e, a := 1, req.RetryCount; e != a {
t.Errorf("expect retry count to be %d, got %d", e, a)
}
}
func TestSendWithContextCanceled(t *testing.T) {
c := awstesting.NewClient(&aws.Config{
SleepDelay: func(dur time.Duration) {
t.Errorf("SleepDelay should not be called")
},
})
req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{}, 0)}
req.SetContext(ctx)
req.Error = fmt.Errorf("some error")
req.Retryable = aws.Bool(true)
req.HTTPResponse = &http.Response{
StatusCode: 500,
}
close(ctx.DoneCh)
ctx.Error = fmt.Errorf("context canceled")
corehandlers.SendHandler.Fn(req)
if req.Error == nil {
t.Fatalf("expect error but didn't receive one")
}
aerr := req.Error.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %q, error code got %q", e, a)
}
}
type testSendHandlerTransport struct{}
@@ -108,6 +219,180 @@ func TestSendHandlerError(t *testing.T) {
r.Send()
assert.Error(t, r.Error)
assert.NotNil(t, r.HTTPResponse)
if r.Error == nil {
t.Errorf("expect error, got none")
}
if r.HTTPResponse == nil {
t.Errorf("expect response, got none")
}
}
func TestSendWithoutFollowRedirects(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/original":
w.Header().Set("Location", "/redirected")
w.WriteHeader(301)
case "/redirected":
t.Fatalf("expect not to redirect, but was")
}
}))
svc := awstesting.NewClient(&aws.Config{
DisableSSL: aws.Bool(true),
Endpoint: aws.String(server.URL),
})
svc.Handlers.Clear()
svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
r := svc.NewRequest(&request.Operation{
Name: "Operation",
HTTPPath: "/original",
}, nil, nil)
r.DisableFollowRedirects = true
err := r.Send()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := 301, r.HTTPResponse.StatusCode; e != a {
t.Errorf("expect %d status code, got %d", e, a)
}
}
func TestValidateReqSigHandler(t *testing.T) {
cases := []struct {
Req *request.Request
Resign bool
}{
{
Req: &request.Request{
Config: aws.Config{Credentials: credentials.AnonymousCredentials},
Time: time.Now().Add(-15 * time.Minute),
},
Resign: false,
},
{
Req: &request.Request{
Time: time.Now().Add(-15 * time.Minute),
},
Resign: true,
},
{
Req: &request.Request{
Time: time.Now().Add(-1 * time.Minute),
},
Resign: false,
},
}
for i, c := range cases {
resigned := false
c.Req.Handlers.Sign.PushBack(func(r *request.Request) {
resigned = true
})
corehandlers.ValidateReqSigHandler.Fn(c.Req)
if c.Req.Error != nil {
t.Errorf("expect no error, got %v", c.Req.Error)
}
if e, a := c.Resign, resigned; e != a {
t.Errorf("%d, expect %v to be %v", i, e, a)
}
}
}
func setupContentLengthTestServer(t *testing.T, hasContentLength bool, contentLength int64) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := r.Header["Content-Length"]
if e, a := hasContentLength, ok; e != a {
t.Errorf("expect %v to be %v", e, a)
}
if hasContentLength {
if e, a := contentLength, r.ContentLength; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
r.Body.Close()
authHeader := r.Header.Get("Authorization")
if hasContentLength {
if e, a := "content-length", authHeader; !strings.Contains(a, e) {
t.Errorf("expect %v to be in %v", e, a)
}
} else {
if e, a := "content-length", authHeader; strings.Contains(a, e) {
t.Errorf("expect %v to not be in %v", e, a)
}
}
if e, a := contentLength, int64(len(b)); e != a {
t.Errorf("expect %v to be %v", e, a)
}
}))
return server
}
func TestBuildContentLength_ZeroBody(t *testing.T) {
server := setupContentLengthTestServer(t, false, 0)
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
})
_, err := svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
}
func TestBuildContentLength_NegativeBody(t *testing.T) {
server := setupContentLengthTestServer(t, false, 0)
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
})
req, _ := svc.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
})
req.HTTPRequest.Header.Set("Content-Length", "-1")
if req.Error != nil {
t.Errorf("expect no error, got %v", req.Error)
}
}
func TestBuildContentLength_WithBody(t *testing.T) {
server := setupContentLengthTestServer(t, true, 1024)
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
})
_, err := svc.PutObject(&s3.PutObjectInput{
Bucket: aws.String("bucketname"),
Key: aws.String("keyname"),
Body: bytes.NewReader(make([]byte, 1024)),
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
}

View File

@@ -1,144 +1,17 @@
package corehandlers
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
import "github.com/aws/aws-sdk-go/aws/request"
// ValidateParametersHandler is a request handler to validate the input parameters.
// Validating parameters only has meaning if done prior to the request being sent.
var ValidateParametersHandler = request.NamedHandler{Name: "core.ValidateParametersHandler", Fn: func(r *request.Request) {
if r.ParamsFilled() {
v := validator{errors: []string{}}
v.validateAny(reflect.ValueOf(r.Params), "")
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}}
// A validator validates values. Collects validations errors which occurs.
type validator struct {
errors []string
}
// validateAny will validate any struct, slice or map type. All validations
// are also performed recursively for nested types.
func (v *validator) validateAny(value reflect.Value, path string) {
value = reflect.Indirect(value)
if !value.IsValid() {
if !r.ParamsFilled() {
return
}
switch value.Kind() {
case reflect.Struct:
v.validateStruct(value, path)
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
v.validateAny(value.Index(i), path+fmt.Sprintf("[%d]", i))
}
case reflect.Map:
for _, n := range value.MapKeys() {
v.validateAny(value.MapIndex(n), path+fmt.Sprintf("[%q]", n.String()))
if v, ok := r.Params.(request.Validator); ok {
if err := v.Validate(); err != nil {
r.Error = err
}
}
}
// validateStruct will validate the struct value's fields. If the structure has
// nested types those types will be validated also.
func (v *validator) validateStruct(value reflect.Value, path string) {
prefix := "."
if path == "" {
prefix = ""
}
for i := 0; i < value.Type().NumField(); i++ {
f := value.Type().Field(i)
if strings.ToLower(f.Name[0:1]) == f.Name[0:1] {
continue
}
fvalue := value.FieldByName(f.Name)
err := validateField(f, fvalue, validateFieldRequired, validateFieldMin)
if err != nil {
v.errors = append(v.errors, fmt.Sprintf("%s: %s", err.Error(), path+prefix+f.Name))
continue
}
v.validateAny(fvalue, path+prefix+f.Name)
}
}
type validatorFunc func(f reflect.StructField, fvalue reflect.Value) error
func validateField(f reflect.StructField, fvalue reflect.Value, funcs ...validatorFunc) error {
for _, fn := range funcs {
if err := fn(f, fvalue); err != nil {
return err
}
}
return nil
}
// Validates that a field has a valid value provided for required fields.
func validateFieldRequired(f reflect.StructField, fvalue reflect.Value) error {
if f.Tag.Get("required") == "" {
return nil
}
switch fvalue.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map:
if fvalue.IsNil() {
return fmt.Errorf("missing required parameter")
}
default:
if !fvalue.IsValid() {
return fmt.Errorf("missing required parameter")
}
}
return nil
}
// Validates that if a value is provided for a field, that value must be at
// least a minimum length.
func validateFieldMin(f reflect.StructField, fvalue reflect.Value) error {
minStr := f.Tag.Get("min")
if minStr == "" {
return nil
}
min, _ := strconv.ParseInt(minStr, 10, 64)
kind := fvalue.Kind()
if kind == reflect.Ptr {
if fvalue.IsNil() {
return nil
}
fvalue = fvalue.Elem()
}
switch fvalue.Kind() {
case reflect.String:
if int64(fvalue.Len()) < min {
return fmt.Errorf("field too short, minimum length %d", min)
}
case reflect.Slice, reflect.Map:
if fvalue.IsNil() {
return nil
}
if int64(fvalue.Len()) < min {
return fmt.Errorf("field too short, minimum length %d", min)
}
// TODO min can also apply to number minimum value.
}
return nil
}
}}

View File

@@ -1,17 +1,18 @@
package corehandlers_test
import (
"fmt"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/stretchr/testify/require"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/kinesis"
)
var testSvc = func() *client.Client {
@@ -26,18 +27,75 @@ var testSvc = func() *client.Client {
}()
type StructShape struct {
_ struct{} `type:"structure"`
RequiredList []*ConditionalStructShape `required:"true"`
RequiredMap map[string]*ConditionalStructShape `required:"true"`
RequiredBool *bool `required:"true"`
OptionalStruct *ConditionalStructShape
hiddenParameter *string
_ struct{}
}
func (s *StructShape) Validate() error {
invalidParams := request.ErrInvalidParams{Context: "StructShape"}
if s.RequiredList == nil {
invalidParams.Add(request.NewErrParamRequired("RequiredList"))
}
if s.RequiredMap == nil {
invalidParams.Add(request.NewErrParamRequired("RequiredMap"))
}
if s.RequiredBool == nil {
invalidParams.Add(request.NewErrParamRequired("RequiredBool"))
}
if s.RequiredList != nil {
for i, v := range s.RequiredList {
if v == nil {
continue
}
if err := v.Validate(); err != nil {
invalidParams.AddNested(fmt.Sprintf("%s[%v]", "RequiredList", i), err.(request.ErrInvalidParams))
}
}
}
if s.RequiredMap != nil {
for i, v := range s.RequiredMap {
if v == nil {
continue
}
if err := v.Validate(); err != nil {
invalidParams.AddNested(fmt.Sprintf("%s[%v]", "RequiredMap", i), err.(request.ErrInvalidParams))
}
}
}
if s.OptionalStruct != nil {
if err := s.OptionalStruct.Validate(); err != nil {
invalidParams.AddNested("OptionalStruct", err.(request.ErrInvalidParams))
}
}
if invalidParams.Len() > 0 {
return invalidParams
}
return nil
}
type ConditionalStructShape struct {
_ struct{} `type:"structure"`
Name *string `required:"true"`
_ struct{}
}
func (s *ConditionalStructShape) Validate() error {
invalidParams := request.ErrInvalidParams{Context: "ConditionalStructShape"}
if s.Name == nil {
invalidParams.Add(request.NewErrParamRequired("Name"))
}
if invalidParams.Len() > 0 {
return invalidParams
}
return nil
}
func TestNoErrors(t *testing.T) {
@@ -53,7 +111,9 @@ func TestNoErrors(t *testing.T) {
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.NoError(t, req.Error)
if req.Error != nil {
t.Fatalf("expect no error, got %v", req.Error)
}
}
func TestMissingRequiredParameters(t *testing.T) {
@@ -61,9 +121,33 @@ func TestMissingRequiredParameters(t *testing.T) {
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList\n- missing required parameter: RequiredMap\n- missing required parameter: RequiredBool", req.Error.(awserr.Error).Message())
if req.Error == nil {
t.Fatalf("expect error")
}
if e, a := "InvalidParameter", req.Error.(awserr.Error).Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "3 validation error(s) found.", req.Error.(awserr.Error).Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
errs := req.Error.(awserr.BatchedErrors).OrigErrs()
if e, a := 3, len(errs); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredList.", errs[0].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredMap.", errs[1].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredBool.", errs[2].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "InvalidParameter: 3 validation error(s) found.\n- missing required field, StructShape.RequiredList.\n- missing required field, StructShape.RequiredMap.\n- missing required field, StructShape.RequiredBool.\n", req.Error.Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestNestedMissingRequiredParameters(t *testing.T) {
@@ -80,41 +164,89 @@ func TestNestedMissingRequiredParameters(t *testing.T) {
req := testSvc.NewRequest(&request.Operation{}, input, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Error(t, req.Error)
assert.Equal(t, "InvalidParameter", req.Error.(awserr.Error).Code())
assert.Equal(t, "3 validation errors:\n- missing required parameter: RequiredList[0].Name\n- missing required parameter: RequiredMap[\"key2\"].Name\n- missing required parameter: OptionalStruct.Name", req.Error.(awserr.Error).Message())
if req.Error == nil {
t.Fatalf("expect error")
}
if e, a := "InvalidParameter", req.Error.(awserr.Error).Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "3 validation error(s) found.", req.Error.(awserr.Error).Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
errs := req.Error.(awserr.BatchedErrors).OrigErrs()
if e, a := 3, len(errs); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredList[0].Name.", errs[0].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.RequiredMap[key2].Name.", errs[1].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ParamRequiredError: missing required field, StructShape.OptionalStruct.Name.", errs[2].Error(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
type testInput struct {
StringField string `min:"5"`
PtrStrField *string `min:"2"`
StringField *string `min:"5"`
ListField []string `min:"3"`
MapField map[string]string `min:"4"`
}
func (s testInput) Validate() error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
if s.StringField != nil && len(*s.StringField) < 5 {
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
}
if s.ListField != nil && len(s.ListField) < 3 {
invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
}
if s.MapField != nil && len(s.MapField) < 4 {
invalidParams.Add(request.NewErrParamMinLen("MapField", 4))
}
if invalidParams.Len() > 0 {
return invalidParams
}
return nil
}
var testsFieldMin = []struct {
err awserr.Error
in testInput
}{
{
err: awserr.New("InvalidParameter", "1 validation errors:\n- field too short, minimum length 5: StringField", nil),
in: testInput{StringField: "abcd"},
err: func() awserr.Error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
return invalidParams
}(),
in: testInput{StringField: aws.String("abcd")},
},
{
err: awserr.New("InvalidParameter", "2 validation errors:\n- field too short, minimum length 5: StringField\n- field too short, minimum length 3: ListField", nil),
in: testInput{StringField: "abcd", ListField: []string{"a", "b"}},
err: func() awserr.Error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
return invalidParams
}(),
in: testInput{StringField: aws.String("abcd"), ListField: []string{"a", "b"}},
},
{
err: awserr.New("InvalidParameter", "3 validation errors:\n- field too short, minimum length 5: StringField\n- field too short, minimum length 3: ListField\n- field too short, minimum length 4: MapField", nil),
in: testInput{StringField: "abcd", ListField: []string{"a", "b"}, MapField: map[string]string{"a": "a", "b": "b"}},
},
{
err: awserr.New("InvalidParameter", "1 validation errors:\n- field too short, minimum length 2: PtrStrField", nil),
in: testInput{StringField: "abcde", PtrStrField: aws.String("v")},
err: func() awserr.Error {
invalidParams := request.ErrInvalidParams{Context: "testInput"}
invalidParams.Add(request.NewErrParamMinLen("StringField", 5))
invalidParams.Add(request.NewErrParamMinLen("ListField", 3))
invalidParams.Add(request.NewErrParamMinLen("MapField", 4))
return invalidParams
}(),
in: testInput{StringField: aws.String("abcd"), ListField: []string{"a", "b"}, MapField: map[string]string{"a": "a", "b": "b"}},
},
{
err: nil,
in: testInput{StringField: "abcde", PtrStrField: aws.String("value"),
in: testInput{StringField: aws.String("abcde"),
ListField: []string{"a", "b", "c"}, MapField: map[string]string{"a": "a", "b": "b", "c": "c", "d": "d"}},
},
}
@@ -124,6 +256,31 @@ func TestValidateFieldMinParameter(t *testing.T) {
req := testSvc.NewRequest(&request.Operation{}, &c.in, nil)
corehandlers.ValidateParametersHandler.Fn(req)
require.Equal(t, c.err, req.Error, "%d case failed", i)
if e, a := c.err, req.Error; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v, got %v", i, e, a)
}
}
}
func BenchmarkValidateAny(b *testing.B) {
input := &kinesis.PutRecordsInput{
StreamName: aws.String("stream"),
}
for i := 0; i < 100; i++ {
record := &kinesis.PutRecordsRequestEntry{
Data: make([]byte, 10000),
PartitionKey: aws.String("partition"),
}
input.Records = append(input.Records, record)
}
req, _ := kinesis.New(unit.Session).PutRecordsRequest(input)
b.ResetTimer()
for i := 0; i < b.N; i++ {
corehandlers.ValidateParametersHandler.Fn(req)
if err := req.Error; err != nil {
b.Fatalf("validation failed: %v", err)
}
}
}

View File

@@ -0,0 +1,37 @@
package corehandlers
import (
"os"
"runtime"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version
// to the user agent.
var SDKVersionUserAgentHandler = request.NamedHandler{
Name: "core.SDKVersionUserAgentHandler",
Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion,
runtime.Version(), runtime.GOOS, runtime.GOARCH),
}
const execEnvVar = `AWS_EXECUTION_ENV`
const execEnvUAKey = `exec_env`
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's
// execution environment to the user agent.
//
// If the environment variable AWS_EXECUTION_ENV is set, its value will be
// appended to the user agent string.
var AddHostExecEnvUserAgentHander = request.NamedHandler{
Name: "core.AddHostExecEnvUserAgentHander",
Fn: func(r *request.Request) {
v := os.Getenv(execEnvVar)
if len(v) == 0 {
return
}
request.AddToUserAgent(r, execEnvUAKey+"/"+v)
},
}

View File

@@ -0,0 +1,40 @@
package corehandlers
import (
"net/http"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestAddHostExecEnvUserAgentHander(t *testing.T) {
cases := []struct {
ExecEnv string
Expect string
}{
{ExecEnv: "Lambda", Expect: "exec_env/Lambda"},
{ExecEnv: "", Expect: ""},
{ExecEnv: "someThingCool", Expect: "exec_env/someThingCool"},
}
for i, c := range cases {
os.Clearenv()
os.Setenv(execEnvVar, c.ExecEnv)
req := &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
}
AddHostExecEnvUserAgentHander.Fn(req)
if err := req.Error; err != nil {
t.Fatalf("%d, expect no error, got %v", i, err)
}
if e, a := c.Expect, req.HTTPRequest.Header.Get("User-Agent"); e != a {
t.Errorf("%d, expect %v user agent, got %v", i, e, a)
}
}
}

View File

@@ -8,8 +8,12 @@ var (
// ErrNoValidProvidersFoundInChain Is returned when there are no valid
// providers in the ChainProvider.
//
// @readonly
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders", "no valid providers in chain", nil)
// This has been deprecated. For verbose error messaging set
// aws.Config.CredentialsChainVerboseErrors to true.
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
`no valid providers in chain. Deprecated.
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,
nil)
)
// A ChainProvider will search for a provider which returns credentials
@@ -28,25 +32,28 @@ var (
//
// Example of ChainProvider to be used with an EnvProvider and EC2RoleProvider.
// In this example EnvProvider will first check if any credentials are available
// vai the environment variables. If there are none ChainProvider will check
// via the environment variables. If there are none ChainProvider will check
// the next Provider in the list, EC2RoleProvider in this case. If EC2RoleProvider
// does not return any credentials ChainProvider will return the error
// ErrNoValidProvidersFoundInChain
//
// creds := NewChainCredentials(
// []Provider{
// &EnvProvider{},
// &EC2RoleProvider{
// creds := credentials.NewChainCredentials(
// []credentials.Provider{
// &credentials.EnvProvider{},
// &ec2rolecreds.EC2RoleProvider{
// Client: ec2metadata.New(sess),
// },
// })
//
// // Usage of ChainCredentials with aws.Config
// svc := ec2.New(&aws.Config{Credentials: creds})
// svc := ec2.New(session.Must(session.NewSession(&aws.Config{
// Credentials: creds,
// })))
//
type ChainProvider struct {
Providers []Provider
curr Provider
Providers []Provider
curr Provider
VerboseErrors bool
}
// NewChainCredentials returns a pointer to a new Credentials object
@@ -63,17 +70,23 @@ func NewChainCredentials(providers []Provider) *Credentials {
// If a provider is found it will be cached and any calls to IsExpired()
// will return the expired state of the cached provider.
func (c *ChainProvider) Retrieve() (Value, error) {
var errs []error
for _, p := range c.Providers {
if creds, err := p.Retrieve(); err == nil {
creds, err := p.Retrieve()
if err == nil {
c.curr = p
return creds, nil
}
errs = append(errs, err)
}
c.curr = nil
// TODO better error reporting. maybe report error for each failed retrieve?
return Value{}, ErrNoValidProvidersFoundInChain
var err error
err = ErrNoValidProvidersFoundInChain
if c.VerboseErrors {
err = awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
}
return Value{}, err
}
// IsExpired will returned the expired state of the currently cached provider

View File

@@ -7,6 +7,54 @@ import (
"github.com/stretchr/testify/assert"
)
type secondStubProvider struct {
creds Value
expired bool
err error
}
func (s *secondStubProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "secondStubProvider"
return s.creds, s.err
}
func (s *secondStubProvider) IsExpired() bool {
return s.expired
}
func TestChainProviderWithNames(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&secondStubProvider{
creds: Value{
AccessKeyID: "AKIF",
SecretAccessKey: "NOSECRET",
SessionToken: "",
},
},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "secondStubProvider", creds.ProviderName, "Expect provider name to match")
// Also check credentials
assert.Equal(t, "AKIF", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "NOSECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
}
func TestChainProviderGet(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
@@ -56,18 +104,51 @@ func TestChainProviderWithNoProvider(t *testing.T) {
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t, ErrNoValidProvidersFoundInChain, err, "Expect no providers error returned")
assert.Equal(t,
ErrNoValidProvidersFoundInChain,
err,
"Expect no providers error returned")
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
errs := []error{
awserr.New("FirstError", "first provider error", nil),
awserr.New("SecondError", "second provider error", nil),
}
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&stubProvider{err: errs[0]},
&stubProvider{err: errs[1]},
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t, ErrNoValidProvidersFoundInChain, err, "Expect no providers error returned")
assert.Equal(t,
ErrNoValidProvidersFoundInChain,
err,
"Expect no providers error returned")
}
func TestChainProviderWithNoValidProviderWithVerboseEnabled(t *testing.T) {
errs := []error{
awserr.New("FirstError", "first provider error", nil),
awserr.New("SecondError", "second provider error", nil),
}
p := &ChainProvider{
VerboseErrors: true,
Providers: []Provider{
&stubProvider{err: errs[0]},
&stubProvider{err: errs[1]},
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
_, err := p.Retrieve()
assert.Equal(t,
awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs),
err,
"Expect no providers error returned")
}

View File

@@ -14,7 +14,7 @@
//
// Example of using the environment variable credentials.
//
// creds := NewEnvCredentials()
// creds := credentials.NewEnvCredentials()
//
// // Retrieve the credentials value
// credValue, err := creds.Get()
@@ -26,7 +26,7 @@
// This may be helpful to proactively expire credentials and refresh them sooner
// than they would naturally expire on their own.
//
// creds := NewCredentials(&EC2RoleProvider{})
// creds := credentials.NewCredentials(&ec2rolecreds.EC2RoleProvider{})
// creds.Expire()
// credsValue, err := creds.Get()
// // New credentials will be retrieved instead of from cache.
@@ -43,7 +43,7 @@
// func (m *MyProvider) Retrieve() (Value, error) {...}
// func (m *MyProvider) IsExpired() bool {...}
//
// creds := NewCredentials(&MyProvider{})
// creds := credentials.NewCredentials(&MyProvider{})
// credValue, err := creds.Get()
//
package credentials
@@ -60,10 +60,10 @@ import (
// when making service API calls. For example, when accessing public
// s3 buckets.
//
// svc := s3.New(&aws.Config{Credentials: AnonymousCredentials})
// svc := s3.New(session.Must(session.NewSession(&aws.Config{
// Credentials: credentials.AnonymousCredentials,
// })))
// // Access public S3 buckets.
//
// @readonly
var AnonymousCredentials = NewStaticCredentials("", "", "")
// A Value is the AWS credentials value for individual credential fields.
@@ -76,6 +76,9 @@ type Value struct {
// AWS Session Token
SessionToken string
// Provider used to get credentials
ProviderName string
}
// A Provider is the interface for any component which will provide credentials
@@ -85,7 +88,7 @@ type Value struct {
// The Provider should not need to implement its own mutexes, because
// that will be managed by Credentials.
type Provider interface {
// Refresh returns nil if it successfully retrieved the value.
// Retrieve returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
Retrieve() (Value, error)
@@ -94,6 +97,27 @@ type Provider interface {
IsExpired() bool
}
// An ErrorProvider is a stub credentials provider that always returns an error
// this is used by the SDK when construction a known provider is not possible
// due to an error.
type ErrorProvider struct {
// The error to be returned from Retrieve
Err error
// The provider name to set on the Retrieved returned Value
ProviderName string
}
// Retrieve will always return the error that the ErrorProvider was created with.
func (p ErrorProvider) Retrieve() (Value, error) {
return Value{ProviderName: p.ProviderName}, p.Err
}
// IsExpired will always return not expired.
func (p ErrorProvider) IsExpired() bool {
return false
}
// A Expiry provides shared expiration logic to be used by credentials
// providers to implement expiry functionality.
//
@@ -132,13 +156,14 @@ func (e *Expiry) SetExpiration(expiration time.Time, window time.Duration) {
// IsExpired returns if the credentials are expired.
func (e *Expiry) IsExpired() bool {
if e.CurrentTime == nil {
e.CurrentTime = time.Now
curTime := e.CurrentTime
if curTime == nil {
curTime = time.Now
}
return e.expiration.Before(e.CurrentTime())
return e.expiration.Before(curTime())
}
// A Credentials provides synchronous safe retrieval of AWS credentials Value.
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
// Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials.
//
@@ -152,7 +177,8 @@ func (e *Expiry) IsExpired() bool {
type Credentials struct {
creds Value
forceRefresh bool
m sync.Mutex
m sync.RWMutex
provider Provider
}
@@ -175,6 +201,17 @@ func NewCredentials(provider Provider) *Credentials {
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) Get() (Value, error) {
// Check the cached credentials first with just the read lock.
c.m.RLock()
if !c.isExpired() {
creds := c.creds
c.m.RUnlock()
return creds, nil
}
c.m.RUnlock()
// Credentials are expired need to retrieve the credentials taking the full
// lock.
c.m.Lock()
defer c.m.Unlock()
@@ -208,8 +245,8 @@ func (c *Credentials) Expire() {
// If the Credentials were forced to be expired with Expire() this will
// reflect that override.
func (c *Credentials) IsExpired() bool {
c.m.Lock()
defer c.m.Unlock()
c.m.RLock()
defer c.m.RUnlock()
return c.isExpired()
}

View File

@@ -0,0 +1,90 @@
// +build go1.9
package credentials
import (
"fmt"
"strconv"
"sync"
"testing"
"time"
)
func BenchmarkCredentials_Get(b *testing.B) {
stub := &stubProvider{}
cases := []int{1, 10, 100, 500, 1000, 10000}
for _, c := range cases {
b.Run(strconv.Itoa(c), func(b *testing.B) {
creds := NewCredentials(stub)
var wg sync.WaitGroup
wg.Add(c)
for i := 0; i < c; i++ {
go func() {
for j := 0; j < b.N; j++ {
v, err := creds.Get()
if err != nil {
b.Fatalf("expect no error %v, %v", v, err)
}
}
wg.Done()
}()
}
b.ResetTimer()
wg.Wait()
})
}
}
func BenchmarkCredentials_Get_Expire(b *testing.B) {
p := &blockProvider{}
expRates := []int{10000, 1000, 100}
cases := []int{1, 10, 100, 500, 1000, 10000}
for _, expRate := range expRates {
for _, c := range cases {
b.Run(fmt.Sprintf("%d-%d", expRate, c), func(b *testing.B) {
creds := NewCredentials(p)
var wg sync.WaitGroup
wg.Add(c)
for i := 0; i < c; i++ {
go func(id int) {
for j := 0; j < b.N; j++ {
v, err := creds.Get()
if err != nil {
b.Fatalf("expect no error %v, %v", v, err)
}
// periodically expire creds to cause rwlock
if id == 0 && j%expRate == 0 {
creds.Expire()
}
}
wg.Done()
}(i)
}
b.ResetTimer()
wg.Wait()
})
}
}
}
type blockProvider struct {
creds Value
expired bool
err error
}
func (s *blockProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "blockProvider"
time.Sleep(time.Millisecond)
return s.creds, s.err
}
func (s *blockProvider) IsExpired() bool {
return s.expired
}

View File

@@ -2,6 +2,7 @@ package credentials
import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
@@ -15,6 +16,7 @@ type stubProvider struct {
func (s *stubProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "stubProvider"
return s.creds, s.err
}
func (s *stubProvider) IsExpired() bool {
@@ -60,3 +62,38 @@ func TestCredentialsExpire(t *testing.T) {
stub.expired = true
assert.True(t, c.IsExpired(), "Expected to be expired")
}
type MockProvider struct {
Expiry
}
func (*MockProvider) Retrieve() (Value, error) {
return Value{}, nil
}
func TestCredentialsGetWithProviderName(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, creds.ProviderName, "stubProvider", "Expected provider name to match")
}
func TestCredentialsIsExpired_Race(t *testing.T) {
creds := NewChainCredentials([]Provider{&MockProvider{}})
starter := make(chan struct{})
for i := 0; i < 10; i++ {
go func() {
<-starter
for {
creds.IsExpired()
}
}()
}
close(starter)
time.Sleep(10 * time.Second)
}

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"encoding/json"
"fmt"
"path"
"strings"
"time"
@@ -12,8 +11,12 @@ import (
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/internal/sdkuri"
)
// ProviderName provides a name of EC2Role provider
const ProviderName = "EC2RoleProvider"
// A EC2RoleProvider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
@@ -85,17 +88,17 @@ func NewCredentialsWithClient(client *ec2metadata.EC2Metadata, options ...func(*
func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
credsList, err := requestCredList(m.Client)
if err != nil {
return credentials.Value{}, err
return credentials.Value{ProviderName: ProviderName}, err
}
if len(credsList) == 0 {
return credentials.Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
return credentials.Value{ProviderName: ProviderName}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
}
credsName := credsList[0]
roleCreds, err := requestCred(m.Client, credsName)
if err != nil {
return credentials.Value{}, err
return credentials.Value{ProviderName: ProviderName}, err
}
m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow)
@@ -104,10 +107,11 @@ func (m *EC2RoleProvider) Retrieve() (credentials.Value, error) {
AccessKeyID: roleCreds.AccessKeyID,
SecretAccessKey: roleCreds.SecretAccessKey,
SessionToken: roleCreds.Token,
ProviderName: ProviderName,
}, nil
}
// A ec2RoleCredRespBody provides the shape for unmarshalling credential
// A ec2RoleCredRespBody provides the shape for unmarshaling credential
// request responses.
type ec2RoleCredRespBody struct {
// Success State
@@ -121,14 +125,14 @@ type ec2RoleCredRespBody struct {
Message string
}
const iamSecurityCredsPath = "/iam/security-credentials"
const iamSecurityCredsPath = "iam/security-credentials/"
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
resp, err := client.GetMetadata(iamSecurityCredsPath)
if err != nil {
return nil, awserr.New("EC2RoleRequestError", "failed to list EC2 Roles", err)
return nil, awserr.New("EC2RoleRequestError", "no EC2 instance role found", err)
}
credsList := []string{}
@@ -138,7 +142,7 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
}
if err := s.Err(); err != nil {
return nil, awserr.New("SerializationError", "failed to read list of EC2 Roles", err)
return nil, awserr.New("SerializationError", "failed to read EC2 instance role from metadata service", err)
}
return credsList, nil
@@ -149,11 +153,11 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadata(path.Join(iamSecurityCredsPath, credsName))
resp, err := client.GetMetadata(sdkuri.PathJoin(iamSecurityCredsPath, credsName))
if err != nil {
return ec2RoleCredRespBody{},
awserr.New("EC2RoleRequestError",
fmt.Sprintf("failed to get %s EC2 Role credentials", credsName),
fmt.Sprintf("failed to get %s EC2 instance role credentials", credsName),
err)
}
@@ -161,7 +165,7 @@ func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCred
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil {
return ec2RoleCredRespBody{},
awserr.New("SerializationError",
fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName),
fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName),
err)
}

View File

@@ -13,7 +13,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
const credsRespTmpl = `{
@@ -34,7 +34,7 @@ const credsFailRespTmpl = `{
func initTestServer(expireOn string, failAssume bool) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/latest/meta-data/iam/security-credentials" {
if r.URL.Path == "/latest/meta-data/iam/security-credentials/" {
fmt.Fprintln(w, "RoleName")
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
if failAssume {
@@ -55,7 +55,7 @@ func TestEC2RoleProvider(t *testing.T) {
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
creds, err := p.Retrieve()
@@ -71,7 +71,7 @@ func TestEC2RoleProviderFailAssume(t *testing.T) {
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
creds, err := p.Retrieve()
@@ -92,7 +92,7 @@ func TestEC2RoleProviderIsExpired(t *testing.T) {
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
@@ -117,7 +117,7 @@ func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
ExpiryWindow: time.Hour * 1,
}
p.CurrentTime = func() time.Time {
@@ -143,7 +143,7 @@ func BenchmarkEC3RoleProvider(b *testing.B) {
defer server.Close()
p := &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
Client: ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}),
}
_, err := p.Retrieve()
if err != nil {

View File

@@ -0,0 +1,198 @@
// Package endpointcreds provides support for retrieving credentials from an
// arbitrary HTTP endpoint.
//
// The credentials endpoint Provider can receive both static and refreshable
// credentials that will expire. Credentials are static when an "Expiration"
// value is not provided in the endpoint's response.
//
// Static credentials will never expire once they have been retrieved. The format
// of the static credentials response:
// {
// "AccessKeyId" : "MUA...",
// "SecretAccessKey" : "/7PC5om....",
// }
//
// Refreshable credentials will expire within the "ExpiryWindow" of the Expiration
// value in the response. The format of the refreshable credentials response:
// {
// "AccessKeyId" : "MUA...",
// "SecretAccessKey" : "/7PC5om....",
// "Token" : "AQoDY....=",
// "Expiration" : "2016-02-25T06:03:31Z"
// }
//
// Errors should be returned in the following format and only returned with 400
// or 500 HTTP status codes.
// {
// "code": "ErrorCode",
// "message": "Helpful error message."
// }
package endpointcreds
import (
"encoding/json"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
// ProviderName is the name of the credentials provider.
const ProviderName = `CredentialsEndpointProvider`
// Provider satisfies the credentials.Provider interface, and is a client to
// retrieve credentials from an arbitrary endpoint.
type Provider struct {
staticCreds bool
credentials.Expiry
// Requires a AWS Client to make HTTP requests to the endpoint with.
// the Endpoint the request will be made to is provided by the aws.Config's
// Endpoint value.
Client *client.Client
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request.
AuthorizationToken string
}
// NewProviderClient returns a credentials Provider for retrieving AWS credentials
// from arbitrary endpoint.
func NewProviderClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) credentials.Provider {
p := &Provider{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: "CredentialsEndpoint",
Endpoint: endpoint,
},
handlers,
),
}
p.Client.Handlers.Unmarshal.PushBack(unmarshalHandler)
p.Client.Handlers.UnmarshalError.PushBack(unmarshalError)
p.Client.Handlers.Validate.Clear()
p.Client.Handlers.Validate.PushBack(validateEndpointHandler)
for _, option := range options {
option(p)
}
return p
}
// NewCredentialsClient returns a Credentials wrapper for retrieving credentials
// from an arbitrary endpoint concurrently. The client will request the
func NewCredentialsClient(cfg aws.Config, handlers request.Handlers, endpoint string, options ...func(*Provider)) *credentials.Credentials {
return credentials.NewCredentials(NewProviderClient(cfg, handlers, endpoint, options...))
}
// IsExpired returns true if the credentials retrieved are expired, or not yet
// retrieved.
func (p *Provider) IsExpired() bool {
if p.staticCreds {
return false
}
return p.Expiry.IsExpired()
}
// Retrieve will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails.
func (p *Provider) Retrieve() (credentials.Value, error) {
resp, err := p.getCredentials()
if err != nil {
return credentials.Value{ProviderName: ProviderName},
awserr.New("CredentialsEndpointError", "failed to load credentials", err)
}
if resp.Expiration != nil {
p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
} else {
p.staticCreds = true
}
return credentials.Value{
AccessKeyID: resp.AccessKeyID,
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.Token,
ProviderName: ProviderName,
}, nil
}
type getCredentialsOutput struct {
Expiration *time.Time
AccessKeyID string
SecretAccessKey string
Token string
}
type errorOutput struct {
Code string `json:"code"`
Message string `json:"message"`
}
func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
op := &request.Operation{
Name: "GetCredentials",
HTTPMethod: "GET",
}
out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out)
req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
}
return out, req.Send()
}
func validateEndpointHandler(r *request.Request) {
if len(r.ClientInfo.Endpoint) == 0 {
r.Error = aws.ErrMissingEndpoint
}
}
func unmarshalHandler(r *request.Request) {
defer r.HTTPResponse.Body.Close()
out := r.Data.(*getCredentialsOutput)
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil {
r.Error = awserr.New("SerializationError",
"failed to decode endpoint credentials",
err,
)
}
}
func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
var errOut errorOutput
if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil {
r.Error = awserr.New("SerializationError",
"failed to decode endpoint credentials",
err,
)
}
// Response body format is not consistent between metadata endpoints.
// Grab the error message as a string and include that as the source error
r.Error = awserr.New(errOut.Code, errOut.Message, nil)
}

View File

@@ -0,0 +1,218 @@
package endpointcreds_test
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
func TestRetrieveRefreshableCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if e, a := "/path/to/endpoint", r.URL.Path; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "application/json", r.Header.Get("Accept"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "else", r.URL.Query().Get("something"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
"Token": "TOKEN",
"Expiration": time.Now().Add(1 * time.Hour),
})
if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config,
unit.Session.Handlers,
server.URL+"/path/to/endpoint?something=else",
)
creds, err := client.Retrieve()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}
client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
}
}
func TestRetrieveStaticCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
})
if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
creds, err := client.Retrieve()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no SessionToken, got %#v", v)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}
}
func TestFailedRetrieveCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(400)
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"Code": "Error",
"Message": "Message",
})
if err != nil {
fmt.Println("failed to write error", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
creds, err := client.Retrieve()
if err == nil {
t.Errorf("expect error, got none")
}
aerr := err.(awserr.Error)
if e, a := "CredentialsEndpointError", aerr.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "failed to load credentials", aerr.Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
aerr = aerr.OrigErr().(awserr.Error)
if e, a := "Error", aerr.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "Message", aerr.Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.AccessKeyID; len(v) != 0 {
t.Errorf("expect empty, got %#v", v)
}
if v := creds.SecretAccessKey; len(v) != 0 {
t.Errorf("expect empty, got %#v", v)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %#v", v)
}
if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
}
}
func TestAuthorizationToken(t *testing.T) {
const expectAuthToken = "Basic abc123"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if e, a := "/path/to/endpoint", r.URL.Path; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "application/json", r.Header.Get("Accept"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectAuthToken, r.Header.Get("Authorization"); e != a {
t.Fatalf("expect %v, got %v", e, a)
}
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
"Token": "TOKEN",
"Expiration": time.Now().Add(1 * time.Hour),
})
if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config,
unit.Session.Handlers,
server.URL+"/path/to/endpoint?something=else",
func(p *endpointcreds.Provider) {
p.AuthorizationToken = expectAuthToken
},
)
creds, err := client.Retrieve()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}
client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
}
}

View File

@@ -6,17 +6,16 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
)
// EnvProviderName provides a name of Env provider
const EnvProviderName = "EnvProvider"
var (
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
// found in the process's environment.
//
// @readonly
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
// can't be found in the process's environment.
//
// @readonly
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
)
@@ -26,6 +25,7 @@ var (
// Environment variables used:
//
// * Access Key ID: AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY
//
// * Secret Access Key: AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY
type EnvProvider struct {
retrieved bool
@@ -52,11 +52,11 @@ func (e *EnvProvider) Retrieve() (Value, error) {
}
if id == "" {
return Value{}, ErrAccessKeyIDNotFound
return Value{ProviderName: EnvProviderName}, ErrAccessKeyIDNotFound
}
if secret == "" {
return Value{}, ErrSecretAccessKeyNotFound
return Value{ProviderName: EnvProviderName}, ErrSecretAccessKeyNotFound
}
e.retrieved = true
@@ -64,6 +64,7 @@ func (e *EnvProvider) Retrieve() (Value, error) {
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: os.Getenv("AWS_SESSION_TOKEN"),
ProviderName: EnvProviderName,
}, nil
}

View File

@@ -1,12 +0,0 @@
[default]
aws_access_key_id = accessKey
aws_secret_access_key = secret
aws_session_token = token
[no_token]
aws_access_key_id = accessKey
aws_secret_access_key = secret
[with_colon]
aws_access_key_id: accessKey
aws_secret_access_key: secret

View File

@@ -0,0 +1,5 @@
// +build !go1.8
// Package plugincreds provides usage of Go plugins for providing credentials
// to the SDK. Only available with Go 1.8 and above.
package plugincreds

View File

@@ -0,0 +1,211 @@
// +build go1.8
// Package plugincreds implements a credentials provider sourced from a Go
// plugin. This package allows you to use a Go plugin to retrieve AWS credentials
// for the SDK to use for service API calls.
//
// As of Go 1.8 plugins are only supported on the Linux platform.
//
// Plugin Symbol Name
//
// The "GetAWSSDKCredentialProvider" is the symbol name that will be used to
// lookup the credentials provider getter from the plugin. If you want to use a
// custom symbol name you should use GetPluginProviderFnsByName to lookup the
// symbol by a custom name.
//
// This symbol is a function that returns two additional functions. One to
// retrieve the credentials, and another to determine if the credentials have
// expired.
//
// Plugin Symbol Signature
//
// The plugin credential provider requires the symbol to match the
// following signature.
//
// func() (RetrieveFn func() (key, secret, token string, err error), IsExpiredFn func() bool)
//
// Plugin Implementation Exmaple
//
// The following is an example implementation of a SDK credential provider using
// the plugin provider in this package. See the SDK's example/aws/credential/plugincreds/plugin
// folder for a runnable example of this.
//
// package main
//
// func main() {}
//
// var myCredProvider provider
//
// // Build: go build -o plugin.so -buildmode=plugin plugin.go
// func init() {
// // Initialize a mock credential provider with stubs
// myCredProvider = provider{"a","b","c"}
// }
//
// // GetAWSSDKCredentialProvider is the symbol SDK will lookup and use to
// // get the credential provider's retrieve and isExpired functions.
// func GetAWSSDKCredentialProvider() (func() (key, secret, token string, err error), func() bool) {
// return myCredProvider.Retrieve, myCredProvider.IsExpired
// }
//
// // mock implementation of a type that returns retrieves credentials and
// // returns if they have expired.
// type provider struct {
// key, secret, token string
// }
//
// func (p provider) Retrieve() (key, secret, token string, err error) {
// return p.key, p.secret, p.token, nil
// }
//
// func (p *provider) IsExpired() bool {
// return false;
// }
//
// Configuring SDK for Plugin Credentials
//
// To configure the SDK to use a plugin's credential provider you'll need to first
// open the plugin file using the plugin standard library package. Once you have
// a handle to the plugin you can use the NewCredentials function of this package
// to create a new credentials.Credentials value that can be set as the
// credentials loader of a Session or Config. See the SDK's example/aws/credential/plugincreds
// folder for a runnable example of this.
//
// // Open plugin, and load it into the process.
// p, err := plugin.Open("somefile.so")
// if err != nil {
// return nil, err
// }
//
// // Create a new Credentials value which will source the provider's Retrieve
// // and IsExpired functions from the plugin.
// creds, err := plugincreds.NewCredentials(p)
// if err != nil {
// return nil, err
// }
//
// // Example to configure a Session with the newly created credentials that
// // will be sourced using the plugin's functionality.
// sess := session.Must(session.NewSession(&aws.Config{
// Credentials: creds,
// }))
package plugincreds
import (
"fmt"
"plugin"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
)
// ProviderSymbolName the symbol name the SDK will use to lookup the plugin
// provider value from.
const ProviderSymbolName = `GetAWSSDKCredentialProvider`
// ProviderName is the name this credentials provider will label any returned
// credentials Value with.
const ProviderName = `PluginCredentialsProvider`
const (
// ErrCodeLookupSymbolError failed to lookup symbol
ErrCodeLookupSymbolError = "LookupSymbolError"
// ErrCodeInvalidSymbolError symbol invalid
ErrCodeInvalidSymbolError = "InvalidSymbolError"
// ErrCodePluginRetrieveNil Retrieve function was nil
ErrCodePluginRetrieveNil = "PluginRetrieveNilError"
// ErrCodePluginIsExpiredNil IsExpired Function was nil
ErrCodePluginIsExpiredNil = "PluginIsExpiredNilError"
// ErrCodePluginProviderRetrieve plugin provider's retrieve returned error
ErrCodePluginProviderRetrieve = "PluginProviderRetrieveError"
)
// Provider is the credentials provider that will use the plugin provided
// Retrieve and IsExpired functions to retrieve credentials.
type Provider struct {
RetrieveFn func() (key, secret, token string, err error)
IsExpiredFn func() bool
}
// NewCredentials returns a new Credentials loader using the plugin provider.
// If the symbol isn't found or is invalid in the plugin an error will be
// returned.
func NewCredentials(p *plugin.Plugin) (*credentials.Credentials, error) {
retrieve, isExpired, err := GetPluginProviderFns(p)
if err != nil {
return nil, err
}
return credentials.NewCredentials(Provider{
RetrieveFn: retrieve,
IsExpiredFn: isExpired,
}), nil
}
// Retrieve will return the credentials Value if they were successfully retrieved
// from the underlying plugin provider. An error will be returned otherwise.
func (p Provider) Retrieve() (credentials.Value, error) {
creds := credentials.Value{
ProviderName: ProviderName,
}
k, s, t, err := p.RetrieveFn()
if err != nil {
return creds, awserr.New(ErrCodePluginProviderRetrieve,
"failed to retrieve credentials with plugin provider", err)
}
creds.AccessKeyID = k
creds.SecretAccessKey = s
creds.SessionToken = t
return creds, nil
}
// IsExpired will return the expired state of the underlying plugin provider.
func (p Provider) IsExpired() bool {
return p.IsExpiredFn()
}
// GetPluginProviderFns returns the plugin's Retrieve and IsExpired functions
// returned by the plugin's credential provider getter.
//
// Uses ProviderSymbolName as the symbol name when lookup up the symbol. If you
// want to use a different symbol name, use GetPluginProviderFnsByName.
func GetPluginProviderFns(p *plugin.Plugin) (func() (key, secret, token string, err error), func() bool, error) {
return GetPluginProviderFnsByName(p, ProviderSymbolName)
}
// GetPluginProviderFnsByName returns the plugin's Retrieve and IsExpired functions
// returned by the plugin's credential provider getter.
//
// Same as GetPluginProviderFns, but takes a custom symbolName to lookup with.
func GetPluginProviderFnsByName(p *plugin.Plugin, symbolName string) (func() (key, secret, token string, err error), func() bool, error) {
sym, err := p.Lookup(symbolName)
if err != nil {
return nil, nil, awserr.New(ErrCodeLookupSymbolError,
fmt.Sprintf("failed to lookup %s plugin provider symbol", symbolName), err)
}
fn, ok := sym.(func() (func() (key, secret, token string, err error), func() bool))
if !ok {
return nil, nil, awserr.New(ErrCodeInvalidSymbolError,
fmt.Sprintf("symbol %T, does not match the 'func() (func() (key, secret, token string, err error), func() bool)' type", sym), nil)
}
retrieveFn, isExpiredFn := fn()
if retrieveFn == nil {
return nil, nil, awserr.New(ErrCodePluginRetrieveNil,
"the plugin provider retrieve function cannot be nil", nil)
}
if isExpiredFn == nil {
return nil, nil, awserr.New(ErrCodePluginIsExpiredNil,
"the plugin provider isExpired function cannot be nil", nil)
}
return retrieveFn, isExpiredFn, nil
}

View File

@@ -0,0 +1,71 @@
// +build go1.8,awsinclude
package plugincreds
import (
"fmt"
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
)
func TestProvider_Passthrough(t *testing.T) {
p := Provider{
RetrieveFn: func() (string, string, string, error) {
return "key", "secret", "token", nil
},
IsExpiredFn: func() bool {
return false
},
}
actual, err := p.Retrieve()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
expect := credentials.Value{
AccessKeyID: "key",
SecretAccessKey: "secret",
SessionToken: "token",
ProviderName: ProviderName,
}
if expect != actual {
t.Errorf("expect %+v credentials, got %+v", expect, actual)
}
}
func TestProvider_Error(t *testing.T) {
expectErr := fmt.Errorf("expect error")
p := Provider{
RetrieveFn: func() (string, string, string, error) {
return "", "", "", expectErr
},
IsExpiredFn: func() bool {
return false
},
}
actual, err := p.Retrieve()
if err == nil {
t.Fatalf("expect error, got none")
}
aerr := err.(awserr.Error)
if e, a := ErrCodePluginProviderRetrieve, aerr.Code(); e != a {
t.Errorf("expect %s error code, got %s", e, a)
}
if e, a := expectErr, aerr.OrigErr(); e != a {
t.Errorf("expect %v cause error, got %v", e, a)
}
expect := credentials.Value{
ProviderName: ProviderName,
}
if expect != actual {
t.Errorf("expect %+v credentials, got %+v", expect, actual)
}
}

View File

@@ -3,17 +3,17 @@ package credentials
import (
"fmt"
"os"
"path/filepath"
"github.com/go-ini/ini"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/ini"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
// SharedCredsProviderName provides a name of SharedCreds provider
const SharedCredsProviderName = "SharedCredentialsProvider"
var (
// ErrSharedCredentialsHomeNotFound is emitted when the user directory cannot be found.
//
// @readonly
ErrSharedCredentialsHomeNotFound = awserr.New("UserHomeNotFound", "user home directory not found.", nil)
)
@@ -55,12 +55,12 @@ func (p *SharedCredentialsProvider) Retrieve() (Value, error) {
filename, err := p.filename()
if err != nil {
return Value{}, err
return Value{ProviderName: SharedCredsProviderName}, err
}
creds, err := loadProfile(filename, p.profile())
if err != nil {
return Value{}, err
return Value{ProviderName: SharedCredsProviderName}, err
}
p.retrieved = true
@@ -76,36 +76,38 @@ func (p *SharedCredentialsProvider) IsExpired() bool {
// The credentials retrieved from the profile will be returned or error. Error will be
// returned if it fails to read from the file, or the data is invalid.
func loadProfile(filename, profile string) (Value, error) {
config, err := ini.Load(filename)
config, err := ini.OpenFile(filename)
if err != nil {
return Value{}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
}
iniProfile, err := config.GetSection(profile)
if err != nil {
return Value{}, awserr.New("SharedCredsLoad", "failed to get profile", err)
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
}
id, err := iniProfile.GetKey("aws_access_key_id")
if err != nil {
return Value{}, awserr.New("SharedCredsAccessKey",
iniProfile, ok := config.GetSection(profile)
if !ok {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", nil)
}
id := iniProfile.String("aws_access_key_id")
if len(id) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey",
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
err)
nil)
}
secret, err := iniProfile.GetKey("aws_secret_access_key")
if err != nil {
return Value{}, awserr.New("SharedCredsSecret",
secret := iniProfile.String("aws_secret_access_key")
if len(secret) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret",
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
nil)
}
// Default to empty string if not found
token := iniProfile.Key("aws_session_token")
token := iniProfile.String("aws_session_token")
return Value{
AccessKeyID: id.String(),
SecretAccessKey: secret.String(),
SessionToken: token.String(),
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: token,
ProviderName: SharedCredsProviderName,
}, nil
}
@@ -113,22 +115,23 @@ func loadProfile(filename, profile string) (Value, error) {
//
// Will return an error if the user's home directory path cannot be found.
func (p *SharedCredentialsProvider) filename() (string, error) {
if p.Filename == "" {
if p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); p.Filename != "" {
return p.Filename, nil
}
homeDir := os.Getenv("HOME") // *nix
if homeDir == "" { // Windows
homeDir = os.Getenv("USERPROFILE")
}
if homeDir == "" {
return "", ErrSharedCredentialsHomeNotFound
}
p.Filename = filepath.Join(homeDir, ".aws", "credentials")
if len(p.Filename) != 0 {
return p.Filename, nil
}
if p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); len(p.Filename) != 0 {
return p.Filename, nil
}
if home := shareddefaults.UserHomeDir(); len(home) == 0 {
// Backwards compatibility of home directly not found error being returned.
// This error is too verbose, failure when opening the file would of been
// a better error to return.
return "", ErrSharedCredentialsHomeNotFound
}
p.Filename = shareddefaults.SharedCredentialsFilename()
return p.Filename, nil
}

View File

@@ -1,9 +1,12 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"path/filepath"
"testing"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
"github.com/stretchr/testify/assert"
)
func TestSharedCredentialsProvider(t *testing.T) {
@@ -44,6 +47,20 @@ func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILE(t *testing.T)
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILEAbsPath(t *testing.T) {
os.Clearenv()
wd, err := os.Getwd()
assert.NoError(t, err)
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join(wd, "example.ini"))
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
os.Clearenv()
os.Setenv("AWS_PROFILE", "no_token")
@@ -81,6 +98,25 @@ func TestSharedCredentialsProviderColonInCredFile(t *testing.T) {
assert.Empty(t, creds.SessionToken, "Expect no token")
}
func TestSharedCredentialsProvider_DefaultFilename(t *testing.T) {
os.Clearenv()
os.Setenv("USERPROFILE", "profile_dir")
os.Setenv("HOME", "home_dir")
// default filename and profile
p := SharedCredentialsProvider{}
filename, err := p.filename()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := shareddefaults.SharedCredentialsFilename(), filename; e != a {
t.Errorf("expect %q filename, got %q", e, a)
}
}
func BenchmarkSharedCredentialsProvider(b *testing.B) {
os.Clearenv()

View File

@@ -4,14 +4,15 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
)
// StaticProviderName provides a name of Static provider
const StaticProviderName = "StaticProvider"
var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
//
// @readonly
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
)
// A StaticProvider is a set of credentials which are set pragmatically,
// A StaticProvider is a set of credentials which are set programmatically,
// and will never expire.
type StaticProvider struct {
Value
@@ -27,12 +28,22 @@ func NewStaticCredentials(id, secret, token string) *Credentials {
}})
}
// NewStaticCredentialsFromCreds returns a pointer to a new Credentials object
// wrapping the static credentials value provide. Same as NewStaticCredentials
// but takes the creds Value instead of individual fields
func NewStaticCredentialsFromCreds(creds Value) *Credentials {
return NewCredentials(&StaticProvider{Value: creds})
}
// Retrieve returns the credentials or error if the credentials are invalid.
func (s *StaticProvider) Retrieve() (Value, error) {
if s.AccessKeyID == "" || s.SecretAccessKey == "" {
return Value{}, ErrStaticCredentialsEmpty
return Value{ProviderName: StaticProviderName}, ErrStaticCredentialsEmpty
}
if len(s.Value.ProviderName) == 0 {
s.Value.ProviderName = StaticProviderName
}
return s.Value, nil
}

View File

@@ -1,7 +1,81 @@
// Package stscreds are credential Providers to retrieve STS AWS credentials.
//
// STS provides multiple ways to retrieve credentials which can be used when making
// future AWS service API operation calls.
/*
Package stscreds are credential Providers to retrieve STS AWS credentials.
STS provides multiple ways to retrieve credentials which can be used when making
future AWS service API operation calls.
The SDK will ensure that per instance of credentials.Credentials all requests
to refresh the credentials will be synchronized. But, the SDK is unable to
ensure synchronous usage of the AssumeRoleProvider if the value is shared
between multiple Credentials, Sessions or service clients.
Assume Role
To assume an IAM role using STS with the SDK you can create a new Credentials
with the SDKs's stscreds package.
// Initial credentials loaded from SDK's default credential chain. Such as
// the environment, shared credentials (~/.aws/credentials), or EC2 Instance
// Role. These credentials will be used to to make the STS Assume Role API.
sess := session.Must(session.NewSession())
// Create the credentials from AssumeRoleProvider to assume the role
// referenced by the "myRoleARN" ARN.
creds := stscreds.NewCredentials(sess, "myRoleArn")
// Create service client value configured for credentials
// from assumed role.
svc := s3.New(sess, &aws.Config{Credentials: creds})
Assume Role with static MFA Token
To assume an IAM role with a MFA token you can either specify a MFA token code
directly or provide a function to prompt the user each time the credentials
need to refresh the role's credentials. Specifying the TokenCode should be used
for short lived operations that will not need to be refreshed, and when you do
not want to have direct control over the user provides their MFA token.
With TokenCode the AssumeRoleProvider will be not be able to refresh the role's
credentials.
// Create the credentials from AssumeRoleProvider to assume the role
// referenced by the "myRoleARN" ARN using the MFA token code provided.
creds := stscreds.NewCredentials(sess, "myRoleArn", func(p *stscreds.AssumeRoleProvider) {
p.SerialNumber = aws.String("myTokenSerialNumber")
p.TokenCode = aws.String("00000000")
})
// Create service client value configured for credentials
// from assumed role.
svc := s3.New(sess, &aws.Config{Credentials: creds})
Assume Role with MFA Token Provider
To assume an IAM role with MFA for longer running tasks where the credentials
may need to be refreshed setting the TokenProvider field of AssumeRoleProvider
will allow the credential provider to prompt for new MFA token code when the
role's credentials need to be refreshed.
The StdinTokenProvider function is available to prompt on stdin to retrieve
the MFA token code from the user. You can also implement custom prompts by
satisfing the TokenProvider function signature.
Using StdinTokenProvider with multiple AssumeRoleProviders, or Credentials will
have undesirable results as the StdinTokenProvider will not be synchronized. A
single Credentials with an AssumeRoleProvider can be shared safely.
// Create the credentials from AssumeRoleProvider to assume the role
// referenced by the "myRoleARN" ARN. Prompting for MFA token from stdin.
creds := stscreds.NewCredentials(sess, "myRoleArn", func(p *stscreds.AssumeRoleProvider) {
p.SerialNumber = aws.String("myTokenSerialNumber")
p.TokenProvider = stscreds.StdinTokenProvider
})
// Create service client value configured for credentials
// from assumed role.
svc := s3.New(sess, &aws.Config{Credentials: creds})
*/
package stscreds
import (
@@ -9,11 +83,34 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
)
// StdinTokenProvider will prompt on stdout and read from stdin for a string value.
// An error is returned if reading from stdin fails.
//
// Use this function go read MFA tokens from stdin. The function makes no attempt
// to make atomic prompts from stdin across multiple gorouties.
//
// Using StdinTokenProvider with multiple AssumeRoleProviders, or Credentials will
// have undesirable results as the StdinTokenProvider will not be synchronized. A
// single Credentials with an AssumeRoleProvider can be shared safely
//
// Will wait forever until something is provided on the stdin.
func StdinTokenProvider() (string, error) {
var v string
fmt.Printf("Assume Role MFA token code: ")
_, err := fmt.Scanln(&v)
return v, err
}
// ProviderName provides a name of AssumeRole provider
const ProviderName = "AssumeRoleProvider"
// AssumeRoler represents the minimal subset of the STS client API used by this provider.
type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
@@ -24,8 +121,15 @@ type AssumeRoler interface {
var DefaultDuration = time.Duration(15) * time.Minute
// AssumeRoleProvider retrieves temporary credentials from the STS service, and
// keeps track of their expiration time. This provider must be used explicitly,
// as it is not included in the credentials chain.
// keeps track of their expiration time.
//
// This credential provider will be used by the SDKs default credential change
// when shared configuration is enabled, and the shared config or shared credentials
// file configure assume role. See Session docs for how to do this.
//
// AssumeRoleProvider does not provide any synchronization and it is not safe
// to share this value across multiple Credentials, Sessions, or service clients
// without also sharing the same Credentials instance.
type AssumeRoleProvider struct {
credentials.Expiry
@@ -44,6 +148,41 @@ type AssumeRoleProvider struct {
// Optional ExternalID to pass along, defaults to nil if not set.
ExternalID *string
// The policy plain text must be 2048 bytes or shorter. However, an internal
// conversion compresses it into a packed binary format with a separate limit.
// The PackedPolicySize response element indicates by percentage how close to
// the upper size limit the policy is, with 100% equaling the maximum allowed
// size.
Policy *string
// The identification number of the MFA device that is associated with the user
// who is making the AssumeRole call. Specify this value if the trust policy
// of the role being assumed includes a condition that requires MFA authentication.
// The value is either the serial number for a hardware device (such as GAHT12345678)
// or an Amazon Resource Name (ARN) for a virtual device (such as arn:aws:iam::123456789012:mfa/user).
SerialNumber *string
// The value provided by the MFA device, if the trust policy of the role being
// assumed requires MFA (that is, if the policy includes a condition that tests
// for MFA). If the role being assumed requires MFA and if the TokenCode value
// is missing or expired, the AssumeRole call returns an "access denied" error.
//
// If SerialNumber is set and neither TokenCode nor TokenProvider are also
// set an error will be returned.
TokenCode *string
// Async method of providing MFA token code for assuming an IAM role with MFA.
// The value returned by the function will be used as the TokenCode in the Retrieve
// call. See StdinTokenProvider for a provider that prompts and reads from stdin.
//
// This token provider will be called when ever the assumed role's
// credentials need to be refreshed when SerialNumber is also set and
// TokenCode is not set.
//
// If both TokenCode and TokenProvider is set, TokenProvider will be used and
// TokenCode is ignored.
TokenProvider func() (string, error)
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
@@ -62,6 +201,10 @@ type AssumeRoleProvider struct {
//
// Takes a Config provider to create the STS client. The ConfigProvider is
// satisfied by the session.Session type.
//
// It is safe to share the returned Credentials with multiple Sessions and
// service clients. All access to the credentials and refreshing them
// will be synchronized.
func NewCredentials(c client.ConfigProvider, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials {
p := &AssumeRoleProvider{
Client: sts.New(c),
@@ -80,7 +223,11 @@ func NewCredentials(c client.ConfigProvider, roleARN string, options ...func(*As
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// Takes an AssumeRoler which can be satisfiede by the STS client.
// Takes an AssumeRoler which can be satisfied by the STS client.
//
// It is safe to share the returned Credentials with multiple Sessions and
// service clients. All access to the credentials and refreshing them
// will be synchronized.
func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials {
p := &AssumeRoleProvider{
Client: svc,
@@ -107,16 +254,36 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Expire as often as AWS permits.
p.Duration = DefaultDuration
}
roleOutput, err := p.Client.AssumeRole(&sts.AssumeRoleInput{
input := &sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID,
})
}
if p.Policy != nil {
input.Policy = p.Policy
}
if p.SerialNumber != nil {
if p.TokenCode != nil {
input.SerialNumber = p.SerialNumber
input.TokenCode = p.TokenCode
} else if p.TokenProvider != nil {
input.SerialNumber = p.SerialNumber
code, err := p.TokenProvider()
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
input.TokenCode = aws.String(code)
} else {
return credentials.Value{ProviderName: ProviderName},
awserr.New("AssumeRoleTokenNotAvailable",
"assume role with MFA enabled, but neither TokenCode nor TokenProvider are set", nil)
}
}
roleOutput, err := p.Client.AssumeRole(input)
if err != nil {
return credentials.Value{}, err
return credentials.Value{ProviderName: ProviderName}, err
}
// We will proactively generate new credentials before they expire.
@@ -126,5 +293,6 @@ func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
AccessKeyID: *roleOutput.Credentials.AccessKeyId,
SecretAccessKey: *roleOutput.Credentials.SecretAccessKey,
SessionToken: *roleOutput.Credentials.SessionToken,
ProviderName: ProviderName,
}, nil
}

View File

@@ -1,6 +1,7 @@
package stscreds
import (
"fmt"
"testing"
"time"
@@ -10,9 +11,13 @@ import (
)
type stubSTS struct {
TestInput func(*sts.AssumeRoleInput)
}
func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
if s.TestInput != nil {
s.TestInput(input)
}
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleOutput{
Credentials: &sts.Credentials{
@@ -40,6 +45,95 @@ func TestAssumeRoleProvider(t *testing.T) {
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func TestAssumeRoleProvider_WithTokenCode(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Equal(t, "0123456789", *in.SerialNumber)
assert.Equal(t, "code", *in.TokenCode)
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
TokenCode: aws.String("code"),
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Equal(t, "0123456789", *in.SerialNumber)
assert.Equal(t, "code", *in.TokenCode)
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
TokenProvider: func() (string, error) {
return "code", nil
},
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Fail(t, "API request should not of been called")
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
TokenProvider: func() (string, error) {
return "", fmt.Errorf("error occurred")
},
}
creds, err := p.Retrieve()
assert.Error(t, err)
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
}
func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Fail(t, "API request should not of been called")
},
}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
SerialNumber: aws.String("0123456789"),
}
creds, err := p.Retrieve()
assert.Error(t, err)
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
}
func BenchmarkAssumeRoleProvider(b *testing.B) {
stub := &stubSTS{}
p := &AssumeRoleProvider{

110
vendor/github.com/aws/aws-sdk-go/aws/crr/cache.go generated vendored Normal file
View File

@@ -0,0 +1,110 @@
package crr
import (
"sync/atomic"
)
// EndpointCache is an LRU cache that holds a series of endpoints
// based on some key. The datastructure makes use of a read write
// mutex to enable asynchronous use.
type EndpointCache struct {
endpoints syncMap
endpointLimit int64
// size is used to count the number elements in the cache.
// The atomic package is used to ensure this size is accurate when
// using multiple goroutines.
size int64
}
// NewEndpointCache will return a newly initialized cache with a limit
// of endpointLimit entries.
func NewEndpointCache(endpointLimit int64) *EndpointCache {
return &EndpointCache{
endpointLimit: endpointLimit,
endpoints: newSyncMap(),
}
}
// get is a concurrent safe get operation that will retrieve an endpoint
// based on endpointKey. A boolean will also be returned to illustrate whether
// or not the endpoint had been found.
func (c *EndpointCache) get(endpointKey string) (Endpoint, bool) {
endpoint, ok := c.endpoints.Load(endpointKey)
if !ok {
return Endpoint{}, false
}
c.endpoints.Store(endpointKey, endpoint)
return endpoint.(Endpoint), true
}
// Get will retrieve a weighted address based off of the endpoint key. If an endpoint
// should be retrieved, due to not existing or the current endpoint has expired
// the Discoverer object that was passed in will attempt to discover a new endpoint
// and add that to the cache.
func (c *EndpointCache) Get(d Discoverer, endpointKey string, required bool) (WeightedAddress, error) {
var err error
endpoint, ok := c.get(endpointKey)
weighted, found := endpoint.GetValidAddress()
shouldGet := !ok || !found
if required && shouldGet {
if endpoint, err = c.discover(d, endpointKey); err != nil {
return WeightedAddress{}, err
}
weighted, _ = endpoint.GetValidAddress()
} else if shouldGet {
go c.discover(d, endpointKey)
}
return weighted, nil
}
// Add is a concurrent safe operation that will allow new endpoints to be added
// to the cache. If the cache is full, the number of endpoints equal endpointLimit,
// then this will remove the oldest entry before adding the new endpoint.
func (c *EndpointCache) Add(endpoint Endpoint) {
// de-dups multiple adds of an endpoint with a pre-existing key
if iface, ok := c.endpoints.Load(endpoint.Key); ok {
e := iface.(Endpoint)
if e.Len() > 0 {
return
}
}
c.endpoints.Store(endpoint.Key, endpoint)
size := atomic.AddInt64(&c.size, 1)
if size > 0 && size > c.endpointLimit {
c.deleteRandomKey()
}
}
// deleteRandomKey will delete a random key from the cache. If
// no key was deleted false will be returned.
func (c *EndpointCache) deleteRandomKey() bool {
atomic.AddInt64(&c.size, -1)
found := false
c.endpoints.Range(func(key, value interface{}) bool {
found = true
c.endpoints.Delete(key)
return false
})
return found
}
// discover will get and store and endpoint using the Discoverer.
func (c *EndpointCache) discover(d Discoverer, endpointKey string) (Endpoint, error) {
endpoint, err := d.Discover()
if err != nil {
return Endpoint{}, err
}
endpoint.Key = endpointKey
c.Add(endpoint)
return endpoint, nil
}

452
vendor/github.com/aws/aws-sdk-go/aws/crr/cache_test.go generated vendored Normal file
View File

@@ -0,0 +1,452 @@
package crr
import (
"net/url"
"reflect"
"testing"
)
func urlParse(uri string) *url.URL {
u, _ := url.Parse(uri)
return u
}
func TestCacheAdd(t *testing.T) {
cases := []struct {
limit int64
endpoints []Endpoint
validKeys map[string]Endpoint
expectedSize int
}{
{
limit: 5,
endpoints: []Endpoint{
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
expectedSize: 5,
},
{
limit: 2,
endpoints: []Endpoint{
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
expectedSize: 2,
},
}
for _, c := range cases {
cache := NewEndpointCache(c.limit)
for _, endpoint := range c.endpoints {
cache.Add(endpoint)
}
count := 0
endpoints := map[string]Endpoint{}
cache.endpoints.Range(func(key, value interface{}) bool {
count++
endpoints[key.(string)] = value.(Endpoint)
return true
})
if e, a := c.expectedSize, cache.size; int64(e) != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := c.expectedSize, count; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
for k, ep := range endpoints {
endpoint, ok := c.validKeys[k]
if !ok {
t.Errorf("unrecognized key %q in cache", k)
}
if e, a := endpoint, ep; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
}
}
func TestCacheGet(t *testing.T) {
cases := []struct {
addEndpoints []Endpoint
validKeys map[string]Endpoint
limit int64
}{
{
limit: 5,
addEndpoints: []Endpoint{
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
},
{
limit: 2,
addEndpoints: []Endpoint{
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
},
}
for _, c := range cases {
cache := NewEndpointCache(c.limit)
for _, endpoint := range c.addEndpoints {
cache.Add(endpoint)
}
keys := []string{}
cache.endpoints.Range(func(key, value interface{}) bool {
a := value.(Endpoint)
e, ok := c.validKeys[key.(string)]
if !ok {
t.Errorf("unrecognized key %q in cache", key.(string))
}
if !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
keys = append(keys, key.(string))
return true
})
for _, key := range keys {
a, ok := cache.get(key)
if !ok {
t.Errorf("expected key to be present: %q", key)
}
e := c.validKeys[key]
if !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
}
}

99
vendor/github.com/aws/aws-sdk-go/aws/crr/endpoint.go generated vendored Normal file
View File

@@ -0,0 +1,99 @@
package crr
import (
"net/url"
"sort"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
)
// Endpoint represents an endpoint used in endpoint discovery.
type Endpoint struct {
Key string
Addresses WeightedAddresses
}
// WeightedAddresses represents a list of WeightedAddress.
type WeightedAddresses []WeightedAddress
// WeightedAddress represents an address with a given weight.
type WeightedAddress struct {
URL *url.URL
Expired time.Time
}
// HasExpired will return whether or not the endpoint has expired with
// the exception of a zero expiry meaning does not expire.
func (e WeightedAddress) HasExpired() bool {
return e.Expired.Before(time.Now())
}
// Add will add a given WeightedAddress to the address list of Endpoint.
func (e *Endpoint) Add(addr WeightedAddress) {
e.Addresses = append(e.Addresses, addr)
}
// Len returns the number of valid endpoints where valid means the endpoint
// has not expired.
func (e *Endpoint) Len() int {
validEndpoints := 0
for _, endpoint := range e.Addresses {
if endpoint.HasExpired() {
continue
}
validEndpoints++
}
return validEndpoints
}
// GetValidAddress will return a non-expired weight endpoint
func (e *Endpoint) GetValidAddress() (WeightedAddress, bool) {
for i := 0; i < len(e.Addresses); i++ {
we := e.Addresses[i]
if we.HasExpired() {
e.Addresses = append(e.Addresses[:i], e.Addresses[i+1:]...)
i--
continue
}
return we, true
}
return WeightedAddress{}, false
}
// Discoverer is an interface used to discovery which endpoint hit. This
// allows for specifics about what parameters need to be used to be contained
// in the Discoverer implementor.
type Discoverer interface {
Discover() (Endpoint, error)
}
// BuildEndpointKey will sort the keys in alphabetical order and then retrieve
// the values in that order. Those values are then concatenated together to form
// the endpoint key.
func BuildEndpointKey(params map[string]*string) string {
keys := make([]string, len(params))
i := 0
for k := range params {
keys[i] = k
i++
}
sort.Strings(keys)
values := make([]string, len(params))
for i, k := range keys {
if params[k] == nil {
continue
}
values[i] = aws.StringValue(params[k])
}
return strings.Join(values, ".")
}

29
vendor/github.com/aws/aws-sdk-go/aws/crr/sync_map.go generated vendored Normal file
View File

@@ -0,0 +1,29 @@
// +build go1.9
package crr
import (
"sync"
)
type syncMap sync.Map
func newSyncMap() syncMap {
return syncMap{}
}
func (m *syncMap) Load(key interface{}) (interface{}, bool) {
return (*sync.Map)(m).Load(key)
}
func (m *syncMap) Store(key interface{}, value interface{}) {
(*sync.Map)(m).Store(key, value)
}
func (m *syncMap) Delete(key interface{}) {
(*sync.Map)(m).Delete(key)
}
func (m *syncMap) Range(f func(interface{}, interface{}) bool) {
(*sync.Map)(m).Range(f)
}

View File

@@ -0,0 +1,48 @@
// +build !go1.9
package crr
import (
"sync"
)
type syncMap struct {
container map[interface{}]interface{}
lock sync.RWMutex
}
func newSyncMap() syncMap {
return syncMap{
container: map[interface{}]interface{}{},
}
}
func (m *syncMap) Load(key interface{}) (interface{}, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
v, ok := m.container[key]
return v, ok
}
func (m *syncMap) Store(key interface{}, value interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
m.container[key] = value
}
func (m *syncMap) Delete(key interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.container, key)
}
func (m *syncMap) Range(f func(interface{}, interface{}) bool) {
for k, v := range m.container {
if !f(k, v) {
return
}
}
}

View File

@@ -0,0 +1,110 @@
package crr
import (
"reflect"
"testing"
)
func TestRangeDelete(t *testing.T) {
m := newSyncMap()
for i := 0; i < 10; i++ {
m.Store(i, i*10)
}
m.Range(func(key, value interface{}) bool {
m.Delete(key)
return true
})
expectedMap := map[interface{}]interface{}{}
actualMap := map[interface{}]interface{}{}
m.Range(func(key, value interface{}) bool {
actualMap[key] = value
return true
})
if e, a := len(expectedMap), len(actualMap); e != a {
t.Errorf("expected map size %d, but received %d", e, a)
}
if e, a := expectedMap, actualMap; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
func TestRangeStore(t *testing.T) {
m := newSyncMap()
for i := 0; i < 10; i++ {
m.Store(i, i*10)
}
m.Range(func(key, value interface{}) bool {
v := value.(int)
m.Store(key, v+1)
return true
})
expectedMap := map[interface{}]interface{}{
0: 1,
1: 11,
2: 21,
3: 31,
4: 41,
5: 51,
6: 61,
7: 71,
8: 81,
9: 91,
}
actualMap := map[interface{}]interface{}{}
m.Range(func(key, value interface{}) bool {
actualMap[key] = value
return true
})
if e, a := len(expectedMap), len(actualMap); e != a {
t.Errorf("expected map size %d, but received %d", e, a)
}
if e, a := expectedMap, actualMap; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
func TestRangeGet(t *testing.T) {
m := newSyncMap()
for i := 0; i < 10; i++ {
m.Store(i, i*10)
}
m.Range(func(key, value interface{}) bool {
m.Load(key)
return true
})
expectedMap := map[interface{}]interface{}{
0: 0,
1: 10,
2: 20,
3: 30,
4: 40,
5: 50,
6: 60,
7: 70,
8: 80,
9: 90,
}
actualMap := map[interface{}]interface{}{}
m.Range(func(key, value interface{}) bool {
actualMap[key] = value
return true
})
if e, a := len(expectedMap), len(actualMap); e != a {
t.Errorf("expected map size %d, but received %d", e, a)
}
if e, a := expectedMap, actualMap; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}

46
vendor/github.com/aws/aws-sdk-go/aws/csm/doc.go generated vendored Normal file
View File

@@ -0,0 +1,46 @@
// Package csm provides Client Side Monitoring (CSM) which enables sending metrics
// via UDP connection. Using the Start function will enable the reporting of
// metrics on a given port. If Start is called, with different parameters, again,
// a panic will occur.
//
// Pause can be called to pause any metrics publishing on a given port. Sessions
// that have had their handlers modified via InjectHandlers may still be used.
// However, the handlers will act as a no-op meaning no metrics will be published.
//
// Example:
// r, err := csm.Start("clientID", ":31000")
// if err != nil {
// panic(fmt.Errorf("failed starting CSM: %v", err))
// }
//
// sess, err := session.NewSession(&aws.Config{})
// if err != nil {
// panic(fmt.Errorf("failed loading session: %v", err))
// }
//
// r.InjectHandlers(&sess.Handlers)
//
// client := s3.New(sess)
// resp, err := client.GetObject(&s3.GetObjectInput{
// Bucket: aws.String("bucket"),
// Key: aws.String("key"),
// })
//
// // Will pause monitoring
// r.Pause()
// resp, err = client.GetObject(&s3.GetObjectInput{
// Bucket: aws.String("bucket"),
// Key: aws.String("key"),
// })
//
// // Resume monitoring
// r.Continue()
//
// Start returns a Reporter that is used to enable or disable monitoring. If
// access to the Reporter is required later, calling Get will return the Reporter
// singleton.
//
// Example:
// r := csm.Get()
// r.Continue()
package csm

67
vendor/github.com/aws/aws-sdk-go/aws/csm/enable.go generated vendored Normal file
View File

@@ -0,0 +1,67 @@
package csm
import (
"fmt"
"sync"
)
var (
lock sync.Mutex
)
// Client side metric handler names
const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
)
// Start will start the a long running go routine to capture
// client side metrics. Calling start multiple time will only
// start the metric listener once and will panic if a different
// client ID or port is passed in.
//
// Example:
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err))
// }
// sess := session.NewSession()
// r.InjectHandlers(sess.Handlers)
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput{
// Bucket: aws.String("bucket"),
// Key: aws.String("key"),
// })
func Start(clientID string, url string) (*Reporter, error) {
lock.Lock()
defer lock.Unlock()
if sender == nil {
sender = newReporter(clientID, url)
} else {
if sender.clientID != clientID {
panic(fmt.Errorf("inconsistent client IDs. %q was expected, but received %q", sender.clientID, clientID))
}
if sender.url != url {
panic(fmt.Errorf("inconsistent URLs. %q was expected, but received %q", sender.url, url))
}
}
if err := connect(url); err != nil {
sender = nil
return nil, err
}
return sender, nil
}
// Get will return a reporter if one exists, if one does not exist, nil will
// be returned.
func Get() *Reporter {
lock.Lock()
defer lock.Unlock()
return sender
}

View File

@@ -0,0 +1,74 @@
package csm
import (
"encoding/json"
"fmt"
"net"
"testing"
)
func startUDPServer(done chan struct{}, fn func([]byte)) (string, error) {
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
return "", err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return "", err
}
buf := make([]byte, 1024)
go func() {
defer conn.Close()
for {
select {
case <-done:
return
default:
}
n, _, err := conn.ReadFromUDP(buf)
fn(buf[:n])
if err != nil {
panic(err)
}
}
}()
return conn.LocalAddr().String(), nil
}
func TestDifferentParams(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("expected panic with different parameters")
}
}()
Start("clientID2", ":0")
}
var MetricsCh = make(chan map[string]interface{}, 1)
var Done = make(chan struct{})
func init() {
url, err := startUDPServer(Done, func(b []byte) {
m := map[string]interface{}{}
if err := json.Unmarshal(b, &m); err != nil {
panic(fmt.Sprintf("expected no error, but received %v", err))
}
MetricsCh <- m
})
if err != nil {
panic(err)
}
_, err = Start("clientID", url)
if err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,40 @@
package csm_test
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
)
func ExampleStart() {
r, err := csm.Start("clientID", ":31000")
if err != nil {
panic(fmt.Errorf("failed starting CSM: %v", err))
}
sess, err := session.NewSession(&aws.Config{})
if err != nil {
panic(fmt.Errorf("failed loading session: %v", err))
}
r.InjectHandlers(&sess.Handlers)
client := s3.New(sess)
client.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
// Pauses monitoring
r.Pause()
client.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
// Resume monitoring
r.Continue()
}

53
vendor/github.com/aws/aws-sdk-go/aws/csm/metric.go generated vendored Normal file
View File

@@ -0,0 +1,53 @@
package csm
import (
"strconv"
"time"
)
type metricTime time.Time
func (t metricTime) MarshalJSON() ([]byte, error) {
ns := time.Duration(time.Time(t).UnixNano())
return []byte(strconv.FormatInt(int64(ns/time.Millisecond), 10)), nil
}
type metric struct {
ClientID *string `json:"ClientId,omitempty"`
API *string `json:"Api,omitempty"`
Service *string `json:"Service,omitempty"`
Timestamp *metricTime `json:"Timestamp,omitempty"`
Type *string `json:"Type,omitempty"`
Version *int `json:"Version,omitempty"`
AttemptCount *int `json:"AttemptCount,omitempty"`
Latency *int `json:"Latency,omitempty"`
Fqdn *string `json:"Fqdn,omitempty"`
UserAgent *string `json:"UserAgent,omitempty"`
AttemptLatency *int `json:"AttemptLatency,omitempty"`
SessionToken *string `json:"SessionToken,omitempty"`
Region *string `json:"Region,omitempty"`
AccessKey *string `json:"AccessKey,omitempty"`
HTTPStatusCode *int `json:"HttpStatusCode,omitempty"`
XAmzID2 *string `json:"XAmzId2,omitempty"`
XAmzRequestID *string `json:"XAmznRequestId,omitempty"`
AWSException *string `json:"AwsException,omitempty"`
AWSExceptionMessage *string `json:"AwsExceptionMessage,omitempty"`
SDKException *string `json:"SdkException,omitempty"`
SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"`
DestinationIP *string `json:"DestinationIp,omitempty"`
ConnectionReused *int `json:"ConnectionReused,omitempty"`
AcquireConnectionLatency *int `json:"AcquireConnectionLatency,omitempty"`
ConnectLatency *int `json:"ConnectLatency,omitempty"`
RequestLatency *int `json:"RequestLatency,omitempty"`
DNSLatency *int `json:"DnsLatency,omitempty"`
TCPLatency *int `json:"TcpLatency,omitempty"`
SSLLatency *int `json:"SslLatency,omitempty"`
MaxRetriesExceeded *int `json:"MaxRetriesExceeded,omitempty"`
}

View File

@@ -0,0 +1,54 @@
package csm
import (
"sync/atomic"
)
const (
runningEnum = iota
pausedEnum
)
var (
// MetricsChannelSize of metrics to hold in the channel
MetricsChannelSize = 100
)
type metricChan struct {
ch chan metric
paused int64
}
func newMetricChan(size int) metricChan {
return metricChan{
ch: make(chan metric, size),
}
}
func (ch *metricChan) Pause() {
atomic.StoreInt64(&ch.paused, pausedEnum)
}
func (ch *metricChan) Continue() {
atomic.StoreInt64(&ch.paused, runningEnum)
}
func (ch *metricChan) IsPaused() bool {
v := atomic.LoadInt64(&ch.paused)
return v == pausedEnum
}
// Push will push metrics to the metric channel if the channel
// is not paused
func (ch *metricChan) Push(m metric) bool {
if ch.IsPaused() {
return false
}
select {
case ch.ch <- m:
return true
default:
return false
}
}

View File

@@ -0,0 +1,72 @@
package csm
import (
"testing"
)
func TestMetricChanPush(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
pushed := ch.Push(metric{})
if !pushed {
t.Errorf("expected metrics to be pushed")
}
if e, a := 1, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanPauseContinue(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
ch.Pause()
if !ch.IsPaused() {
t.Errorf("expected to be paused, but did not pause properly")
}
ch.Continue()
if ch.IsPaused() {
t.Errorf("expected to be not paused, but did not continue properly")
}
pushed := ch.Push(metric{})
if !pushed {
t.Errorf("expected metrics to be pushed")
}
if e, a := 1, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanPushWhenPaused(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
ch.Pause()
pushed := ch.Push(metric{})
if pushed {
t.Errorf("expected metrics to not be pushed")
}
if e, a := 0, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanNonBlocking(t *testing.T) {
ch := newMetricChan(0)
defer close(ch.ch)
pushed := ch.Push(metric{})
if pushed {
t.Errorf("expected metrics to be not pushed")
}
if e, a := 0, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}

242
vendor/github.com/aws/aws-sdk-go/aws/csm/reporter.go generated vendored Normal file
View File

@@ -0,0 +1,242 @@
package csm
import (
"encoding/json"
"net"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
const (
// DefaultPort is used when no port is specified
DefaultPort = "31000"
)
// Reporter will gather metrics of API requests made and
// send those metrics to the CSM endpoint.
type Reporter struct {
clientID string
url string
conn net.Conn
metricsCh metricChan
done chan struct{}
}
var (
sender *Reporter
)
func connect(url string) error {
const network = "udp"
if err := sender.connect(network, url); err != nil {
return err
}
if sender.done == nil {
sender.done = make(chan struct{})
go sender.start()
}
return nil
}
func newReporter(clientID, url string) *Reporter {
return &Reporter{
clientID: clientID,
url: url,
metricsCh: newMetricChan(MetricsChannelSize),
}
}
func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
if rep == nil {
return
}
now := time.Now()
creds, _ := r.Config.Credentials.Get()
m := metric{
ClientID: aws.String(rep.clientID),
API: aws.String(r.Operation.Name),
Service: aws.String(r.ClientInfo.ServiceID),
Timestamp: (*metricTime)(&now),
UserAgent: aws.String(r.HTTPRequest.Header.Get("User-Agent")),
Region: r.Config.Region,
Type: aws.String("ApiCallAttempt"),
Version: aws.Int(1),
XAmzRequestID: aws.String(r.RequestID),
AttemptCount: aws.Int(r.RetryCount + 1),
AttemptLatency: aws.Int(int(now.Sub(r.AttemptTime).Nanoseconds() / int64(time.Millisecond))),
AccessKey: aws.String(creds.AccessKeyID),
}
if r.HTTPResponse != nil {
m.HTTPStatusCode = aws.Int(r.HTTPResponse.StatusCode)
}
if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok {
setError(&m, awserr)
}
}
rep.metricsCh.Push(m)
}
func setError(m *metric, err awserr.Error) {
msg := err.Error()
code := err.Code()
switch code {
case "RequestError",
"SerializationError",
request.CanceledErrorCode:
m.SDKException = &code
m.SDKExceptionMessage = &msg
default:
m.AWSException = &code
m.AWSExceptionMessage = &msg
}
}
func (rep *Reporter) sendAPICallMetric(r *request.Request) {
if rep == nil {
return
}
now := time.Now()
m := metric{
ClientID: aws.String(rep.clientID),
API: aws.String(r.Operation.Name),
Service: aws.String(r.ClientInfo.ServiceID),
Timestamp: (*metricTime)(&now),
Type: aws.String("ApiCall"),
AttemptCount: aws.Int(r.RetryCount + 1),
Region: r.Config.Region,
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)),
XAmzRequestID: aws.String(r.RequestID),
MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())),
}
// TODO: Probably want to figure something out for logging dropped
// metrics
rep.metricsCh.Push(m)
}
func (rep *Reporter) connect(network, url string) error {
if rep.conn != nil {
rep.conn.Close()
}
conn, err := net.Dial(network, url)
if err != nil {
return awserr.New("UDPError", "Could not connect", err)
}
rep.conn = conn
return nil
}
func (rep *Reporter) close() {
if rep.done != nil {
close(rep.done)
}
rep.metricsCh.Pause()
}
func (rep *Reporter) start() {
defer func() {
rep.metricsCh.Pause()
}()
for {
select {
case <-rep.done:
rep.done = nil
return
case m := <-rep.metricsCh.ch:
// TODO: What to do with this error? Probably should just log
b, err := json.Marshal(m)
if err != nil {
continue
}
rep.conn.Write(b)
}
}
}
// Pause will pause the metric channel preventing any new metrics from
// being added.
func (rep *Reporter) Pause() {
lock.Lock()
defer lock.Unlock()
if rep == nil {
return
}
rep.close()
}
// Continue will reopen the metric channel and allow for monitoring
// to be resumed.
func (rep *Reporter) Continue() {
lock.Lock()
defer lock.Unlock()
if rep == nil {
return
}
if !rep.metricsCh.IsPaused() {
return
}
rep.metricsCh.Continue()
}
// InjectHandlers will will enable client side metrics and inject the proper
// handlers to handle how metrics are sent.
//
// Example:
// // Start must be called in order to inject the correct handlers
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err))
// }
//
// sess := session.NewSession()
// r.InjectHandlers(&sess.Handlers)
//
// // create a new service client with our client side metric session
// svc := s3.New(sess)
func (rep *Reporter) InjectHandlers(handlers *request.Handlers) {
if rep == nil {
return
}
apiCallHandler := request.NamedHandler{Name: APICallMetricHandlerName, Fn: rep.sendAPICallMetric}
apiCallAttemptHandler := request.NamedHandler{Name: APICallAttemptMetricHandlerName, Fn: rep.sendAPICallAttemptMetric}
handlers.Complete.PushFrontNamed(apiCallHandler)
handlers.Complete.PushFrontNamed(apiCallAttemptHandler)
handlers.AfterRetry.PushFrontNamed(apiCallAttemptHandler)
}
// boolIntValue return 1 for true and 0 for false.
func boolIntValue(b bool) int {
if b {
return 1
}
return 0
}

View File

@@ -0,0 +1,72 @@
package csm
import (
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestMaxRetriesExceeded(t *testing.T) {
md := metadata.ClientInfo{
Endpoint: "http://127.0.0.1",
}
cfg := aws.Config{
Region: aws.String("foo"),
Credentials: credentials.NewStaticCredentials("", "", ""),
}
op := &request.Operation{}
cases := []struct {
name string
httpStatusCode int
expectedMaxRetriesValue int
expectedMetrics int
}{
{
name: "max retry reached",
httpStatusCode: http.StatusBadGateway,
expectedMaxRetriesValue: 1,
},
{
name: "status ok",
httpStatusCode: http.StatusOK,
expectedMaxRetriesValue: 0,
},
}
for _, c := range cases {
r := request.New(cfg, md, defaults.Handlers(), client.DefaultRetryer{NumMaxRetries: 2}, op, nil, nil)
reporter := newReporter("", "")
r.Handlers.Send.Clear()
reporter.InjectHandlers(&r.Handlers)
r.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: c.httpStatusCode,
}
})
r.Send()
for {
m := <-reporter.metricsCh.ch
if *m.Type != "ApiCall" {
// ignore non-ApiCall metrics since MaxRetriesExceeded is only on ApiCall events
continue
}
if val := *m.MaxRetriesExceeded; val != c.expectedMaxRetriesValue {
t.Errorf("%s: expected %d, but received %d", c.name, c.expectedMaxRetriesValue, val)
}
break
}
}
}

View File

@@ -0,0 +1,213 @@
package csm_test
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
)
func TestReportingMetrics(t *testing.T) {
reporter := csm.Get()
if reporter == nil {
t.Errorf("expected non-nil reporter")
}
sess := session.New()
sess.Handlers.Clear()
reporter.InjectHandlers(&sess.Handlers)
md := metadata.ClientInfo{}
op := &request.Operation{}
r := request.New(*sess.Config, md, sess.Handlers, client.DefaultRetryer{NumMaxRetries: 0}, op, nil, nil)
sess.Handlers.Complete.Run(r)
foundAttempt := false
foundCall := false
expectedMetrics := 2
for i := 0; i < expectedMetrics; i++ {
m := <-csm.MetricsCh
for k, v := range m {
switch k {
case "Type":
a := v.(string)
foundCall = foundCall || a == "ApiCall"
foundAttempt = foundAttempt || a == "ApiCallAttempt"
if prefix := "ApiCall"; !strings.HasPrefix(a, prefix) {
t.Errorf("expected 'APICall' prefix, but received %q", a)
}
}
}
}
if !foundAttempt {
t.Errorf("expected attempt event to have occurred")
}
if !foundCall {
t.Errorf("expected call event to have occurred")
}
}
type mockService struct {
*client.Client
}
type input struct{}
type output struct{}
func (s *mockService) Request(i input) *request.Request {
op := &request.Operation{
Name: "foo",
HTTPMethod: "POST",
HTTPPath: "/",
}
o := output{}
req := s.NewRequest(op, &i, &o)
return req
}
func BenchmarkWithCSM(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := session.New(&cfg)
r := csm.Get()
r.InjectHandlers(&sess.Handlers)
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}
func BenchmarkWithCSMNoUDPConnection(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := session.New(&cfg)
r := csm.Get()
r.Pause()
r.InjectHandlers(&sess.Handlers)
defer r.Pause()
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}
func BenchmarkWithoutCSM(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := session.New(&cfg)
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}

View File

@@ -8,17 +8,23 @@
package defaults
import (
"fmt"
"net"
"net/http"
"net/url"
"os"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/endpoints"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
// A Defaults provides a collection of default values for SDK clients.
@@ -54,7 +60,7 @@ func Config() *aws.Config {
WithMaxRetries(aws.UseServiceDefaultRetries).
WithLogger(aws.NewDefaultLogger()).
WithLogLevel(aws.LogOff).
WithSleepDelay(time.Sleep)
WithEndpointResolver(endpoints.DefaultResolver())
}
// Handlers returns the default request handlers.
@@ -66,8 +72,12 @@ func Handlers() request.Handlers {
var handlers request.Handlers
handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
handlers.Validate.AfterEachFn = request.HandlerListStopOnError
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
handlers.Build.PushBackNamed(corehandlers.AddHostExecEnvUserAgentHander)
handlers.Build.AfterEachFn = request.HandlerListStopOnError
handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
handlers.Send.PushBackNamed(corehandlers.ValidateReqSigHandler)
handlers.Send.PushBackNamed(corehandlers.SendHandler)
handlers.AfterRetry.PushBackNamed(corehandlers.AfterRetryHandler)
handlers.ValidateResponse.PushBackNamed(corehandlers.ValidateResponseHandler)
@@ -81,15 +91,117 @@ func Handlers() request.Handlers {
// is available if you need to reset the credentials of an
// existing service client or session's Config.
func CredChain(cfg *aws.Config, handlers request.Handlers) *credentials.Credentials {
endpoint, signingRegion := endpoints.EndpointForRegion(ec2metadata.ServiceName, *cfg.Region, true)
return credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.NewClient(*cfg, handlers, endpoint, signingRegion),
ExpiryWindow: 5 * time.Minute,
},
})
return credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: CredProviders(cfg, handlers),
})
}
// CredProviders returns the slice of providers used in
// the default credential chain.
//
// For applications that need to use some other provider (for example use
// different environment variables for legacy reasons) but still fall back
// on the default chain of providers. This allows that default chaint to be
// automatically updated
func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Provider {
return []credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
RemoteCredProvider(*cfg, handlers),
}
}
const (
httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
)
// RemoteCredProvider returns a credentials provider for the default remote
// endpoints such as EC2 or ECS Roles.
func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider {
if u := os.Getenv(httpProviderEnvVar); len(u) > 0 {
return localHTTPCredProvider(cfg, handlers, u)
}
if uri := os.Getenv(shareddefaults.ECSCredsProviderEnvVar); len(uri) > 0 {
u := fmt.Sprintf("%s%s", shareddefaults.ECSContainerCredentialsURI, uri)
return httpCredProvider(cfg, handlers, u)
}
return ec2RoleProvider(cfg, handlers)
}
var lookupHostFn = net.LookupHost
func isLoopbackHost(host string) (bool, error) {
ip := net.ParseIP(host)
if ip != nil {
return ip.IsLoopback(), nil
}
// Host is not an ip, perform lookup
addrs, err := lookupHostFn(host)
if err != nil {
return false, err
}
for _, addr := range addrs {
if !net.ParseIP(addr).IsLoopback() {
return false, nil
}
}
return true, nil
}
func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
var errMsg string
parsed, err := url.Parse(u)
if err != nil {
errMsg = fmt.Sprintf("invalid URL, %v", err)
} else {
host := aws.URLHostname(parsed)
if len(host) == 0 {
errMsg = "unable to parse host from local HTTP cred provider URL"
} else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil {
errMsg = fmt.Sprintf("failed to resolve host %q, %v", host, loopbackErr)
} else if !isLoopback {
errMsg = fmt.Sprintf("invalid endpoint host, %q, only loopback hosts are allowed.", host)
}
}
if len(errMsg) > 0 {
if cfg.Logger != nil {
cfg.Logger.Log("Ignoring, HTTP credential provider", errMsg, err)
}
return credentials.ErrorProvider{
Err: awserr.New("CredentialsEndpointError", errMsg, err),
ProviderName: endpointcreds.ProviderName,
}
}
return httpCredProvider(cfg, handlers, u)
}
func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
return endpointcreds.NewProviderClient(cfg, handlers, u,
func(p *endpointcreds.Provider) {
p.ExpiryWindow = 5 * time.Minute
p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar)
},
)
}
func ec2RoleProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider {
resolver := cfg.EndpointResolver
if resolver == nil {
resolver = endpoints.DefaultResolver()
}
e, _ := resolver.EndpointFor(endpoints.Ec2metadataServiceID, "")
return &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.NewClient(cfg, handlers, e.URL, e.SigningRegion),
ExpiryWindow: 5 * time.Minute,
}
}

View File

@@ -0,0 +1,123 @@
package defaults
import (
"fmt"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func TestHTTPCredProvider(t *testing.T) {
origFn := lookupHostFn
defer func() { lookupHostFn = origFn }()
lookupHostFn = func(host string) ([]string, error) {
m := map[string]struct {
Addrs []string
Err error
}{
"localhost": {Addrs: []string{"::1", "127.0.0.1"}},
"actuallylocal": {Addrs: []string{"127.0.0.2"}},
"notlocal": {Addrs: []string{"::1", "127.0.0.1", "192.168.1.10"}},
"www.example.com": {Addrs: []string{"10.10.10.10"}},
}
h, ok := m[host]
if !ok {
t.Fatalf("unknown host in test, %v", host)
return nil, fmt.Errorf("unknown host")
}
return h.Addrs, h.Err
}
cases := []struct {
Host string
AuthToken string
Fail bool
}{
{Host: "localhost", Fail: false},
{Host: "actuallylocal", Fail: false},
{Host: "127.0.0.1", Fail: false},
{Host: "127.1.1.1", Fail: false},
{Host: "[::1]", Fail: false},
{Host: "www.example.com", Fail: true},
{Host: "169.254.170.2", Fail: true},
{Host: "localhost", Fail: false, AuthToken: "Basic abc123"},
}
defer os.Clearenv()
for i, c := range cases {
u := fmt.Sprintf("http://%s/abc/123", c.Host)
os.Setenv(httpProviderEnvVar, u)
os.Setenv(httpProviderAuthorizationEnvVar, c.AuthToken)
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
t.Fatalf("%d, expect provider not to be nil, but was", i)
}
if c.Fail {
creds, err := provider.Retrieve()
if err == nil {
t.Fatalf("%d, expect error but got none", i)
} else {
aerr := err.(awserr.Error)
if e, a := "CredentialsEndpointError", aerr.Code(); e != a {
t.Errorf("%d, expect %s error code, got %s", i, e, a)
}
}
if e, a := endpointcreds.ProviderName, creds.ProviderName; e != a {
t.Errorf("%d, expect %s provider name got %s", i, e, a)
}
} else {
httpProvider := provider.(*endpointcreds.Provider)
if e, a := u, httpProvider.Client.Endpoint; e != a {
t.Errorf("%d, expect %q endpoint, got %q", i, e, a)
}
if e, a := c.AuthToken, httpProvider.AuthorizationToken; e != a {
t.Errorf("%d, expect %q auth token, got %q", i, e, a)
}
}
}
}
func TestECSCredProvider(t *testing.T) {
defer os.Clearenv()
os.Setenv(shareddefaults.ECSCredsProviderEnvVar, "/abc/123")
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
httpProvider := provider.(*endpointcreds.Provider)
if httpProvider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
if e, a := "http://169.254.170.2/abc/123", httpProvider.Client.Endpoint; e != a {
t.Errorf("expect %q endpoint, got %q", e, a)
}
}
func TestDefaultEC2RoleProvider(t *testing.T) {
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
ec2Provider := provider.(*ec2rolecreds.EC2RoleProvider)
if ec2Provider == nil {
t.Fatalf("expect provider not to be nil, but was")
}
if e, a := "http://169.254.169.254/latest", ec2Provider.Client.Endpoint; e != a {
t.Errorf("expect %q endpoint, got %q", e, a)
}
}

View File

@@ -0,0 +1,27 @@
package defaults
import (
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
// SharedCredentialsFilename returns the SDK's default file path
// for the shared credentials file.
//
// Builds the shared config file path based on the OS's platform.
//
// - Linux/Unix: $HOME/.aws/credentials
// - Windows: %USERPROFILE%\.aws\credentials
func SharedCredentialsFilename() string {
return shareddefaults.SharedCredentialsFilename()
}
// SharedConfigFilename returns the SDK's default file path for
// the shared config file.
//
// Builds the shared config file path based on the OS's platform.
//
// - Linux/Unix: $HOME/.aws/config
// - Windows: %USERPROFILE%\.aws\config
func SharedConfigFilename() string {
return shareddefaults.SharedConfigFilename()
}

56
vendor/github.com/aws/aws-sdk-go/aws/doc.go generated vendored Normal file
View File

@@ -0,0 +1,56 @@
// Package aws provides the core SDK's utilities and shared types. Use this package's
// utilities to simplify setting and reading API operations parameters.
//
// Value and Pointer Conversion Utilities
//
// This package includes a helper conversion utility for each scalar type the SDK's
// API use. These utilities make getting a pointer of the scalar, and dereferencing
// a pointer easier.
//
// Each conversion utility comes in two forms. Value to Pointer and Pointer to Value.
// The Pointer to value will safely dereference the pointer and return its value.
// If the pointer was nil, the scalar's zero value will be returned.
//
// The value to pointer functions will be named after the scalar type. So get a
// *string from a string value use the "String" function. This makes it easy to
// to get pointer of a literal string value, because getting the address of a
// literal requires assigning the value to a variable first.
//
// var strPtr *string
//
// // Without the SDK's conversion functions
// str := "my string"
// strPtr = &str
//
// // With the SDK's conversion functions
// strPtr = aws.String("my string")
//
// // Convert *string to string value
// str = aws.StringValue(strPtr)
//
// In addition to scalars the aws package also includes conversion utilities for
// map and slice for commonly types used in API parameters. The map and slice
// conversion functions use similar naming pattern as the scalar conversion
// functions.
//
// var strPtrs []*string
// var strs []string = []string{"Go", "Gophers", "Go"}
//
// // Convert []string to []*string
// strPtrs = aws.StringSlice(strs)
//
// // Convert []*string to []string
// strs = aws.StringValueSlice(strPtrs)
//
// SDK Default HTTP Client
//
// The SDK will use the http.DefaultClient if a HTTP client is not provided to
// the SDK's Session, or service client constructor. This means that if the
// http.DefaultClient is modified by other components of your application the
// modifications will be picked up by the SDK as well.
//
// In some cases this might be intended, but it is a better practice to create
// a custom HTTP Client to share explicitly through your application. You can
// configure the SDK to use the custom HTTP Client by setting the HTTPClient
// value of the SDK's Config type when creating a Session or service client.
package aws

View File

@@ -1,17 +1,25 @@
package ec2metadata
import (
"path"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkuri"
)
// GetMetadata uses the path provided to request
// GetMetadata uses the path provided to request information from the EC2
// instance metdata service. The content will be returned as a string, or
// error if the request failed.
func (c *EC2Metadata) GetMetadata(p string) (string, error) {
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", p),
HTTPPath: sdkuri.PathJoin("/meta-data", p),
}
output := &metadataOutput{}
@@ -20,6 +28,89 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
return output.Content, req.Send()
}
// GetUserData returns the userdata that was configured for the service. If
// there is no user-data setup for the EC2 instance a "NotFoundError" error
// code will be returned.
func (c *EC2Metadata) GetUserData() (string, error) {
op := &request.Operation{
Name: "GetUserData",
HTTPMethod: "GET",
HTTPPath: "/user-data",
}
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
req.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
if r.HTTPResponse.StatusCode == http.StatusNotFound {
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
}
})
return output.Content, req.Send()
}
// GetDynamicData uses the path provided to request information from the EC2
// instance metadata service for dynamic data. The content will be returned
// as a string, or error if the request failed.
func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
op := &request.Operation{
Name: "GetDynamicData",
HTTPMethod: "GET",
HTTPPath: sdkuri.PathJoin("/dynamic", p),
}
output := &metadataOutput{}
req := c.NewRequest(op, nil, output)
return output.Content, req.Send()
}
// GetInstanceIdentityDocument retrieves an identity document describing an
// instance. Error is returned if the request fails or is unable to parse
// the response.
func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument, error) {
resp, err := c.GetDynamicData("instance-identity/document")
if err != nil {
return EC2InstanceIdentityDocument{},
awserr.New("EC2MetadataRequestError",
"failed to get EC2 instance identity document", err)
}
doc := EC2InstanceIdentityDocument{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil {
return EC2InstanceIdentityDocument{},
awserr.New("SerializationError",
"failed to decode EC2 instance identity document", err)
}
return doc, nil
}
// IAMInfo retrieves IAM info from the metadata API
func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) {
resp, err := c.GetMetadata("iam/info")
if err != nil {
return EC2IAMInfo{},
awserr.New("EC2MetadataRequestError",
"failed to get EC2 IAM info", err)
}
info := EC2IAMInfo{}
if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil {
return EC2IAMInfo{},
awserr.New("SerializationError",
"failed to decode EC2 IAM info", err)
}
if info.Code != "Success" {
errMsg := fmt.Sprintf("failed to get EC2 IAM Info (%s)", info.Code)
return EC2IAMInfo{},
awserr.New("EC2MetadataError", errMsg, nil)
}
return info, nil
}
// Region returns the region the instance is running in.
func (c *EC2Metadata) Region() (string, error) {
resp, err := c.GetMetadata("placement/availability-zone")
@@ -41,3 +132,31 @@ func (c *EC2Metadata) Available() bool {
return true
}
// An EC2IAMInfo provides the shape for unmarshaling
// an IAM info from the metadata API
type EC2IAMInfo struct {
Code string
LastUpdated time.Time
InstanceProfileArn string
InstanceProfileID string
}
// An EC2InstanceIdentityDocument provides the shape for unmarshaling
// an instance identity document
type EC2InstanceIdentityDocument struct {
DevpayProductCodes []string `json:"devpayProductCodes"`
AvailabilityZone string `json:"availabilityZone"`
PrivateIP string `json:"privateIp"`
Version string `json:"version"`
Region string `json:"region"`
InstanceID string `json:"instanceId"`
BillingProducts []string `json:"billingProducts"`
InstanceType string `json:"instanceType"`
AccountID string `json:"accountId"`
PendingTime time.Time `json:"pendingTime"`
ImageID string `json:"imageId"`
KernelID string `json:"kernelId"`
RamdiskID string `json:"ramdiskId"`
Architecture string `json:"architecture"`
}

View File

@@ -2,21 +2,53 @@ package ec2metadata_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"path"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
const instanceIdentityDocument = `{
"devpayProductCodes" : null,
"availabilityZone" : "us-east-1d",
"privateIp" : "10.158.112.84",
"version" : "2010-08-31",
"region" : "us-east-1",
"instanceId" : "i-1234567890abcdef0",
"billingProducts" : null,
"instanceType" : "t1.micro",
"accountId" : "123456789012",
"pendingTime" : "2015-11-19T16:32:11Z",
"imageId" : "ami-5fb8c835",
"kernelId" : "aki-919dcaf8",
"ramdiskId" : null,
"architecture" : "x86_64"
}`
const validIamInfo = `{
"Code" : "Success",
"LastUpdated" : "2016-03-17T12:27:32Z",
"InstanceProfileArn" : "arn:aws:iam::123456789012:instance-profile/my-instance-profile",
"InstanceProfileId" : "AIPAABCDEFGHIJKLMN123"
}`
const unsuccessfulIamInfo = `{
"Code" : "Failed",
"LastUpdated" : "2016-03-17T12:27:32Z",
"InstanceProfileArn" : "arn:aws:iam::123456789012:instance-profile/my-instance-profile",
"InstanceProfileId" : "AIPAABCDEFGHIJKLMN123"
}`
func initTestServer(path string, resp string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI != path {
@@ -29,7 +61,7 @@ func initTestServer(path string, resp string) *httptest.Server {
}
func TestEndpoint(t *testing.T) {
c := ec2metadata.New(session.New())
c := ec2metadata.New(unit.Session)
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
@@ -37,8 +69,12 @@ func TestEndpoint(t *testing.T) {
}
req := c.NewRequest(op, nil, nil)
assert.Equal(t, "http://169.254.169.254/latest", req.ClientInfo.Endpoint)
assert.Equal(t, "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String())
if e, a := "http://169.254.169.254/latest", req.ClientInfo.Endpoint; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetMetadata(t *testing.T) {
@@ -47,12 +83,70 @@ func TestGetMetadata(t *testing.T) {
"success", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetMetadata("some/path")
assert.NoError(t, err)
assert.Equal(t, "success", resp)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "success", resp; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetUserData(t *testing.T) {
server := initTestServer(
"/latest/user-data",
"success", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetUserData()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "success", resp; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetUserData_Error(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reader := strings.NewReader(`<?xml version="1.0" encoding="iso-8859-1"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en" lang="en">
<head>
<title>404 - Not Found</title>
</head>
<body>
<h1>404 - Not Found</h1>
</body>
</html>`)
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Length", fmt.Sprintf("%d", reader.Len()))
w.WriteHeader(http.StatusNotFound)
io.Copy(w, reader)
}))
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
resp, err := c.GetUserData()
if err == nil {
t.Errorf("expect error")
}
if len(resp) != 0 {
t.Errorf("expect empty, got %v", resp)
}
aerr := err.(awserr.Error)
if e, a := "NotFoundError", aerr.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestGetRegion(t *testing.T) {
@@ -61,12 +155,16 @@ func TestGetRegion(t *testing.T) {
"us-west-2a", // real response includes suffix
)
defer server.Close()
c := ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
region, err := c.Region()
assert.NoError(t, err)
assert.Equal(t, "us-west-2", region)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "us-west-2", region; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataAvailable(t *testing.T) {
@@ -75,15 +173,61 @@ func TestMetadataAvailable(t *testing.T) {
"instance-id",
)
defer server.Close()
c := ec2metadata.New(session.New(), &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
available := c.Available()
if !c.Available() {
t.Errorf("expect available")
}
}
assert.True(t, available)
func TestMetadataIAMInfo_success(t *testing.T) {
server := initTestServer(
"/latest/meta-data/iam/info",
validIamInfo,
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
iamInfo, err := c.IAMInfo()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "Success", iamInfo.Code; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "arn:aws:iam::123456789012:instance-profile/my-instance-profile", iamInfo.InstanceProfileArn; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "AIPAABCDEFGHIJKLMN123", iamInfo.InstanceProfileID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataIAMInfo_failure(t *testing.T) {
server := initTestServer(
"/latest/meta-data/iam/info",
unsuccessfulIamInfo,
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
iamInfo, err := c.IAMInfo()
if err == nil {
t.Errorf("expect error")
}
if e, a := "", iamInfo.Code; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "", iamInfo.InstanceProfileArn; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "", iamInfo.InstanceProfileID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataNotAvailable(t *testing.T) {
c := ec2metadata.New(session.New())
c := ec2metadata.New(unit.Session)
c.Handlers.Send.Clear()
c.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
@@ -95,7 +239,51 @@ func TestMetadataNotAvailable(t *testing.T) {
r.Retryable = aws.Bool(true) // network errors are retryable
})
available := c.Available()
assert.False(t, available)
if c.Available() {
t.Errorf("expect not available")
}
}
func TestMetadataErrorResponse(t *testing.T) {
c := ec2metadata.New(unit.Session)
c.Handlers.Send.Clear()
c.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest),
Body: ioutil.NopCloser(strings.NewReader("error message text")),
}
r.Retryable = aws.Bool(false) // network errors are retryable
})
data, err := c.GetMetadata("uri/path")
if len(data) != 0 {
t.Errorf("expect empty, got %v", data)
}
if e, a := "error message text", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %v to be in %v", e, a)
}
}
func TestEC2RoleProviderInstanceIdentity(t *testing.T) {
server := initTestServer(
"/latest/dynamic/instance-identity/document",
instanceIdentityDocument,
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
doc, err := c.GetInstanceIdentityDocument()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := doc.AccountID, "123456789012"; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := doc.AvailabilityZone, "us-east-1d"; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := doc.Region, "us-east-1"; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}

View File

@@ -1,22 +1,32 @@
// Package ec2metadata provides the client for making API calls to the
// EC2 Metadata service.
//
// This package's client can be disabled completely by setting the environment
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
// be used while the environemnt variable is set to true, (case insensitive).
package ec2metadata
import (
"io/ioutil"
"net"
"bytes"
"errors"
"io"
"net/http"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
)
// ServiceName is the name of the service.
const ServiceName = "ec2metadata"
const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
// A EC2Metadata is an EC2 Metadata service Client.
type EC2Metadata struct {
@@ -26,6 +36,7 @@ type EC2Metadata struct {
// New creates a new instance of the EC2Metadata client with a session.
// This client is safe to use across multiple goroutines.
//
//
// Example:
// // Create a EC2Metadata client from just a session.
// svc := ec2metadata.New(mySession)
@@ -40,22 +51,19 @@ func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata {
// NewClient returns a new EC2Metadata client. Should be used to create
// a client when not using a session. Generally using just New with a session
// is preferred.
//
// If an unmodified HTTP client is provided from the stdlib default, or no client
// the EC2RoleProvider's EC2Metadata HTTP client's timeout will be shortened.
// To disable this set Config.EC2MetadataDisableTimeoutOverride to false. Enabled by default.
func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string, opts ...func(*client.Client)) *EC2Metadata {
// If the default http client is provided, replace it with a custom
// client using default timeouts.
if cfg.HTTPClient == http.DefaultClient {
if !aws.BoolValue(cfg.EC2MetadataDisableTimeoutOverride) && httpClientZero(cfg.HTTPClient) {
// If the http client is unmodified and this feature is not disabled
// set custom timeouts for EC2Metadata requests.
cfg.HTTPClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
// use a shorter timeout than default because the metadata
// service is local if it is running, and to fail faster
// if not running on an ec2 instance.
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
},
// use a shorter timeout than default because the metadata
// service is local if it is running, and to fail faster
// if not running on an ec2 instance.
Timeout: 5 * time.Second,
}
}
@@ -64,6 +72,7 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
ServiceID: ServiceName,
Endpoint: endpoint,
APIVersion: "latest",
},
@@ -76,6 +85,21 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
svc.Handlers.Validate.Clear()
svc.Handlers.Validate.PushBack(validateEndpointHandler)
// Disable the EC2 Metadata service if the environment variable is set.
// This shortcirctes the service's functionality to always fail to send
// requests.
if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" {
svc.Handlers.Send.SwapNamed(request.NamedHandler{
Name: corehandlers.SendHandler.Name,
Fn: func(r *request.Request) {
r.Error = awserr.New(
request.CanceledErrorCode,
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
nil)
},
})
}
// Add additional options to the service config
for _, option := range opts {
option(svc.Client)
@@ -84,29 +108,38 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
return svc
}
func httpClientZero(c *http.Client) bool {
return c == nil || (c.Transport == nil && c.CheckRedirect == nil && c.Jar == nil && c.Timeout == 0)
}
type metadataOutput struct {
Content string
}
func unmarshalHandler(r *request.Request) {
defer r.HTTPResponse.Body.Close()
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err)
return
}
data := r.Data.(*metadataOutput)
data.Content = string(b)
if data, ok := r.Data.(*metadataOutput); ok {
data.Content = b.String()
}
}
func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
_, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
b := &bytes.Buffer{}
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err)
return
}
// TODO extract the error...
// Response body format is not consistent between metadata endpoints.
// Grab the error message as a string and include that as the source error
r.Error = awserr.New("EC2MetadataError", "failed to make EC2Metadata request", errors.New(b.String()))
}
func validateEndpointHandler(r *request.Request) {

View File

@@ -0,0 +1,120 @@
package ec2metadata_test
import (
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e == a {
t.Errorf("expect %v, not to equal %v", e, a)
}
if e, a := 5*time.Second, svc.Config.HTTPClient.Timeout; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}
func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
http.DefaultClient.Transport = &http.Transport{}
defer func() {
http.DefaultClient.Transport = nil
}()
svc := ec2metadata.New(unit.Session)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
t.Errorf("expect %v, got %v", e, a)
}
tr := svc.Config.HTTPClient.Transport.(*http.Transport)
if tr == nil {
t.Fatalf("expect transport not to be nil")
}
if tr.Dial != nil {
t.Errorf("expect dial to be nil, was not")
}
}
func TestClientDisableOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session, aws.NewConfig().WithEC2MetadataDisableTimeoutOverride(true))
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestClientOverrideDefaultHTTPClientTimeoutRace(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("us-east-1a"))
}))
cfg := aws.NewConfig().WithEndpoint(server.URL)
runEC2MetadataClients(t, cfg, 100)
}
func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("us-east-1a"))
}))
cfg := aws.NewConfig().WithEndpoint(server.URL).WithHTTPClient(&http.Client{
Transport: http.DefaultTransport,
})
runEC2MetadataClients(t, cfg, 100)
}
func TestClientDisableIMDS(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
os.Setenv("AWS_EC2_METADATA_DISABLED", "true")
svc := ec2metadata.New(unit.Session)
resp, err := svc.Region()
if err == nil {
t.Fatalf("expect error, got none")
}
if len(resp) != 0 {
t.Errorf("expect no response, got %v", resp)
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %v error code, got %v", e, a)
}
if e, a := "AWS_EC2_METADATA_DISABLED", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expect %v in error message, got %v", e, a)
}
}
func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
var wg sync.WaitGroup
wg.Add(atOnce)
for i := 0; i < atOnce; i++ {
go func() {
svc := ec2metadata.New(unit.Session, cfg)
_, err := svc.Region()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
wg.Done()
}()
}
wg.Wait()
}

View File

@@ -0,0 +1,155 @@
package endpoints
import (
"encoding/json"
"fmt"
"io"
"github.com/aws/aws-sdk-go/aws/awserr"
)
type modelDefinition map[string]json.RawMessage
// A DecodeModelOptions are the options for how the endpoints model definition
// are decoded.
type DecodeModelOptions struct {
SkipCustomizations bool
}
// Set combines all of the option functions together.
func (d *DecodeModelOptions) Set(optFns ...func(*DecodeModelOptions)) {
for _, fn := range optFns {
fn(d)
}
}
// DecodeModel unmarshals a Regions and Endpoint model definition file into
// a endpoint Resolver. If the file format is not supported, or an error occurs
// when unmarshaling the model an error will be returned.
//
// Casting the return value of this func to a EnumPartitions will
// allow you to get a list of the partitions in the order the endpoints
// will be resolved in.
//
// resolver, err := endpoints.DecodeModel(reader)
//
// partitions := resolver.(endpoints.EnumPartitions).Partitions()
// for _, p := range partitions {
// // ... inspect partitions
// }
func DecodeModel(r io.Reader, optFns ...func(*DecodeModelOptions)) (Resolver, error) {
var opts DecodeModelOptions
opts.Set(optFns...)
// Get the version of the partition file to determine what
// unmarshaling model to use.
modelDef := modelDefinition{}
if err := json.NewDecoder(r).Decode(&modelDef); err != nil {
return nil, newDecodeModelError("failed to decode endpoints model", err)
}
var version string
if b, ok := modelDef["version"]; ok {
version = string(b)
} else {
return nil, newDecodeModelError("endpoints version not found in model", nil)
}
if version == "3" {
return decodeV3Endpoints(modelDef, opts)
}
return nil, newDecodeModelError(
fmt.Sprintf("endpoints version %s, not supported", version), nil)
}
func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resolver, error) {
b, ok := modelDef["partitions"]
if !ok {
return nil, newDecodeModelError("endpoints model missing partitions", nil)
}
ps := partitions{}
if err := json.Unmarshal(b, &ps); err != nil {
return nil, newDecodeModelError("failed to decode endpoints model", err)
}
if opts.SkipCustomizations {
return ps, nil
}
// Customization
for i := 0; i < len(ps); i++ {
p := &ps[i]
custAddEC2Metadata(p)
custAddS3DualStack(p)
custRmIotDataService(p)
custFixAppAutoscalingChina(p)
}
return ps, nil
}
func custAddS3DualStack(p *partition) {
if p.ID != "aws" {
return
}
s, ok := p.Services["s3"]
if !ok {
return
}
s.Defaults.HasDualStack = boxedTrue
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}"
p.Services["s3"] = s
}
func custAddEC2Metadata(p *partition) {
p.Services["ec2metadata"] = service{
IsRegionalized: boxedFalse,
PartitionEndpoint: "aws-global",
Endpoints: endpoints{
"aws-global": endpoint{
Hostname: "169.254.169.254/latest",
Protocols: []string{"http"},
},
},
}
}
func custRmIotDataService(p *partition) {
delete(p.Services, "data.iot")
}
func custFixAppAutoscalingChina(p *partition) {
if p.ID != "aws-cn" {
return
}
const serviceName = "application-autoscaling"
s, ok := p.Services[serviceName]
if !ok {
return
}
const expectHostname = `autoscaling.{region}.amazonaws.com`
if e, a := s.Defaults.Hostname, expectHostname; e != a {
fmt.Printf("custFixAppAutoscalingChina: ignoring customization, expected %s, got %s\n", e, a)
return
}
s.Defaults.Hostname = expectHostname + ".cn"
p.Services[serviceName] = s
}
type decodeModelError struct {
awsError
}
func newDecodeModelError(msg string, err error) decodeModelError {
return decodeModelError{
awsError: awserr.New("DecodeEndpointsModelError", msg, err),
}
}

View File

@@ -0,0 +1,174 @@
package endpoints
import (
"strings"
"testing"
)
func TestDecodeEndpoints_V3(t *testing.T) {
const v3Doc = `
{
"version": 3,
"partitions": [
{
"defaults": {
"hostname": "{service}.{region}.{dnsSuffix}",
"protocols": [
"https"
],
"signatureVersions": [
"v4"
]
},
"dnsSuffix": "amazonaws.com",
"partition": "aws",
"partitionName": "AWS Standard",
"regionRegex": "^(us|eu|ap|sa|ca)\\-\\w+\\-\\d+$",
"regions": {
"ap-northeast-1": {
"description": "Asia Pacific (Tokyo)"
}
},
"services": {
"acm": {
"endpoints": {
"ap-northeast-1": {}
}
},
"s3": {
"endpoints": {
"ap-northeast-1": {}
}
}
}
}
]
}`
resolver, err := DecodeModel(strings.NewReader(v3Doc))
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
endpoint, err := resolver.EndpointFor("acm", "ap-northeast-1")
if err != nil {
t.Fatalf("failed to resolve endpoint, %v", err)
}
if a, e := endpoint.URL, "https://acm.ap-northeast-1.amazonaws.com"; a != e {
t.Errorf("expected %q URL got %q", e, a)
}
p := resolver.(partitions)[0]
s3Defaults := p.Services["s3"].Defaults
if a, e := s3Defaults.HasDualStack, boxedTrue; a != e {
t.Errorf("expect s3 service to have dualstack enabled")
}
if a, e := s3Defaults.DualStackHostname, "{service}.dualstack.{region}.{dnsSuffix}"; a != e {
t.Errorf("expect s3 dualstack host pattern to be %q, got %q", e, a)
}
ec2metaEndpoint := p.Services["ec2metadata"].Endpoints["aws-global"]
if a, e := ec2metaEndpoint.Hostname, "169.254.169.254/latest"; a != e {
t.Errorf("expect ec2metadata host to be %q, got %q", e, a)
}
}
func TestDecodeEndpoints_NoPartitions(t *testing.T) {
const doc = `{ "version": 3 }`
resolver, err := DecodeModel(strings.NewReader(doc))
if err == nil {
t.Fatalf("expected error")
}
if resolver != nil {
t.Errorf("expect resolver to be nil")
}
}
func TestDecodeEndpoints_UnsupportedVersion(t *testing.T) {
const doc = `{ "version": 2 }`
resolver, err := DecodeModel(strings.NewReader(doc))
if err == nil {
t.Fatalf("expected error decoding model")
}
if resolver != nil {
t.Errorf("expect resolver to be nil")
}
}
func TestDecodeModelOptionsSet(t *testing.T) {
var actual DecodeModelOptions
actual.Set(func(o *DecodeModelOptions) {
o.SkipCustomizations = true
})
expect := DecodeModelOptions{
SkipCustomizations: true,
}
if actual != expect {
t.Errorf("expect %v options got %v", expect, actual)
}
}
func TestCustFixAppAutoscalingChina(t *testing.T) {
const doc = `
{
"version": 3,
"partitions": [{
"defaults" : {
"hostname" : "{service}.{region}.{dnsSuffix}",
"protocols" : [ "https" ],
"signatureVersions" : [ "v4" ]
},
"dnsSuffix" : "amazonaws.com.cn",
"partition" : "aws-cn",
"partitionName" : "AWS China",
"regionRegex" : "^cn\\-\\w+\\-\\d+$",
"regions" : {
"cn-north-1" : {
"description" : "China (Beijing)"
},
"cn-northwest-1" : {
"description" : "China (Ningxia)"
}
},
"services" : {
"application-autoscaling" : {
"defaults" : {
"credentialScope" : {
"service" : "application-autoscaling"
},
"hostname" : "autoscaling.{region}.amazonaws.com",
"protocols" : [ "http", "https" ]
},
"endpoints" : {
"cn-north-1" : { },
"cn-northwest-1" : { }
}
}
}
}]
}`
resolver, err := DecodeModel(strings.NewReader(doc))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
endpoint, err := resolver.EndpointFor(
"application-autoscaling", "cn-northwest-1",
)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := `https://autoscaling.cn-northwest-1.amazonaws.com.cn`, endpoint.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}

File diff suppressed because it is too large Load Diff

66
vendor/github.com/aws/aws-sdk-go/aws/endpoints/doc.go generated vendored Normal file
View File

@@ -0,0 +1,66 @@
// Package endpoints provides the types and functionality for defining regions
// and endpoints, as well as querying those definitions.
//
// The SDK's Regions and Endpoints metadata is code generated into the endpoints
// package, and is accessible via the DefaultResolver function. This function
// returns a endpoint Resolver will search the metadata and build an associated
// endpoint if one is found. The default resolver will search all partitions
// known by the SDK. e.g AWS Standard (aws), AWS China (aws-cn), and
// AWS GovCloud (US) (aws-us-gov).
// .
//
// Enumerating Regions and Endpoint Metadata
//
// Casting the Resolver returned by DefaultResolver to a EnumPartitions interface
// will allow you to get access to the list of underlying Partitions with the
// Partitions method. This is helpful if you want to limit the SDK's endpoint
// resolving to a single partition, or enumerate regions, services, and endpoints
// in the partition.
//
// resolver := endpoints.DefaultResolver()
// partitions := resolver.(endpoints.EnumPartitions).Partitions()
//
// for _, p := range partitions {
// fmt.Println("Regions for", p.ID())
// for id, _ := range p.Regions() {
// fmt.Println("*", id)
// }
//
// fmt.Println("Services for", p.ID())
// for id, _ := range p.Services() {
// fmt.Println("*", id)
// }
// }
//
// Using Custom Endpoints
//
// The endpoints package also gives you the ability to use your own logic how
// endpoints are resolved. This is a great way to define a custom endpoint
// for select services, without passing that logic down through your code.
//
// If a type implements the Resolver interface it can be used to resolve
// endpoints. To use this with the SDK's Session and Config set the value
// of the type to the EndpointsResolver field of aws.Config when initializing
// the session, or service client.
//
// In addition the ResolverFunc is a wrapper for a func matching the signature
// of Resolver.EndpointFor, converting it to a type that satisfies the
// Resolver interface.
//
//
// myCustomResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
// if service == endpoints.S3ServiceID {
// return endpoints.ResolvedEndpoint{
// URL: "s3.custom.endpoint.com",
// SigningRegion: "custom-signing-region",
// }, nil
// }
//
// return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
// }
//
// sess := session.Must(session.NewSession(&aws.Config{
// Region: aws.String("us-west-2"),
// EndpointResolver: endpoints.ResolverFunc(myCustomResolver),
// }))
package endpoints

View File

@@ -0,0 +1,449 @@
package endpoints
import (
"fmt"
"regexp"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// Options provide the configuration needed to direct how the
// endpoints will be resolved.
type Options struct {
// DisableSSL forces the endpoint to be resolved as HTTP.
// instead of HTTPS if the service supports it.
DisableSSL bool
// Sets the resolver to resolve the endpoint as a dualstack endpoint
// for the service. If dualstack support for a service is not known and
// StrictMatching is not enabled a dualstack endpoint for the service will
// be returned. This endpoint may not be valid. If StrictMatching is
// enabled only services that are known to support dualstack will return
// dualstack endpoints.
UseDualStack bool
// Enables strict matching of services and regions resolved endpoints.
// If the partition doesn't enumerate the exact service and region an
// error will be returned. This option will prevent returning endpoints
// that look valid, but may not resolve to any real endpoint.
StrictMatching bool
// Enables resolving a service endpoint based on the region provided if the
// service does not exist. The service endpoint ID will be used as the service
// domain name prefix. By default the endpoint resolver requires the service
// to be known when resolving endpoints.
//
// If resolving an endpoint on the partition list the provided region will
// be used to determine which partition's domain name pattern to the service
// endpoint ID with. If both the service and region are unkonwn and resolving
// the endpoint on partition list an UnknownEndpointError error will be returned.
//
// If resolving and endpoint on a partition specific resolver that partition's
// domain name pattern will be used with the service endpoint ID. If both
// region and service do not exist when resolving an endpoint on a specific
// partition the partition's domain pattern will be used to combine the
// endpoint and region together.
//
// This option is ignored if StrictMatching is enabled.
ResolveUnknownService bool
}
// Set combines all of the option functions together.
func (o *Options) Set(optFns ...func(*Options)) {
for _, fn := range optFns {
fn(o)
}
}
// DisableSSLOption sets the DisableSSL options. Can be used as a functional
// option when resolving endpoints.
func DisableSSLOption(o *Options) {
o.DisableSSL = true
}
// UseDualStackOption sets the UseDualStack option. Can be used as a functional
// option when resolving endpoints.
func UseDualStackOption(o *Options) {
o.UseDualStack = true
}
// StrictMatchingOption sets the StrictMatching option. Can be used as a functional
// option when resolving endpoints.
func StrictMatchingOption(o *Options) {
o.StrictMatching = true
}
// ResolveUnknownServiceOption sets the ResolveUnknownService option. Can be used
// as a functional option when resolving endpoints.
func ResolveUnknownServiceOption(o *Options) {
o.ResolveUnknownService = true
}
// A Resolver provides the interface for functionality to resolve endpoints.
// The build in Partition and DefaultResolver return value satisfy this interface.
type Resolver interface {
EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error)
}
// ResolverFunc is a helper utility that wraps a function so it satisfies the
// Resolver interface. This is useful when you want to add additional endpoint
// resolving logic, or stub out specific endpoints with custom values.
type ResolverFunc func(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error)
// EndpointFor wraps the ResolverFunc function to satisfy the Resolver interface.
func (fn ResolverFunc) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return fn(service, region, opts...)
}
var schemeRE = regexp.MustCompile("^([^:]+)://")
// AddScheme adds the HTTP or HTTPS schemes to a endpoint URL if there is no
// scheme. If disableSSL is true HTTP will set HTTP instead of the default HTTPS.
//
// If disableSSL is set, it will only set the URL's scheme if the URL does not
// contain a scheme.
func AddScheme(endpoint string, disableSSL bool) string {
if !schemeRE.MatchString(endpoint) {
scheme := "https"
if disableSSL {
scheme = "http"
}
endpoint = fmt.Sprintf("%s://%s", scheme, endpoint)
}
return endpoint
}
// EnumPartitions a provides a way to retrieve the underlying partitions that
// make up the SDK's default Resolver, or any resolver decoded from a model
// file.
//
// Use this interface with DefaultResolver and DecodeModels to get the list of
// Partitions.
type EnumPartitions interface {
Partitions() []Partition
}
// RegionsForService returns a map of regions for the partition and service.
// If either the partition or service does not exist false will be returned
// as the second parameter.
//
// This example shows how to get the regions for DynamoDB in the AWS partition.
// rs, exists := endpoints.RegionsForService(endpoints.DefaultPartitions(), endpoints.AwsPartitionID, endpoints.DynamodbServiceID)
//
// This is equivalent to using the partition directly.
// rs := endpoints.AwsPartition().Services()[endpoints.DynamodbServiceID].Regions()
func RegionsForService(ps []Partition, partitionID, serviceID string) (map[string]Region, bool) {
for _, p := range ps {
if p.ID() != partitionID {
continue
}
if _, ok := p.p.Services[serviceID]; !ok {
break
}
s := Service{
id: serviceID,
p: p.p,
}
return s.Regions(), true
}
return map[string]Region{}, false
}
// PartitionForRegion returns the first partition which includes the region
// passed in. This includes both known regions and regions which match
// a pattern supported by the partition which may include regions that are
// not explicitly known by the partition. Use the Regions method of the
// returned Partition if explicit support is needed.
func PartitionForRegion(ps []Partition, regionID string) (Partition, bool) {
for _, p := range ps {
if _, ok := p.p.Regions[regionID]; ok || p.p.RegionRegex.MatchString(regionID) {
return p, true
}
}
return Partition{}, false
}
// A Partition provides the ability to enumerate the partition's regions
// and services.
type Partition struct {
id string
p *partition
}
// ID returns the identifier of the partition.
func (p Partition) ID() string { return p.id }
// EndpointFor attempts to resolve the endpoint based on service and region.
// See Options for information on configuring how the endpoint is resolved.
//
// If the service cannot be found in the metadata the UnknownServiceError
// error will be returned. This validation will occur regardless if
// StrictMatching is enabled. To enable resolving unknown services set the
// "ResolveUnknownService" option to true. When StrictMatching is disabled
// this option allows the partition resolver to resolve a endpoint based on
// the service endpoint ID provided.
//
// When resolving endpoints you can choose to enable StrictMatching. This will
// require the provided service and region to be known by the partition.
// If the endpoint cannot be strictly resolved an error will be returned. This
// mode is useful to ensure the endpoint resolved is valid. Without
// StrictMatching enabled the endpoint returned my look valid but may not work.
// StrictMatching requires the SDK to be updated if you want to take advantage
// of new regions and services expansions.
//
// Errors that can be returned.
// * UnknownServiceError
// * UnknownEndpointError
func (p Partition) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return p.p.EndpointFor(service, region, opts...)
}
// Regions returns a map of Regions indexed by their ID. This is useful for
// enumerating over the regions in a partition.
func (p Partition) Regions() map[string]Region {
rs := map[string]Region{}
for id, r := range p.p.Regions {
rs[id] = Region{
id: id,
desc: r.Description,
p: p.p,
}
}
return rs
}
// Services returns a map of Service indexed by their ID. This is useful for
// enumerating over the services in a partition.
func (p Partition) Services() map[string]Service {
ss := map[string]Service{}
for id := range p.p.Services {
ss[id] = Service{
id: id,
p: p.p,
}
}
return ss
}
// A Region provides information about a region, and ability to resolve an
// endpoint from the context of a region, given a service.
type Region struct {
id, desc string
p *partition
}
// ID returns the region's identifier.
func (r Region) ID() string { return r.id }
// Description returns the region's description. The region description
// is free text, it can be empty, and it may change between SDK releases.
func (r Region) Description() string { return r.desc }
// ResolveEndpoint resolves an endpoint from the context of the region given
// a service. See Partition.EndpointFor for usage and errors that can be returned.
func (r Region) ResolveEndpoint(service string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return r.p.EndpointFor(service, r.id, opts...)
}
// Services returns a list of all services that are known to be in this region.
func (r Region) Services() map[string]Service {
ss := map[string]Service{}
for id, s := range r.p.Services {
if _, ok := s.Endpoints[r.id]; ok {
ss[id] = Service{
id: id,
p: r.p,
}
}
}
return ss
}
// A Service provides information about a service, and ability to resolve an
// endpoint from the context of a service, given a region.
type Service struct {
id string
p *partition
}
// ID returns the identifier for the service.
func (s Service) ID() string { return s.id }
// ResolveEndpoint resolves an endpoint from the context of a service given
// a region. See Partition.EndpointFor for usage and errors that can be returned.
func (s Service) ResolveEndpoint(region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return s.p.EndpointFor(s.id, region, opts...)
}
// Regions returns a map of Regions that the service is present in.
//
// A region is the AWS region the service exists in. Whereas a Endpoint is
// an URL that can be resolved to a instance of a service.
func (s Service) Regions() map[string]Region {
rs := map[string]Region{}
for id := range s.p.Services[s.id].Endpoints {
if r, ok := s.p.Regions[id]; ok {
rs[id] = Region{
id: id,
desc: r.Description,
p: s.p,
}
}
}
return rs
}
// Endpoints returns a map of Endpoints indexed by their ID for all known
// endpoints for a service.
//
// A region is the AWS region the service exists in. Whereas a Endpoint is
// an URL that can be resolved to a instance of a service.
func (s Service) Endpoints() map[string]Endpoint {
es := map[string]Endpoint{}
for id := range s.p.Services[s.id].Endpoints {
es[id] = Endpoint{
id: id,
serviceID: s.id,
p: s.p,
}
}
return es
}
// A Endpoint provides information about endpoints, and provides the ability
// to resolve that endpoint for the service, and the region the endpoint
// represents.
type Endpoint struct {
id string
serviceID string
p *partition
}
// ID returns the identifier for an endpoint.
func (e Endpoint) ID() string { return e.id }
// ServiceID returns the identifier the endpoint belongs to.
func (e Endpoint) ServiceID() string { return e.serviceID }
// ResolveEndpoint resolves an endpoint from the context of a service and
// region the endpoint represents. See Partition.EndpointFor for usage and
// errors that can be returned.
func (e Endpoint) ResolveEndpoint(opts ...func(*Options)) (ResolvedEndpoint, error) {
return e.p.EndpointFor(e.serviceID, e.id, opts...)
}
// A ResolvedEndpoint is an endpoint that has been resolved based on a partition
// service, and region.
type ResolvedEndpoint struct {
// The endpoint URL
URL string
// The region that should be used for signing requests.
SigningRegion string
// The service name that should be used for signing requests.
SigningName string
// States that the signing name for this endpoint was derived from metadata
// passed in, but was not explicitly modeled.
SigningNameDerived bool
// The signing method that should be used for signing requests.
SigningMethod string
}
// So that the Error interface type can be included as an anonymous field
// in the requestError struct and not conflict with the error.Error() method.
type awsError awserr.Error
// A EndpointNotFoundError is returned when in StrictMatching mode, and the
// endpoint for the service and region cannot be found in any of the partitions.
type EndpointNotFoundError struct {
awsError
Partition string
Service string
Region string
}
// A UnknownServiceError is returned when the service does not resolve to an
// endpoint. Includes a list of all known services for the partition. Returned
// when a partition does not support the service.
type UnknownServiceError struct {
awsError
Partition string
Service string
Known []string
}
// NewUnknownServiceError builds and returns UnknownServiceError.
func NewUnknownServiceError(p, s string, known []string) UnknownServiceError {
return UnknownServiceError{
awsError: awserr.New("UnknownServiceError",
"could not resolve endpoint for unknown service", nil),
Partition: p,
Service: s,
Known: known,
}
}
// String returns the string representation of the error.
func (e UnknownServiceError) Error() string {
extra := fmt.Sprintf("partition: %q, service: %q",
e.Partition, e.Service)
if len(e.Known) > 0 {
extra += fmt.Sprintf(", known: %v", e.Known)
}
return awserr.SprintError(e.Code(), e.Message(), extra, e.OrigErr())
}
// String returns the string representation of the error.
func (e UnknownServiceError) String() string {
return e.Error()
}
// A UnknownEndpointError is returned when in StrictMatching mode and the
// service is valid, but the region does not resolve to an endpoint. Includes
// a list of all known endpoints for the service.
type UnknownEndpointError struct {
awsError
Partition string
Service string
Region string
Known []string
}
// NewUnknownEndpointError builds and returns UnknownEndpointError.
func NewUnknownEndpointError(p, s, r string, known []string) UnknownEndpointError {
return UnknownEndpointError{
awsError: awserr.New("UnknownEndpointError",
"could not resolve endpoint", nil),
Partition: p,
Service: s,
Region: r,
Known: known,
}
}
// String returns the string representation of the error.
func (e UnknownEndpointError) Error() string {
extra := fmt.Sprintf("partition: %q, service: %q, region: %q",
e.Partition, e.Service, e.Region)
if len(e.Known) > 0 {
extra += fmt.Sprintf(", known: %v", e.Known)
}
return awserr.SprintError(e.Code(), e.Message(), extra, e.OrigErr())
}
// String returns the string representation of the error.
func (e UnknownEndpointError) String() string {
return e.Error()
}

View File

@@ -0,0 +1,342 @@
package endpoints
import "testing"
func TestEnumDefaultPartitions(t *testing.T) {
resolver := DefaultResolver()
enum, ok := resolver.(EnumPartitions)
if ok != true {
t.Fatalf("resolver must satisfy EnumPartition interface")
}
ps := enum.Partitions()
if a, e := len(ps), len(defaultPartitions); a != e {
t.Errorf("expected %d partitions, got %d", e, a)
}
}
func TestEnumDefaultRegions(t *testing.T) {
expectPart := defaultPartitions[0]
partEnum := defaultPartitions[0].Partition()
regEnum := partEnum.Regions()
if a, e := len(regEnum), len(expectPart.Regions); a != e {
t.Errorf("expected %d regions, got %d", e, a)
}
}
func TestEnumPartitionServices(t *testing.T) {
expectPart := testPartitions[0]
partEnum := testPartitions[0].Partition()
if a, e := partEnum.ID(), "part-id"; a != e {
t.Errorf("expect %q partition ID, got %q", e, a)
}
svcEnum := partEnum.Services()
if a, e := len(svcEnum), len(expectPart.Services); a != e {
t.Errorf("expected %d regions, got %d", e, a)
}
}
func TestEnumRegionServices(t *testing.T) {
p := testPartitions[0].Partition()
rs := p.Regions()
if a, e := len(rs), 2; a != e {
t.Errorf("expect %d regions got %d", e, a)
}
if _, ok := rs["us-east-1"]; !ok {
t.Errorf("expect us-east-1 region to be found, was not")
}
if _, ok := rs["us-west-2"]; !ok {
t.Errorf("expect us-west-2 region to be found, was not")
}
r := rs["us-east-1"]
if a, e := r.ID(), "us-east-1"; a != e {
t.Errorf("expect %q region ID, got %q", e, a)
}
if a, e := r.Description(), "region description"; a != e {
t.Errorf("expect %q region Description, got %q", e, a)
}
ss := r.Services()
if a, e := len(ss), 1; a != e {
t.Errorf("expect %d services for us-east-1, got %d", e, a)
}
if _, ok := ss["service1"]; !ok {
t.Errorf("expect service1 service to be found, was not")
}
resolved, err := r.ResolveEndpoint("service1")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service1.us-east-1.amazonaws.com"; a != e {
t.Errorf("expect %q resolved URL, got %q", e, a)
}
}
func TestEnumServiceRegions(t *testing.T) {
p := testPartitions[0].Partition()
rs := p.Services()["service1"].Regions()
if e, a := 2, len(rs); e != a {
t.Errorf("expect %d regions, got %d", e, a)
}
if _, ok := rs["us-east-1"]; !ok {
t.Errorf("expect region to be found")
}
if _, ok := rs["us-west-2"]; !ok {
t.Errorf("expect region to be found")
}
}
func TestEnumServicesEndpoints(t *testing.T) {
p := testPartitions[0].Partition()
ss := p.Services()
if a, e := len(ss), 5; a != e {
t.Errorf("expect %d regions got %d", e, a)
}
if _, ok := ss["service1"]; !ok {
t.Errorf("expect service1 region to be found, was not")
}
if _, ok := ss["service2"]; !ok {
t.Errorf("expect service2 region to be found, was not")
}
s := ss["service1"]
if a, e := s.ID(), "service1"; a != e {
t.Errorf("expect %q service ID, got %q", e, a)
}
resolved, err := s.ResolveEndpoint("us-west-2")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service1.us-west-2.amazonaws.com"; a != e {
t.Errorf("expect %q resolved URL, got %q", e, a)
}
}
func TestEnumEndpoints(t *testing.T) {
p := testPartitions[0].Partition()
s := p.Services()["service1"]
es := s.Endpoints()
if a, e := len(es), 2; a != e {
t.Errorf("expect %d endpoints for service2, got %d", e, a)
}
if _, ok := es["us-east-1"]; !ok {
t.Errorf("expect us-east-1 to be found, was not")
}
e := es["us-east-1"]
if a, e := e.ID(), "us-east-1"; a != e {
t.Errorf("expect %q endpoint ID, got %q", e, a)
}
if a, e := e.ServiceID(), "service1"; a != e {
t.Errorf("expect %q service ID, got %q", e, a)
}
resolved, err := e.ResolveEndpoint()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service1.us-east-1.amazonaws.com"; a != e {
t.Errorf("expect %q resolved URL, got %q", e, a)
}
}
func TestResolveEndpointForPartition(t *testing.T) {
enum := testPartitions.Partitions()[0]
expected, err := testPartitions.EndpointFor("service1", "us-east-1")
actual, err := enum.EndpointFor("service1", "us-east-1")
if err != nil {
t.Fatalf("unexpected error, %v", err)
}
if expected != actual {
t.Errorf("expect resolved endpoint to be %v, but got %v", expected, actual)
}
}
func TestAddScheme(t *testing.T) {
cases := []struct {
In string
Expect string
DisableSSL bool
}{
{
In: "https://example.com",
Expect: "https://example.com",
},
{
In: "example.com",
Expect: "https://example.com",
},
{
In: "http://example.com",
Expect: "http://example.com",
},
{
In: "example.com",
Expect: "http://example.com",
DisableSSL: true,
},
{
In: "https://example.com",
Expect: "https://example.com",
DisableSSL: true,
},
}
for i, c := range cases {
actual := AddScheme(c.In, c.DisableSSL)
if actual != c.Expect {
t.Errorf("%d, expect URL to be %q, got %q", i, c.Expect, actual)
}
}
}
func TestResolverFunc(t *testing.T) {
var resolver Resolver
resolver = ResolverFunc(func(s, r string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return ResolvedEndpoint{
URL: "https://service.region.dnssuffix.com",
SigningRegion: "region",
SigningName: "service",
}, nil
})
resolved, err := resolver.EndpointFor("service", "region", func(o *Options) {
o.DisableSSL = true
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if a, e := resolved.URL, "https://service.region.dnssuffix.com"; a != e {
t.Errorf("expect %q endpoint URL, got %q", e, a)
}
if a, e := resolved.SigningRegion, "region"; a != e {
t.Errorf("expect %q region, got %q", e, a)
}
if a, e := resolved.SigningName, "service"; a != e {
t.Errorf("expect %q signing name, got %q", e, a)
}
}
func TestOptionsSet(t *testing.T) {
var actual Options
actual.Set(DisableSSLOption, UseDualStackOption, StrictMatchingOption)
expect := Options{
DisableSSL: true,
UseDualStack: true,
StrictMatching: true,
}
if actual != expect {
t.Errorf("expect %v options got %v", expect, actual)
}
}
func TestRegionsForService(t *testing.T) {
ps := DefaultPartitions()
var expect map[string]Region
var serviceID string
for _, s := range ps[0].Services() {
expect = s.Regions()
serviceID = s.ID()
if len(expect) > 0 {
break
}
}
actual, ok := RegionsForService(ps, ps[0].ID(), serviceID)
if !ok {
t.Fatalf("expect regions to be found, was not")
}
if len(actual) == 0 {
t.Fatalf("expect service %s to have regions", serviceID)
}
if e, a := len(expect), len(actual); e != a {
t.Fatalf("expect %d regions, got %d", e, a)
}
for id, r := range actual {
if e, a := id, r.ID(); e != a {
t.Errorf("expect %s region id, got %s", e, a)
}
if _, ok := expect[id]; !ok {
t.Errorf("expect %s region to be found", id)
}
if a, e := r.Description(), expect[id].desc; a != e {
t.Errorf("expect %q region Description, got %q", e, a)
}
}
}
func TestRegionsForService_NotFound(t *testing.T) {
ps := testPartitions.Partitions()
actual, ok := RegionsForService(ps, ps[0].ID(), "service-not-exists")
if ok {
t.Fatalf("expect no regions to be found, but were")
}
if len(actual) != 0 {
t.Errorf("expect no regions, got %v", actual)
}
}
func TestPartitionForRegion(t *testing.T) {
ps := DefaultPartitions()
expect := ps[len(ps)%2]
var regionID string
for id := range expect.Regions() {
regionID = id
break
}
actual, ok := PartitionForRegion(ps, regionID)
if !ok {
t.Fatalf("expect partition to be found")
}
if e, a := expect.ID(), actual.ID(); e != a {
t.Errorf("expect %s partition, got %s", e, a)
}
}
func TestPartitionForRegion_NotFound(t *testing.T) {
ps := DefaultPartitions()
actual, ok := PartitionForRegion(ps, "regionNotExists")
if ok {
t.Errorf("expect no partition to be found, got %v", actual)
}
}

View File

@@ -0,0 +1,66 @@
package endpoints_test
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sqs"
)
func ExampleEnumPartitions() {
resolver := endpoints.DefaultResolver()
partitions := resolver.(endpoints.EnumPartitions).Partitions()
for _, p := range partitions {
fmt.Println("Regions for", p.ID())
for id := range p.Regions() {
fmt.Println("*", id)
}
fmt.Println("Services for", p.ID())
for id := range p.Services() {
fmt.Println("*", id)
}
}
}
func ExampleResolverFunc() {
myCustomResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if service == endpoints.S3ServiceID {
return endpoints.ResolvedEndpoint{
URL: "s3.custom.endpoint.com",
SigningRegion: "custom-signing-region",
}, nil
}
return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
}
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String("us-west-2"),
EndpointResolver: endpoints.ResolverFunc(myCustomResolver),
}))
// Create the S3 service client with the shared session. This will
// automatically use the S3 custom endpoint configured in the custom
// endpoint resolver wrapping the default endpoint resolver.
s3Svc := s3.New(sess)
// Operation calls will be made to the custom endpoint.
s3Svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String("myBucket"),
Key: aws.String("myObjectKey"),
})
// Create the SQS service client with the shared session. This will
// fallback to the default endpoint resolver because the customization
// passes any non S3 service endpoint resolve to the default resolver.
sqsSvc := sqs.New(sess)
// Operation calls will be made to the default endpoint for SQS for the
// region configured.
sqsSvc.ReceiveMessage(&sqs.ReceiveMessageInput{
QueueUrl: aws.String("my-queue-url"),
})
}

View File

@@ -0,0 +1,307 @@
package endpoints
import (
"fmt"
"regexp"
"strconv"
"strings"
)
type partitions []partition
func (ps partitions) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
var opt Options
opt.Set(opts...)
for i := 0; i < len(ps); i++ {
if !ps[i].canResolveEndpoint(service, region, opt.StrictMatching) {
continue
}
return ps[i].EndpointFor(service, region, opts...)
}
// If loose matching fallback to first partition format to use
// when resolving the endpoint.
if !opt.StrictMatching && len(ps) > 0 {
return ps[0].EndpointFor(service, region, opts...)
}
return ResolvedEndpoint{}, NewUnknownEndpointError("all partitions", service, region, []string{})
}
// Partitions satisfies the EnumPartitions interface and returns a list
// of Partitions representing each partition represented in the SDK's
// endpoints model.
func (ps partitions) Partitions() []Partition {
parts := make([]Partition, 0, len(ps))
for i := 0; i < len(ps); i++ {
parts = append(parts, ps[i].Partition())
}
return parts
}
type partition struct {
ID string `json:"partition"`
Name string `json:"partitionName"`
DNSSuffix string `json:"dnsSuffix"`
RegionRegex regionRegex `json:"regionRegex"`
Defaults endpoint `json:"defaults"`
Regions regions `json:"regions"`
Services services `json:"services"`
}
func (p partition) Partition() Partition {
return Partition{
id: p.ID,
p: &p,
}
}
func (p partition) canResolveEndpoint(service, region string, strictMatch bool) bool {
s, hasService := p.Services[service]
_, hasEndpoint := s.Endpoints[region]
if hasEndpoint && hasService {
return true
}
if strictMatch {
return false
}
return p.RegionRegex.MatchString(region)
}
func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (resolved ResolvedEndpoint, err error) {
var opt Options
opt.Set(opts...)
s, hasService := p.Services[service]
if !(hasService || opt.ResolveUnknownService) {
// Only return error if the resolver will not fallback to creating
// endpoint based on service endpoint ID passed in.
return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services))
}
e, hasEndpoint := s.endpointForRegion(region)
if !hasEndpoint && opt.StrictMatching {
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints))
}
defs := []endpoint{p.Defaults, s.Defaults}
return e.resolve(service, region, p.DNSSuffix, defs, opt), nil
}
func serviceList(ss services) []string {
list := make([]string, 0, len(ss))
for k := range ss {
list = append(list, k)
}
return list
}
func endpointList(es endpoints) []string {
list := make([]string, 0, len(es))
for k := range es {
list = append(list, k)
}
return list
}
type regionRegex struct {
*regexp.Regexp
}
func (rr *regionRegex) UnmarshalJSON(b []byte) (err error) {
// Strip leading and trailing quotes
regex, err := strconv.Unquote(string(b))
if err != nil {
return fmt.Errorf("unable to strip quotes from regex, %v", err)
}
rr.Regexp, err = regexp.Compile(regex)
if err != nil {
return fmt.Errorf("unable to unmarshal region regex, %v", err)
}
return nil
}
type regions map[string]region
type region struct {
Description string `json:"description"`
}
type services map[string]service
type service struct {
PartitionEndpoint string `json:"partitionEndpoint"`
IsRegionalized boxedBool `json:"isRegionalized,omitempty"`
Defaults endpoint `json:"defaults"`
Endpoints endpoints `json:"endpoints"`
}
func (s *service) endpointForRegion(region string) (endpoint, bool) {
if s.IsRegionalized == boxedFalse {
return s.Endpoints[s.PartitionEndpoint], region == s.PartitionEndpoint
}
if e, ok := s.Endpoints[region]; ok {
return e, true
}
// Unable to find any matching endpoint, return
// blank that will be used for generic endpoint creation.
return endpoint{}, false
}
type endpoints map[string]endpoint
type endpoint struct {
Hostname string `json:"hostname"`
Protocols []string `json:"protocols"`
CredentialScope credentialScope `json:"credentialScope"`
// Custom fields not modeled
HasDualStack boxedBool `json:"-"`
DualStackHostname string `json:"-"`
// Signature Version not used
SignatureVersions []string `json:"signatureVersions"`
// SSLCommonName not used.
SSLCommonName string `json:"sslCommonName"`
}
const (
defaultProtocol = "https"
defaultSigner = "v4"
)
var (
protocolPriority = []string{"https", "http"}
signerPriority = []string{"v4", "v2"}
)
func getByPriority(s []string, p []string, def string) string {
if len(s) == 0 {
return def
}
for i := 0; i < len(p); i++ {
for j := 0; j < len(s); j++ {
if s[j] == p[i] {
return s[j]
}
}
}
return s[0]
}
func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint {
var merged endpoint
for _, def := range defs {
merged.mergeIn(def)
}
merged.mergeIn(e)
e = merged
hostname := e.Hostname
// Offset the hostname for dualstack if enabled
if opts.UseDualStack && e.HasDualStack == boxedTrue {
hostname = e.DualStackHostname
}
u := strings.Replace(hostname, "{service}", service, 1)
u = strings.Replace(u, "{region}", region, 1)
u = strings.Replace(u, "{dnsSuffix}", dnsSuffix, 1)
scheme := getEndpointScheme(e.Protocols, opts.DisableSSL)
u = fmt.Sprintf("%s://%s", scheme, u)
signingRegion := e.CredentialScope.Region
if len(signingRegion) == 0 {
signingRegion = region
}
signingName := e.CredentialScope.Service
var signingNameDerived bool
if len(signingName) == 0 {
signingName = service
signingNameDerived = true
}
return ResolvedEndpoint{
URL: u,
SigningRegion: signingRegion,
SigningName: signingName,
SigningNameDerived: signingNameDerived,
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
}
}
func getEndpointScheme(protocols []string, disableSSL bool) string {
if disableSSL {
return "http"
}
return getByPriority(protocols, protocolPriority, defaultProtocol)
}
func (e *endpoint) mergeIn(other endpoint) {
if len(other.Hostname) > 0 {
e.Hostname = other.Hostname
}
if len(other.Protocols) > 0 {
e.Protocols = other.Protocols
}
if len(other.SignatureVersions) > 0 {
e.SignatureVersions = other.SignatureVersions
}
if len(other.CredentialScope.Region) > 0 {
e.CredentialScope.Region = other.CredentialScope.Region
}
if len(other.CredentialScope.Service) > 0 {
e.CredentialScope.Service = other.CredentialScope.Service
}
if len(other.SSLCommonName) > 0 {
e.SSLCommonName = other.SSLCommonName
}
if other.HasDualStack != boxedBoolUnset {
e.HasDualStack = other.HasDualStack
}
if len(other.DualStackHostname) > 0 {
e.DualStackHostname = other.DualStackHostname
}
}
type credentialScope struct {
Region string `json:"region"`
Service string `json:"service"`
}
type boxedBool int
func (b *boxedBool) UnmarshalJSON(buf []byte) error {
v, err := strconv.ParseBool(string(buf))
if err != nil {
return err
}
if v {
*b = boxedTrue
} else {
*b = boxedFalse
}
return nil
}
const (
boxedBoolUnset boxedBool = iota
boxedFalse
boxedTrue
)

View File

@@ -0,0 +1,337 @@
// +build codegen
package endpoints
import (
"fmt"
"io"
"reflect"
"strings"
"text/template"
"unicode"
)
// A CodeGenOptions are the options for code generating the endpoints into
// Go code from the endpoints model definition.
type CodeGenOptions struct {
// Options for how the model will be decoded.
DecodeModelOptions DecodeModelOptions
}
// Set combines all of the option functions together
func (d *CodeGenOptions) Set(optFns ...func(*CodeGenOptions)) {
for _, fn := range optFns {
fn(d)
}
}
// CodeGenModel given a endpoints model file will decode it and attempt to
// generate Go code from the model definition. Error will be returned if
// the code is unable to be generated, or decoded.
func CodeGenModel(modelFile io.Reader, outFile io.Writer, optFns ...func(*CodeGenOptions)) error {
var opts CodeGenOptions
opts.Set(optFns...)
resolver, err := DecodeModel(modelFile, func(d *DecodeModelOptions) {
*d = opts.DecodeModelOptions
})
if err != nil {
return err
}
tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl))
if err := tmpl.ExecuteTemplate(outFile, "defaults", resolver); err != nil {
return fmt.Errorf("failed to execute template, %v", err)
}
return nil
}
func toSymbol(v string) string {
out := []rune{}
for _, c := range strings.Title(v) {
if !(unicode.IsNumber(c) || unicode.IsLetter(c)) {
continue
}
out = append(out, c)
}
return string(out)
}
func quoteString(v string) string {
return fmt.Sprintf("%q", v)
}
func regionConstName(p, r string) string {
return toSymbol(p) + toSymbol(r)
}
func partitionGetter(id string) string {
return fmt.Sprintf("%sPartition", toSymbol(id))
}
func partitionVarName(id string) string {
return fmt.Sprintf("%sPartition", strings.ToLower(toSymbol(id)))
}
func listPartitionNames(ps partitions) string {
names := []string{}
switch len(ps) {
case 1:
return ps[0].Name
case 2:
return fmt.Sprintf("%s and %s", ps[0].Name, ps[1].Name)
default:
for i, p := range ps {
if i == len(ps)-1 {
names = append(names, "and "+p.Name)
} else {
names = append(names, p.Name)
}
}
return strings.Join(names, ", ")
}
}
func boxedBoolIfSet(msg string, v boxedBool) string {
switch v {
case boxedTrue:
return fmt.Sprintf(msg, "boxedTrue")
case boxedFalse:
return fmt.Sprintf(msg, "boxedFalse")
default:
return ""
}
}
func stringIfSet(msg, v string) string {
if len(v) == 0 {
return ""
}
return fmt.Sprintf(msg, v)
}
func stringSliceIfSet(msg string, vs []string) string {
if len(vs) == 0 {
return ""
}
names := []string{}
for _, v := range vs {
names = append(names, `"`+v+`"`)
}
return fmt.Sprintf(msg, strings.Join(names, ","))
}
func endpointIsSet(v endpoint) bool {
return !reflect.DeepEqual(v, endpoint{})
}
func serviceSet(ps partitions) map[string]struct{} {
set := map[string]struct{}{}
for _, p := range ps {
for id := range p.Services {
set[id] = struct{}{}
}
}
return set
}
var funcMap = template.FuncMap{
"ToSymbol": toSymbol,
"QuoteString": quoteString,
"RegionConst": regionConstName,
"PartitionGetter": partitionGetter,
"PartitionVarName": partitionVarName,
"ListPartitionNames": listPartitionNames,
"BoxedBoolIfSet": boxedBoolIfSet,
"StringIfSet": stringIfSet,
"StringSliceIfSet": stringSliceIfSet,
"EndpointIsSet": endpointIsSet,
"ServicesSet": serviceSet,
}
const v3Tmpl = `
{{ define "defaults" -}}
// Code generated by aws/endpoints/v3model_codegen.go. DO NOT EDIT.
package endpoints
import (
"regexp"
)
{{ template "partition consts" . }}
{{ range $_, $partition := . }}
{{ template "partition region consts" $partition }}
{{ end }}
{{ template "service consts" . }}
{{ template "endpoint resolvers" . }}
{{- end }}
{{ define "partition consts" }}
// Partition identifiers
const (
{{ range $_, $p := . -}}
{{ ToSymbol $p.ID }}PartitionID = {{ QuoteString $p.ID }} // {{ $p.Name }} partition.
{{ end -}}
)
{{- end }}
{{ define "partition region consts" }}
// {{ .Name }} partition's regions.
const (
{{ range $id, $region := .Regions -}}
{{ ToSymbol $id }}RegionID = {{ QuoteString $id }} // {{ $region.Description }}.
{{ end -}}
)
{{- end }}
{{ define "service consts" }}
// Service identifiers
const (
{{ $serviceSet := ServicesSet . -}}
{{ range $id, $_ := $serviceSet -}}
{{ ToSymbol $id }}ServiceID = {{ QuoteString $id }} // {{ ToSymbol $id }}.
{{ end -}}
)
{{- end }}
{{ define "endpoint resolvers" }}
// DefaultResolver returns an Endpoint resolver that will be able
// to resolve endpoints for: {{ ListPartitionNames . }}.
//
// Use DefaultPartitions() to get the list of the default partitions.
func DefaultResolver() Resolver {
return defaultPartitions
}
// DefaultPartitions returns a list of the partitions the SDK is bundled
// with. The available partitions are: {{ ListPartitionNames . }}.
//
// partitions := endpoints.DefaultPartitions
// for _, p := range partitions {
// // ... inspect partitions
// }
func DefaultPartitions() []Partition {
return defaultPartitions.Partitions()
}
var defaultPartitions = partitions{
{{ range $_, $partition := . -}}
{{ PartitionVarName $partition.ID }},
{{ end }}
}
{{ range $_, $partition := . -}}
{{ $name := PartitionGetter $partition.ID -}}
// {{ $name }} returns the Resolver for {{ $partition.Name }}.
func {{ $name }}() Partition {
return {{ PartitionVarName $partition.ID }}.Partition()
}
var {{ PartitionVarName $partition.ID }} = {{ template "gocode Partition" $partition }}
{{ end }}
{{ end }}
{{ define "default partitions" }}
func DefaultPartitions() []Partition {
return []partition{
{{ range $_, $partition := . -}}
// {{ ToSymbol $partition.ID}}Partition(),
{{ end }}
}
}
{{ end }}
{{ define "gocode Partition" -}}
partition{
{{ StringIfSet "ID: %q,\n" .ID -}}
{{ StringIfSet "Name: %q,\n" .Name -}}
{{ StringIfSet "DNSSuffix: %q,\n" .DNSSuffix -}}
RegionRegex: {{ template "gocode RegionRegex" .RegionRegex }},
{{ if EndpointIsSet .Defaults -}}
Defaults: {{ template "gocode Endpoint" .Defaults }},
{{- end }}
Regions: {{ template "gocode Regions" .Regions }},
Services: {{ template "gocode Services" .Services }},
}
{{- end }}
{{ define "gocode RegionRegex" -}}
regionRegex{
Regexp: func() *regexp.Regexp{
reg, _ := regexp.Compile({{ QuoteString .Regexp.String }})
return reg
}(),
}
{{- end }}
{{ define "gocode Regions" -}}
regions{
{{ range $id, $region := . -}}
"{{ $id }}": {{ template "gocode Region" $region }},
{{ end -}}
}
{{- end }}
{{ define "gocode Region" -}}
region{
{{ StringIfSet "Description: %q,\n" .Description -}}
}
{{- end }}
{{ define "gocode Services" -}}
services{
{{ range $id, $service := . -}}
"{{ $id }}": {{ template "gocode Service" $service }},
{{ end }}
}
{{- end }}
{{ define "gocode Service" -}}
service{
{{ StringIfSet "PartitionEndpoint: %q,\n" .PartitionEndpoint -}}
{{ BoxedBoolIfSet "IsRegionalized: %s,\n" .IsRegionalized -}}
{{ if EndpointIsSet .Defaults -}}
Defaults: {{ template "gocode Endpoint" .Defaults -}},
{{- end }}
{{ if .Endpoints -}}
Endpoints: {{ template "gocode Endpoints" .Endpoints }},
{{- end }}
}
{{- end }}
{{ define "gocode Endpoints" -}}
endpoints{
{{ range $id, $endpoint := . -}}
"{{ $id }}": {{ template "gocode Endpoint" $endpoint }},
{{ end }}
}
{{- end }}
{{ define "gocode Endpoint" -}}
endpoint{
{{ StringIfSet "Hostname: %q,\n" .Hostname -}}
{{ StringIfSet "SSLCommonName: %q,\n" .SSLCommonName -}}
{{ StringSliceIfSet "Protocols: []string{%s},\n" .Protocols -}}
{{ StringSliceIfSet "SignatureVersions: []string{%s},\n" .SignatureVersions -}}
{{ if or .CredentialScope.Region .CredentialScope.Service -}}
CredentialScope: credentialScope{
{{ StringIfSet "Region: %q,\n" .CredentialScope.Region -}}
{{ StringIfSet "Service: %q,\n" .CredentialScope.Service -}}
},
{{- end }}
{{ BoxedBoolIfSet "HasDualStack: %s,\n" .HasDualStack -}}
{{ StringIfSet "DualStackHostname: %q,\n" .DualStackHostname -}}
}
{{- end }}
`

View File

@@ -0,0 +1,541 @@
package endpoints
import (
"encoding/json"
"reflect"
"regexp"
"testing"
)
func TestUnmarshalRegionRegex(t *testing.T) {
var input = []byte(`
{
"regionRegex": "^(us|eu|ap|sa|ca)\\-\\w+\\-\\d+$"
}`)
p := partition{}
err := json.Unmarshal(input, &p)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
expectRegexp, err := regexp.Compile(`^(us|eu|ap|sa|ca)\-\w+\-\d+$`)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := expectRegexp.String(), p.RegionRegex.Regexp.String(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestUnmarshalRegion(t *testing.T) {
var input = []byte(`
{
"aws-global": {
"description": "AWS partition-global endpoint"
},
"us-east-1": {
"description": "US East (N. Virginia)"
}
}`)
rs := regions{}
err := json.Unmarshal(input, &rs)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := 2, len(rs); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
r, ok := rs["aws-global"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := "AWS partition-global endpoint", r.Description; e != a {
t.Errorf("expect %v, got %v", e, a)
}
r, ok = rs["us-east-1"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := "US East (N. Virginia)", r.Description; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestUnmarshalServices(t *testing.T) {
var input = []byte(`
{
"acm": {
"endpoints": {
"us-east-1": {}
}
},
"apigateway": {
"isRegionalized": true,
"endpoints": {
"us-east-1": {},
"us-west-2": {}
}
},
"notRegionalized": {
"isRegionalized": false,
"endpoints": {
"us-east-1": {},
"us-west-2": {}
}
}
}`)
ss := services{}
err := json.Unmarshal(input, &ss)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := 3, len(ss); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
s, ok := ss["acm"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := 1, len(s.Endpoints); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
if e, a := boxedBoolUnset, s.IsRegionalized; e != a {
t.Errorf("expect %v, got %v", e, a)
}
s, ok = ss["apigateway"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := 2, len(s.Endpoints); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
if e, a := boxedTrue, s.IsRegionalized; e != a {
t.Errorf("expect %v, got %v", e, a)
}
s, ok = ss["notRegionalized"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := 2, len(s.Endpoints); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
if e, a := boxedFalse, s.IsRegionalized; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestUnmarshalEndpoints(t *testing.T) {
var inputs = []byte(`
{
"aws-global": {
"hostname": "cloudfront.amazonaws.com",
"protocols": [
"http",
"https"
],
"signatureVersions": [ "v4" ],
"credentialScope": {
"region": "us-east-1",
"service": "serviceName"
},
"sslCommonName": "commonName"
},
"us-east-1": {}
}`)
es := endpoints{}
err := json.Unmarshal(inputs, &es)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := 2, len(es); e != a {
t.Errorf("expect %v len, got %v", e, a)
}
s, ok := es["aws-global"]
if !ok {
t.Errorf("expect found, was not")
}
if e, a := "cloudfront.amazonaws.com", s.Hostname; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := []string{"http", "https"}, s.Protocols; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := []string{"v4"}, s.SignatureVersions; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := (credentialScope{"us-east-1", "serviceName"}), s.CredentialScope; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "commonName", s.SSLCommonName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEndpointResolve(t *testing.T) {
defs := []endpoint{
{
Hostname: "{service}.{region}.{dnsSuffix}",
SignatureVersions: []string{"v2"},
SSLCommonName: "sslCommonName",
},
{
Hostname: "other-hostname",
Protocols: []string{"http"},
CredentialScope: credentialScope{
Region: "signing_region",
Service: "signing_service",
},
},
}
e := endpoint{
Hostname: "{service}.{region}.{dnsSuffix}",
Protocols: []string{"http", "https"},
SignatureVersions: []string{"v4"},
SSLCommonName: "new sslCommonName",
}
resolved := e.resolve("service", "region", "dnsSuffix",
defs, Options{},
)
if e, a := "https://service.region.dnsSuffix", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "signing_service", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "signing_region", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "v4", resolved.SigningMethod; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEndpointMergeIn(t *testing.T) {
expected := endpoint{
Hostname: "other hostname",
Protocols: []string{"http"},
SignatureVersions: []string{"v4"},
SSLCommonName: "ssl common name",
CredentialScope: credentialScope{
Region: "region",
Service: "service",
},
}
actual := endpoint{}
actual.mergeIn(endpoint{
Hostname: "other hostname",
Protocols: []string{"http"},
SignatureVersions: []string{"v4"},
SSLCommonName: "ssl common name",
CredentialScope: credentialScope{
Region: "region",
Service: "service",
},
})
if e, a := expected, actual; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
}
var testPartitions = partitions{
partition{
ID: "part-id",
Name: "partitionName",
DNSSuffix: "amazonaws.com",
RegionRegex: regionRegex{
Regexp: func() *regexp.Regexp {
reg, _ := regexp.Compile("^(us|eu|ap|sa|ca)\\-\\w+\\-\\d+$")
return reg
}(),
},
Defaults: endpoint{
Hostname: "{service}.{region}.{dnsSuffix}",
Protocols: []string{"https"},
SignatureVersions: []string{"v4"},
},
Regions: regions{
"us-east-1": region{
Description: "region description",
},
"us-west-2": region{},
},
Services: services{
"s3": service{},
"service1": service{
Defaults: endpoint{
CredentialScope: credentialScope{
Service: "service1",
},
},
Endpoints: endpoints{
"us-east-1": {},
"us-west-2": {
HasDualStack: boxedTrue,
DualStackHostname: "{service}.dualstack.{region}.{dnsSuffix}",
},
},
},
"service2": service{
Defaults: endpoint{
CredentialScope: credentialScope{
Service: "service2",
},
},
},
"httpService": service{
Defaults: endpoint{
Protocols: []string{"http"},
},
},
"globalService": service{
IsRegionalized: boxedFalse,
PartitionEndpoint: "aws-global",
Endpoints: endpoints{
"aws-global": endpoint{
CredentialScope: credentialScope{
Region: "us-east-1",
},
Hostname: "globalService.amazonaws.com",
},
},
},
},
},
}
func TestResolveEndpoint(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-west-2")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service2.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_DisableSSL(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-west-2", DisableSSLOption)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "http://service2.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_UseDualStack(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service1", "us-west-2", UseDualStackOption)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service1.dualstack.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service1", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_HTTPProtocol(t *testing.T) {
resolved, err := testPartitions.EndpointFor("httpService", "us-west-2")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "http://httpService.us-west-2.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-2", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "httpService", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}
func TestResolveEndpoint_UnknownService(t *testing.T) {
_, err := testPartitions.EndpointFor("unknownservice", "us-west-2")
if err == nil {
t.Errorf("expect error, got none")
}
_, ok := err.(UnknownServiceError)
if !ok {
t.Errorf("expect error to be UnknownServiceError")
}
}
func TestResolveEndpoint_ResolveUnknownService(t *testing.T) {
resolved, err := testPartitions.EndpointFor("unknown-service", "us-region-1",
ResolveUnknownServiceOption)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://unknown-service.us-region-1.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-region-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "unknown-service", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}
func TestResolveEndpoint_UnknownMatchedRegion(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "us-region-1")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service2.us-region-1.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-region-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_UnknownRegion(t *testing.T) {
resolved, err := testPartitions.EndpointFor("service2", "unknownregion")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://service2.unknownregion.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "unknownregion", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "service2", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if resolved.SigningNameDerived {
t.Errorf("expect the signing name not to be derived, but was")
}
}
func TestResolveEndpoint_StrictPartitionUnknownEndpoint(t *testing.T) {
_, err := testPartitions[0].EndpointFor("service2", "unknownregion", StrictMatchingOption)
if err == nil {
t.Errorf("expect error, got none")
}
_, ok := err.(UnknownEndpointError)
if !ok {
t.Errorf("expect error to be UnknownEndpointError")
}
}
func TestResolveEndpoint_StrictPartitionsUnknownEndpoint(t *testing.T) {
_, err := testPartitions.EndpointFor("service2", "us-region-1", StrictMatchingOption)
if err == nil {
t.Errorf("expect error, got none")
}
_, ok := err.(UnknownEndpointError)
if !ok {
t.Errorf("expect error to be UnknownEndpointError")
}
}
func TestResolveEndpoint_NotRegionalized(t *testing.T) {
resolved, err := testPartitions.EndpointFor("globalService", "us-west-2")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://globalService.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-east-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "globalService", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}
func TestResolveEndpoint_AwsGlobal(t *testing.T) {
resolved, err := testPartitions.EndpointFor("globalService", "aws-global")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := "https://globalService.amazonaws.com", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-east-1", resolved.SigningRegion; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "globalService", resolved.SigningName; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if !resolved.SigningNameDerived {
t.Errorf("expect the signing name to be derived")
}
}

View File

@@ -5,13 +5,9 @@ import "github.com/aws/aws-sdk-go/aws/awserr"
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
//
// @readonly
ErrMissingRegion = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
//
// @readonly
ErrMissingEndpoint = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)

12
vendor/github.com/aws/aws-sdk-go/aws/jsonvalue.go generated vendored Normal file
View File

@@ -0,0 +1,12 @@
package aws
// JSONValue is a representation of a grab bag type that will be marshaled
// into a json string. This type can be used just like any other map.
//
// Example:
//
// values := aws.JSONValue{
// "Foo": "Bar",
// }
// values["Baz"] = "Qux"
type JSONValue map[string]interface{}

View File

@@ -26,14 +26,14 @@ func (l *LogLevelType) Value() LogLevelType {
// Matches returns true if the v LogLevel is enabled by this LogLevel. Should be
// used with logging sub levels. Is safe to use on nil value LogLevelTypes. If
// LogLevel is nill, will default to LogOff comparison.
// LogLevel is nil, will default to LogOff comparison.
func (l *LogLevelType) Matches(v LogLevelType) bool {
c := l.Value()
return c&v == v
}
// AtLeast returns true if this LogLevel is at least high enough to satisfies v.
// Is safe to use on nil value LogLevelTypes. If LogLevel is nill, will default
// Is safe to use on nil value LogLevelTypes. If LogLevel is nil, will default
// to LogOff comparison.
func (l *LogLevelType) AtLeast(v LogLevelType) bool {
c := l.Value()
@@ -71,6 +71,12 @@ const (
// LogDebugWithRequestErrors states the SDK should log when service requests fail
// to build, send, validate, or unmarshal.
LogDebugWithRequestErrors
// LogDebugWithEventStreamBody states the SDK should log EventStream
// request and response bodys. This should be used to log the EventStream
// wire unmarshaled message content of requests and responses made while
// using the SDK Will also enable LogDebug.
LogDebugWithEventStreamBody
)
// A Logger is a minimalistic interface for the SDK to log messages to. Should
@@ -79,6 +85,20 @@ type Logger interface {
Log(...interface{})
}
// A LoggerFunc is a convenience type to convert a function taking a variadic
// list of arguments and wrap it so the Logger interface can be used.
//
// Example:
// s3.New(sess, &aws.Config{Logger: aws.LoggerFunc(func(args ...interface{}) {
// fmt.Fprintln(os.Stdout, args...)
// })})
type LoggerFunc func(...interface{})
// Log calls the wrapped function with the arguments provided
func (f LoggerFunc) Log(args ...interface{}) {
f(args...)
}
// NewDefaultLogger returns a Logger which will write log messages to stdout, and
// use same formatting runes as the stdlib log.Logger
func NewDefaultLogger() Logger {

View File

@@ -0,0 +1,19 @@
// +build !appengine,!plan9
package request
import (
"net"
"os"
"syscall"
)
func isErrConnectionReset(err error) bool {
if opErr, ok := err.(*net.OpError); ok {
if sysErr, ok := opErr.Err.(*os.SyscallError); ok {
return sysErr.Err == syscall.ECONNRESET
}
}
return false
}

View File

@@ -0,0 +1,11 @@
// +build appengine plan9
package request
import (
"strings"
)
func isErrConnectionReset(err error) bool {
return strings.Contains(err.Error(), "connection reset")
}

View File

@@ -0,0 +1,9 @@
// +build appengine plan9
package request_test
import (
"errors"
)
var stubConnectionResetError = errors.New("connection reset")

View File

@@ -0,0 +1,11 @@
// +build !appengine,!plan9
package request_test
import (
"net"
"os"
"syscall"
)
var stubConnectionResetError = &net.OpError{Err: &os.SyscallError{Syscall: "read", Err: syscall.ECONNRESET}}

View File

@@ -14,10 +14,12 @@ type Handlers struct {
Send HandlerList
ValidateResponse HandlerList
Unmarshal HandlerList
UnmarshalStream HandlerList
UnmarshalMeta HandlerList
UnmarshalError HandlerList
Retry HandlerList
AfterRetry HandlerList
Complete HandlerList
}
// Copy returns of this handler's lists.
@@ -29,10 +31,12 @@ func (h *Handlers) Copy() Handlers {
Send: h.Send.copy(),
ValidateResponse: h.ValidateResponse.copy(),
Unmarshal: h.Unmarshal.copy(),
UnmarshalStream: h.UnmarshalStream.copy(),
UnmarshalError: h.UnmarshalError.copy(),
UnmarshalMeta: h.UnmarshalMeta.copy(),
Retry: h.Retry.copy(),
AfterRetry: h.AfterRetry.copy(),
Complete: h.Complete.copy(),
}
}
@@ -43,16 +47,37 @@ func (h *Handlers) Clear() {
h.Send.Clear()
h.Sign.Clear()
h.Unmarshal.Clear()
h.UnmarshalStream.Clear()
h.UnmarshalMeta.Clear()
h.UnmarshalError.Clear()
h.ValidateResponse.Clear()
h.Retry.Clear()
h.AfterRetry.Clear()
h.Complete.Clear()
}
// A HandlerListRunItem represents an entry in the HandlerList which
// is being run.
type HandlerListRunItem struct {
Index int
Handler NamedHandler
Request *Request
}
// A HandlerList manages zero or more handlers in a list.
type HandlerList struct {
list []NamedHandler
// Called after each request handler in the list is called. If set
// and the func returns true the HandlerList will continue to iterate
// over the request handlers. If false is returned the HandlerList
// will stop iterating.
//
// Should be used if extra logic to be performed between each handler
// in the list. This can be used to terminate a list's iteration
// based on a condition such as error like, HandlerListStopOnError.
// Or for logging like HandlerListLogItem.
AfterEachFn func(item HandlerListRunItem) bool
}
// A NamedHandler is a struct that contains a name and function callback.
@@ -63,14 +88,20 @@ type NamedHandler struct {
// copy creates a copy of the handler list.
func (l *HandlerList) copy() HandlerList {
var n HandlerList
n.list = append([]NamedHandler{}, l.list...)
n := HandlerList{
AfterEachFn: l.AfterEachFn,
}
if len(l.list) == 0 {
return n
}
n.list = append(make([]NamedHandler, 0, len(l.list)), l.list...)
return n
}
// Clear clears the handler list.
func (l *HandlerList) Clear() {
l.list = []NamedHandler{}
l.list = l.list[0:0]
}
// Len returns the number of handlers in the list.
@@ -80,39 +111,142 @@ func (l *HandlerList) Len() int {
// PushBack pushes handler f to the back of the handler list.
func (l *HandlerList) PushBack(f func(*Request)) {
l.list = append(l.list, NamedHandler{"__anonymous", f})
}
// PushFront pushes handler f to the front of the handler list.
func (l *HandlerList) PushFront(f func(*Request)) {
l.list = append([]NamedHandler{{"__anonymous", f}}, l.list...)
l.PushBackNamed(NamedHandler{"__anonymous", f})
}
// PushBackNamed pushes named handler f to the back of the handler list.
func (l *HandlerList) PushBackNamed(n NamedHandler) {
if cap(l.list) == 0 {
l.list = make([]NamedHandler, 0, 5)
}
l.list = append(l.list, n)
}
// PushFront pushes handler f to the front of the handler list.
func (l *HandlerList) PushFront(f func(*Request)) {
l.PushFrontNamed(NamedHandler{"__anonymous", f})
}
// PushFrontNamed pushes named handler f to the front of the handler list.
func (l *HandlerList) PushFrontNamed(n NamedHandler) {
l.list = append([]NamedHandler{n}, l.list...)
if cap(l.list) == len(l.list) {
// Allocating new list required
l.list = append([]NamedHandler{n}, l.list...)
} else {
// Enough room to prepend into list.
l.list = append(l.list, NamedHandler{})
copy(l.list[1:], l.list)
l.list[0] = n
}
}
// Remove removes a NamedHandler n
func (l *HandlerList) Remove(n NamedHandler) {
newlist := []NamedHandler{}
for _, m := range l.list {
if m.Name != n.Name {
newlist = append(newlist, m)
l.RemoveByName(n.Name)
}
// RemoveByName removes a NamedHandler by name.
func (l *HandlerList) RemoveByName(name string) {
for i := 0; i < len(l.list); i++ {
m := l.list[i]
if m.Name == name {
// Shift array preventing creating new arrays
copy(l.list[i:], l.list[i+1:])
l.list[len(l.list)-1] = NamedHandler{}
l.list = l.list[:len(l.list)-1]
// decrement list so next check to length is correct
i--
}
}
l.list = newlist
}
// SwapNamed will swap out any existing handlers with the same name as the
// passed in NamedHandler returning true if handlers were swapped. False is
// returned otherwise.
func (l *HandlerList) SwapNamed(n NamedHandler) (swapped bool) {
for i := 0; i < len(l.list); i++ {
if l.list[i].Name == n.Name {
l.list[i].Fn = n.Fn
swapped = true
}
}
return swapped
}
// Swap will swap out all handlers matching the name passed in. The matched
// handlers will be swapped in. True is returned if the handlers were swapped.
func (l *HandlerList) Swap(name string, replace NamedHandler) bool {
var swapped bool
for i := 0; i < len(l.list); i++ {
if l.list[i].Name == name {
l.list[i] = replace
swapped = true
}
}
return swapped
}
// SetBackNamed will replace the named handler if it exists in the handler list.
// If the handler does not exist the handler will be added to the end of the list.
func (l *HandlerList) SetBackNamed(n NamedHandler) {
if !l.SwapNamed(n) {
l.PushBackNamed(n)
}
}
// SetFrontNamed will replace the named handler if it exists in the handler list.
// If the handler does not exist the handler will be added to the beginning of
// the list.
func (l *HandlerList) SetFrontNamed(n NamedHandler) {
if !l.SwapNamed(n) {
l.PushFrontNamed(n)
}
}
// Run executes all handlers in the list with a given request object.
func (l *HandlerList) Run(r *Request) {
for _, f := range l.list {
f.Fn(r)
for i, h := range l.list {
h.Fn(r)
item := HandlerListRunItem{
Index: i, Handler: h, Request: r,
}
if l.AfterEachFn != nil && !l.AfterEachFn(item) {
return
}
}
}
// HandlerListLogItem logs the request handler and the state of the
// request's Error value. Always returns true to continue iterating
// request handlers in a HandlerList.
func HandlerListLogItem(item HandlerListRunItem) bool {
if item.Request.Config.Logger == nil {
return true
}
item.Request.Config.Logger.Log("DEBUG: RequestHandler",
item.Index, item.Handler.Name, item.Request.Error)
return true
}
// HandlerListStopOnError returns false to stop the HandlerList iterating
// over request handlers if Request.Error is not nil. True otherwise
// to continue iterating.
func HandlerListStopOnError(item HandlerListRunItem) bool {
return item.Request.Error == nil
}
// WithAppendUserAgent will add a string to the user agent prefixed with a
// single white space.
func WithAppendUserAgent(s string) Option {
return func(r *Request) {
r.Handlers.Build.PushBack(func(r2 *Request) {
AddToUserAgent(r, s)
})
}
}

View File

@@ -1,12 +1,13 @@
package request_test
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestHandlerList(t *testing.T) {
@@ -18,8 +19,12 @@ func TestHandlerList(t *testing.T) {
r.Data = s
})
l.Run(r)
assert.Equal(t, "a", s)
assert.Equal(t, "a", r.Data)
if e, a := "a", s; e != a {
t.Errorf("expect %q update got %q", e, a)
}
if e, a := "a", r.Data.(string); e != a {
t.Errorf("expect %q data update got %q", e, a)
}
}
func TestMultipleHandlers(t *testing.T) {
@@ -41,7 +46,221 @@ func TestNamedHandlers(t *testing.T) {
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.PushBack(func(r *request.Request) {})
assert.Equal(t, 4, l.Len())
if e, a := 4, l.Len(); e != a {
t.Errorf("expect %d list length, got %d", e, a)
}
l.Remove(named)
assert.Equal(t, 2, l.Len())
if e, a := 2, l.Len(); e != a {
t.Errorf("expect %d list length, got %d", e, a)
}
}
func TestSwapHandlers(t *testing.T) {
firstHandlerCalled := 0
swappedOutHandlerCalled := 0
swappedInHandlerCalled := 0
l := request.HandlerList{}
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {
firstHandlerCalled++
}}
named2 := request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
swappedOutHandlerCalled++
}}
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.PushBackNamed(named)
l.SwapNamed(request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
swappedInHandlerCalled++
}})
l.Run(&request.Request{})
if e, a := 2, firstHandlerCalled; e != a {
t.Errorf("expect first handler to be called %d, was called %d times", e, a)
}
if n := swappedOutHandlerCalled; n != 0 {
t.Errorf("expect swapped out handler to not be called, was called %d times", n)
}
if e, a := 1, swappedInHandlerCalled; e != a {
t.Errorf("expect swapped in handler to be called %d, was called %d times", e, a)
}
}
func TestSetBackNamed_Exists(t *testing.T) {
firstHandlerCalled := 0
swappedOutHandlerCalled := 0
swappedInHandlerCalled := 0
l := request.HandlerList{}
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {
firstHandlerCalled++
}}
named2 := request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
swappedOutHandlerCalled++
}}
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.SetBackNamed(request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
swappedInHandlerCalled++
}})
l.Run(&request.Request{})
if e, a := 1, firstHandlerCalled; e != a {
t.Errorf("expect first handler to be called %d, was called %d times", e, a)
}
if n := swappedOutHandlerCalled; n != 0 {
t.Errorf("expect swapped out handler to not be called, was called %d times", n)
}
if e, a := 1, swappedInHandlerCalled; e != a {
t.Errorf("expect swapped in handler to be called %d, was called %d times", e, a)
}
}
func TestSetBackNamed_NotExists(t *testing.T) {
firstHandlerCalled := 0
secondHandlerCalled := 0
swappedInHandlerCalled := 0
l := request.HandlerList{}
named := request.NamedHandler{Name: "Name", Fn: func(r *request.Request) {
firstHandlerCalled++
}}
named2 := request.NamedHandler{Name: "OtherName", Fn: func(r *request.Request) {
secondHandlerCalled++
}}
l.PushBackNamed(named)
l.PushBackNamed(named2)
l.SetBackNamed(request.NamedHandler{Name: "SwapOutName", Fn: func(r *request.Request) {
swappedInHandlerCalled++
}})
l.Run(&request.Request{})
if e, a := 1, firstHandlerCalled; e != a {
t.Errorf("expect first handler to be called %d, was called %d times", e, a)
}
if e, a := 1, secondHandlerCalled; e != a {
t.Errorf("expect second handler to be called %d, was called %d times", e, a)
}
if e, a := 1, swappedInHandlerCalled; e != a {
t.Errorf("expect swapped in handler to be called %d, was called %d times", e, a)
}
}
func TestLoggedHandlers(t *testing.T) {
expectedHandlers := []string{"name1", "name2"}
l := request.HandlerList{}
loggedHandlers := []string{}
l.AfterEachFn = request.HandlerListLogItem
cfg := aws.Config{Logger: aws.LoggerFunc(func(args ...interface{}) {
loggedHandlers = append(loggedHandlers, args[2].(string))
})}
named1 := request.NamedHandler{Name: "name1", Fn: func(r *request.Request) {}}
named2 := request.NamedHandler{Name: "name2", Fn: func(r *request.Request) {}}
l.PushBackNamed(named1)
l.PushBackNamed(named2)
l.Run(&request.Request{Config: cfg})
if !reflect.DeepEqual(expectedHandlers, loggedHandlers) {
t.Errorf("expect handlers executed %v to match logged handlers, %v",
expectedHandlers, loggedHandlers)
}
}
func TestStopHandlers(t *testing.T) {
l := request.HandlerList{}
stopAt := 1
l.AfterEachFn = func(item request.HandlerListRunItem) bool {
return item.Index != stopAt
}
called := 0
l.PushBackNamed(request.NamedHandler{Name: "name1", Fn: func(r *request.Request) {
called++
}})
l.PushBackNamed(request.NamedHandler{Name: "name2", Fn: func(r *request.Request) {
called++
}})
l.PushBackNamed(request.NamedHandler{Name: "name3", Fn: func(r *request.Request) {
t.Fatalf("third handler should not be called")
}})
l.Run(&request.Request{})
if e, a := 2, called; e != a {
t.Errorf("expect %d handlers called, got %d", e, a)
}
}
func BenchmarkNewRequest(b *testing.B) {
svc := s3.New(unit.Session)
for i := 0; i < b.N; i++ {
r, _ := svc.GetObjectRequest(nil)
if r == nil {
b.Fatal("r should not be nil")
}
}
}
func BenchmarkHandlersCopy(b *testing.B) {
handlers := request.Handlers{}
handlers.Validate.PushBack(func(r *request.Request) {})
handlers.Validate.PushBack(func(r *request.Request) {})
handlers.Build.PushBack(func(r *request.Request) {})
handlers.Build.PushBack(func(r *request.Request) {})
handlers.Send.PushBack(func(r *request.Request) {})
handlers.Send.PushBack(func(r *request.Request) {})
handlers.Unmarshal.PushBack(func(r *request.Request) {})
handlers.Unmarshal.PushBack(func(r *request.Request) {})
for i := 0; i < b.N; i++ {
h := handlers.Copy()
if e, a := handlers.Validate.Len(), h.Validate.Len(); e != a {
b.Fatalf("expected %d handlers got %d", e, a)
}
}
}
func BenchmarkHandlersPushBack(b *testing.B) {
handlers := request.Handlers{}
for i := 0; i < b.N; i++ {
h := handlers.Copy()
h.Validate.PushBack(func(r *request.Request) {})
h.Validate.PushBack(func(r *request.Request) {})
h.Validate.PushBack(func(r *request.Request) {})
h.Validate.PushBack(func(r *request.Request) {})
}
}
func BenchmarkHandlersPushFront(b *testing.B) {
handlers := request.Handlers{}
for i := 0; i < b.N; i++ {
h := handlers.Copy()
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
}
}
func BenchmarkHandlersClear(b *testing.B) {
handlers := request.Handlers{}
for i := 0; i < b.N; i++ {
h := handlers.Copy()
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Validate.PushFront(func(r *request.Request) {})
h.Clear()
}
}

View File

@@ -0,0 +1,24 @@
package request
import (
"io"
"net/http"
"net/url"
)
func copyHTTPRequest(r *http.Request, body io.ReadCloser) *http.Request {
req := new(http.Request)
*req = *r
req.URL = &url.URL{}
*req.URL = *r.URL
req.Body = body
req.Header = http.Header{}
for k, v := range r.Header {
for _, vv := range v {
req.Header.Add(k, vv)
}
}
return req
}

View File

@@ -0,0 +1,34 @@
package request
import (
"bytes"
"io/ioutil"
"net/http"
"net/url"
"sync"
"testing"
)
func TestRequestCopyRace(t *testing.T) {
origReq := &http.Request{URL: &url.URL{}, Header: http.Header{}}
origReq.Header.Set("Header", "OrigValue")
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
req := copyHTTPRequest(origReq, ioutil.NopCloser(&bytes.Buffer{}))
req.Header.Set("Header", "Value")
go func() {
req2 := copyHTTPRequest(req, ioutil.NopCloser(&bytes.Buffer{}))
req2.Header.Add("Header", "Value2")
}()
_ = req.Header.Get("Header")
wg.Done()
}()
_ = origReq.Header.Get("Header")
}
origReq.Header.Get("Header")
wg.Wait()
}

View File

@@ -0,0 +1,37 @@
// +build go1.5
package request_test
import (
"errors"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/mock"
"github.com/stretchr/testify/assert"
)
func TestRequestCancelRetry(t *testing.T) {
c := make(chan struct{})
reqNum := 0
s := mock.NewMockClient(aws.NewConfig().WithMaxRetries(10))
s.Handlers.Validate.Clear()
s.Handlers.Unmarshal.Clear()
s.Handlers.UnmarshalMeta.Clear()
s.Handlers.UnmarshalError.Clear()
s.Handlers.Send.PushFront(func(r *request.Request) {
reqNum++
r.Error = errors.New("net/http: request canceled")
})
out := &testData{}
r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out)
r.HTTPRequest.Cancel = c
close(c)
err := r.Send()
assert.True(t, strings.Contains(err.Error(), "canceled"))
assert.Equal(t, 1, reqNum)
}

Some files were not shown because too many files have changed in this diff Show More