diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 92cf9d494..211232a74 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -17,6 +17,7 @@ use defguard_common::{ settings::{initialize_current_settings, update_current_settings}, }, }, + gateway_event::GatewayCommand, messages::peer_stats_update::PeerStatsUpdate, types::proxy::ProxyControlMessage, }; @@ -30,7 +31,7 @@ use defguard_core::{ }, events::{ApiEvent, BidiStreamEvent}, gateway_config, - grpc::{GatewayEvent, WorkerState, run_grpc_server}, + grpc::{WorkerState, run_grpc_server}, init_dev_env, init_vpn_location, run_web_server, setup_logs::CoreSetupLogLayer, utility_thread::run_utility_thread, @@ -210,7 +211,7 @@ async fn main() -> Result<(), anyhow::Error> { // setup communication channels for services let (webhook_tx, webhook_rx) = unbounded_channel::(); // RX is discarded here since it can be derived from TX later on - let (gateway_tx, _gateway_rx) = broadcast::channel::(256); + let (gateway_tx, _gateway_rx) = broadcast::channel::(256); let (event_logger_tx, event_logger_rx) = unbounded_channel::(); let (peer_stats_tx, peer_stats_rx) = unbounded_channel::(); diff --git a/crates/defguard_common/src/gateway_event.rs b/crates/defguard_common/src/gateway_event.rs new file mode 100644 index 000000000..426504d3a --- /dev/null +++ b/crates/defguard_common/src/gateway_event.rs @@ -0,0 +1,66 @@ +//! Gateway command types and helpers for communicating with the gateway manager service. +//! +//! [`GatewayCommand`] is the primary type sent from core to the gateway manager over +//! an in-process broadcast channel. The gateway manager converts each command to the +//! appropriate protobuf wire message before forwarding it to the gateway daemon. + +use tokio::sync::broadcast::Sender; +use tracing::{debug, error}; + +use crate::{ + db::{ + Id, + models::{ + Device, WireguardNetwork, + device::{DeviceInfo, DeviceNetworkInfo}, + }, + }, + gateway_types::{FirewallConfig, WireguardPeer}, +}; + +/// A command sent from core to the gateway manager service. +/// +/// Each variant instructs the gateway daemon to update its WireGuard state or +/// firewall configuration. Native Rust types are used throughout; conversion to +/// protobuf wire types happens at the serialization boundary in the gateway manager. +#[derive(Clone, Debug)] +pub enum GatewayCommand { + NetworkCreated(Id, WireguardNetwork), + NetworkModified( + Id, + WireguardNetwork, + Vec, + Option, + ), + NetworkDeleted(Id, String), + DeviceCreated(DeviceInfo), + DeviceModified(DeviceInfo), + DeviceDeleted(DeviceInfo), + FirewallConfigChanged(Id, FirewallConfig), + FirewallDisabled(Id), + VpnSessionAuthorized(Id, Device, DeviceNetworkInfo), + VpnSessionDeauthorized(Id, Device), +} + +/// Sends a [`GatewayCommand`] to the gateway manager service. +/// +/// In API handler context prefer `AppState::send_gateway_command`. +pub fn send_gateway_command(command: GatewayCommand, gateway_tx: &Sender) { + debug!("Sending the following command to Gateway Manager: {command:?}"); + if let Err(err) = gateway_tx.send(command) { + error!("Error sending gateway command: {err}"); + } +} + +/// Sends multiple [`GatewayCommand`]s to the gateway manager service. +/// +/// In API handler context prefer `AppState::send_multiple_gateway_commands`. +pub fn send_multiple_gateway_commands( + commands: Vec, + gateway_tx: &Sender, +) { + debug!("Sending {} gateway commands", commands.len()); + for command in commands { + send_gateway_command(command, gateway_tx); + } +} diff --git a/crates/defguard_common/src/gateway_types.rs b/crates/defguard_common/src/gateway_types.rs new file mode 100644 index 000000000..25737def9 --- /dev/null +++ b/crates/defguard_common/src/gateway_types.rs @@ -0,0 +1,113 @@ +//! Native Rust types for data carried in [`crate::gateway_event::GatewayCommand`] variants. +//! +//! These are domain types; conversion to protobuf wire types happens at the +//! serialization boundary (gateway manager) via `From` impls in `defguard_proto`. + +/// A WireGuard peer entry to be configured on a gateway. +#[derive(Clone, Debug, PartialEq)] +pub struct WireguardPeer { + pub pubkey: String, + pub allowed_ips: Vec, + pub preshared_key: Option, + pub keepalive_interval: Option, +} + +/// Default firewall action applied to traffic that does not match any rule. +#[derive(Clone, Debug, Default, PartialEq)] +pub enum FirewallPolicy { + #[default] + Unspecified, + Allow, + Deny, +} + +/// IP protocol version a firewall rule applies to. +#[derive(Clone, Debug, Default, PartialEq)] +pub enum IpVersion { + #[default] + Unspecified, + Ipv4, + Ipv6, +} + +/// Network protocol matched by a firewall rule. +#[derive(Clone, Debug, PartialEq)] +pub enum Protocol { + Unspecified, + Icmp, + Tcp, + Udp, +} + +impl From for Protocol { + fn from(v: i32) -> Self { + match v { + 1 => Self::Icmp, + 6 => Self::Tcp, + 17 => Self::Udp, + _ => Self::Unspecified, + } + } +} + +/// An inclusive range of IP addresses. +#[derive(Clone, Debug, PartialEq)] +pub struct IpRange { + pub start: String, + pub end: String, +} + +/// An IP address, range, or subnet. +#[derive(Clone, Debug, PartialEq)] +pub enum IpAddress { + /// Single IP address string (e.g. `"10.0.0.1"`). + Ip(String), + /// Inclusive IP range. + IpRange(IpRange), + /// IP subnet in CIDR notation (e.g. `"10.0.0.0/24"`). + IpSubnet(String), +} + +/// An inclusive range of port numbers. +#[derive(Clone, Debug, PartialEq)] +pub struct PortRange { + pub start: u32, + pub end: u32, +} + +/// A single port or an inclusive port range matched by a firewall rule. +#[derive(Clone, Debug, PartialEq)] +pub enum Port { + Single(u32), + Range(PortRange), +} + +/// A single ACL-derived firewall rule to be enforced on a gateway. +#[derive(Clone, Debug, PartialEq)] +pub struct FirewallRule { + pub id: i64, + pub source_addrs: Vec, + pub destination_addrs: Vec, + pub destination_ports: Vec, + pub protocols: Vec, + pub verdict: FirewallPolicy, + pub comment: Option, + pub ip_version: IpVersion, +} + +/// Source NAT binding that rewrites the source IP of matching VPN traffic. +#[derive(Clone, Debug, PartialEq)] +pub struct SnatBinding { + pub id: i64, + pub source_addrs: Vec, + pub public_ip: String, + pub comment: Option, +} + +/// Full firewall configuration to be applied to a gateway location. +#[derive(Clone, Debug, Default, PartialEq)] +pub struct FirewallConfig { + pub default_policy: FirewallPolicy, + pub rules: Vec, + pub snat_bindings: Vec, +} diff --git a/crates/defguard_common/src/lib.rs b/crates/defguard_common/src/lib.rs index 394f7e333..9eff8e1d0 100644 --- a/crates/defguard_common/src/lib.rs +++ b/crates/defguard_common/src/lib.rs @@ -2,6 +2,8 @@ pub mod auth; pub mod config; pub mod csv; pub mod db; +pub mod gateway_event; +pub mod gateway_types; pub mod globals; pub mod hex; pub mod messages; diff --git a/crates/defguard_core/src/appstate.rs b/crates/defguard_core/src/appstate.rs index 499775c14..a69237d3c 100644 --- a/crates/defguard_core/src/appstate.rs +++ b/crates/defguard_core/src/appstate.rs @@ -19,7 +19,7 @@ use crate::{ db::{AppEvent, WebHook}, error::WebError, events::ApiEvent, - grpc::{GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, + grpc::{GatewayCommand, send_gateway_command, send_multiple_gateway_commands}, version::IncompatibleComponents, }; @@ -29,7 +29,7 @@ const X_DEFGUARD_EVENT: &str = "x-defguard-event"; pub struct AppState { pub pool: PgPool, tx: UnboundedSender, - pub wireguard_tx: Sender, + pub gateway_tx: Sender, pub web_reload_tx: tokio::sync::broadcast::Sender<()>, pub failed_logins: Arc>, key: Key, @@ -86,16 +86,16 @@ impl AppState { } } - /// Sends given `GatewayEvent` to be handled by gateway GRPC server. - /// Convenience wrapper around [`send_wireguard_event`] - pub fn send_wireguard_event(&self, event: GatewayEvent) { - send_wireguard_event(event, &self.wireguard_tx); + /// Sends given `GatewayCommand` to be handled by gateway manager service. + /// Convenience wrapper around [`send_gateway_command`] + pub fn send_gateway_command(&self, command: GatewayCommand) { + send_gateway_command(command, &self.gateway_tx); } - /// Sends multiple events to be handled by gateway GRPC server. - /// Convenience wrapper around [`send_multiple_wireguard_events`] - pub fn send_multiple_wireguard_events(&self, events: Vec) { - send_multiple_wireguard_events(events, &self.wireguard_tx); + /// Sends multiple commands to be handled by gateway manager service. + /// Convenience wrapper around [`send_multiple_gateway_commands`] + pub fn send_multiple_gateway_commands(&self, commands: Vec) { + send_multiple_gateway_commands(commands, &self.gateway_tx); } /// Sends event to the main event router @@ -118,7 +118,7 @@ impl AppState { pool: PgPool, tx: UnboundedSender, rx: UnboundedReceiver, - wireguard_tx: Sender, + gateway_tx: Sender, web_reload_tx: tokio::sync::broadcast::Sender<()>, key: Key, failed_logins: Arc>, @@ -132,7 +132,7 @@ impl AppState { Self { pool, tx, - wireguard_tx, + gateway_tx, web_reload_tx, failed_logins, key, diff --git a/crates/defguard_core/src/enterprise/db/models/acl.rs b/crates/defguard_core/src/enterprise/db/models/acl.rs index a13f0a5c4..699f48aaf 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl.rs @@ -27,7 +27,7 @@ use crate::{ ApiAclRule, EditAclRule, alias::EditAclAlias, destination::EditAclDestination, }, }, - grpc::GatewayEvent, + grpc::GatewayCommand, }; #[derive(Debug, Error)] @@ -547,7 +547,7 @@ impl AclRule { match try_get_location_firewall_config(&location, &mut transaction).await? { Some(firewall_config) => { debug!("Sending firewall update event for location {location}"); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, )); @@ -555,7 +555,7 @@ impl AclRule { None => { debug!( "No firewall config generated for location {location}. Not sending a \ - gateway event" + gateway command" ); } } @@ -1828,7 +1828,7 @@ impl AclAlias { match try_get_location_firewall_config(&location, &mut transaction).await? { Some(firewall_config) => { debug!("Sending firewall update event for location {location}"); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, )); @@ -1836,7 +1836,7 @@ impl AclAlias { None => { debug!( "No firewall config generated for location {location}. Not sending a \ - gateway event" + gateway command" ); } } diff --git a/crates/defguard_core/src/enterprise/directory_sync/mod.rs b/crates/defguard_core/src/enterprise/directory_sync/mod.rs index f6da8105e..22d358b8e 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/mod.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/mod.rs @@ -31,7 +31,7 @@ use crate::{ license::get_cached_license, limits::{get_counts, update_counts}, }, - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::user::check_username, user_management::{delete_user_and_cleanup_devices, disable_user, sync_allowed_user_devices}, }; @@ -324,7 +324,7 @@ async fn sync_user_groups( directory_sync: &T, user: &User, pool: &PgPool, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), DirectorySyncError> { info!("Syncing groups of user {} with the directory", user.email); let directory_groups = directory_sync.get_user_groups(&user.email).await?; @@ -371,7 +371,7 @@ async fn sync_user_groups( } } - sync_allowed_user_devices(user, &mut transaction, wg_tx) + sync_allowed_user_devices(user, &mut transaction, gateway_tx) .await .map_err(|err| { DirectorySyncError::NetworkUpdateError(format!( @@ -418,7 +418,7 @@ pub(crate) async fn test_directory_sync_connection( pub async fn sync_user_groups_if_configured( user: &User, pool: &PgPool, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_business_license_active() { @@ -435,7 +435,7 @@ pub async fn sync_user_groups_if_configured( match DirectorySyncClient::build(pool).await { Ok(mut dir_sync) => { dir_sync.prepare().await?; - sync_user_groups(&dir_sync, user, pool, wg_tx).await?; + sync_user_groups(&dir_sync, user, pool, gateway_tx).await?; } Err(err) => { error!("Failed to build directory sync client: {err}"); @@ -482,7 +482,7 @@ async fn create_and_add_to_group( async fn sync_all_users_groups( directory_sync: &T, pool: &PgPool, - wg_tx: &Sender, + gateway_tx: &Sender, all_users: Option<&[DirectoryUser]>, ) -> Result<(), DirectorySyncError> { info!("Syncing all users' groups with the directory, this may take a while..."); @@ -579,7 +579,7 @@ async fn sync_all_users_groups( create_and_add_to_group(&user, group, pool).await?; } - sync_allowed_user_devices(&user, &mut transaction, wg_tx).await.map_err(|err| { + sync_allowed_user_devices(&user, &mut transaction, gateway_tx).await.map_err(|err| { DirectorySyncError::NetworkUpdateError(format!( "Failed to sync allowed devices for user {} during directory synchronization: {err}", user.email @@ -618,7 +618,7 @@ fn is_directory_sync_enabled(provider: Option<&OpenIdProvider>) -> bool { async fn sync_all_users_state( pool: &PgPool, - wg_tx: &Sender, + gateway_tx: &Sender, all_users: &[DirectoryUser], ) -> Result<(), DirectorySyncError> { info!("Syncing all users' state with the directory, this may take a while..."); @@ -656,7 +656,7 @@ async fn sync_all_users_state( &mut transaction, &inactive_directory_users, &mut modified_users, - wg_tx, + gateway_tx, ) .await?; @@ -789,7 +789,7 @@ async fn sync_all_users_state( the admin behavior setting is set to disable", user.email ); - disable_user(&mut user, &mut transaction, wg_tx).await.map_err(|err| { + disable_user(&mut user, &mut transaction, gateway_tx).await.map_err(|err| { DirectorySyncError::UserUpdateError(format!( "Failed to disable admin {} during directory synchronization: {err}", user.email @@ -819,7 +819,7 @@ async fn sync_all_users_state( if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { deleted_users.push(user.clone().as_noid()); } - delete_user_and_cleanup_devices(user, &mut transaction, wg_tx) + delete_user_and_cleanup_devices(user, &mut transaction, gateway_tx) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -844,7 +844,7 @@ async fn sync_all_users_state( the user behavior setting is set to disable", user.email ); - disable_user(&mut user, &mut transaction, wg_tx).await.map_err(|err| { + disable_user(&mut user, &mut transaction, gateway_tx).await.map_err(|err| { DirectorySyncError::UserUpdateError(format!( "Failed to disable user {} during directory synchronization: {err}", user.email @@ -866,7 +866,7 @@ async fn sync_all_users_state( if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { deleted_users.push(user.clone().as_noid()); } - delete_user_and_cleanup_devices(user, &mut transaction, wg_tx) + delete_user_and_cleanup_devices(user, &mut transaction, gateway_tx) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -904,7 +904,7 @@ async fn sync_inactive_directory_users( transaction: &mut PgConnection, inactive_directory_users: &[&DirectoryUser], modified_users: &mut Vec>, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), DirectorySyncError> { // find all active Defguard users disabled in directory let disabled_users_emails = inactive_directory_users @@ -929,7 +929,7 @@ async fn sync_inactive_directory_users( "Disabling user {} because they are disabled in the directory", user.email ); - disable_user(&mut user, transaction, wg_tx) + disable_user(&mut user, transaction, gateway_tx) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -1005,7 +1005,7 @@ pub(crate) async fn get_directory_sync_interval(pool: &PgPool) -> u64 { // Performs the directory sync job. This function is called by the utility thread. pub(crate) async fn do_directory_sync( pool: &PgPool, - wireguard_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_business_license_active() { @@ -1042,7 +1042,7 @@ pub(crate) async fn do_directory_sync( DirectorySyncTarget::All | DirectorySyncTarget::Users ) { let users = dir_sync.get_all_users().await?; - sync_all_users_state(pool, wireguard_tx, &users).await?; + sync_all_users_state(pool, gateway_tx, &users).await?; all_users = Some(users); } if matches!( @@ -1063,7 +1063,7 @@ pub(crate) async fn do_directory_sync( } _ => None, // No need to pass all users for other providers, for the time being. }; - sync_all_users_groups(&dir_sync, pool, wireguard_tx, users_to_pass.as_deref()) + sync_all_users_groups(&dir_sync, pool, gateway_tx, users_to_pass.as_deref()) .await?; } } diff --git a/crates/defguard_core/src/enterprise/directory_sync/tests.rs b/crates/defguard_core/src/enterprise/directory_sync/tests.rs index 3eb6a4fff..0487838e7 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/tests.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/tests.rs @@ -146,7 +146,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -167,7 +167,7 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); @@ -176,7 +176,7 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); // No events - assert!(wg_rx.try_recv().is_err()); + assert!(gateway_rx.try_recv().is_err()); } // Delete users, keep admins @@ -186,7 +186,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -208,7 +208,7 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); @@ -216,8 +216,8 @@ mod test { assert!(get_test_user(&pool, "user2").await.is_none()); assert!(get_test_user(&pool, "testuser").await.is_some()); - let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + let event = gateway_rx.try_recv(); + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event { assert_eq!(dev.device.user_id, user2.id); } else { panic!("Expected a DeviceDeleted event"); @@ -229,7 +229,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); User::init_admin_user(&pool, "pass123").await.unwrap(); let _ = make_test_provider( @@ -254,7 +254,7 @@ mod test { assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); @@ -270,8 +270,8 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); // Check that we received a device deleted event for whichever admin was removed - let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + let event = gateway_rx.try_recv(); + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event { assert!(dev.device.user_id == user1.id || dev.device.user_id == user3.id); } else { panic!("Expected a DeviceDeleted event"); @@ -284,7 +284,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -308,7 +308,7 @@ mod test { assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); @@ -324,8 +324,8 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); // Check for device deletion events - let event1 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + let event1 = gateway_rx.try_recv(); + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event1 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user2.id @@ -335,8 +335,8 @@ mod test { panic!("Expected a DeviceDeleted event"); } - let event2 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + let event2 = gateway_rx.try_recv(); + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event2 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user2.id @@ -353,7 +353,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Disable, @@ -395,20 +395,20 @@ mod test { assert!(testuserdisabled.is_active); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); // Check for device disconnection events - let event1 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + let event1 = gateway_rx.try_recv(); + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event1 { assert!(dev.device.user_id == user2.id || dev.device.user_id == testuserdisabled.id); } else { panic!("Expected a DeviceDisconnected event"); } - let event2 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + let event2 = gateway_rx.try_recv(); + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event2 { assert!(dev.device.user_id == user2.id || dev.device.user_id == testuserdisabled.id); } else { panic!("Expected a DeviceDisconnected event"); @@ -436,7 +436,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); // Added mut wg_rx + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); // Added mut gateway_rx make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -468,13 +468,13 @@ mod test { assert!(testuserdisabled.is_active); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); // Check for device disconnection events - let event1 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + let event1 = gateway_rx.try_recv(); + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event1 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user3.id @@ -484,8 +484,8 @@ mod test { panic!("Expected a DeviceDisconnected event"); } - let event2 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + let event2 = gateway_rx.try_recv(); + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event2 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user3.id @@ -514,7 +514,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -530,7 +530,7 @@ mod test { make_test_user_and_device("testuser2", &pool).await; make_test_user_and_device("testuserdisabled", &pool).await; let all_users = client.get_all_users().await.unwrap(); - sync_all_users_groups(&client, &pool, &wg_tx, Some(&all_users)) + sync_all_users_groups(&client, &pool, &gateway_tx, Some(&all_users)) .await .unwrap(); @@ -571,7 +571,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -585,7 +585,7 @@ mod test { let user = make_test_user_and_device("testuser", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - sync_user_groups_if_configured(&user, &pool, &wg_tx) + sync_user_groups_if_configured(&user, &pool, &gateway_tx) .await .unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); @@ -600,7 +600,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -614,7 +614,7 @@ mod test { let user = make_test_user_and_device("testuser", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); } @@ -625,7 +625,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -648,24 +648,24 @@ mod test { let user2_pre_sync = make_test_user_and_device("user2", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 3); let user2 = get_test_user(&pool, "user2").await; assert!(user2.is_none()); let mut transaction = pool.begin().await.unwrap(); - sync_allowed_user_devices(&user, &mut transaction, &wg_tx) + sync_allowed_user_devices(&user, &mut transaction, &gateway_tx) .await .unwrap(); transaction.commit().await.unwrap(); - let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + let event = gateway_rx.try_recv(); + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event { assert_eq!(dev.device.user_id, user2_pre_sync.id); } else { panic!("Expected DeviceDeleted event"); } - let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceCreated(dev)) = event { + let event = gateway_rx.try_recv(); + if let Ok(GatewayCommand::DeviceCreated(dev)) = event { panic!("Unexpected DeviceCreated event: {dev:?}"); } } @@ -676,7 +676,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -691,7 +691,7 @@ mod test { make_test_user_and_device("user2", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 3); let user2 = get_test_user(&pool, "user2").await; @@ -704,7 +704,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -724,7 +724,7 @@ mod test { assert_eq!(user_groups.len(), 1); assert!(user.is_admin(&pool).await.unwrap()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); // He should still be an admin as it's the last one assert!(user.is_admin(&pool).await.unwrap()); @@ -733,7 +733,7 @@ mod test { let user2 = make_test_user_and_device("testuser2", &pool).await; user2.add_to_group(&pool, &admin_grp).await.unwrap(); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); let admins = User::find_admins(&pool).await.unwrap(); // There should be only one admin left @@ -742,7 +742,7 @@ mod test { let defguard_user = make_test_user_and_device("defguard", &pool).await; make_admin(&pool, &defguard_user).await; - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); } #[sqlx::test] @@ -751,7 +751,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -768,7 +768,7 @@ mod test { make_admin(&pool, &defguard_user).await; assert!(defguard_user.is_admin(&pool).await.unwrap()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); // The user should still be an admin assert!(defguard_user.is_admin(&pool).await.unwrap()); @@ -780,7 +780,7 @@ mod test { .await .unwrap(); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); let user = User::find_by_username(&pool, "defguard").await.unwrap(); assert!(user.is_none()); } @@ -791,7 +791,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); // disable prefetching users make_test_provider( @@ -809,14 +809,14 @@ mod test { let defguard_users = User::all(&pool).await.unwrap(); assert!(defguard_users.is_empty()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); // no users in Defguard after sync let defguard_users = User::all(&pool).await.unwrap(); assert!(defguard_users.is_empty()); // No events - assert!(wg_rx.try_recv().is_err()); + assert!(gateway_rx.try_recv().is_err()); } #[sqlx::test] @@ -825,7 +825,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); // enable prefetching users make_test_provider( @@ -843,14 +843,14 @@ mod test { let defguard_users = User::all(&pool).await.unwrap(); assert!(defguard_users.is_empty()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); // all active directory users were synced let defguard_users = User::all(&pool).await.unwrap(); assert_eq!(defguard_users.len(), 3); // No events - assert!(wg_rx.try_recv().is_err()); + assert!(gateway_rx.try_recv().is_err()); } #[sqlx::test] @@ -862,7 +862,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); // enable prefetching users make_test_provider( @@ -892,7 +892,7 @@ mod test { set_cached_license(Some(license)); update_counts(&pool).await.unwrap(); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); update_counts(&pool).await.unwrap(); let user_count = get_counts().user(); @@ -902,6 +902,6 @@ mod test { assert_eq!(defguard_users.len(), user_limit as usize); // No events - assert!(wg_rx.try_recv().is_err()); + assert!(gateway_rx.try_recv().is_err()); } } diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs index a49bc2f5d..aba2ea4b7 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -1,13 +1,14 @@ use std::{net::IpAddr, ops::RangeInclusive}; -use defguard_common::db::{ - Id, - models::{Device, ModelError, WireguardNetwork, user::User}, -}; -use defguard_proto::enterprise::firewall::{ - FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpRange, IpVersion, Port, - PortRange as PortRangeProto, SnatBinding as SnatBindingProto, ip_address::Address, - port::Port as PortInner, +use defguard_common::{ + db::{ + Id, + models::{Device, ModelError, WireguardNetwork, user::User}, + }, + gateway_types::{ + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpRange, IpVersion, Port, + PortRange as GwPortRange, Protocol as GwProtocol, SnatBinding, + }, }; use ipnetwork::IpNetwork; use sqlx::{PgConnection, query_as, query_scalar}; @@ -421,7 +422,7 @@ fn create_rules( protocols: &[Protocol], comment: &str, ) -> (Option, FirewallRule) { - let ip_version = i32::from(ip_version); + let gw_protocols: Vec = protocols.iter().map(|&p| GwProtocol::from(p)).collect(); let allow = if source_addrs.is_empty() { debug!("Source address list is empty. Skipping generating the ALLOW rule for this ACL"); None @@ -432,10 +433,10 @@ fn create_rules( source_addrs: source_addrs.to_vec(), destination_addrs: destination_addrs.to_vec(), destination_ports: destination_ports.to_vec(), - protocols: protocols.to_vec(), - verdict: i32::from(FirewallPolicy::Allow), + protocols: gw_protocols.clone(), + verdict: FirewallPolicy::Allow, comment: Some(format!("{comment} ALLOW")), - ip_version, + ip_version: ip_version.clone(), }; debug!("ALLOW rule generated from ACL: {rule:?}"); Some(rule) @@ -449,7 +450,7 @@ fn create_rules( destination_addrs: destination_addrs.to_vec(), destination_ports: Vec::new(), protocols: Vec::new(), - verdict: i32::from(FirewallPolicy::Deny), + verdict: FirewallPolicy::Deny, comment: Some(format!("{comment} DENY")), ip_version, }; @@ -625,12 +626,10 @@ fn extract_all_subnets_from_range(range_start: IpAddr, range_end: IpAddr) -> Vec // If decomposition produced nothing for a multi-IP range, the range straddles // a CIDR boundary and cannot be expressed as subnets - fall back to IpRange. if networks.is_empty() && range_start != range_end { - return vec![IpAddress { - address: Some(Address::IpRange(IpRange { - start: range_start.to_string(), - end: range_end.to_string(), - })), - }]; + return vec![IpAddress::IpRange(IpRange { + start: range_start.to_string(), + end: range_end.to_string(), + })]; } networks @@ -638,12 +637,11 @@ fn extract_all_subnets_from_range(range_start: IpAddr, range_end: IpAddr) -> Vec .map(|network| { let is_host = (network.is_ipv4() && network.prefix() == 32) || (network.is_ipv6() && network.prefix() == 128); - let address = if is_host { - Some(Address::Ip(network.ip().to_string())) + if is_host { + IpAddress::Ip(network.ip().to_string()) } else { - Some(Address::IpSubnet(network.to_string())) - }; - IpAddress { address } + IpAddress::IpSubnet(network.to_string()) + } }) .collect() } @@ -680,16 +678,12 @@ fn merge_port_ranges(port_ranges: Vec) -> Vec { let range_start = *range.start(); let range_end = *range.end(); if range_start == range_end { - Port { - port: Some(PortInner::SinglePort(u32::from(range_start))), - } + Port::Single(u32::from(range_start)) } else { - Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: u32::from(range_start), - end: u32::from(range_end), - })), - } + Port::Range(GwPortRange { + start: u32::from(range_start), + end: u32::from(range_end), + }) } }) .collect() @@ -702,7 +696,7 @@ fn merge_port_ranges(port_ranges: Vec) -> Vec { async fn generate_user_snat_bindings_for_location( location_id: Id, conn: &mut PgConnection, -) -> sqlx::Result> { +) -> sqlx::Result> { debug!("Generating SNAT bindings for location {location_id}"); let user_snat_bindings = UserSnatBinding::all_for_location(&mut *conn, location_id).await?; @@ -754,7 +748,7 @@ async fn generate_user_snat_bindings_for_location( } // create the SNAT binding proto - let snat_binding = SnatBindingProto { + let snat_binding = SnatBinding { id: user_binding.id, source_addrs, public_ip: user_binding.public_ip.to_string(), @@ -849,7 +843,7 @@ pub async fn try_get_location_firewall_config( generate_firewall_rules_from_acls(location.id, location_acls, &mut *conn).await?; let snat_bindings = generate_user_snat_bindings_for_location(location.id, &mut *conn).await?; let firewall_config = FirewallConfig { - default_policy: default_policy.into(), + default_policy, rules: firewall_rules, snat_bindings, }; diff --git a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs index ae4c50b63..a53c91ec8 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs @@ -1,20 +1,22 @@ +use defguard_common::{ + db::{NoId, models::WireguardNetwork, setup_pool}, + gateway_types::{ + FirewallPolicy, IpAddress, IpRange, Port, PortRange as GwPortRange, Protocol as GwProtocol, + }, +}; +use defguard_proto::enterprise::firewall::Protocol as ProtoProtocol; +use rand::thread_rng; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, ops::RangeInclusive, }; -use defguard_common::db::{NoId, models::WireguardNetwork, setup_pool}; -use defguard_proto::enterprise::firewall::{ - FirewallPolicy, IpAddress, IpRange, Port, Protocol, ip_address::Address, - port::Port as PortInner, -}; -use rand::thread_rng; -use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; - use super::{create_acl_rule, create_test_users_and_devices, set_test_license_business}; use crate::enterprise::{ db::models::acl::{ - AclAlias, AclAliasDestinationRange, AclRule, AclRuleDestinationRange, AliasKind, RuleState, + AclAlias, AclAliasDestinationRange, AclRule, AclRuleDestinationRange, AliasKind, PortRange, + RuleState, }, firewall::{process_destination_addrs, try_get_location_firewall_config}, }; @@ -50,21 +52,13 @@ fn test_process_destination_addrs_v4() { assert_eq!( destination_addrs.0, [ - IpAddress { - address: Some(Address::IpSubnet("10.0.1.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.3.255".to_owned(), - end: "10.0.4.0".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }, + IpAddress::IpSubnet("10.0.1.0/24".to_owned()), + IpAddress::IpSubnet("10.0.2.0/24".to_owned()), + IpAddress::IpRange(IpRange { + start: "10.0.3.255".to_owned(), + end: "10.0.4.0".to_owned(), + }), + IpAddress::IpSubnet("192.168.1.0/24".to_owned()), ] ); @@ -111,21 +105,11 @@ fn test_process_destination_addrs_v6() { assert_eq!( destination_addrs.1, [ - IpAddress { - address: Some(Address::IpSubnet("2001:db8:1::/64".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:2::/64".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:3::/64".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("2001:db8:4::1".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:4::2/127".to_owned())) - } + IpAddress::IpSubnet("2001:db8:1::/64".to_owned()), + IpAddress::IpSubnet("2001:db8:2::/64".to_owned()), + IpAddress::IpSubnet("2001:db8:3::/64".to_owned()), + IpAddress::Ip("2001:db8:4::1".to_owned()), + IpAddress::IpSubnet("2001:db8:4::2/127".to_owned()) ] ); @@ -195,29 +179,25 @@ async fn test_any_address_overwrites_manual_destination( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert!(allow_rule.destination_addrs.is_empty()); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert!(deny_rule.destination_addrs.is_empty()); } @@ -294,29 +274,25 @@ async fn test_any_address_overwrites_destination_alias_addrs( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert!(allow_rule.destination_addrs.is_empty()); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert!(deny_rule.destination_addrs.is_empty()); } @@ -390,35 +366,29 @@ async fn test_manual_destination_includes_component_alias_address_range( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; - let expected_destination_addrs = [IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.2.0.255".to_owned(), - end: "10.2.1.0".to_owned(), - })), - }]; + let expected_destination_addrs = [IpAddress::IpRange(IpRange { + start: "10.2.0.255".to_owned(), + end: "10.2.1.0".to_owned(), + })]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); } @@ -495,43 +465,35 @@ async fn test_manual_destination_merges_rule_and_component_alias_address_ranges( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.2.0.255".to_owned(), - end: "10.2.1.0".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.3.0.255".to_owned(), - end: "10.3.1.0".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.2.0.255".to_owned(), + end: "10.2.1.0".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.3.0.255".to_owned(), + end: "10.3.1.0".to_owned(), + }), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); } @@ -560,10 +522,10 @@ async fn test_any_port_preserves_destination_addresses_and_protocols( allow_all_users: true, addresses: vec!["192.168.50.0/24".parse().unwrap()], ports: vec![ - crate::enterprise::db::models::acl::PortRange::new(22, 22).into(), - crate::enterprise::db::models::acl::PortRange::new(443, 443).into(), + PortRange::new(22, 22).into(), + PortRange::new(443, 443).into(), ], - protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + protocols: vec![ProtoProtocol::Tcp as i32, ProtoProtocol::Udp as i32], any_address: false, any_port: true, any_protocol: false, @@ -597,51 +559,34 @@ async fn test_any_port_preserves_destination_addresses_and_protocols( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::IpSubnet("192.168.50.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.60.10/31".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.60.12/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.60.16/30".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.60.20".to_owned())), - }, + IpAddress::IpSubnet("192.168.50.0/24".to_owned()), + IpAddress::IpSubnet("192.168.60.10/31".to_owned()), + IpAddress::IpSubnet("192.168.60.12/30".to_owned()), + IpAddress::IpSubnet("192.168.60.16/30".to_owned()), + IpAddress::Ip("192.168.60.20".to_owned()), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); assert!(allow_rule.destination_ports.is_empty()); - assert_eq!( - allow_rule.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp, GwProtocol::Udp]); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -678,7 +623,7 @@ async fn test_any_protocol_preserves_destination_addresses_and_ports( crate::enterprise::db::models::acl::PortRange::new(80, 80).into(), crate::enterprise::db::models::acl::PortRange::new(1000, 1005).into(), ], - protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + protocols: vec![ProtoProtocol::Tcp as i32, ProtoProtocol::Udp as i32], any_address: false, any_port: false, any_protocol: true, @@ -709,52 +654,38 @@ async fn test_any_protocol_preserves_destination_addresses_and_ports( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::IpSubnet("192.168.70.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.80.1".to_owned())), - }, + IpAddress::IpSubnet("192.168.70.0/24".to_owned()), + IpAddress::Ip("192.168.80.1".to_owned()), ]; let expected_ports = [ - Port { - port: Some(PortInner::SinglePort(80)), - }, - Port { - port: Some(PortInner::PortRange( - defguard_proto::enterprise::firewall::PortRange { - start: 1000, - end: 1005, - }, - )), - }, + Port::Single(80), + Port::Range(GwPortRange { + start: 1000, + end: 1005, + }), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); assert_eq!(allow_rule.destination_ports, expected_ports); assert!(allow_rule.protocols.is_empty()); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -787,7 +718,7 @@ async fn test_destination_alias_any_port_preserves_addresses_and_protocols( crate::enterprise::db::models::acl::PortRange::new(22, 22).into(), crate::enterprise::db::models::acl::PortRange::new(443, 443).into(), ], - protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + protocols: vec![ProtoProtocol::Tcp as i32, ProtoProtocol::Udp as i32], any_address: false, any_port: true, any_protocol: false, @@ -838,51 +769,34 @@ async fn test_destination_alias_any_port_preserves_addresses_and_protocols( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::IpSubnet("192.168.90.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.91.10/31".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.91.12/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.91.16/30".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.91.20".to_owned())), - }, + IpAddress::IpSubnet("192.168.90.0/24".to_owned()), + IpAddress::IpSubnet("192.168.91.10/31".to_owned()), + IpAddress::IpSubnet("192.168.91.12/30".to_owned()), + IpAddress::IpSubnet("192.168.91.16/30".to_owned()), + IpAddress::Ip("192.168.91.20".to_owned()), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); assert!(allow_rule.destination_ports.is_empty()); - assert_eq!( - allow_rule.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp, GwProtocol::Udp]); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -918,7 +832,7 @@ async fn test_destination_alias_any_protocol_preserves_addresses_and_ports( crate::enterprise::db::models::acl::PortRange::new(80, 80).into(), crate::enterprise::db::models::acl::PortRange::new(1000, 1005).into(), ], - protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + protocols: vec![ProtoProtocol::Tcp as i32, ProtoProtocol::Udp as i32], any_address: false, any_port: false, any_protocol: true, @@ -959,52 +873,38 @@ async fn test_destination_alias_any_protocol_preserves_addresses_and_ports( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::IpSubnet("192.168.110.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.120.1".to_owned())), - }, + IpAddress::IpSubnet("192.168.110.0/24".to_owned()), + IpAddress::Ip("192.168.120.1".to_owned()), ]; let expected_ports = [ - Port { - port: Some(PortInner::SinglePort(80)), - }, - Port { - port: Some(PortInner::PortRange( - defguard_proto::enterprise::firewall::PortRange { - start: 1000, - end: 1005, - }, - )), - }, + Port::Single(80), + Port::Range(GwPortRange { + start: 1000, + end: 1005, + }), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); assert_eq!(allow_rule.destination_ports, expected_ports); assert!(allow_rule.protocols.is_empty()); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); assert!(deny_rule.destination_ports.is_empty()); diff --git a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs b/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs index 47ae9b0db..96b1b96e4 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs @@ -1,12 +1,14 @@ +use defguard_common::{ + db::{ + Id, NoId, + models::{Device, DeviceType, User, WireguardNetwork, device::WireguardNetworkDevice}, + setup_pool, + }, + gateway_types::{FirewallPolicy, IpVersion}, +}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use chrono::NaiveDateTime; -use defguard_common::db::{ - Id, NoId, - models::{Device, DeviceType, User, WireguardNetwork, device::WireguardNetworkDevice}, - setup_pool, -}; -use defguard_proto::enterprise::firewall::{FirewallPolicy, IpVersion}; use ipnetwork::IpNetwork; use rand::{Rng, rngs::ThreadRng, thread_rng}; use sqlx::{ @@ -124,12 +126,12 @@ async fn test_gh1868_ipv6_rule_is_not_created_with_v4_only_destination( assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict(), FirewallPolicy::Allow); - assert_eq!(allow_rule.ip_version(), IpVersion::Ipv4); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(allow_rule.ip_version, IpVersion::Ipv4); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict(), FirewallPolicy::Deny); - assert_eq!(allow_rule.ip_version(), IpVersion::Ipv4); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); + assert_eq!(allow_rule.ip_version, IpVersion::Ipv4); } #[sqlx::test] @@ -184,12 +186,12 @@ async fn test_gh1868_ipv4_rule_is_not_created_with_v6_only_destination( assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); - assert_eq!(allow_rule.ip_version, i32::from(IpVersion::Ipv6)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(allow_rule.ip_version, IpVersion::Ipv6); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); - assert_eq!(allow_rule.ip_version, i32::from(IpVersion::Ipv6)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); + assert_eq!(allow_rule.ip_version, IpVersion::Ipv6); } #[sqlx::test] @@ -244,16 +246,16 @@ async fn test_gh1868_ipv4_and_ipv6_rules_are_created_with_any_destination( assert_eq!(generated_firewall_rules.len(), 4); let allow_rule_ipv4 = &generated_firewall_rules[0]; - assert_eq!(allow_rule_ipv4.verdict(), FirewallPolicy::Allow); - assert_eq!(allow_rule_ipv4.ip_version(), IpVersion::Ipv4); + assert_eq!(allow_rule_ipv4.verdict, FirewallPolicy::Allow); + assert_eq!(allow_rule_ipv4.ip_version, IpVersion::Ipv4); let allow_rule_ipv6 = &generated_firewall_rules[1]; - assert_eq!(allow_rule_ipv6.verdict(), FirewallPolicy::Allow); - assert_eq!(allow_rule_ipv6.ip_version(), IpVersion::Ipv6); + assert_eq!(allow_rule_ipv6.verdict, FirewallPolicy::Allow); + assert_eq!(allow_rule_ipv6.ip_version, IpVersion::Ipv6); let deny_rule_ipv4 = &generated_firewall_rules[2]; - assert_eq!(deny_rule_ipv4.verdict(), FirewallPolicy::Deny); - assert_eq!(allow_rule_ipv4.ip_version(), IpVersion::Ipv4); + assert_eq!(deny_rule_ipv4.verdict, FirewallPolicy::Deny); + assert_eq!(allow_rule_ipv4.ip_version, IpVersion::Ipv4); let deny_rule_ipv6 = &generated_firewall_rules[3]; - assert_eq!(deny_rule_ipv6.verdict(), FirewallPolicy::Deny); - assert_eq!(allow_rule_ipv6.ip_version(), IpVersion::Ipv6); + assert_eq!(deny_rule_ipv6.verdict, FirewallPolicy::Deny); + assert_eq!(allow_rule_ipv6.ip_version, IpVersion::Ipv6); } diff --git a/crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs b/crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs index e163d35a8..89a843d45 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs @@ -1,9 +1,6 @@ +use defguard_common::gateway_types::{IpAddress, IpRange, Port, PortRange as GwPortRange}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use defguard_proto::enterprise::firewall::{ - IpAddress, IpRange, Port, PortRange as PortRangeProto, ip_address::Address, - port::Port as PortInner, -}; use ipnetwork::Ipv6Network; use crate::enterprise::{ @@ -29,30 +26,14 @@ fn test_merge_v4_addrs() { assert_eq!( merged_addrs, [ - IpAddress { - address: Some(Address::Ip("10.0.8.127".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.8.128/25".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.9.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.10.0/27".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("10.0.20.20".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.60.20/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.60.24/31".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.0.20".to_owned())), - }, + IpAddress::Ip("10.0.8.127".to_owned()), + IpAddress::IpSubnet("10.0.8.128/25".to_owned()), + IpAddress::IpSubnet("10.0.9.0/24".to_owned()), + IpAddress::IpSubnet("10.0.10.0/27".to_owned()), + IpAddress::Ip("10.0.20.20".to_owned()), + IpAddress::IpSubnet("10.0.60.20/30".to_owned()), + IpAddress::IpSubnet("10.0.60.24/31".to_owned()), + IpAddress::Ip("192.168.0.20".to_owned()), ] ); @@ -69,12 +50,8 @@ fn test_merge_v4_addrs() { assert_eq!( merged_addrs, [ - IpAddress { - address: Some(Address::IpSubnet("10.0.10.0/30".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("10.0.10.20".to_owned())), - }, + IpAddress::IpSubnet("10.0.10.0/30".to_owned()), + IpAddress::Ip("10.0.10.20".to_owned()), ] ); } @@ -96,27 +73,13 @@ fn test_merge_v6_addrs() { assert_eq!( merged_addrs, [ - IpAddress { - address: Some(Address::Ip("2001:db8:1::1".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:1::2/127".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:1::4/126".to_owned())) - }, - IpAddress { - address: Some(Address::Ip("2001:db8:1::8".to_owned())) - }, - IpAddress { - address: Some(Address::Ip("2001:db8:2::1".to_owned())) - }, - IpAddress { - address: Some(Address::Ip("2001:db8:3::1".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:3::2/127".to_owned())) - } + IpAddress::Ip("2001:db8:1::1".to_owned()), + IpAddress::IpSubnet("2001:db8:1::2/127".to_owned()), + IpAddress::IpSubnet("2001:db8:1::4/126".to_owned()), + IpAddress::Ip("2001:db8:1::8".to_owned()), + IpAddress::Ip("2001:db8:2::1".to_owned()), + IpAddress::Ip("2001:db8:3::1".to_owned()), + IpAddress::IpSubnet("2001:db8:3::2/127".to_owned()) ] ); } @@ -132,12 +95,8 @@ fn test_merge_addrs_extracts_ipv4_subnets() { assert_eq!( result, [ - IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.2.0/24".to_owned())) - }, + IpAddress::IpSubnet("192.168.1.0/24".to_owned()), + IpAddress::IpSubnet("192.168.2.0/24".to_owned()), ] ); } @@ -153,12 +112,8 @@ fn test_merge_addrs_extracts_ipv6_subnets() { assert_eq!( result, [ - IpAddress { - address: Some(Address::IpSubnet("2001:db8::/32".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db9::/112".to_owned())) - }, + IpAddress::IpSubnet("2001:db8::/32".to_owned()), + IpAddress::IpSubnet("2001:db9::/112".to_owned()), ] ); } @@ -173,12 +128,10 @@ fn test_merge_addrs_falls_back_to_range_when_no_subnet_fits() { assert_eq!( result, - [IpAddress { - address: Some(Address::IpRange(IpRange { - start: "192.168.1.255".to_owned(), - end: "192.168.2.0".to_owned(), - })), - },] + [IpAddress::IpRange(IpRange { + start: "192.168.1.255".to_owned(), + end: "192.168.2.0".to_owned(), + }),] ); let start = "2001:db8:ffff:ffff:ffff:ffff:ffff:ffff" @@ -191,12 +144,10 @@ fn test_merge_addrs_falls_back_to_range_when_no_subnet_fits() { assert_eq!( result, - [IpAddress { - address: Some(Address::IpRange(IpRange { - start: "2001:db8:ffff:ffff:ffff:ffff:ffff:ffff".to_owned(), - end: "2001:db9::".to_owned(), - })), - },] + [IpAddress::IpRange(IpRange { + start: "2001:db8:ffff:ffff:ffff:ffff:ffff:ffff".to_owned(), + end: "2001:db9::".to_owned(), + }),] ); } @@ -208,12 +159,7 @@ fn test_merge_addrs_handles_single_ip() { let result = merge_addrs(ranges); - assert_eq!( - result, - [IpAddress { - address: Some(Address::Ip("192.168.1.1".to_owned())), - },] - ); + assert_eq!(result, [IpAddress::Ip("192.168.1.1".to_owned()),]); let start = "2001:db8::".parse::().unwrap(); let end = "2001:db8::".parse::().unwrap(); @@ -221,12 +167,7 @@ fn test_merge_addrs_handles_single_ip() { let result = merge_addrs(ranges); - assert_eq!( - result, - [IpAddress { - address: Some(Address::Ip("2001:db8::".to_owned())), - },] - ); + assert_eq!(result, [IpAddress::Ip("2001:db8::".to_owned()),]); } #[test] @@ -298,12 +239,8 @@ fn test_merge_addrs_subnet_at_start_of_range() { assert_eq!( result, [ - IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/26".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.1.64".to_owned())), - }, + IpAddress::IpSubnet("192.168.1.0/26".to_owned()), + IpAddress::Ip("192.168.1.64".to_owned()), ] ); @@ -317,12 +254,8 @@ fn test_merge_addrs_subnet_at_start_of_range() { assert_eq!( result, [ - IpAddress { - address: Some(Address::IpSubnet("2001:db8::/122".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("2001:db8::40".to_owned())), - }, + IpAddress::IpSubnet("2001:db8::/122".to_owned()), + IpAddress::Ip("2001:db8::40".to_owned()), ] ); } @@ -338,12 +271,8 @@ fn test_merge_addrs_subnet_at_end_of_range() { assert_eq!( result, [ - IpAddress { - address: Some(Address::Ip("192.168.1.15".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.1.16/28".to_owned())), - }, + IpAddress::Ip("192.168.1.15".to_owned()), + IpAddress::IpSubnet("192.168.1.16/28".to_owned()), ] ); @@ -357,12 +286,8 @@ fn test_merge_addrs_subnet_at_end_of_range() { assert_eq!( result, [ - IpAddress { - address: Some(Address::Ip("2001:db8::f".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8::10/124".to_owned())), - }, + IpAddress::Ip("2001:db8::f".to_owned()), + IpAddress::IpSubnet("2001:db8::10/124".to_owned()), ] ); } @@ -372,12 +297,7 @@ fn test_merge_port_ranges() { // single port let input_ranges = vec![PortRange::new(100, 100)]; let merged = merge_port_ranges(input_ranges); - assert_eq!( - merged, - [Port { - port: Some(PortInner::SinglePort(100)) - }] - ); + assert_eq!(merged, [Port::Single(100)]); // overlapping ranges let input_ranges = vec![ @@ -388,12 +308,10 @@ fn test_merge_port_ranges() { let merged = merge_port_ranges(input_ranges); assert_eq!( merged, - [Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 100, - end: 300 - })) - }] + [Port::Range(GwPortRange { + start: 100, + end: 300 + })] ); // duplicate ranges @@ -412,18 +330,14 @@ fn test_merge_port_ranges() { assert_eq!( merged, [ - Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 100, - end: 300 - })) - }, - Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 350, - end: 400 - })) - } + Port::Range(GwPortRange { + start: 100, + end: 300 + }), + Port::Range(GwPortRange { + start: 350, + end: 400 + }) ] ); @@ -440,24 +354,16 @@ fn test_merge_port_ranges() { assert_eq!( merged, [ - Port { - port: Some(PortInner::SinglePort(50)) - }, - Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 151, - end: 300 - })) - }, - Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 501, - end: 699 - })) - }, - Port { - port: Some(PortInner::SinglePort(800)) - } + Port::Single(50), + Port::Range(GwPortRange { + start: 151, + end: 300 + }), + Port::Range(GwPortRange { + start: 501, + end: 699 + }), + Port::Single(800) ] ); @@ -466,12 +372,10 @@ fn test_merge_port_ranges() { let merged = merge_port_ranges(input_ranges); assert_eq!( merged, - [Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 100, - end: 200 - })) - }] + [Port::Range(GwPortRange { + start: 100, + end: 200 + })] ); } diff --git a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs index 616f4373c..e25397c8f 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs @@ -1,18 +1,21 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use chrono::NaiveDateTime; -use defguard_common::db::{ - Id, NoId, - models::{ - Device, DeviceType, WireguardNetwork, device::WireguardNetworkDevice, group::Group, - user::User, +use defguard_common::{ + db::{ + Id, NoId, + models::{ + Device, DeviceType, WireguardNetwork, device::WireguardNetworkDevice, group::Group, + user::User, + }, + setup_pool, + }, + gateway_types::{ + FirewallPolicy, IpAddress, IpRange, IpVersion, Port, PortRange as GwPortRange, + Protocol as GwProtocol, }, - setup_pool, -}; -use defguard_proto::enterprise::firewall::{ - FirewallPolicy, IpAddress, IpRange, IpVersion, Port, PortRange as PortRangeProto, Protocol, - ip_address::Address, port::Port as PortInner, }; +use defguard_proto::enterprise::firewall::Protocol as ProtoProtocol; use ipnetwork::IpNetwork; use rand::{Rng, rngs::ThreadRng, thread_rng}; use sqlx::{ @@ -78,12 +81,10 @@ fn random_network_device_with_id(rng: &mut R, id: Id) -> Device { fn expected_ipv4_source_range_for_user(user_id: Id) -> IpAddress { let user_octet = user_id as u8; - IpAddress { - address: Some(Address::IpRange(IpRange { - start: format!("10.0.{user_octet}.1"), - end: format!("10.0.{user_octet}.2"), - })), - } + IpAddress::IpRange(IpRange { + start: format!("10.0.{user_octet}.1"), + end: format!("10.0.{user_octet}.2"), + }) } async fn create_test_user_with_devices( @@ -483,7 +484,7 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO PortRange::new(80, 80).into(), PortRange::new(443, 443).into(), ], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -529,7 +530,7 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO deny_all_network_devices: false, addresses: Vec::new(), // Will use destination ranges instead ports: vec![PortRange::new(53, 53).into()], - protocols: vec![Protocol::Udp.into(), Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Udp as i32, ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -585,7 +586,7 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO .unwrap(); assert_eq!( generated_firewall_config.default_policy, - i32::from(FirewallPolicy::Allow) + FirewallPolicy::Allow ); let generated_firewall_rules = generated_firewall_config.rules; @@ -594,142 +595,87 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO // First ACL - Web Access ALLOW let web_allow_rule = &generated_firewall_rules[0]; - assert_eq!(web_allow_rule.verdict, i32::from(FirewallPolicy::Allow)); - assert_eq!(web_allow_rule.protocols, vec![i32::from(Protocol::Tcp)]); + assert_eq!(web_allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(web_allow_rule.protocols, vec![GwProtocol::Tcp]); assert_eq!( web_allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }] + [IpAddress::IpSubnet("192.168.1.0/24".to_owned())] ); assert_eq!( web_allow_rule.destination_ports, - [ - Port { - port: Some(PortInner::SinglePort(80)) - }, - Port { - port: Some(PortInner::SinglePort(443)) - } - ] + [Port::Single(80), Port::Single(443)] ); // Source addresses should include devices of users 1,2 and network_device_1 assert_eq!( web_allow_rule.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::Ip("10.0.100.1".to_owned())), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), + IpAddress::Ip("10.0.100.1".to_owned()), ] ); // First ACL - Web Access DENY let web_deny_rule = &generated_firewall_rules[2]; - assert_eq!(web_deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(web_deny_rule.verdict, FirewallPolicy::Deny); assert!(web_deny_rule.protocols.is_empty()); assert!(web_deny_rule.destination_ports.is_empty()); assert!(web_deny_rule.source_addrs.is_empty()); assert_eq!( web_deny_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }] + [IpAddress::IpSubnet("192.168.1.0/24".to_owned())] ); // Second ACL - DNS Access ALLOW let dns_allow_rule = &generated_firewall_rules[1]; - assert_eq!(dns_allow_rule.verdict, i32::from(FirewallPolicy::Allow)); - assert_eq!( - dns_allow_rule.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); - assert_eq!( - dns_allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(53)) - }] - ); + assert_eq!(dns_allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(dns_allow_rule.protocols, [GwProtocol::Tcp, GwProtocol::Udp]); + assert_eq!(dns_allow_rule.destination_ports, [Port::Single(53)]); // Source addresses should include network_devices 1,2 assert_eq!( dns_allow_rule.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.100.1".to_owned(), - end: "10.0.100.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.100.1".to_owned(), + end: "10.0.100.2".to_owned(), + }), ] ); let expected_destination_addrs = [ - IpAddress { - address: Some(Address::Ip("10.0.1.13".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.14/31".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.16/28".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.32/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.40/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.52/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.56/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.64/26".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.128/25".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.0/27".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.32/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.40/30".to_owned())), - }, + IpAddress::Ip("10.0.1.13".to_owned()), + IpAddress::IpSubnet("10.0.1.14/31".to_owned()), + IpAddress::IpSubnet("10.0.1.16/28".to_owned()), + IpAddress::IpSubnet("10.0.1.32/29".to_owned()), + IpAddress::IpSubnet("10.0.1.40/30".to_owned()), + IpAddress::IpSubnet("10.0.1.52/30".to_owned()), + IpAddress::IpSubnet("10.0.1.56/29".to_owned()), + IpAddress::IpSubnet("10.0.1.64/26".to_owned()), + IpAddress::IpSubnet("10.0.1.128/25".to_owned()), + IpAddress::IpSubnet("10.0.2.0/27".to_owned()), + IpAddress::IpSubnet("10.0.2.32/29".to_owned()), + IpAddress::IpSubnet("10.0.2.40/30".to_owned()), ]; assert_eq!(dns_allow_rule.destination_addrs, expected_destination_addrs); // Second ACL - DNS Access DENY let dns_deny_rule = &generated_firewall_rules[3]; - assert_eq!(dns_deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(dns_deny_rule.verdict, FirewallPolicy::Deny); assert!(dns_deny_rule.protocols.is_empty(),); assert!(dns_deny_rule.destination_ports.is_empty(),); assert!(dns_deny_rule.source_addrs.is_empty(),); @@ -904,7 +850,7 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO PortRange::new(80, 80).into(), PortRange::new(443, 443).into(), ], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -950,7 +896,7 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO deny_all_network_devices: false, addresses: Vec::new(), // Will use destination ranges instead ports: vec![PortRange::new(53, 53).into()], - protocols: vec![Protocol::Udp.into(), Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Udp as i32, ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -1006,7 +952,7 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO .unwrap(); assert_eq!( generated_firewall_config.default_policy, - i32::from(FirewallPolicy::Allow) + FirewallPolicy::Allow ); let generated_firewall_rules = generated_firewall_config.rules; @@ -1015,166 +961,95 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO // First ACL - Web Access ALLOW let web_allow_rule = &generated_firewall_rules[0]; - assert_eq!(web_allow_rule.verdict, i32::from(FirewallPolicy::Allow)); - assert_eq!(web_allow_rule.protocols, vec![i32::from(Protocol::Tcp)]); + assert_eq!(web_allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(web_allow_rule.protocols, vec![GwProtocol::Tcp]); assert_eq!( web_allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("fc00::/112".to_owned())), - }] + [IpAddress::IpSubnet("fc00::/112".to_owned())] ); assert_eq!( web_allow_rule.destination_ports, - [ - Port { - port: Some(PortInner::SinglePort(80)) - }, - Port { - port: Some(PortInner::SinglePort(443)) - } - ] + [Port::Single(80), Port::Single(443)] ); // Source addresses should include devices of users 1,2 and network_device_1 assert_eq!( web_allow_rule.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::1:1".to_owned(), - end: "ff00::1:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::2:1".to_owned(), - end: "ff00::2:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::Ip("ff00::100:1".to_owned())), - }, + IpAddress::IpRange(IpRange { + start: "ff00::1:1".to_owned(), + end: "ff00::1:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::2:1".to_owned(), + end: "ff00::2:2".to_owned(), + }), + IpAddress::Ip("ff00::100:1".to_owned()), ] ); // First ACL - Web Access DENY let web_deny_rule = &generated_firewall_rules[2]; - assert_eq!(web_deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(web_deny_rule.verdict, FirewallPolicy::Deny); assert!(web_deny_rule.protocols.is_empty()); assert!(web_deny_rule.destination_ports.is_empty()); assert!(web_deny_rule.source_addrs.is_empty()); assert_eq!( web_deny_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("fc00::/112".to_owned())), - }] + [IpAddress::IpSubnet("fc00::/112".to_owned())] ); // Second ACL - DNS Access ALLOW let dns_allow_rule = &generated_firewall_rules[1]; - assert_eq!(dns_allow_rule.verdict, i32::from(FirewallPolicy::Allow)); - assert_eq!( - dns_allow_rule.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); - assert_eq!( - dns_allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(53)) - }] - ); + assert_eq!(dns_allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(dns_allow_rule.protocols, [GwProtocol::Tcp, GwProtocol::Udp]); + assert_eq!(dns_allow_rule.destination_ports, [Port::Single(53)]); let expected_destination_addrs = vec![ - IpAddress { - address: Some(Address::Ip("fc00::1:13".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:14/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:18/125".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:20/123".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:40/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:52/127".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:54/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:58/125".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:60/123".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:80/121".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:100/120".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:200/119".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:400/118".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:800/117".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:1000/116".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:2000/115".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:4000/114".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:8000/113".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::2:0/122".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::2:40/126".to_owned())), - }, + IpAddress::Ip("fc00::1:13".to_owned()), + IpAddress::IpSubnet("fc00::1:14/126".to_owned()), + IpAddress::IpSubnet("fc00::1:18/125".to_owned()), + IpAddress::IpSubnet("fc00::1:20/123".to_owned()), + IpAddress::IpSubnet("fc00::1:40/126".to_owned()), + IpAddress::IpSubnet("fc00::1:52/127".to_owned()), + IpAddress::IpSubnet("fc00::1:54/126".to_owned()), + IpAddress::IpSubnet("fc00::1:58/125".to_owned()), + IpAddress::IpSubnet("fc00::1:60/123".to_owned()), + IpAddress::IpSubnet("fc00::1:80/121".to_owned()), + IpAddress::IpSubnet("fc00::1:100/120".to_owned()), + IpAddress::IpSubnet("fc00::1:200/119".to_owned()), + IpAddress::IpSubnet("fc00::1:400/118".to_owned()), + IpAddress::IpSubnet("fc00::1:800/117".to_owned()), + IpAddress::IpSubnet("fc00::1:1000/116".to_owned()), + IpAddress::IpSubnet("fc00::1:2000/115".to_owned()), + IpAddress::IpSubnet("fc00::1:4000/114".to_owned()), + IpAddress::IpSubnet("fc00::1:8000/113".to_owned()), + IpAddress::IpSubnet("fc00::2:0/122".to_owned()), + IpAddress::IpSubnet("fc00::2:40/126".to_owned()), ]; // Source addresses should include network_devices 1,2 assert_eq!( dns_allow_rule.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::1:1".to_owned(), - end: "ff00::1:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::2:1".to_owned(), - end: "ff00::2:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::100:1".to_owned(), - end: "ff00::100:2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "ff00::1:1".to_owned(), + end: "ff00::1:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::2:1".to_owned(), + end: "ff00::2:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::100:1".to_owned(), + end: "ff00::100:2".to_owned(), + }), ] ); assert_eq!(dns_allow_rule.destination_addrs, expected_destination_addrs); // Second ACL - DNS Access DENY let dns_deny_rule = &generated_firewall_rules[3]; - assert_eq!(dns_deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(dns_deny_rule.verdict, FirewallPolicy::Deny); assert!(dns_deny_rule.protocols.is_empty(),); assert!(dns_deny_rule.destination_ports.is_empty(),); assert!(dns_deny_rule.source_addrs.is_empty(),); @@ -1368,7 +1243,7 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P PortRange::new(80, 80).into(), PortRange::new(443, 443).into(), ], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -1414,7 +1289,7 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P deny_all_network_devices: false, addresses: Vec::new(), // Will use destination ranges instead ports: vec![PortRange::new(53, 53).into()], - protocols: vec![Protocol::Udp.into(), Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Udp as i32, ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -1472,7 +1347,7 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P .unwrap(); assert_eq!( generated_firewall_config.default_policy, - i32::from(FirewallPolicy::Allow) + FirewallPolicy::Allow ); let generated_firewall_rules = generated_firewall_config.rules; @@ -1481,201 +1356,120 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P // First ACL - Web Access ALLOW let web_allow_rule_ipv4 = &generated_firewall_rules[0]; - assert_eq!( - web_allow_rule_ipv4.verdict, - i32::from(FirewallPolicy::Allow) - ); - assert_eq!( - web_allow_rule_ipv4.protocols, - vec![i32::from(Protocol::Tcp)] - ); + assert_eq!(web_allow_rule_ipv4.verdict, FirewallPolicy::Allow); + assert_eq!(web_allow_rule_ipv4.protocols, vec![GwProtocol::Tcp]); assert_eq!( web_allow_rule_ipv4.destination_addrs, - vec![IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }] + vec![IpAddress::IpSubnet("192.168.1.0/24".to_owned())] ); assert_eq!( web_allow_rule_ipv4.destination_ports, - vec![ - Port { - port: Some(PortInner::SinglePort(80)) - }, - Port { - port: Some(PortInner::SinglePort(443)) - } - ] + vec![Port::Single(80), Port::Single(443)] ); // Source addresses should include devices of users 1,2 and network_device_1 assert_eq!( web_allow_rule_ipv4.source_addrs, vec![ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::Ip("10.0.100.1".to_owned())), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), + IpAddress::Ip("10.0.100.1".to_owned()), ] ); let web_allow_rule_ipv6 = &generated_firewall_rules[1]; - assert_eq!( - web_allow_rule_ipv6.verdict, - i32::from(FirewallPolicy::Allow) - ); - assert_eq!(web_allow_rule_ipv6.protocols, [i32::from(Protocol::Tcp)]); + assert_eq!(web_allow_rule_ipv6.verdict, FirewallPolicy::Allow); + assert_eq!(web_allow_rule_ipv6.protocols, [GwProtocol::Tcp]); assert_eq!( web_allow_rule_ipv6.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("fc00::/112".to_owned())), - }] + [IpAddress::IpSubnet("fc00::/112".to_owned())] ); assert_eq!( web_allow_rule_ipv6.destination_ports, - [ - Port { - port: Some(PortInner::SinglePort(80)) - }, - Port { - port: Some(PortInner::SinglePort(443)) - } - ] + [Port::Single(80), Port::Single(443)] ); // Source addresses should include devices of users 1,2 and network_device_1 assert_eq!( web_allow_rule_ipv6.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::1:1".to_owned(), - end: "ff00::1:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::2:1".to_owned(), - end: "ff00::2:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::Ip("ff00::100:1".to_owned())), - }, + IpAddress::IpRange(IpRange { + start: "ff00::1:1".to_owned(), + end: "ff00::1:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::2:1".to_owned(), + end: "ff00::2:2".to_owned(), + }), + IpAddress::Ip("ff00::100:1".to_owned()), ] ); // First ACL - Web Access DENY let web_deny_rule_ipv4 = &generated_firewall_rules[4]; - assert_eq!(web_deny_rule_ipv4.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(web_deny_rule_ipv4.verdict, FirewallPolicy::Deny); assert!(web_deny_rule_ipv4.protocols.is_empty()); assert!(web_deny_rule_ipv4.destination_ports.is_empty()); assert!(web_deny_rule_ipv4.source_addrs.is_empty()); assert_eq!( web_deny_rule_ipv4.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }] + [IpAddress::IpSubnet("192.168.1.0/24".to_owned())] ); let web_deny_rule_ipv6 = &generated_firewall_rules[5]; - assert_eq!(web_deny_rule_ipv6.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(web_deny_rule_ipv6.verdict, FirewallPolicy::Deny); assert!(web_deny_rule_ipv6.protocols.is_empty()); assert!(web_deny_rule_ipv6.destination_ports.is_empty()); assert!(web_deny_rule_ipv6.source_addrs.is_empty()); assert_eq!( web_deny_rule_ipv6.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("fc00::/112".to_owned())), - }] + [IpAddress::IpSubnet("fc00::/112".to_owned())] ); // Second ACL - DNS Access ALLOW let dns_allow_rule_ipv4 = &generated_firewall_rules[2]; - assert_eq!( - dns_allow_rule_ipv4.verdict, - i32::from(FirewallPolicy::Allow) - ); + assert_eq!(dns_allow_rule_ipv4.verdict, FirewallPolicy::Allow); assert_eq!( dns_allow_rule_ipv4.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); - assert_eq!( - dns_allow_rule_ipv4.destination_ports, - [Port { - port: Some(PortInner::SinglePort(53)) - }] + [GwProtocol::Tcp, GwProtocol::Udp] ); + assert_eq!(dns_allow_rule_ipv4.destination_ports, [Port::Single(53)]); // Source addresses should include network_devices 1,2 assert_eq!( dns_allow_rule_ipv4.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.100.1".to_owned(), - end: "10.0.100.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.100.1".to_owned(), + end: "10.0.100.2".to_owned(), + }), ] ); let expected_destination_addrs_v4 = vec![ - IpAddress { - address: Some(Address::Ip("10.0.1.13".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.14/31".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.16/28".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.32/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.40/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.52/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.56/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.64/26".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.128/25".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.0/27".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.32/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.40/30".to_owned())), - }, + IpAddress::Ip("10.0.1.13".to_owned()), + IpAddress::IpSubnet("10.0.1.14/31".to_owned()), + IpAddress::IpSubnet("10.0.1.16/28".to_owned()), + IpAddress::IpSubnet("10.0.1.32/29".to_owned()), + IpAddress::IpSubnet("10.0.1.40/30".to_owned()), + IpAddress::IpSubnet("10.0.1.52/30".to_owned()), + IpAddress::IpSubnet("10.0.1.56/29".to_owned()), + IpAddress::IpSubnet("10.0.1.64/26".to_owned()), + IpAddress::IpSubnet("10.0.1.128/25".to_owned()), + IpAddress::IpSubnet("10.0.2.0/27".to_owned()), + IpAddress::IpSubnet("10.0.2.32/29".to_owned()), + IpAddress::IpSubnet("10.0.2.40/30".to_owned()), ]; assert_eq!( @@ -1684,106 +1478,52 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P ); let dns_allow_rule_ipv6 = &generated_firewall_rules[3]; - assert_eq!( - dns_allow_rule_ipv6.verdict, - i32::from(FirewallPolicy::Allow) - ); + assert_eq!(dns_allow_rule_ipv6.verdict, FirewallPolicy::Allow); assert_eq!( dns_allow_rule_ipv6.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); - assert_eq!( - dns_allow_rule_ipv6.destination_ports, - [Port { - port: Some(PortInner::SinglePort(53)) - }] + [GwProtocol::Tcp, GwProtocol::Udp] ); + assert_eq!(dns_allow_rule_ipv6.destination_ports, [Port::Single(53)]); // Source addresses should include network_devices 1,2 assert_eq!( dns_allow_rule_ipv6.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::1:1".to_owned(), - end: "ff00::1:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::2:1".to_owned(), - end: "ff00::2:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::100:1".to_owned(), - end: "ff00::100:2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "ff00::1:1".to_owned(), + end: "ff00::1:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::2:1".to_owned(), + end: "ff00::2:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::100:1".to_owned(), + end: "ff00::100:2".to_owned(), + }), ] ); let expected_destination_addrs_v6 = vec![ - IpAddress { - address: Some(Address::Ip("fc00::1:13".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:14/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:18/125".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:20/123".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:40/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:52/127".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:54/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:58/125".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:60/123".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:80/121".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:100/120".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:200/119".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:400/118".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:800/117".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:1000/116".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:2000/115".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:4000/114".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:8000/113".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::2:0/122".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::2:40/126".to_owned())), - }, + IpAddress::Ip("fc00::1:13".to_owned()), + IpAddress::IpSubnet("fc00::1:14/126".to_owned()), + IpAddress::IpSubnet("fc00::1:18/125".to_owned()), + IpAddress::IpSubnet("fc00::1:20/123".to_owned()), + IpAddress::IpSubnet("fc00::1:40/126".to_owned()), + IpAddress::IpSubnet("fc00::1:52/127".to_owned()), + IpAddress::IpSubnet("fc00::1:54/126".to_owned()), + IpAddress::IpSubnet("fc00::1:58/125".to_owned()), + IpAddress::IpSubnet("fc00::1:60/123".to_owned()), + IpAddress::IpSubnet("fc00::1:80/121".to_owned()), + IpAddress::IpSubnet("fc00::1:100/120".to_owned()), + IpAddress::IpSubnet("fc00::1:200/119".to_owned()), + IpAddress::IpSubnet("fc00::1:400/118".to_owned()), + IpAddress::IpSubnet("fc00::1:800/117".to_owned()), + IpAddress::IpSubnet("fc00::1:1000/116".to_owned()), + IpAddress::IpSubnet("fc00::1:2000/115".to_owned()), + IpAddress::IpSubnet("fc00::1:4000/114".to_owned()), + IpAddress::IpSubnet("fc00::1:8000/113".to_owned()), + IpAddress::IpSubnet("fc00::2:0/122".to_owned()), + IpAddress::IpSubnet("fc00::2:40/126".to_owned()), ]; assert_eq!( @@ -1793,7 +1533,7 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P // Second ACL - DNS Access DENY let dns_deny_rule_ipv4 = &generated_firewall_rules[6]; - assert_eq!(dns_deny_rule_ipv4.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(dns_deny_rule_ipv4.verdict, FirewallPolicy::Deny); assert!(dns_deny_rule_ipv4.protocols.is_empty(),); assert!(dns_deny_rule_ipv4.destination_ports.is_empty(),); assert!(dns_deny_rule_ipv4.source_addrs.is_empty(),); @@ -1803,7 +1543,7 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P ); let dns_deny_rule_ipv6 = &generated_firewall_rules[7]; - assert_eq!(dns_deny_rule_ipv6.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(dns_deny_rule_ipv6.verdict, FirewallPolicy::Deny); assert!(dns_deny_rule_ipv6.protocols.is_empty(),); assert!(dns_deny_rule_ipv6.destination_ports.is_empty(),); assert!(dns_deny_rule_ipv6.source_addrs.is_empty(),); @@ -1889,30 +1629,22 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { // check generated rules assert_eq!(generated_firewall_rules.len(), 4); let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::Ip("10.0.2.3".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }, + IpAddress::Ip("10.0.2.3".to_owned()), + IpAddress::IpSubnet("192.168.1.0/24".to_owned()), ]; let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); assert!(allow_rule.destination_ports.is_empty()); @@ -1923,17 +1655,15 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { ); let alias_allow_rule = &generated_firewall_rules[1]; - assert_eq!(alias_allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(alias_allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(alias_allow_rule.source_addrs, expected_source_addrs); assert!(alias_allow_rule.destination_addrs.is_empty()); assert_eq!( alias_allow_rule.destination_ports, - vec![Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 100, - end: 200, - })) - }] + vec![Port::Range(GwPortRange { + start: 100, + end: 200, + })] ); assert!(alias_allow_rule.protocols.is_empty()); assert_eq!( @@ -1942,7 +1672,7 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { ); let deny_rule = &generated_firewall_rules[2]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -1953,7 +1683,7 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { ); let alias_deny_rule = &generated_firewall_rules[3]; - assert_eq!(alias_deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(alias_deny_rule.verdict, FirewallPolicy::Deny); assert!(alias_deny_rule.source_addrs.is_empty()); assert!(alias_deny_rule.destination_addrs.is_empty()); assert!(alias_deny_rule.destination_ports.is_empty()); @@ -2040,34 +1770,26 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt // check generated rules assert_eq!(generated_firewall_rules.len(), 4); let expected_source_addrs = vec![ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let alias_allow_rule_1 = &generated_firewall_rules[0]; - assert_eq!(alias_allow_rule_1.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(alias_allow_rule_1.verdict, FirewallPolicy::Allow); assert_eq!(alias_allow_rule_1.source_addrs, expected_source_addrs); assert_eq!( alias_allow_rule_1.destination_addrs, - vec![IpAddress { - address: Some(Address::Ip("10.0.2.3".to_owned())), - },] + vec![IpAddress::Ip("10.0.2.3".to_owned()),] ); assert_eq!( alias_allow_rule_1.destination_ports, - vec![Port { - port: Some(PortInner::SinglePort(5432)) - }] + vec![Port::Single(5432)] ); assert!(alias_allow_rule_1.protocols.is_empty()); assert_eq!( @@ -2076,19 +1798,15 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt ); let alias_allow_rule_2 = &generated_firewall_rules[1]; - assert_eq!(alias_allow_rule_2.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(alias_allow_rule_2.verdict, FirewallPolicy::Allow); assert_eq!(alias_allow_rule_2.source_addrs, expected_source_addrs); assert_eq!( alias_allow_rule_2.destination_addrs, - vec![IpAddress { - address: Some(Address::Ip("10.0.2.4".to_owned())), - },] + vec![IpAddress::Ip("10.0.2.4".to_owned()),] ); assert_eq!( alias_allow_rule_2.destination_ports, - vec![Port { - port: Some(PortInner::SinglePort(6379)) - }] + vec![Port::Single(6379)] ); assert!(alias_allow_rule_2.protocols.is_empty()); assert_eq!( @@ -2097,13 +1815,11 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt ); let alias_deny_rule_1 = &generated_firewall_rules[2]; - assert_eq!(alias_deny_rule_1.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(alias_deny_rule_1.verdict, FirewallPolicy::Deny); assert!(alias_deny_rule_1.source_addrs.is_empty()); assert_eq!( alias_deny_rule_1.destination_addrs, - vec![IpAddress { - address: Some(Address::Ip("10.0.2.3".to_owned())), - },] + vec![IpAddress::Ip("10.0.2.3".to_owned()),] ); assert!(alias_deny_rule_1.destination_ports.is_empty()); assert!(alias_deny_rule_1.protocols.is_empty()); @@ -2113,13 +1829,11 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt ); let alias_deny_rule_2 = &generated_firewall_rules[3]; - assert_eq!(alias_deny_rule_2.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(alias_deny_rule_2.verdict, FirewallPolicy::Deny); assert!(alias_deny_rule_2.source_addrs.is_empty()); assert_eq!( alias_deny_rule_2.destination_addrs, - vec![IpAddress { - address: Some(Address::Ip("10.0.2.4".to_owned())), - },] + vec![IpAddress::Ip("10.0.2.4".to_owned()),] ); assert!(alias_deny_rule_2.destination_ports.is_empty()); assert!(alias_deny_rule_2.protocols.is_empty()); @@ -2183,7 +1897,7 @@ async fn test_no_allowed_users_ipv4(_: PgPoolOptions, options: PgConnectOptions) // only deny rules are generated assert_eq!(generated_firewall_rules.len(), 2); for rule in generated_firewall_rules { - assert_eq!(rule.verdict(), FirewallPolicy::Deny); + assert_eq!(rule.verdict, FirewallPolicy::Deny); } } @@ -2238,7 +1952,7 @@ async fn test_allow_all_groups_expands_all_group_members_into_firewall_sources( allow_all_groups: true, addresses: vec!["192.168.10.0/24".parse().unwrap()], ports: vec![PortRange::new(443, 443).into()], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], any_address: false, any_port: false, any_protocol: false, @@ -2282,21 +1996,14 @@ async fn test_allow_all_groups_expands_all_group_members_into_firewall_sources( .collect(); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!( allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.10.0/24".to_owned())), - }] + [IpAddress::IpSubnet("192.168.10.0/24".to_owned())] ); - assert_eq!( - allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(443)), - }] - ); - assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + assert_eq!(allow_rule.destination_ports, [Port::Single(443)]); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp]); assert!( allow_rule .source_addrs @@ -2311,7 +2018,7 @@ async fn test_allow_all_groups_expands_all_group_members_into_firewall_sources( ); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -2374,7 +2081,7 @@ async fn test_allow_all_groups_deduplicates_shared_group_members_before_source_r allow_all_groups: true, addresses: vec!["192.168.30.0/24".parse().unwrap()], ports: vec![PortRange::new(443, 443).into()], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], any_address: false, any_port: false, any_protocol: false, @@ -2436,21 +2143,14 @@ async fn test_allow_all_groups_deduplicates_shared_group_members_before_source_r .collect(); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!( allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.30.0/24".to_owned())), - }] + [IpAddress::IpSubnet("192.168.30.0/24".to_owned())] ); - assert_eq!( - allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(443)), - }] - ); - assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + assert_eq!(allow_rule.destination_ports, [Port::Single(443)]); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp]); assert!( allow_rule .source_addrs @@ -2459,7 +2159,7 @@ async fn test_allow_all_groups_deduplicates_shared_group_members_before_source_r ); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -2518,7 +2218,7 @@ async fn test_deny_all_groups_excludes_members_of_every_group_from_firewall_sour deny_all_groups: true, addresses: vec!["192.168.20.0/24".parse().unwrap()], ports: vec![PortRange::new(8443, 8443).into()], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], any_address: false, any_port: false, any_protocol: false, @@ -2551,7 +2251,7 @@ async fn test_deny_all_groups_excludes_members_of_every_group_from_firewall_sour assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!( allow_rule.source_addrs, [expected_ipv4_source_range_for_user( @@ -2560,17 +2260,10 @@ async fn test_deny_all_groups_excludes_members_of_every_group_from_firewall_sour ); assert_eq!( allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.20.0/24".to_owned())), - }] - ); - assert_eq!( - allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(8443)), - }] + [IpAddress::IpSubnet("192.168.20.0/24".to_owned())] ); - assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + assert_eq!(allow_rule.destination_ports, [Port::Single(8443)]); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp]); for denied_user in [ grouped_denied_user_a.id, grouped_denied_user_b.id, @@ -2586,7 +2279,7 @@ async fn test_deny_all_groups_excludes_members_of_every_group_from_firewall_sour } let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -2653,7 +2346,7 @@ async fn test_deny_all_groups_deduplicates_shared_group_members_before_source_fi deny_all_groups: true, addresses: vec!["192.168.40.0/24".parse().unwrap()], ports: vec![PortRange::new(8443, 8443).into()], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], any_address: false, any_port: false, any_protocol: false, @@ -2714,7 +2407,7 @@ async fn test_deny_all_groups_deduplicates_shared_group_members_before_source_fi assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!( allow_rule.source_addrs, [expected_ipv4_source_range_for_user( @@ -2723,17 +2416,10 @@ async fn test_deny_all_groups_deduplicates_shared_group_members_before_source_fi ); assert_eq!( allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.40.0/24".to_owned())), - }] - ); - assert_eq!( - allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(8443)), - }] + [IpAddress::IpSubnet("192.168.40.0/24".to_owned())] ); - assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + assert_eq!(allow_rule.destination_ports, [Port::Single(8443)]); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp]); for denied_user in expected_denied_user_ids { assert!( allow_rule @@ -2744,7 +2430,7 @@ async fn test_deny_all_groups_deduplicates_shared_group_members_before_source_fi } let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -2880,28 +2566,24 @@ async fn test_empty_manual_destination_only_acl(_: PgPoolOptions, options: PgCon assert_eq!(generated_firewall_rules_ipv4.len(), 2); let expected_source_addrs_ipv4 = vec![ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let allow_rule_ipv4 = &generated_firewall_rules_ipv4[0]; - assert_eq!(allow_rule_ipv4.ip_version, i32::from(IpVersion::Ipv4)); - assert_eq!(allow_rule_ipv4.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule_ipv4.ip_version, IpVersion::Ipv4); + assert_eq!(allow_rule_ipv4.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule_ipv4.source_addrs, expected_source_addrs_ipv4); assert!(allow_rule_ipv4.destination_addrs.is_empty()); let deny_rule_ipv4 = &generated_firewall_rules_ipv4[1]; - assert_eq!(deny_rule_ipv4.ip_version, i32::from(IpVersion::Ipv4)); - assert_eq!(deny_rule_ipv4.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule_ipv4.ip_version, IpVersion::Ipv4); + assert_eq!(deny_rule_ipv4.verdict, FirewallPolicy::Deny); assert!(deny_rule_ipv4.source_addrs.is_empty()); assert!(deny_rule_ipv4.destination_addrs.is_empty()); @@ -2914,28 +2596,24 @@ async fn test_empty_manual_destination_only_acl(_: PgPoolOptions, options: PgCon assert_eq!(generated_firewall_rules_ipv6.len(), 2); let expected_source_addrs_ipv6 = vec![ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::1:1".to_owned(), - end: "ff00::1:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::2:1".to_owned(), - end: "ff00::2:2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "ff00::1:1".to_owned(), + end: "ff00::1:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::2:1".to_owned(), + end: "ff00::2:2".to_owned(), + }), ]; let allow_rule_ipv6 = &generated_firewall_rules_ipv6[0]; - assert_eq!(allow_rule_ipv6.ip_version, i32::from(IpVersion::Ipv6)); - assert_eq!(allow_rule_ipv6.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule_ipv6.ip_version, IpVersion::Ipv6); + assert_eq!(allow_rule_ipv6.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule_ipv6.source_addrs, expected_source_addrs_ipv6); assert!(allow_rule_ipv6.destination_addrs.is_empty()); let deny_rule_ipv6 = &generated_firewall_rules_ipv6[1]; - assert_eq!(deny_rule_ipv6.ip_version, i32::from(IpVersion::Ipv6)); - assert_eq!(deny_rule_ipv6.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule_ipv6.ip_version, IpVersion::Ipv6); + assert_eq!(deny_rule_ipv6.verdict, FirewallPolicy::Deny); assert!(deny_rule_ipv6.source_addrs.is_empty()); assert!(deny_rule_ipv6.destination_addrs.is_empty()); @@ -2949,26 +2627,26 @@ async fn test_empty_manual_destination_only_acl(_: PgPoolOptions, options: PgCon assert_eq!(generated_firewall_rules_ipv4_and_ipv6.len(), 4); let allow_rule_ipv4 = &generated_firewall_rules_ipv4_and_ipv6[0]; - assert_eq!(allow_rule_ipv4.ip_version, i32::from(IpVersion::Ipv4)); - assert_eq!(allow_rule_ipv4.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule_ipv4.ip_version, IpVersion::Ipv4); + assert_eq!(allow_rule_ipv4.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule_ipv4.source_addrs, expected_source_addrs_ipv4); assert!(allow_rule_ipv4.destination_addrs.is_empty()); let allow_rule_ipv6 = &generated_firewall_rules_ipv4_and_ipv6[1]; - assert_eq!(allow_rule_ipv6.ip_version, i32::from(IpVersion::Ipv6)); - assert_eq!(allow_rule_ipv6.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule_ipv6.ip_version, IpVersion::Ipv6); + assert_eq!(allow_rule_ipv6.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule_ipv6.source_addrs, expected_source_addrs_ipv6); assert!(allow_rule_ipv6.destination_addrs.is_empty()); let deny_rule_ipv4 = &generated_firewall_rules_ipv4_and_ipv6[2]; - assert_eq!(deny_rule_ipv4.ip_version, i32::from(IpVersion::Ipv4)); - assert_eq!(deny_rule_ipv4.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule_ipv4.ip_version, IpVersion::Ipv4); + assert_eq!(deny_rule_ipv4.verdict, FirewallPolicy::Deny); assert!(deny_rule_ipv4.source_addrs.is_empty()); assert!(deny_rule_ipv4.destination_addrs.is_empty()); let deny_rule_ipv6 = &generated_firewall_rules_ipv4_and_ipv6[3]; - assert_eq!(deny_rule_ipv6.ip_version, i32::from(IpVersion::Ipv6)); - assert_eq!(deny_rule_ipv6.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule_ipv6.ip_version, IpVersion::Ipv6); + assert_eq!(deny_rule_ipv6.verdict, FirewallPolicy::Deny); assert!(deny_rule_ipv6.source_addrs.is_empty()); assert!(deny_rule_ipv6.destination_addrs.is_empty()); } diff --git a/crates/defguard_core/src/enterprise/firewall/tests/source.rs b/crates/defguard_core/src/enterprise/firewall/tests/source.rs index 98d463820..4db9d82d9 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/source.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/source.rs @@ -1,7 +1,6 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - -use defguard_proto::enterprise::firewall::{IpAddress, IpVersion, ip_address::Address}; +use defguard_common::gateway_types::{IpAddress, IpVersion}; use rand::thread_rng; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use crate::enterprise::firewall::{ get_source_addrs, get_source_network_devices, get_source_users, @@ -71,21 +70,11 @@ fn test_process_source_addrs_v4() { assert_eq!( source_addrs, [ - IpAddress { - address: Some(Address::Ip("10.0.1.1".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.2/31".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.4/31".to_owned())) - }, - IpAddress { - address: Some(Address::Ip("172.16.1.1".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.1.100".to_owned())), - }, + IpAddress::Ip("10.0.1.1".to_owned()), + IpAddress::IpSubnet("10.0.1.2/31".to_owned()), + IpAddress::IpSubnet("10.0.1.4/31".to_owned()), + IpAddress::Ip("172.16.1.1".to_owned()), + IpAddress::Ip("192.168.1.100".to_owned()), ] ); @@ -126,21 +115,11 @@ fn test_process_source_addrs_v6() { assert_eq!( source_addrs, [ - IpAddress { - address: Some(Address::Ip("2001:db8::1".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8::2/127".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8::4/127".to_owned())) - }, - IpAddress { - address: Some(Address::Ip("2001:db8:0:1::1".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("2001:db8:0:2::1".to_owned())), - }, + IpAddress::Ip("2001:db8::1".to_owned()), + IpAddress::IpSubnet("2001:db8::2/127".to_owned()), + IpAddress::IpSubnet("2001:db8::4/127".to_owned()), + IpAddress::Ip("2001:db8:0:1::1".to_owned()), + IpAddress::Ip("2001:db8:0:2::1".to_owned()), ] ); diff --git a/crates/defguard_core/src/enterprise/handlers/openid_login.rs b/crates/defguard_core/src/enterprise/handlers/openid_login.rs index 47df9efd8..1cdc92d16 100644 --- a/crates/defguard_core/src/enterprise/handlers/openid_login.rs +++ b/crates/defguard_core/src/enterprise/handlers/openid_login.rs @@ -636,7 +636,7 @@ pub async fn auth_callback( // sync the groups for the MFA enabled user logging in through the provider without firing it on // every login attempt, even for standard, non-provider users. if let Err(err) = - sync_user_groups_if_configured(&user, &appstate.pool, &appstate.wireguard_tx).await + sync_user_groups_if_configured(&user, &appstate.pool, &appstate.gateway_tx).await { error!( "Failed to sync user groups for user {} with the directory while the user was trying \ diff --git a/crates/defguard_core/src/enterprise/snat/handlers.rs b/crates/defguard_core/src/enterprise/snat/handlers.rs index a591ac641..23082fc17 100644 --- a/crates/defguard_core/src/enterprise/snat/handlers.rs +++ b/crates/defguard_core/src/enterprise/snat/handlers.rs @@ -21,7 +21,7 @@ use crate::{ }, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{ApiResponse, ApiResult}, }; @@ -159,7 +159,7 @@ pub async fn create_snat_binding( debug!( "Sending firewall config update for location {location} affected by adding new SNAT binding" ); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, )); @@ -259,7 +259,7 @@ pub async fn modify_snat_binding( debug!( "Sending firewall config update for location {location} affected by adding new SNAT binding" ); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( location_id, firewall_config, )); @@ -344,7 +344,7 @@ pub async fn delete_snat_binding( debug!( "Sending firewall config update for location {location} affected by adding new SNAT binding" ); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( location_id, firewall_config, )); diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index f47b2a836..25b5556f8 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -10,18 +10,14 @@ use defguard_common::{ config::server_config, db::{ Id, - models::{ - Device, Settings, WireguardNetwork, - device::{DeviceInfo, DeviceNetworkInfo}, - wireguard::ServiceLocationMode, - }, + models::{Settings, WireguardNetwork, wireguard::ServiceLocationMode}, }, types::UrlParseError, }; use reqwest::Url; use serde::Serialize; use sqlx::PgPool; -use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; +use tokio::sync::mpsc::UnboundedSender; use crate::{ db::AppEvent, @@ -49,10 +45,7 @@ pub mod proto { } } -use defguard_proto::{ - enterprise::firewall::FirewallConfig, gateway::Peer, - worker::worker_service_server::WorkerServiceServer, -}; +use defguard_proto::worker::worker_service_server::WorkerServiceServer; use tonic::transport::{Identity, Server, ServerTlsConfig, server::Router}; // gRPC header for passing auth token from clients @@ -214,40 +207,9 @@ impl From for defguard_proto::client_types::InstanceInfo { } } -// TODO: move this to common crate -#[derive(Clone, Debug)] -pub enum GatewayEvent { - NetworkCreated(Id, WireguardNetwork), - NetworkModified(Id, WireguardNetwork, Vec, Option), - NetworkDeleted(Id, String), - DeviceCreated(DeviceInfo), - DeviceModified(DeviceInfo), - DeviceDeleted(DeviceInfo), - FirewallConfigChanged(Id, FirewallConfig), - FirewallDisabled(Id), - VpnSessionAuthorized(Id, Device, DeviceNetworkInfo), - VpnSessionDeauthorized(Id, Device), -} - -/// Sends given `GatewayEvent` to be handled by gateway GRPC server -/// -/// If you want to use it inside the API context, use [`crate::AppState::send_wireguard_event`] instead -pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { - debug!("Sending the following WireGuard event to Defguard Gateway: {event:?}"); - if let Err(err) = wg_tx.send(event) { - error!("Error sending WireGuard event {err}"); - } -} - -/// Sends multiple events to be handled by gateway gRPC server. -/// -/// If you want to use it inside the API context, use [`crate::AppState::send_multiple_wireguard_events`] instead -pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender) { - debug!("Sending {} WireGuard events", events.len()); - for event in events { - send_wireguard_event(event, wg_tx); - } -} +pub use defguard_common::gateway_event::{ + GatewayCommand, send_gateway_command, send_multiple_gateway_commands, +}; /// If this location is marked as a service location, checks if all requirements are met for it to /// function: diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 917f41809..2869fdb7d 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -50,7 +50,7 @@ use crate::{ posture::{PostureCheckError, PostureResult, validate_posture}, }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, - grpc::{GatewayEvent, utils::parse_client_ip_agent}, + grpc::{GatewayCommand, utils::parse_client_ip_agent}, }; const CLIENT_SESSION_TIMEOUT: u64 = 60 * 5; // 10 minutes @@ -82,7 +82,7 @@ pub struct ClientLoginSession { pub struct ClientMfaServer { pub(crate) pool: PgPool, - wireguard_tx: Sender, + gateway_tx: Sender, pub(crate) sessions: Arc>>, remote_mfa_responses: Arc>>>, bidi_event_tx: UnboundedSender, @@ -103,14 +103,14 @@ impl ClientMfaServer { #[must_use] pub fn new( pool: PgPool, - wireguard_tx: Sender, + gateway_tx: Sender, bidi_event_tx: UnboundedSender, remote_mfa_responses: Arc>>>, sessions: Arc>>, ) -> Self { Self { pool, - wireguard_tx, + gateway_tx, sessions, remote_mfa_responses, bidi_event_tx, @@ -745,8 +745,8 @@ impl ClientMfaServer { // send gateway event debug!("Sending `peer_create` message to gateway"); let event = - GatewayEvent::VpnSessionAuthorized(location.id, device.clone(), gateway_network_info); - self.wireguard_tx.send(event).map_err(|err| { + GatewayCommand::VpnSessionAuthorized(location.id, device.clone(), gateway_network_info); + self.gateway_tx.send(event).map_err(|err| { error!("Error sending WireGuard event: {err}"); Status::internal("unexpected error") })?; @@ -932,8 +932,8 @@ impl ClientMfaServer { })?; let event = - GatewayEvent::VpnSessionAuthorized(location.id, device.clone(), gateway_network_info); - self.wireguard_tx.send(event).map_err(|err| { + GatewayCommand::VpnSessionAuthorized(location.id, device.clone(), gateway_network_info); + self.gateway_tx.send(event).map_err(|err| { error!("Error sending WireGuard event: {err}"); Status::internal("unexpected error") })?; @@ -1028,8 +1028,8 @@ impl ClientMfaServer { // gateway update is only needed to remove peers that were authorized at runtime - MFA and posture-check sessions // this is needed to remove peers for both Connected and New sessions if requires_gateway_update { - let gateway_event = GatewayEvent::VpnSessionDeauthorized(location.id, device.clone()); - self.wireguard_tx.send(gateway_event).map_err(|err| { + let gateway_event = GatewayCommand::VpnSessionDeauthorized(location.id, device.clone()); + self.gateway_tx.send(gateway_event).map_err(|err| { error!("Error sending WireGuard event: {err}"); Status::internal("unexpected error") })?; @@ -1118,7 +1118,7 @@ mod tests { limits::{Counts, set_counts}, }, events::{BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, - grpc::{GatewayEvent, proto::enterprise::license::LicenseLimits}, + grpc::{GatewayCommand, proto::enterprise::license::LicenseLimits}, }; const REPLACEMENT_MFA_PRESHARED_KEY: &str = "replacement-mfa-psk"; @@ -1157,7 +1157,7 @@ mod tests { .try_recv() .expect("expected VPN authorization gateway event") { - GatewayEvent::VpnSessionAuthorized(location_id, authorized_device, network_info) => { + GatewayCommand::VpnSessionAuthorized(location_id, authorized_device, network_info) => { assert_eq!(location_id, location.id); assert_eq!(authorized_device.id, device.id); assert_eq!(network_info.network_id, location.id); @@ -1224,7 +1224,7 @@ mod tests { .try_recv() .expect("expected VPN deauthorization gateway event for replaced posture session") { - GatewayEvent::VpnSessionDeauthorized(location_id, disconnected_device) => { + GatewayCommand::VpnSessionDeauthorized(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } @@ -1234,7 +1234,7 @@ mod tests { .try_recv() .expect("expected VPN authorization gateway event for replacement posture session") { - GatewayEvent::VpnSessionAuthorized(location_id, authorized_device, network_info) => { + GatewayCommand::VpnSessionAuthorized(location_id, authorized_device, network_info) => { assert_eq!(location_id, location.id); assert_eq!(authorized_device.id, device.id); assert!(network_info.preshared_key.is_some()); @@ -1340,7 +1340,7 @@ mod tests { .try_recv() .expect("expected MFA gateway disconnect event for replaced connected session"); match gateway_event { - GatewayEvent::VpnSessionDeauthorized(location_id, disconnected_device) => { + GatewayCommand::VpnSessionDeauthorized(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } @@ -1415,7 +1415,7 @@ mod tests { .try_recv() .expect("expected MFA gateway disconnect event for replaced new session"); match gateway_event { - GatewayEvent::VpnSessionDeauthorized(location_id, disconnected_device) => { + GatewayCommand::VpnSessionDeauthorized(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } @@ -1508,9 +1508,9 @@ mod tests { ) -> ( ClientMfaServer, tokio::sync::mpsc::UnboundedReceiver, - tokio::sync::broadcast::Receiver, + tokio::sync::broadcast::Receiver, ) { - let (wireguard_tx, wireguard_rx) = broadcast::channel(8); + let (gateway_tx, gateway_rx) = broadcast::channel(8); let (bidi_event_tx, bidi_event_rx) = mpsc::unbounded_channel(); let remote_mfa_responses: Arc>>> = Arc::default(); @@ -1519,13 +1519,13 @@ mod tests { ( ClientMfaServer::new( pool, - wireguard_tx, + gateway_tx, bidi_event_tx, remote_mfa_responses, sessions, ), bidi_event_rx, - wireguard_rx, + gateway_rx, ) } @@ -1632,7 +1632,7 @@ mod tests { ); match gateway_rx.try_recv() { - Ok(GatewayEvent::VpnSessionDeauthorized(location_id, disconnected_device)) => { + Ok(GatewayCommand::VpnSessionDeauthorized(location_id, disconnected_device)) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } diff --git a/crates/defguard_core/src/handlers/group.rs b/crates/defguard_core/src/handlers/group.rs index fbe40ab42..db6db11a2 100644 --- a/crates/defguard_core/src/handlers/group.rs +++ b/crates/defguard_core/src/handlers/group.rs @@ -110,7 +110,7 @@ pub(crate) async fn bulk_assign_to_groups( } } - sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; + sync_all_networks(&mut transaction, &appstate.gateway_tx).await?; transaction.commit().await?; @@ -364,7 +364,7 @@ pub(crate) async fn create_group( .insert(&group_info.name); } - sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; + sync_all_networks(&mut transaction, &appstate.gateway_tx).await?; transaction.commit().await?; @@ -498,7 +498,7 @@ pub(crate) async fn modify_group( .insert(group.name.as_str()); } - sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; + sync_all_networks(&mut transaction, &appstate.gateway_tx).await?; let users_after = group.members(&mut *transaction).await?.clone(); transaction.commit().await?; @@ -608,7 +608,7 @@ pub(crate) async fn delete_group( // sync allowed devices for all locations let mut conn = appstate.pool.acquire().await?; - sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; + sync_all_networks(&mut conn, &appstate.gateway_tx).await?; info!( "User {} deleted group {}", @@ -665,7 +665,7 @@ pub(crate) async fn add_group_member( ldap_add_user_to_groups(&user, hashset![group.name.as_str()], &appstate.pool).await; ldap_update_user_state(&mut user, &appstate.pool).await; let mut conn = appstate.pool.acquire().await?; - sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; + sync_all_networks(&mut conn, &appstate.gateway_tx).await?; info!("Added user: {} to group: {}", user.username, group.name); appstate.emit_event(ApiEvent { context, @@ -728,7 +728,7 @@ pub(crate) async fn remove_group_member( .await; let mut conn = appstate.pool.acquire().await?; - sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; + sync_all_networks(&mut conn, &appstate.gateway_tx).await?; info!("Removed user: {} from group: {}", user.username, group.name); appstate.emit_event(ApiEvent { context, diff --git a/crates/defguard_core/src/handlers/network_devices.rs b/crates/defguard_core/src/handlers/network_devices.rs index 015ec84eb..6ccb1edfe 100644 --- a/crates/defguard_core/src/handlers/network_devices.rs +++ b/crates/defguard_core/src/handlers/network_devices.rs @@ -34,7 +34,7 @@ use crate::{ firewall::try_get_location_firewall_config, limits::update_counts, }, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{ device_for_admin_or_self, pagination::{PaginatedApiResponse, PaginatedApiResult, PaginationParams}, @@ -670,7 +670,7 @@ pub(crate) async fn add_network_device( .add_to_network(&network, &ips, &mut transaction) .await?; - appstate.send_wireguard_event(GatewayEvent::DeviceCreated(DeviceInfo { + appstate.send_gateway_command(GatewayCommand::DeviceCreated(DeviceInfo { device: device.clone(), network_info: vec![network_info.clone()], })); @@ -681,7 +681,7 @@ pub(crate) async fn add_network_device( if let Some(firewall_config) = try_get_location_firewall_config(&network, &mut transaction).await? { - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( network.id, firewall_config, )); @@ -778,14 +778,14 @@ pub async fn modify_network_device( wireguard_network_device.wireguard_ips = data.assigned_ips; wireguard_network_device.update(&mut *transaction).await?; let device_info = DeviceInfo::from_device(&mut *transaction, device.clone()).await?; - appstate.send_wireguard_event(GatewayEvent::DeviceModified(device_info)); + appstate.send_gateway_command(GatewayCommand::DeviceModified(device_info)); // send firewall update event if ACLs are enabled if device_network.acl_enabled { if let Some(firewall_config) = try_get_location_firewall_config(&device_network, &mut transaction).await? { - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( device_network.id, firewall_config, )); diff --git a/crates/defguard_core/src/handlers/user.rs b/crates/defguard_core/src/handlers/user.rs index 5b4fb2529..aa5d6662c 100644 --- a/crates/defguard_core/src/handlers/user.rs +++ b/crates/defguard_core/src/handlers/user.rs @@ -775,7 +775,7 @@ pub(crate) async fn modify_user( "User {} changed {username} groups or status, syncing allowed network devices.", session.user.username ); - sync_allowed_user_devices(&user, &mut transaction, &appstate.wireguard_tx).await?; + sync_allowed_user_devices(&user, &mut transaction, &appstate.gateway_tx).await?; } // remove API tokens when deactivating a user @@ -908,7 +908,7 @@ pub(crate) async fn delete_user( } else { None }; - delete_user_and_cleanup_devices(user.clone(), &mut transaction, &appstate.wireguard_tx) + delete_user_and_cleanup_devices(user.clone(), &mut transaction, &appstate.gateway_tx) .await?; appstate.trigger_action(AppEvent::UserDeleted(username.clone())); diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index 3e7095b49..103643bef 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -38,7 +38,7 @@ use crate::{ limits::{get_counts, update_counts}, }, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{gateway::GatewayInfo, network_devices::DeviceWireGuardConfig}, location_management::{ allowed_peers::get_location_allowed_peers, handle_imported_devices, handle_mapped_devices, @@ -260,7 +260,7 @@ pub(crate) async fn create_network( network.add_all_allowed_devices(&mut transaction).await?; info!("Assigning IPs for existing devices in network {network}"); - appstate.send_wireguard_event(GatewayEvent::NetworkCreated(network.id, network.clone())); + appstate.send_gateway_command(GatewayCommand::NetworkCreated(network.id, network.clone())); transaction.commit().await?; @@ -381,7 +381,7 @@ pub(crate) async fn modify_network( let peers = get_location_allowed_peers(&network, &mut transaction).await?; let maybe_firewall_config = try_get_location_firewall_config(&network, &mut transaction).await?; - appstate.send_wireguard_event(GatewayEvent::NetworkModified( + appstate.send_gateway_command(GatewayCommand::NetworkModified( network.id, network.clone(), peers, @@ -448,7 +448,7 @@ pub(crate) async fn delete_network( } network.clone().delete(&mut *transaction).await?; transaction.commit().await?; - appstate.send_wireguard_event(GatewayEvent::NetworkDeleted(network_id, network_name)); + appstate.send_gateway_command(GatewayCommand::NetworkDeleted(network_id, network_name)); info!( "User {} deleted WireGuard network {network_id}", session.user.username, @@ -655,7 +655,7 @@ pub(crate) async fn import_network( .await?; info!("New network {network} created"); - appstate.send_wireguard_event(GatewayEvent::NetworkCreated(network.id, network.clone())); + appstate.send_gateway_command(GatewayCommand::NetworkCreated(network.id, network.clone())); let reserved_ips = imported_devices .iter() @@ -663,13 +663,13 @@ pub(crate) async fn import_network( .collect::>(); let (devices, gateway_events) = handle_imported_devices(&network, &mut transaction, imported_devices).await?; - appstate.send_multiple_wireguard_events(gateway_events); + appstate.send_multiple_gateway_commands(gateway_events); // assign IPs for other existing devices debug!("Assigning IPs in imported network for remaining existing devices"); let gateway_events = sync_location_allowed_devices(&network, &mut transaction, Some(&reserved_ips)).await?; - appstate.send_multiple_wireguard_events(gateway_events); + appstate.send_multiple_gateway_commands(gateway_events); debug!("Assigned IPs in imported network for remaining existing devices"); transaction.commit().await?; @@ -716,7 +716,7 @@ pub(crate) async fn add_user_devices( // wrap loop in transaction to abort if a device is invalid let mut transaction = appstate.pool.begin().await?; let events = handle_mapped_devices(&network, &mut transaction, &mapped_devices).await?; - appstate.send_multiple_wireguard_events(events); + appstate.send_multiple_gateway_commands(events); transaction.commit().await?; info!( @@ -871,7 +871,7 @@ pub(crate) async fn add_device( let (network_info, configs) = device.add_to_all_networks(&mut transaction).await?; - // prepare a list of gateway events to be sent + // prepare a list of gateway commands to be sent let mut events = Vec::new(); // get all locations affected by device being added @@ -892,7 +892,7 @@ pub(crate) async fn add_device( "Sending firewall config update for location {location} affected by adding new \ user {username} devices" ); - events.push(GatewayEvent::FirewallConfigChanged( + events.push(GatewayCommand::FirewallConfigChanged( location_id, firewall_config, )); @@ -901,12 +901,12 @@ pub(crate) async fn add_device( } // add peer on relevant gateways - events.push(GatewayEvent::DeviceCreated(DeviceInfo { + events.push(GatewayCommand::DeviceCreated(DeviceInfo { device: device.clone(), network_info: network_info.clone(), })); - appstate.send_multiple_wireguard_events(events); + appstate.send_multiple_gateway_commands(events); let template_locations = configs .iter() @@ -1061,7 +1061,7 @@ pub(crate) async fn modify_device( network_info.push(device_network_info); } } - appstate.send_wireguard_event(GatewayEvent::DeviceModified(DeviceInfo { + appstate.send_gateway_command(GatewayCommand::DeviceModified(DeviceInfo { device: device.clone(), network_info, })); @@ -1186,7 +1186,7 @@ pub(crate) async fn delete_device( debug!( "Sending firewall config update for location {location} affected by deleting user {username} device" ); - events.push(GatewayEvent::FirewallConfigChanged( + events.push(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, )); @@ -1195,10 +1195,10 @@ pub(crate) async fn delete_device( } let device_id = device_info.device.id; - events.push(GatewayEvent::DeviceDeleted(device_info.clone())); + events.push(GatewayCommand::DeviceDeleted(device_info.clone())); - // send generated gateway events - appstate.send_multiple_wireguard_events(events); + // send generated gateway commands + appstate.send_multiple_gateway_commands(events); // Emit event specific to the device type. match device.device_type { diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index 48c20ba00..0b8a57ae5 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -126,7 +126,7 @@ use crate::{ create_snat_binding, delete_snat_binding, list_snat_bindings, modify_snat_binding, }, }, - grpc::{GatewayEvent, WorkerState}, + grpc::{GatewayCommand, WorkerState}, handlers::{ app_info::get_app_info, auth::{ @@ -256,7 +256,7 @@ async fn openapi() -> Json { pub fn build_webapp( webhook_tx: UnboundedSender, webhook_rx: UnboundedReceiver, - wireguard_tx: Sender, + gateway_tx: Sender, web_reload_tx: tokio::sync::broadcast::Sender<()>, worker_state: Arc>, pool: PgPool, @@ -734,7 +734,7 @@ pub fn build_webapp( pool.clone(), webhook_tx, webhook_rx, - wireguard_tx, + gateway_tx, web_reload_tx, key, failed_logins, @@ -804,7 +804,7 @@ pub async fn run_web_server( worker_state: Arc>, webhook_tx: UnboundedSender, webhook_rx: UnboundedReceiver, - wireguard_tx: Sender, + gateway_tx: Sender, web_reload_tx: tokio::sync::broadcast::Sender<()>, pool: PgPool, failed_logins: Arc>, @@ -822,7 +822,7 @@ pub async fn run_web_server( let webapp = build_webapp( webhook_tx, webhook_rx, - wireguard_tx, + gateway_tx, web_reload_tx.clone(), worker_state, pool.clone(), diff --git a/crates/defguard_core/src/location_management/allowed_peers.rs b/crates/defguard_core/src/location_management/allowed_peers.rs index 575ae6397..24d2629d2 100644 --- a/crates/defguard_core/src/location_management/allowed_peers.rs +++ b/crates/defguard_core/src/location_management/allowed_peers.rs @@ -1,5 +1,7 @@ -use defguard_common::db::{Id, models::WireguardNetwork}; -use defguard_proto::gateway::Peer; +use defguard_common::{ + db::{Id, models::WireguardNetwork}, + gateway_types::WireguardPeer, +}; use sqlx::{PgConnection, query}; use crate::grpc::should_prevent_service_location_usage; @@ -14,7 +16,7 @@ use crate::grpc::should_prevent_service_location_usage; pub async fn get_location_allowed_peers( location: &WireguardNetwork, conn: &mut PgConnection, -) -> sqlx::Result> { +) -> sqlx::Result> { debug!("Fetching all allowed peers for location {}", location.id); if should_prevent_service_location_usage(location) { @@ -47,7 +49,7 @@ pub async fn get_location_allowed_peers( return Ok(rows .into_iter() - .map(|row| Peer { + .map(|row| WireguardPeer { pubkey: row.pubkey, allowed_ips: row.allowed_ips, preshared_key: None, @@ -87,7 +89,7 @@ pub async fn get_location_allowed_peers( Ok(rows .into_iter() - .map(|row| Peer { + .map(|row| WireguardPeer { pubkey: row.pubkey, allowed_ips: row.allowed_ips, preshared_key: Some(row.preshared_key), diff --git a/crates/defguard_core/src/location_management/mod.rs b/crates/defguard_core/src/location_management/mod.rs index 0f58a4529..4d6328bd5 100644 --- a/crates/defguard_core/src/location_management/mod.rs +++ b/crates/defguard_core/src/location_management/mod.rs @@ -18,7 +18,7 @@ use tokio::sync::broadcast::Sender; use crate::{ enterprise::firewall::{FirewallError, try_get_location_firewall_config}, - grpc::{GatewayEvent, send_multiple_wireguard_events}, + grpc::{GatewayCommand, send_multiple_gateway_commands}, wg_config::ImportedDevice, }; @@ -41,7 +41,7 @@ pub enum LocationManagementError { // run sync_allowed_devices on all wireguard networks pub(crate) async fn sync_all_networks( conn: &mut PgConnection, - wireguard_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), LocationManagementError> { info!("Syncing allowed devices for all WireGuard locations"); let locations = WireguardNetwork::all(&mut *conn).await?; @@ -53,14 +53,14 @@ pub(crate) async fn sync_all_networks( if let Some(firewall_config) = try_get_location_firewall_config(&network, &mut *conn).await? { - gateway_events.push(GatewayEvent::FirewallConfigChanged( + gateway_events.push(GatewayCommand::FirewallConfigChanged( network.id, firewall_config, )); } - // check if any gateway events need to be sent + // check if any gateway commands need to be sent if !gateway_events.is_empty() { - send_multiple_wireguard_events(gateway_events, wireguard_tx); + send_multiple_gateway_commands(gateway_events, gateway_tx); } } Ok(()) @@ -75,7 +75,7 @@ pub(crate) async fn sync_location_allowed_devices( location: &WireguardNetwork, conn: &mut PgConnection, reserved_ips: Option<&[IpAddr]>, -) -> Result, LocationManagementError> { +) -> Result, LocationManagementError> { info!("Synchronizing IPs in network {location} for all allowed devices "); // list all allowed devices let mut allowed_devices = location.get_allowed_devices(&mut *conn).await?; @@ -119,7 +119,7 @@ pub(crate) async fn sync_allowed_devices_for_user( conn: &mut PgConnection, user: &User, reserved_ips: Option<&[IpAddr]>, -) -> Result, WireguardNetworkError> { +) -> Result, WireguardNetworkError> { info!("Synchronizing IPs in network {location} for all allowed devices "); // list all allowed devices let allowed_devices = location @@ -160,12 +160,12 @@ pub async fn process_device_access_changes( mut allowed_devices: HashMap>, currently_configured_devices: Vec, reserved_ips: Option<&[IpAddr]>, -) -> Result, WireguardNetworkError> { +) -> Result, WireguardNetworkError> { // Loop through current device configurations; remove no longer allowed, readdress // when necessary; remove processed entry from all devices list initial list should // now contain only devices to be added. let mut used_ips = location.all_used_ips_for_network(&mut *transaction).await?; - let mut events: Vec = Vec::new(); + let mut events: Vec = Vec::new(); for device_network_config in currently_configured_devices { // Device is allowed and an IP was already assigned if let Some(device) = allowed_devices.remove(&device_network_config.device_id) { @@ -186,7 +186,7 @@ pub async fn process_device_access_changes( let network_info = wireguard_network_device .to_device_network_info_runtime(&mut *transaction, location) .await?; - events.push(GatewayEvent::DeviceModified(DeviceInfo { + events.push(GatewayCommand::DeviceModified(DeviceInfo { device, network_info: vec![network_info], })); @@ -204,7 +204,7 @@ pub async fn process_device_access_changes( Device::find_by_id(&mut *transaction, device_network_config.device_id).await? { let network_info = device_network_config.to_device_network_info(location, None); - events.push(GatewayEvent::DeviceDeleted(DeviceInfo { + events.push(GatewayCommand::DeviceDeleted(DeviceInfo { device, network_info: vec![network_info], })); @@ -224,7 +224,7 @@ pub async fn process_device_access_changes( let network_info = wireguard_network_device .to_device_network_info_runtime(&mut *transaction, location) .await?; - events.push(GatewayEvent::DeviceCreated(DeviceInfo { + events.push(GatewayCommand::DeviceCreated(DeviceInfo { device, network_info: vec![network_info], })); @@ -236,12 +236,12 @@ pub async fn process_device_access_changes( /// Check if devices found in an imported config file exist already, /// if they do assign a specified IP. /// Return a list of imported devices which need to be manually mapped to a user -/// and a list of WireGuard events to be sent out. +/// and a list of gateway commands to be sent out. pub(crate) async fn handle_imported_devices( location: &WireguardNetwork, transaction: &mut PgConnection, imported_devices: Vec, -) -> Result<(Vec, Vec), WireguardNetworkError> { +) -> Result<(Vec, Vec), WireguardNetworkError> { let allowed_devices = location.get_allowed_devices(&mut *transaction).await?; // convert to a map for easier processing let allowed_devices: HashMap> = allowed_devices @@ -276,7 +276,7 @@ pub(crate) async fn handle_imported_devices( .to_device_network_info_runtime(&mut *transaction, location) .await?; // send device to connected gateways - events.push(GatewayEvent::DeviceModified(DeviceInfo { + events.push(GatewayCommand::DeviceModified(DeviceInfo { device: existing_device, network_info: vec![network_info], })); @@ -301,7 +301,7 @@ pub(crate) async fn handle_mapped_devices( location: &WireguardNetwork, conn: &mut PgConnection, mapped_devices: &[MappedDevice], -) -> Result, WireguardNetworkError> { +) -> Result, WireguardNetworkError> { info!("Mapping user devices for network {location}"); // get allowed groups for network let allowed_groups = location.get_allowed_groups(&mut *conn).await?; @@ -366,7 +366,7 @@ pub(crate) async fn handle_mapped_devices( // send device to connected gateways if !network_info.is_empty() { - events.push(GatewayEvent::DeviceCreated(DeviceInfo { + events.push(GatewayCommand::DeviceCreated(DeviceInfo { device, network_info, })); @@ -470,11 +470,11 @@ mod test { assert_eq!(events.len(), 2); assert!(events.iter().any(|e| match e { - GatewayEvent::DeviceCreated(info) => info.device.id == device1.id, + GatewayCommand::DeviceCreated(info) => info.device.id == device1.id, _ => false, })); assert!(events.iter().any(|e| match e { - GatewayEvent::DeviceCreated(info) => info.device.id == device2.id, + GatewayCommand::DeviceCreated(info) => info.device.id == device2.id, _ => false, })); @@ -485,7 +485,7 @@ mod test { assert_eq!(events.len(), 1); match &events[0] { - GatewayEvent::DeviceCreated(info) => { + GatewayCommand::DeviceCreated(info) => { assert_eq!(info.device.id, device3.id); } _ => panic!("Expected DeviceCreated event"), @@ -612,7 +612,7 @@ mod test { .unwrap(); assert_eq!(events.len(), 1); match &events[0] { - GatewayEvent::DeviceCreated(info) => { + GatewayCommand::DeviceCreated(info) => { assert_eq!(info.device.id, device1.id); } _ => panic!("Expected DeviceCreated event"), @@ -623,7 +623,7 @@ mod test { .unwrap(); assert_eq!(events.len(), 1); match &events[0] { - GatewayEvent::DeviceCreated(info) => { + GatewayCommand::DeviceCreated(info) => { assert_eq!(info.device.id, device2.id); } _ => panic!("Expected DeviceCreated event"), @@ -634,7 +634,7 @@ mod test { .unwrap(); assert_eq!(events.len(), 1); match &events[0] { - GatewayEvent::DeviceCreated(info) => { + GatewayCommand::DeviceCreated(info) => { assert_eq!(info.device.id, device3.id); } _ => panic!("Expected DeviceCreated event"), diff --git a/crates/defguard_core/src/user_management.rs b/crates/defguard_core/src/user_management.rs index 1fc5d7cea..e059e189b 100644 --- a/crates/defguard_core/src/user_management.rs +++ b/crates/defguard_core/src/user_management.rs @@ -10,7 +10,7 @@ use tokio::sync::broadcast::Sender; use crate::{ enterprise::{firewall::try_get_location_firewall_config, limits::update_counts}, error::WebError, - grpc::{GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, + grpc::{GatewayCommand, send_gateway_command, send_multiple_gateway_commands}, location_management::sync_allowed_devices_for_user, }; @@ -18,7 +18,7 @@ use crate::{ pub async fn delete_user_and_cleanup_devices( user: User, conn: &mut PgConnection, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), WebError> { let username = user.username.clone(); debug!("Deleting user {username}, removing his devices from gateways and updating ldap...",); @@ -33,7 +33,7 @@ pub async fn delete_user_and_cleanup_devices( for network_info in &device_info.network_info { affected_location_ids.insert(network_info.network_id); } - events.push(GatewayEvent::DeviceDeleted(device_info)); + events.push(GatewayCommand::DeviceDeleted(device_info)); } user.delete(&mut *conn).await?; @@ -49,7 +49,7 @@ pub async fn delete_user_and_cleanup_devices( debug!( "Sending firewall config update for location {location} affected by deleting user {username} devices" ); - events.push(GatewayEvent::FirewallConfigChanged( + events.push(GatewayCommand::FirewallConfigChanged( location_id, firewall_config, )); @@ -57,7 +57,7 @@ pub async fn delete_user_and_cleanup_devices( } } - send_multiple_wireguard_events(events, wg_tx); + send_multiple_gateway_commands(events, gateway_tx); info!( "The user {} has been deleted and his devices removed from gateways.", &username @@ -69,12 +69,12 @@ pub async fn delete_user_and_cleanup_devices( pub async fn disable_user( user: &mut User, conn: &mut PgConnection, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), WebError> { user.is_active = false; user.save(&mut *conn).await?; user.logout_all_sessions(&mut *conn).await?; - sync_allowed_user_devices(user, conn, wg_tx).await?; + sync_allowed_user_devices(user, conn, gateway_tx).await?; Ok(()) } @@ -82,7 +82,7 @@ pub async fn disable_user( pub async fn sync_allowed_user_devices( user: &User, conn: &mut PgConnection, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), WebError> { debug!("Syncing allowed devices of user {}", user.username); let locations = WireguardNetwork::all(&mut *conn).await?; @@ -93,16 +93,16 @@ pub async fn sync_allowed_user_devices( // check if any peers were updated if !gateway_events.is_empty() { // send peer update events - send_multiple_wireguard_events(gateway_events, wg_tx); + send_multiple_gateway_commands(gateway_events, gateway_tx); } // send firewall config update if ACLs & enterprise features are enabled if let Some(firewall_config) = try_get_location_firewall_config(&location, &mut *conn).await? { - send_wireguard_event( - GatewayEvent::FirewallConfigChanged(location.id, firewall_config), - wg_tx, + send_gateway_command( + GatewayCommand::FirewallConfigChanged(location.id, firewall_config), + gateway_tx, ); } } diff --git a/crates/defguard_core/src/utility_thread.rs b/crates/defguard_core/src/utility_thread.rs index 84acfb248..6e2f27f22 100644 --- a/crates/defguard_core/src/utility_thread.rs +++ b/crates/defguard_core/src/utility_thread.rs @@ -26,7 +26,7 @@ use crate::{ ldap::{do_ldap_sync, sync::get_ldap_sync_interval}, limits::update_counts, }, - grpc::GatewayEvent, + grpc::GatewayCommand, letsencrypt::do_letsencrypt_refresh, location_management::allowed_peers::get_location_allowed_peers, updates::do_new_version_check, @@ -46,7 +46,7 @@ const ACL_EXPIRY_SYSTEM_ACTOR: &str = "system:acl-expiry"; #[instrument(skip_all)] pub async fn run_utility_thread( pool: &PgPool, - wireguard_tx: broadcast::Sender, + gateway_tx: broadcast::Sender, proxy_control_tx: mpsc::Sender, web_reload_tx: broadcast::Sender<()>, ) -> Result<(), anyhow::Error> { @@ -64,7 +64,7 @@ pub async fn run_utility_thread( let directory_sync_task = || async { if let Err(e) = Box::pin( - do_directory_sync(pool, &wireguard_tx).instrument(info_span!("directory_sync_task")), + do_directory_sync(pool, &gateway_tx).instrument(info_span!("directory_sync_task")), ) .await { @@ -100,7 +100,7 @@ pub async fn run_utility_thread( }; let expired_acl_rules_task = || async { - if let Err(err) = expired_acl_rules_check(pool, wireguard_tx.clone()) + if let Err(err) = expired_acl_rules_check(pool, gateway_tx.clone()) .instrument(info_span!("expired_acl_rules_task")) .await { @@ -174,7 +174,7 @@ pub async fn run_utility_thread( {new_enterprise_enabled}" ); if let Err(err) = - enterprise_status_check(pool, wireguard_tx.clone(), new_enterprise_enabled) + enterprise_status_check(pool, gateway_tx.clone(), new_enterprise_enabled) .instrument(info_span!("enterprise_status_check")) .await { @@ -199,7 +199,7 @@ pub async fn run_utility_thread( /// Check if enterprise status has changed and perform any necessary actions async fn enterprise_status_check( pool: &PgPool, - wireguard_tx: broadcast::Sender, + gateway_tx: broadcast::Sender, enable_enterprise: bool, ) -> Result<(), anyhow::Error> { // fetch all ACL-enabled networks @@ -221,13 +221,13 @@ async fn enterprise_status_check( // Handle service location update or just update the firewall if location.service_location_mode == ServiceLocationMode::Disabled { - wireguard_tx.send(GatewayEvent::FirewallConfigChanged( + gateway_tx.send(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, ))?; } else { let new_peers = get_location_allowed_peers(&location, &mut transaction).await?; - wireguard_tx.send(GatewayEvent::NetworkModified( + gateway_tx.send(GatewayCommand::NetworkModified( location.id, location, new_peers, @@ -242,13 +242,13 @@ async fn enterprise_status_check( for location in locations { if location.service_location_mode == ServiceLocationMode::Disabled { debug!("Disabling gateway firewall configuration for location {location:?}"); - wireguard_tx.send(GatewayEvent::FirewallDisabled(location.id))?; + gateway_tx.send(GatewayCommand::FirewallDisabled(location.id))?; } else { debug!( "Disabling gateway firewall configuration and service location client \ connections for location {location}" ); - wireguard_tx.send(GatewayEvent::NetworkModified( + gateway_tx.send(GatewayCommand::NetworkModified( location.id, location, // Send empty peer list, we are disabling the service location @@ -265,7 +265,7 @@ async fn enterprise_status_check( /// Find newly expired ACL rules and update their status. async fn expired_acl_rules_check( pool: &PgPool, - wireguard_tx: broadcast::Sender, + gateway_tx: broadcast::Sender, ) -> Result<(), anyhow::Error> { // mark relevant rules as expired let updated_rules = query_as!( @@ -309,7 +309,7 @@ async fn expired_acl_rules_check( match try_get_location_firewall_config(&location, &mut conn).await? { Some(firewall_config) => { debug!("Sending firewall update event for location {location}"); - wireguard_tx.send(GatewayEvent::FirewallConfigChanged( + gateway_tx.send(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, ))?; diff --git a/crates/defguard_core/tests/integration/api/common/mod.rs b/crates/defguard_core/tests/integration/api/common/mod.rs index 350f010d9..725d13e8b 100644 --- a/crates/defguard_core/tests/integration/api/common/mod.rs +++ b/crates/defguard_core/tests/integration/api/common/mod.rs @@ -30,7 +30,7 @@ use defguard_core::{ db::AppEvent, enterprise::license::{License, LicenseTier, SupportType, set_cached_license}, events::ApiEvent, - grpc::{GatewayEvent, WorkerState}, + grpc::{GatewayCommand, WorkerState}, handlers::{Auth, user::UserDetails}, }; use reqwest::{StatusCode, header::HeaderName}; @@ -60,7 +60,7 @@ pub const X_FORWARDED_URI: HeaderName = HeaderName::from_static("x-forwarded-uri pub(crate) struct ClientState { pub pool: PgPool, pub worker_state: Arc>, - pub wireguard_rx: Receiver, + pub gateway_rx: Receiver, pub test_user: User, #[allow(dead_code)] pub config: DefGuardConfig, @@ -70,14 +70,14 @@ impl ClientState { pub fn new( pool: PgPool, worker_state: Arc>, - wireguard_rx: Receiver, + gateway_rx: Receiver, test_user: User, config: DefGuardConfig, ) -> Self { Self { pool, worker_state, - wireguard_rx, + gateway_rx, test_user, config, } @@ -92,7 +92,7 @@ pub(crate) async fn make_base_client( let (api_event_tx, api_event_rx) = unbounded_channel::(); let (tx, rx) = unbounded_channel::(); let worker_state = Arc::new(Mutex::new(WorkerState::new(tx.clone()))); - let (wg_tx, wg_rx) = broadcast::channel::(16); + let (gateway_tx, gateway_rx) = broadcast::channel::(16); let failed_logins = FailedLoginMap::new(); let failed_logins = Arc::new(Mutex::new(failed_logins)); @@ -113,7 +113,7 @@ pub(crate) async fn make_base_client( let client_state = ClientState::new( pool.clone(), worker_state.clone(), - wg_rx, + gateway_rx, User::find_by_username(&pool, "hpotter") .await .unwrap() @@ -146,7 +146,7 @@ pub(crate) async fn make_base_client( let webapp = build_webapp( tx, rx, - wg_tx, + gateway_tx, web_reload_tx, worker_state, pool, diff --git a/crates/defguard_core/tests/integration/api/proxy_certs.rs b/crates/defguard_core/tests/integration/api/proxy_certs.rs index 0473395ad..86c5ccc84 100644 --- a/crates/defguard_core/tests/integration/api/proxy_certs.rs +++ b/crates/defguard_core/tests/integration/api/proxy_certs.rs @@ -30,7 +30,7 @@ use defguard_core::{ db::AppEvent, enterprise::license::{License, LicenseTier, SupportType, set_cached_license}, events::ApiEvent, - grpc::{GatewayEvent, WorkerState}, + grpc::{GatewayCommand, WorkerState}, handlers::Auth, }; use reqwest::StatusCode; @@ -114,7 +114,7 @@ async fn make_test_client_with_proxy_rx( let (api_event_tx, api_event_rx) = unbounded_channel::(); let (tx, rx) = unbounded_channel::(); let worker_state = Arc::new(Mutex::new(WorkerState::new(tx.clone()))); - let (wg_tx, _wg_rx) = broadcast::channel::(16); + let (gateway_tx, _wg_rx) = broadcast::channel::(16); let failed_logins = Arc::new(Mutex::new(FailedLoginMap::new())); @@ -140,7 +140,7 @@ async fn make_test_client_with_proxy_rx( let webapp = build_webapp( tx, rx, - wg_tx, + gateway_tx, web_reload_tx, worker_state, pool.clone(), diff --git a/crates/defguard_core/tests/integration/api/wireguard.rs b/crates/defguard_core/tests/integration/api/wireguard.rs index 0ca61f6d8..150ad98da 100644 --- a/crates/defguard_core/tests/integration/api/wireguard.rs +++ b/crates/defguard_core/tests/integration/api/wireguard.rs @@ -22,7 +22,7 @@ use defguard_core::{ license::{License, LicenseTier, SupportType, get_cached_license, set_cached_license}, limits::update_counts, }, - grpc::{GatewayEvent, proto::enterprise::license::LicenseLimits}, + grpc::{GatewayCommand, proto::enterprise::license::LicenseLimits}, handlers::{Auth, GroupInfo, wireguard::WireguardNetworkData}, }; use ipnetwork::IpNetwork; @@ -44,7 +44,7 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -54,8 +54,8 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { let response = make_network(&client, "network").await; let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // check vpn locations for `admin` group let admin_id = Group::find_by_name(&client_state.pool, "admin") @@ -102,8 +102,8 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { ] ); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkModified(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkModified(..)); // check vpn locations for `admin` group let response = client.get(format!("/api/v1/group/{admin_id}")).send().await; @@ -134,8 +134,8 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkDeleted(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkDeleted(..)); } #[sqlx::test] @@ -507,7 +507,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -515,8 +515,8 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { // create network make_network(&client, "network").await; - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network details let response = client.get("/api/v1/network/1").send().await; @@ -534,8 +534,8 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::CREATED); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::DeviceCreated(..)); // an IP was assigned for new device let network_devices = WireguardNetworkDevice::find_by_device(&client_state.pool, 1) @@ -549,7 +549,10 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { // add another network make_network(&client, "network").await; - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkCreated(..)); + assert_matches!( + gateway_rx.try_recv().unwrap(), + GatewayCommand::NetworkCreated(..) + ); // an IP was assigned for an existing device let network_devices = WireguardNetworkDevice::find_by_device(&client_state.pool, 1) @@ -594,8 +597,8 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceModified(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::DeviceModified(..)); // device details let response = client @@ -636,8 +639,8 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkDeleted(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkDeleted(..)); // delete device let response = client @@ -645,8 +648,8 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceDeleted(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::DeviceDeleted(..)); let response = client.get("/api/v1/device").json(&device).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -930,7 +933,7 @@ async fn test_device_pubkey(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -938,8 +941,8 @@ async fn test_device_pubkey(_: PgPoolOptions, options: PgConnectOptions) { // create network make_network(&client, "network").await; - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network details let response = client.get("/api/v1/network/1").send().await; diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs b/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs index 270d46c7f..d8833f7cb 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs @@ -16,7 +16,7 @@ use defguard_common::{ }, }; use defguard_core::{ - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{ Auth, wireguard::{ImportedNetworkData, WireguardNetworkData}, @@ -148,7 +148,7 @@ async fn test_create_new_network(_: PgPoolOptions, options: PgConnectOptions) { let (_users, devices) = setup_test_users(&client_state.pool).await; let mut conn = client_state.pool.acquire().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -180,9 +180,9 @@ async fn test_create_new_network(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); - assert_err!(wg_rx.try_recv()); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); + assert_err!(gateway_rx.try_recv()); // network configuration was created only for admin and allowed user let peers = get_location_allowed_peers(&network, &mut conn) @@ -201,7 +201,7 @@ async fn test_create_new_network_allow_all_groups(_: PgPoolOptions, options: PgC let (_users, devices) = setup_test_users(&client_state.pool).await; let mut conn = client_state.pool.acquire().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -237,7 +237,10 @@ async fn test_create_new_network_allow_all_groups(_: PgPoolOptions, options: PgC .await .unwrap(); assert_eq!(allowed_groups, vec!["allowed group"]); - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkCreated(..)); + assert_matches!( + gateway_rx.try_recv().unwrap(), + GatewayCommand::NetworkCreated(..) + ); let peers = get_location_allowed_peers(&network, &mut conn) .await @@ -257,7 +260,7 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { let (_users, devices) = setup_test_users(&client_state.pool).await; let mut conn = client_state.pool.acquire().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -289,8 +292,8 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network configuration was created for admin and the allowed group member let peers = get_location_allowed_peers(&network, &mut conn) @@ -324,7 +327,10 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); + assert_matches!( + gateway_rx.try_recv().unwrap(), + GatewayCommand::NetworkModified(..) + ); let new_peers = get_location_allowed_peers(&network, &mut conn) .await @@ -357,7 +363,10 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); + assert_matches!( + gateway_rx.try_recv().unwrap(), + GatewayCommand::NetworkModified(..) + ); let new_peers = get_location_allowed_peers(&network, &mut conn) .await @@ -391,7 +400,10 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); + assert_matches!( + gateway_rx.try_recv().unwrap(), + GatewayCommand::NetworkModified(..) + ); let new_peers = get_location_allowed_peers(&network, &mut conn) .await @@ -400,7 +412,7 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(new_peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(new_peers[1].pubkey, devices[2].wireguard_pubkey); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); } #[sqlx::test] @@ -411,7 +423,7 @@ async fn test_modify_network_enable_allow_all_groups(_: PgPoolOptions, options: let (_users, devices) = setup_test_users(&client_state.pool).await; let mut conn = client_state.pool.acquire().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -441,7 +453,10 @@ async fn test_modify_network_enable_allow_all_groups(_: PgPoolOptions, options: .await; assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkCreated(..)); + assert_matches!( + gateway_rx.try_recv().unwrap(), + GatewayCommand::NetworkCreated(..) + ); let peers = get_location_allowed_peers(&network, &mut conn) .await @@ -478,7 +493,10 @@ async fn test_modify_network_enable_allow_all_groups(_: PgPoolOptions, options: .await .unwrap(); assert_eq!(allowed_groups, vec!["allowed group"]); - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); + assert_matches!( + gateway_rx.try_recv().unwrap(), + GatewayCommand::NetworkModified(..) + ); let peers = get_location_allowed_peers(&network, &mut conn) .await @@ -499,7 +517,7 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne let (_users, devices) = setup_test_users(&client_state.pool).await; let mut conn = client_state.pool.acquire().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -558,11 +576,11 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(peers[1].pubkey, devices[1].wireguard_pubkey); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network config was only created for one of the existing devices and the admin device - let GatewayEvent::DeviceModified(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceModified(device_info) = gateway_rx.try_recv().unwrap() else { panic!() }; assert_eq!(device_info.device.id, devices[1].id); @@ -573,7 +591,7 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne peers[1].allowed_ips[0] ); - let GatewayEvent::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceCreated(device_info) = gateway_rx.try_recv().unwrap() else { panic!() }; assert_eq!(device_info.device.id, devices[0].id); @@ -584,7 +602,7 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne peers[0].allowed_ips[0] ); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); } #[sqlx::test] @@ -595,7 +613,7 @@ async fn test_import_mapping_devices(_: PgPoolOptions, options: PgConnectOptions let (users, devices) = setup_test_users(&client_state.pool).await; let mut conn = client_state.pool.acquire().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -642,7 +660,7 @@ PersistentKeepalive = 300 let mut mapped_devices = response.devices; assert_eq!(mapped_devices.len(), 4); for _ in 0..3 { - wg_rx.try_recv().unwrap(); + gateway_rx.try_recv().unwrap(); } // assign devices to users @@ -668,7 +686,7 @@ PersistentKeepalive = 300 assert_eq!(peers[3].pubkey, mapped_devices[1].wireguard_pubkey); // assert events - let GatewayEvent::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceCreated(device_info) = gateway_rx.try_recv().unwrap() else { panic!() }; assert_eq!( @@ -682,7 +700,7 @@ PersistentKeepalive = 300 mapped_devices[0].wireguard_ips, ); - let GatewayEvent::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceCreated(device_info) = gateway_rx.try_recv().unwrap() else { panic!() }; assert_eq!( @@ -696,7 +714,7 @@ PersistentKeepalive = 300 mapped_devices[1].wireguard_ips, ); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); } /// Test that changing groups for a particular user generates correct update events @@ -708,7 +726,7 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { let (_users, devices) = setup_test_users(&client_state.pool).await; let mut conn = client_state.pool.acquire().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -740,9 +758,9 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); - assert_err!(wg_rx.try_recv()); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); + assert_err!(gateway_rx.try_recv()); // network configuration was created only for admin and allowed user let peers = get_location_allowed_peers(&network, &mut conn) @@ -762,9 +780,9 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceDeleted(..)); - assert_err!(wg_rx.try_recv()); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::DeviceDeleted(..)); + assert_err!(gateway_rx.try_recv()); let peers = get_location_allowed_peers(&network, &mut conn) .await @@ -782,7 +800,7 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); let peers = get_location_allowed_peers(&network, &mut conn) .await @@ -800,9 +818,9 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceCreated(..)); - assert_err!(wg_rx.try_recv()); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::DeviceCreated(..)); + assert_err!(gateway_rx.try_recv()); let peers = get_location_allowed_peers(&network, &mut conn) .await @@ -823,7 +841,7 @@ async fn test_modify_user_no_effect_when_allow_all_groups( let (_users, devices) = setup_test_users(&client_state.pool).await; let mut conn = client_state.pool.acquire().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -853,8 +871,11 @@ async fn test_modify_user_no_effect_when_allow_all_groups( .await; assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkCreated(..)); - assert_err!(wg_rx.try_recv()); + assert_matches!( + gateway_rx.try_recv().unwrap(), + GatewayCommand::NetworkCreated(..) + ); + assert_err!(gateway_rx.try_recv()); let peers = get_location_allowed_peers(&network, &mut conn) .await @@ -870,7 +891,7 @@ async fn test_modify_user_no_effect_when_allow_all_groups( .await; assert_eq!(response.status(), StatusCode::OK); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); let peers = get_location_allowed_peers(&network, &mut conn) .await @@ -934,7 +955,7 @@ async fn test_delete_only_allowed_group_rejected(_: PgPoolOptions, options: PgCo let (_users, devices) = setup_test_users(&client_state.pool).await; let mut conn = client_state.pool.acquire().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -966,8 +987,8 @@ async fn test_delete_only_allowed_group_rejected(_: PgPoolOptions, options: PgCo assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); let peers = get_location_allowed_peers(&network, &mut conn) .await @@ -1007,7 +1028,7 @@ async fn test_delete_allowed_group_when_location_keeps_other_groups( let (client, client_state) = make_test_client(pool).await; setup_test_users(&client_state.pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -1037,8 +1058,8 @@ async fn test_delete_allowed_group_when_location_keeps_other_groups( .await; assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); let allowed_group_id = Group::find_by_name(&client_state.pool, "allowed group") .await diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs index 071847c0f..13fc2422a 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs @@ -5,7 +5,7 @@ use defguard_common::db::{ models::{Device, WireguardNetwork}, }; use defguard_core::{ - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{Auth, network_devices::AddNetworkDevice}, }; use ipnetwork::IpNetwork; @@ -98,7 +98,7 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -108,13 +108,13 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { let response = make_first_network(&client).await; let network_1: WireguardNetwork = response.json().await; assert_eq!(network_1.name, "network"); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); let response = make_second_network(&client).await; let network_2: WireguardNetwork = response.json().await; assert_eq!(network_2.name, "network-2"); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // ip suggestions let response = client.get("/api/v1/device/network/ip/1").send().await; @@ -201,8 +201,8 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { let configured = json["device"]["configured"].as_bool().unwrap(); let config_text = json["config"]["config"].as_str().unwrap(); assert!(configured); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::DeviceCreated(..)); // download WG config let response = client.get("/api/v1/device/network/1/config").send().await; @@ -238,8 +238,8 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { .unwrap(); assert_eq!(device.name, "device-1"); assert_eq!(device.description, Some("new description".to_owned())); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceModified(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::DeviceModified(..)); // Make sure the device is only in the selected network let device_networks = diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs index d176b3885..05d3b965c 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs @@ -6,7 +6,7 @@ use defguard_common::db::models::{ wireguard::{LocationMfaMode, ServiceLocationMode}, }; use defguard_core::{ - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{Auth, wireguard::ImportedNetworkData}, }; use matches::assert_matches; @@ -102,7 +102,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { transaction.commit().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -137,12 +137,15 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(network.dns, Some("10.0.0.2".to_owned())); assert_eq!(network.allowed_ips, vec!["10.0.0.0/24".parse().unwrap()]); assert_eq!(network.connected_at, None); - let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let event = gateway_rx.try_recv().unwrap(); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // existing devices assertion // imported config for an existing device - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::DeviceModified(..)); + assert_matches!( + gateway_rx.try_recv().unwrap(), + GatewayCommand::DeviceModified(..) + ); let user_device_1 = UserDevice::from_device(&pool, device_1) .await .unwrap() @@ -153,7 +156,10 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { vec!["10.0.0.12"] ); // generated IP for other existing device - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::DeviceCreated(..)); + assert_matches!( + gateway_rx.try_recv().unwrap(), + GatewayCommand::DeviceCreated(..) + ); let user_device_2 = UserDevice::from_device(&pool, device_2) .await .unwrap() @@ -203,23 +209,23 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::CREATED); // assert events - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); match event { - GatewayEvent::DeviceCreated(device_info) => { + GatewayCommand::DeviceCreated(device_info) => { assert_eq!(device_info.device.name, "device_1"); } _ => unreachable!("Invalid event type received"), } - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); match event { - GatewayEvent::DeviceCreated(device_info) => { + GatewayCommand::DeviceCreated(device_info) => { assert_eq!(device_info.device.name, "device_2"); } _ => unreachable!("Invalid event type received"), } - let event = wg_rx.try_recv(); + let event = gateway_rx.try_recv(); assert_matches!(event, Err(TryRecvError::Empty)); // assert user devices diff --git a/crates/defguard_event_router/Cargo.toml b/crates/defguard_event_router/Cargo.toml index 8b6423149..c36146d0a 100644 --- a/crates/defguard_event_router/Cargo.toml +++ b/crates/defguard_event_router/Cargo.toml @@ -9,6 +9,7 @@ rust-version.workspace = true [dependencies] # internal crates +defguard_common = { workspace = true } defguard_core = { workspace = true } defguard_event_logger = { workspace = true } defguard_session_manager = { workspace = true } @@ -19,4 +20,3 @@ tokio = { workspace = true } tracing = { workspace = true } [dev-dependencies] -defguard_common = { workspace = true } diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index 748f8ff3e..d5252bab3 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -114,10 +114,8 @@ mod tests { wireguard::{LocationMfaMode, ServiceLocationMode}, }, }; - use defguard_core::{ - events::{BidiRequestContext, BidiStreamEventType}, - grpc::GatewayEvent, - }; + use defguard_common::gateway_event::GatewayCommand; + use defguard_core::events::{BidiRequestContext, BidiStreamEventType}; use tokio::sync::{Notify, broadcast, mpsc::unbounded_channel}; use super::*; @@ -163,13 +161,13 @@ mod tests { let (_bidi_tx, bidi_rx) = unbounded_channel(); let (_session_manager_tx, session_manager_rx) = unbounded_channel(); let (event_logger_tx, event_logger_rx) = unbounded_channel(); - let (wireguard_tx, _wireguard_rx) = broadcast::channel::(1); + let (gateway_tx, _gateway_rx) = broadcast::channel::(1); ( EventRouter::new( RouterReceiverSet::new(api_rx, bidi_rx, session_manager_rx), event_logger_tx, - wireguard_tx, + gateway_tx, Arc::new(Notify::new()), ), event_logger_rx, diff --git a/crates/defguard_event_router/src/lib.rs b/crates/defguard_event_router/src/lib.rs index b46e6d257..193744288 100644 --- a/crates/defguard_event_router/src/lib.rs +++ b/crates/defguard_event_router/src/lib.rs @@ -13,16 +13,14 @@ //! MPSC channel. //! 2. The router processes these events and forwards them to the appropriate services: //! - Activity log events go to the event logger service -//! - WireGuard events go to the gateway service +//! - gateway commands go to the gateway service //! - Mail events go to the mail service //! - etc. use std::sync::Arc; -use defguard_core::{ - events::{ApiEvent, BidiStreamEvent}, - grpc::GatewayEvent, -}; +use defguard_common::gateway_event::GatewayCommand; +use defguard_core::events::{ApiEvent, BidiStreamEvent}; use defguard_event_logger::message::{EventContext, EventLoggerMessage, LoggerEvent}; use defguard_session_manager::events::SessionManagerEvent; use error::EventRouterError; @@ -63,7 +61,7 @@ impl RouterReceiverSet { struct EventRouter { receivers: RouterReceiverSet, event_logger_tx: UnboundedSender, - wireguard_tx: Sender, + gateway_tx: Sender, activity_log_stream_reload_notify: Arc, } @@ -89,13 +87,13 @@ impl EventRouter { fn new( receivers: RouterReceiverSet, event_logger_tx: UnboundedSender, - wireguard_tx: Sender, + gateway_tx: Sender, activity_log_stream_reload_notify: Arc, ) -> Self { Self { receivers, event_logger_tx, - wireguard_tx, + gateway_tx, activity_log_stream_reload_notify, } } @@ -140,7 +138,7 @@ impl EventRouter { pub async fn run_event_router( receivers: RouterReceiverSet, event_logger_tx: UnboundedSender, - wireguard_tx: Sender, + gateway_tx: Sender, activity_log_stream_reload_notify: Arc, ) -> Result<(), EventRouterError> { info!("Starting main event router service"); @@ -148,7 +146,7 @@ pub async fn run_event_router( let mut event_router = EventRouter::new( receivers, event_logger_tx, - wireguard_tx, + gateway_tx, activity_log_stream_reload_notify, ); diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index d300c6ce3..5f9ca2257 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -21,17 +21,18 @@ use defguard_common::{ wireguard::DEFAULT_WIREGUARD_MTU, }, }, + gateway_event::GatewayCommand, + gateway_types::{FirewallConfig, WireguardPeer}, messages::peer_stats_update::PeerStatsUpdate, }; use defguard_core::{ enterprise::firewall::try_get_location_firewall_config, - grpc::GatewayEvent, handlers::mail::{send_gateway_disconnected_email, send_gateway_reconnected_email}, location_management::allowed_peers::get_location_allowed_peers, }; use defguard_grpc_tls::{certs as tls_certs, connector::HttpsSchemeConnector}; use defguard_proto::{ - enterprise::firewall::FirewallConfig, + enterprise::firewall::FirewallConfig as ProtoFirewallConfig, gateway::{ Configuration, CoreResponse, Peer, PeerStats, Update, UpdateType, core_request, core_response, gateway_client, update, @@ -87,7 +88,7 @@ pub(crate) struct GatewayHandler { gateway: Gateway, message_id: AtomicU64, pool: PgPool, - events_tx: Sender, + events_tx: Sender, peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, updates_handler_handle: Option>, @@ -101,7 +102,7 @@ impl GatewayHandler { pub fn new( gateway: Gateway, pool: PgPool, - events_tx: Sender, + events_tx: Sender, peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, ) -> Result { @@ -564,7 +565,7 @@ impl GatewayHandler { pub(crate) fn new_with_test_socket( gateway: Gateway, pool: PgPool, - events_tx: Sender, + events_tx: Sender, peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, socket_path: PathBuf, @@ -622,7 +623,7 @@ struct GatewayUpdatesHandler { gateway_name: String, pool: Option, session_authorization_required: bool, - events_rx: broadcast::Receiver, + events_rx: broadcast::Receiver, tx: UnboundedSender, } @@ -633,7 +634,7 @@ impl GatewayUpdatesHandler { network: WireguardNetwork, gateway_name: String, pool: Option, - events_rx: broadcast::Receiver, + events_rx: broadcast::Receiver, tx: UnboundedSender, ) -> Self { Self { @@ -754,7 +755,7 @@ impl GatewayUpdatesHandler { while let Ok(update) = self.events_rx.recv().await { debug!("Received WireGuard update: {update:?}"); let result = match update { - GatewayEvent::NetworkCreated(network_id, network) => { + GatewayCommand::NetworkCreated(network_id, network) => { if network_id == self.network_id { self.send_network_update( &network, @@ -766,7 +767,7 @@ impl GatewayUpdatesHandler { Ok(()) } } - GatewayEvent::NetworkModified( + GatewayCommand::NetworkModified( network_id, network, peers, @@ -787,14 +788,14 @@ impl GatewayUpdatesHandler { Ok(()) } } - GatewayEvent::NetworkDeleted(network_id, network_name) => { + GatewayCommand::NetworkDeleted(network_id, network_name) => { if network_id == self.network_id { self.send_network_delete(&network_name) } else { Ok(()) } } - GatewayEvent::DeviceCreated(device) => { + GatewayCommand::DeviceCreated(device) => { // check if a peer has to be added in the current network match device .network_info @@ -810,7 +811,7 @@ impl GatewayUpdatesHandler { None => Ok(()), } } - GatewayEvent::DeviceModified(device) => { + GatewayCommand::DeviceModified(device) => { // check if a peer has to be updated in the current network match device .network_info @@ -826,7 +827,7 @@ impl GatewayUpdatesHandler { None => Ok(()), } } - GatewayEvent::DeviceDeleted(device) => { + GatewayCommand::DeviceDeleted(device) => { // check if a peer has to be updated in the current network match device .network_info @@ -837,28 +838,28 @@ impl GatewayUpdatesHandler { None => Ok(()), } } - GatewayEvent::FirewallConfigChanged(location_id, firewall_config) => { + GatewayCommand::FirewallConfigChanged(location_id, firewall_config) => { if location_id == self.network_id { self.send_firewall_update(firewall_config) } else { Ok(()) } } - GatewayEvent::FirewallDisabled(location_id) => { + GatewayCommand::FirewallDisabled(location_id) => { if location_id == self.network_id { self.send_firewall_disable() } else { Ok(()) } } - GatewayEvent::VpnSessionDeauthorized(location_id, device) => { + GatewayCommand::VpnSessionDeauthorized(location_id, device) => { if location_id == self.network_id { self.send_peer_delete(&device.wireguard_pubkey) } else { Ok(()) } } - GatewayEvent::VpnSessionAuthorized(location_id, device, network_info) => { + GatewayCommand::VpnSessionAuthorized(location_id, device, network_info) => { if location_id == self.network_id { if network_info.network_id != location_id { error!( @@ -892,11 +893,13 @@ impl GatewayUpdatesHandler { fn send_network_update( &self, network: &WireguardNetwork, - peers: Vec, + peers: Vec, firewall_config: Option, update_type: i32, ) -> Result<(), Status> { debug!("Sending network update for network {network}"); + let proto_peers: Vec = peers.into_iter().map(Into::into).collect(); + let proto_firewall: Option = firewall_config.map(Into::into); if let Err(err) = self.tx.send(CoreResponse { id: 0, payload: Some(core_response::Payload::Update(Update { @@ -906,8 +909,8 @@ impl GatewayUpdatesHandler { private_key: network.prvkey.clone(), addresses: network.address().iter().map(ToString::to_string).collect(), port: network.port.cast_unsigned(), - peers, - firewall_config, + peers: proto_peers, + firewall_config: proto_firewall, mtu: network.mtu.cast_unsigned(), fwmark: network.fwmark as u32, })), @@ -1022,11 +1025,12 @@ impl GatewayUpdatesHandler { "Sending firewall config update for network {} with config {firewall_config:?}", self.network ); + let proto_firewall: ProtoFirewallConfig = firewall_config.into(); if let Err(err) = self.tx.send(CoreResponse { id: 0, payload: Some(core_response::Payload::Update(Update { update_type: UpdateType::Modify as i32, - update: Some(update::Update::FirewallConfig(firewall_config)), + update: Some(update::Update::FirewallConfig(proto_firewall)), })), }) { let msg = format!( @@ -1107,14 +1111,14 @@ mod tests { }, setup_pool, }; - use defguard_core::grpc::GatewayEvent; - use defguard_proto::gateway::{Configuration, Peer, PeerStats, core_response}; + use defguard_common::gateway_event::GatewayCommand; + use defguard_proto::gateway::{Configuration, PeerStats, core_response}; use prost_types::Timestamp; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::sync::{broadcast, mpsc::unbounded_channel, watch}; use super::{ - FirewallConfig, GatewayHandler, GatewayUpdatesHandler, try_protos_into_stats_message, + GatewayHandler, GatewayUpdatesHandler, WireguardPeer, try_protos_into_stats_message, }; fn test_network(location_mfa_mode: LocationMfaMode) -> WireguardNetwork { @@ -1235,16 +1239,20 @@ mod tests { #[test] fn gen_config_maps_network_fields() { + use defguard_common::gateway_types::{ + FirewallConfig as NativeFirewallConfig, FirewallPolicy, + }; + use defguard_proto::enterprise::firewall::FirewallPolicy as ProtoFirewallPolicy; let config = Configuration::new( &build_network(), - vec![Peer { + vec![WireguardPeer { pubkey: "peer-public-key".to_owned(), allowed_ips: vec!["10.10.0.2/32".to_owned()], preshared_key: Some("peer-preshared-key".to_owned()), keepalive_interval: Some(25), }], - Some(FirewallConfig { - default_policy: 0, + Some(NativeFirewallConfig { + default_policy: FirewallPolicy::Unspecified, rules: Vec::new(), snat_bindings: Vec::new(), }), @@ -1269,7 +1277,10 @@ mod tests { let firewall_config = config .firewall_config .expect("generated config should include firewall config"); - assert_eq!(firewall_config.default_policy, 0); + assert_eq!( + firewall_config.default_policy, + ProtoFirewallPolicy::Unspecified as i32 + ); assert!(firewall_config.rules.is_empty()); assert!(firewall_config.snat_bindings.is_empty()); } @@ -1451,7 +1462,7 @@ mod tests { .save(&pool) .await .unwrap(); - let (events_tx, _events_rx) = broadcast::channel::(1); + let (events_tx, _events_rx) = broadcast::channel::(1); let (peer_stats_tx, _peer_stats_rx) = unbounded_channel(); let (_certs_tx, certs_rx) = watch::channel(Arc::new(HashMap::::new())); let handler = diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index 8cdac400e..f5691da93 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -11,9 +11,9 @@ use std::{ use defguard_common::{ db::{ChangeNotification, Id, TriggerOperation, models::gateway::Gateway}, + gateway_event::GatewayCommand, messages::peer_stats_update::PeerStatsUpdate, }; -use defguard_core::grpc::GatewayEvent; use defguard_proto::gateway::gateway_client::GatewayClient; use defguard_version::client::ClientVersionInterceptor; use sqlx::{PgPool, postgres::PgListener}; @@ -656,14 +656,14 @@ mod unit_tests { /// events, notifications, and side effects to Core components. #[derive(Clone)] pub struct GatewayTxSet { - events: Sender, + events: Sender, peer_stats: UnboundedSender, } impl GatewayTxSet { #[must_use] pub const fn new( - events: Sender, + events: Sender, peer_stats: UnboundedSender, ) -> Self { Self { events, peer_stats } diff --git a/crates/defguard_gateway_manager/src/tests/common/mod.rs b/crates/defguard_gateway_manager/src/tests/common/mod.rs index eee7e9e47..e59f72c57 100644 --- a/crates/defguard_gateway_manager/src/tests/common/mod.rs +++ b/crates/defguard_gateway_manager/src/tests/common/mod.rs @@ -19,9 +19,9 @@ use defguard_common::{ }, setup_pool, }, + gateway_event::GatewayCommand, messages::peer_stats_update::PeerStatsUpdate, }; -use defguard_core::grpc::GatewayEvent; use defguard_proto::gateway::{CoreRequest, CoreResponse, PeerStats, core_request, gateway_server}; use prost_types::Timestamp; use sqlx::{PgPool, postgres::PgConnectOptions}; @@ -476,7 +476,7 @@ pub(crate) struct HandlerTestContext { pub(crate) network: WireguardNetwork, pub(crate) gateway: Gateway, pub(crate) peer_stats_rx: UnboundedReceiver, - events_tx: Option>, + events_tx: Option>, pub(crate) mock_gateway: Option, handler_task: Option>>, } @@ -489,7 +489,7 @@ impl HandlerTestContext { pub(crate) async fn new_with_events_tx( options: PgConnectOptions, - events_tx: broadcast::Sender, + events_tx: broadcast::Sender, ) -> Self { let pool = setup_pool(options).await; initialize_current_settings(&pool) @@ -524,7 +524,7 @@ impl HandlerTestContext { } } - pub(crate) fn events_tx(&self) -> &broadcast::Sender { + pub(crate) fn events_tx(&self) -> &broadcast::Sender { self.events_tx .as_ref() .expect("events sender already taken from context") diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs index a46a56e46..97986d1e3 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs @@ -2,7 +2,7 @@ mod support; use defguard_common::db::models::device::{DeviceInfo, WireguardNetworkDevice}; -use defguard_core::grpc::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_proto::gateway::{UpdateType, core_response}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tonic::Status; diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs index a15b8cbd1..6008540cf 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs @@ -18,7 +18,7 @@ async fn test_device_created_for_network_produces_peer_create_update( assert_send_ok!( context .events_tx() - .send(GatewayEvent::DeviceCreated(device_info)), + .send(GatewayCommand::DeviceCreated(device_info)), "failed to broadcast created device event" ); @@ -46,7 +46,7 @@ async fn test_device_created_before_config_handshake_is_ignored( "created-before-config-device", "tND8hJQhYnI8naBTo59He43zYldagfjlwmSxWEc01Cc=", "10.10.0.11", - GatewayEvent::DeviceCreated, + GatewayCommand::DeviceCreated, ) .await; } @@ -69,10 +69,11 @@ async fn test_device_modified_for_network_produces_peer_modify_update( ) .await; - let mut network_device = WireguardNetworkDevice::find(&context.pool, device.id, context.network.id) - .await - .expect("failed to load device network info") - .expect("expected device network info for modified device"); + let mut network_device = + WireguardNetworkDevice::find(&context.pool, device.id, context.network.id) + .await + .expect("failed to load device network info") + .expect("expected device network info for modified device"); network_device.wireguard_ips = vec![parse_test_ip("10.10.0.21")]; network_device .update(&context.pool) @@ -86,7 +87,7 @@ async fn test_device_modified_for_network_produces_peer_modify_update( assert_send_ok!( context .events_tx() - .send(GatewayEvent::DeviceModified(device_info)), + .send(GatewayCommand::DeviceModified(device_info)), "failed to broadcast modified device event" ); @@ -114,7 +115,7 @@ async fn test_device_modified_before_config_handshake_is_ignored( "modified-before-config-device", "wyFOHCec/Fi9s+cARikVO71JhyYtYMk0FrQx3fK2PTM=", "10.10.0.22", - GatewayEvent::DeviceModified, + GatewayCommand::DeviceModified, ) .await; } @@ -138,7 +139,7 @@ async fn test_device_deleted_for_network_produces_peer_delete_update( assert_send_ok!( context .events_tx() - .send(GatewayEvent::DeviceDeleted(device_info)), + .send(GatewayCommand::DeviceDeleted(device_info)), "failed to broadcast deleted device event" ); @@ -166,7 +167,7 @@ async fn test_device_deleted_before_config_handshake_is_ignored( "deleted-before-config-device", "m84QJmDMkqdCj8AB2NTE8F55W7M/i3CaaD3eQbQdInY=", "10.10.0.31", - GatewayEvent::DeviceDeleted, + GatewayCommand::DeviceDeleted, ) .await; } @@ -181,7 +182,7 @@ async fn test_device_created_for_different_network_is_ignored( "created-other-network-device", "W6wBmd8wgTwvCyGqDRXk6Hf4OMqDUbUn2XWKnG5wVVQ=", "10.11.0.10", - GatewayEvent::DeviceCreated, + GatewayCommand::DeviceCreated, ) .await; } @@ -196,7 +197,7 @@ async fn test_device_modified_for_different_network_is_ignored( "modified-other-network-device", "yjuzq0cLk3Ww5oQcqK6YkSKwXnqQ1V9OlSMFAEkr0lU=", "10.11.0.20", - GatewayEvent::DeviceModified, + GatewayCommand::DeviceModified, ) .await; } @@ -211,7 +212,7 @@ async fn test_device_deleted_for_different_network_is_ignored( "deleted-other-network-device", "Jtp+K8xnFXuF4cae+tVGZNwoSM2fXjJbRl3sI6rdcAQ=", "10.11.0.30", - GatewayEvent::DeviceDeleted, + GatewayCommand::DeviceDeleted, ) .await; } diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs index 1545aab35..5086f9dda 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs @@ -11,7 +11,7 @@ async fn test_matching_location_firewall_config_changed_event_produces_update( assert_send_ok!( context .events_tx() - .send(GatewayEvent::FirewallConfigChanged( + .send(GatewayCommand::FirewallConfigChanged( context.network.id, expected_firewall_config.clone(), )), @@ -37,7 +37,7 @@ async fn test_matching_location_firewall_disabled_event_produces_disable_update( assert_send_ok!( context .events_tx() - .send(GatewayEvent::FirewallDisabled(context.network.id)), + .send(GatewayCommand::FirewallDisabled(context.network.id)), "failed to broadcast firewall disabled event" ); @@ -56,7 +56,7 @@ async fn test_different_location_firewall_config_changed_event_is_ignored( let expected_firewall_config = build_test_firewall_config(); assert_firewall_event_for_different_network_is_ignored(options, move |other_network_id| { - GatewayEvent::FirewallConfigChanged(other_network_id, expected_firewall_config) + GatewayCommand::FirewallConfigChanged(other_network_id, expected_firewall_config) }) .await; } @@ -67,7 +67,7 @@ async fn test_different_location_firewall_disabled_event_is_ignored( options: PgConnectOptions, ) { assert_firewall_event_for_different_network_is_ignored(options, |other_network_id| { - GatewayEvent::FirewallDisabled(other_network_id) + GatewayCommand::FirewallDisabled(other_network_id) }) .await; } diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs index c21e4b594..9b17a67c6 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs @@ -18,7 +18,7 @@ async fn test_matching_location_posture_vpn_session_authorized_produces_peer_cre .await; assert_send_ok!( - context.events_tx().send(GatewayEvent::VpnSessionAuthorized( + context.events_tx().send(GatewayCommand::VpnSessionAuthorized( context.network.id, device, network_device, @@ -60,7 +60,7 @@ async fn test_matching_location_posture_vpn_session_authorized_without_psk_is_sk network_device.preshared_key = None; assert_send_ok!( - context.events_tx().send(GatewayEvent::VpnSessionAuthorized( + context.events_tx().send(GatewayCommand::VpnSessionAuthorized( context.network.id, device, network_device, @@ -93,7 +93,7 @@ async fn test_matching_location_vpn_session_authorized_produces_peer_create( .await; assert_send_ok!( - context.events_tx().send(GatewayEvent::VpnSessionAuthorized( + context.events_tx().send(GatewayCommand::VpnSessionAuthorized( context.network.id, device, network_device, @@ -139,7 +139,7 @@ async fn test_vpn_session_authorized_with_mismatched_network_id_is_ignored( .await; assert_send_ok!( - context.events_tx().send(GatewayEvent::VpnSessionAuthorized( + context.events_tx().send(GatewayCommand::VpnSessionAuthorized( context.network.id, device, network_device, @@ -173,7 +173,7 @@ async fn test_matching_location_vpn_session_deauthorized_produces_peer_delete( assert_send_ok!( context .events_tx() - .send(GatewayEvent::VpnSessionDeauthorized( + .send(GatewayCommand::VpnSessionDeauthorized( context.network.id, device, )), diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs index e7aeddcdb..176319151 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs @@ -8,7 +8,7 @@ async fn test_matching_location_network_deleted_event_produces_delete_update( let _ = context.complete_config_handshake().await; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkDeleted( + context.events_tx().send(GatewayCommand::NetworkDeleted( context.network.id, context.network.name.clone(), )), @@ -45,7 +45,7 @@ async fn test_matching_location_network_modified_event_produces_modify_update( modified_network.fwmark = 42; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkModified( + context.events_tx().send(GatewayCommand::NetworkModified( context.network.id, modified_network, Vec::new(), @@ -92,7 +92,7 @@ async fn test_matching_location_network_created_event_produces_create_update( created_network.fwmark = 17; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkCreated( + context.events_tx().send(GatewayCommand::NetworkCreated( context.network.id, created_network, )), @@ -145,7 +145,7 @@ async fn test_only_matching_handler_receives_network_modified_update( assert_send_ok!( matching_context .events_tx() - .send(GatewayEvent::NetworkModified( + .send(GatewayCommand::NetworkModified( matching_context.network.id, modified_network, Vec::new(), @@ -189,7 +189,7 @@ async fn test_different_location_network_created_event_is_ignored( let _ = context.complete_config_handshake().await; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkCreated( + context.events_tx().send(GatewayCommand::NetworkCreated( other_network.id, other_network, )), @@ -212,7 +212,7 @@ async fn test_different_location_network_deleted_event_is_ignored( let _ = context.complete_config_handshake().await; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkDeleted( + context.events_tx().send(GatewayCommand::NetworkDeleted( other_network.id, other_network.name.clone(), )), @@ -222,7 +222,7 @@ async fn test_different_location_network_deleted_event_is_ignored( context.mock_gateway_mut().expect_no_outbound().await; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkDeleted( + context.events_tx().send(GatewayCommand::NetworkDeleted( context.network.id, context.network.name.clone(), )), diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs index 0f2b60180..cc28d7f36 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs @@ -1,25 +1,26 @@ use std::net::IpAddr; -use defguard_common::db::{ - Id, NoId, - models::{ - device::{Device, DeviceInfo, DeviceNetworkInfo, DeviceType, WireguardNetworkDevice}, - user::User, - vpn_client_session::VpnClientSession, - wireguard::{LocationMfaMode, WireguardNetwork}, +use defguard_common::gateway_event::GatewayCommand; +use defguard_common::{ + db::{ + Id, NoId, + models::{ + device::{Device, DeviceInfo, DeviceNetworkInfo, DeviceType, WireguardNetworkDevice}, + user::User, + vpn_client_session::VpnClientSession, + wireguard::{LocationMfaMode, WireguardNetwork}, + }, }, -}; -use defguard_core::{ - enterprise::db::models::device_posture::{ - DevicePosture, DevicePostureLocation, DevicePostureOsRule, OsType, + gateway_types::{ + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, + Protocol as GwProtocol, SnatBinding, }, - grpc::GatewayEvent, +}; +use defguard_core::enterprise::db::models::device_posture::{ + DevicePosture, DevicePostureLocation, DevicePostureOsRule, OsType, }; use defguard_proto::{ - enterprise::firewall::{ - FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, Protocol, - SnatBinding, ip_address::Address, port::Port as PortInner, - }, + enterprise::firewall::FirewallConfig as ProtoFirewallConfig, gateway::{ CoreResponse, Update, UpdateType, core_response, update::{self}, @@ -263,7 +264,7 @@ pub(crate) async fn assert_device_event_is_ignored_before_config_handshake( device_name: &str, device_pubkey: &str, device_ip: &str, - build_event: fn(DeviceInfo) -> GatewayEvent, + build_event: fn(DeviceInfo) -> GatewayCommand, ) { let mut context = HandlerTestContext::new(options).await; assert_eq!(context.events_tx().receiver_count(), 0); @@ -288,7 +289,7 @@ pub(crate) async fn assert_device_event_for_different_network_is_ignored( device_name: &str, device_pubkey: &str, device_ip: &str, - build_event: fn(DeviceInfo) -> GatewayEvent, + build_event: fn(DeviceInfo) -> GatewayCommand, ) { let mut context = HandlerTestContext::new(options).await; let other_network = context.create_other_network().await; @@ -316,7 +317,7 @@ pub(crate) async fn assert_device_event_for_different_network_is_ignored( pub(crate) async fn assert_firewall_event_for_different_network_is_ignored( options: PgConnectOptions, - build_event: impl FnOnce(Id) -> GatewayEvent, + build_event: impl FnOnce(Id) -> GatewayCommand, ) { let mut context = HandlerTestContext::new(options).await; let other_network = context.create_other_network().await; @@ -439,28 +440,20 @@ pub(crate) fn assert_network_modify_update( pub(crate) fn build_test_firewall_config() -> FirewallConfig { FirewallConfig { - default_policy: i32::from(FirewallPolicy::Allow), + default_policy: FirewallPolicy::Allow, rules: vec![FirewallRule { id: 101, - source_addrs: vec![IpAddress { - address: Some(Address::IpSubnet("10.10.0.0/24".to_owned())), - }], - destination_addrs: vec![IpAddress { - address: Some(Address::Ip("198.51.100.20".to_owned())), - }], - destination_ports: vec![Port { - port: Some(PortInner::SinglePort(443)), - }], - protocols: vec![i32::from(Protocol::Tcp)], - verdict: i32::from(FirewallPolicy::Deny), + source_addrs: vec![IpAddress::IpSubnet("10.10.0.0/24".to_owned())], + destination_addrs: vec![IpAddress::Ip("198.51.100.20".to_owned())], + destination_ports: vec![Port::Single(443)], + protocols: vec![GwProtocol::Tcp], + verdict: FirewallPolicy::Deny, comment: Some("block test https destination".to_owned()), - ip_version: i32::from(IpVersion::Ipv4), + ip_version: IpVersion::Ipv4, }], snat_bindings: vec![SnatBinding { id: 202, - source_addrs: vec![IpAddress { - address: Some(Address::IpSubnet("10.10.0.0/24".to_owned())), - }], + source_addrs: vec![IpAddress::IpSubnet("10.10.0.0/24".to_owned())], public_ip: "203.0.113.44".to_owned(), comment: Some("test snat binding".to_owned()), }], @@ -471,6 +464,7 @@ pub(crate) fn assert_firewall_modify_update( outbound: CoreResponse, expected_firewall_config: &FirewallConfig, ) { + let expected_proto: ProtoFirewallConfig = expected_firewall_config.clone().into(); match outbound.payload { Some(core_response::Payload::Update(Update { update_type, @@ -479,22 +473,19 @@ pub(crate) fn assert_firewall_modify_update( assert_eq!(update_type, UpdateType::Modify as i32); assert_eq!( firewall_config.default_policy, - expected_firewall_config.default_policy - ); - assert_eq!( - firewall_config.rules.len(), - expected_firewall_config.rules.len() + expected_proto.default_policy ); + assert_eq!(firewall_config.rules.len(), expected_proto.rules.len()); assert_eq!( firewall_config.snat_bindings.len(), - expected_firewall_config.snat_bindings.len() + expected_proto.snat_bindings.len() ); let firewall_rule = firewall_config .rules .first() .expect("expected firewall rule in update payload"); - let expected_firewall_rule = expected_firewall_config + let expected_firewall_rule = expected_proto .rules .first() .expect("expected firewall rule in test config"); @@ -520,7 +511,7 @@ pub(crate) fn assert_firewall_modify_update( .snat_bindings .first() .expect("expected SNAT binding in update payload"); - let expected_snat_binding = expected_firewall_config + let expected_snat_binding = expected_proto .snat_bindings .first() .expect("expected SNAT binding in test config"); diff --git a/crates/defguard_proto/src/gateway_conversions.rs b/crates/defguard_proto/src/gateway_conversions.rs new file mode 100644 index 000000000..a2e92215c --- /dev/null +++ b/crates/defguard_proto/src/gateway_conversions.rs @@ -0,0 +1,132 @@ +//! Conversions from [`defguard_common::gateway_types`] domain types to their +//! protobuf counterparts. These are the sole serialization boundary between +//! the native gateway command representation and the wire protocol. + +use defguard_common::gateway_types::{ + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpRange, IpVersion, Port, PortRange, + Protocol, SnatBinding, WireguardPeer, +}; + +use crate::{ + enterprise::firewall::{ + FirewallConfig as ProtoFirewallConfig, FirewallPolicy as ProtoFirewallPolicy, + FirewallRule as ProtoFirewallRule, IpAddress as ProtoIpAddress, IpRange as ProtoIpRange, + IpVersion as ProtoIpVersion, Port as ProtoPort, PortRange as ProtoPortRange, + Protocol as ProtoProtocol, SnatBinding as ProtoSnatBinding, ip_address::Address, + port::Port as PortInner, + }, + gateway::Peer, +}; + +impl From for Peer { + fn from(p: WireguardPeer) -> Self { + Self { + pubkey: p.pubkey, + allowed_ips: p.allowed_ips, + preshared_key: p.preshared_key, + keepalive_interval: p.keepalive_interval, + } + } +} + +fn policy_to_i32(policy: FirewallPolicy) -> i32 { + match policy { + FirewallPolicy::Unspecified => ProtoFirewallPolicy::Unspecified as i32, + FirewallPolicy::Allow => ProtoFirewallPolicy::Allow as i32, + FirewallPolicy::Deny => ProtoFirewallPolicy::Deny as i32, + } +} + +fn ip_version_to_i32(v: IpVersion) -> i32 { + match v { + IpVersion::Unspecified => ProtoIpVersion::Unspecified as i32, + IpVersion::Ipv4 => ProtoIpVersion::Ipv4 as i32, + IpVersion::Ipv6 => ProtoIpVersion::Ipv6 as i32, + } +} + +fn protocol_to_i32(p: Protocol) -> i32 { + match p { + Protocol::Unspecified => ProtoProtocol::Unspecified as i32, + Protocol::Icmp => ProtoProtocol::Icmp as i32, + Protocol::Tcp => ProtoProtocol::Tcp as i32, + Protocol::Udp => ProtoProtocol::Udp as i32, + } +} + +impl From for ProtoIpRange { + fn from(r: IpRange) -> Self { + Self { + start: r.start, + end: r.end, + } + } +} + +impl From for ProtoIpAddress { + fn from(addr: IpAddress) -> Self { + Self { + address: Some(match addr { + IpAddress::Ip(ip) => Address::Ip(ip), + IpAddress::IpRange(r) => Address::IpRange(r.into()), + IpAddress::IpSubnet(s) => Address::IpSubnet(s), + }), + } + } +} + +impl From for ProtoPortRange { + fn from(r: PortRange) -> Self { + Self { + start: r.start, + end: r.end, + } + } +} + +impl From for ProtoPort { + fn from(p: Port) -> Self { + Self { + port: Some(match p { + Port::Single(n) => PortInner::SinglePort(n), + Port::Range(r) => PortInner::PortRange(r.into()), + }), + } + } +} + +impl From for ProtoFirewallRule { + fn from(r: FirewallRule) -> Self { + Self { + id: r.id, + source_addrs: r.source_addrs.into_iter().map(Into::into).collect(), + destination_addrs: r.destination_addrs.into_iter().map(Into::into).collect(), + destination_ports: r.destination_ports.into_iter().map(Into::into).collect(), + protocols: r.protocols.into_iter().map(protocol_to_i32).collect(), + verdict: policy_to_i32(r.verdict), + comment: r.comment, + ip_version: ip_version_to_i32(r.ip_version), + } + } +} + +impl From for ProtoSnatBinding { + fn from(b: SnatBinding) -> Self { + Self { + id: b.id, + source_addrs: b.source_addrs.into_iter().map(Into::into).collect(), + public_ip: b.public_ip, + comment: b.comment, + } + } +} + +impl From for ProtoFirewallConfig { + fn from(c: FirewallConfig) -> Self { + Self { + default_policy: policy_to_i32(c.default_policy), + rules: c.rules.into_iter().map(Into::into).collect(), + snat_bindings: c.snat_bindings.into_iter().map(Into::into).collect(), + } + } +} diff --git a/crates/defguard_proto/src/lib.rs b/crates/defguard_proto/src/lib.rs index db7d37dde..8e3221214 100644 --- a/crates/defguard_proto/src/lib.rs +++ b/crates/defguard_proto/src/lib.rs @@ -1,3 +1,5 @@ +pub mod gateway_conversions; + use std::fmt; mod generated { @@ -85,15 +87,13 @@ use defguard_common::{ wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, + gateway_types::{FirewallConfig, WireguardPeer}, }; use proxy::CoreError; use serde::Serialize; use tonic::Status; -use crate::{ - enterprise::firewall::FirewallConfig, - gateway::{Configuration, Peer}, -}; +use crate::gateway::Configuration; // Client MFA methods impl fmt::Display for MfaMethod { @@ -224,7 +224,7 @@ impl From for client_types::ServiceLocationMode { impl Configuration { pub fn new( location: &WireguardNetwork, - peers: Vec, + peers: Vec, maybe_firewall_config: Option, ) -> Self { Self { @@ -232,8 +232,8 @@ impl Configuration { port: location.port.cast_unsigned(), private_key: location.prvkey.clone(), addresses: location.address().iter().map(ToString::to_string).collect(), - peers, - firewall_config: maybe_firewall_config, + peers: peers.into_iter().map(Into::into).collect(), + firewall_config: maybe_firewall_config.map(Into::into), mtu: location.mtu.cast_unsigned(), fwmark: location.fwmark as u32, } diff --git a/crates/defguard_proxy_manager/src/handler.rs b/crates/defguard_proxy_manager/src/handler.rs index 870c2bb5f..5e73bb11c 100644 --- a/crates/defguard_proxy_manager/src/handler.rs +++ b/crates/defguard_proxy_manager/src/handler.rs @@ -31,7 +31,7 @@ use defguard_core::{ ldap::utils::ldap_update_user_state, }, grpc::{ - GatewayEvent, + GatewayCommand, proxy::client_mfa::{ ClientLoginSession, ClientMfaServer, ClientMfaStartOutcome, PostureCheckOutcome, }, @@ -473,7 +473,7 @@ impl ProxyHandler { async fn message_loop( &mut self, tx: UnboundedSender, - wireguard_tx: Sender, + gateway_tx: Sender, resp_stream: &mut Streaming, ) -> Result<(), ProxyError> { let pool = self.pool.clone(); @@ -872,7 +872,7 @@ impl ProxyHandler { if let Err(err) = sync_user_groups_if_configured( &user, &pool, - &wireguard_tx, + &gateway_tx, ) .await { diff --git a/crates/defguard_proxy_manager/src/lib.rs b/crates/defguard_proxy_manager/src/lib.rs index cb6c546be..baded624b 100644 --- a/crates/defguard_proxy_manager/src/lib.rs +++ b/crates/defguard_proxy_manager/src/lib.rs @@ -9,11 +9,11 @@ use std::{path::PathBuf, str::FromStr, sync::Mutex as StdMutex}; use axum_extra::extract::cookie::Key; use defguard_common::{ db::{Id, models::proxy::Proxy}, + gateway_event::GatewayCommand, types::proxy::ProxyControlMessage, }; use defguard_core::{ - events::BidiStreamEvent, - grpc::{GatewayEvent, proxy::client_mfa::ClientLoginSession}, + events::BidiStreamEvent, grpc::proxy::client_mfa::ClientLoginSession, version::IncompatibleComponents, }; use defguard_proto::proxy::{CoreResponse, HttpsCerts, core_response}; @@ -401,14 +401,14 @@ impl ProxyManager { /// events, notifications, and side effects to Core components. #[derive(Clone)] pub struct ProxyTxSet { - wireguard: Sender, + wireguard: Sender, bidi_events: UnboundedSender, } impl ProxyTxSet { #[must_use] pub const fn new( - wireguard: Sender, + wireguard: Sender, bidi_events: UnboundedSender, ) -> Self { Self { diff --git a/crates/defguard_proxy_manager/src/servers/enrollment.rs b/crates/defguard_proxy_manager/src/servers/enrollment.rs index 792ed90d5..c23760acd 100644 --- a/crates/defguard_proxy_manager/src/servers/enrollment.rs +++ b/crates/defguard_proxy_manager/src/servers/enrollment.rs @@ -21,7 +21,7 @@ use defguard_core::{ }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, EnrollmentEvent}, grpc::{ - GatewayEvent, InstanceInfo, + GatewayCommand, InstanceInfo, client_version::ClientFeature, utils::{build_device_config_response, parse_client_ip_agent}, }, @@ -48,7 +48,7 @@ use tonic::Status; pub(crate) struct EnrollmentServer { pool: PgPool, - wireguard_tx: Sender, + gateway_tx: Sender, bidi_event_tx: UnboundedSender, } @@ -56,12 +56,12 @@ impl EnrollmentServer { #[must_use] pub(crate) fn new( pool: PgPool, - wireguard_tx: Sender, + gateway_tx: Sender, bidi_event_tx: UnboundedSender, ) -> Self { Self { pool, - wireguard_tx, + gateway_tx, bidi_event_tx, } } @@ -96,10 +96,10 @@ impl EnrollmentServer { } } - /// Sends given `GatewayEvent` to be handled by gateway GRPC server - pub(crate) fn send_wireguard_event(&self, event: GatewayEvent) { - if let Err(err) = self.wireguard_tx.send(event) { - error!("Error sending WireGuard event {err}"); + /// Sends given `GatewayCommand` to be handled by gateway manager service + pub(crate) fn send_gateway_command(&self, event: GatewayCommand) { + if let Err(err) = self.gateway_tx.send(event) { + error!("Error sending Gateway command: {err}"); } } @@ -785,7 +785,7 @@ impl EnrollmentServer { adding new device {}, user {}({})", device.wireguard_pubkey, user.username, user.id ); - self.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + self.send_gateway_command(GatewayCommand::FirewallConfigChanged( location_id, firewall_config, )); @@ -797,7 +797,7 @@ impl EnrollmentServer { "Sending DeviceCreated event to gateway for device {}, user {}({:?})", device.wireguard_pubkey, user.username, user.id, ); - self.send_wireguard_event(GatewayEvent::DeviceCreated(DeviceInfo { + self.send_gateway_command(GatewayCommand::DeviceCreated(DeviceInfo { device: device.clone(), network_info, })); @@ -1201,9 +1201,9 @@ mod test { settings.enrollment_send_welcome_email = false; update_current_settings(&pool, settings).await.unwrap(); - let (wireguard_tx, _) = broadcast::channel(1); + let (gateway_tx, _) = broadcast::channel(1); let (bidi_event_tx, _) = unbounded_channel(); - let server = EnrollmentServer::new(pool.clone(), wireguard_tx, bidi_event_tx); + let server = EnrollmentServer::new(pool.clone(), gateway_tx, bidi_event_tx); let mut transaction = pool.begin().await.unwrap(); let result = server diff --git a/crates/defguard_proxy_manager/src/tests/common/mod.rs b/crates/defguard_proxy_manager/src/tests/common/mod.rs index a2e19d770..a3f4601f8 100644 --- a/crates/defguard_proxy_manager/src/tests/common/mod.rs +++ b/crates/defguard_proxy_manager/src/tests/common/mod.rs @@ -16,15 +16,18 @@ use axum::{ response::Json, }; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; -use defguard_common::db::{ - Id, NoId, - models::{ - proxy::Proxy, - settings::{Settings, initialize_current_settings}, +use defguard_common::{ + db::{ + Id, NoId, + models::{ + proxy::Proxy, + settings::{Settings, initialize_current_settings}, + }, + setup_pool, }, - setup_pool, + gateway_event::GatewayCommand, }; -use defguard_core::{events::BidiStreamEvent, grpc::GatewayEvent}; +use defguard_core::events::BidiStreamEvent; use defguard_proto::proxy::{ AcmeChallenge, AcmeIssueEvent, CoreRequest, CoreResponse, InitialInfo, core_response, proxy_server, @@ -381,7 +384,7 @@ impl Drop for MockProxyHarness { pub(crate) struct HandlerTestContext { pub(crate) pool: PgPool, pub(crate) proxy: Proxy, - pub(crate) wireguard_tx: broadcast::Sender, + pub(crate) gateway_tx: broadcast::Sender, pub(crate) bidi_events_rx: UnboundedReceiver, pub(crate) mock_proxy: Option, handler_task: Option>>, @@ -406,9 +409,9 @@ impl HandlerTestContext { let proxy = create_proxy(&pool).await; - let (wireguard_tx, _) = broadcast::channel(16); + let (gateway_tx, _) = broadcast::channel(16); let (bidi_events_tx, bidi_events_rx) = mpsc::unbounded_channel::(); - let tx_set = ProxyTxSet::new(wireguard_tx.clone(), bidi_events_tx); + let tx_set = ProxyTxSet::new(gateway_tx.clone(), bidi_events_tx); let (_, certs_rx) = watch::channel(Arc::new(HashMap::new())); let incompatible_components = Arc::new(std::sync::RwLock::new( @@ -450,7 +453,7 @@ impl HandlerTestContext { Self { pool, proxy, - wireguard_tx, + gateway_tx, bidi_events_rx, mock_proxy: Some(mock_proxy), handler_task: Some(handler_task), @@ -586,9 +589,9 @@ impl ManagerTestContext { pub(crate) async fn start(&mut self) { assert!(self.manager_task.is_none(), "proxy manager already started"); - let (wireguard_tx, _) = broadcast::channel(16); + let (gateway_tx, _) = broadcast::channel(16); let (bidi_events_tx, _bidi_events_rx) = mpsc::unbounded_channel::(); - let tx_set = ProxyTxSet::new(wireguard_tx, bidi_events_tx); + let tx_set = ProxyTxSet::new(gateway_tx, bidi_events_tx); let incompatible_components = Arc::new(std::sync::RwLock::new( defguard_core::version::IncompatibleComponents::default(), diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs index 36af0a652..80477829e 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs @@ -1,11 +1,11 @@ -use defguard_common::db::models::{ - Device, Settings, User, biometric_auth::BiometricAuth, polling_token::PollingToken, - settings::update_current_settings, -}; -use defguard_core::{ - events::{BidiStreamEventType, EnrollmentEvent}, - grpc::GatewayEvent, +use defguard_common::{ + db::models::{ + Device, Settings, User, biometric_auth::BiometricAuth, polling_token::PollingToken, + settings::update_current_settings, + }, + gateway_event::GatewayCommand, }; +use defguard_core::events::{BidiStreamEventType, EnrollmentEvent}; use defguard_proto::{ client_types::{ExistingDevice, MfaMethod, NewDevice, RegisterMobileAuthRequest}, proxy::{CoreRequest, core_request, core_response}, @@ -261,8 +261,8 @@ async fn test_new_device_sends_gateway_device_created_event( // Start the enrollment session so Token::used_at is set. start_enrollment_session(&mut context, &token.id).await; - // Subscribe to gateway events BEFORE sending the request. - let mut gateway_rx = context.wireguard_tx.subscribe(); + // Subscribe to gateway commands BEFORE sending the request. + let mut gateway_rx = context.gateway_tx.subscribe(); let pubkey = "DhsoNUJPXGl2g5CdqrfE0d7r+AUSHyw5RlNgbXqHlKE="; context.mock_proxy().send_request(CoreRequest { @@ -281,12 +281,12 @@ async fn test_new_device_sends_gateway_device_created_event( // Check that a DeviceCreated event was broadcast. let event = timeout(TEST_TIMEOUT, gateway_rx.recv()) .await - .expect("timed out waiting for GatewayEvent::DeviceCreated") - .expect("gateway event channel closed"); + .expect("timed out waiting for GatewayCommand::DeviceCreated") + .expect("gateway command channel closed"); assert!( - matches!(event, GatewayEvent::DeviceCreated(_)), - "expected DeviceCreated gateway event, got: {event:?}" + matches!(event, GatewayCommand::DeviceCreated(_)), + "expected DeviceCreated gateway command, got: {event:?}" ); context.finish().await.expect_server_finished().await; diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs index 8f16a630e..ad1c0f186 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs @@ -1,5 +1,4 @@ -use defguard_common::db::Id; -use defguard_core::grpc::GatewayEvent; +use defguard_common::{db::Id, gateway_event::GatewayCommand}; use defguard_proto::{ client_types::{ClientMfaFinishRequest, ClientMfaStartRequest, MfaMethod}, proxy::{AwaitRemoteMfaFinishRequest, CoreRequest, core_request, core_response}, @@ -123,9 +122,9 @@ async fn test_mfa_finish_succeeds_with_totp_code(_: PgPoolOptions, options: PgCo ) .await; - // Subscribe before finish so the handler's wireguard_tx.send() has a receiver, + // Subscribe before finish so the handler's gateway_tx.send() has a receiver, // and keep the receiver alive so we can assert on the event. - let mut gateway_rx = context.wireguard_tx.subscribe(); + let mut gateway_rx = context.gateway_tx.subscribe(); let code = generate_totp_code(&user); let (_, psk) = send_mfa_finish(&mut context, &token, Some(&code)).await; @@ -138,14 +137,14 @@ async fn test_mfa_finish_succeeds_with_totp_code(_: PgPoolOptions, options: PgCo let session = assert_vpn_session_exists(&context.pool, network.id, device.id).await; assert!(session.preshared_key.is_some()); - // Verify GatewayEvent::VpnSessionAuthorized was broadcast. + // Verify GatewayCommand::VpnSessionAuthorized was broadcast. // Use the already-subscribed receiver - subscribing after send_mfa_finish would miss the event. let event = timeout(RECEIVE_TIMEOUT, gateway_rx.recv()) .await - .expect("timed out waiting for GatewayEvent::VpnSessionAuthorized") - .expect("gateway event channel closed"); + .expect("timed out waiting for GatewayCommand::VpnSessionAuthorized") + .expect("gateway command channel closed"); let gateway_loc_id = match event { - GatewayEvent::VpnSessionAuthorized(loc_id, _, _) => loc_id, + GatewayCommand::VpnSessionAuthorized(loc_id, _, _) => loc_id, other => panic!("expected VpnSessionAuthorized, got: {other:?}"), }; assert_eq!(gateway_loc_id, network.id); @@ -298,9 +297,9 @@ async fn test_mfa_finish_succeeds_and_creates_session(_: PgPoolOptions, options: .await; // Subscribe to the gateway broadcast BEFORE calling finish, so that the - // handler's wireguard_tx.send() has at least one active receiver (without + // handler's gateway_tx.send() has at least one active receiver (without // one the send would fail with SendError and return Internal). - let mut gateway_rx = context.wireguard_tx.subscribe(); + let mut gateway_rx = context.gateway_tx.subscribe(); // The start handler has already called generate_email_mfa_code internally // and the in-memory secret is still the same, so regenerating here gives @@ -316,13 +315,13 @@ async fn test_mfa_finish_succeeds_and_creates_session(_: PgPoolOptions, options: let session = assert_vpn_session_exists(&context.pool, network.id, device.id).await; assert!(session.preshared_key.is_some()); - // Verify GatewayEvent::VpnSessionAuthorized was broadcast + // Verify GatewayCommand::VpnSessionAuthorized was broadcast let event = timeout(RECEIVE_TIMEOUT, gateway_rx.recv()) .await - .expect("timed out waiting for GatewayEvent::VpnSessionAuthorized") - .expect("gateway event channel closed"); + .expect("timed out waiting for GatewayCommand::VpnSessionAuthorized") + .expect("gateway command channel closed"); let loc_id = match event { - GatewayEvent::VpnSessionAuthorized(loc_id, _, _) => loc_id, + GatewayCommand::VpnSessionAuthorized(loc_id, _, _) => loc_id, other => panic!("expected VpnSessionAuthorized, got: {other:?}"), }; assert_eq!(loc_id, network.id); @@ -358,8 +357,8 @@ async fn test_mfa_token_valid_before_finish_invalid_after( let valid = send_token_validation(&mut context, &token).await; assert!(valid, "token must be valid after start"); - // Subscribe before finish so the handler's wireguard_tx.send() has a receiver - let _gateway_rx = context.wireguard_tx.subscribe(); + // Subscribe before finish so the handler's gateway_tx.send() has a receiver + let _gateway_rx = context.gateway_tx.subscribe(); let code = user.generate_email_mfa_code().expect("generate email code"); send_mfa_finish(&mut context, &token, Some(&code)).await; @@ -468,8 +467,8 @@ async fn test_mfa_await_remote_receives_psk_after_finish( // we proceed with the finish call. task::yield_now().await; - // Subscribe before finish so the handler's wireguard_tx.send() has a receiver - let _gateway_rx = context.wireguard_tx.subscribe(); + // Subscribe before finish so the handler's gateway_tx.send() has a receiver + let _gateway_rx = context.gateway_tx.subscribe(); // Now finish the MFA login with the correct code. Use the no-recv variant // because two responses will arrive (ClientMfaFinish + AwaitRemoteMfaFinish) @@ -509,7 +508,7 @@ async fn test_mfa_await_remote_receives_psk_after_finish( /// When a second MFA cycle completes for the same device+location the handler /// must: /// - disconnect the first `VpnClientSession` (state → Disconnected), -/// - emit `GatewayEvent::VpnSessionDeauthorized` for the first session, and +/// - emit `GatewayCommand::VpnSessionDeauthorized` for the first session, and /// - create a new active `VpnClientSession`. #[sqlx::test] async fn test_mfa_finish_replaces_existing_session_disconnects_old( @@ -525,7 +524,7 @@ async fn test_mfa_finish_replaces_existing_session_disconnects_old( // ---- First MFA cycle ---- // Must subscribe before finish so the send has a receiver. - let _gw_rx1 = context.wireguard_tx.subscribe(); + let _gw_rx1 = context.gateway_tx.subscribe(); let (_, token1) = send_mfa_start( &mut context, @@ -566,7 +565,7 @@ async fn test_mfa_finish_replaces_existing_session_disconnects_old( // Subscribe before finish so both VpnSessionDeauthorized and // VpnSessionAuthorized have an active receiver. - let mut gw_rx2 = context.wireguard_tx.subscribe(); + let mut gw_rx2 = context.gateway_tx.subscribe(); let code2 = generate_totp_code(&user); let (_, psk2) = send_mfa_finish(&mut context, &token2, Some(&code2)).await; @@ -583,20 +582,20 @@ async fn test_mfa_finish_replaces_existing_session_disconnects_old( for _ in 0..2 { let event = timeout(RECEIVE_TIMEOUT, gw_rx2.recv()) .await - .expect("timed out waiting for gateway event after second MFA finish") - .expect("gateway event channel closed"); + .expect("timed out waiting for gateway command after second MFA finish") + .expect("gateway command channel closed"); match event { - GatewayEvent::VpnSessionDeauthorized(loc_id, ref dev) => { + GatewayCommand::VpnSessionDeauthorized(loc_id, ref dev) => { assert_eq!(loc_id, network.id, "disconnected session location mismatch"); assert_eq!(dev.id, device.id, "disconnected session device mismatch"); got_disconnected = true; } - GatewayEvent::VpnSessionAuthorized(loc_id, _, _) => { + GatewayCommand::VpnSessionAuthorized(loc_id, _, _) => { assert_eq!(loc_id, network.id, "authorized session location mismatch"); got_authorized = true; } - other => panic!("unexpected gateway event: {other:?}"), + other => panic!("unexpected gateway command: {other:?}"), } } assert!(got_disconnected, "VpnSessionDeauthorized must be emitted"); diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/oidc.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/oidc.rs index 6b3e97b26..176d52e0f 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/oidc.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/oidc.rs @@ -282,7 +282,7 @@ async fn test_mfa_oidc_full_flow(_: PgPoolOptions, options: PgConnectOptions) { set_public_proxy_url(&context.pool, &mock.base_url).await; // Subscribe to gateway events before sending MFA finish. - let _gateway_rx = context.wireguard_tx.subscribe(); + let _gateway_rx = context.gateway_tx.subscribe(); // ---- Step 1: ClientMfaStart with Oidc method ---- let (_, mfa_token) = send_mfa_start( diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 4904a2603..44327f69b 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -1,5 +1,4 @@ -use defguard_common::db::Id; -use defguard_core::grpc::GatewayEvent; +use defguard_common::{db::Id, gateway_event::GatewayCommand}; use thiserror::Error; use tokio::sync::{broadcast::error::SendError as BroadcastSendError, mpsc::error::SendError}; @@ -36,7 +35,7 @@ pub enum SessionManagerError { #[error("Failed to send session manager event: {0}")] SessionManagerEventError(Box>), #[error("Failed to send gateway manager event: {0}")] - GatewayManagerEventError(Box>), + GatewayManagerEventError(Box>), } impl From> for SessionManagerError { @@ -44,8 +43,8 @@ impl From> for SessionManagerError { Self::SessionManagerEventError(Box::new(error)) } } -impl From> for SessionManagerError { - fn from(error: BroadcastSendError) -> Self { +impl From> for SessionManagerError { + fn from(error: BroadcastSendError) -> Self { Self::GatewayManagerEventError(Box::new(error)) } } diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 240dc9ae9..9aa180089 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -7,9 +7,9 @@ use defguard_common::{ vpn_client_session::{VpnClientSession, VpnClientSessionState}, }, }, + gateway_event::GatewayCommand, messages::peer_stats_update::PeerStatsUpdate, }; -use defguard_core::grpc::GatewayEvent; use sqlx::{PgConnection, PgPool}; use tokio::{ sync::{ @@ -42,7 +42,7 @@ pub async fn run_session_manager( pool: PgPool, mut peer_stats_rx: UnboundedReceiver, session_manager_event_tx: UnboundedSender, - gateway_tx: Sender, + gateway_tx: Sender, ) -> Result<(), SessionManagerError> { info!("Starting VPN client session manager service"); let mut session_update_timer = interval(SESSION_UPDATE_INTERVAL); @@ -103,7 +103,7 @@ pub async fn run_session_manager_iteration( pub struct SessionManager { pool: PgPool, session_manager_event_tx: UnboundedSender, - gateway_tx: Sender, + gateway_tx: Sender, } impl SessionManager { @@ -111,7 +111,7 @@ impl SessionManager { pub fn new( pool: PgPool, session_manager_event_tx: UnboundedSender, - gateway_tx: Sender, + gateway_tx: Sender, ) -> Self { Self { pool, @@ -335,9 +335,9 @@ impl SessionManager { device: &Device, ) -> Result<(), SessionManagerError> { debug!( - "Sending VPN session deauthorization event for device {device} in location {location} to gateway manager" + "Sending MFA session disconnect event for device {device} in location {location} to gateway manager" ); - let event = GatewayEvent::VpnSessionDeauthorized(location.id, device.clone()); + let event = GatewayCommand::VpnSessionDeauthorized(location.id, device.clone()); self.gateway_tx.send(event)?; Ok(()) } diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index e675bdb3c..8b178945e 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -16,6 +16,7 @@ use defguard_common::{ wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, + gateway_event::GatewayCommand, messages::peer_stats_update::PeerStatsUpdate, }; use defguard_core::enterprise::db::models::device_posture::{ @@ -40,7 +41,7 @@ pub(crate) struct SessionManagerHarness { stats_tx: mpsc::UnboundedSender, pub(crate) stats_rx: mpsc::UnboundedReceiver, pub(crate) event_rx: mpsc::UnboundedReceiver, - pub(crate) gateway_rx: broadcast::Receiver, + pub(crate) gateway_rx: broadcast::Receiver, } pub(crate) fn assert_no_session_manager_events(harness: &mut SessionManagerHarness) { diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs index badea546f..72c4bf47e 100644 --- a/crates/defguard_session_manager/tests/session_manager/mfa.rs +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -9,7 +9,7 @@ use defguard_common::db::{ }, setup_pool, }; -use defguard_core::grpc::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_session_manager::events::SessionManagerEventType; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::time::{Duration, timeout}; @@ -503,7 +503,7 @@ async fn test_inactive_mfa_connected_sessions_disconnect_and_clear_authorization .expect("timed out waiting for MFA disconnect gateway event") .expect("gateway event channel closed"); match gateway_event { - GatewayEvent::VpnSessionDeauthorized(location_id, disconnected_device) => { + GatewayCommand::VpnSessionDeauthorized(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } @@ -570,7 +570,7 @@ async fn test_never_connected_mfa_new_sessions_disconnect_after_threshold( .expect("timed out waiting for MFA disconnect gateway event for new session") .expect("gateway event channel closed"); match gateway_event { - GatewayEvent::VpnSessionDeauthorized(location_id, disconnected_device) => { + GatewayCommand::VpnSessionDeauthorized(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } diff --git a/crates/defguard_session_manager/tests/session_manager/sessions.rs b/crates/defguard_session_manager/tests/session_manager/sessions.rs index 76282b135..bbbd1e9e4 100644 --- a/crates/defguard_session_manager/tests/session_manager/sessions.rs +++ b/crates/defguard_session_manager/tests/session_manager/sessions.rs @@ -8,7 +8,7 @@ use defguard_common::db::{ }, setup_pool, }; -use defguard_core::grpc::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_session_manager::events::SessionManagerEventType; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::time::{Duration, timeout}; @@ -348,7 +348,7 @@ async fn test_never_connected_posture_new_session_disconnects_after_threshold( .expect("timed out waiting for posture disconnect gateway event for new session") .expect("gateway event channel closed"); match gateway_event { - GatewayEvent::VpnSessionDeauthorized(location_id, disconnected_device) => { + GatewayCommand::VpnSessionDeauthorized(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } @@ -417,7 +417,7 @@ async fn test_inactive_posture_connected_session_disconnects_and_clears_authoriz .expect("timed out waiting for posture disconnect gateway event") .expect("gateway event channel closed"); match gateway_event { - GatewayEvent::VpnSessionDeauthorized(location_id, disconnected_device) => { + GatewayCommand::VpnSessionDeauthorized(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } diff --git a/crates/defguard_setup/src/migration.rs b/crates/defguard_setup/src/migration.rs index ee1980bf4..8c7c76b32 100644 --- a/crates/defguard_setup/src/migration.rs +++ b/crates/defguard_setup/src/migration.rs @@ -10,14 +10,15 @@ use axum::{ serve, }; use axum_extra::extract::cookie::Key; -use defguard_common::{VERSION, db::models::Settings, types::proxy::ProxyControlMessage}; +use defguard_common::{ + VERSION, db::models::Settings, gateway_event::GatewayCommand, types::proxy::ProxyControlMessage, +}; use defguard_core::{ appstate::AppState, auth::failed_login::FailedLoginMap, db::AppEvent, enterprise::handlers::openid_login::{auth_callback, get_auth_info}, events::ApiEvent, - grpc::GatewayEvent, handle_404, handlers::{ auth::{ @@ -60,7 +61,7 @@ use crate::handlers::{ pub struct MigrationWebapp { pub router: Router, _event_rx: mpsc::UnboundedReceiver, - _wireguard_rx: broadcast::Receiver, + _gateway_rx: broadcast::Receiver, _proxy_control_rx: mpsc::Receiver, } @@ -72,7 +73,7 @@ pub fn build_migration_webapp( let failed_logins = Arc::new(Mutex::new(FailedLoginMap::new())); let (webhook_tx, webhook_rx) = mpsc::unbounded_channel::(); let (event_tx, event_rx) = mpsc::unbounded_channel::(); - let (wireguard_tx, wireguard_rx) = broadcast::channel::(64); + let (gateway_tx, gateway_rx) = broadcast::channel::(64); let (web_reload_tx, _web_reload_rx) = broadcast::channel::<()>(8); let (proxy_control_tx, proxy_control_rx) = mpsc::channel(32); let incompatible_components = Arc::new(RwLock::new(IncompatibleComponents::default())); @@ -86,7 +87,7 @@ pub fn build_migration_webapp( pool.clone(), webhook_tx, webhook_rx, - wireguard_tx.clone(), + gateway_tx.clone(), web_reload_tx, key, failed_logins.clone(), @@ -175,7 +176,7 @@ pub fn build_migration_webapp( MigrationWebapp { router, _event_rx: event_rx, - _wireguard_rx: wireguard_rx, + _gateway_rx: gateway_rx, _proxy_control_rx: proxy_control_rx, } }