diff --git a/probe/host/reporter.go b/probe/host/reporter.go index 3aff7a602..1f0ab5ce2 100644 --- a/probe/host/reporter.go +++ b/probe/host/reporter.go @@ -2,7 +2,6 @@ package host import ( "fmt" - "net" "runtime" "sync" "time" @@ -86,20 +85,7 @@ func NewReporter(hostID, hostName, probeID, version string, pipes controls.PipeC func (*Reporter) Name() string { return "Host" } // GetLocalNetworks is exported for mocking -var GetLocalNetworks = func() ([]*net.IPNet, error) { - addrs, err := net.InterfaceAddrs() - if err != nil { - return nil, err - } - localNets := report.Networks{} - for _, addr := range addrs { - // Not all addrs are IPNets. - if ipNet, ok := addr.(*net.IPNet); ok { - localNets = append(localNets, ipNet) - } - } - return localNets, nil -} +var GetLocalNetworks = report.GetLocalNetworks // Report implements Reporter. func (r *Reporter) Report() (report.Report, error) { diff --git a/report/networks.go b/report/networks.go index 42efd5e4b..c2797760d 100644 --- a/report/networks.go +++ b/report/networks.go @@ -8,17 +8,11 @@ import ( // Networks represent a set of subnets type Networks []*net.IPNet -// Interface is exported for testing. -type Interface interface { - Addrs() ([]net.Addr, error) -} - -// Variables exposed for testing. +// LocalNetworks helps in determining which addresses a probe reports +// as being host-scoped. +// // TODO this design is broken, make it consistent with probe networks. -var ( - LocalNetworks = Networks{} - InterfaceByNameStub = func(name string) (Interface, error) { return net.InterfaceByName(name) } -) +var LocalNetworks = Networks{} // Contains returns true if IP is in Networks. func (n Networks) Contains(ip net.IP) bool { @@ -51,12 +45,7 @@ func LocalAddresses() ([]net.IP, error) { return []net.IP{}, err } - for _, addr := range addrs { - ipnet, ok := addr.(*net.IPNet) - if !ok { - continue - } - + for _, ipnet := range ipv4Nets(addrs) { result = append(result, ipnet.IP) } } @@ -68,7 +57,7 @@ func LocalAddresses() ([]net.IP, error) { // supplied, such that MakeAddressNodeID will scope addresses in this subnet // as local. func AddLocalBridge(name string) error { - inf, err := InterfaceByNameStub(name) + inf, err := net.InterfaceByName(name) if err != nil { return err } @@ -77,18 +66,27 @@ func AddLocalBridge(name string) error { if err != nil { return err } - for _, addr := range addrs { - _, network, err := net.ParseCIDR(addr.String()) - if err != nil { - return err - } - if network == nil { - continue - } - - LocalNetworks = append(LocalNetworks, network) - } + LocalNetworks = ipv4Nets(addrs) return nil } + +// GetLocalNetworks returns all the local networks. +func GetLocalNetworks() ([]*net.IPNet, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + return ipv4Nets(addrs), nil +} + +func ipv4Nets(addrs []net.Addr) []*net.IPNet { + nets := Networks{} + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.To4() != nil { + nets = append(nets, ipnet) + } + } + return nets +} diff --git a/report/networks_test.go b/report/networks_test.go index a199fcf82..47ca616ac 100644 --- a/report/networks_test.go +++ b/report/networks_test.go @@ -4,9 +4,7 @@ import ( "net" "testing" - "github.com/weaveworks/common/test" "github.com/weaveworks/scope/report" - "github.com/weaveworks/scope/test/reflect" ) func TestContains(t *testing.T) { @@ -31,42 +29,3 @@ func mustParseCIDR(s string) *net.IPNet { } return ipNet } - -type mockInterface struct { - addrs []net.Addr -} - -type mockAddr string - -func (m mockInterface) Addrs() ([]net.Addr, error) { - return m.addrs, nil -} - -func (m mockAddr) Network() string { - return "ip+net" -} - -func (m mockAddr) String() string { - return string(m) -} - -func TestAddLocal(t *testing.T) { - oldInterfaceByNameStub := report.InterfaceByNameStub - defer func() { report.InterfaceByNameStub = oldInterfaceByNameStub }() - - report.InterfaceByNameStub = func(name string) (report.Interface, error) { - return mockInterface{[]net.Addr{mockAddr("52.53.54.55/16")}}, nil - } - - err := report.AddLocalBridge("foo") - if err != nil { - t.Errorf("%v", err) - } - - want := report.Networks([]*net.IPNet{mustParseCIDR("52.53.54.55/16")}) - have := report.LocalNetworks - - if !reflect.DeepEqual(want, have) { - t.Errorf("%s", test.Diff(want, have)) - } -}