diff --git a/render/detailed/connections.go b/render/detailed/connections.go index d4ce34f4a..7b21703c4 100644 --- a/render/detailed/connections.go +++ b/render/detailed/connections.go @@ -85,112 +85,39 @@ func (row connection) ID() string { return fmt.Sprintf("%s:%s-%s:%s-%s", row.remoteNodeID, row.remoteAddr, row.localNodeID, row.localAddr, row.port) } -func incomingConnectionsSummary(topologyID string, r report.Report, n report.Node, ns report.Nodes) ConnectionsSummary { - localEndpointIDs := endpointChildIDsOf(n) - - // For each node which has an edge TO me - counts := map[connection]int{} - for _, node := range ns { - if !node.Adjacency.Contains(n.ID) { - continue - } - // Work out what port they are talking to, and count the number of - // connections to that port. - for _, child := range endpointChildrenOf(node) { - for _, localEndpointID := range child.Adjacency.Intersection(localEndpointIDs) { - _, localAddr, port, ok := report.ParseEndpointNodeID(localEndpointID) - if !ok { - continue - } - key := newConnection(n, node, port, localEndpointID, localAddr) - counts[key] = counts[key] + 1 - } - } - } - - columnHeaders := NormalColumns - if isInternetNode(n) { - columnHeaders = InternetColumns - } - return ConnectionsSummary{ - ID: "incoming-connections", - TopologyID: topologyID, - Label: "Inbound", - Columns: columnHeaders, - Connections: connectionRows(r, ns, counts, isInternetNode(n)), - } +type connectionCounters struct { + counted map[string]struct{} + counts map[connection]int } -func outgoingConnectionsSummary(topologyID string, r report.Report, n report.Node, ns report.Nodes) ConnectionsSummary { - localEndpoints := endpointChildrenOf(n) +func newConnectionCounters() *connectionCounters { + return &connectionCounters{counted: map[string]struct{}{}, counts: map[connection]int{}} +} - // For each node which has an edge FROM me - counts := map[connection]int{} - for _, id := range n.Adjacency { - node, ok := ns[id] - if !ok { - continue - } - - remoteEndpointIDs := endpointChildIDsOf(node) - - for _, localEndpoint := range localEndpoints { - _, localAddr, _, ok := report.ParseEndpointNodeID(localEndpoint.ID) - if !ok { - continue - } - - for _, remoteEndpointID := range localEndpoint.Adjacency.Intersection(remoteEndpointIDs) { - _, _, port, ok := report.ParseEndpointNodeID(remoteEndpointID) - if !ok { - continue - } - key := newConnection(n, node, port, localEndpoint.ID, localAddr) - counts[key] = counts[key] + 1 - } - } +func (c *connectionCounters) add(sourceEndpoint report.Node, n report.Node, node report.Node, port string, endpointID string, localAddr string) { + // We identify connections by their source endpoint, pre-NAT, to + // ensure we only count them once. + // + // There is some arbitrariness here: We may see the same + // connection multiple times, with different destination + // endpoints, due to DNATing; the (destination) port under which + // we track that connection is determined by the first destination + // endpoint we encounter. + connectionID := sourceEndpoint.ID + if copySourceEndpointID, _, ok := sourceEndpoint.Latest.LookupEntry("copy_of"); ok { + connectionID = copySourceEndpointID } - - columnHeaders := NormalColumns - if isInternetNode(n) { - columnHeaders = InternetColumns - } - return ConnectionsSummary{ - ID: "outgoing-connections", - TopologyID: topologyID, - Label: "Outbound", - Columns: columnHeaders, - Connections: connectionRows(r, ns, counts, isInternetNode(n)), + if _, ok := c.counted[connectionID]; ok { + return } + c.counted[connectionID] = struct{}{} + key := newConnection(n, node, port, endpointID, localAddr) + c.counts[key] = c.counts[key] + 1 } -func endpointChildrenOf(n report.Node) []report.Node { - result := []report.Node{} - n.Children.ForEach(func(child report.Node) { - if child.Topology == report.Endpoint { - result = append(result, child) - } - }) - return result -} - -func endpointChildIDsOf(n report.Node) report.IDList { - result := report.MakeIDList() - n.Children.ForEach(func(child report.Node) { - if child.Topology == report.Endpoint { - result = result.Add(child.ID) - } - }) - return result -} - -func isInternetNode(n report.Node) bool { - return n.ID == render.IncomingInternetID || n.ID == render.OutgoingInternetID -} - -func connectionRows(r report.Report, ns report.Nodes, in map[connection]int, includeLocal bool) []Connection { +func (c *connectionCounters) rows(r report.Report, ns report.Nodes, includeLocal bool) []Connection { output := []Connection{} - for row, count := range in { + for row, count := range c.counts { // Use MakeNodeSummary to render the id and label of this node // TODO(paulbellamy): Would be cleaner if we hade just a // MakeNodeID(ns[row.remoteNodeID]). As we don't need the whole summary. @@ -236,3 +163,103 @@ func connectionRows(r report.Report, ns report.Nodes, in map[connection]int, inc sort.Sort(connectionsByID(output)) return output } + +func incomingConnectionsSummary(topologyID string, r report.Report, n report.Node, ns report.Nodes) ConnectionsSummary { + localEndpointIDs := endpointChildIDsOf(n) + counts := newConnectionCounters() + + // For each node which has an edge TO me + for _, node := range ns { + if !node.Adjacency.Contains(n.ID) { + continue + } + // Work out what port they are talking to, and count the number of + // connections to that port. + for _, remoteEndpoint := range endpointChildrenOf(node) { + for _, localEndpointID := range remoteEndpoint.Adjacency.Intersection(localEndpointIDs) { + _, localAddr, port, ok := report.ParseEndpointNodeID(localEndpointID) + if !ok { + continue + } + counts.add(remoteEndpoint, n, node, port, localEndpointID, localAddr) + } + } + } + + columnHeaders := NormalColumns + if isInternetNode(n) { + columnHeaders = InternetColumns + } + return ConnectionsSummary{ + ID: "incoming-connections", + TopologyID: topologyID, + Label: "Inbound", + Columns: columnHeaders, + Connections: counts.rows(r, ns, isInternetNode(n)), + } +} + +func outgoingConnectionsSummary(topologyID string, r report.Report, n report.Node, ns report.Nodes) ConnectionsSummary { + localEndpoints := endpointChildrenOf(n) + counts := newConnectionCounters() + + // For each node which has an edge FROM me + for _, id := range n.Adjacency { + node, ok := ns[id] + if !ok { + continue + } + + remoteEndpointIDs := endpointChildIDsOf(node) + + for _, localEndpoint := range localEndpoints { + _, localAddr, _, ok := report.ParseEndpointNodeID(localEndpoint.ID) + if !ok { + continue + } + for _, remoteEndpointID := range localEndpoint.Adjacency.Intersection(remoteEndpointIDs) { + _, _, port, ok := report.ParseEndpointNodeID(remoteEndpointID) + if !ok { + continue + } + counts.add(localEndpoint, n, node, port, localEndpoint.ID, localAddr) + } + } + } + + columnHeaders := NormalColumns + if isInternetNode(n) { + columnHeaders = InternetColumns + } + return ConnectionsSummary{ + ID: "outgoing-connections", + TopologyID: topologyID, + Label: "Outbound", + Columns: columnHeaders, + Connections: counts.rows(r, ns, isInternetNode(n)), + } +} + +func endpointChildrenOf(n report.Node) []report.Node { + result := []report.Node{} + n.Children.ForEach(func(child report.Node) { + if child.Topology == report.Endpoint { + result = append(result, child) + } + }) + return result +} + +func endpointChildIDsOf(n report.Node) report.IDList { + result := report.MakeIDList() + n.Children.ForEach(func(child report.Node) { + if child.Topology == report.Endpoint { + result = result.Add(child.ID) + } + }) + return result +} + +func isInternetNode(n report.Node) bool { + return n.ID == render.IncomingInternetID || n.ID == render.OutgoingInternetID +}