Prevent multiple connections for the same app.

This commit is contained in:
Tom Wilkie
2015-06-10 12:03:32 +00:00
parent 9c6ed7b3c4
commit 894439a449
9 changed files with 52 additions and 18 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"log"
"log/syslog"
"math/rand"
"net/http"
_ "net/http/pprof"
"os"
@@ -60,11 +61,11 @@ func main() {
log.SetOutput(f)
}
log.Printf("app starting, version %s", version)
id := strconv.FormatInt(rand.Int63(), 16)
log.Printf("app starting, version %s, id %s", version, id)
// Collector deals with the probes, and generates merged reports.
xfer.MaxBackoff = 10 * time.Second
c := xfer.NewCollector(*batch)
c := xfer.NewCollector(*batch, id)
defer c.Stop()
r := newStaticResolver(probes, c.Add)

View File

@@ -49,7 +49,7 @@ func main() {
// Collector deals with the probes, and generates a single merged report
// every second.
c := xfer.NewCollector(*batch)
c := xfer.NewCollector(*batch, "id")
for _, addr := range fixedAddresses {
c.Add(addr)
}

View File

@@ -26,7 +26,7 @@ func main() {
flag.Parse()
xfer.MaxBackoff = 10 * time.Second
c := xfer.NewCollector(*batch)
c := xfer.NewCollector(*batch, "id")
for _, addr := range strings.Split(*probes, ",") {
c.Add(addr)
}

View File

@@ -28,7 +28,7 @@ func main() {
// Collector deals with the probes, and generates merged reports.
xfer.MaxBackoff = 1 * time.Second
c := xfer.NewCollector(1 * time.Second)
c := xfer.NewCollector(1*time.Second, "id")
for _, addr := range strings.Split(*probes, ",") {
c.Add(addr)
}

View File

@@ -18,7 +18,7 @@ const (
var (
// MaxBackoff is the maximum time between connect retries.
// It's exported so it's externally configurable.
MaxBackoff = 2 * time.Minute
MaxBackoff = 1 * time.Minute
// This is extracted out for mocking.
tick = time.Tick
@@ -43,10 +43,11 @@ type realCollector struct {
add chan string
remove chan string
quit chan struct{}
id string
}
// NewCollector produces and returns a report collector.
func NewCollector(batchTime time.Duration) Collector {
func NewCollector(batchTime time.Duration, id string) Collector {
c := &realCollector{
in: make(chan report.Report),
out: make(chan report.Report),
@@ -54,6 +55,7 @@ func NewCollector(batchTime time.Duration) Collector {
add: make(chan string),
remove: make(chan string),
quit: make(chan struct{}),
id: id,
}
go c.loop(batchTime)
return c
@@ -75,7 +77,7 @@ func (c *realCollector) loop(batchTime time.Duration) {
wg.Add(1)
go func(quit chan struct{}) {
defer wg.Done()
reportCollector(ip, c.in, quit)
c.reportCollector(ip, quit)
}(addrs[ip])
}
@@ -147,7 +149,7 @@ func (c *realCollector) Stop() {
// reportCollector is the loop to connect to a single Probe. It'll keep
// running until the quit channel is closed.
func reportCollector(ip string, col chan<- report.Report, quit <-chan struct{}) {
func (c *realCollector) reportCollector(ip string, quit <-chan struct{}) {
backoff := initialBackoff / 2
for {
backoff *= 2
@@ -183,6 +185,11 @@ func reportCollector(ip string, col chan<- report.Report, quit <-chan struct{})
}()
// Connection accepted.
if err := gob.NewEncoder(conn).Encode(HandshakeRequest{ID: c.id}); err != nil {
log.Printf("handshake error: %v", err)
break
}
dec := gob.NewDecoder(conn)
for {
var report report.Report
@@ -199,7 +206,7 @@ func reportCollector(ip string, col chan<- report.Report, quit <-chan struct{})
}
select {
case col <- report:
case c.in <- report:
case <-quit:
return
}

View File

@@ -22,7 +22,7 @@ func TestCollector(t *testing.T) {
defer func() { tick = oldTick }()
// Build a collector
collector := NewCollector(time.Second)
collector := NewCollector(time.Second, "id")
defer collector.Stop()
concreteCollector, ok := collector.(*realCollector)
@@ -54,7 +54,7 @@ func TestCollector(t *testing.T) {
}
func TestCollectorQuitWithActiveConnections(t *testing.T) {
c := NewCollector(time.Second)
c := NewCollector(time.Second, "id")
c.Add("1.2.3.4:56789")
c.Stop()
}

View File

@@ -31,7 +31,7 @@ func TestMerge(t *testing.T) {
defer p2.Close()
batchTime := 100 * time.Millisecond
c := xfer.NewCollector(batchTime)
c := xfer.NewCollector(batchTime, "id")
c.Add(p1Addr)
c.Add(p2Addr)
defer c.Stop()

View File

@@ -21,6 +21,11 @@ type TCPPublisher struct {
closer io.Closer
}
// HandshakeRequest contains the unique ID of the connecting app.
type HandshakeRequest struct {
ID string
}
// NewTCPPublisher listens for connections on listenAddress. Only one client
// is accepted at a time; other clients are accepted, but disconnected right
// away. Reports published via publish() will be written to the connected
@@ -68,20 +73,20 @@ func (p *TCPPublisher) loop(incoming <-chan net.Conn) {
}
// Don't allow multiple connections from the same remote host.
host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
listenerID, err := getListenerID(conn)
if err != nil {
log.Printf("incoming connection: %s: %v (dropped)", conn.RemoteAddr(), err)
conn.Close()
continue
}
if _, ok := activeConns[host]; ok {
if _, ok := activeConns[listenerID]; ok {
log.Printf("duplicate connection: %s (dropped)", conn.RemoteAddr())
conn.Close()
continue
}
log.Printf("connection initiated: %s", conn.RemoteAddr())
activeConns[host] = connEncoder{conn, gob.NewEncoder(conn)}
log.Printf("connection initiated: %s (%s)", conn.RemoteAddr(), listenerID)
activeConns[listenerID] = connEncoder{conn, gob.NewEncoder(conn)}
case msg, ok := <-p.msg:
if !ok {
@@ -99,6 +104,15 @@ func (p *TCPPublisher) loop(incoming <-chan net.Conn) {
}
}
func getListenerID(c net.Conn) (string, error) {
var req HandshakeRequest
if err := gob.NewDecoder(c).Decode(&req); err != nil {
return "", err
}
return req.ID, nil
}
func fwd(ln net.Listener) chan net.Conn {
c := make(chan net.Conn)

View File

@@ -37,6 +37,11 @@ func TestTCPPublisher(t *testing.T) {
defer conn.Close()
time.Sleep(time.Millisecond)
// Send handshake
if err := gob.NewEncoder(conn).Encode(xfer.HandshakeRequest{ID: "foo"}); err != nil {
t.Fatal(err)
}
// Publish a message
p.Publish(report.Report{})
@@ -69,6 +74,9 @@ func TestPublisherClosesDuplicateConnections(t *testing.T) {
t.Fatal(err)
}
defer conn.Close()
if err := gob.NewEncoder(conn).Encode(xfer.HandshakeRequest{ID: "foo"}); err != nil {
t.Fatal(err)
}
time.Sleep(time.Millisecond)
// Try to connect the same listener
@@ -76,6 +84,10 @@ func TestPublisherClosesDuplicateConnections(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// Send handshake
if err := gob.NewEncoder(dupconn).Encode(xfer.HandshakeRequest{ID: "foo"}); err != nil {
t.Fatal(err)
}
defer dupconn.Close()
// Publish a message