Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion payjoin/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ pub use url::{ParseError, Url};
pub(crate) mod error_codes;

pub(crate) mod output_substitution;
#[cfg(feature = "v1")]
pub use output_substitution::OutputSubstitution;

#[cfg(feature = "v2")]
Expand Down
6 changes: 6 additions & 0 deletions payjoin/src/core/receive/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,10 @@ pub struct OutputSubstitutionError(InternalOutputSubstitutionError);

#[derive(Debug, PartialEq, Eq)]
pub(crate) enum InternalOutputSubstitutionError {
#[cfg(feature = "v1")]
/// Output substitution is disabled and output value was decreased
DecreasedValueWhenDisabled,
#[cfg(feature = "v1")]
/// Output substitution is disabled and script pubkey was changed
ScriptPubKeyChangedWhenDisabled,
/// Current output substitution implementation doesn't support reducing the number of outputs
Expand All @@ -292,7 +294,9 @@ pub(crate) enum InternalOutputSubstitutionError {
impl fmt::Display for OutputSubstitutionError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.0 {
#[cfg(feature = "v1")]
InternalOutputSubstitutionError::DecreasedValueWhenDisabled => write!(f, "Decreasing the receiver output value is not allowed when output substitution is disabled"),
#[cfg(feature = "v1")]
InternalOutputSubstitutionError::ScriptPubKeyChangedWhenDisabled => write!(f, "Changing the receiver output script pubkey is not allowed when output substitution is disabled"),
InternalOutputSubstitutionError::NotEnoughOutputs => write!(
f,
Expand All @@ -311,7 +315,9 @@ impl From<InternalOutputSubstitutionError> for OutputSubstitutionError {
impl std::error::Error for OutputSubstitutionError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.0 {
#[cfg(feature = "v1")]
InternalOutputSubstitutionError::DecreasedValueWhenDisabled => None,
#[cfg(feature = "v1")]
InternalOutputSubstitutionError::ScriptPubKeyChangedWhenDisabled => None,
InternalOutputSubstitutionError::NotEnoughOutputs => None,
InternalOutputSubstitutionError::InvalidDrainScript => None,
Expand Down
168 changes: 168 additions & 0 deletions payjoin/src/core/receive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use std::collections::BTreeMap;
use std::str::FromStr;

use bitcoin::hashes::sha256d::Hash;
use bitcoin::secp256k1::rand::seq::SliceRandom;
use bitcoin::secp256k1::rand::{self, Rng};
use bitcoin::{
psbt, AddressType, FeeRate, OutPoint, Psbt, Script, ScriptBuf, Sequence, Transaction, TxIn,
TxOut, Weight,
Expand All @@ -29,6 +31,7 @@ pub use crate::psbt::PsbtInputError;
use crate::psbt::{
InputWeightError, InternalInputPair, InternalPsbtInputError, PrevTxOutError, PsbtExt,
};
use crate::receive::error::InternalOutputSubstitutionError;
use crate::{ImplementationError, Version};

mod error;
Expand Down Expand Up @@ -479,6 +482,147 @@ impl Original {
}
}

/// Shuffles `new` vector, then interleaves its elements with those from `original`,
/// maintaining the relative order in `original` but randomly inserting elements from `new`.
///
/// The combined result replaces the contents of `original`.
fn interleave_shuffle<T: Clone, R: rand::Rng>(original: &mut Vec<T>, new: &mut [T], rng: &mut R) {
// Shuffle the substitute_outputs
new.shuffle(rng);
// Create a new vector to store the combined result
let mut combined = Vec::with_capacity(original.len() + new.len());
// Initialize indices
let mut original_index = 0;
let mut new_index = 0;
// Interleave elements
while original_index < original.len() || new_index < new.len() {
if original_index < original.len() && (new_index >= new.len() || rng.gen_bool(0.5)) {
combined.push(original[original_index].clone());
original_index += 1;
} else {
combined.push(new[new_index].clone());
new_index += 1;
}
}
*original = combined;
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct WantsOutputs {
original_psbt: Psbt,
payjoin_psbt: Psbt,
params: Params,
change_vout: usize,
owned_vouts: Vec<usize>,
}

impl WantsOutputs {
pub(crate) fn from_original(original: Original, owned_vouts: Vec<usize>) -> Self {
Self {
original_psbt: original.psbt.clone(),
payjoin_psbt: original.psbt,
params: original.params,
change_vout: owned_vouts[0],
owned_vouts,
}
}

/// Returns whether the receiver is allowed to substitute original outputs or not.
pub(crate) fn output_substitution(&self) -> crate::OutputSubstitution {
self.params.output_substitution
}
/// Substitute the receiver output script with the provided script.
pub(crate) fn substitute_receiver_script(
&self,
output_script: &Script,
) -> Result<Self, OutputSubstitutionError> {
let output_value = self.original_psbt.unsigned_tx.output[self.change_vout].value;
let outputs = [TxOut { value: output_value, script_pubkey: output_script.into() }];
self.replace_receiver_outputs(outputs, output_script)
}

/// Replaces **all** receiver outputs with the one or more provided `replacement_outputs`, and
/// sets up the passed `drain_script` as the receiver-owned output which might have its value
/// adjusted based on the modifications the receiver makes in the subsequent typestates.
///
/// The sender's outputs are not touched. Existing receiver outputs will be replaced with the
/// outputs in the `replacement_outputs` argument. The number of replacement outputs should
/// match or exceed the number of receiver outputs in the original proposal PSBT.
///
/// The drain script is the receiver script which will have its value adjusted based on the
/// modifications the receiver makes on the transaction in the subsequent typestates. For
/// example, if the receiver adds their own input, then the drain script output will have its
/// value increased by the same amount. Or if an output needs to have its value reduced to
/// account for fees, the value of the output for this script will be reduced.
pub(crate) fn replace_receiver_outputs(
&self,
replacement_outputs: impl IntoIterator<Item = TxOut>,
drain_script: &Script,
) -> Result<Self, OutputSubstitutionError> {
let mut payjoin_psbt = self.original_psbt.clone();
let mut outputs = vec![];
let mut replacement_outputs: Vec<TxOut> = replacement_outputs.into_iter().collect();
let mut rng = rand::thread_rng();
// Substitute the existing receiver outputs, keeping the sender/receiver output ordering
for (i, original_output) in self.original_psbt.unsigned_tx.output.iter().enumerate() {
if self.owned_vouts.contains(&i) {
// Receiver output: substitute in-place a provided replacement output
if replacement_outputs.is_empty() {
return Err(InternalOutputSubstitutionError::NotEnoughOutputs.into());
}
match replacement_outputs
.iter()
.position(|txo| txo.script_pubkey == original_output.script_pubkey)
{
// Select an output with the same address if one was provided
Some(pos) => {
let txo = replacement_outputs.swap_remove(pos);
#[cfg(feature = "v1")]
if self.output_substitution() == crate::OutputSubstitution::Disabled
&& txo.value < original_output.value
{
return Err(
InternalOutputSubstitutionError::DecreasedValueWhenDisabled.into(),
);
}
outputs.push(txo);
}
// Otherwise randomly select one of the replacement outputs
None => {
#[cfg(feature = "v1")]
if self.output_substitution() == crate::OutputSubstitution::Disabled {
return Err(
InternalOutputSubstitutionError::ScriptPubKeyChangedWhenDisabled
.into(),
);
}
let index = rng.gen_range(0..replacement_outputs.len());
let txo = replacement_outputs.swap_remove(index);
outputs.push(txo);
}
}
} else {
// Sender output: leave it as is
outputs.push(original_output.clone());
}
}
// Insert all remaining outputs at random indices for privacy
interleave_shuffle(&mut outputs, &mut replacement_outputs, &mut rng);
// Identify the receiver output that will be used for change and fees
let change_vout = outputs.iter().position(|txo| txo.script_pubkey == *drain_script);
// Update the payjoin PSBT outputs
payjoin_psbt.outputs = vec![Default::default(); outputs.len()];
payjoin_psbt.unsigned_tx.output = outputs;
Ok(Self {
original_psbt: self.original_psbt.clone(),
payjoin_psbt,
params: self.params.clone(),
change_vout: change_vout.ok_or(InternalOutputSubstitutionError::InvalidDrainScript)?,
owned_vouts: self.owned_vouts.clone(),
})
}
}

#[cfg(test)]
mod tests {
use bitcoin::absolute::{LockTime, Time};
Expand All @@ -490,6 +634,8 @@ mod tests {
witness, Amount, PubkeyHash, ScriptBuf, ScriptHash, Txid, WScriptHash, XOnlyPublicKey,
};
use payjoin_test_utils::{DUMMY20, DUMMY32};
use rand::rngs::StdRng;
use rand::SeedableRng;

use super::*;
use crate::psbt::InternalPsbtInputError::InvalidScriptPubKey;
Expand Down Expand Up @@ -781,4 +927,26 @@ mod tests {
PsbtInputError::from(InvalidScriptPubKey(AddressType::P2tr))
)
}

#[test]
fn test_interleave_shuffle() {
let mut original1 = vec![1, 2, 3];
let mut original2 = original1.clone();
let mut original3 = original1.clone();
let mut new1 = vec![4, 5, 6];
let mut new2 = new1.clone();
let mut new3 = new1.clone();
let mut rng1 = StdRng::seed_from_u64(123);
let mut rng2 = StdRng::seed_from_u64(234);
let mut rng3 = StdRng::seed_from_u64(345);
// Operate on the same data multiple times with different RNG seeds.
interleave_shuffle(&mut original1, &mut new1, &mut rng1);
interleave_shuffle(&mut original2, &mut new2, &mut rng2);
interleave_shuffle(&mut original3, &mut new3, &mut rng3);
// The result should be different for each seed
// and the relative ordering from `original` always preserved/
assert_eq!(original1, vec![1, 6, 2, 5, 4, 3]);
assert_eq!(original2, vec![1, 5, 4, 2, 6, 3]);
assert_eq!(original3, vec![4, 5, 1, 2, 6, 3]);
}
}
Loading
Loading