diff --git a/probe/plugins/registry.go b/probe/plugins/registry.go index 11d9d1db1..9b46f88e8 100644 --- a/probe/plugins/registry.go +++ b/probe/plugins/registry.go @@ -231,32 +231,13 @@ func (r *Registry) updateAndRegisterControlsInReport(rpt *report.Report) { key := rpt.Plugins.Keys()[0] spec, _ := rpt.Plugins.Lookup(key) pluginID := spec.ID - topologies := topologyPointers(rpt) var newPluginControls []string - for _, topology := range topologies { + rpt.WalkTopologies(func(topology *report.Topology) { newPluginControls = append(newPluginControls, r.updateAndGetControlsInTopology(pluginID, topology)...) - } + }) r.updatePluginControls(pluginID, report.MakeStringSet(newPluginControls...)) } -func topologyPointers(rpt *report.Report) []*report.Topology { - // We cannot use rpt.Topologies(), because it makes a slice of - // topology copies and we need original locations to modify - // them. - return []*report.Topology{ - &rpt.Endpoint, - &rpt.Process, - &rpt.Container, - &rpt.ContainerImage, - &rpt.Pod, - &rpt.Service, - &rpt.Deployment, - &rpt.ReplicaSet, - &rpt.Host, - &rpt.Overlay, - } -} - func (r *Registry) updateAndGetControlsInTopology(pluginID string, topology *report.Topology) []string { var pluginControls []string newControls := report.Controls{} diff --git a/probe/topology_tagger.go b/probe/topology_tagger.go index ce37e19a3..49dfc3bb3 100644 --- a/probe/topology_tagger.go +++ b/probe/topology_tagger.go @@ -16,10 +16,10 @@ func (topologyTagger) Name() string { return "Topology" } // Tag implements Tagger func (topologyTagger) Tag(r report.Report) (report.Report, error) { - for name, t := range r.TopologyMap() { + r.WalkNamedTopologies(func(name string, t *report.Topology) { for _, node := range t.Nodes { t.AddNode(node.WithTopology(name)) } - } + }) return r, nil } diff --git a/report/report.go b/report/report.go index e25f35b69..19bf9e780 100644 --- a/report/report.go +++ b/report/report.go @@ -43,6 +43,26 @@ const ( ContainersKey = "containers" ) +// topologyNames are the names of all report topologies. +var topologyNames = []string{ + Endpoint, + Process, + Container, + ContainerImage, + Pod, + Service, + Deployment, + ReplicaSet, + DaemonSet, + StatefulSet, + CronJob, + Host, + Overlay, + ECSTask, + ECSService, + SwarmService, +} + // Report is the core data type. It's produced by probes, and consumed and // stored by apps. It's composed of multiple topologies, each representing // a different (related, but not equivalent) view of the network. @@ -226,28 +246,6 @@ func MakeReport() Report { } } -// TopologyMap gets a map from topology names to pointers to the respective topologies -func (r *Report) TopologyMap() map[string]*Topology { - return map[string]*Topology{ - Endpoint: &r.Endpoint, - Process: &r.Process, - Container: &r.Container, - ContainerImage: &r.ContainerImage, - Pod: &r.Pod, - Service: &r.Service, - Deployment: &r.Deployment, - ReplicaSet: &r.ReplicaSet, - DaemonSet: &r.DaemonSet, - StatefulSet: &r.StatefulSet, - CronJob: &r.CronJob, - Host: &r.Host, - Overlay: &r.Overlay, - ECSTask: &r.ECSTask, - ECSService: &r.ECSService, - SwarmService: &r.SwarmService, - } -} - // Copy returns a value copy of the report. func (r Report) Copy() Report { newReport := Report{ @@ -275,46 +273,73 @@ func (r Report) Merge(other Report) Report { return newReport } -// Topologies returns a slice of Topologies in this report -func (r Report) Topologies() []Topology { - result := []Topology{} - r.WalkTopologies(func(t *Topology) { - result = append(result, *t) - }) - return result -} - // WalkTopologies iterates through the Topologies of the report, // potentially modifying them func (r *Report) WalkTopologies(f func(*Topology)) { - var dummy Report - r.WalkPairedTopologies(&dummy, func(t, _ *Topology) { f(t) }) + for _, name := range topologyNames { + f(r.topology(name)) + } +} + +// WalkNamedTopologies iterates through the Topologies of the report, +// potentially modifying them. +func (r *Report) WalkNamedTopologies(f func(string, *Topology)) { + for _, name := range topologyNames { + f(name, r.topology(name)) + } } // WalkPairedTopologies iterates through the Topologies of this and another report, // potentially modifying one or both. func (r *Report) WalkPairedTopologies(o *Report, f func(*Topology, *Topology)) { - f(&r.Endpoint, &o.Endpoint) - f(&r.Process, &o.Process) - f(&r.Container, &o.Container) - f(&r.ContainerImage, &o.ContainerImage) - f(&r.Pod, &o.Pod) - f(&r.Service, &o.Service) - f(&r.Deployment, &o.Deployment) - f(&r.ReplicaSet, &o.ReplicaSet) - f(&r.DaemonSet, &o.DaemonSet) - f(&r.StatefulSet, &o.StatefulSet) - f(&r.CronJob, &o.CronJob) - f(&r.Host, &o.Host) - f(&r.Overlay, &o.Overlay) - f(&r.ECSTask, &o.ECSTask) - f(&r.ECSService, &o.ECSService) - f(&r.SwarmService, &o.SwarmService) + for _, name := range topologyNames { + f(r.topology(name), o.topology(name)) + } } -// Topology gets a topology by name +// topology returns a reference to one of the report's topologies, +// selected by name. +func (r *Report) topology(name string) *Topology { + switch name { + case Endpoint: + return &r.Endpoint + case Process: + return &r.Process + case Container: + return &r.Container + case ContainerImage: + return &r.ContainerImage + case Pod: + return &r.Pod + case Service: + return &r.Service + case Deployment: + return &r.Deployment + case ReplicaSet: + return &r.ReplicaSet + case DaemonSet: + return &r.DaemonSet + case StatefulSet: + return &r.StatefulSet + case CronJob: + return &r.CronJob + case Host: + return &r.Host + case Overlay: + return &r.Overlay + case ECSTask: + return &r.ECSTask + case ECSService: + return &r.ECSService + case SwarmService: + return &r.SwarmService + } + return nil +} + +// Topology returns one of the report's topologies, selected by name. func (r Report) Topology(name string) (Topology, bool) { - if t, ok := r.TopologyMap()[name]; ok { + if t := r.topology(name); t != nil { return *t, true } return Topology{}, false @@ -323,8 +348,8 @@ func (r Report) Topology(name string) (Topology, bool) { // Validate checks the report for various inconsistencies. func (r Report) Validate() error { var errs []string - for _, topology := range r.Topologies() { - if err := topology.Validate(); err != nil { + for _, name := range topologyNames { + if err := r.topology(name).Validate(); err != nil { errs = append(errs, err.Error()) } } diff --git a/report/report_test.go b/report/report_test.go index bf5dc2a36..f65576e79 100644 --- a/report/report_test.go +++ b/report/report_test.go @@ -20,14 +20,18 @@ func TestReportTopologies(t *testing.T) { topologyType = reflect.TypeOf(report.MakeTopology()) ) - var want int + var want, have int for i := 0; i < reportType.NumField(); i++ { if reportType.Field(i).Type == topologyType { want++ } } - if have := len(report.MakeReport().Topologies()); want != have { + r := report.MakeReport() + r.WalkTopologies(func(_ *report.Topology) { + have++ + }) + if want != have { t.Errorf("want %d, have %d", want, have) } }