Files
weave-scope/app/multitenant/sqs_control_router.go
Bryan Boreham 89363f5dcf Defer metrics registration until we need it
This avoids app-specific metrics appearing in the probe.
2019-07-04 14:24:22 +00:00

373 lines
10 KiB
Go

package multitenant
import (
"bytes"
"encoding/json"
"fmt"
"math/rand"
"sync"
"time"
"context"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"github.com/weaveworks/common/instrument"
"github.com/weaveworks/scope/app"
"github.com/weaveworks/scope/common/xfer"
)
var (
longPollTime = aws.Int64(10)
sqsRequestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "scope",
Name: "sqs_request_duration_seconds",
Help: "Time in seconds spent doing SQS requests.",
Buckets: prometheus.DefBuckets,
}, []string{"method", "status_code"})
)
func registerSQSMetrics() {
prometheus.MustRegister(sqsRequestDuration)
}
var registerSQSMetricsOnce sync.Once
// sqsControlRouter:
// Creates a queue for every probe that connects to it, and a queue for
// responses back to it. When it receives a request, posts it to the
// probe queue. When probe receives a request, handles it and posts the
// response back to the response queue.
type sqsControlRouter struct {
service *sqs.SQS
responseQueueURL *string
userIDer UserIDer
prefix string
rpcTimeout time.Duration
mtx sync.Mutex
responses map[string]chan xfer.Response
probeWorkers map[int64]*probeWorker
}
type sqsRequestMessage struct {
ID string
Request xfer.Request
ResponseQueueURL string
}
type sqsResponseMessage struct {
ID string
Response xfer.Response
}
// NewSQSControlRouter the harbinger of death
func NewSQSControlRouter(config *aws.Config, userIDer UserIDer, prefix string, rpcTimeout time.Duration) app.ControlRouter {
registerSQSMetricsOnce.Do(registerSQSMetrics)
result := &sqsControlRouter{
service: sqs.New(session.New(config)),
responseQueueURL: nil,
userIDer: userIDer,
prefix: prefix,
rpcTimeout: rpcTimeout,
responses: map[string]chan xfer.Response{},
probeWorkers: map[int64]*probeWorker{},
}
go result.loop()
return result
}
func (cr *sqsControlRouter) Stop() error {
return nil
}
func (cr *sqsControlRouter) setResponseQueueURL(url *string) {
cr.mtx.Lock()
defer cr.mtx.Unlock()
cr.responseQueueURL = url
}
func (cr *sqsControlRouter) getResponseQueueURL() *string {
cr.mtx.Lock()
defer cr.mtx.Unlock()
return cr.responseQueueURL
}
func (cr *sqsControlRouter) getOrCreateQueue(ctx context.Context, name string) (*string, error) {
// CreateQueue creates a queue or if it already exists, returns url of said queue
var createQueueRes *sqs.CreateQueueOutput
var err error
err = instrument.TimeRequestHistogram(ctx, "SQS.CreateQueue", sqsRequestDuration, func(_ context.Context) error {
createQueueRes, err = cr.service.CreateQueue(&sqs.CreateQueueInput{
QueueName: aws.String(name),
})
return err
})
if err != nil {
return nil, err
}
return createQueueRes.QueueUrl, nil
}
func (cr *sqsControlRouter) loop() {
var (
responseQueueURL *string
err error
ctx = context.Background()
)
for {
// This app has a random id and uses this as a return path for all responses from probes.
name := fmt.Sprintf("%scontrol-app-%d", cr.prefix, rand.Int63())
responseQueueURL, err = cr.getOrCreateQueue(ctx, name)
if err != nil {
log.Errorf("Failed to create queue: %v", err)
time.Sleep(1 * time.Second)
continue
}
cr.setResponseQueueURL(responseQueueURL)
break
}
for {
var res *sqs.ReceiveMessageOutput
var err error
err = instrument.TimeRequestHistogram(ctx, "SQS.ReceiveMessage", sqsRequestDuration, func(_ context.Context) error {
res, err = cr.service.ReceiveMessage(&sqs.ReceiveMessageInput{
QueueUrl: responseQueueURL,
WaitTimeSeconds: longPollTime,
})
return err
})
if err != nil {
log.Errorf("Error receiving message from %s: %v", *responseQueueURL, err)
continue
}
if len(res.Messages) == 0 {
continue
}
if err := cr.deleteMessages(ctx, responseQueueURL, res.Messages); err != nil {
log.Errorf("Error deleting message from %s: %v", *responseQueueURL, err)
}
cr.handleResponses(res)
}
}
func (cr *sqsControlRouter) deleteMessages(ctx context.Context, queueURL *string, messages []*sqs.Message) error {
entries := []*sqs.DeleteMessageBatchRequestEntry{}
for _, message := range messages {
entries = append(entries, &sqs.DeleteMessageBatchRequestEntry{
ReceiptHandle: message.ReceiptHandle,
Id: message.MessageId,
})
}
return instrument.TimeRequestHistogram(ctx, "SQS.DeleteMessageBatch", sqsRequestDuration, func(_ context.Context) error {
_, err := cr.service.DeleteMessageBatch(&sqs.DeleteMessageBatchInput{
QueueUrl: queueURL,
Entries: entries,
})
return err
})
}
func (cr *sqsControlRouter) handleResponses(res *sqs.ReceiveMessageOutput) {
cr.mtx.Lock()
defer cr.mtx.Unlock()
for _, message := range res.Messages {
var sqsResponse sqsResponseMessage
if err := json.NewDecoder(bytes.NewBufferString(*message.Body)).Decode(&sqsResponse); err != nil {
log.Errorf("Error decoding message: %v", err)
continue
}
waiter, ok := cr.responses[sqsResponse.ID]
if !ok {
log.Errorf("Dropping response %s - no one waiting for it!", sqsResponse.ID)
continue
}
waiter <- sqsResponse.Response
}
}
func (cr *sqsControlRouter) sendMessage(ctx context.Context, queueURL *string, message interface{}) error {
buf := bytes.Buffer{}
if err := json.NewEncoder(&buf).Encode(message); err != nil {
return err
}
log.Debugf("sendMessage to %s: %s", *queueURL, buf.String())
return instrument.TimeRequestHistogram(ctx, "SQS.SendMessage", sqsRequestDuration, func(_ context.Context) error {
_, err := cr.service.SendMessage(&sqs.SendMessageInput{
QueueUrl: queueURL,
MessageBody: aws.String(buf.String()),
})
return err
})
}
func (cr *sqsControlRouter) Handle(ctx context.Context, probeID string, req xfer.Request) (xfer.Response, error) {
// Make sure we know the users
userID, err := cr.userIDer(ctx)
if err != nil {
return xfer.Response{}, err
}
// Get the queue url for the local (control app) queue, and for the probe.
responseQueueURL := cr.getResponseQueueURL()
if responseQueueURL == nil {
return xfer.Response{}, fmt.Errorf("no SQS queue yet")
}
var probeQueueURL *sqs.GetQueueUrlOutput
err = instrument.TimeRequestHistogram(ctx, "SQS.GetQueueUrl", sqsRequestDuration, func(_ context.Context) error {
probeQueueName := fmt.Sprintf("%sprobe-%s-%s", cr.prefix, userID, probeID)
probeQueueURL, err = cr.service.GetQueueUrl(&sqs.GetQueueUrlInput{
QueueName: aws.String(probeQueueName),
})
return err
})
if err != nil {
return xfer.Response{}, err
}
// Add a response channel before we send the request, to prevent races
id := fmt.Sprintf("request-%s-%d", userID, rand.Int63())
waiter := make(chan xfer.Response, 1)
cr.mtx.Lock()
cr.responses[id] = waiter
cr.mtx.Unlock()
defer func() {
cr.mtx.Lock()
delete(cr.responses, id)
cr.mtx.Unlock()
}()
// Next, send the request to that queue
if err := instrument.TimeRequestHistogram(ctx, "SQS.SendMessage", sqsRequestDuration, func(ctx context.Context) error {
return cr.sendMessage(ctx, probeQueueURL.QueueUrl, sqsRequestMessage{
ID: id,
Request: req,
ResponseQueueURL: *responseQueueURL,
})
}); err != nil {
return xfer.Response{}, err
}
// Finally, wait for a response on our queue
select {
case response := <-waiter:
return response, nil
case <-time.After(cr.rpcTimeout):
return xfer.Response{}, fmt.Errorf("request timed out")
}
}
func (cr *sqsControlRouter) Register(ctx context.Context, probeID string, handler xfer.ControlHandlerFunc) (int64, error) {
userID, err := cr.userIDer(ctx)
if err != nil {
return 0, err
}
name := fmt.Sprintf("%sprobe-%s-%s", cr.prefix, userID, probeID)
queueURL, err := cr.getOrCreateQueue(ctx, name)
if err != nil {
return 0, err
}
pwID := rand.Int63()
pw := &probeWorker{
ctx: ctx,
router: cr,
requestQueueURL: queueURL,
handler: handler,
quit: make(chan struct{}),
}
pw.done.Add(1)
go pw.loop()
cr.mtx.Lock()
defer cr.mtx.Unlock()
cr.probeWorkers[pwID] = pw
return pwID, nil
}
func (cr *sqsControlRouter) Deregister(_ context.Context, probeID string, id int64) error {
cr.mtx.Lock()
pw, ok := cr.probeWorkers[id]
delete(cr.probeWorkers, id)
cr.mtx.Unlock()
if ok {
pw.stop()
}
return nil
}
// a probeWorker encapsulates a goroutine serving a probe's websocket connection.
type probeWorker struct {
ctx context.Context
router *sqsControlRouter
requestQueueURL *string
handler xfer.ControlHandlerFunc
quit chan struct{}
done sync.WaitGroup
}
func (pw *probeWorker) stop() {
close(pw.quit)
pw.done.Wait()
}
func (pw *probeWorker) loop() {
defer pw.done.Done()
for {
// have we been stopped?
select {
case <-pw.quit:
return
default:
}
var res *sqs.ReceiveMessageOutput
var err error
err = instrument.TimeRequestHistogram(pw.ctx, "SQS.ReceiveMessage", sqsRequestDuration, func(_ context.Context) error {
res, err = pw.router.service.ReceiveMessage(&sqs.ReceiveMessageInput{
QueueUrl: pw.requestQueueURL,
WaitTimeSeconds: longPollTime,
})
return err
})
if err != nil {
log.Errorf("Error receiving message: %v", err)
continue
}
if len(res.Messages) == 0 {
continue
}
if err := pw.router.deleteMessages(pw.ctx, pw.requestQueueURL, res.Messages); err != nil {
log.Errorf("Error deleting message from %s: %v", *pw.requestQueueURL, err)
}
for _, message := range res.Messages {
var sqsRequest sqsRequestMessage
if err := json.NewDecoder(bytes.NewBufferString(*message.Body)).Decode(&sqsRequest); err != nil {
log.Errorf("Error decoding message from: %v", err)
continue
}
response := pw.handler(sqsRequest.Request)
if err := pw.router.sendMessage(pw.ctx, &sqsRequest.ResponseQueueURL, sqsResponseMessage{
ID: sqsRequest.ID,
Response: response,
}); err != nil {
log.Errorf("Error sending response: %v", err)
}
}
}
}