Verified Commit b519cbc3 authored by Randell Jesup's avatar Randell Jesup Committed by ma1
Browse files

Bug 2036907: Fix WebTransportSessionBase::mListener data race r=necko-reviewers,valentin

parent 45cb457e
Loading
Loading
Loading
Loading
+15 −10
Original line number Diff line number Diff line
@@ -367,10 +367,13 @@ bool Http2WebTransportSessionImpl::OnCapsule(Capsule&& aCapsule) {
      LOG(("Handling DATAGRAM\n"));
      WebTransportDatagramCapsule& datagram =
          aCapsule.GetWebTransportDatagramCapsule();
      if (RefPtr<WebTransportSessionEventListener> baseListener =
              GetListener()) {
        if (nsCOMPtr<WebTransportSessionEventListenerInternal> listener =
              do_QueryInterface(mListener)) {
                do_QueryInterface(baseListener)) {
          listener->OnDatagramReceivedInternal(std::move(datagram.mPayload));
        }
      }
      break;
    }
    default:
@@ -434,8 +437,8 @@ bool Http2WebTransportSessionImpl::HandleStreamStopSendingCapsule(

  uint8_t wtError = Http3ErrorToWebTransportError(stopSending.mErrorCode);
  nsresult rv = GetNSResultFromWebTransportError(wtError);
  if (mListener) {
    mListener->OnStopSending(aId, rv);
  if (RefPtr<WebTransportSessionEventListener> listener = GetListener()) {
    listener->OnStopSending(aId, rv);
  }
  return true;
}
@@ -454,8 +457,8 @@ bool Http2WebTransportSessionImpl::HandleStreamResetCapsule(

  uint8_t wtError = Http3ErrorToWebTransportError(reset.mErrorCode);
  nsresult rv = GetNSResultFromWebTransportError(wtError);
  if (mListener) {
    mListener->OnResetReceived(aId, rv);
  if (RefPtr<WebTransportSessionEventListener> listener = GetListener()) {
    listener->OnResetReceived(aId, rv);
  }

  return true;
@@ -510,11 +513,13 @@ bool Http2WebTransportSessionImpl::ProcessIncomingStreamCapsule(
      return false;
    }
    mIncomingStreams.InsertOrUpdate(newStreamID, stream);
    if (RefPtr<WebTransportSessionEventListener> baseListener = GetListener()) {
      if (nsCOMPtr<WebTransportSessionEventListenerInternal> listener =
            do_QueryInterface(mListener)) {
              do_QueryInterface(baseListener)) {
        listener->OnIncomingStreamAvailableInternal(stream);
      }
    }
  }

  stream = mIncomingStreams.Get(aID);
  if (stream) {
+24 −19
Original line number Diff line number Diff line
@@ -324,9 +324,8 @@ nsresult Http3WebTransportSession::OnWriteSegment(char* buf, uint32_t count,

void Http3WebTransportSession::Close(nsresult aResult) {
  LOG(("Http3WebTransportSession::Close %p", this));
  if (mListener) {
    mListener->OnSessionClosed(NS_SUCCEEDED(aResult), 0, ""_ns);
    mListener = nullptr;
  if (RefPtr<WebTransportSessionEventListener> listener = TakeListener()) {
    listener->OnSessionClosed(NS_SUCCEEDED(aResult), 0, ""_ns);
  }
  if (mTransaction) {
    mTransaction->Close(aResult);
@@ -347,9 +346,8 @@ void Http3WebTransportSession::OnSessionClosed(bool aCleanly, uint32_t aStatus,
    mTransaction->Close(NS_BASE_STREAM_CLOSED);
    mTransaction = nullptr;
  }
  if (mListener) {
    mListener->OnSessionClosed(aCleanly, aStatus, aReason);
    mListener = nullptr;
  if (RefPtr<WebTransportSessionEventListener> listener = TakeListener()) {
    listener->OnSessionClosed(aCleanly, aStatus, aReason);
  }
  mRecvState = RECV_DONE;
  mSendState = SEND_DONE;
@@ -366,7 +364,8 @@ void Http3WebTransportSession::CloseSession(uint32_t aStatus,
    mRecvState = CLOSE_PENDING;
    mSendState = SEND_DONE;
  }
  mListener = nullptr;
  RefPtr<WebTransportSessionEventListener> listener = TakeListener();
  // let it drop
}

void Http3WebTransportSession::TransactionIsDone(nsresult aResult) {
@@ -447,12 +446,13 @@ Http3WebTransportSession::OnIncomingWebTransportStream(
    }
  }

  if (!mListener) {
  RefPtr<WebTransportSessionEventListener> baseListener = GetListener();
  if (!baseListener) {
    return nullptr;
  }

  if (nsCOMPtr<WebTransportSessionEventListenerInternal> listener =
          do_QueryInterface(mListener)) {
          do_QueryInterface(baseListener)) {
    listener->OnIncomingStreamAvailableInternal(stream);
  }
  return stream.forget();
@@ -473,24 +473,26 @@ void Http3WebTransportSession::SendDatagram(nsTArray<uint8_t>&& aData,
void Http3WebTransportSession::OnDatagramReceived(nsTArray<uint8_t>&& aData) {
  MOZ_ASSERT(OnSocketThread(), "not on socket thread");
  LOG(("Http3WebTransportSession::OnDatagramReceived this=%p", this));
  if (mRecvState != ACTIVE || !mListener) {
  RefPtr<WebTransportSessionEventListener> baseListener2 = GetListener();
  if (mRecvState != ACTIVE || !baseListener2) {
    return;
  }

  if (nsCOMPtr<WebTransportSessionEventListenerInternal> listener =
          do_QueryInterface(mListener)) {
          do_QueryInterface(baseListener2)) {
    listener->OnDatagramReceivedInternal(std::move(aData));
  }
}

void Http3WebTransportSession::GetMaxDatagramSize() {
  MOZ_ASSERT(OnSocketThread(), "not on socket thread");
  if (mRecvState != ACTIVE || !mListener) {
  RefPtr<WebTransportSessionEventListener> listener = GetListener();
  if (mRecvState != ACTIVE || !listener) {
    return;
  }

  uint64_t size = mSession->MaxDatagramSize(mStreamId);
  mListener->OnMaxDatagramSize(size);
  listener->OnMaxDatagramSize(size);
}

void Http3WebTransportSession::OnOutgoingDatagramOutCome(
@@ -499,30 +501,33 @@ void Http3WebTransportSession::OnOutgoingDatagramOutCome(
  LOG(("Http3WebTransportSession::OnOutgoingDatagramOutCome this=%p id=%" PRIx64
       ", outCome=%d mRecvState=%d",
       this, aId, static_cast<uint32_t>(aOutCome), mRecvState));
  if (mRecvState != ACTIVE || !mListener || !aId) {
  RefPtr<WebTransportSessionEventListener> listener = GetListener();
  if (mRecvState != ACTIVE || !listener || !aId) {
    return;
  }

  mListener->OnOutgoingDatagramOutCome(aId, aOutCome);
  listener->OnOutgoingDatagramOutCome(aId, aOutCome);
}

void Http3WebTransportSession::OnStreamStopSending(uint64_t aId,
                                                   nsresult aError) {
  LOG(("OnStreamStopSending id:%" PRId64, aId));
  if (!mListener) {
  RefPtr<WebTransportSessionEventListener> listener = GetListener();
  if (!listener) {
    return;
  }

  mListener->OnStopSending(aId, aError);
  listener->OnStopSending(aId, aError);
}

void Http3WebTransportSession::OnStreamReset(uint64_t aId, nsresult aError) {
  LOG(("OnStreamReset id:%" PRId64, aId));
  if (!mListener) {
  RefPtr<WebTransportSessionEventListener> listener = GetListener();
  if (!listener) {
    return;
  }

  mListener->OnResetReceived(aId, aError);
  listener->OnResetReceived(aId, aError);
}

}  // namespace mozilla::net
+13 −0
Original line number Diff line number Diff line
@@ -11,7 +11,20 @@ namespace mozilla::net {

void WebTransportSessionBase::SetWebTransportSessionEventListener(
    WebTransportSessionEventListener* listener) {
  MutexAutoLock lock(mListenerLock);
  mListener = listener;
}

already_AddRefed<WebTransportSessionEventListener>
WebTransportSessionBase::GetListener() {
  MutexAutoLock lock(mListenerLock);
  return do_AddRef(mListener);
}

already_AddRefed<WebTransportSessionEventListener>
WebTransportSessionBase::TakeListener() {
  MutexAutoLock lock(mListenerLock);
  return mListener.forget();
}

}  // namespace mozilla::net
+7 −1
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@

#include <functional>

#include "mozilla/Mutex.h"
#include "nsISupportsImpl.h"
#include "nsTArray.h"

@@ -42,7 +43,12 @@ class WebTransportSessionBase {
 protected:
  virtual ~WebTransportSessionBase() = default;

  RefPtr<WebTransportSessionEventListener> mListener;
  already_AddRefed<WebTransportSessionEventListener> GetListener();
  already_AddRefed<WebTransportSessionEventListener> TakeListener();

  Mutex mListenerLock{"WebTransportSessionBase::mListenerLock"};
  RefPtr<WebTransportSessionEventListener> mListener
      MOZ_GUARDED_BY(mListenerLock);
};

}  // namespace mozilla::net