Skip to content
Open
279 changes: 270 additions & 9 deletions crates/defguard_common/src/db/models/device.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt, net::IpAddr};
use std::{collections::HashSet, fmt, net::IpAddr};

use base64::{Engine, prelude::BASE64_STANDARD};
use chrono::{NaiveDate, NaiveDateTime, Timelike, Utc};
Expand Down Expand Up @@ -843,6 +843,7 @@ impl Device<Id> {
&self,
transaction: &mut PgConnection,
network: &WireguardNetwork<Id>,
used_ips: &HashSet<IpAddr>,
reserved_ips: Option<&[IpAddr]>,
current_ips: Option<&[IpAddr]>,
) -> Result<WireguardNetworkDevice, ModelError> {
Expand Down Expand Up @@ -872,15 +873,16 @@ impl Device<Id> {
}
let mut picked = None;
for ip in address {
if network
.can_assign_ips(transaction, &[ip], Some(self.id))
.await
.is_ok()
&& !reserved.contains(&ip)
{
picked = Some(ip);
break;
if ip == address.network() || ip == address.broadcast() || ip == address.ip() {
continue;
}

if used_ips.contains(&ip) || reserved.contains(&ip) {
continue;
}

picked = Some(ip);
break;
}

// Return error if no address can be assigned
Expand Down Expand Up @@ -1129,6 +1131,265 @@ mod test {
assert!(device.is_err());
}

/// Test that assign_next_network_ip correctly preserves or reassigns device IPs
/// when a network's address list changes.
/// Initial network: 10.0.0.0/8, 123.10.0.0/16, 123.123.123.0/24
/// Device IPs: 10.0.0.234, 123.10.33.44, 123.123.123.52
/// New network: 10.0.0.0/16, 123.12.0.0/16, 123.123.0.0/16
/// Expected:
/// - 10.0.0.234 KEPT (still within 10.0.0.0/16)
/// - 123.10.33.44 CHANGED (not within 123.12.0.0/16)
/// - 123.123.123.52 KEPT (still within 123.123.0.0/16)
#[sqlx::test]
async fn test_assign_next_network_ip_preserves_matching_subnets(
_: PgPoolOptions,
options: PgConnectOptions,
) {
let pool = setup_pool(options).await;

let mut network = WireguardNetwork::default();
network
.try_set_address("10.0.0.1/8,123.10.0.1/16,123.123.123.1/24")
.unwrap();
let network = network.save(&pool).await.unwrap();

let user = User::new(
"testuser",
Some("password"),
"Tester",
"Test",
"test@test.com",
None,
)
.save(&pool)
.await
.unwrap();

let device = Device::new(
"dev1".into(),
"key1".into(),
user.id,
DeviceType::User,
None,
true,
)
.save(&pool)
.await
.unwrap();

let ip = IpAddr::from_str("10.0.0.234").unwrap();
let ip2 = IpAddr::from_str("123.10.33.44").unwrap();
let ip3 = IpAddr::from_str("123.123.123.52").unwrap();
let initial_ips = vec![ip, ip2, ip3];

let mut conn = pool.acquire().await.unwrap();
WireguardNetworkDevice::new(network.id, device.id, initial_ips.clone())
.insert(&mut *conn)
.await
.unwrap();

let mut updated_network = network.clone();
updated_network.address = vec![
"10.0.0.0/16".parse::<IpNetwork>().unwrap(),
"123.12.0.0/16".parse::<IpNetwork>().unwrap(),
"123.123.0.0/16".parse::<IpNetwork>().unwrap(),
];
updated_network.save(&mut *conn).await.unwrap();

let used_ips = updated_network
.all_used_ips_for_network(&mut conn)
.await
.unwrap();

let result = device
.assign_next_network_ip(
&mut conn,
&updated_network,
&used_ips,
None,
Some(&initial_ips),
)
.await
.unwrap();

let new_ips = &result.wireguard_ips;
assert_eq!(new_ips.len(), 3, "should have one IP per subnet");

assert!(
new_ips.contains(&ip),
"10.0.0.234 should be kept – it is still within 10.0.0.0/16; got {new_ips:?}"
);

assert!(
!new_ips.contains(&ip2),
"123.10.33.44 should be reassigned – not within 123.12.0.0/16; got {new_ips:?}"
);
let network: IpNetwork = "123.12.0.0/16".parse().unwrap();
assert!(
new_ips.iter().any(|ip| network.contains(*ip)),
"a new IP within 123.12.0.0/16 should be assigned; got {new_ips:?}"
);

assert!(
new_ips.contains(&ip3),
"123.123.123.52 should be kept – it is still within 123.123.0.0/16; got {new_ips:?}"
);
}
/// Initial: 10.0.0.0/8 | 10.1.0.5
/// Modified: 10.0.0.0/16 | 10.1.0.5 should be replaced with a 10.0.x.x address
#[sqlx::test]
async fn test_assign_next_network_ip_subnet_narrowed(
_: PgPoolOptions,
options: PgConnectOptions,
) {
let pool = setup_pool(options).await;

let mut network = WireguardNetwork::default();
network.try_set_address("10.0.0.1/8").unwrap();
let network = network.save(&pool).await.unwrap();

let user = User::new(
"testuser",
Some("password"),
"Tester",
"Test",
"test@test.com",
None,
)
.save(&pool)
.await
.unwrap();

let device = Device::new(
"dev1".into(),
"key1".into(),
user.id,
DeviceType::User,
None,
true,
)
.save(&pool)
.await
.unwrap();

let ip = IpAddr::from_str("10.1.0.5").unwrap();
let initial_ips = vec![ip];

let mut conn = pool.acquire().await.unwrap();
WireguardNetworkDevice::new(network.id, device.id, initial_ips.clone())
.insert(&mut *conn)
.await
.unwrap();

let mut updated_network = network.clone();
updated_network.address = vec!["10.0.0.0/16".parse::<IpNetwork>().unwrap()];
updated_network.save(&mut *conn).await.unwrap();

let used_ips = updated_network
.all_used_ips_for_network(&mut conn)
.await
.unwrap();

let result = device
.assign_next_network_ip(
&mut conn,
&updated_network,
&used_ips,
None,
Some(&initial_ips),
)
.await
.unwrap();

let new_ips = &result.wireguard_ips;
assert_eq!(new_ips.len(), 1, "should have one IP per subnet");

assert!(
!new_ips.contains(&ip),
"10.1.0.5 should be reassigned – outside narrowed 10.0.0.0/16; got {new_ips:?}"
);
let narrowed_net: IpNetwork = "10.0.0.0/16".parse().unwrap();
assert!(
new_ips.iter().all(|ip| narrowed_net.contains(*ip)),
"new IP must be within 10.0.0.0/16; got {new_ips:?}"
);
}

/// Initial: 123.123.123.0/24 | 123.123.123.254
/// Modified: 123.123.0.0/16 | 123.123.123.254 still fits
#[sqlx::test]
async fn test_assign_next_network_ip_still_valid_after_widening(
_: PgPoolOptions,
options: PgConnectOptions,
) {
let pool = setup_pool(options).await;

let mut network = WireguardNetwork::default();
network.try_set_address("123.123.123.1/24").unwrap();
let network = network.save(&pool).await.unwrap();

let user = User::new(
"testuser",
Some("password"),
"Tester",
"Test",
"test@test.com",
None,
)
.save(&pool)
.await
.unwrap();

let device = Device::new(
"dev1".into(),
"key1".into(),
user.id,
DeviceType::User,
None,
true,
)
.save(&pool)
.await
.unwrap();

let ip = IpAddr::from_str("123.123.123.254").unwrap();
let initial_ips = vec![ip];

let mut conn = pool.acquire().await.unwrap();
WireguardNetworkDevice::new(network.id, device.id, initial_ips.clone())
.insert(&mut *conn)
.await
.unwrap();

let mut updated_network = network.clone();
updated_network.address = vec!["123.123.0.0/16".parse::<IpNetwork>().unwrap()];
updated_network.save(&mut *conn).await.unwrap();

let used_ips = updated_network
.all_used_ips_for_network(&mut conn)
.await
.unwrap();

let result = device
.assign_next_network_ip(
&mut conn,
&updated_network,
&used_ips,
None,
Some(&initial_ips),
)
.await
.unwrap();

let new_ips = &result.wireguard_ips;
assert_eq!(new_ips.len(), 1, "should have one IP per subnet");

assert!(
new_ips.contains(&ip),
"123.123.123.254 should be preserved – still within widened 123.123.0.0/16; got {new_ips:?}"
);
}

#[test]
fn test_pubkey_validation() {
let invalid_test_key = "invalid_key";
Expand Down
27 changes: 21 additions & 6 deletions crates/defguard_common/src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
collections::HashMap,
collections::{HashMap, HashSet},
fmt::{self, Display},
iter::zip,
net::{IpAddr, Ipv4Addr},
Expand Down Expand Up @@ -438,10 +438,11 @@ impl WireguardNetwork<Id> {
"Assigning IPs in network {} for all existing devices ",
self
);
let used_ips = self.all_used_ips_for_network(&mut *transaction).await?;
let devices = self.get_allowed_devices(&mut *transaction).await?;
for device in devices {
device
.assign_next_network_ip(&mut *transaction, self, None, None)
.assign_next_network_ip(&mut *transaction, self, &used_ips, None, None)
.await?;
}
Ok(())
Expand All @@ -457,9 +458,11 @@ impl WireguardNetwork<Id> {
info!("Assigning IP in network {self} for {device}");
let allowed_devices = self.get_allowed_devices(&mut *transaction).await?;
let allowed_device_ids: Vec<i64> = allowed_devices.iter().map(|dev| dev.id).collect();
let used_ips = self.all_used_ips_for_network(&mut *transaction).await?;

if allowed_device_ids.contains(&device.id) {
let wireguard_network_device = device
.assign_next_network_ip(&mut *transaction, self, reserved_ips, None)
.assign_next_network_ip(&mut *transaction, self, &used_ips, reserved_ips, None)
.await?;
Ok(wireguard_network_device)
} else {
Expand Down Expand Up @@ -533,9 +536,7 @@ impl WireguardNetwork<Id> {
// split into separate stats for each device
let mut device_stats: HashMap<Id, Vec<WireguardDeviceTransferRow>> =
stats.into_iter().fold(HashMap::new(), |mut acc, item| {
acc.entry(item.device_id)
.or_insert_with(Vec::new)
.push(item);
acc.entry(item.device_id).or_default().push(item);
acc
});

Expand Down Expand Up @@ -1354,6 +1355,20 @@ impl WireguardNetwork<Id> {
.fetch_all(executor)
.await
}

/// Obtain all used ips for network
pub async fn all_used_ips_for_network(
&self,
transaction: &mut PgConnection,
) -> Result<HashSet<IpAddr>, SqlxError> {
let all_devices =
WireguardNetworkDevice::all_for_network(&mut *transaction, self.id).await?;
let used_ips: HashSet<IpAddr> = all_devices
.into_iter()
.flat_map(|device| device.wireguard_ips)
.collect();
Ok(used_ips)
}
}

// [`IpNetwork`] does not implement [`Default`]
Expand Down
11 changes: 0 additions & 11 deletions crates/defguard_core/src/handlers/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,17 +328,6 @@ pub(crate) async fn modify_network(
let before = network.clone();
let new_addresses = data.parse_addresses()?;

// Block network address changes if any device is assigned to the network
if before.address != new_addresses
&& WireguardNetworkDevice::has_devices_in_network(&appstate.pool, network_id).await?
{
return Err(WebError::BadRequest(
"Cannot change network address while devices are assigned to this network. \
Remove all devices first."
.into(),
));
}

network.address = new_addresses;
network.allowed_ips = data.parse_allowed_ips();
network.name = data.name;
Expand Down
Loading
Loading