diff --git a/probe/plugins/max_bytes_reader.go b/probe/plugins/max_bytes_reader.go new file mode 100644 index 000000000..4775b13d3 --- /dev/null +++ b/probe/plugins/max_bytes_reader.go @@ -0,0 +1,43 @@ +package plugins + +import ( + "io" +) + +// MaxBytesReader is similar to net/http.MaxBytesReader, but lets us choose how +// to handle an overflow by providing an error. net/http.MaxBytesReader uses +// net/http internals to render a naff error message. There are other +// discrepancies with how this detects overflows. Not sure if that will cause +// issues. If you want to use it as part of an HTTP server, it's probably best +// to change it so you can provide a callback func, which renders your error +// message, as returning an error into the middle of the net/http server will +// not be useful. +func MaxBytesReader(r io.ReadCloser, maxBytes int64, err error) io.ReadCloser { + if r == nil { + return nil + } + + return &maxBytesReader{ + ReadCloser: r, + bytesRemaining: maxBytes, + err: err, + } +} + +type maxBytesReader struct { + io.ReadCloser + bytesRemaining int64 + err error // Callback when overflowing +} + +func (r *maxBytesReader) Read(p []byte) (int, error) { + if r.bytesRemaining <= 0 { + return 0, r.err + } + if int64(len(p)) > r.bytesRemaining { + p = p[0:r.bytesRemaining] + } + n, err := r.ReadCloser.Read(p) + r.bytesRemaining -= int64(n) + return n, err +} diff --git a/probe/plugins/max_bytes_reader_internal_test.go b/probe/plugins/max_bytes_reader_internal_test.go new file mode 100644 index 000000000..4eb3f385d --- /dev/null +++ b/probe/plugins/max_bytes_reader_internal_test.go @@ -0,0 +1,89 @@ +package plugins + +import ( + "bytes" + "errors" + "io/ioutil" + "strings" + "testing" +) + +func TestMaxBytesReaderReturnsAllDataIfSmaller(t *testing.T) { + result, err := ioutil.ReadAll(MaxBytesReader(ioutil.NopCloser(strings.NewReader("some data")), 1024, errors.New("test error"))) + if err != nil { + t.Error(err) + } + if string(result) != "some data" { + t.Errorf("Expected %q, got %q", "some data", string(result)) + } +} + +func TestMaxBytesReaderReturnsNilIfNil(t *testing.T) { + result := MaxBytesReader(nil, 1024, errors.New("test error")) + if result != nil { + t.Errorf("Expected nil, got: %q", result) + } +} + +func TestMaxBytesReaderReturnsErrorIfLarger(t *testing.T) { + input := &bytes.Buffer{} + for i := int64(0); i <= 1024; i++ { + input.WriteByte(byte(i)) + } + + result, err := ioutil.ReadAll(MaxBytesReader(ioutil.NopCloser(input), 1024, errors.New("test error"))) + if err.Error() != "test error" { + t.Errorf("Expected error to be %q, got: %q", "test error", err.Error()) + } + if len(result) != 1024 { + t.Errorf("Expected result length to be 1024, but got: %d", len(result)) + } +} + +func TestMaxBytesReaderReturnsErrorIfLargerAndMassiveBufferGiven(t *testing.T) { + input := &bytes.Buffer{} + for i := int64(0); i <= 1024; i++ { + input.WriteByte(byte(i)) + } + + buffer := make([]byte, 1024+2) + reader := MaxBytesReader(ioutil.NopCloser(input), 1024, errors.New("test error")) + + // First read is scoped down to the maximum + readCount, err := reader.Read(buffer) + if err != nil { + t.Error(err) + } + if readCount != 1024 { + t.Errorf("Expected result length to be 1024, but got: %d", readCount) + } + + // Second read returns an error + readCount, err = reader.Read(buffer) + if err.Error() != "test error" { + t.Errorf("Expected error to be %q, got: %q", "test error", err.Error()) + } + if readCount != 0 { + t.Errorf("Expected result length to be 0, but got: %d", readCount) + } +} + +type testReadCloser struct { + closeError error +} + +func (c testReadCloser) Read(p []byte) (n int, err error) { + return 0, nil +} + +func (c testReadCloser) Close() error { + return c.closeError +} + +func TestMaxBytesReaderPassesThroughErrorsWhenClosing(t *testing.T) { + readcloser := testReadCloser{errors.New("test error")} + err := MaxBytesReader(readcloser, 1024, errors.New("overflow")).Close() + if err == nil || err.Error() != "test error" { + t.Errorf("Expected error to be %q, got: %q", "test error", err) + } +} diff --git a/probe/plugins/registry.go b/probe/plugins/registry.go index f4a12bb70..80612a08d 100644 --- a/probe/plugins/registry.go +++ b/probe/plugins/registry.go @@ -2,11 +2,11 @@ package plugins import ( "fmt" - "io" "net/http" "net/url" "path/filepath" "sort" + "strings" "sync" "syscall" "time" @@ -24,7 +24,9 @@ import ( // Exposed for testing var ( - transport = makeUnixRoundTripper + transport = makeUnixRoundTripper + maxResponseBytes int64 = 50 * 1024 * 1024 + errResponseTooLarge = fmt.Errorf("response must be shorter than 50MB") ) const ( @@ -220,8 +222,11 @@ func NewPlugin(ctx context.Context, socket string, client *http.Client, expected params.Add(k, v) } + id := strings.TrimSuffix(filepath.Base(socket), filepath.Ext(socket)) + ctx, cancel := context.WithCancel(ctx) return &Plugin{ + PluginSpec: xfer.PluginSpec{ID: id, Label: id}, context: ctx, socket: socket, expectedAPIVersion: expectedAPIVersion, @@ -234,16 +239,6 @@ func NewPlugin(ctx context.Context, socket string, client *http.Client, expected // Report gets the latest report from the plugin func (p *Plugin) Report() (result report.Report, err error) { result = report.MakeReport() - if err := p.get("/report", p.handshakeMetadata, &result); err != nil { - return result, err - } - if result.Plugins.Size() != 1 { - return result, fmt.Errorf("plugins: %s report must contain exactly one plugin (found %d)", p.socket, result.Plugins.Size()) - } - - key := result.Plugins.Keys()[0] - spec, _ := result.Plugins.Lookup(key) - p.PluginSpec = spec defer func() { p.setStatus(err) result.Plugins = result.Plugins.Add(p.PluginSpec) @@ -253,6 +248,17 @@ func (p *Plugin) Report() (result report.Report, err error) { } }() + if err := p.get("/report", p.handshakeMetadata, &result); err != nil { + return result, err + } + if result.Plugins.Size() != 1 { + return result, fmt.Errorf("report must contain exactly one plugin (found %d)", result.Plugins.Size()) + } + + key := result.Plugins.Keys()[0] + spec, _ := result.Plugins.Lookup(key) + p.PluginSpec = spec + foundReporter := false for _, i := range spec.Interfaces { if i == "reporter" { @@ -294,7 +300,11 @@ func (p *Plugin) get(path string, params url.Values, result interface{}) error { return fmt.Errorf("plugin returned non-200 status code: %s", resp.Status) } defer resp.Body.Close() - if err := codec.NewDecoder(io.LimitReader(resp.Body, 50*1024*1024), &codec.JsonHandle{}).Decode(&result); err != nil { + err = codec.NewDecoder(MaxBytesReader(resp.Body, maxResponseBytes, errResponseTooLarge), &codec.JsonHandle{}).Decode(&result) + if err == errResponseTooLarge { + return err + } + if err != nil { return fmt.Errorf("decoding error: %s", err) } return nil diff --git a/probe/plugins/registry_internal_test.go b/probe/plugins/registry_internal_test.go index e3500b5aa..6f60cc82b 100644 --- a/probe/plugins/registry_internal_test.go +++ b/probe/plugins/registry_internal_test.go @@ -219,7 +219,20 @@ func TestRegistryLoadsExistingPluginsEvenWhenOneFails(t *testing.T) { defer r.Close() r.Report() - checkLoadedPluginIDs(t, r.ForEach, []string{"", "testPlugin"}) + checkLoadedPlugins(t, r.ForEach, []xfer.PluginSpec{ + { + ID: "aFailure", + Label: "aFailure", + Status: "error: plugin returned non-200 status code: 500 Internal Server Error", + }, + { + ID: "testPlugin", + Label: "testPlugin", + Interfaces: []string{"reporter"}, + APIVersion: "1", + Status: "ok", + }, + }) } func TestRegistryDiscoversNewPlugins(t *testing.T) { @@ -412,8 +425,6 @@ func TestRegistryRejectsErroneousPluginResponses(t *testing.T) { r.Report() checkLoadedPlugins(t, r.ForEach, []xfer.PluginSpec{ - {ID: ""}, - {ID: ""}, { ID: "noInterface", Label: "noInterface", @@ -425,6 +436,16 @@ func TestRegistryRejectsErroneousPluginResponses(t *testing.T) { Interfaces: []string{"reporter"}, Status: `error: spec must contain a label`, }, + { + ID: "non200ResponseCode", + Label: "non200ResponseCode", + Status: "error: plugin returned non-200 status code: 500 Internal Server Error", + }, + { + ID: "nonJSONResponseBody", + Label: "nonJSONResponseBody", + Status: "error: decoding error: [pos 4]: json: expecting ull: got otJ", + }, { ID: "okPlugin", Label: "okPlugin", @@ -440,3 +461,43 @@ func TestRegistryRejectsErroneousPluginResponses(t *testing.T) { }, }) } + +func TestRegistryRejectsPluginResponsesWhichAreTooLarge(t *testing.T) { + description := "" + for i := 0; i < 129; i++ { + description += "a" + } + response := fmt.Sprintf( + `{ + "Plugins": [ + { + "id": "foo", + "label": "foo", + "description": %q, + "interfaces": ["reporter"] + } + ] + }`, + description, + ) + setup(t, mockPlugin{t: t, Name: "foo", Handler: stringHandler(http.StatusOK, response)}.file()) + oldMaxResponseBytes := maxResponseBytes + maxResponseBytes = 128 + + defer func() { + maxResponseBytes = oldMaxResponseBytes + restore(t) + }() + + root := "/plugins" + r, err := NewRegistry(root, "", nil) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + r.Report() + checkLoadedPlugins(t, r.ForEach, []xfer.PluginSpec{ + {ID: "foo", Label: "foo", Status: `error: response must be shorter than 50MB`}, + }) +}