Merge pull request #995 from weaveworks/828-websocket-ping

Add ping/pong to websocket protocol
This commit is contained in:
Paul Bellamy
2016-02-25 17:13:14 +00:00
8 changed files with 186 additions and 77 deletions

View File

@@ -5,7 +5,6 @@ import (
"time"
log "github.com/Sirupsen/logrus"
"github.com/gorilla/websocket"
"golang.org/x/net/context"
"github.com/weaveworks/scope/common/xfer"
@@ -14,8 +13,7 @@ import (
)
const (
websocketLoop = 1 * time.Second
websocketTimeout = 10 * time.Second
websocketLoop = 1 * time.Second
)
// APITopology is returned by the /api/topology/{name} handler.
@@ -67,10 +65,6 @@ func handleNode(nodeID string) func(context.Context, Reporter, render.Renderer,
}
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
func handleWebsocket(
ctx context.Context,
w http.ResponseWriter,
@@ -79,7 +73,7 @@ func handleWebsocket(
renderer render.Renderer,
loop time.Duration,
) {
conn, err := upgrader.Upgrade(w, r, nil)
conn, err := xfer.Upgrade(w, r, nil)
if err != nil {
// log.Info("Upgrade:", err)
return
@@ -87,9 +81,9 @@ func handleWebsocket(
defer conn.Close()
quit := make(chan struct{})
go func(c *websocket.Conn) {
go func(c xfer.Websocket) {
for { // just discard everything the browser sends
if _, _, err := c.NextReader(); err != nil {
if _, _, err := c.ReadMessage(); err != nil {
if !xfer.IsExpectedWSCloseError(err) {
log.Println("err:", err)
}
@@ -112,18 +106,13 @@ func handleWebsocket(
diff := render.TopoDiff(previousTopo, newTopo)
previousTopo = newTopo
if err := conn.SetWriteDeadline(time.Now().Add(websocketTimeout)); err != nil {
if err := conn.WriteJSON(diff); err != nil {
if !xfer.IsExpectedWSCloseError(err) {
log.Println("err:", err)
log.Errorf("cannot serialize topology diff: %s", err)
}
return
}
if err := xfer.WriteJSONtoWS(conn, diff); err != nil {
log.Errorf("cannot serialize topology diff: %s", err)
return
}
select {
case <-wait:
case <-tick:

View File

@@ -1,6 +1,7 @@
package app
import (
"io"
"net/http"
"net/rpc"
@@ -56,7 +57,7 @@ func handleProbeWS(cr ControlRouter) CtxHandlerFunc {
return
}
conn, err := upgrader.Upgrade(w, r, nil)
conn, err := xfer.Upgrade(w, r, nil)
if err != nil {
log.Printf("Error upgrading control websocket: %v", err)
return
@@ -79,8 +80,8 @@ func handleProbeWS(cr ControlRouter) CtxHandlerFunc {
return
}
defer cr.Deregister(ctx, probeID, id)
if err := codec.WaitForReadError(); err != nil && !xfer.IsExpectedWSCloseError(err) {
log.Printf("Error reading from probe %s control websocket: %v", probeID, err)
if err := codec.WaitForReadError(); err != nil && err != io.EOF && !xfer.IsExpectedWSCloseError(err) {
log.Errorf("Error on websocket: %v", err)
}
}
}

View File

@@ -35,7 +35,7 @@ func handlePipeWs(pr PipeRouter, end End) CtxHandlerFunc {
}
defer pr.Release(ctx, id, end)
conn, err := upgrader.Upgrade(w, r, nil)
conn, err := xfer.Upgrade(w, r, nil)
if err != nil {
log.Errorf("Error upgrading pipe %s (%d) websocket: %v", id, end, err)
return

View File

@@ -4,8 +4,6 @@ import (
"fmt"
"net/rpc"
"sync"
"github.com/gorilla/websocket"
)
// ErrInvalidMessage is the error returned when the on-wire message is unexpected.
@@ -70,12 +68,12 @@ func ResponseError(err error) Response {
// that transmits and receives RPC messages over a websocker, as JSON.
type JSONWebsocketCodec struct {
sync.Mutex
conn *websocket.Conn
conn Websocket
err chan error
}
// NewJSONWebsocketCodec makes a new JSONWebsocketCodec
func NewJSONWebsocketCodec(conn *websocket.Conn) *JSONWebsocketCodec {
func NewJSONWebsocketCodec(conn Websocket) *JSONWebsocketCodec {
return &JSONWebsocketCodec{
conn: conn,
err: make(chan error, 1),
@@ -93,10 +91,10 @@ func (j *JSONWebsocketCodec) WriteRequest(r *rpc.Request, v interface{}) error {
j.Lock()
defer j.Unlock()
if err := WriteJSONtoWS(j.conn, Message{Request: r}); err != nil {
if err := j.conn.WriteJSON(Message{Request: r}); err != nil {
return err
}
return WriteJSONtoWS(j.conn, Message{Value: v})
return j.conn.WriteJSON(Message{Value: v})
}
// WriteResponse implements rpc.ServerCodec
@@ -104,15 +102,15 @@ func (j *JSONWebsocketCodec) WriteResponse(r *rpc.Response, v interface{}) error
j.Lock()
defer j.Unlock()
if err := WriteJSONtoWS(j.conn, Message{Response: r}); err != nil {
if err := j.conn.WriteJSON(Message{Response: r}); err != nil {
return err
}
return WriteJSONtoWS(j.conn, Message{Value: v})
return j.conn.WriteJSON(Message{Value: v})
}
func (j *JSONWebsocketCodec) readMessage(v interface{}) (*Message, error) {
m := Message{Value: v}
if err := ReadJSONfromWS(j.conn, &m); err != nil {
if err := j.conn.ReadJSON(&m); err != nil {
j.err <- err
close(j.err)
return nil, err

View File

@@ -11,7 +11,7 @@ import (
// to the UI.
type Pipe interface {
Ends() (io.ReadWriter, io.ReadWriter)
CopyToWebsocket(io.ReadWriter, *websocket.Conn) error
CopyToWebsocket(io.ReadWriter, Websocket) error
Close() error
Closed() bool
@@ -83,7 +83,7 @@ func (p *pipe) OnClose(f func()) {
}
// CopyToWebsocket copies pipe data to/from a websocket. It blocks.
func (p *pipe) CopyToWebsocket(end io.ReadWriter, conn *websocket.Conn) error {
func (p *pipe) CopyToWebsocket(end io.ReadWriter, conn Websocket) error {
p.mtx.Lock()
if p.closed {
p.mtx.Unlock()

View File

@@ -2,11 +2,165 @@ package xfer
import (
"io"
"net/http"
"sync"
"time"
log "github.com/Sirupsen/logrus"
"github.com/gorilla/websocket"
"github.com/ugorji/go/codec"
"github.com/weaveworks/scope/common/mtime"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer. Needs to be less
// than the idle timeout on whatever frontend server is proxying the
// websocket connections (e.g. nginx).
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait. The peer
// must respond with a pong in < pongWait. But it may take writeWait for the
// pong to be sent. Therefore we want to allow time for that, and a bit of
// delay/round-trip in case the peer is busy. 1/3 of pongWait seems like a
// reasonable amount of time to respond to a ping.
pingPeriod = ((pongWait - writeWait) * 2 / 3)
)
// Websocket exposes the bits of *websocket.Conn we actually use.
type Websocket interface {
ReadMessage() (messageType int, p []byte, err error)
WriteMessage(messageType int, data []byte) error
ReadJSON(v interface{}) error
WriteJSON(v interface{}) error
Close() error
}
type pingingWebsocket struct {
pinger *time.Timer
readLock sync.Mutex
writeLock sync.Mutex
conn *websocket.Conn
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (Websocket, error) {
wsConn, err := upgrader.Upgrade(w, r, responseHeader)
if err != nil {
return nil, err
}
return Ping(wsConn), nil
}
// WSDialer can dial a new websocket
type WSDialer interface {
Dial(urlStr string, requestHeader http.Header) (*websocket.Conn, *http.Response, error)
}
// DialWS creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
// Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
func DialWS(d WSDialer, urlStr string, requestHeader http.Header) (Websocket, *http.Response, error) {
wsConn, resp, err := d.Dial(urlStr, requestHeader)
if err != nil {
return nil, nil, err
}
return Ping(wsConn), resp, nil
}
// Ping adds a periodic ping to a websocket connection.
func Ping(c *websocket.Conn) Websocket {
p := &pingingWebsocket{conn: c}
p.conn.SetPongHandler(p.pong)
p.conn.SetReadDeadline(mtime.Now().Add(pongWait))
p.pinger = time.AfterFunc(pingPeriod, p.ping)
return p
}
func (p *pingingWebsocket) ping() {
p.writeLock.Lock()
defer p.writeLock.Unlock()
if err := p.conn.WriteControl(websocket.PingMessage, nil, mtime.Now().Add(writeWait)); err != nil {
log.Errorf("websocket ping error: %v", err)
p.Close()
}
p.pinger.Reset(pingPeriod)
}
func (p *pingingWebsocket) pong(string) error {
p.conn.SetReadDeadline(mtime.Now().Add(pongWait))
return nil
}
// ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer.
func (p *pingingWebsocket) ReadMessage() (int, []byte, error) {
p.readLock.Lock()
defer p.readLock.Unlock()
return p.conn.ReadMessage()
}
// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (p *pingingWebsocket) WriteMessage(messageType int, data []byte) error {
p.writeLock.Lock()
defer p.writeLock.Unlock()
if err := p.conn.SetWriteDeadline(mtime.Now().Add(writeWait)); err != nil {
return err
}
return p.conn.WriteMessage(messageType, data)
}
// WriteJSON writes the JSON encoding of v to the connection.
func (p *pingingWebsocket) WriteJSON(v interface{}) error {
p.writeLock.Lock()
defer p.writeLock.Unlock()
w, err := p.conn.NextWriter(websocket.TextMessage)
if err != nil {
return err
}
if err := p.conn.SetWriteDeadline(mtime.Now().Add(writeWait)); err != nil {
return err
}
err1 := codec.NewEncoder(w, &codec.JsonHandle{}).Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
func (p *pingingWebsocket) ReadJSON(v interface{}) error {
p.readLock.Lock()
defer p.readLock.Unlock()
_, r, err := p.conn.NextReader()
if err != nil {
return err
}
err = codec.NewDecoder(r, &codec.JsonHandle{}).Decode(v)
if err == io.EOF {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}
return err
}
// Close closes the connection
func (p *pingingWebsocket) Close() error {
p.pinger.Stop()
return p.conn.Close()
}
// IsExpectedWSCloseError returns boolean indicating whether the error is a
// clean disconnection.
func IsExpectedWSCloseError(err error) bool {
@@ -17,32 +171,3 @@ func IsExpectedWSCloseError(err error) bool {
websocket.CloseNoStatusReceived,
)
}
// WriteJSONtoWS writes the JSON encoding of v to the connection.
func WriteJSONtoWS(c *websocket.Conn, v interface{}) error {
w, err := c.NextWriter(websocket.TextMessage)
if err != nil {
return err
}
err1 := codec.NewEncoder(w, &codec.JsonHandle{}).Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}
// ReadJSONfromWS reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
func ReadJSONfromWS(c *websocket.Conn, v interface{}) error {
_, r, err := c.NextReader()
if err != nil {
return err
}
err = codec.NewDecoder(r, &codec.JsonHandle{}).Decode(v)
if err == io.EOF {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}
return err
}

View File

@@ -46,7 +46,7 @@ type appClient struct {
backgroundWait sync.WaitGroup
// Track ongoing websocket connections
conns map[string]*websocket.Conn
conns map[string]xfer.Websocket
// For publish
publishLoop sync.Once
@@ -73,7 +73,7 @@ func NewAppClient(pc ProbeConfig, hostname, target string, control xfer.ControlH
wsDialer: websocket.Dialer{
TLSClientConfig: httpTransport.TLSClientConfig,
},
conns: map[string]*websocket.Conn{},
conns: map[string]xfer.Websocket{},
readers: make(chan io.Reader),
control: control,
}, nil
@@ -88,7 +88,7 @@ func (c *appClient) hasQuit() bool {
}
}
func (c *appClient) registerConn(id string, conn *websocket.Conn) bool {
func (c *appClient) registerConn(id string, conn xfer.Websocket) bool {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.hasQuit() {
@@ -130,7 +130,7 @@ func (c *appClient) Stop() {
for _, conn := range c.conns {
conn.Close()
}
c.conns = map[string]*websocket.Conn{}
c.conns = map[string]xfer.Websocket{}
c.mtx.Unlock()
c.backgroundWait.Wait()
@@ -188,13 +188,11 @@ func (c *appClient) controlConnection() (bool, error) {
headers := http.Header{}
c.ProbeConfig.authorizeHeaders(headers)
url := sanitize.URL("ws://", 0, "/api/control/ws")(c.target)
conn, _, err := c.wsDialer.Dial(url, headers)
conn, _, err := xfer.DialWS(&c.wsDialer, url, headers)
if err != nil {
return false, err
}
defer func() {
conn.Close()
}()
defer conn.Close()
codec := xfer.NewJSONWebsocketCodec(conn)
server := rpc.NewServer()
@@ -271,7 +269,7 @@ func (c *appClient) pipeConnection(id string, pipe xfer.Pipe) (bool, error) {
headers := http.Header{}
c.ProbeConfig.authorizeHeaders(headers)
url := sanitize.URL("ws://", 0, fmt.Sprintf("/api/pipe/%s/probe", id))(c.target)
conn, resp, err := c.wsDialer.Dial(url, headers)
conn, resp, err := xfer.DialWS(&c.wsDialer, url, headers)
if resp != nil && resp.StatusCode == http.StatusNotFound {
// Special handling - 404 means the app/user has closed the pipe
pipe.Close()

View File

@@ -6,8 +6,6 @@ import (
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/weaveworks/scope/common/xfer"
"github.com/weaveworks/scope/probe/controls"
"github.com/weaveworks/scope/probe/docker"
@@ -43,11 +41,11 @@ func TestControls(t *testing.T) {
type mockPipe struct{}
func (mockPipe) Ends() (io.ReadWriter, io.ReadWriter) { return nil, nil }
func (mockPipe) CopyToWebsocket(io.ReadWriter, *websocket.Conn) error { return nil }
func (mockPipe) Close() error { return nil }
func (mockPipe) Closed() bool { return false }
func (mockPipe) OnClose(func()) {}
func (mockPipe) Ends() (io.ReadWriter, io.ReadWriter) { return nil, nil }
func (mockPipe) CopyToWebsocket(io.ReadWriter, xfer.Websocket) error { return nil }
func (mockPipe) Close() error { return nil }
func (mockPipe) Closed() bool { return false }
func (mockPipe) OnClose(func()) {}
func TestPipes(t *testing.T) {
oldNewPipe := controls.NewPipe