diff --git a/broker/ipc.go b/broker/ipc.go index d0b4d474716ad426e89fbb70ba2cf364c541cc21..f4f7f6b7e2a1643bf0b05cefd8cce351ebc52fc1 100644 --- a/broker/ipc.go +++ b/broker/ipc.go @@ -214,7 +214,9 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error { resp := &messages.ClientPollResponse{Answer: answer} err = sendClientResponse(resp, response) // Initial tracking of elapsed time. + i.ctx.metrics.lock.Lock() i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond + i.ctx.metrics.lock.Unlock() case <-time.After(time.Second * ClientTimeout): log.Println("Client: Timed out.") resp := &messages.ClientPollResponse{Error: messages.StrTimedOut} diff --git a/broker/sqs.go b/broker/sqs.go index 614dafeaf862f8226a7db1e323962a6c9eac2c1a..fb1164ec4480a4e55d0ab8d297fc0d9fb89a589f 100644 --- a/broker/sqs.go +++ b/broker/sqs.go @@ -213,7 +213,7 @@ func newSQSHandler(context context.Context, client sqsclient.SQSClient, sqsQueue func (r *sqsHandler) PollAndHandleMessages(ctx context.Context) { log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL) - messagesChn := make(chan *types.Message, 2) + messagesChn := make(chan *types.Message, 20) go r.pollMessages(ctx, messagesChn) go r.cleanupClientQueues(ctx) @@ -223,8 +223,10 @@ func (r *sqsHandler) PollAndHandleMessages(ctx context.Context) { // if context is cancelled return default: - r.handleMessage(ctx, message) - r.deleteMessage(ctx, message) + go func(msg *types.Message) { + r.handleMessage(ctx, msg) + r.deleteMessage(ctx, msg) + }(message) } } } diff --git a/broker/sqs_test.go b/broker/sqs_test.go index 7c7039082274fd8cc8557469b0c315b4e40e51f2..ab2e761c35cf46435082a243b897a3c4d9240a3a 100644 --- a/broker/sqs_test.go +++ b/broker/sqs_test.go @@ -7,6 +7,7 @@ import ( "log" "strconv" "sync" + "sync/atomic" "testing" "time" @@ -25,9 +26,6 @@ func TestSQS(t *testing.T) { ipcCtx := NewBrokerContext(log.New(buf, "", 0), "", "") i := &IPC{ipcCtx} - var logBuffer bytes.Buffer - log.SetOutput(&logBuffer) - Convey("Responds to SQS client offers...", func() { ctrl := gomock.NewController(t) mockSQSClient := sqsclient.NewMockSQSClient(ctrl) @@ -65,12 +63,7 @@ func TestSQS(t *testing.T) { } Convey("by ignoring it if no client id specified", func(c C) { - var wg sync.WaitGroup - wg.Add(1) - sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn( func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { return &sqs.ReceiveMessageOutput{ @@ -83,41 +76,32 @@ func TestSQS(t *testing.T) { }, nil }, ) - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).Times(1).Do( + mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).MinTimes(1).Do( func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) { - defer wg.Done() - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.") - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes() + sqsCancelFunc() }, ) + // We expect no queues to be created + mockSQSClient.EXPECT().CreateQueue(gomock.Any(), gomock.Any()).Times(0) runSQSHandler(sqsHandlerContext) + <-sqsHandlerContext.Done() }) Convey("by doing nothing if an error occurs upon receipt of the message", func(c C) { - var wg sync.WaitGroup - wg.Add(2) sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() - numTimes := 0 - // When ReceiveMessage is called for the first time, the error has not had a chance to be logged yet. - // Therefore, we opt to wait for the second call because we are guaranteed that the error was logged - // by then. - mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(2).DoAndReturn( + mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn( func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { - numTimes += 1 - if numTimes <= 2 { - wg.Done() - if numTimes == 2 { - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: encountered error while polling for messages: error") - } - } + sqsCancelFunc() return nil, errors.New("error") }, ) + // We expect no queues to be created or deleted + mockSQSClient.EXPECT().CreateQueue(gomock.Any(), gomock.Any()).Times(0) + mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).Times(0) runSQSHandler(sqsHandlerContext) + <-sqsHandlerContext.Done() }) Convey("by attempting to create a new sqs queue...", func() { @@ -125,68 +109,53 @@ func TestSQS(t *testing.T) { sqsCreateQueueInput := sqs.CreateQueueInput{ QueueName: aws.String("snowflake-client-fake-id"), } - - expectReceiveMessageReturnsValidMessage := func(sqsHandlerContext context.Context) { - mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn( - func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { - return &sqs.ReceiveMessageOutput{ - Messages: []types.Message{ - { - Body: messageBody, - MessageAttributes: map[string]types.MessageAttributeValue{ - "ClientID": {StringValue: &clientId}, - }, - ReceiptHandle: &receiptHandle, - }, - }, - }, nil + validMessage := &sqs.ReceiveMessageOutput{ + Messages: []types.Message{ + { + Body: messageBody, + MessageAttributes: map[string]types.MessageAttributeValue{ + "ClientID": {StringValue: &clientId}, + }, + ReceiptHandle: &receiptHandle, }, - ) + }, } - Convey("and does not attempt to send a message via SQS if queue creation fails.", func(c C) { - var wg sync.WaitGroup - wg.Add(2) - sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() - expectReceiveMessageReturnsValidMessage(sqsHandlerContext) + mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn( + func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { + sqsCancelFunc() + return validMessage, nil + }) mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(nil, errors.New("error")).AnyTimes() - numTimes := 0 - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).MinTimes(2).Do( - func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) { - numTimes += 1 - if numTimes <= 2 { - wg.Done() - if numTimes == 2 { - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: error encountered when creating answer queue for client fake-id: error") - } - } - }, - ) + mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes() runSQSHandler(sqsHandlerContext) + <-sqsHandlerContext.Done() }) Convey("and responds with a proxy answer if available.", func(c C) { - var wg sync.WaitGroup - wg.Add(1) - sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() - expectReceiveMessageReturnsValidMessage(sqsHandlerContext) + mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn( + func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { + + go func(c C) { + snowflake := ipcCtx.AddSnowflake("fake", "", NATUnrestricted, 0) + <-snowflake.offerChannel + snowflake.answerChannel <- "fake answer" + }(c) + return validMessage, nil + }) mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{ QueueUrl: responseQueueURL, }, nil).AnyTimes() - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes() - numTimes := 0 + mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).AnyTimes() + var numTimes atomic.Uint32 mockSQSClient.EXPECT().SendMessage(sqsHandlerContext, gomock.Any()).MinTimes(1).DoAndReturn( func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) { - numTimes += 1 - if numTimes == 1 { + n := numTimes.Add(1) + if n == 1 { c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}")) // Ensure that match is correctly recorded in metrics ipcCtx.metrics.printMetrics() @@ -201,19 +170,14 @@ client-ampcache-ips client-sqs-count 8 client-sqs-ips ??=8 `) - wg.Done() + sqsCancelFunc() } return &sqs.SendMessageOutput{}, nil }, ) runSQSHandler(sqsHandlerContext) - snowflake := ipcCtx.AddSnowflake("fake", "", NATUnrestricted, 0) - - offer := <-snowflake.offerChannel - So(offer.sdp, ShouldResemble, []byte("fake")) - - snowflake.answerChannel <- "fake answer" + <-sqsHandlerContext.Done() }) }) }) @@ -299,7 +263,6 @@ client-sqs-ips ??=8 // Executed on second iteration of cleanupClientQueues loop. This means that one full iteration has completed and we can verify the results of that iteration wg.Done() sqsCancelFunc() - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: finished running iteration of client queue cleanup. found and deleted 2 client queues.") return &sqs.ListQueuesOutput{ QueueUrls: []string{}, }, nil