diff --git a/probe/awsecs/client.go b/probe/awsecs/client.go index cb1eeb9de..b9d3a10e3 100644 --- a/probe/awsecs/client.go +++ b/probe/awsecs/client.go @@ -12,8 +12,15 @@ import ( "github.com/aws/aws-sdk-go/service/ecs" ) -// a wrapper around an AWS client that makes all the needed calls and just exposes the final results -type ecsClient struct { +// A wrapper around an AWS client that makes all the needed calls and just exposes the final results. +// We create an interface so we can mock for testing +type ecsClient interface { + // Returns a ecsInfo struct containing data needed for a report. + getInfo([]string) ecsInfo +} + +// actual implementation +type ecsClientImpl struct { client *ecs.ECS cluster string taskCache map[string]ecsTask // Keys are task ARNs. @@ -72,7 +79,7 @@ func newClient(cluster string) (*ecsClient, error) { return nil, err } - return &ecsClient{ + return &ecsClientImpl{ client: ecs.New(sess, &aws.Config{Region: aws.String(region)}), cluster: cluster, taskCache: map[string]ecsTask{}, @@ -114,7 +121,7 @@ func newECSService(service *ecs.Service) ecsService { // Returns a channel from which service ARNs can be read. // Cannot fail as it will attempt to deliver partial results, though that may end up being no results. -func (c ecsClient) listServices() <-chan string { +func (c ecsClientImpl) listServices() <-chan string { log.Debugf("Listing ECS services") results := make(chan string) go func() { @@ -140,7 +147,7 @@ func (c ecsClient) listServices() <-chan string { // Returns (input, done) channels. Service ARNs given to input are batched and details are fetched, // with full ecsService objects being put into the cache. Closes done when finished. -func (c ecsClient) describeServices() (chan<- string, <-chan struct{}) { +func (c ecsClientImpl) describeServices() (chan<- string, <-chan struct{}) { input := make(chan string) done := make(chan struct{}) @@ -182,7 +189,7 @@ func (c ecsClient) describeServices() (chan<- string, <-chan struct{}) { return input, done } -func (c ecsClient) describeServicesBatch(arns []string) { +func (c ecsClientImpl) describeServicesBatch(arns []string) { arnPtrs := make([]*string, 0, len(arns)) for i := range arns { arnPtrs = append(arnPtrs, &arns[i]) @@ -209,7 +216,7 @@ func (c ecsClient) describeServicesBatch(arns []string) { } // get details on given tasks, updating cache with the results -func (c ecsClient) getTasks(taskARNs []string) { +func (c ecsClientImpl) getTasks(taskARNs []string) { log.Debugf("Describing %d ECS tasks", len(taskARNs)) taskPtrs := make([]*string, len(taskARNs)) @@ -238,7 +245,7 @@ func (c ecsClient) getTasks(taskARNs []string) { } // Evict entries from the caches which have not been used within the eviction interval. -func (c ecsClient) evictOldCacheItems() { +func (c ecsClientImpl) evictOldCacheItems() { const evictTime = time.Minute now := time.Now() @@ -264,7 +271,7 @@ func (c ecsClient) evictOldCacheItems() { // Try to match a list of task ARNs to service names using cached info. // Returns (task to service map, unmatched tasks). Ignores tasks whose startedby values // don't appear to point to a service. -func (c ecsClient) matchTasksServices(taskARNs []string) (map[string]string, []string) { +func (c ecsClientImpl) matchTasksServices(taskARNs []string) (map[string]string, []string) { const servicePrefix = "ecs-svc" deploymentMap := map[string]string{} @@ -298,7 +305,7 @@ func (c ecsClient) matchTasksServices(taskARNs []string) (map[string]string, []s return results, unmatched } -func (c ecsClient) ensureTasks(taskARNs []string) { +func (c ecsClientImpl) ensureTasks(taskARNs []string) { tasksToFetch := []string{} now := time.Now() for _, taskARN := range taskARNs { @@ -314,7 +321,7 @@ func (c ecsClient) ensureTasks(taskARNs []string) { } } -func (c ecsClient) refreshServices(taskServiceMap map[string]string) map[string]bool { +func (c ecsClientImpl) refreshServices(taskServiceMap map[string]string) map[string]bool { toDescribe, done := c.describeServices() servicesRefreshed := map[string]bool{} for _, serviceName := range taskServiceMap { @@ -329,7 +336,7 @@ func (c ecsClient) refreshServices(taskServiceMap map[string]string) map[string] return servicesRefreshed } -func (c ecsClient) describeAllServices(servicesRefreshed map[string]bool) { +func (c ecsClientImpl) describeAllServices(servicesRefreshed map[string]bool) { serviceNamesChan := c.listServices() toDescribe, done := c.describeServices() go func() { @@ -344,7 +351,7 @@ func (c ecsClient) describeAllServices(servicesRefreshed map[string]bool) { <-done } -func (c ecsClient) makeECSInfo(taskARNs []string, taskServiceMap map[string]string) ecsInfo { +func (c ecsClientImpl) makeECSInfo(taskARNs []string, taskServiceMap map[string]string) ecsInfo { // The maps to return are the referenced subsets of the full caches tasks := map[string]ecsTask{} for _, taskARN := range taskARNs { @@ -372,8 +379,8 @@ func (c ecsClient) makeECSInfo(taskARNs []string, taskServiceMap map[string]stri return ecsInfo{services: services, tasks: tasks, taskServiceMap: taskServiceMap} } -// Returns a ecsInfo struct containing data needed for a report. -func (c ecsClient) getInfo(taskARNs []string) ecsInfo { +// Implements ecsClient.getInfo +func (c ecsClientImpl) getInfo(taskARNs []string) ecsInfo { log.Debugf("Getting ECS info on %d tasks", len(taskARNs)) // We do a weird order of operations here to minimize unneeded cache refreshes. diff --git a/probe/awsecs/reporter_test.go b/probe/awsecs/reporter_test.go index 3a92b48ff..ddcc55e40 100644 --- a/probe/awsecs/reporter_test.go +++ b/probe/awsecs/reporter_test.go @@ -5,6 +5,28 @@ import ( "testing" ) +const ( + testCluster = "test-cluster" + testFamily = "test-family" + testTaskARN = "arn:aws:ecs:us-east-1:123456789012:task/12345678-9abc-def0-1234-56789abcdef0" + testContainer = "test-container" + testContainerData = map[string]string{ + docker.LabelPrefix + "com.amazonaws.ecs.task-arn": + testTaskARN, + docker.LabelPrefix + "com.amazonaws.ecs.cluster": + testCluster, + docker.LabelPrefix + "com.amazonaws.ecs.task-definition-family": + testFamily, + } +) + +func getTestContainerNode() report.Node { + return report.MakeNodeWith( + report.MakeContainerNodeID("test-container"), + testContainerData + ) +} + func TestGetLabelInfo(t *testing.T) { r := Make() rpt, err := r.Report() @@ -17,25 +39,13 @@ func TestGetLabelInfo(t *testing.T) { t.Error("Empty report did not produce empty label info: %v != %v", labelInfo, expected) } - rpt.Containers = rpt.Containers.AddNode( - report.MakeNodeWith( - report.MakeContainerNodeID("test-container"), - map[string]string{ - docker.LabelPrefix + "com.amazonaws.ecs.task-arn": - "arn:aws:ecs:us-east-1:123456789012:task/12345678-9abc-def0-1234-56789abcdef0", - docker.LabelPrefix + "com.amazonaws.ecs.cluster": - "test-cluster", - docker.LabelPrefix + "com.amazonaws.ecs.task-definition-family": - "test-family", - } - ) - ) + rpt.Containers = rpt.Containers.AddNode(getTestContainerNode()) labelInfo = r.getLabelInfo(rpt) expected = map[string]map[string]*taskLabelInfo{ - "test-cluster": map[string]*taskLabelInfo{ - "arn:aws:ecs:us-east-1:123456789012:task/12345678-9abc-def0-1234-56789abcdef0": &taskLabelInfo{ - containerIDs: []string{"test-container"}, - family: "test-family", + testCluster: map[string]*taskLabelInfo{ + testTaskARN: &taskLabelInfo{ + containerIDs: []string{testContainer}, + family: testFamily, } } } @@ -43,3 +53,47 @@ func TestGetLabelInfo(t *testing.T) { t.Error("Did not get expected label info: %v != %v", labelInfo, expected) } } + +// Implements ecsClient +type mockEcsClient { + t *testing.T + expectedARNs []string + info ecsInfo +} + +func newMockEcsClient(t *testing.T, expectedARNs []string, info ecsInfo) *ecsClient { + return &mockEcsClient{ + t, + expectedARNs, + info, + } +} + +func (c mockEcsClient) getInfo(taskARNs []string) ecsInfo { + if !reflect.DeepEqual(taskARNs, c.expectedARNs) { + c.t.Fatal("getInfo called with wrong ARNs: %v != %v", taskARNs, c.expectedARNs) + } + return c.info +} + +func TestTagReport(t *testing.T) { + r := Make() + + r.clientsByCluster[testCluster] = newMockEcsClient( + t, + []string{}, + ecsInfo{ + // TODO fill in values below + tasks: map[string]ecsTask{}, + services: map[string]ecsService{}, + taskServiceMap: map[string]string{}, + }, + ) + + rpt, err := r.Report() + if err != nil { + t.Fatal("Error making report") + } + rpt = r.Tag(rpt) + // TODO check it matches +}