Commit 12c294c7 authored by David Goulet's avatar David Goulet
Browse files

socket: Add async UdpSocket support



Signed-off-by: default avatarDavid Goulet <dgoulet@ev0ke.net>
parent a05ff87d
Loading
Loading
Loading
Loading
+115 −2
Original line number Diff line number Diff line
@@ -5,7 +5,6 @@ use std::{
    task::{Context, Poll},
};

use log::info;
use smoltcp::{
    iface::{Interface, SocketHandle},
    phy::TunTapInterface,
@@ -21,6 +20,121 @@ pub struct TcpSocket {
    iface: IFace,
}

pub struct UdpSocket {
    handle: SocketHandle,
    iface: IFace,
}

impl UdpSocket {
    pub fn new(iface: IFace, s: smoltcp::socket::UdpSocket<'static>) -> Self {
        let handle = iface.lock().unwrap().add_socket(s);
        Self { handle, iface }
    }

    fn with<R>(&mut self, f: impl FnOnce(&mut smoltcp::socket::UdpSocket) -> R) -> R {
        f(self
            .iface
            .lock()
            .unwrap()
            .get_socket::<smoltcp::socket::UdpSocket>(self.handle))
    }

    pub fn dest(&mut self) -> (IpAddr, u16) {
        self.with(|s| {
            let endpoint = s.endpoint();
            (endpoint.addr.into(), endpoint.port)
        })
    }
}

impl Drop for UdpSocket {
    fn drop(&mut self) {
        self.iface.lock().unwrap().remove_socket(self.handle);
    }
}

impl AsyncRead for UdpSocket {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        self.with(|s| match s.can_recv() {
            true => {
                let rbuf = s.recv();
                match rbuf {
                    Err(e) => {
                        return Poll::Ready(Err(std::io::Error::new(
                            std::io::ErrorKind::Other,
                            format!("{}", e),
                        )))
                    }
                    Ok((payload, _)) => {
                        if payload.len() > 0 {
                            buf.put_slice(payload);
                            Poll::Ready(Ok(()))
                        } else {
                            s.register_recv_waker(cx.waker());
                            Poll::Pending
                        }
                    }
                }
            }
            false => {
                s.register_recv_waker(cx.waker());
                Poll::Pending
            }
        })
    }
}

impl AsyncWrite for UdpSocket {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        let dest = self.dest().into();
        let p = self.with(|s| match s.can_send() {
            true => match s.send_slice(buf, dest) {
                Ok(_) => {
                    s.register_send_waker(cx.waker());
                    Poll::Pending
                }
                Err(e) => Poll::Ready(Err(std::io::Error::new(
                    std::io::ErrorKind::Other,
                    format!("{}", e),
                ))),
            },
            false => {
                s.register_send_waker(cx.waker());
                Poll::Pending
            }
        });
        match p {
            Poll::Ready(_) => {
                // We need to poll in order to process the ingress packets.
                let _ = self
                    .iface
                    .lock()
                    .unwrap()
                    .poll(smoltcp::time::Instant::now());
            }
            _ => (),
        };
        p
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        Poll::Ready(Ok(()))
    }

    fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        self.with(|s| s.close());
        Poll::Ready(Ok(()))
    }
}

impl TcpSocket {
    pub fn new(iface: IFace, s: smoltcp::socket::TcpSocket<'static>) -> Self {
        let handle = iface.lock().unwrap().add_socket(s);
@@ -68,7 +182,6 @@ impl AsyncRead for TcpSocket {
                    Ok(b) => {
                        if b.len() > 0 {
                            buf.put_slice(b);
                            info!("Async Read: {:?}", buf);
                            Poll::Ready(Ok(()))
                        } else {
                            s.register_recv_waker(cx.waker());