Commit 43506601 authored by Nick Mathewson's avatar Nick Mathewson 🤹
Browse files

Move tor-dirmgr to use a sync::Mutex.

The futures::lock::Mutex was unnecessary, since we never held it
when we were suspending.
parent c8cfbda3
Loading
Loading
Loading
Loading
+6 −8
Original line number Diff line number Diff line
@@ -20,13 +20,13 @@ use tor_rtcompat::{Runtime, SleepProviderExt};
use tracing::{info, trace, warn};

/// Try to read a set of documents from `dirmgr` by ID.
async fn load_all<R: Runtime>(
fn load_all<R: Runtime>(
    dirmgr: &DirMgr<R>,
    missing: Vec<DocId>,
) -> Result<HashMap<DocId, DocumentText>> {
    let mut loaded = HashMap::new();
    for query in docid::partition_by_type(missing.into_iter()).values() {
        dirmgr.load_documents_into(query, &mut loaded).await?;
        dirmgr.load_documents_into(query, &mut loaded)?;
    }
    Ok(loaded)
}
@@ -60,7 +60,7 @@ async fn fetch_multiple<R: Runtime>(
) -> Result<Vec<(ClientRequest, DirResponse)>> {
    let mut requests = Vec::new();
    for (_type, query) in docid::partition_by_type(missing.into_iter()) {
        requests.extend(dirmgr.query_into_requests(query).await?);
        requests.extend(dirmgr.query_into_requests(query)?);
    }

    // TODO: instead of waiting for all the queries to finish, we
@@ -98,7 +98,7 @@ async fn load_once<R: Runtime>(
            "Found {} missing documents; trying to load them",
            missing.len()
        );
        let documents = load_all(dirmgr, missing).await?;
        let documents = load_all(dirmgr, missing)?;
        state.add_from_cache(documents)
    };
    dirmgr.notify().await;
@@ -153,11 +153,9 @@ async fn download_attempt<R: Runtime>(
    let fetched = fetch_multiple(Arc::clone(dirmgr), missing, parallelism).await?;
    for (client_req, dir_response) in fetched {
        let text = String::from_utf8(dir_response.into_output())?;
        match dirmgr.expand_response_text(&client_req, text).await {
        match dirmgr.expand_response_text(&client_req, text) {
            Ok(text) => {
                let outcome = state
                    .add_from_download(&text, &client_req, Some(&dirmgr.store))
                    .await;
                let outcome = state.add_from_download(&text, &client_req, Some(&dirmgr.store));
                dirmgr.notify().await;
                match outcome {
                    Ok(b) => changed |= b,
+22 −24
Original line number Diff line number Diff line
@@ -70,13 +70,12 @@ use tor_circmgr::CircMgr;
use tor_netdir::NetDir;
use tor_netdoc::doc::netstatus::ConsensusFlavor;

use async_trait::async_trait;
use futures::{channel::oneshot, lock::Mutex, task::SpawnExt};
use futures::{channel::oneshot, task::SpawnExt};
use tor_rtcompat::{Runtime, SleepProviderExt};
use tracing::{info, trace, warn};

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::{collections::HashMap, sync::Weak};
use std::{fmt::Debug, time::SystemTime};

@@ -114,10 +113,6 @@ pub struct DirMgr<R: Runtime> {
    /// Handle to our sqlite cache.
    // XXXX I'd like to use an rwlock, but that's not feasible, since
    // rusqlite::Connection isn't Sync.
    // TODO: Does this have to be a futures::Mutex?  I would rather have
    // a rule that we never hold the guard for this mutex across an async
    // suspend point.  But that will be hard to enforce until the
    // `must_not_suspend` lint is in stable.
    store: Mutex<SqliteStore>,
    /// Our latest sufficiently bootstrapped directory, if we have one.
    ///
@@ -261,7 +256,7 @@ impl<R: Runtime> DirMgr<R> {
            {
                let dirmgr = upgrade_weak_ref(weak)?;
                trace!("Trying to take ownership of the directory cache lock");
                if dirmgr.try_upgrade_to_readwrite().await? {
                if dirmgr.try_upgrade_to_readwrite()? {
                    // We now own the lock!  (Maybe we owned it before; the
                    // upgrade_to_readwrite() function is idempotent.)  We can
                    // do our own bootstrapping.
@@ -397,8 +392,11 @@ impl<R: Runtime> DirMgr<R> {
    /// Return true if we got the lock, or if we already had it.
    ///
    /// Return false if another process has the lock
    async fn try_upgrade_to_readwrite(&self) -> Result<bool> {
        self.store.lock().await.upgrade_to_readwrite()
    fn try_upgrade_to_readwrite(&self) -> Result<bool> {
        self.store
            .lock()
            .expect("Directory storage lock poisoned")
            .upgrade_to_readwrite()
    }

    /// Construct a DirMgr from a DirMgrConfig.
@@ -464,11 +462,11 @@ impl<R: Runtime> DirMgr<R> {

    /// Try to load the text of a single document described by `doc` from
    /// storage.
    pub async fn text(&self, doc: &DocId) -> Result<Option<DocumentText>> {
    pub fn text(&self, doc: &DocId) -> Result<Option<DocumentText>> {
        use itertools::Itertools;
        let mut result = HashMap::new();
        let query = (*doc).into();
        self.load_documents_into(&query, &mut result).await?;
        self.load_documents_into(&query, &mut result)?;
        let item = result.into_iter().at_most_one().map_err(|_| {
            Error::CacheCorruption("Found more than one entry in storage for given docid")
        })?;
@@ -488,14 +486,14 @@ impl<R: Runtime> DirMgr<R> {
    ///
    /// If many of the documents have the same type, this can be more
    /// efficient than calling [`text`](Self::text).
    pub async fn texts<T>(&self, docs: T) -> Result<HashMap<DocId, DocumentText>>
    pub fn texts<T>(&self, docs: T) -> Result<HashMap<DocId, DocumentText>>
    where
        T: IntoIterator<Item = DocId>,
    {
        let partitioned = docid::partition_by_type(docs);
        let mut result = HashMap::new();
        for (_, query) in partitioned.into_iter() {
            self.load_documents_into(&query, &mut result).await?
            self.load_documents_into(&query, &mut result)?;
        }
        Ok(result)
    }
@@ -518,13 +516,13 @@ impl<R: Runtime> DirMgr<R> {
    }

    /// Load all the documents for a single DocumentQuery from the store.
    async fn load_documents_into(
    fn load_documents_into(
        &self,
        query: &DocQuery,
        result: &mut HashMap<DocId, DocumentText>,
    ) -> Result<()> {
        use DocQuery::*;
        let store = self.store.lock().await;
        let store = self.store.lock().expect("Directory storage lock poisoned");
        match query {
            LatestConsensus {
                flavor,
@@ -573,12 +571,12 @@ impl<R: Runtime> DirMgr<R> {
    ///
    /// This conversion has to be a function of the dirmgr, since it may
    /// require knowledge about our current state.
    async fn query_into_requests(&self, q: DocQuery) -> Result<Vec<ClientRequest>> {
    fn query_into_requests(&self, q: DocQuery) -> Result<Vec<ClientRequest>> {
        let mut res = Vec::new();
        for q in q.split_for_download() {
            match q {
                DocQuery::LatestConsensus { flavor, .. } => {
                    res.push(self.make_consensus_request(flavor).await?);
                    res.push(self.make_consensus_request(flavor)?);
                }
                DocQuery::AuthCert(ids) => {
                    res.push(ClientRequest::AuthCert(ids.into_iter().collect()));
@@ -596,10 +594,11 @@ impl<R: Runtime> DirMgr<R> {

    /// Construct an appropriate ClientRequest to download a consensus
    /// of the given flavor.
    async fn make_consensus_request(&self, flavor: ConsensusFlavor) -> Result<ClientRequest> {
    fn make_consensus_request(&self, flavor: ConsensusFlavor) -> Result<ClientRequest> {
        #![allow(clippy::unnecessary_wraps)]
        let mut request = tor_dirclient::request::ConsensusRequest::new(flavor);

        let r = self.store.lock().await;
        let r = self.store.lock().expect("Directory storage lock poisoned");
        match r.latest_consensus_meta(flavor) {
            Ok(Some(meta)) => {
                request.set_last_consensus_date(meta.lifetime().valid_after());
@@ -621,12 +620,12 @@ impl<R: Runtime> DirMgr<R> {
    /// Currently, this handles expanding consensus diffs, and nothing
    /// else.  We do it at this stage of our downloading operation
    /// because it requires access to the store.
    async fn expand_response_text(&self, req: &ClientRequest, text: String) -> Result<String> {
    fn expand_response_text(&self, req: &ClientRequest, text: String) -> Result<String> {
        if let ClientRequest::Consensus(req) = req {
            if tor_consdiff::looks_like_diff(&text) {
                if let Some(old_d) = req.old_consensus_digests().next() {
                    let db_val = {
                        let s = self.store.lock().await;
                        let s = self.store.lock().expect("Directory storage lock poisoned");
                        s.consensus_by_sha3_digest_of_signed_part(old_d)?
                    };
                    if let Some((old_consensus, meta)) = db_val {
@@ -672,7 +671,6 @@ enum Readiness {
/// Resetting happens when this state needs to go back to an initial
/// state in order to start over -- either because of an error or
/// because the information it has downloaded is no longer timely.
#[async_trait]
trait DirState: Send {
    /// Return a human-readable description of this state.
    fn describe(&self) -> String;
@@ -707,7 +705,7 @@ trait DirState: Send {
    // TODO: It would be better to not have this function be async,
    // once the `must_not_suspend` lint is stable.
    // TODO: this should take a "DirSource" too.
    async fn add_from_download(
    fn add_from_download(
        &mut self,
        text: &str,
        request: &ClientRequest,
+8 −13
Original line number Diff line number Diff line
@@ -10,13 +10,11 @@
//! [`bootstrap`](crate::bootstrap) module for functions that actually
//! load or download directory information.

use async_trait::async_trait;
use chrono::{DateTime, Utc};
use futures::lock::Mutex;
use rand::Rng;
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::sync::Weak;
use std::sync::{Mutex, Weak};
use std::time::{Duration, SystemTime};
use tor_netdir::{MdReceiver, NetDir, PartialNetDir};
use tor_netdoc::doc::netstatus::Lifetime;
@@ -132,7 +130,6 @@ impl<DM: WriteNetDir> GetConsensusState<DM> {
    }
}

#[async_trait]
impl<DM: WriteNetDir> DirState for GetConsensusState<DM> {
    fn describe(&self) -> String {
        if self.next.is_some() {
@@ -185,7 +182,7 @@ impl<DM: WriteNetDir> DirState for GetConsensusState<DM> {
        self.add_consensus_text(true, text.as_str()?)
            .map(|meta| meta.is_some())
    }
    async fn add_from_download(
    fn add_from_download(
        &mut self,
        text: &str,
        _request: &ClientRequest,
@@ -193,7 +190,7 @@ impl<DM: WriteNetDir> DirState for GetConsensusState<DM> {
    ) -> Result<bool> {
        if let Some(meta) = self.add_consensus_text(false, text)? {
            if let Some(store) = storage {
                let mut w = store.lock().await;
                let mut w = store.lock().expect("Directory storage lock poisoned");
                w.store_consensus(meta, ConsensusFlavor::Microdesc, true, text)?;
            }
            Ok(true)
@@ -301,7 +298,6 @@ struct GetCertsState<DM: WriteNetDir> {
    writedir: Weak<DM>,
}

#[async_trait]
impl<DM: WriteNetDir> DirState for GetCertsState<DM> {
    fn describe(&self) -> String {
        let total = self.certs.len() + self.missing_certs.len();
@@ -348,7 +344,7 @@ impl<DM: WriteNetDir> DirState for GetCertsState<DM> {
        }
        Ok(changed)
    }
    async fn add_from_download(
    fn add_from_download(
        &mut self,
        text: &str,
        request: &ClientRequest,
@@ -397,7 +393,7 @@ impl<DM: WriteNetDir> DirState for GetCertsState<DM> {
                .iter()
                .map(|(cert, s)| (AuthCertMeta::from_authcert(cert), *s))
                .collect();
            let mut w = store.lock().await;
            let mut w = store.lock().expect("Directory storage lock poisoned");
            w.store_authcerts(&v[..])?;
        }

@@ -538,7 +534,6 @@ impl<DM: WriteNetDir> GetMicrodescsState<DM> {
    }
}

#[async_trait]
impl<DM: WriteNetDir> DirState for GetMicrodescsState<DM> {
    fn describe(&self) -> String {
        format!(
@@ -591,7 +586,7 @@ impl<DM: WriteNetDir> DirState for GetMicrodescsState<DM> {

        Ok(changed)
    }
    async fn add_from_download(
    fn add_from_download(
        &mut self,
        text: &str,
        request: &ClientRequest,
@@ -621,7 +616,7 @@ impl<DM: WriteNetDir> DirState for GetMicrodescsState<DM> {

        let mark_listed = self.meta.lifetime().valid_after();
        if let Some(store) = storage {
            let mut s = store.lock().await;
            let mut s = store.lock().expect("Directory storage lock poisoned");
            if !self.newly_listed.is_empty() {
                s.update_microdescs_listed(self.newly_listed.iter(), mark_listed)?;
                self.newly_listed.clear();
@@ -636,7 +631,7 @@ impl<DM: WriteNetDir> DirState for GetMicrodescsState<DM> {
        if self.register_microdescs(new_mds.into_iter().map(|(_, md)| md)) {
            // oh hey, this is no longer pending.
            if let Some(store) = storage {
                let mut store = store.lock().await;
                let mut store = store.lock().expect("Directory storage lock poisoned");
                info!("Marked consensus usable.");
                store.mark_consensus_usable(&self.meta)?;
                // DOCDOC: explain why we're doing this here.