diff --git a/probe/endpoint/conntrack.go b/probe/endpoint/conntrack.go index 149f6896e..d3bc9e679 100644 --- a/probe/endpoint/conntrack.go +++ b/probe/endpoint/conntrack.go @@ -65,7 +65,7 @@ type conntrack struct { Flows []Flow `xml:"flow"` } -// Conntracker is somethin that tracks connections. +// Conntracker is something that tracks connections. type Conntracker interface { WalkFlows(f func(Flow)) Stop() @@ -81,7 +81,7 @@ type conntracker struct { } // NewConntracker creates and starts a new Conntracter -var NewConntracker = func(existingConns bool, args ...string) (Conntracker, error) { +func NewConntracker(existingConns bool, args ...string) (Conntracker, error) { if !ConntrackModulePresent() { return nil, fmt.Errorf("No conntrack module") } diff --git a/probe/endpoint/nat.go b/probe/endpoint/nat.go index 023f0f4c1..fe9927155 100644 --- a/probe/endpoint/nat.go +++ b/probe/endpoint/nat.go @@ -22,12 +22,8 @@ type NATMapper struct { } // NewNATMapper is exposed for testing -func NewNATMapper() (*NATMapper, error) { - ct, err := NewConntracker(true, "--any-nat") - if err != nil { - return nil, err - } - return &NATMapper{ct}, nil +func NewNATMapper(ct Conntracker) NATMapper { + return NATMapper{ct} } func toMapping(f Flow) *endpointMapping { @@ -53,7 +49,7 @@ func toMapping(f Flow) *endpointMapping { // ApplyNAT duplicates Nodes in the endpoint topology of a // report, based on the NAT table as returns by natTable. -func (n *NATMapper) ApplyNAT(rpt report.Report, scope string) { +func (n NATMapper) ApplyNAT(rpt report.Report, scope string) { n.WalkFlows(func(f Flow) { var ( mapping = toMapping(f) diff --git a/probe/endpoint/nat_test.go b/probe/endpoint/nat_test.go index 8858b07c3..65b2c16ea 100644 --- a/probe/endpoint/nat_test.go +++ b/probe/endpoint/nat_test.go @@ -22,43 +22,76 @@ func (m *mockConntracker) WalkFlows(f func(endpoint.Flow)) { func (m *mockConntracker) Stop() {} func TestNat(t *testing.T) { - oldNewConntracker := endpoint.NewConntracker - defer func() { endpoint.NewConntracker = oldNewConntracker }() + // test that two containers, on the docker network, get their connections mapped + // correctly. + // the setup is this: + // + // container2 (10.0.47.2:222222), host2 (2.3.4.5:22223) -> + // host1 (1.2.3.4:80), container1 (10.0.47.2:80) - endpoint.NewConntracker = func(existingConns bool, args ...string) (endpoint.Conntracker, error) { + // from the PoV of host1 + { flow := makeFlow("") addIndependant(&flow, 1, "") - flow.Original = addMeta(&flow, "original", "10.0.47.1", "2.3.4.5", 80, 22222) - flow.Reply = addMeta(&flow, "reply", "2.3.4.5", "1.2.3.4", 22222, 80) - - return &mockConntracker{ + flow.Original = addMeta(&flow, "original", "2.3.4.5", "1.2.3.4", 222222, 80) + flow.Reply = addMeta(&flow, "reply", "10.0.47.1", "2.3.4.5", 80, 222222) + ct := &mockConntracker{ flows: []endpoint.Flow{flow}, - }, nil + } + + have := report.MakeReport() + originalID := report.MakeEndpointNodeID("host1", "10.0.47.1", "80") + have.Endpoint.AddNode(originalID, report.MakeNodeWith(report.Metadata{ + endpoint.Addr: "10.0.47.1", + endpoint.Port: "80", + "foo": "bar", + })) + + want := have.Copy() + want.Endpoint.AddNode(report.MakeEndpointNodeID("host1", "1.2.3.4", "80"), report.MakeNodeWith(report.Metadata{ + endpoint.Addr: "1.2.3.4", + endpoint.Port: "80", + "copy_of": originalID, + "foo": "bar", + })) + + natmapper := endpoint.NewNATMapper(ct) + natmapper.ApplyNAT(have, "host1") + if !reflect.DeepEqual(want, have) { + t.Fatal(test.Diff(want, have)) + } } - have := report.MakeReport() - originalID := report.MakeEndpointNodeID("host1", "10.0.47.1", "80") - have.Endpoint.AddNode(originalID, report.MakeNodeWith(report.Metadata{ - endpoint.Addr: "10.0.47.1", - endpoint.Port: "80", - "foo": "bar", - })) + // form the PoV of host2 + { + flow := makeFlow("") + addIndependant(&flow, 2, "") + flow.Original = addMeta(&flow, "original", "10.0.47.2", "1.2.3.4", 22222, 80) + flow.Reply = addMeta(&flow, "reply", "1.2.3.4", "2.3.4.5", 80, 22223) + ct := &mockConntracker{ + flows: []endpoint.Flow{flow}, + } - want := have.Copy() - want.Endpoint.AddNode(report.MakeEndpointNodeID("host1", "1.2.3.4", "80"), report.MakeNodeWith(report.Metadata{ - endpoint.Addr: "1.2.3.4", - endpoint.Port: "80", - "copy_of": originalID, - "foo": "bar", - })) + have := report.MakeReport() + originalID := report.MakeEndpointNodeID("host2", "10.0.47.2", "22222") + have.Endpoint.AddNode(originalID, report.MakeNodeWith(report.Metadata{ + endpoint.Addr: "10.0.47.2", + endpoint.Port: "22222", + "foo": "baz", + })) - natmapper, err := endpoint.NewNATMapper() - if err != nil { - t.Fatal(err) - } + want := have.Copy() + want.Endpoint.AddNode(report.MakeEndpointNodeID("host2", "2.3.4.5", "22223"), report.MakeNodeWith(report.Metadata{ + endpoint.Addr: "2.3.4.5", + endpoint.Port: "22223", + "copy_of": originalID, + "foo": "baz", + })) - natmapper.ApplyNAT(have, "host1") - if !reflect.DeepEqual(want, have) { - t.Fatal(test.Diff(want, have)) + natmapper := endpoint.NewNATMapper(ct) + natmapper.ApplyNAT(have, "host1") + if !reflect.DeepEqual(want, have) { + t.Fatal(test.Diff(want, have)) + } } } diff --git a/probe/endpoint/reporter.go b/probe/endpoint/reporter.go index 9cd45725d..b5b67d7fb 100644 --- a/probe/endpoint/reporter.go +++ b/probe/endpoint/reporter.go @@ -52,7 +52,7 @@ func NewReporter(hostID, hostName string, includeProcesses bool, useConntrack bo var ( conntrackModulePresent = ConntrackModulePresent() conntracker Conntracker - natmapper *NATMapper + natmapper NATMapper err error ) if conntrackModulePresent && useConntrack { @@ -62,17 +62,18 @@ func NewReporter(hostID, hostName string, includeProcesses bool, useConntrack bo } } if conntrackModulePresent { - natmapper, err = NewNATMapper() + ct, err := NewConntracker(true, "--any-nat") if err != nil { - log.Printf("Failed to start natMapper: %v", err) + log.Printf("Failed to start conntracker for natmapper: %v", err) } + natmapper = NewNATMapper(ct) } return &Reporter{ hostID: hostID, hostName: hostName, includeProcesses: includeProcesses, conntracker: conntracker, - natmapper: natmapper, + natmapper: &natmapper, revResolver: NewReverseResolver(), } }