Commit a3c472c6 authored by Nick Mathewson's avatar Nick Mathewson 🥔
Browse files

Tests and refactoring for IsolationMap.

parent be482381
Loading
Loading
Loading
Loading
+50 −13
Original line number Diff line number Diff line
@@ -5,14 +5,13 @@

use futures::future::FutureExt;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Error as IoError};
use futures::lock::Mutex;
use futures::stream::StreamExt;
use futures::task::SpawnExt;
use std::collections::HashMap;
use std::convert::TryInto;
use std::io::Result as IoResult;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::sync::{self, Arc};
use std::time::{Duration, Instant};
use tracing::{error, info, warn};

@@ -52,7 +51,7 @@ type IsolationKey = (usize, IpAddr, SocksAuth);
/// Shared and garbage-collected Map used to isolate connections.
struct IsolationMap {
    /// Inner map guarded by a Mutex
    inner: Mutex<IsolationMapInner>,
    inner: sync::Mutex<IsolationMapInner>,
}

/// Inner map, generally guarded by a Mutex
@@ -63,13 +62,17 @@ struct IsolationMapInner {
    next_gc: Instant,
}

/// How frequently should we discard entries from the isolation map, and
/// how old should we let them get?
const ISOMAP_GC_INTERVAL: Duration = Duration::from_secs(60 * 30);

impl IsolationMap {
    /// Create a new, empty, IsolationMap
    fn new() -> Self {
        IsolationMap {
            inner: Mutex::new(IsolationMapInner {
            inner: sync::Mutex::new(IsolationMapInner {
                map: HashMap::new(),
                next_gc: Instant::now() + Duration::new(60 * 30, 0),
                next_gc: Instant::now() + ISOMAP_GC_INTERVAL,
            }),
        }
    }
@@ -78,13 +81,12 @@ impl IsolationMap {
    /// if none exists for this key.
    ///
    /// Every 30 minutes, on next call to this functions, entry older than 30 minutes are removed
    async fn get_or_create(&self, key: IsolationKey) -> IsolationToken {
        let now = Instant::now();
        let mut inner = self.inner.lock().await;
    fn get_or_create(&self, key: IsolationKey, now: Instant) -> IsolationToken {
        let mut inner = self.inner.lock().expect("Posioned lock on isolation map.");
        if inner.next_gc < now {
            inner.next_gc = now + Duration::new(60 * 30, 0);
            inner.next_gc = now + ISOMAP_GC_INTERVAL;

            let old_limit = now - Duration::new(60 * 30, 0);
            let old_limit = now - ISOMAP_GC_INTERVAL;
            inner.map.retain(|_, val| val.1 > old_limit);
        }
        let entry = inner
@@ -178,9 +180,7 @@ where
    // the same values for all of these properties.)
    let auth = request.auth().clone();
    let (source_address, ip) = isolation_info;
    let isolation_token = isolation_map
        .get_or_create((source_address, ip, auth))
        .await;
    let isolation_token = isolation_map.get_or_create((source_address, ip, auth), Instant::now());

    // Determine whether we want to ask for IPv4/IPv6 addresses.
    let mut prefs = stream_preference(&request, &addr);
@@ -448,3 +448,40 @@ pub(crate) async fn run_socks_proxy<R: Runtime>(

    Ok(())
}

#[cfg(test)]
mod test {
    #![allow(clippy::unwrap_used)]
    use super::*;

    #[test]
    fn test_isomap() {
        let m = IsolationMap::new();

        let k1 = (6, "10.0.0.1".parse().unwrap(), SocksAuth::NoAuth);
        let k2 = (
            6,
            "10.0.0.1".parse().unwrap(),
            SocksAuth::Socks4(vec![1, 2, 3]),
        );

        let t1 = Instant::now() + ISOMAP_GC_INTERVAL / 2;

        let tok1 = m.get_or_create(k1.clone(), t1);
        let tok2 = m.get_or_create(k2, t1);
        assert_ne!(tok1, tok2);
        assert_eq!(tok1, m.get_or_create(k1.clone(), t1));

        // Now make sure the GC happens, but the items aren't deleted since
        // they aren't quite old enough
        let t2 = t1 + (ISOMAP_GC_INTERVAL * 3) / 4;
        assert_eq!(tok1, m.get_or_create(k1.clone(), t2));

        // Now make sure that the GC happens, and the items _are_ deleted
        // as to old.
        let t3 = t2 + ISOMAP_GC_INTERVAL * 2;
        let tok3 = m.get_or_create(k1, t3);
        assert_ne!(tok3, tok2);
        assert_ne!(tok3, tok1);
    }
}