From ad45ae2c964afa064056d53731e1079b0a3220f2 Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Tue, 19 May 2015 12:21:11 +0200 Subject: [PATCH 1/4] Publisher refuses connections from the same host --- xfer/publisher.go | 27 +++++++++++++++-- xfer/publisher_test.go | 69 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 89 insertions(+), 7 deletions(-) diff --git a/xfer/publisher.go b/xfer/publisher.go index 6a89298de..af24007ea 100644 --- a/xfer/publisher.go +++ b/xfer/publisher.go @@ -53,7 +53,7 @@ func (p *TCPPublisher) Publish(msg report.Report) { } func (p *TCPPublisher) loop(incoming <-chan net.Conn) { - var activeConns = make(map[net.Conn]*gob.Encoder) + activeConns := map[net.Conn]*gob.Encoder{} for { select { @@ -62,6 +62,27 @@ func (p *TCPPublisher) loop(incoming <-chan net.Conn) { return // someone closed our connection chan -- weird? } + // Don't allow multiple connections from the same remote host. + host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + log.Printf("incoming connection: %s: %v (dropped)", conn.RemoteAddr(), err) + conn.Close() + continue + } + outer: + for activeConn := range activeConns { + activeHost, _, err := net.SplitHostPort(activeConn.RemoteAddr().String()) + if err != nil { + log.Printf("active connection: %s: %v (strange)", activeConn.RemoteAddr(), err) + continue + } + if host == activeHost { + log.Printf("duplicate connection: %s (dropped)", conn.RemoteAddr()) + conn.Close() + continue outer + } + } + log.Printf("connection initiated: %s", conn.RemoteAddr()) activeConns[conn] = gob.NewEncoder(conn) @@ -70,10 +91,10 @@ func (p *TCPPublisher) loop(incoming <-chan net.Conn) { return // someone closed our msg chan, so we're done } - var teminatedConns []net.Conn + teminatedConns := []net.Conn{} for conn, encoder := range activeConns { if err := encoder.Encode(msg); err != nil { - log.Printf("connection terminated: %v", err) + log.Printf("connection terminated: %s: %v", conn.RemoteAddr(), err) teminatedConns = append(teminatedConns, conn) conn.Close() } diff --git a/xfer/publisher_test.go b/xfer/publisher_test.go index 5b25130af..3e770dfe1 100644 --- a/xfer/publisher_test.go +++ b/xfer/publisher_test.go @@ -2,6 +2,7 @@ package xfer_test import ( "encoding/gob" + "fmt" "io/ioutil" "log" "net" @@ -15,9 +16,8 @@ import ( func TestTCPPublisher(t *testing.T) { log.SetOutput(ioutil.Discard) - // Build the address - port := ":12345" - addr, err := net.ResolveTCPAddr("tcp4", "127.0.0.1"+port) + // Choose a port + port, err := getFreePort() if err != nil { t.Fatal(err) } @@ -30,7 +30,7 @@ func TestTCPPublisher(t *testing.T) { defer p.Close() // Start a raw listener - conn, err := net.DialTCP("tcp4", nil, addr) + conn, err := net.Dial("tcp4", "127.0.0.1"+port) if err != nil { t.Fatal(err) } @@ -46,3 +46,64 @@ func TestTCPPublisher(t *testing.T) { t.Fatal(err) } } + +func TestPublisherClosesDuplicateConnections(t *testing.T) { + log.SetOutput(ioutil.Discard) + + // Choose a port + port, err := getFreePort() + if err != nil { + t.Fatal(err) + } + + // Start a publisher + p, err := xfer.NewTCPPublisher(port) + if err != nil { + t.Fatal(err) + } + defer p.Close() + + // Connect a listener + conn, err := net.Dial("tcp4", "127.0.0.1"+port) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + time.Sleep(time.Millisecond) + + // Try to connect the same listener + dupconn, err := net.Dial("tcp4", "127.0.0.1"+port) + if err != nil { + t.Fatal(err) + } + defer dupconn.Close() + + // Publish a message + p.Publish(report.Report{}) + + // The first listener should receive it + var r report.Report + if err := gob.NewDecoder(conn).Decode(&r); err != nil { + t.Fatal(err) + } + + // The duplicate listener should have an error + if err := gob.NewDecoder(dupconn).Decode(&r); err == nil { + t.Errorf("expected error, got none") + } else { + t.Logf("dupconn got expected error: %v", err) + } +} + +func getFreePort() (string, error) { + ln, err := net.Listen("tcp4", ":0") + if err != nil { + return "", fmt.Errorf("Listen: %v", err) + } + defer ln.Close() + _, port, err := net.SplitHostPort(ln.Addr().String()) + if err != nil { + return "", fmt.Errorf("SplitHostPort(%s): %v", ln.Addr().String(), err) + } + return ":" + port, nil +} From 11f85cda1bff1cb60c52ac6e5163f6a08c63ad6c Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Tue, 19 May 2015 12:22:02 +0200 Subject: [PATCH 2/4] Remove needless probe -version --- probe/main.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/probe/main.go b/probe/main.go index ba4fe4d98..ba45114d3 100644 --- a/probe/main.go +++ b/probe/main.go @@ -2,7 +2,6 @@ package main import ( "flag" - "fmt" "log" "net" "net/http" @@ -22,7 +21,6 @@ import ( func main() { var ( httpListen = flag.String("http.listen", "", "listen address for HTTP profiling and instrumentation server") - version = flag.Bool("version", false, "print version number and exit") publishInterval = flag.Duration("publish.interval", 1*time.Second, "publish (output) interval") spyInterval = flag.Duration("spy.interval", 100*time.Millisecond, "spy (scan) interval") listen = flag.String("listen", ":"+strconv.Itoa(xfer.ProbePort), "listen address") @@ -39,12 +37,6 @@ func main() { os.Exit(1) } - // -version flag: - if *version { - fmt.Printf("unstable\n") - return - } - procspy.SetProcRoot(*procRoot) if *httpListen != "" { From 08f59057427651dee4d9e98196f112b2a794510b Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Tue, 19 May 2015 12:45:48 +0200 Subject: [PATCH 3/4] Use host -> (Conn, Encoder) mapping for active conns --- xfer/publisher.go | 40 ++++++++++++++++------------------------ 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/xfer/publisher.go b/xfer/publisher.go index af24007ea..7dc8758f6 100644 --- a/xfer/publisher.go +++ b/xfer/publisher.go @@ -53,7 +53,7 @@ func (p *TCPPublisher) Publish(msg report.Report) { } func (p *TCPPublisher) loop(incoming <-chan net.Conn) { - activeConns := map[net.Conn]*gob.Encoder{} + activeConns := map[string]connEncoder{} // host: connEncoder for { select { @@ -69,44 +69,36 @@ func (p *TCPPublisher) loop(incoming <-chan net.Conn) { conn.Close() continue } - outer: - for activeConn := range activeConns { - activeHost, _, err := net.SplitHostPort(activeConn.RemoteAddr().String()) - if err != nil { - log.Printf("active connection: %s: %v (strange)", activeConn.RemoteAddr(), err) - continue - } - if host == activeHost { - log.Printf("duplicate connection: %s (dropped)", conn.RemoteAddr()) - conn.Close() - continue outer - } + if _, ok := activeConns[host]; ok { + log.Printf("duplicate connection: %s (dropped)", conn.RemoteAddr()) + conn.Close() + continue } log.Printf("connection initiated: %s", conn.RemoteAddr()) - activeConns[conn] = gob.NewEncoder(conn) + activeConns[host] = connEncoder{conn, gob.NewEncoder(conn)} case msg, ok := <-p.msg: if !ok { return // someone closed our msg chan, so we're done } - teminatedConns := []net.Conn{} - for conn, encoder := range activeConns { - if err := encoder.Encode(msg); err != nil { - log.Printf("connection terminated: %s: %v", conn.RemoteAddr(), err) - teminatedConns = append(teminatedConns, conn) - conn.Close() + for host, connEncoder := range activeConns { + if err := connEncoder.Encoder.Encode(msg); err != nil { + log.Printf("connection terminated: %s: %v", connEncoder.Conn.RemoteAddr(), err) + connEncoder.Conn.Close() + delete(activeConns, host) } } - - for _, conn := range teminatedConns { - delete(activeConns, conn) - } } } } +type connEncoder struct { + net.Conn + *gob.Encoder +} + func fwd(ln net.Listener) chan net.Conn { c := make(chan net.Conn) From 36d04da82eca99c533540926ac3abaf4a8209feb Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Tue, 19 May 2015 14:48:16 +0200 Subject: [PATCH 4/4] type connEncoder can be function-scoped --- xfer/publisher.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xfer/publisher.go b/xfer/publisher.go index 7dc8758f6..183841ccf 100644 --- a/xfer/publisher.go +++ b/xfer/publisher.go @@ -53,6 +53,11 @@ func (p *TCPPublisher) Publish(msg report.Report) { } func (p *TCPPublisher) loop(incoming <-chan net.Conn) { + type connEncoder struct { + net.Conn + *gob.Encoder + } + activeConns := map[string]connEncoder{} // host: connEncoder for { @@ -94,11 +99,6 @@ func (p *TCPPublisher) loop(incoming <-chan net.Conn) { } } -type connEncoder struct { - net.Conn - *gob.Encoder -} - func fwd(ln net.Listener) chan net.Conn { c := make(chan net.Conn)