diff --git a/render/detailed/connections.go b/render/detailed/connections.go index d4ce34f4a..0cd8e742e 100644 --- a/render/detailed/connections.go +++ b/render/detailed/connections.go @@ -85,112 +85,33 @@ func (row connection) ID() string { return fmt.Sprintf("%s:%s-%s:%s-%s", row.remoteNodeID, row.remoteAddr, row.localNodeID, row.localAddr, row.port) } -func incomingConnectionsSummary(topologyID string, r report.Report, n report.Node, ns report.Nodes) ConnectionsSummary { - localEndpointIDs := endpointChildIDsOf(n) - - // For each node which has an edge TO me - counts := map[connection]int{} - for _, node := range ns { - if !node.Adjacency.Contains(n.ID) { - continue - } - // Work out what port they are talking to, and count the number of - // connections to that port. - for _, child := range endpointChildrenOf(node) { - for _, localEndpointID := range child.Adjacency.Intersection(localEndpointIDs) { - _, localAddr, port, ok := report.ParseEndpointNodeID(localEndpointID) - if !ok { - continue - } - key := newConnection(n, node, port, localEndpointID, localAddr) - counts[key] = counts[key] + 1 - } - } - } - - columnHeaders := NormalColumns - if isInternetNode(n) { - columnHeaders = InternetColumns - } - return ConnectionsSummary{ - ID: "incoming-connections", - TopologyID: topologyID, - Label: "Inbound", - Columns: columnHeaders, - Connections: connectionRows(r, ns, counts, isInternetNode(n)), - } +type connectionCounters struct { + counted map[string]struct{} + counts map[connection]int } -func outgoingConnectionsSummary(topologyID string, r report.Report, n report.Node, ns report.Nodes) ConnectionsSummary { - localEndpoints := endpointChildrenOf(n) +func newConnectionCounters() *connectionCounters { + return &connectionCounters{counted: map[string]struct{}{}, counts: map[connection]int{}} +} - // For each node which has an edge FROM me - counts := map[connection]int{} - for _, id := range n.Adjacency { - node, ok := ns[id] - if !ok { - continue - } - - remoteEndpointIDs := endpointChildIDsOf(node) - - for _, localEndpoint := range localEndpoints { - _, localAddr, _, ok := report.ParseEndpointNodeID(localEndpoint.ID) - if !ok { - continue - } - - for _, remoteEndpointID := range localEndpoint.Adjacency.Intersection(remoteEndpointIDs) { - _, _, port, ok := report.ParseEndpointNodeID(remoteEndpointID) - if !ok { - continue - } - key := newConnection(n, node, port, localEndpoint.ID, localAddr) - counts[key] = counts[key] + 1 - } - } +func (c *connectionCounters) add(sourceEndpoint report.Node, n report.Node, node report.Node, port string, endpointID string, localAddr string) { + // We identify connections by their source endpoint, pre-NAT, to + // ensure we only count them once. + connectionID := sourceEndpoint.ID + if copySourceEndpointID, _, ok := sourceEndpoint.Latest.LookupEntry("copy_of"); ok { + connectionID = copySourceEndpointID } - - columnHeaders := NormalColumns - if isInternetNode(n) { - columnHeaders = InternetColumns - } - return ConnectionsSummary{ - ID: "outgoing-connections", - TopologyID: topologyID, - Label: "Outbound", - Columns: columnHeaders, - Connections: connectionRows(r, ns, counts, isInternetNode(n)), + if _, ok := c.counted[connectionID]; ok { + return } + c.counted[connectionID] = struct{}{} + key := newConnection(n, node, port, endpointID, localAddr) + c.counts[key] = c.counts[key] + 1 } -func endpointChildrenOf(n report.Node) []report.Node { - result := []report.Node{} - n.Children.ForEach(func(child report.Node) { - if child.Topology == report.Endpoint { - result = append(result, child) - } - }) - return result -} - -func endpointChildIDsOf(n report.Node) report.IDList { - result := report.MakeIDList() - n.Children.ForEach(func(child report.Node) { - if child.Topology == report.Endpoint { - result = result.Add(child.ID) - } - }) - return result -} - -func isInternetNode(n report.Node) bool { - return n.ID == render.IncomingInternetID || n.ID == render.OutgoingInternetID -} - -func connectionRows(r report.Report, ns report.Nodes, in map[connection]int, includeLocal bool) []Connection { +func (c *connectionCounters) rows(r report.Report, ns report.Nodes, includeLocal bool) []Connection { output := []Connection{} - for row, count := range in { + for row, count := range c.counts { // Use MakeNodeSummary to render the id and label of this node // TODO(paulbellamy): Would be cleaner if we hade just a // MakeNodeID(ns[row.remoteNodeID]). As we don't need the whole summary. @@ -236,3 +157,125 @@ func connectionRows(r report.Report, ns report.Nodes, in map[connection]int, inc sort.Sort(connectionsByID(output)) return output } + +func incomingConnectionsSummary(topologyID string, r report.Report, n report.Node, ns report.Nodes) ConnectionsSummary { + localEndpointIDs, localEndpointIDCopies := endpointChildIDsAndCopyMapOf(n) + counts := newConnectionCounters() + + // For each node which has an edge TO me + for _, node := range ns { + if !node.Adjacency.Contains(n.ID) { + continue + } + // Work out what port they are talking to, and count the number of + // connections to that port. + for _, remoteEndpoint := range endpointChildrenOf(node) { + for _, localEndpointID := range remoteEndpoint.Adjacency.Intersection(localEndpointIDs) { + localEndpointID = canonicalEndpointID(localEndpointIDCopies, localEndpointID) + _, localAddr, port, ok := report.ParseEndpointNodeID(localEndpointID) + if !ok { + continue + } + counts.add(remoteEndpoint, n, node, port, localEndpointID, localAddr) + } + } + } + + columnHeaders := NormalColumns + if isInternetNode(n) { + columnHeaders = InternetColumns + } + return ConnectionsSummary{ + ID: "incoming-connections", + TopologyID: topologyID, + Label: "Inbound", + Columns: columnHeaders, + Connections: counts.rows(r, ns, isInternetNode(n)), + } +} + +func outgoingConnectionsSummary(topologyID string, r report.Report, n report.Node, ns report.Nodes) ConnectionsSummary { + localEndpoints := endpointChildrenOf(n) + counts := newConnectionCounters() + + // For each node which has an edge FROM me + for _, id := range n.Adjacency { + node, ok := ns[id] + if !ok { + continue + } + + remoteEndpointIDs, remoteEndpointIDCopies := endpointChildIDsAndCopyMapOf(node) + + for _, localEndpoint := range localEndpoints { + _, localAddr, _, ok := report.ParseEndpointNodeID(localEndpoint.ID) + if !ok { + continue + } + for _, remoteEndpointID := range localEndpoint.Adjacency.Intersection(remoteEndpointIDs) { + remoteEndpointID = canonicalEndpointID(remoteEndpointIDCopies, remoteEndpointID) + _, _, port, ok := report.ParseEndpointNodeID(remoteEndpointID) + if !ok { + continue + } + counts.add(localEndpoint, n, node, port, localEndpoint.ID, localAddr) + } + } + } + + columnHeaders := NormalColumns + if isInternetNode(n) { + columnHeaders = InternetColumns + } + return ConnectionsSummary{ + ID: "outgoing-connections", + TopologyID: topologyID, + Label: "Outbound", + Columns: columnHeaders, + Connections: counts.rows(r, ns, isInternetNode(n)), + } +} + +func endpointChildrenOf(n report.Node) []report.Node { + result := []report.Node{} + n.Children.ForEach(func(child report.Node) { + if child.Topology == report.Endpoint { + result = append(result, child) + } + }) + return result +} + +func endpointChildIDsAndCopyMapOf(n report.Node) (report.IDList, map[string]string) { + ids := report.MakeIDList() + copies := map[string]string{} + n.Children.ForEach(func(child report.Node) { + if child.Topology == report.Endpoint { + ids = ids.Add(child.ID) + if copyID, _, ok := child.Latest.LookupEntry("copy_of"); ok { + copies[child.ID] = copyID + } + } + }) + return ids, copies +} + +// canonicalEndpointID returns the original endpoint ID of which id is +// a "copy_of" (due to NATing), or, if the id is not a copy, the id +// itself. +// +// This is used for determining a unique destination endpoint ID for a +// connection, removing any arbitrariness in the destination port we +// are associating with the connection when it is encountered multiple +// times in the topology (with different destination endpoints, due to +// DNATing). +func canonicalEndpointID(copies map[string]string, id string) string { + if original, ok := copies[id]; ok { + return original + } + return id +} + +func isInternetNode(n report.Node) bool { + return n.ID == render.IncomingInternetID || n.ID == render.OutgoingInternetID +}