Skip to content

Commit 971c8d3

Browse files
committed
Payjoin-cli should cache ohttp-keys for re-use
This pr adds the functionality of ohttp-keys catching to payjoin-cli , ohttp-keys should not be fetched each time and a cached key should be use. Keys expire in 6 months .
1 parent baa63f6 commit 971c8d3

File tree

2 files changed

+91
-7
lines changed

2 files changed

+91
-7
lines changed

payjoin-cli/src/app/v2/ohttp.rs

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
use std::fs;
2+
use std::path::PathBuf;
13
use std::sync::{Arc, Mutex};
4+
use std::time::{Duration, SystemTime};
25

36
use anyhow::{anyhow, Result};
7+
use serde::{Deserialize, Serialize};
48

59
use super::Config;
610

11+
// 6 months
12+
const CACHE_DURATION: Duration = Duration::from_secs(6 * 30 * 24 * 60 * 60);
13+
714
#[derive(Debug, Clone)]
815
pub struct RelayManager {
916
selected_relay: Option<url::Url>,
@@ -38,12 +45,12 @@ pub(crate) async fn unwrap_ohttp_keys_or_else_fetch(
3845
ohttp_keys,
3946
relay_url: config.v2()?.ohttp_relays[0].clone(),
4047
});
41-
} else {
42-
println!("Bootstrapping private network transport over Oblivious HTTP");
43-
let fetched_keys = fetch_ohttp_keys(config, directory, relay_manager).await?;
44-
45-
Ok(fetched_keys)
4648
}
49+
50+
println!("Bootstrapping private network transport over Oblivious HTTP");
51+
let fetched_keys = fetch_ohttp_keys(config, directory, relay_manager).await?;
52+
53+
Ok(fetched_keys)
4754
}
4855

4956
async fn fetch_ohttp_keys(
@@ -77,6 +84,17 @@ async fn fetch_ohttp_keys(
7784
.expect("Lock should not be poisoned")
7885
.set_selected_relay(selected_relay.clone());
7986

87+
// try cache for this selected relay first
88+
if let Some(cached) = read_cached_ohttp_keys(&selected_relay) {
89+
println!("using Cached keys for relay: {}", selected_relay);
90+
if !is_expired(&cached) && cached.relay_url == selected_relay {
91+
return Ok(ValidatedOhttpKeys {
92+
ohttp_keys: cached.keys,
93+
relay_url: cached.relay_url,
94+
});
95+
}
96+
}
97+
8098
let ohttp_keys = {
8199
#[cfg(feature = "_manual-tls")]
82100
{
@@ -101,8 +119,17 @@ async fn fetch_ohttp_keys(
101119
};
102120

103121
match ohttp_keys {
104-
Ok(keys) =>
105-
return Ok(ValidatedOhttpKeys { ohttp_keys: keys, relay_url: selected_relay }),
122+
Ok(keys) => {
123+
// Cache the keys if they are not already cached for this relay
124+
if read_cached_ohttp_keys(&selected_relay).is_none() {
125+
if let Err(e) = cache_ohttp_keys(&keys, &selected_relay) {
126+
tracing::debug!(
127+
"Failed to cache OHTTP keys for relay {selected_relay}: {e:?}"
128+
);
129+
}
130+
}
131+
return Ok(ValidatedOhttpKeys { ohttp_keys: keys, relay_url: selected_relay });
132+
}
106133
Err(payjoin::io::Error::UnexpectedStatusCode(e)) => {
107134
return Err(payjoin::io::Error::UnexpectedStatusCode(e).into());
108135
}
@@ -116,3 +143,49 @@ async fn fetch_ohttp_keys(
116143
}
117144
}
118145
}
146+
147+
#[derive(Serialize, Deserialize, Debug)]
148+
struct CachedOhttpKeys {
149+
keys: payjoin::OhttpKeys,
150+
relay_url: payjoin::Url,
151+
fetched_at: u64,
152+
}
153+
154+
fn get_cache_file(relay_url: &payjoin::Url) -> PathBuf {
155+
dirs::cache_dir()
156+
.unwrap()
157+
.join("payjoin-cli")
158+
.join(relay_url.host_str().unwrap())
159+
.join("ohttp-keys.json")
160+
}
161+
162+
fn read_cached_ohttp_keys(relay_url: &payjoin::Url) -> Option<CachedOhttpKeys> {
163+
let cache_file = get_cache_file(relay_url);
164+
if !cache_file.exists() {
165+
return None;
166+
}
167+
let data = fs::read_to_string(cache_file).ok().unwrap();
168+
serde_json::from_str(&data).ok()
169+
}
170+
171+
fn cache_ohttp_keys(ohttp_keys: &payjoin::OhttpKeys, relay_url: &payjoin::Url) -> Result<()> {
172+
let cached = CachedOhttpKeys {
173+
keys: ohttp_keys.clone(),
174+
relay_url: relay_url.clone(),
175+
fetched_at: SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs(),
176+
};
177+
178+
let serialized = serde_json::to_string(&cached)?;
179+
let path = get_cache_file(relay_url);
180+
fs::create_dir_all(path.parent().unwrap())?;
181+
fs::write(path, serialized)?;
182+
Ok(())
183+
}
184+
185+
fn is_expired(cached_keys: &CachedOhttpKeys) -> bool {
186+
let now = SystemTime::now()
187+
.duration_since(SystemTime::UNIX_EPOCH)
188+
.unwrap_or(Duration::ZERO)
189+
.as_secs();
190+
now.saturating_sub(cached_keys.fetched_at) > CACHE_DURATION.as_secs()
191+
}

payjoin-cli/tests/e2e.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ mod e2e {
6464
res
6565
}
6666

67+
fn clear_payjoin_cache() -> std::io::Result<()> {
68+
let cache_dir = dirs::cache_dir().unwrap().join("payjoin-cli");
69+
70+
if cache_dir.exists() {
71+
std::fs::remove_dir_all(cache_dir)?;
72+
}
73+
Ok(())
74+
}
75+
6776
#[cfg(feature = "v1")]
6877
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
6978
async fn send_receive_payjoin_v1() -> Result<(), BoxError> {
@@ -203,6 +212,8 @@ mod e2e {
203212
use tempfile::TempDir;
204213
use tokio::process::Child;
205214

215+
clear_payjoin_cache()?;
216+
206217
type Result<T> = std::result::Result<T, BoxError>;
207218

208219
init_tracing();

0 commit comments

Comments
 (0)