diff --git a/app/api_topology.go b/app/api_topology.go index 305567a5d..16f03f192 100644 --- a/app/api_topology.go +++ b/app/api_topology.go @@ -5,7 +5,6 @@ import ( "time" log "github.com/Sirupsen/logrus" - "github.com/gorilla/websocket" "golang.org/x/net/context" "github.com/weaveworks/scope/common/xfer" @@ -14,8 +13,7 @@ import ( ) const ( - websocketLoop = 1 * time.Second - websocketTimeout = 10 * time.Second + websocketLoop = 1 * time.Second ) // APITopology is returned by the /api/topology/{name} handler. @@ -67,10 +65,6 @@ func handleNode(nodeID string) func(context.Context, Reporter, render.Renderer, } } -var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { return true }, -} - func handleWebsocket( ctx context.Context, w http.ResponseWriter, @@ -79,7 +73,7 @@ func handleWebsocket( renderer render.Renderer, loop time.Duration, ) { - conn, err := upgrader.Upgrade(w, r, nil) + conn, err := xfer.Upgrade(w, r, nil) if err != nil { // log.Info("Upgrade:", err) return @@ -87,9 +81,9 @@ func handleWebsocket( defer conn.Close() quit := make(chan struct{}) - go func(c *websocket.Conn) { + go func(c xfer.Websocket) { for { // just discard everything the browser sends - if _, _, err := c.NextReader(); err != nil { + if _, _, err := c.ReadMessage(); err != nil { if !xfer.IsExpectedWSCloseError(err) { log.Println("err:", err) } @@ -112,18 +106,13 @@ func handleWebsocket( diff := render.TopoDiff(previousTopo, newTopo) previousTopo = newTopo - if err := conn.SetWriteDeadline(time.Now().Add(websocketTimeout)); err != nil { + if err := conn.WriteJSON(diff); err != nil { if !xfer.IsExpectedWSCloseError(err) { - log.Println("err:", err) + log.Errorf("cannot serialize topology diff: %s", err) } return } - if err := xfer.WriteJSONtoWS(conn, diff); err != nil { - log.Errorf("cannot serialize topology diff: %s", err) - return - } - select { case <-wait: case <-tick: diff --git a/app/controls.go b/app/controls.go index 00eb43d20..2548651af 100644 --- a/app/controls.go +++ b/app/controls.go @@ -1,6 +1,7 @@ package app import ( + "io" "net/http" "net/rpc" @@ -56,7 +57,7 @@ func handleProbeWS(cr ControlRouter) CtxHandlerFunc { return } - conn, err := upgrader.Upgrade(w, r, nil) + conn, err := xfer.Upgrade(w, r, nil) if err != nil { log.Printf("Error upgrading control websocket: %v", err) return @@ -79,8 +80,8 @@ func handleProbeWS(cr ControlRouter) CtxHandlerFunc { return } defer cr.Deregister(ctx, probeID, id) - if err := codec.WaitForReadError(); err != nil && !xfer.IsExpectedWSCloseError(err) { - log.Printf("Error reading from probe %s control websocket: %v", probeID, err) + if err := codec.WaitForReadError(); err != nil && err != io.EOF && !xfer.IsExpectedWSCloseError(err) { + log.Errorf("Error on websocket: %v", err) } } } diff --git a/app/pipes.go b/app/pipes.go index 7ca90b543..c4724a360 100644 --- a/app/pipes.go +++ b/app/pipes.go @@ -35,7 +35,7 @@ func handlePipeWs(pr PipeRouter, end End) CtxHandlerFunc { } defer pr.Release(ctx, id, end) - conn, err := upgrader.Upgrade(w, r, nil) + conn, err := xfer.Upgrade(w, r, nil) if err != nil { log.Errorf("Error upgrading pipe %s (%d) websocket: %v", id, end, err) return diff --git a/common/xfer/controls.go b/common/xfer/controls.go index cfb0d1dde..3bb9c711a 100644 --- a/common/xfer/controls.go +++ b/common/xfer/controls.go @@ -4,8 +4,6 @@ import ( "fmt" "net/rpc" "sync" - - "github.com/gorilla/websocket" ) // ErrInvalidMessage is the error returned when the on-wire message is unexpected. @@ -70,12 +68,12 @@ func ResponseError(err error) Response { // that transmits and receives RPC messages over a websocker, as JSON. type JSONWebsocketCodec struct { sync.Mutex - conn *websocket.Conn + conn Websocket err chan error } // NewJSONWebsocketCodec makes a new JSONWebsocketCodec -func NewJSONWebsocketCodec(conn *websocket.Conn) *JSONWebsocketCodec { +func NewJSONWebsocketCodec(conn Websocket) *JSONWebsocketCodec { return &JSONWebsocketCodec{ conn: conn, err: make(chan error, 1), @@ -93,10 +91,10 @@ func (j *JSONWebsocketCodec) WriteRequest(r *rpc.Request, v interface{}) error { j.Lock() defer j.Unlock() - if err := WriteJSONtoWS(j.conn, Message{Request: r}); err != nil { + if err := j.conn.WriteJSON(Message{Request: r}); err != nil { return err } - return WriteJSONtoWS(j.conn, Message{Value: v}) + return j.conn.WriteJSON(Message{Value: v}) } // WriteResponse implements rpc.ServerCodec @@ -104,15 +102,15 @@ func (j *JSONWebsocketCodec) WriteResponse(r *rpc.Response, v interface{}) error j.Lock() defer j.Unlock() - if err := WriteJSONtoWS(j.conn, Message{Response: r}); err != nil { + if err := j.conn.WriteJSON(Message{Response: r}); err != nil { return err } - return WriteJSONtoWS(j.conn, Message{Value: v}) + return j.conn.WriteJSON(Message{Value: v}) } func (j *JSONWebsocketCodec) readMessage(v interface{}) (*Message, error) { m := Message{Value: v} - if err := ReadJSONfromWS(j.conn, &m); err != nil { + if err := j.conn.ReadJSON(&m); err != nil { j.err <- err close(j.err) return nil, err diff --git a/common/xfer/pipes.go b/common/xfer/pipes.go index 8e85515ff..8172644b6 100644 --- a/common/xfer/pipes.go +++ b/common/xfer/pipes.go @@ -11,7 +11,7 @@ import ( // to the UI. type Pipe interface { Ends() (io.ReadWriter, io.ReadWriter) - CopyToWebsocket(io.ReadWriter, *websocket.Conn) error + CopyToWebsocket(io.ReadWriter, Websocket) error Close() error Closed() bool @@ -83,7 +83,7 @@ func (p *pipe) OnClose(f func()) { } // CopyToWebsocket copies pipe data to/from a websocket. It blocks. -func (p *pipe) CopyToWebsocket(end io.ReadWriter, conn *websocket.Conn) error { +func (p *pipe) CopyToWebsocket(end io.ReadWriter, conn Websocket) error { p.mtx.Lock() if p.closed { p.mtx.Unlock() diff --git a/common/xfer/websocket.go b/common/xfer/websocket.go index 1d3dd07be..f5faabe54 100644 --- a/common/xfer/websocket.go +++ b/common/xfer/websocket.go @@ -2,11 +2,165 @@ package xfer import ( "io" + "net/http" + "sync" + "time" + log "github.com/Sirupsen/logrus" "github.com/gorilla/websocket" "github.com/ugorji/go/codec" + + "github.com/weaveworks/scope/common/mtime" ) +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. Needs to be less + // than the idle timeout on whatever frontend server is proxying the + // websocket connections (e.g. nginx). + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. The peer + // must respond with a pong in < pongWait. But it may take writeWait for the + // pong to be sent. Therefore we want to allow time for that, and a bit of + // delay/round-trip in case the peer is busy. 1/3 of pongWait seems like a + // reasonable amount of time to respond to a ping. + pingPeriod = ((pongWait - writeWait) * 2 / 3) +) + +// Websocket exposes the bits of *websocket.Conn we actually use. +type Websocket interface { + ReadMessage() (messageType int, p []byte, err error) + WriteMessage(messageType int, data []byte) error + ReadJSON(v interface{}) error + WriteJSON(v interface{}) error + Close() error +} + +type pingingWebsocket struct { + pinger *time.Timer + readLock sync.Mutex + writeLock sync.Mutex + conn *websocket.Conn +} + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (Websocket, error) { + wsConn, err := upgrader.Upgrade(w, r, responseHeader) + if err != nil { + return nil, err + } + return Ping(wsConn), nil +} + +// WSDialer can dial a new websocket +type WSDialer interface { + Dial(urlStr string, requestHeader http.Header) (*websocket.Conn, *http.Response, error) +} + +// DialWS creates a new client connection. Use requestHeader to specify the +// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). +// Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +func DialWS(d WSDialer, urlStr string, requestHeader http.Header) (Websocket, *http.Response, error) { + wsConn, resp, err := d.Dial(urlStr, requestHeader) + if err != nil { + return nil, nil, err + } + return Ping(wsConn), resp, nil +} + +// Ping adds a periodic ping to a websocket connection. +func Ping(c *websocket.Conn) Websocket { + p := &pingingWebsocket{conn: c} + p.conn.SetPongHandler(p.pong) + p.conn.SetReadDeadline(mtime.Now().Add(pongWait)) + p.pinger = time.AfterFunc(pingPeriod, p.ping) + return p +} + +func (p *pingingWebsocket) ping() { + p.writeLock.Lock() + defer p.writeLock.Unlock() + if err := p.conn.WriteControl(websocket.PingMessage, nil, mtime.Now().Add(writeWait)); err != nil { + log.Errorf("websocket ping error: %v", err) + p.Close() + } + p.pinger.Reset(pingPeriod) +} + +func (p *pingingWebsocket) pong(string) error { + p.conn.SetReadDeadline(mtime.Now().Add(pongWait)) + return nil +} + +// ReadMessage is a helper method for getting a reader using NextReader and +// reading from that reader to a buffer. +func (p *pingingWebsocket) ReadMessage() (int, []byte, error) { + p.readLock.Lock() + defer p.readLock.Unlock() + return p.conn.ReadMessage() +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (p *pingingWebsocket) WriteMessage(messageType int, data []byte) error { + p.writeLock.Lock() + defer p.writeLock.Unlock() + if err := p.conn.SetWriteDeadline(mtime.Now().Add(writeWait)); err != nil { + return err + } + return p.conn.WriteMessage(messageType, data) +} + +// WriteJSON writes the JSON encoding of v to the connection. +func (p *pingingWebsocket) WriteJSON(v interface{}) error { + p.writeLock.Lock() + defer p.writeLock.Unlock() + w, err := p.conn.NextWriter(websocket.TextMessage) + if err != nil { + return err + } + if err := p.conn.SetWriteDeadline(mtime.Now().Add(writeWait)); err != nil { + return err + } + err1 := codec.NewEncoder(w, &codec.JsonHandle{}).Encode(v) + err2 := w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +func (p *pingingWebsocket) ReadJSON(v interface{}) error { + p.readLock.Lock() + defer p.readLock.Unlock() + _, r, err := p.conn.NextReader() + if err != nil { + return err + } + err = codec.NewDecoder(r, &codec.JsonHandle{}).Decode(v) + if err == io.EOF { + // One value is expected in the message. + err = io.ErrUnexpectedEOF + } + return err +} + +// Close closes the connection +func (p *pingingWebsocket) Close() error { + p.pinger.Stop() + return p.conn.Close() +} + // IsExpectedWSCloseError returns boolean indicating whether the error is a // clean disconnection. func IsExpectedWSCloseError(err error) bool { @@ -17,32 +171,3 @@ func IsExpectedWSCloseError(err error) bool { websocket.CloseNoStatusReceived, ) } - -// WriteJSONtoWS writes the JSON encoding of v to the connection. -func WriteJSONtoWS(c *websocket.Conn, v interface{}) error { - w, err := c.NextWriter(websocket.TextMessage) - if err != nil { - return err - } - err1 := codec.NewEncoder(w, &codec.JsonHandle{}).Encode(v) - err2 := w.Close() - if err1 != nil { - return err1 - } - return err2 -} - -// ReadJSONfromWS reads the next JSON-encoded message from the connection and stores -// it in the value pointed to by v. -func ReadJSONfromWS(c *websocket.Conn, v interface{}) error { - _, r, err := c.NextReader() - if err != nil { - return err - } - err = codec.NewDecoder(r, &codec.JsonHandle{}).Decode(v) - if err == io.EOF { - // One value is expected in the message. - err = io.ErrUnexpectedEOF - } - return err -} diff --git a/probe/appclient/app_client.go b/probe/appclient/app_client.go index 1424d9962..8d8409f0a 100644 --- a/probe/appclient/app_client.go +++ b/probe/appclient/app_client.go @@ -46,7 +46,7 @@ type appClient struct { backgroundWait sync.WaitGroup // Track ongoing websocket connections - conns map[string]*websocket.Conn + conns map[string]xfer.Websocket // For publish publishLoop sync.Once @@ -73,7 +73,7 @@ func NewAppClient(pc ProbeConfig, hostname, target string, control xfer.ControlH wsDialer: websocket.Dialer{ TLSClientConfig: httpTransport.TLSClientConfig, }, - conns: map[string]*websocket.Conn{}, + conns: map[string]xfer.Websocket{}, readers: make(chan io.Reader), control: control, }, nil @@ -88,7 +88,7 @@ func (c *appClient) hasQuit() bool { } } -func (c *appClient) registerConn(id string, conn *websocket.Conn) bool { +func (c *appClient) registerConn(id string, conn xfer.Websocket) bool { c.mtx.Lock() defer c.mtx.Unlock() if c.hasQuit() { @@ -130,7 +130,7 @@ func (c *appClient) Stop() { for _, conn := range c.conns { conn.Close() } - c.conns = map[string]*websocket.Conn{} + c.conns = map[string]xfer.Websocket{} c.mtx.Unlock() c.backgroundWait.Wait() @@ -188,13 +188,11 @@ func (c *appClient) controlConnection() (bool, error) { headers := http.Header{} c.ProbeConfig.authorizeHeaders(headers) url := sanitize.URL("ws://", 0, "/api/control/ws")(c.target) - conn, _, err := c.wsDialer.Dial(url, headers) + conn, _, err := xfer.DialWS(&c.wsDialer, url, headers) if err != nil { return false, err } - defer func() { - conn.Close() - }() + defer conn.Close() codec := xfer.NewJSONWebsocketCodec(conn) server := rpc.NewServer() @@ -271,7 +269,7 @@ func (c *appClient) pipeConnection(id string, pipe xfer.Pipe) (bool, error) { headers := http.Header{} c.ProbeConfig.authorizeHeaders(headers) url := sanitize.URL("ws://", 0, fmt.Sprintf("/api/pipe/%s/probe", id))(c.target) - conn, resp, err := c.wsDialer.Dial(url, headers) + conn, resp, err := xfer.DialWS(&c.wsDialer, url, headers) if resp != nil && resp.StatusCode == http.StatusNotFound { // Special handling - 404 means the app/user has closed the pipe pipe.Close() diff --git a/probe/docker/controls_test.go b/probe/docker/controls_test.go index bf5093fd1..e7f17aef7 100644 --- a/probe/docker/controls_test.go +++ b/probe/docker/controls_test.go @@ -6,8 +6,6 @@ import ( "testing" "time" - "github.com/gorilla/websocket" - "github.com/weaveworks/scope/common/xfer" "github.com/weaveworks/scope/probe/controls" "github.com/weaveworks/scope/probe/docker" @@ -43,11 +41,11 @@ func TestControls(t *testing.T) { type mockPipe struct{} -func (mockPipe) Ends() (io.ReadWriter, io.ReadWriter) { return nil, nil } -func (mockPipe) CopyToWebsocket(io.ReadWriter, *websocket.Conn) error { return nil } -func (mockPipe) Close() error { return nil } -func (mockPipe) Closed() bool { return false } -func (mockPipe) OnClose(func()) {} +func (mockPipe) Ends() (io.ReadWriter, io.ReadWriter) { return nil, nil } +func (mockPipe) CopyToWebsocket(io.ReadWriter, xfer.Websocket) error { return nil } +func (mockPipe) Close() error { return nil } +func (mockPipe) Closed() bool { return false } +func (mockPipe) OnClose(func()) {} func TestPipes(t *testing.T) { oldNewPipe := controls.NewPipe