Better error message when plugin responses are too large

This commit is contained in:
Paul Bellamy
2016-04-18 11:01:32 +01:00
parent 4b41c48e6f
commit f17c8e6bae
4 changed files with 219 additions and 16 deletions

View File

@@ -0,0 +1,43 @@
package plugins
import (
"io"
)
// MaxBytesReader is similar to net/http.MaxBytesReader, but lets us choose how
// to handle an overflow by providing an error. net/http.MaxBytesReader uses
// net/http internals to render a naff error message. There are other
// discrepancies with how this detects overflows. Not sure if that will cause
// issues. If you want to use it as part of an HTTP server, it's probably best
// to change it so you can provide a callback func, which renders your error
// message, as returning an error into the middle of the net/http server will
// not be useful.
func MaxBytesReader(r io.ReadCloser, maxBytes int64, err error) io.ReadCloser {
if r == nil {
return nil
}
return &maxBytesReader{
ReadCloser: r,
bytesRemaining: maxBytes,
err: err,
}
}
type maxBytesReader struct {
io.ReadCloser
bytesRemaining int64
err error // Callback when overflowing
}
func (r *maxBytesReader) Read(p []byte) (int, error) {
if r.bytesRemaining <= 0 {
return 0, r.err
}
if int64(len(p)) > r.bytesRemaining {
p = p[0:r.bytesRemaining]
}
n, err := r.ReadCloser.Read(p)
r.bytesRemaining -= int64(n)
return n, err
}

View File

@@ -0,0 +1,89 @@
package plugins
import (
"bytes"
"errors"
"io/ioutil"
"strings"
"testing"
)
func TestMaxBytesReaderReturnsAllDataIfSmaller(t *testing.T) {
result, err := ioutil.ReadAll(MaxBytesReader(ioutil.NopCloser(strings.NewReader("some data")), 1024, errors.New("test error")))
if err != nil {
t.Error(err)
}
if string(result) != "some data" {
t.Errorf("Expected %q, got %q", "some data", string(result))
}
}
func TestMaxBytesReaderReturnsNilIfNil(t *testing.T) {
result := MaxBytesReader(nil, 1024, errors.New("test error"))
if result != nil {
t.Errorf("Expected nil, got: %q", result)
}
}
func TestMaxBytesReaderReturnsErrorIfLarger(t *testing.T) {
input := &bytes.Buffer{}
for i := int64(0); i <= 1024; i++ {
input.WriteByte(byte(i))
}
result, err := ioutil.ReadAll(MaxBytesReader(ioutil.NopCloser(input), 1024, errors.New("test error")))
if err.Error() != "test error" {
t.Errorf("Expected error to be %q, got: %q", "test error", err.Error())
}
if len(result) != 1024 {
t.Errorf("Expected result length to be 1024, but got: %d", len(result))
}
}
func TestMaxBytesReaderReturnsErrorIfLargerAndMassiveBufferGiven(t *testing.T) {
input := &bytes.Buffer{}
for i := int64(0); i <= 1024; i++ {
input.WriteByte(byte(i))
}
buffer := make([]byte, 1024+2)
reader := MaxBytesReader(ioutil.NopCloser(input), 1024, errors.New("test error"))
// First read is scoped down to the maximum
readCount, err := reader.Read(buffer)
if err != nil {
t.Error(err)
}
if readCount != 1024 {
t.Errorf("Expected result length to be 1024, but got: %d", readCount)
}
// Second read returns an error
readCount, err = reader.Read(buffer)
if err.Error() != "test error" {
t.Errorf("Expected error to be %q, got: %q", "test error", err.Error())
}
if readCount != 0 {
t.Errorf("Expected result length to be 0, but got: %d", readCount)
}
}
type testReadCloser struct {
closeError error
}
func (c testReadCloser) Read(p []byte) (n int, err error) {
return 0, nil
}
func (c testReadCloser) Close() error {
return c.closeError
}
func TestMaxBytesReaderPassesThroughErrorsWhenClosing(t *testing.T) {
readcloser := testReadCloser{errors.New("test error")}
err := MaxBytesReader(readcloser, 1024, errors.New("overflow")).Close()
if err == nil || err.Error() != "test error" {
t.Errorf("Expected error to be %q, got: %q", "test error", err)
}
}

View File

