Loading crates/arti-client/examples/hook-tcp.rs 0 → 100644 +339 −0 Original line number Diff line number Diff line //! This example showcases using a custom [`TcpProvider`] to do custom actions before Arti initiates //! TCP connections, and after the connections are closed. //! //! This might be useful, for example, to dynamically open ports on a restrictive firewall or modify //! routing information. It would also be possible to adapt the example to make it proxy the TCP //! connections somehow, depending on your usecase. use std::future::Future; use std::io::Result as IoResult; use std::net::SocketAddr; use std::pin::Pin; use std::task::{Context, Poll}; use anyhow::Result; use arti_client::{TorClient, TorClientConfig}; use tokio_crate as tokio; use tor_rtcompat::tokio::TokioNativeTlsRuntime; use futures::{AsyncRead, AsyncWrite, FutureExt, Stream}; use tor_rtcompat::tls::NativeTlsProvider; use tor_rtcompat::{CompoundRuntime, TcpListener, TcpProvider}; use futures::io::{AsyncReadExt, AsyncWriteExt}; #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let config = TorClientConfig::default(); let rt = TokioNativeTlsRuntime::current()?; // Instantiate our custom TCP provider (see implementation below). let tcp_rt = CustomTcpProvider { inner: rt.clone() }; // Create a `CompoundRuntime`, swapping out the TCP part of the runtime for our custom one. let rt = CompoundRuntime::new(rt.clone(), rt, tcp_rt, NativeTlsProvider::default()); eprintln!("connecting to Tor..."); let tor_client = TorClient::create_bootstrapped(rt, config).await?; eprintln!("connecting to example.com..."); let mut stream = tor_client.connect(("example.com", 80)).await?; eprintln!("sending request..."); stream .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n") .await?; stream.flush().await?; eprintln!("reading response..."); let mut buf = Vec::new(); stream.read_to_end(&mut buf).await?; println!("{}", String::from_utf8_lossy(&buf)); Ok(()) } /// A custom TCP provider that relies on an existing TCP provider (`inner`), but modifies its /// behavior. struct CustomTcpProvider<T> { /// The underlying TCP provider. inner: T, } /// A custom TCP stream that wraps another TCP provider's TCP stream type, letting us do things /// when the stream is read from, written to, or closed. struct CustomTcpStream<T> { /// The underlying TCP stream. inner: T, /// The address of the remote peer at the other end of this stream. addr: SocketAddr, /// The current state of the socket: whether it is open, in the process of closing, or closed. state: TcpState, } /// An enum representing states a TCP stream can be in. #[derive(PartialEq)] enum TcpState { /// Stream is open. Open, /// We've sent a close, but haven't received one. SendClosed, /// We've received a close, but haven't sent one. RecvClosed, /// Stream is fully closed. Closed, } /// A wrapper over a `TcpListener`. struct CustomTcpListener<T> { inner: T, } /// An `Incoming` type for our `CustomTcpListener`. struct CustomIncoming<T> { inner: T, } impl<T> TcpProvider for CustomTcpProvider<T> where T: TcpProvider, { type TcpStream = CustomTcpStream<T::TcpStream>; type TcpListener = CustomTcpListener<T::TcpListener>; // This is an async trait method (using the `async_trait` crate). We manually implement it // here so that we don't borrow `self` for too long. // (The lifetimes are explicit and somewhat ugly because that's how `async_trait` works.) fn connect<'a, 'b, 'c>( &'a self, addr: &'b SocketAddr, ) -> Pin<Box<dyn Future<Output = IoResult<Self::TcpStream>> + Send + 'c>> where 'a: 'c, 'b: 'c, Self: 'c, { // Use the underlying TCP provider implementation to do the connection, and // return our wrapper around it once done. println!("tcp connect to {}", addr); self.inner .connect(addr) .map(move |r| { r.map(|stream| CustomTcpStream { inner: stream, addr: *addr, state: TcpState::Open, }) }) .boxed() } // This is also an async trait method (see above). fn listen<'a, 'b, 'c>( &'a self, addr: &'b SocketAddr, ) -> Pin<Box<dyn Future<Output = IoResult<Self::TcpListener>> + Send + 'c>> where 'a: 'c, 'b: 'c, Self: 'c, { // Use the underlying TCP provider implementation to make the listener, and // return our wrapper around it once done. println!("tcp listen on {}", addr); self.inner .listen(addr) .map(|l| l.map(|listener| CustomTcpListener { inner: listener })) .boxed() } } // We implement `AsyncRead` and `AsyncWrite` for our custom TCP stream object. // This implementation mostly uses the underlying stream's methods, but we insert some // code to check for a zero-byte read (indicating stream closure), and callers closing the // stream, and use that to update our `TcpState`. // When we detect that the stream is closed, we run some code (in this case, just a `println!`). impl<T> AsyncRead for CustomTcpStream<T> where T: AsyncRead + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<IoResult<usize>> { // Call the underlying stream's method. let res = Pin::new(&mut self.inner).poll_read(cx, buf); // Check for a zero-byte read, indicating closure. if let Poll::Ready(Ok(0)) = res { // Ignore if buf is zero-lenght, reading 0 bytes doesn't mean eof in that case if !buf.is_empty() { match self.state { // If we're already closed, do nothing. TcpState::Closed | TcpState::RecvClosed => (), // We're open, and haven't tried to close the stream yet, so note that // the other side closed it. TcpState::Open => self.state = TcpState::RecvClosed, // We've closed the stream on our end, and the other side has now closed it // too, so the stream is now fully closed. TcpState::SendClosed => { println!("closed a connecion to {}", self.addr); self.state = TcpState::Closed; } } } } res } // Do the same thing, but for `poll_read_vectored`. fn poll_read_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [std::io::IoSliceMut<'_>], ) -> Poll<IoResult<usize>> { let res = Pin::new(&mut self.inner).poll_read_vectored(cx, bufs); if let Poll::Ready(Ok(0)) = res { if bufs.iter().any(|buf| !buf.is_empty()) { match self.state { TcpState::Closed | TcpState::RecvClosed => (), TcpState::Open => self.state = TcpState::RecvClosed, TcpState::SendClosed => { println!("closed a connecion to {}", self.addr); self.state = TcpState::Closed; } } } } res } } // The only thing that's custom here is checking for closure. Everything else is just calling // `self.inner`. impl<T> AsyncWrite for CustomTcpStream<T> where T: AsyncWrite + Unpin, { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<IoResult<usize>> { Pin::new(&mut self.inner).poll_write(cx, buf) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> { Pin::new(&mut self.inner).poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> { let res = Pin::new(&mut self.inner).poll_close(cx); if res.is_ready() { match self.state { TcpState::Closed | TcpState::SendClosed => (), TcpState::Open => self.state = TcpState::SendClosed, TcpState::RecvClosed => { println!("closed a connection to {}", self.addr); self.state = TcpState::Closed; } } } res } fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> Poll<IoResult<usize>> { Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) } } impl<T> Drop for CustomTcpStream<T> { fn drop(&mut self) { if self.state != TcpState::Closed { println!("closed a connection to {}", self.addr); } } } type AcceptResult<T> = IoResult<(T, SocketAddr)>; impl<T> TcpListener for CustomTcpListener<T> where T: TcpListener, { type TcpStream = CustomTcpStream<T::TcpStream>; type Incoming = CustomIncoming<T::Incoming>; // This is also an async trait method (see earlier commentary). fn accept<'a, 'b>( &'a self, ) -> Pin<Box<dyn Future<Output = AcceptResult<Self::TcpStream>> + Send + 'b>> where 'a: 'b, Self: 'b, { // As with other implementations, we just defer to `self.inner` and wrap the result. self.inner .accept() .inspect(|r| { if let Ok((_, addr)) = r { println!("accepted connection from {}", addr) } }) .map(|r| { r.map(|(stream, addr)| { ( CustomTcpStream { inner: stream, addr, state: TcpState::Open, }, addr, ) }) }) .boxed() } fn incoming(self) -> Self::Incoming { CustomIncoming { inner: self.inner.incoming(), } } fn local_addr(&self) -> IoResult<SocketAddr> { self.inner.local_addr() } } impl<T, S> Stream for CustomIncoming<T> where T: Stream<Item = IoResult<(S, SocketAddr)>> + Unpin, { type Item = IoResult<(CustomTcpStream<S>, SocketAddr)>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { match Pin::new(&mut self.inner).poll_next(cx) { Poll::Ready(Some(Ok((stream, addr)))) => Poll::Ready(Some(Ok(( CustomTcpStream { inner: stream, addr, state: TcpState::Open, }, addr, )))), Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } } crates/tor-rtcompat/src/lib.rs +5 −0 Original line number Diff line number Diff line Loading @@ -164,6 +164,11 @@ pub use timer::{SleepProviderExt, Timeout, TimeoutError}; /// create them. pub mod tls { pub use crate::traits::{CertifiedConn, TlsConnector}; #[cfg(feature = "native-tls")] pub use crate::impls::native_tls::NativeTlsProvider; #[cfg(feature = "rustls")] pub use crate::impls::rustls::RustlsProvider; } #[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "tokio"))] Loading Loading
crates/arti-client/examples/hook-tcp.rs 0 → 100644 +339 −0 Original line number Diff line number Diff line //! This example showcases using a custom [`TcpProvider`] to do custom actions before Arti initiates //! TCP connections, and after the connections are closed. //! //! This might be useful, for example, to dynamically open ports on a restrictive firewall or modify //! routing information. It would also be possible to adapt the example to make it proxy the TCP //! connections somehow, depending on your usecase. use std::future::Future; use std::io::Result as IoResult; use std::net::SocketAddr; use std::pin::Pin; use std::task::{Context, Poll}; use anyhow::Result; use arti_client::{TorClient, TorClientConfig}; use tokio_crate as tokio; use tor_rtcompat::tokio::TokioNativeTlsRuntime; use futures::{AsyncRead, AsyncWrite, FutureExt, Stream}; use tor_rtcompat::tls::NativeTlsProvider; use tor_rtcompat::{CompoundRuntime, TcpListener, TcpProvider}; use futures::io::{AsyncReadExt, AsyncWriteExt}; #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let config = TorClientConfig::default(); let rt = TokioNativeTlsRuntime::current()?; // Instantiate our custom TCP provider (see implementation below). let tcp_rt = CustomTcpProvider { inner: rt.clone() }; // Create a `CompoundRuntime`, swapping out the TCP part of the runtime for our custom one. let rt = CompoundRuntime::new(rt.clone(), rt, tcp_rt, NativeTlsProvider::default()); eprintln!("connecting to Tor..."); let tor_client = TorClient::create_bootstrapped(rt, config).await?; eprintln!("connecting to example.com..."); let mut stream = tor_client.connect(("example.com", 80)).await?; eprintln!("sending request..."); stream .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n") .await?; stream.flush().await?; eprintln!("reading response..."); let mut buf = Vec::new(); stream.read_to_end(&mut buf).await?; println!("{}", String::from_utf8_lossy(&buf)); Ok(()) } /// A custom TCP provider that relies on an existing TCP provider (`inner`), but modifies its /// behavior. struct CustomTcpProvider<T> { /// The underlying TCP provider. inner: T, } /// A custom TCP stream that wraps another TCP provider's TCP stream type, letting us do things /// when the stream is read from, written to, or closed. struct CustomTcpStream<T> { /// The underlying TCP stream. inner: T, /// The address of the remote peer at the other end of this stream. addr: SocketAddr, /// The current state of the socket: whether it is open, in the process of closing, or closed. state: TcpState, } /// An enum representing states a TCP stream can be in. #[derive(PartialEq)] enum TcpState { /// Stream is open. Open, /// We've sent a close, but haven't received one. SendClosed, /// We've received a close, but haven't sent one. RecvClosed, /// Stream is fully closed. Closed, } /// A wrapper over a `TcpListener`. struct CustomTcpListener<T> { inner: T, } /// An `Incoming` type for our `CustomTcpListener`. struct CustomIncoming<T> { inner: T, } impl<T> TcpProvider for CustomTcpProvider<T> where T: TcpProvider, { type TcpStream = CustomTcpStream<T::TcpStream>; type TcpListener = CustomTcpListener<T::TcpListener>; // This is an async trait method (using the `async_trait` crate). We manually implement it // here so that we don't borrow `self` for too long. // (The lifetimes are explicit and somewhat ugly because that's how `async_trait` works.) fn connect<'a, 'b, 'c>( &'a self, addr: &'b SocketAddr, ) -> Pin<Box<dyn Future<Output = IoResult<Self::TcpStream>> + Send + 'c>> where 'a: 'c, 'b: 'c, Self: 'c, { // Use the underlying TCP provider implementation to do the connection, and // return our wrapper around it once done. println!("tcp connect to {}", addr); self.inner .connect(addr) .map(move |r| { r.map(|stream| CustomTcpStream { inner: stream, addr: *addr, state: TcpState::Open, }) }) .boxed() } // This is also an async trait method (see above). fn listen<'a, 'b, 'c>( &'a self, addr: &'b SocketAddr, ) -> Pin<Box<dyn Future<Output = IoResult<Self::TcpListener>> + Send + 'c>> where 'a: 'c, 'b: 'c, Self: 'c, { // Use the underlying TCP provider implementation to make the listener, and // return our wrapper around it once done. println!("tcp listen on {}", addr); self.inner .listen(addr) .map(|l| l.map(|listener| CustomTcpListener { inner: listener })) .boxed() } } // We implement `AsyncRead` and `AsyncWrite` for our custom TCP stream object. // This implementation mostly uses the underlying stream's methods, but we insert some // code to check for a zero-byte read (indicating stream closure), and callers closing the // stream, and use that to update our `TcpState`. // When we detect that the stream is closed, we run some code (in this case, just a `println!`). impl<T> AsyncRead for CustomTcpStream<T> where T: AsyncRead + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<IoResult<usize>> { // Call the underlying stream's method. let res = Pin::new(&mut self.inner).poll_read(cx, buf); // Check for a zero-byte read, indicating closure. if let Poll::Ready(Ok(0)) = res { // Ignore if buf is zero-lenght, reading 0 bytes doesn't mean eof in that case if !buf.is_empty() { match self.state { // If we're already closed, do nothing. TcpState::Closed | TcpState::RecvClosed => (), // We're open, and haven't tried to close the stream yet, so note that // the other side closed it. TcpState::Open => self.state = TcpState::RecvClosed, // We've closed the stream on our end, and the other side has now closed it // too, so the stream is now fully closed. TcpState::SendClosed => { println!("closed a connecion to {}", self.addr); self.state = TcpState::Closed; } } } } res } // Do the same thing, but for `poll_read_vectored`. fn poll_read_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [std::io::IoSliceMut<'_>], ) -> Poll<IoResult<usize>> { let res = Pin::new(&mut self.inner).poll_read_vectored(cx, bufs); if let Poll::Ready(Ok(0)) = res { if bufs.iter().any(|buf| !buf.is_empty()) { match self.state { TcpState::Closed | TcpState::RecvClosed => (), TcpState::Open => self.state = TcpState::RecvClosed, TcpState::SendClosed => { println!("closed a connecion to {}", self.addr); self.state = TcpState::Closed; } } } } res } } // The only thing that's custom here is checking for closure. Everything else is just calling // `self.inner`. impl<T> AsyncWrite for CustomTcpStream<T> where T: AsyncWrite + Unpin, { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<IoResult<usize>> { Pin::new(&mut self.inner).poll_write(cx, buf) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> { Pin::new(&mut self.inner).poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> { let res = Pin::new(&mut self.inner).poll_close(cx); if res.is_ready() { match self.state { TcpState::Closed | TcpState::SendClosed => (), TcpState::Open => self.state = TcpState::SendClosed, TcpState::RecvClosed => { println!("closed a connection to {}", self.addr); self.state = TcpState::Closed; } } } res } fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> Poll<IoResult<usize>> { Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) } } impl<T> Drop for CustomTcpStream<T> { fn drop(&mut self) { if self.state != TcpState::Closed { println!("closed a connection to {}", self.addr); } } } type AcceptResult<T> = IoResult<(T, SocketAddr)>; impl<T> TcpListener for CustomTcpListener<T> where T: TcpListener, { type TcpStream = CustomTcpStream<T::TcpStream>; type Incoming = CustomIncoming<T::Incoming>; // This is also an async trait method (see earlier commentary). fn accept<'a, 'b>( &'a self, ) -> Pin<Box<dyn Future<Output = AcceptResult<Self::TcpStream>> + Send + 'b>> where 'a: 'b, Self: 'b, { // As with other implementations, we just defer to `self.inner` and wrap the result. self.inner .accept() .inspect(|r| { if let Ok((_, addr)) = r { println!("accepted connection from {}", addr) } }) .map(|r| { r.map(|(stream, addr)| { ( CustomTcpStream { inner: stream, addr, state: TcpState::Open, }, addr, ) }) }) .boxed() } fn incoming(self) -> Self::Incoming { CustomIncoming { inner: self.inner.incoming(), } } fn local_addr(&self) -> IoResult<SocketAddr> { self.inner.local_addr() } } impl<T, S> Stream for CustomIncoming<T> where T: Stream<Item = IoResult<(S, SocketAddr)>> + Unpin, { type Item = IoResult<(CustomTcpStream<S>, SocketAddr)>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { match Pin::new(&mut self.inner).poll_next(cx) { Poll::Ready(Some(Ok((stream, addr)))) => Poll::Ready(Some(Ok(( CustomTcpStream { inner: stream, addr, state: TcpState::Open, }, addr, )))), Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } }
crates/tor-rtcompat/src/lib.rs +5 −0 Original line number Diff line number Diff line Loading @@ -164,6 +164,11 @@ pub use timer::{SleepProviderExt, Timeout, TimeoutError}; /// create them. pub mod tls { pub use crate::traits::{CertifiedConn, TlsConnector}; #[cfg(feature = "native-tls")] pub use crate::impls::native_tls::NativeTlsProvider; #[cfg(feature = "rustls")] pub use crate::impls::rustls::RustlsProvider; } #[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "tokio"))] Loading