mirror of
https://github.com/weaveworks/scope.git
synced 2026-02-14 18:09:59 +00:00
Better error message when plugin responses are too large
This commit is contained in:
43
probe/plugins/max_bytes_reader.go
Normal file
43
probe/plugins/max_bytes_reader.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// MaxBytesReader is similar to net/http.MaxBytesReader, but lets us choose how
|
||||
// to handle an overflow by providing an error. net/http.MaxBytesReader uses
|
||||
// net/http internals to render a naff error message. There are other
|
||||
// discrepancies with how this detects overflows. Not sure if that will cause
|
||||
// issues. If you want to use it as part of an HTTP server, it's probably best
|
||||
// to change it so you can provide a callback func, which renders your error
|
||||
// message, as returning an error into the middle of the net/http server will
|
||||
// not be useful.
|
||||
func MaxBytesReader(r io.ReadCloser, maxBytes int64, err error) io.ReadCloser {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &maxBytesReader{
|
||||
ReadCloser: r,
|
||||
bytesRemaining: maxBytes,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
type maxBytesReader struct {
|
||||
io.ReadCloser
|
||||
bytesRemaining int64
|
||||
err error // Callback when overflowing
|
||||
}
|
||||
|
||||
func (r *maxBytesReader) Read(p []byte) (int, error) {
|
||||
if r.bytesRemaining <= 0 {
|
||||
return 0, r.err
|
||||
}
|
||||
if int64(len(p)) > r.bytesRemaining {
|
||||
p = p[0:r.bytesRemaining]
|
||||
}
|
||||
n, err := r.ReadCloser.Read(p)
|
||||
r.bytesRemaining -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
89
probe/plugins/max_bytes_reader_internal_test.go
Normal file
89
probe/plugins/max_bytes_reader_internal_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package plugins
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMaxBytesReaderReturnsAllDataIfSmaller(t *testing.T) {
|
||||
result, err := ioutil.ReadAll(MaxBytesReader(ioutil.NopCloser(strings.NewReader("some data")), 1024, errors.New("test error")))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if string(result) != "some data" {
|
||||
t.Errorf("Expected %q, got %q", "some data", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBytesReaderReturnsNilIfNil(t *testing.T) {
|
||||
result := MaxBytesReader(nil, 1024, errors.New("test error"))
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil, got: %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBytesReaderReturnsErrorIfLarger(t *testing.T) {
|
||||
input := &bytes.Buffer{}
|
||||
for i := int64(0); i <= 1024; i++ {
|
||||
input.WriteByte(byte(i))
|
||||
}
|
||||
|
||||
result, err := ioutil.ReadAll(MaxBytesReader(ioutil.NopCloser(input), 1024, errors.New("test error")))
|
||||
if err.Error() != "test error" {
|
||||
t.Errorf("Expected error to be %q, got: %q", "test error", err.Error())
|
||||
}
|
||||
if len(result) != 1024 {
|
||||
t.Errorf("Expected result length to be 1024, but got: %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBytesReaderReturnsErrorIfLargerAndMassiveBufferGiven(t *testing.T) {
|
||||
input := &bytes.Buffer{}
|
||||
for i := int64(0); i <= 1024; i++ {
|
||||
input.WriteByte(byte(i))
|
||||
}
|
||||
|
||||
buffer := make([]byte, 1024+2)
|
||||
reader := MaxBytesReader(ioutil.NopCloser(input), 1024, errors.New("test error"))
|
||||
|
||||
// First read is scoped down to the maximum
|
||||
readCount, err := reader.Read(buffer)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if readCount != 1024 {
|
||||
t.Errorf("Expected result length to be 1024, but got: %d", readCount)
|
||||
}
|
||||
|
||||
// Second read returns an error
|
||||
readCount, err = reader.Read(buffer)
|
||||
if err.Error() != "test error" {
|
||||
t.Errorf("Expected error to be %q, got: %q", "test error", err.Error())
|
||||
}
|
||||
if readCount != 0 {
|
||||
t.Errorf("Expected result length to be 0, but got: %d", readCount)
|
||||
}
|
||||
}
|
||||
|
||||
type testReadCloser struct {
|
||||
closeError error
|
||||
}
|
||||
|
||||
func (c testReadCloser) Read(p []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c testReadCloser) Close() error {
|
||||
return c.closeError
|
||||
}
|
||||
|
||||
func TestMaxBytesReaderPassesThroughErrorsWhenClosing(t *testing.T) {
|
||||
readcloser := testReadCloser{errors.New("test error")}
|
||||
err := MaxBytesReader(readcloser, 1024, errors.New("overflow")).Close()
|
||||
if err == nil || err.Error() != "test error" {
|
||||
t.Errorf("Expected error to be %q, got: %q", "test error", err)
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,11 @@ package plugins
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -24,7 +24,9 @@ import (
|
||||
|
||||
// Exposed for testing
|
||||
var (
|
||||
transport = makeUnixRoundTripper
|
||||
transport = makeUnixRoundTripper
|
||||
maxResponseBytes int64 = 50 * 1024 * 1024
|
||||
errResponseTooLarge = fmt.Errorf("response must be shorter than 50MB")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -220,8 +222,11 @@ func NewPlugin(ctx context.Context, socket string, client *http.Client, expected
|
||||
params.Add(k, v)
|
||||
}
|
||||
|
||||
id := strings.TrimSuffix(filepath.Base(socket), filepath.Ext(socket))
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Plugin{
|
||||
PluginSpec: xfer.PluginSpec{ID: id, Label: id},
|
||||
context: ctx,
|
||||
socket: socket,
|
||||
expectedAPIVersion: expectedAPIVersion,
|
||||
@@ -234,16 +239,6 @@ func NewPlugin(ctx context.Context, socket string, client *http.Client, expected
|
||||
// Report gets the latest report from the plugin
|
||||
func (p *Plugin) Report() (result report.Report, err error) {
|
||||
result = report.MakeReport()
|
||||
if err := p.get("/report", p.handshakeMetadata, &result); err != nil {
|
||||
return result, err
|
||||
}
|
||||
if result.Plugins.Size() != 1 {
|
||||
return result, fmt.Errorf("plugins: %s report must contain exactly one plugin (found %d)", p.socket, result.Plugins.Size())
|
||||
}
|
||||
|
||||
key := result.Plugins.Keys()[0]
|
||||
spec, _ := result.Plugins.Lookup(key)
|
||||
p.PluginSpec = spec
|
||||
defer func() {
|
||||
p.setStatus(err)
|
||||
result.Plugins = result.Plugins.Add(p.PluginSpec)
|
||||
@@ -253,6 +248,17 @@ func (p *Plugin) Report() (result report.Report, err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := p.get("/report", p.handshakeMetadata, &result); err != nil {
|
||||
return result, err
|
||||
}
|
||||
if result.Plugins.Size() != 1 {
|
||||
return result, fmt.Errorf("report must contain exactly one plugin (found %d)", result.Plugins.Size())
|
||||
}
|
||||
|
||||
key := result.Plugins.Keys()[0]
|
||||
spec, _ := result.Plugins.Lookup(key)
|
||||
p.PluginSpec = spec
|
||||
|
||||
foundReporter := false
|
||||
for _, i := range spec.Interfaces {
|
||||
if i == "reporter" {
|
||||
@@ -294,7 +300,11 @@ func (p *Plugin) get(path string, params url.Values, result interface{}) error {
|
||||
return fmt.Errorf("plugin returned non-200 status code: %s", resp.Status)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if err := codec.NewDecoder(io.LimitReader(resp.Body, 50*1024*1024), &codec.JsonHandle{}).Decode(&result); err != nil {
|
||||
err = codec.NewDecoder(MaxBytesReader(resp.Body, maxResponseBytes, errResponseTooLarge), &codec.JsonHandle{}).Decode(&result)
|
||||
if err == errResponseTooLarge {
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("decoding error: %s", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -219,7 +219,20 @@ func TestRegistryLoadsExistingPluginsEvenWhenOneFails(t *testing.T) {
|
||||
defer r.Close()
|
||||
|
||||
r.Report()
|
||||
checkLoadedPluginIDs(t, r.ForEach, []string{"", "testPlugin"})
|
||||
checkLoadedPlugins(t, r.ForEach, []xfer.PluginSpec{
|
||||
{
|
||||
ID: "aFailure",
|
||||
Label: "aFailure",
|
||||
Status: "error: plugin returned non-200 status code: 500 Internal Server Error",
|
||||
},
|
||||
{
|
||||
ID: "testPlugin",
|
||||
Label: "testPlugin",
|
||||
Interfaces: []string{"reporter"},
|
||||
APIVersion: "1",
|
||||
Status: "ok",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistryDiscoversNewPlugins(t *testing.T) {
|
||||
@@ -412,8 +425,6 @@ func TestRegistryRejectsErroneousPluginResponses(t *testing.T) {
|
||||
|
||||
r.Report()
|
||||
checkLoadedPlugins(t, r.ForEach, []xfer.PluginSpec{
|
||||
{ID: ""},
|
||||
{ID: ""},
|
||||
{
|
||||
ID: "noInterface",
|
||||
Label: "noInterface",
|
||||
@@ -425,6 +436,16 @@ func TestRegistryRejectsErroneousPluginResponses(t *testing.T) {
|
||||
Interfaces: []string{"reporter"},
|
||||
Status: `error: spec must contain a label`,
|
||||
},
|
||||
{
|
||||
ID: "non200ResponseCode",
|
||||
Label: "non200ResponseCode",
|
||||
Status: "error: plugin returned non-200 status code: 500 Internal Server Error",
|
||||
},
|
||||
{
|
||||
ID: "nonJSONResponseBody",
|
||||
Label: "nonJSONResponseBody",
|
||||
Status: "error: decoding error: [pos 4]: json: expecting ull: got otJ",
|
||||
},
|
||||
{
|
||||
ID: "okPlugin",
|
||||
Label: "okPlugin",
|
||||
@@ -440,3 +461,43 @@ func TestRegistryRejectsErroneousPluginResponses(t *testing.T) {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistryRejectsPluginResponsesWhichAreTooLarge(t *testing.T) {
|
||||
description := ""
|
||||
for i := 0; i < 129; i++ {
|
||||
description += "a"
|
||||
}
|
||||
response := fmt.Sprintf(
|
||||
`{
|
||||
"Plugins": [
|
||||
{
|
||||
"id": "foo",
|
||||
"label": "foo",
|
||||
"description": %q,
|
||||
"interfaces": ["reporter"]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
description,
|
||||
)
|
||||
setup(t, mockPlugin{t: t, Name: "foo", Handler: stringHandler(http.StatusOK, response)}.file())
|
||||
oldMaxResponseBytes := maxResponseBytes
|
||||
maxResponseBytes = 128
|
||||
|
||||
defer func() {
|
||||
maxResponseBytes = oldMaxResponseBytes
|
||||
restore(t)
|
||||
}()
|
||||
|
||||
root := "/plugins"
|
||||
r, err := NewRegistry(root, "", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
r.Report()
|
||||
checkLoadedPlugins(t, r.ForEach, []xfer.PluginSpec{
|
||||
{ID: "foo", Label: "foo", Status: `error: response must be shorter than 50MB`},
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user