diff --git a/crates/tor-circmgr/src/preemptive.rs b/crates/tor-circmgr/src/preemptive.rs index 42066e6bd6a099189fc6e85449cd1ee359b56c18..b02ed61b35a4162c9510436a73cdab4b366f3783 100644 --- a/crates/tor-circmgr/src/preemptive.rs +++ b/crates/tor-circmgr/src/preemptive.rs @@ -99,8 +99,6 @@ mod test { .unwrap(); let predictor = PreemptiveCircuitPredictor::new(cfg); - let mut results = predictor.predict(); - results.sort(); assert_eq!( predictor.predict(), vec![TargetCircUsage::Preemptive { @@ -117,21 +115,16 @@ mod test { let predictor = PreemptiveCircuitPredictor::new(cfg); - let mut results = predictor.predict(); - results.sort(); - assert_eq!( - results, - vec![ - TargetCircUsage::Preemptive { - port: None, - circs: 2 - }, - TargetCircUsage::Preemptive { - port: Some(TargetPort::ipv4(80)), - circs: 2 - }, - ] - ); + let results = predictor.predict(); + assert_eq!(results.len(), 2); + assert!(results.contains(&TargetCircUsage::Preemptive { + port: None, + circs: 2 + })); + assert!(results.contains(&TargetCircUsage::Preemptive { + port: Some(TargetPort::ipv4(80)), + circs: 2 + })); } #[test] @@ -153,21 +146,16 @@ mod test { predictor.note_usage(Some(TargetPort::ipv4(1234)), Instant::now()); - let mut results = predictor.predict(); - results.sort(); - assert_eq!( - results, - vec![ - TargetCircUsage::Preemptive { - port: None, - circs: 2 - }, - TargetCircUsage::Preemptive { - port: Some(TargetPort::ipv4(1234)), - circs: 2 - } - ] - ); + let results = predictor.predict(); + assert_eq!(results.len(), 2); + assert!(results.contains(&TargetCircUsage::Preemptive { + port: None, + circs: 2 + })); + assert!(results.contains(&TargetCircUsage::Preemptive { + port: Some(TargetPort::ipv4(1234)), + circs: 2 + })); } #[test] diff --git a/crates/tor-circmgr/src/usage.rs b/crates/tor-circmgr/src/usage.rs index 89dd86d69a91dab19728006126bded61630662d2..314b7947655d40476925aca2f10643a00365dce5 100644 --- a/crates/tor-circmgr/src/usage.rs +++ b/crates/tor-circmgr/src/usage.rs @@ -102,7 +102,9 @@ impl Display for TargetPorts { use std::any::Any; +/// TODO pub trait AsAny { + /// TODO fn as_any(&self) -> &dyn Any; } @@ -125,7 +127,7 @@ impl<T: IsolationHelper + std::fmt::Debug + Send + Sync + 'static> Isolation for if let Some(other) = other.as_any().downcast_ref() { self.isolated_same_type(other) } else { - false + true } } fn join(&self, other: &dyn Isolation) -> JoinResult { @@ -244,17 +246,17 @@ impl IsolationToken { impl IsolationHelper for IsolationToken { fn isolated_same_type(&self, other: &Self) -> bool { - self == other + self != other } fn join_same_type(&self, other: &Self) -> JoinResult { if self.isolated_same_type(other) { + JoinResult::NoJoin + } else { // for IsolationToken, any of the three would be correct, but the last one is probably // slower. JoinResult::UseLeft // JoinResult::UseRight // JoinResult::New(Arc::new(*self)) - } else { - JoinResult::NoJoin } } } @@ -266,8 +268,8 @@ impl IsolationHelper for IsolationToken { #[derive(Clone, Debug, derive_builder::Builder)] pub struct StreamIsolation { /// Any isolation token set on the stream. - #[builder(setter(strip_option), default)] - stream_token: Option<Arc<dyn Isolation>>, + #[builder(default = "Arc::new(IsolationToken::no_isolation())")] + stream_token: Arc<dyn Isolation>, /// Any additional isolation token set on an object that "owns" this /// stream. This is typically owned by a `TorClient`. #[builder(default = "IsolationToken::no_isolation()")] @@ -292,11 +294,7 @@ impl StreamIsolation { /// `other`. fn may_share_circuit(&self, other: &StreamIsolation) -> bool { self.owner_token == other.owner_token - && match (&self.stream_token, &other.stream_token) { - (None, None) => true, - (Some(this), Some(other)) => !this.isolated(other.as_ref()), - _ => false, - } + && !self.stream_token.isolated(other.stream_token.as_ref()) } /// Return a StreamIsolation that is the intersection of self and other. @@ -305,18 +303,14 @@ impl StreamIsolation { if self.owner_token != other.owner_token { return None; } - match (&self.stream_token, &other.stream_token) { - (None, None) => Some(self.clone()), - (Some(this), Some(other_stream)) => match this.join(other_stream.as_ref()) { - JoinResult::New(isolation) => Some(StreamIsolation { - stream_token: Some(isolation), - owner_token: self.owner_token, - }), - JoinResult::UseLeft => Some(self.clone()), - JoinResult::UseRight => Some(other.clone()), - JoinResult::NoJoin => None, - }, - _ => None, + match self.stream_token.join(other.stream_token.as_ref()) { + JoinResult::New(isolation) => Some(StreamIsolation { + stream_token: isolation, + owner_token: self.owner_token, + }), + JoinResult::UseLeft => Some(self.clone()), + JoinResult::UseRight => Some(other.clone()), + JoinResult::NoJoin => None, } } } @@ -679,11 +673,11 @@ mod test { let targ_dir = TargetCircUsage::Dir; let supp_exit = SupportedCircUsage::Exit { policy: policy.clone(), - isolation: Some(isolation), + isolation: Some(isolation.clone()), }; let supp_exit_iso2 = SupportedCircUsage::Exit { policy: policy.clone(), - isolation: Some(isolation2), + isolation: Some(isolation2.clone()), }; let supp_exit_no_iso = SupportedCircUsage::Exit { policy, @@ -693,7 +687,7 @@ mod test { let targ_80_v4 = TargetCircUsage::Exit { ports: vec![TargetPort::ipv4(80)], - isolation, + isolation: isolation.clone(), }; let targ_80_v4_iso2 = TargetCircUsage::Exit { ports: vec![TargetPort::ipv4(80)], @@ -701,11 +695,11 @@ mod test { }; let targ_80_23_v4 = TargetCircUsage::Exit { ports: vec![TargetPort::ipv4(80), TargetPort::ipv4(23)], - isolation, + isolation: isolation.clone(), }; let targ_80_23_mixed = TargetCircUsage::Exit { ports: vec![TargetPort::ipv4(80), TargetPort::ipv6(23)], - isolation, + isolation: isolation.clone(), }; let targ_999_v6 = TargetCircUsage::Exit { ports: vec![TargetPort::ipv6(999)], @@ -760,11 +754,11 @@ mod test { let targ_dir = TargetCircUsage::Dir; let supp_exit = SupportedCircUsage::Exit { policy: policy.clone(), - isolation: Some(isolation), + isolation: Some(isolation.clone()), }; let supp_exit_iso2 = SupportedCircUsage::Exit { policy: policy.clone(), - isolation: Some(isolation2), + isolation: Some(isolation2.clone()), }; let supp_exit_no_iso = SupportedCircUsage::Exit { policy, @@ -866,7 +860,7 @@ mod test { let exit_usage = TargetCircUsage::Exit { ports: vec![TargetPort::ipv4(995)], - isolation, + isolation: isolation.clone(), }; let (p_exit, u_exit, _, _) = exit_usage .build_path(&mut rng, di, guards, &config) @@ -874,9 +868,9 @@ mod test { assert!(matches!( u_exit, SupportedCircUsage::Exit { - isolation: iso, + isolation: ref iso, .. - } if iso == Some(isolation) + } if iso == &Some(isolation) )); assert!(u_exit.supports(&exit_usage)); assert_eq!(p_exit.len(), 3); @@ -934,16 +928,28 @@ mod test { let no_isolation = StreamIsolation::no_isolation(); let no_isolation2 = StreamIsolation::builder() .owner_token(IsolationToken::no_isolation()) - .stream_token(IsolationToken::no_isolation()) + .stream_token(Arc::new(IsolationToken::no_isolation())) .build() .unwrap(); - assert_eq!(no_isolation, no_isolation2); + assert_eq!(no_isolation.owner_token, no_isolation2.owner_token); + assert_eq!( + no_isolation + .stream_token + .as_ref() + .as_any() + .downcast_ref::<IsolationToken>(), + no_isolation2 + .stream_token + .as_ref() + .as_any() + .downcast_ref::<IsolationToken>() + ); assert!(no_isolation.may_share_circuit(&no_isolation2)); let tok = IsolationToken::new(); let some_isolation = StreamIsolation::builder().owner_token(tok).build().unwrap(); let some_isolation2 = StreamIsolation::builder() - .stream_token(tok) + .stream_token(Arc::new(tok)) .build() .unwrap(); assert!(!no_isolation.may_share_circuit(&some_isolation));