♻️ Wrap relay config in a RwLock

main
_ 2020-11-25 02:30:57 +00:00
parent 7aafbba4d9
commit b40233cc62
1 changed files with 22 additions and 15 deletions

View File

@ -147,7 +147,7 @@ impl Default for ServerStatus {
} }
pub struct RelayState { pub struct RelayState {
config: Config, config: RwLock <Config>,
handlebars: Arc <Handlebars <'static>>, handlebars: Arc <Handlebars <'static>>,
// Key: Server ID // Key: Server ID
@ -166,7 +166,7 @@ impl From <&ConfigFile> for RelayState {
let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false); let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false);
Self { Self {
config: Config::from (config_file), config: Config::from (config_file).into (),
handlebars: Arc::new (load_templates (&PathBuf::new ()).unwrap ()), handlebars: Arc::new (load_templates (&PathBuf::new ()).unwrap ()),
request_rendezvous: Default::default (), request_rendezvous: Default::default (),
server_status: Default::default (), server_status: Default::default (),
@ -212,16 +212,20 @@ async fn handle_http_listen (
{ {
let trip_error = error_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey"); let trip_error = error_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey");
let expected_tripcode = match state.config.server_tripcodes.get (&watcher_code) { let expected_tripcode = {
None => { let config = state.config.read ().await;
error! ("Denied http_listen for non-existent server name {}", watcher_code);
return trip_error; match config.server_tripcodes.get (&watcher_code) {
}, None => {
Some (x) => x, error! ("Denied http_listen for non-existent server name {}", watcher_code);
return trip_error;
},
Some (x) => (*x).clone (),
}
}; };
let actual_tripcode = blake3::hash (api_key); let actual_tripcode = blake3::hash (api_key);
if expected_tripcode != &actual_tripcode { if expected_tripcode != actual_tripcode {
error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ()));
return trip_error; return trip_error;
} }
@ -382,8 +386,11 @@ async fn handle_http_request (
) )
-> Response <Body> -> Response <Body>
{ {
if ! state.config.server_tripcodes.contains_key (&watcher_code) { {
return error_reply (StatusCode::NOT_FOUND, "Unknown server"); let config = state.config.read ().await;
if ! config.server_tripcodes.contains_key (&watcher_code) {
return error_reply (StatusCode::NOT_FOUND, "Unknown server");
}
} }
let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) {
@ -669,16 +676,16 @@ pub async fn run_relay (
{ {
let mut tripcode_set = HashSet::new (); let mut tripcode_set = HashSet::new ();
let config = state.config.read ().await;
for (_, v) in state.config.server_tripcodes.iter () { for (_, v) in config.server_tripcodes.iter () {
if ! tripcode_set.insert (v) { if ! tripcode_set.insert (v) {
panic! ("Two servers have the same tripcode. That is not allowed."); panic! ("Two servers have the same tripcode. That is not allowed.");
} }
} }
info! ("Loaded {} server tripcodes", config.server_tripcodes.len ());
} }
info! ("Loaded {} server tripcodes", state.config.server_tripcodes.len ());
let make_svc = make_service_fn (|_conn| { let make_svc = make_service_fn (|_conn| {
let state = state.clone (); let state = state.clone ();