Commit d47fecdb authored by opara's avatar opara 🙃
Browse files

tor-rtcompat: manually create and connect() socket

This will allow us in the future to set custom sockopts on the socket
before calling connect().

I tested an arti proxy with tokio and async-std manually. Arti doesn't
yet support smol so I was not able to test it, but it's using the same
code as async-std so I would expect it to work.
parent 926c7a54
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -116,7 +116,7 @@ zeroize = "1"
[target.'cfg(not(all(target_arch="wasm32", target_os="unknown")))'.dependencies]
socket2 = "0.6.1"

[target.'cfg(target_os = "linux")'.dependencies]
[target.'cfg(unix)'.dependencies]
libc = "0.2"

[dev-dependencies]
+126 −2
Original line number Diff line number Diff line
@@ -20,7 +20,10 @@ pub(crate) mod native_tls;
pub(crate) mod streamops;
pub(crate) mod unimpl_tls;

use crate::network::{CommonListenOptions, TcpListenOptions};
use crate::network::{
    CommonConnectOptions, CommonListenOptions, TcpConnectOptions, TcpListenOptions,
};
use socket2::Socket;

#[cfg(unix)]
use tor_error::warn_report;
@@ -103,7 +106,7 @@ pub(crate) fn tcp_listen(
    addr: &std::net::SocketAddr,
    options: &TcpListenOptions,
) -> std::io::Result<std::net::TcpListener> {
    use socket2::{Domain, Socket, Type};
    use socket2::{Domain, Type};

    // Destructure the options so that we don't forget to use any.
    let TcpListenOptions {
@@ -193,6 +196,127 @@ pub(crate) fn tcp_listen(
    Err(std::io::Error::from(std::io::ErrorKind::Unsupported))
}

/// Initialize a TCP socket in preparation for a connect().
///
/// The socket will be non-blocking, and the socket handle will be close-on-exec/non-inheritable.
/// Other socket options may also be set depending on the socket type and platform.
///
/// This returns a socket without any `connect()` call. The caller MUST:
///
/// 1. connect() the socket.
/// 2. Wait for the socket to become writable using whatever mechanism
///    is available with the current runtime.
/// 3. Check `SO_ERROR` for errors.
///
/// Historically we relied on the runtime to create and connect the socket, but we need some
/// specific socket options set, and not all runtimes will behave the same. It's better for us to
/// create the socket with the options we need and with consistent behaviour across all runtimes.
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
pub(crate) fn tcp_pre_connect(
    addr: &std::net::SocketAddr,
    options: &TcpConnectOptions,
) -> std::io::Result<socket2::Socket> {
    use socket2::{Domain, Type};

    // Destructure the options so that we don't forget to use any.
    let TcpConnectOptions {
        common: CommonConnectOptions {},
    } = options;

    let domain = match addr {
        std::net::SocketAddr::V4(_) => Domain::IPV4,
        std::net::SocketAddr::V6(_) => Domain::IPV6,
    };

    // `socket2::Socket::new()`:
    // > This function corresponds to `socket(2)` on Unix and `WSASocketW` on Windows.
    // >
    // > On Unix-like systems, the close-on-exec flag is set on the new socket. Additionally, on
    // > Apple platforms `SOCK_NOSIGPIPE` is set. On Windows, the socket is made non-inheritable.
    let socket = Socket::new(domain, Type::STREAM, None)?;

    socket.set_nonblocking(true)?;

    // TODO: In the future, we'll likely want to support optionally binding to an address or to a
    // network interface (`SO_BINDTODEVICE`). See c-tor's `OutboundBindAddresses`.
    // If we do, we will also want to set `IP_BIND_ADDRESS_NO_PORT`.
    // We may also want to consider setting `IPV6_V6ONLY` (do we want to support connecting to
    // IPv4-mapped IPv6 addresses while we already do happy eyeballs?).

    // We do not connect() here so that we can use whatever connection mechanism is best for the
    // runtime being used.

    Ok(socket)
}

/// Stub replacement for tcp_pre_connect on wasm32-unknown
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub(crate) fn tcp_pre_connect(
    _addr: &std::net::SocketAddr,
    _options: &TcpConnectOptions,
) -> std::io::Result<socket2::Socket> {
    Err(std::io::Error::from(std::io::ErrorKind::Unsupported))
}

/// Connect a TCP socket using the async-io crate.
///
/// This in theory should be runtime-independent as async-io spawns its own thread to poll the
/// socket. But this is inefficient on some runtimes like tokio.
///
/// Runtimes that want to connect manually should use [`tcp_pre_connect()`] to set up the socket,
/// and then connect it manually.
#[cfg(any(feature = "async-std", feature = "smol"))]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
pub(crate) async fn tcp_async_io_connect(
    addr: &std::net::SocketAddr,
    options: &TcpConnectOptions,
) -> std::io::Result<std::net::TcpStream> {
    use async_io::Async;

    // The socket before connect() has been called.
    let socket = tcp_pre_connect(addr, options)?;

    // Different platforms return different results from non-blocking `connect()`s.
    // Here we've checked that we match mio (tokio's low-level I/O code) for unix and windows
    // to ensure that we're handling the right error kind/errno.
    match socket.connect(&(*addr).into()) {
        Ok(()) => {}
        // On unix, mio checks for `EINPROGRESS`:
        // https://github.com/tokio-rs/mio/blob/0db25a7eae653f02e964a28d9aaf65b74c941208/src/sys/unix/tcp.rs#L35
        #[cfg(unix)]
        Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {}
        // On windows, mio checks for `WouldBlock`:
        // https://github.com/tokio-rs/mio/blob/0db25a7eae653f02e964a28d9aaf65b74c941208/src/sys/windows/tcp.rs#L44
        #[cfg(windows)]
        Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
        Err(e) => return Err(e),
    }

    // The socket is already non-blocking,
    // so `Async` doesn't need to set as non-blocking again.
    let socket = Async::new_nonblocking(socket)?;

    // Wait for the socket to become writable, indicating that it's connected.
    socket.writable().await?;

    // Check `SO_ERROR`.
    if let Some(e) = socket.get_ref().take_error()? {
        return Err(e);
    }

    Ok(socket.into_inner()?.into())
}

/// Stub replacement for tcp_async_io_connect on wasm32-unknown
#[cfg(any(feature = "async-std", feature = "smol"))]
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub(crate) async fn tcp_async_io_connect(
    _addr: &std::net::SocketAddr,
    _options: &TcpConnectOptions,
) -> std::io::Result<std::net::TcpStream> {
    Err(std::io::Error::from(std::io::ErrorKind::Unsupported))
}

/// Helper: Implement an unreachable NetProvider<unix::SocketAddr> for a given runtime.
#[cfg(not(unix))]
macro_rules! impl_unix_non_provider {
+3 −4
Original line number Diff line number Diff line
@@ -86,10 +86,9 @@ mod net {
            addr: &SocketAddr,
            options: &Self::ConnectOptions,
        ) -> IoResult<Self::Stream> {
            // XXXX use the options
            let _ = options;

            TcpStream::connect(addr).await
            // The async-std runtime uses async-io internally.
            let stream = impls::tcp_async_io_connect(addr, options).await?;
            Ok(stream.into())
        }
        async fn listen(
            &self,
+5 −3
Original line number Diff line number Diff line
@@ -87,10 +87,12 @@ pub(crate) mod net {
            addr: &SocketAddr,
            options: &Self::ConnectOptions,
        ) -> IoResult<Self::Stream> {
            // XXXX use the options
            let _ = options;
            // The smol runtime uses async-io internally.
            let stream = impls::tcp_async_io_connect(addr, options).await?;

            TcpStream::connect(addr).await
            // The socket is already non-blocking,
            // so `Async` doesn't need to set as non-blocking again.
            Ok(Async::new_nonblocking(stream)?.into())
        }

        async fn listen(
+21 −5
Original line number Diff line number Diff line
@@ -231,11 +231,27 @@ impl crate::traits::NetStreamProvider for TokioRuntimeHandle {
        addr: &std::net::SocketAddr,
        options: &Self::ConnectOptions,
    ) -> IoResult<Self::Stream> {
        // XXXX use the options
        let _ = options;

        let s = net::TokioTcpStream::connect(addr).await?;
        Ok(s.into())
        // The socket before connect() has been called.
        let socket = super::tcp_pre_connect(addr, options)?;

        // It might seem a little weird to convert the `socket2::Socket` to a std `TcpStream` before
        // it's connected, but this is the approach recommended by tokio.
        //
        // https://docs.rs/tokio/latest/tokio/net/struct.TcpSocket.html#method.from_std_stream
        //
        // > Converts a `std::net::TcpStream` into a `TcpSocket`. The provided socket must not have
        // > been connected prior to calling this function. This function is typically used together
        // > with crates such as socket2 to configure socket options that are not available on
        // > `TcpSocket`.
        //
        // The socket will already be non-blocking.
        let socket = std::net::TcpStream::from(socket);
        let socket = tokio_crate::net::TcpSocket::from_std_stream(socket);

        // Let tokio handle the connection.
        let socket = socket.connect(*addr).await?;

        Ok(socket.into())
    }
    async fn listen(
        &self,