diff --git a/src/relay/mod.rs b/src/relay/mod.rs index 90a90c3..4af0314 100644 --- a/src/relay/mod.rs +++ b/src/relay/mod.rs @@ -147,7 +147,7 @@ impl Default for ServerStatus { } pub struct RelayState { - config: Config, + config: RwLock , handlebars: Arc >, // Key: Server ID @@ -166,7 +166,7 @@ impl From <&ConfigFile> for RelayState { let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false); Self { - config: Config::from (config_file), + config: Config::from (config_file).into (), handlebars: Arc::new (load_templates (&PathBuf::new ()).unwrap ()), request_rendezvous: Default::default (), server_status: Default::default (), @@ -212,16 +212,20 @@ async fn handle_http_listen ( { let trip_error = error_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey"); - let expected_tripcode = match state.config.server_tripcodes.get (&watcher_code) { - None => { - error! ("Denied http_listen for non-existent server name {}", watcher_code); - return trip_error; - }, - Some (x) => x, + let expected_tripcode = { + let config = state.config.read ().await; + + match config.server_tripcodes.get (&watcher_code) { + None => { + error! ("Denied http_listen for non-existent server name {}", watcher_code); + return trip_error; + }, + Some (x) => (*x).clone (), + } }; let actual_tripcode = blake3::hash (api_key); - if expected_tripcode != &actual_tripcode { + if expected_tripcode != actual_tripcode { error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); return trip_error; } @@ -382,8 +386,11 @@ async fn handle_http_request ( ) -> Response { - if ! state.config.server_tripcodes.contains_key (&watcher_code) { - return error_reply (StatusCode::NOT_FOUND, "Unknown server"); + { + let config = state.config.read ().await; + if ! config.server_tripcodes.contains_key (&watcher_code) { + return error_reply (StatusCode::NOT_FOUND, "Unknown server"); + } } let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { @@ -669,16 +676,16 @@ pub async fn run_relay ( { let mut tripcode_set = HashSet::new (); - - for (_, v) in state.config.server_tripcodes.iter () { + let config = state.config.read ().await; + for (_, v) in config.server_tripcodes.iter () { if ! tripcode_set.insert (v) { panic! ("Two servers have the same tripcode. That is not allowed."); } } + + info! ("Loaded {} server tripcodes", config.server_tripcodes.len ()); } - info! ("Loaded {} server tripcodes", state.config.server_tripcodes.len ()); - let make_svc = make_service_fn (|_conn| { let state = state.clone ();