mirror of
https://github.com/weaveworks/scope.git
synced 2026-02-14 18:09:59 +00:00
373 lines
10 KiB
Go
373 lines
10 KiB
Go
package multitenant
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/rand"
|
|
"sync"
|
|
"time"
|
|
|
|
"context"
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/sqs"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/weaveworks/common/instrument"
|
|
"github.com/weaveworks/scope/app"
|
|
"github.com/weaveworks/scope/common/xfer"
|
|
)
|
|
|
|
var (
|
|
longPollTime = aws.Int64(10)
|
|
sqsRequestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
|
Namespace: "scope",
|
|
Name: "sqs_request_duration_seconds",
|
|
Help: "Time in seconds spent doing SQS requests.",
|
|
Buckets: prometheus.DefBuckets,
|
|
}, []string{"method", "status_code"})
|
|
)
|
|
|
|
func registerSQSMetrics() {
|
|
prometheus.MustRegister(sqsRequestDuration)
|
|
}
|
|
|
|
var registerSQSMetricsOnce sync.Once
|
|
|
|
// sqsControlRouter:
|
|
// Creates a queue for every probe that connects to it, and a queue for
|
|
// responses back to it. When it receives a request, posts it to the
|
|
// probe queue. When probe receives a request, handles it and posts the
|
|
// response back to the response queue.
|
|
type sqsControlRouter struct {
|
|
service *sqs.SQS
|
|
responseQueueURL *string
|
|
userIDer UserIDer
|
|
prefix string
|
|
rpcTimeout time.Duration
|
|
|
|
mtx sync.Mutex
|
|
responses map[string]chan xfer.Response
|
|
probeWorkers map[int64]*probeWorker
|
|
}
|
|
|
|
type sqsRequestMessage struct {
|
|
ID string
|
|
Request xfer.Request
|
|
ResponseQueueURL string
|
|
}
|
|
|
|
type sqsResponseMessage struct {
|
|
ID string
|
|
Response xfer.Response
|
|
}
|
|
|
|
// NewSQSControlRouter the harbinger of death
|
|
func NewSQSControlRouter(config *aws.Config, userIDer UserIDer, prefix string, rpcTimeout time.Duration) app.ControlRouter {
|
|
registerSQSMetricsOnce.Do(registerSQSMetrics)
|
|
result := &sqsControlRouter{
|
|
service: sqs.New(session.New(config)),
|
|
responseQueueURL: nil,
|
|
userIDer: userIDer,
|
|
prefix: prefix,
|
|
rpcTimeout: rpcTimeout,
|
|
responses: map[string]chan xfer.Response{},
|
|
probeWorkers: map[int64]*probeWorker{},
|
|
}
|
|
go result.loop()
|
|
return result
|
|
}
|
|
|
|
func (cr *sqsControlRouter) Stop() error {
|
|
return nil
|
|
}
|
|
|
|
func (cr *sqsControlRouter) setResponseQueueURL(url *string) {
|
|
cr.mtx.Lock()
|
|
defer cr.mtx.Unlock()
|
|
cr.responseQueueURL = url
|
|
}
|
|
|
|
func (cr *sqsControlRouter) getResponseQueueURL() *string {
|
|
cr.mtx.Lock()
|
|
defer cr.mtx.Unlock()
|
|
return cr.responseQueueURL
|
|
}
|
|
|
|
func (cr *sqsControlRouter) getOrCreateQueue(ctx context.Context, name string) (*string, error) {
|
|
// CreateQueue creates a queue or if it already exists, returns url of said queue
|
|
var createQueueRes *sqs.CreateQueueOutput
|
|
var err error
|
|
err = instrument.TimeRequestHistogram(ctx, "SQS.CreateQueue", sqsRequestDuration, func(_ context.Context) error {
|
|
createQueueRes, err = cr.service.CreateQueue(&sqs.CreateQueueInput{
|
|
QueueName: aws.String(name),
|
|
})
|
|
return err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return createQueueRes.QueueUrl, nil
|
|
}
|
|
|
|
func (cr *sqsControlRouter) loop() {
|
|
var (
|
|
responseQueueURL *string
|
|
err error
|
|
ctx = context.Background()
|
|
)
|
|
for {
|
|
// This app has a random id and uses this as a return path for all responses from probes.
|
|
name := fmt.Sprintf("%scontrol-app-%d", cr.prefix, rand.Int63())
|
|
responseQueueURL, err = cr.getOrCreateQueue(ctx, name)
|
|
if err != nil {
|
|
log.Errorf("Failed to create queue: %v", err)
|
|
time.Sleep(1 * time.Second)
|
|
continue
|
|
}
|
|
cr.setResponseQueueURL(responseQueueURL)
|
|
break
|
|
}
|
|
|
|
for {
|
|
var res *sqs.ReceiveMessageOutput
|
|
var err error
|
|
err = instrument.TimeRequestHistogram(ctx, "SQS.ReceiveMessage", sqsRequestDuration, func(_ context.Context) error {
|
|
res, err = cr.service.ReceiveMessage(&sqs.ReceiveMessageInput{
|
|
QueueUrl: responseQueueURL,
|
|
WaitTimeSeconds: longPollTime,
|
|
})
|
|
return err
|
|
})
|
|
if err != nil {
|
|
log.Errorf("Error receiving message from %s: %v", *responseQueueURL, err)
|
|
continue
|
|
}
|
|
|
|
if len(res.Messages) == 0 {
|
|
continue
|
|
}
|
|
if err := cr.deleteMessages(ctx, responseQueueURL, res.Messages); err != nil {
|
|
log.Errorf("Error deleting message from %s: %v", *responseQueueURL, err)
|
|
}
|
|
cr.handleResponses(res)
|
|
}
|
|
}
|
|
|
|
func (cr *sqsControlRouter) deleteMessages(ctx context.Context, queueURL *string, messages []*sqs.Message) error {
|
|
entries := []*sqs.DeleteMessageBatchRequestEntry{}
|
|
for _, message := range messages {
|
|
entries = append(entries, &sqs.DeleteMessageBatchRequestEntry{
|
|
ReceiptHandle: message.ReceiptHandle,
|
|
Id: message.MessageId,
|
|
})
|
|
}
|
|
return instrument.TimeRequestHistogram(ctx, "SQS.DeleteMessageBatch", sqsRequestDuration, func(_ context.Context) error {
|
|
_, err := cr.service.DeleteMessageBatch(&sqs.DeleteMessageBatchInput{
|
|
QueueUrl: queueURL,
|
|
Entries: entries,
|
|
})
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (cr *sqsControlRouter) handleResponses(res *sqs.ReceiveMessageOutput) {
|
|
cr.mtx.Lock()
|
|
defer cr.mtx.Unlock()
|
|
|
|
for _, message := range res.Messages {
|
|
var sqsResponse sqsResponseMessage
|
|
if err := json.NewDecoder(bytes.NewBufferString(*message.Body)).Decode(&sqsResponse); err != nil {
|
|
log.Errorf("Error decoding message: %v", err)
|
|
continue
|
|
}
|
|
|
|
waiter, ok := cr.responses[sqsResponse.ID]
|
|
if !ok {
|
|
log.Errorf("Dropping response %s - no one waiting for it!", sqsResponse.ID)
|
|
continue
|
|
}
|
|
waiter <- sqsResponse.Response
|
|
}
|
|
}
|
|
|
|
func (cr *sqsControlRouter) sendMessage(ctx context.Context, queueURL *string, message interface{}) error {
|
|
buf := bytes.Buffer{}
|
|
if err := json.NewEncoder(&buf).Encode(message); err != nil {
|
|
return err
|
|
}
|
|
log.Debugf("sendMessage to %s: %s", *queueURL, buf.String())
|
|
|
|
return instrument.TimeRequestHistogram(ctx, "SQS.SendMessage", sqsRequestDuration, func(_ context.Context) error {
|
|
_, err := cr.service.SendMessage(&sqs.SendMessageInput{
|
|
QueueUrl: queueURL,
|
|
MessageBody: aws.String(buf.String()),
|
|
})
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (cr *sqsControlRouter) Handle(ctx context.Context, probeID string, req xfer.Request) (xfer.Response, error) {
|
|
// Make sure we know the users
|
|
userID, err := cr.userIDer(ctx)
|
|
if err != nil {
|
|
return xfer.Response{}, err
|
|
}
|
|
|
|
// Get the queue url for the local (control app) queue, and for the probe.
|
|
responseQueueURL := cr.getResponseQueueURL()
|
|
if responseQueueURL == nil {
|
|
return xfer.Response{}, fmt.Errorf("no SQS queue yet")
|
|
}
|
|
|
|
var probeQueueURL *sqs.GetQueueUrlOutput
|
|
err = instrument.TimeRequestHistogram(ctx, "SQS.GetQueueUrl", sqsRequestDuration, func(_ context.Context) error {
|
|
probeQueueName := fmt.Sprintf("%sprobe-%s-%s", cr.prefix, userID, probeID)
|
|
probeQueueURL, err = cr.service.GetQueueUrl(&sqs.GetQueueUrlInput{
|
|
QueueName: aws.String(probeQueueName),
|
|
})
|
|
return err
|
|
})
|
|
if err != nil {
|
|
return xfer.Response{}, err
|
|
}
|
|
|
|
// Add a response channel before we send the request, to prevent races
|
|
id := fmt.Sprintf("request-%s-%d", userID, rand.Int63())
|
|
waiter := make(chan xfer.Response, 1)
|
|
cr.mtx.Lock()
|
|
cr.responses[id] = waiter
|
|
cr.mtx.Unlock()
|
|
defer func() {
|
|
cr.mtx.Lock()
|
|
delete(cr.responses, id)
|
|
cr.mtx.Unlock()
|
|
}()
|
|
|
|
// Next, send the request to that queue
|
|
if err := instrument.TimeRequestHistogram(ctx, "SQS.SendMessage", sqsRequestDuration, func(ctx context.Context) error {
|
|
return cr.sendMessage(ctx, probeQueueURL.QueueUrl, sqsRequestMessage{
|
|
ID: id,
|
|
Request: req,
|
|
ResponseQueueURL: *responseQueueURL,
|
|
})
|
|
}); err != nil {
|
|
return xfer.Response{}, err
|
|
}
|
|
|
|
// Finally, wait for a response on our queue
|
|
select {
|
|
case response := <-waiter:
|
|
return response, nil
|
|
case <-time.After(cr.rpcTimeout):
|
|
return xfer.Response{}, fmt.Errorf("request timed out")
|
|
}
|
|
}
|
|
|
|
func (cr *sqsControlRouter) Register(ctx context.Context, probeID string, handler xfer.ControlHandlerFunc) (int64, error) {
|
|
userID, err := cr.userIDer(ctx)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
name := fmt.Sprintf("%sprobe-%s-%s", cr.prefix, userID, probeID)
|
|
queueURL, err := cr.getOrCreateQueue(ctx, name)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
pwID := rand.Int63()
|
|
pw := &probeWorker{
|
|
ctx: ctx,
|
|
router: cr,
|
|
requestQueueURL: queueURL,
|
|
handler: handler,
|
|
quit: make(chan struct{}),
|
|
}
|
|
pw.done.Add(1)
|
|
go pw.loop()
|
|
|
|
cr.mtx.Lock()
|
|
defer cr.mtx.Unlock()
|
|
cr.probeWorkers[pwID] = pw
|
|
return pwID, nil
|
|
}
|
|
|
|
func (cr *sqsControlRouter) Deregister(_ context.Context, probeID string, id int64) error {
|
|
cr.mtx.Lock()
|
|
pw, ok := cr.probeWorkers[id]
|
|
delete(cr.probeWorkers, id)
|
|
cr.mtx.Unlock()
|
|
if ok {
|
|
pw.stop()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// a probeWorker encapsulates a goroutine serving a probe's websocket connection.
|
|
type probeWorker struct {
|
|
ctx context.Context
|
|
router *sqsControlRouter
|
|
requestQueueURL *string
|
|
handler xfer.ControlHandlerFunc
|
|
quit chan struct{}
|
|
done sync.WaitGroup
|
|
}
|
|
|
|
func (pw *probeWorker) stop() {
|
|
close(pw.quit)
|
|
pw.done.Wait()
|
|
}
|
|
|
|
func (pw *probeWorker) loop() {
|
|
defer pw.done.Done()
|
|
for {
|
|
// have we been stopped?
|
|
select {
|
|
case <-pw.quit:
|
|
return
|
|
default:
|
|
}
|
|
|
|
var res *sqs.ReceiveMessageOutput
|
|
var err error
|
|
err = instrument.TimeRequestHistogram(pw.ctx, "SQS.ReceiveMessage", sqsRequestDuration, func(_ context.Context) error {
|
|
res, err = pw.router.service.ReceiveMessage(&sqs.ReceiveMessageInput{
|
|
QueueUrl: pw.requestQueueURL,
|
|
WaitTimeSeconds: longPollTime,
|
|
})
|
|
return err
|
|
})
|
|
if err != nil {
|
|
log.Errorf("Error receiving message: %v", err)
|
|
continue
|
|
}
|
|
|
|
if len(res.Messages) == 0 {
|
|
continue
|
|
}
|
|
if err := pw.router.deleteMessages(pw.ctx, pw.requestQueueURL, res.Messages); err != nil {
|
|
log.Errorf("Error deleting message from %s: %v", *pw.requestQueueURL, err)
|
|
}
|
|
|
|
for _, message := range res.Messages {
|
|
var sqsRequest sqsRequestMessage
|
|
if err := json.NewDecoder(bytes.NewBufferString(*message.Body)).Decode(&sqsRequest); err != nil {
|
|
log.Errorf("Error decoding message from: %v", err)
|
|
continue
|
|
}
|
|
|
|
response := pw.handler(sqsRequest.Request)
|
|
|
|
if err := pw.router.sendMessage(pw.ctx, &sqsRequest.ResponseQueueURL, sqsResponseMessage{
|
|
ID: sqsRequest.ID,
|
|
Response: response,
|
|
}); err != nil {
|
|
log.Errorf("Error sending response: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|