@@ -2,11 +2,11 @@ package plugins
import (
"fmt"
"io"
"net/http"
"net/url"
"path/filepath"
"sort"
"strings"
"sync"
"syscall"
"time"
@@ -24,7 +24,9 @@ import (
// Exposed for testing
var (
transport = makeUnixRoundTripper
transport = makeUnixRoundTripper
maxResponseBytes int64 = 50 * 1024 * 1024
errResponseTooLarge = fmt.Errorf("response must be shorter than 50MB")
)
const (
@@ -220,8 +222,11 @@ func NewPlugin(ctx context.Context, socket string, client *http.Client, expected
params.Add(k, v)
}
id := strings.TrimSuffix(filepath.Base(socket), filepath.Ext(socket))
ctx, cancel := context.WithCancel(ctx)
return &Plugin{
PluginSpec: xfer.PluginSpec{ID: id, Label: id},
context: ctx,
socket: socket,
expectedAPIVersion: expectedAPIVersion,
@@ -234,16 +239,6 @@ func NewPlugin(ctx context.Context, socket string, client *http.Client, expected
// Report gets the latest report from the plugin
func (p *Plugin) Report() (result report.Report, err error) {
result = report.MakeReport()
if err := p.get("/report", p.handshakeMetadata, &result); err != nil {
return result, err
}
if result.Plugins.Size() != 1 {
return result, fmt.Errorf("plugins: %s report must contain exactly one plugin (found %d)", p.socket, result.Plugins.Size())
}
key := result.Plugins.Keys()[0]
spec, _ := result.Plugins.Lookup(key)
p.PluginSpec = spec
defer func() {
p.setStatus(err)
result.Plugins = result.Plugins.Add(p.PluginSpec)
@@ -253,6 +248,17 @@ func (p *Plugin) Report() (result report.Report, err error) {
}
}()
if err := p.get("/report", p.handshakeMetadata, &result); err != nil {
return result, err
}
if result.Plugins.Size() != 1 {
return result, fmt.Errorf("report must contain exactly one plugin (found %d)", result.Plugins.Size())
}
key := result.Plugins.Keys()[0]
spec, _ := result.Plugins.Lookup(key)
p.PluginSpec = spec
foundReporter := false
for _, i := range spec.Interfaces {
if i == "reporter" {
@@ -294,7 +300,11 @@ func (p *Plugin) get(path string, params url.Values, result interface{}) error {
return fmt.Errorf("plugin returned non-200 status code: %s", resp.Status)
}
defer resp.Body.Close()
if err := codec.NewDecoder(io.LimitReader(resp.Body, 50*1024*1024), &codec.JsonHandle{}).Decode(&result); err != nil {
err = codec.NewDecoder(MaxBytesReader(resp.Body, maxResponseBytes, errResponseTooLarge), &codec.JsonHandle{}).Decode(&result)
if err == errResponseTooLarge {
return err
}
if err != nil {
return fmt.Errorf("decoding error: %s", err)
}
return nil

View File

@@ -219,7 +219,20 @@ func TestRegistryLoadsExistingPluginsEvenWhenOneFails(t *testing.T) {
defer r.Close()
r.Report()
checkLoadedPluginIDs(t, r.ForEach, []string{"", "testPlugin"})
checkLoadedPlugins(t, r.ForEach, []xfer.PluginSpec{
{
ID: "aFailure",
Label: "aFailure",
Status: "error: plugin returned non-200 status code: 500 Internal Server Error",
},
{
ID: "testPlugin",
Label: "testPlugin",
Interfaces: []string{"reporter"},
APIVersion: "1",
Status: "ok",
},
})
}
func TestRegistryDiscoversNewPlugins(t *testing.T) {
@@ -412,8 +425,6 @@ func TestRegistryRejectsErroneousPluginResponses(t *testing.T) {
r.Report()
checkLoadedPlugins(t, r.ForEach, []xfer.PluginSpec{
{ID: ""},
{ID: ""},
{
ID: "noInterface",
Label: "noInterface",
@@ -425,6 +436,16 @@ func TestRegistryRejectsErroneousPluginResponses(t *testing.T) {
Interfaces: []string{"reporter"},
Status: `error: spec must contain a label`,
},
{
ID: "non200ResponseCode",
Label: "non200ResponseCode",
Status: "error: plugin returned non-200 status code: 500 Internal Server Error",
},
{
ID: "nonJSONResponseBody",
Label: "nonJSONResponseBody",
Status: "error: decoding error: [pos 4]: json: expecting ull: got otJ",
},
{
ID: "okPlugin",
Label: "okPlugin",
@@ -440,3 +461,43 @@ func TestRegistryRejectsErroneousPluginResponses(t *testing.T) {
},
})
}
func TestRegistryRejectsPluginResponsesWhichAreTooLarge(t *testing.T) {
description := ""
for i := 0; i < 129; i++ {
description += "a"
}
response := fmt.Sprintf(
`{
"Plugins": [
{
"id": "foo",
"label": "foo",
"description": %q,
"interfaces": ["reporter"]
}
]
}`,
description,
)
setup(t, mockPlugin{t: t, Name: "foo", Handler: stringHandler(http.StatusOK, response)}.file())
oldMaxResponseBytes := maxResponseBytes
maxResponseBytes = 128
defer func() {
maxResponseBytes = oldMaxResponseBytes
restore(t)
}()
root := "/plugins"
r, err := NewRegistry(root, "", nil)
if err != nil {
t.Fatal(err)
}
defer r.Close()
r.Report()
checkLoadedPlugins(t, r.ForEach, []xfer.PluginSpec{
{ID: "foo", Label: "foo", Status: `error: response must be shorter than 50MB`},
})
}