From 91977358c3da83aed3b6fc23d2206efcd49934dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 18 May 2026 16:37:39 +0200 Subject: [PATCH 01/10] decouple gateway events from protobuf --- crates/defguard_common/src/gateway_types.rs | 99 ++ crates/defguard_common/src/lib.rs | 1 + .../src/enterprise/firewall/mod.rs | 69 +- .../enterprise/firewall/tests/destination.rs | 385 +++----- .../src/enterprise/firewall/tests/gh1868.rs | 34 +- .../firewall/tests/ip_address_handling.rs | 230 ++--- .../src/enterprise/firewall/tests/mod.rs | 920 ++++++------------ .../src/enterprise/firewall/tests/source.rs | 45 +- crates/defguard_core/src/grpc/mod.rs | 13 +- .../src/location_management/allowed_peers.rs | 12 +- .../defguard_gateway_manager/src/handler.rs | 33 +- .../tests/gateway_manager/handler/support.rs | 47 +- .../defguard_proto/src/gateway_conversions.rs | 128 +++ crates/defguard_proto/src/lib.rs | 14 +- 14 files changed, 855 insertions(+), 1175 deletions(-) create mode 100644 crates/defguard_common/src/gateway_types.rs create mode 100644 crates/defguard_proto/src/gateway_conversions.rs diff --git a/crates/defguard_common/src/gateway_types.rs b/crates/defguard_common/src/gateway_types.rs new file mode 100644 index 0000000000..4a3199bbe5 --- /dev/null +++ b/crates/defguard_common/src/gateway_types.rs @@ -0,0 +1,99 @@ +/// A WireGuard peer entry to be configured on a gateway. +#[derive(Clone, Debug, PartialEq)] +pub struct WireguardPeer { + pub pubkey: String, + pub allowed_ips: Vec, + pub preshared_key: Option, + pub keepalive_interval: Option, +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub enum FirewallPolicy { + #[default] + Unspecified, + Allow, + Deny, +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub enum IpVersion { + #[default] + Unspecified, + Ipv4, + Ipv6, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Protocol { + Unspecified, + Icmp, + Tcp, + Udp, +} + +impl From for Protocol { + fn from(v: i32) -> Self { + match v { + 1 => Self::Icmp, + 6 => Self::Tcp, + 17 => Self::Udp, + _ => Self::Unspecified, + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct IpRange { + pub start: String, + pub end: String, +} + +/// An IP address, range, or subnet. +#[derive(Clone, Debug, PartialEq)] +pub enum IpAddress { + /// Single IP address string (e.g. `"10.0.0.1"`). + Ip(String), + /// Inclusive IP range. + IpRange(IpRange), + /// IP subnet in CIDR notation (e.g. `"10.0.0.0/24"`). + IpSubnet(String), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct PortRange { + pub start: u32, + pub end: u32, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Port { + Single(u32), + Range(PortRange), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct FirewallRule { + pub id: i64, + pub source_addrs: Vec, + pub destination_addrs: Vec, + pub destination_ports: Vec, + pub protocols: Vec, + pub verdict: FirewallPolicy, + pub comment: Option, + pub ip_version: IpVersion, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct SnatBinding { + pub id: i64, + pub source_addrs: Vec, + pub public_ip: String, + pub comment: Option, +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct FirewallConfig { + pub default_policy: FirewallPolicy, + pub rules: Vec, + pub snat_bindings: Vec, +} diff --git a/crates/defguard_common/src/lib.rs b/crates/defguard_common/src/lib.rs index 394f7e333a..bf1fb913ff 100644 --- a/crates/defguard_common/src/lib.rs +++ b/crates/defguard_common/src/lib.rs @@ -2,6 +2,7 @@ pub mod auth; pub mod config; pub mod csv; pub mod db; +pub mod gateway_types; pub mod globals; pub mod hex; pub mod messages; diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs index bf4b793c83..af18056f79 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -3,14 +3,15 @@ use std::{ ops::RangeInclusive, }; -use defguard_common::db::{ - Id, - models::{Device, ModelError, WireguardNetwork, user::User}, -}; -use defguard_proto::enterprise::firewall::{ - FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpRange, IpVersion, Port, - PortRange as PortRangeProto, SnatBinding as SnatBindingProto, ip_address::Address, - port::Port as PortInner, +use defguard_common::{ + db::{ + Id, + models::{Device, ModelError, WireguardNetwork, user::User}, + }, + gateway_types::{ + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpRange, IpVersion, Port, + PortRange as GwPortRange, Protocol as GwProtocol, SnatBinding, + }, }; use ipnetwork::IpNetwork; use sqlx::{PgConnection, query_as, query_scalar}; @@ -424,7 +425,7 @@ fn create_rules( protocols: &[Protocol], comment: &str, ) -> (Option, FirewallRule) { - let ip_version = i32::from(ip_version); + let gw_protocols: Vec = protocols.iter().map(|&p| GwProtocol::from(p)).collect(); let allow = if source_addrs.is_empty() { debug!("Source address list is empty. Skipping generating the ALLOW rule for this ACL"); None @@ -435,10 +436,10 @@ fn create_rules( source_addrs: source_addrs.to_vec(), destination_addrs: destination_addrs.to_vec(), destination_ports: destination_ports.to_vec(), - protocols: protocols.to_vec(), - verdict: i32::from(FirewallPolicy::Allow), + protocols: gw_protocols.clone(), + verdict: FirewallPolicy::Allow, comment: Some(format!("{comment} ALLOW")), - ip_version, + ip_version: ip_version.clone(), }; debug!("ALLOW rule generated from ACL: {rule:?}"); Some(rule) @@ -452,7 +453,7 @@ fn create_rules( destination_addrs: destination_addrs.to_vec(), destination_ports: Vec::new(), protocols: Vec::new(), - verdict: i32::from(FirewallPolicy::Deny), + verdict: FirewallPolicy::Deny, comment: Some(format!("{comment} DENY")), ip_version, }; @@ -753,9 +754,7 @@ fn extract_all_subnets_from_range(range_start: IpAddr, range_end: IpAddr) -> Vec // Return early if range represents a single IP address. if range_start == range_end { - result.push(IpAddress { - address: Some(Address::Ip(range_start.to_string())), - }); + result.push(IpAddress::Ip(range_start.to_string())); return result; } @@ -770,9 +769,7 @@ fn extract_all_subnets_from_range(range_start: IpAddr, range_end: IpAddr) -> Vec // Check if the subnet covers the entire range if subnet_start == range_start && subnet_end == range_end { // Use subnet notation for the entire range - result.push(IpAddress { - address: Some(Address::IpSubnet(subnet.to_string())), - }); + result.push(IpAddress::IpSubnet(subnet.to_string())); } else { // Subnet is found within the range, append both subnet and remaining ranges. @@ -803,9 +800,7 @@ fn extract_all_subnets_from_range(range_start: IpAddr, range_end: IpAddr) -> Vec } // Add the subnet itself - result.push(IpAddress { - address: Some(Address::IpSubnet(subnet.to_string())), - }); + result.push(IpAddress::IpSubnet(subnet.to_string())); // Add range after subnet (if any) if subnet_end < range_end { @@ -834,12 +829,10 @@ fn extract_all_subnets_from_range(range_start: IpAddr, range_end: IpAddr) -> Vec } } else { // Fall back to range notation if no subnet is found. - result.push(IpAddress { - address: Some(Address::IpRange(IpRange { - start: range_start.to_string(), - end: range_end.to_string(), - })), - }); + result.push(IpAddress::IpRange(IpRange { + start: range_start.to_string(), + end: range_end.to_string(), + })); } result @@ -877,16 +870,12 @@ fn merge_port_ranges(port_ranges: Vec) -> Vec { let range_start = *range.start(); let range_end = *range.end(); if range_start == range_end { - Port { - port: Some(PortInner::SinglePort(u32::from(range_start))), - } + Port::Single(u32::from(range_start)) } else { - Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: u32::from(range_start), - end: u32::from(range_end), - })), - } + Port::Range(GwPortRange { + start: u32::from(range_start), + end: u32::from(range_end), + }) } }) .collect() @@ -899,7 +888,7 @@ fn merge_port_ranges(port_ranges: Vec) -> Vec { async fn generate_user_snat_bindings_for_location( location_id: Id, conn: &mut PgConnection, -) -> sqlx::Result> { +) -> sqlx::Result> { debug!("Generating SNAT bindings for location {location_id}"); let user_snat_bindings = UserSnatBinding::all_for_location(&mut *conn, location_id).await?; @@ -951,7 +940,7 @@ async fn generate_user_snat_bindings_for_location( } // create the SNAT binding proto - let snat_binding = SnatBindingProto { + let snat_binding = SnatBinding { id: user_binding.id, source_addrs, public_ip: user_binding.public_ip.to_string(), @@ -1046,7 +1035,7 @@ pub async fn try_get_location_firewall_config( generate_firewall_rules_from_acls(location.id, location_acls, &mut *conn).await?; let snat_bindings = generate_user_snat_bindings_for_location(location.id, &mut *conn).await?; let firewall_config = FirewallConfig { - default_policy: default_policy.into(), + default_policy, rules: firewall_rules, snat_bindings, }; diff --git a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs index ae4c50b633..478fdbf986 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs @@ -1,20 +1,21 @@ +use defguard_common::gateway_types::{ + FirewallPolicy, IpAddress, IpRange, Port, PortRange as GwPortRange, Protocol as GwProtocol, +}; +use defguard_proto::enterprise::firewall::Protocol as ProtoProtocol; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, ops::RangeInclusive, }; use defguard_common::db::{NoId, models::WireguardNetwork, setup_pool}; -use defguard_proto::enterprise::firewall::{ - FirewallPolicy, IpAddress, IpRange, Port, Protocol, ip_address::Address, - port::Port as PortInner, -}; use rand::thread_rng; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use super::{create_acl_rule, create_test_users_and_devices, set_test_license_business}; use crate::enterprise::{ db::models::acl::{ - AclAlias, AclAliasDestinationRange, AclRule, AclRuleDestinationRange, AliasKind, RuleState, + AclAlias, AclAliasDestinationRange, AclRule, AclRuleDestinationRange, AliasKind, PortRange, + RuleState, }, firewall::{process_destination_addrs, try_get_location_firewall_config}, }; @@ -50,21 +51,13 @@ fn test_process_destination_addrs_v4() { assert_eq!( destination_addrs.0, [ - IpAddress { - address: Some(Address::IpSubnet("10.0.1.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.3.255".to_owned(), - end: "10.0.4.0".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }, + IpAddress::IpSubnet("10.0.1.0/24".to_owned()), + IpAddress::IpSubnet("10.0.2.0/24".to_owned()), + IpAddress::IpRange(IpRange { + start: "10.0.3.255".to_owned(), + end: "10.0.4.0".to_owned(), + }), + IpAddress::IpSubnet("192.168.1.0/24".to_owned()), ] ); @@ -111,21 +104,11 @@ fn test_process_destination_addrs_v6() { assert_eq!( destination_addrs.1, [ - IpAddress { - address: Some(Address::IpSubnet("2001:db8:1::/64".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:2::/64".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:3::/64".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("2001:db8:4::1".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:4::2/127".to_owned())) - } + IpAddress::IpSubnet("2001:db8:1::/64".to_owned()), + IpAddress::IpSubnet("2001:db8:2::/64".to_owned()), + IpAddress::IpSubnet("2001:db8:3::/64".to_owned()), + IpAddress::Ip("2001:db8:4::1".to_owned()), + IpAddress::IpSubnet("2001:db8:4::2/127".to_owned()) ] ); @@ -195,29 +178,25 @@ async fn test_any_address_overwrites_manual_destination( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert!(allow_rule.destination_addrs.is_empty()); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert!(deny_rule.destination_addrs.is_empty()); } @@ -294,29 +273,25 @@ async fn test_any_address_overwrites_destination_alias_addrs( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert!(allow_rule.destination_addrs.is_empty()); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert!(deny_rule.destination_addrs.is_empty()); } @@ -390,35 +365,29 @@ async fn test_manual_destination_includes_component_alias_address_range( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; - let expected_destination_addrs = [IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.2.0.255".to_owned(), - end: "10.2.1.0".to_owned(), - })), - }]; + let expected_destination_addrs = [IpAddress::IpRange(IpRange { + start: "10.2.0.255".to_owned(), + end: "10.2.1.0".to_owned(), + })]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); } @@ -495,43 +464,35 @@ async fn test_manual_destination_merges_rule_and_component_alias_address_ranges( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.2.0.255".to_owned(), - end: "10.2.1.0".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.3.0.255".to_owned(), - end: "10.3.1.0".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.2.0.255".to_owned(), + end: "10.2.1.0".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.3.0.255".to_owned(), + end: "10.3.1.0".to_owned(), + }), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); } @@ -560,10 +521,10 @@ async fn test_any_port_preserves_destination_addresses_and_protocols( allow_all_users: true, addresses: vec!["192.168.50.0/24".parse().unwrap()], ports: vec![ - crate::enterprise::db::models::acl::PortRange::new(22, 22).into(), - crate::enterprise::db::models::acl::PortRange::new(443, 443).into(), + PortRange::new(22, 22).into(), + PortRange::new(443, 443).into(), ], - protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + protocols: vec![ProtoProtocol::Tcp as i32, ProtoProtocol::Udp as i32], any_address: false, any_port: true, any_protocol: false, @@ -597,51 +558,34 @@ async fn test_any_port_preserves_destination_addresses_and_protocols( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::IpSubnet("192.168.50.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.60.10/31".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.60.12/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.60.16/30".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.60.20".to_owned())), - }, + IpAddress::IpSubnet("192.168.50.0/24".to_owned()), + IpAddress::IpSubnet("192.168.60.10/31".to_owned()), + IpAddress::IpSubnet("192.168.60.12/30".to_owned()), + IpAddress::IpSubnet("192.168.60.16/30".to_owned()), + IpAddress::Ip("192.168.60.20".to_owned()), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); assert!(allow_rule.destination_ports.is_empty()); - assert_eq!( - allow_rule.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp, GwProtocol::Udp]); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -678,7 +622,7 @@ async fn test_any_protocol_preserves_destination_addresses_and_ports( crate::enterprise::db::models::acl::PortRange::new(80, 80).into(), crate::enterprise::db::models::acl::PortRange::new(1000, 1005).into(), ], - protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + protocols: vec![ProtoProtocol::Tcp as i32, ProtoProtocol::Udp as i32], any_address: false, any_port: false, any_protocol: true, @@ -709,52 +653,38 @@ async fn test_any_protocol_preserves_destination_addresses_and_ports( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::IpSubnet("192.168.70.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.80.1".to_owned())), - }, + IpAddress::IpSubnet("192.168.70.0/24".to_owned()), + IpAddress::Ip("192.168.80.1".to_owned()), ]; let expected_ports = [ - Port { - port: Some(PortInner::SinglePort(80)), - }, - Port { - port: Some(PortInner::PortRange( - defguard_proto::enterprise::firewall::PortRange { - start: 1000, - end: 1005, - }, - )), - }, + Port::Single(80), + Port::Range(GwPortRange { + start: 1000, + end: 1005, + }), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); assert_eq!(allow_rule.destination_ports, expected_ports); assert!(allow_rule.protocols.is_empty()); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -787,7 +717,7 @@ async fn test_destination_alias_any_port_preserves_addresses_and_protocols( crate::enterprise::db::models::acl::PortRange::new(22, 22).into(), crate::enterprise::db::models::acl::PortRange::new(443, 443).into(), ], - protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + protocols: vec![ProtoProtocol::Tcp as i32, ProtoProtocol::Udp as i32], any_address: false, any_port: true, any_protocol: false, @@ -838,51 +768,34 @@ async fn test_destination_alias_any_port_preserves_addresses_and_protocols( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::IpSubnet("192.168.90.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.91.10/31".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.91.12/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.91.16/30".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.91.20".to_owned())), - }, + IpAddress::IpSubnet("192.168.90.0/24".to_owned()), + IpAddress::IpSubnet("192.168.91.10/31".to_owned()), + IpAddress::IpSubnet("192.168.91.12/30".to_owned()), + IpAddress::IpSubnet("192.168.91.16/30".to_owned()), + IpAddress::Ip("192.168.91.20".to_owned()), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); assert!(allow_rule.destination_ports.is_empty()); - assert_eq!( - allow_rule.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp, GwProtocol::Udp]); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -918,7 +831,7 @@ async fn test_destination_alias_any_protocol_preserves_addresses_and_ports( crate::enterprise::db::models::acl::PortRange::new(80, 80).into(), crate::enterprise::db::models::acl::PortRange::new(1000, 1005).into(), ], - protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + protocols: vec![ProtoProtocol::Tcp as i32, ProtoProtocol::Udp as i32], any_address: false, any_port: false, any_protocol: true, @@ -959,52 +872,38 @@ async fn test_destination_alias_any_protocol_preserves_addresses_and_ports( .rules; let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::IpSubnet("192.168.110.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.120.1".to_owned())), - }, + IpAddress::IpSubnet("192.168.110.0/24".to_owned()), + IpAddress::Ip("192.168.120.1".to_owned()), ]; let expected_ports = [ - Port { - port: Some(PortInner::SinglePort(80)), - }, - Port { - port: Some(PortInner::PortRange( - defguard_proto::enterprise::firewall::PortRange { - start: 1000, - end: 1005, - }, - )), - }, + Port::Single(80), + Port::Range(GwPortRange { + start: 1000, + end: 1005, + }), ]; assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); assert_eq!(allow_rule.destination_ports, expected_ports); assert!(allow_rule.protocols.is_empty()); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); assert!(deny_rule.destination_ports.is_empty()); diff --git a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs b/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs index 47ae9b0db2..4774cd4da9 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs @@ -1,3 +1,4 @@ +use defguard_common::gateway_types::{FirewallPolicy, IpVersion}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use chrono::NaiveDateTime; @@ -6,7 +7,6 @@ use defguard_common::db::{ models::{Device, DeviceType, User, WireguardNetwork, device::WireguardNetworkDevice}, setup_pool, }; -use defguard_proto::enterprise::firewall::{FirewallPolicy, IpVersion}; use ipnetwork::IpNetwork; use rand::{Rng, rngs::ThreadRng, thread_rng}; use sqlx::{ @@ -124,12 +124,12 @@ async fn test_gh1868_ipv6_rule_is_not_created_with_v4_only_destination( assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict(), FirewallPolicy::Allow); - assert_eq!(allow_rule.ip_version(), IpVersion::Ipv4); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(allow_rule.ip_version, IpVersion::Ipv4); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict(), FirewallPolicy::Deny); - assert_eq!(allow_rule.ip_version(), IpVersion::Ipv4); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); + assert_eq!(allow_rule.ip_version, IpVersion::Ipv4); } #[sqlx::test] @@ -184,12 +184,12 @@ async fn test_gh1868_ipv4_rule_is_not_created_with_v6_only_destination( assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); - assert_eq!(allow_rule.ip_version, i32::from(IpVersion::Ipv6)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(allow_rule.ip_version, IpVersion::Ipv6); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); - assert_eq!(allow_rule.ip_version, i32::from(IpVersion::Ipv6)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); + assert_eq!(allow_rule.ip_version, IpVersion::Ipv6); } #[sqlx::test] @@ -244,16 +244,16 @@ async fn test_gh1868_ipv4_and_ipv6_rules_are_created_with_any_destination( assert_eq!(generated_firewall_rules.len(), 4); let allow_rule_ipv4 = &generated_firewall_rules[0]; - assert_eq!(allow_rule_ipv4.verdict(), FirewallPolicy::Allow); - assert_eq!(allow_rule_ipv4.ip_version(), IpVersion::Ipv4); + assert_eq!(allow_rule_ipv4.verdict, FirewallPolicy::Allow); + assert_eq!(allow_rule_ipv4.ip_version, IpVersion::Ipv4); let allow_rule_ipv6 = &generated_firewall_rules[1]; - assert_eq!(allow_rule_ipv6.verdict(), FirewallPolicy::Allow); - assert_eq!(allow_rule_ipv6.ip_version(), IpVersion::Ipv6); + assert_eq!(allow_rule_ipv6.verdict, FirewallPolicy::Allow); + assert_eq!(allow_rule_ipv6.ip_version, IpVersion::Ipv6); let deny_rule_ipv4 = &generated_firewall_rules[2]; - assert_eq!(deny_rule_ipv4.verdict(), FirewallPolicy::Deny); - assert_eq!(allow_rule_ipv4.ip_version(), IpVersion::Ipv4); + assert_eq!(deny_rule_ipv4.verdict, FirewallPolicy::Deny); + assert_eq!(allow_rule_ipv4.ip_version, IpVersion::Ipv4); let deny_rule_ipv6 = &generated_firewall_rules[3]; - assert_eq!(deny_rule_ipv6.verdict(), FirewallPolicy::Deny); - assert_eq!(allow_rule_ipv6.ip_version(), IpVersion::Ipv6); + assert_eq!(deny_rule_ipv6.verdict, FirewallPolicy::Deny); + assert_eq!(allow_rule_ipv6.ip_version, IpVersion::Ipv6); } diff --git a/crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs b/crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs index 252368d7a6..f251c3732c 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/ip_address_handling.rs @@ -1,9 +1,6 @@ +use defguard_common::gateway_types::{IpAddress, IpRange, Port, PortRange as GwPortRange}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use defguard_proto::enterprise::firewall::{ - IpAddress, IpRange, Port, PortRange as PortRangeProto, ip_address::Address, - port::Port as PortInner, -}; use ipnetwork::Ipv6Network; use crate::enterprise::{ @@ -30,30 +27,14 @@ fn test_merge_v4_addrs() { assert_eq!( merged_addrs, [ - IpAddress { - address: Some(Address::Ip("10.0.8.127".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.8.128/25".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.9.0/24".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.10.0/27".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("10.0.20.20".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.60.20/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.60.24/31".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.0.20".to_owned())), - }, + IpAddress::Ip("10.0.8.127".to_owned()), + IpAddress::IpSubnet("10.0.8.128/25".to_owned()), + IpAddress::IpSubnet("10.0.9.0/24".to_owned()), + IpAddress::IpSubnet("10.0.10.0/27".to_owned()), + IpAddress::Ip("10.0.20.20".to_owned()), + IpAddress::IpSubnet("10.0.60.20/30".to_owned()), + IpAddress::IpSubnet("10.0.60.24/31".to_owned()), + IpAddress::Ip("192.168.0.20".to_owned()), ] ); @@ -70,12 +51,8 @@ fn test_merge_v4_addrs() { assert_eq!( merged_addrs, [ - IpAddress { - address: Some(Address::IpSubnet("10.0.10.0/30".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("10.0.10.20".to_owned())), - }, + IpAddress::IpSubnet("10.0.10.0/30".to_owned()), + IpAddress::Ip("10.0.10.20".to_owned()), ] ); } @@ -97,27 +74,13 @@ fn test_merge_v6_addrs() { assert_eq!( merged_addrs, [ - IpAddress { - address: Some(Address::Ip("2001:db8:1::1".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:1::2/127".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:1::4/126".to_owned())) - }, - IpAddress { - address: Some(Address::Ip("2001:db8:1::8".to_owned())) - }, - IpAddress { - address: Some(Address::Ip("2001:db8:2::1".to_owned())) - }, - IpAddress { - address: Some(Address::Ip("2001:db8:3::1".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8:3::2/127".to_owned())) - } + IpAddress::Ip("2001:db8:1::1".to_owned()), + IpAddress::IpSubnet("2001:db8:1::2/127".to_owned()), + IpAddress::IpSubnet("2001:db8:1::4/126".to_owned()), + IpAddress::Ip("2001:db8:1::8".to_owned()), + IpAddress::Ip("2001:db8:2::1".to_owned()), + IpAddress::Ip("2001:db8:3::1".to_owned()), + IpAddress::IpSubnet("2001:db8:3::2/127".to_owned()) ] ); } @@ -133,12 +96,8 @@ fn test_merge_addrs_extracts_ipv4_subnets() { assert_eq!( result, [ - IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.2.0/24".to_owned())) - }, + IpAddress::IpSubnet("192.168.1.0/24".to_owned()), + IpAddress::IpSubnet("192.168.2.0/24".to_owned()), ] ); } @@ -154,12 +113,8 @@ fn test_merge_addrs_extracts_ipv6_subnets() { assert_eq!( result, [ - IpAddress { - address: Some(Address::IpSubnet("2001:db8::/32".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db9::/112".to_owned())) - }, + IpAddress::IpSubnet("2001:db8::/32".to_owned()), + IpAddress::IpSubnet("2001:db9::/112".to_owned()), ] ); } @@ -174,12 +129,10 @@ fn test_merge_addrs_falls_back_to_range_when_no_subnet_fits() { assert_eq!( result, - [IpAddress { - address: Some(Address::IpRange(IpRange { - start: "192.168.1.255".to_owned(), - end: "192.168.2.0".to_owned(), - })), - },] + [IpAddress::IpRange(IpRange { + start: "192.168.1.255".to_owned(), + end: "192.168.2.0".to_owned(), + }),] ); let start = "2001:db8:ffff:ffff:ffff:ffff:ffff:ffff" @@ -192,12 +145,10 @@ fn test_merge_addrs_falls_back_to_range_when_no_subnet_fits() { assert_eq!( result, - [IpAddress { - address: Some(Address::IpRange(IpRange { - start: "2001:db8:ffff:ffff:ffff:ffff:ffff:ffff".to_owned(), - end: "2001:db9::".to_owned(), - })), - },] + [IpAddress::IpRange(IpRange { + start: "2001:db8:ffff:ffff:ffff:ffff:ffff:ffff".to_owned(), + end: "2001:db9::".to_owned(), + }),] ); } @@ -209,12 +160,7 @@ fn test_merge_addrs_handles_single_ip() { let result = merge_addrs(ranges); - assert_eq!( - result, - [IpAddress { - address: Some(Address::Ip("192.168.1.1".to_owned())), - },] - ); + assert_eq!(result, [IpAddress::Ip("192.168.1.1".to_owned()),]); let start = "2001:db8::".parse::().unwrap(); let end = "2001:db8::".parse::().unwrap(); @@ -222,12 +168,7 @@ fn test_merge_addrs_handles_single_ip() { let result = merge_addrs(ranges); - assert_eq!( - result, - [IpAddress { - address: Some(Address::Ip("2001:db8::".to_owned())), - },] - ); + assert_eq!(result, [IpAddress::Ip("2001:db8::".to_owned()),]); } #[test] @@ -299,12 +240,8 @@ fn test_merge_addrs_subnet_at_start_of_range() { assert_eq!( result, [ - IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/26".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.1.64".to_owned())), - }, + IpAddress::IpSubnet("192.168.1.0/26".to_owned()), + IpAddress::Ip("192.168.1.64".to_owned()), ] ); @@ -318,12 +255,8 @@ fn test_merge_addrs_subnet_at_start_of_range() { assert_eq!( result, [ - IpAddress { - address: Some(Address::IpSubnet("2001:db8::/122".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("2001:db8::40".to_owned())), - }, + IpAddress::IpSubnet("2001:db8::/122".to_owned()), + IpAddress::Ip("2001:db8::40".to_owned()), ] ); } @@ -339,12 +272,8 @@ fn test_merge_addrs_subnet_at_end_of_range() { assert_eq!( result, [ - IpAddress { - address: Some(Address::Ip("192.168.1.15".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.1.16/28".to_owned())), - }, + IpAddress::Ip("192.168.1.15".to_owned()), + IpAddress::IpSubnet("192.168.1.16/28".to_owned()), ] ); @@ -358,12 +287,8 @@ fn test_merge_addrs_subnet_at_end_of_range() { assert_eq!( result, [ - IpAddress { - address: Some(Address::Ip("2001:db8::f".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8::10/124".to_owned())), - }, + IpAddress::Ip("2001:db8::f".to_owned()), + IpAddress::IpSubnet("2001:db8::10/124".to_owned()), ] ); } @@ -373,12 +298,7 @@ fn test_merge_port_ranges() { // single port let input_ranges = vec![PortRange::new(100, 100)]; let merged = merge_port_ranges(input_ranges); - assert_eq!( - merged, - [Port { - port: Some(PortInner::SinglePort(100)) - }] - ); + assert_eq!(merged, [Port::Single(100)]); // overlapping ranges let input_ranges = vec![ @@ -389,12 +309,10 @@ fn test_merge_port_ranges() { let merged = merge_port_ranges(input_ranges); assert_eq!( merged, - [Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 100, - end: 300 - })) - }] + [Port::Range(GwPortRange { + start: 100, + end: 300 + })] ); // duplicate ranges @@ -413,18 +331,14 @@ fn test_merge_port_ranges() { assert_eq!( merged, [ - Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 100, - end: 300 - })) - }, - Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 350, - end: 400 - })) - } + Port::Range(GwPortRange { + start: 100, + end: 300 + }), + Port::Range(GwPortRange { + start: 350, + end: 400 + }) ] ); @@ -441,24 +355,16 @@ fn test_merge_port_ranges() { assert_eq!( merged, [ - Port { - port: Some(PortInner::SinglePort(50)) - }, - Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 151, - end: 300 - })) - }, - Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 501, - end: 699 - })) - }, - Port { - port: Some(PortInner::SinglePort(800)) - } + Port::Single(50), + Port::Range(GwPortRange { + start: 151, + end: 300 + }), + Port::Range(GwPortRange { + start: 501, + end: 699 + }), + Port::Single(800) ] ); @@ -467,12 +373,10 @@ fn test_merge_port_ranges() { let merged = merge_port_ranges(input_ranges); assert_eq!( merged, - [Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 100, - end: 200 - })) - }] + [Port::Range(GwPortRange { + start: 100, + end: 200 + })] ); } diff --git a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs index 616f4373cf..1801930b1c 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs @@ -9,10 +9,11 @@ use defguard_common::db::{ }, setup_pool, }; -use defguard_proto::enterprise::firewall::{ - FirewallPolicy, IpAddress, IpRange, IpVersion, Port, PortRange as PortRangeProto, Protocol, - ip_address::Address, port::Port as PortInner, +use defguard_common::gateway_types::{ + FirewallPolicy, IpAddress, IpRange, IpVersion, Port, PortRange as GwPortRange, + Protocol as GwProtocol, }; +use defguard_proto::enterprise::firewall::Protocol as ProtoProtocol; use ipnetwork::IpNetwork; use rand::{Rng, rngs::ThreadRng, thread_rng}; use sqlx::{ @@ -78,12 +79,10 @@ fn random_network_device_with_id(rng: &mut R, id: Id) -> Device { fn expected_ipv4_source_range_for_user(user_id: Id) -> IpAddress { let user_octet = user_id as u8; - IpAddress { - address: Some(Address::IpRange(IpRange { - start: format!("10.0.{user_octet}.1"), - end: format!("10.0.{user_octet}.2"), - })), - } + IpAddress::IpRange(IpRange { + start: format!("10.0.{user_octet}.1"), + end: format!("10.0.{user_octet}.2"), + }) } async fn create_test_user_with_devices( @@ -483,7 +482,7 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO PortRange::new(80, 80).into(), PortRange::new(443, 443).into(), ], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -529,7 +528,7 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO deny_all_network_devices: false, addresses: Vec::new(), // Will use destination ranges instead ports: vec![PortRange::new(53, 53).into()], - protocols: vec![Protocol::Udp.into(), Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Udp as i32, ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -585,7 +584,7 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO .unwrap(); assert_eq!( generated_firewall_config.default_policy, - i32::from(FirewallPolicy::Allow) + FirewallPolicy::Allow ); let generated_firewall_rules = generated_firewall_config.rules; @@ -594,142 +593,87 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO // First ACL - Web Access ALLOW let web_allow_rule = &generated_firewall_rules[0]; - assert_eq!(web_allow_rule.verdict, i32::from(FirewallPolicy::Allow)); - assert_eq!(web_allow_rule.protocols, vec![i32::from(Protocol::Tcp)]); + assert_eq!(web_allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(web_allow_rule.protocols, vec![GwProtocol::Tcp]); assert_eq!( web_allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }] + [IpAddress::IpSubnet("192.168.1.0/24".to_owned())] ); assert_eq!( web_allow_rule.destination_ports, - [ - Port { - port: Some(PortInner::SinglePort(80)) - }, - Port { - port: Some(PortInner::SinglePort(443)) - } - ] + [Port::Single(80), Port::Single(443)] ); // Source addresses should include devices of users 1,2 and network_device_1 assert_eq!( web_allow_rule.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::Ip("10.0.100.1".to_owned())), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), + IpAddress::Ip("10.0.100.1".to_owned()), ] ); // First ACL - Web Access DENY let web_deny_rule = &generated_firewall_rules[2]; - assert_eq!(web_deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(web_deny_rule.verdict, FirewallPolicy::Deny); assert!(web_deny_rule.protocols.is_empty()); assert!(web_deny_rule.destination_ports.is_empty()); assert!(web_deny_rule.source_addrs.is_empty()); assert_eq!( web_deny_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }] + [IpAddress::IpSubnet("192.168.1.0/24".to_owned())] ); // Second ACL - DNS Access ALLOW let dns_allow_rule = &generated_firewall_rules[1]; - assert_eq!(dns_allow_rule.verdict, i32::from(FirewallPolicy::Allow)); - assert_eq!( - dns_allow_rule.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); - assert_eq!( - dns_allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(53)) - }] - ); + assert_eq!(dns_allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(dns_allow_rule.protocols, [GwProtocol::Tcp, GwProtocol::Udp]); + assert_eq!(dns_allow_rule.destination_ports, [Port::Single(53)]); // Source addresses should include network_devices 1,2 assert_eq!( dns_allow_rule.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.100.1".to_owned(), - end: "10.0.100.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.100.1".to_owned(), + end: "10.0.100.2".to_owned(), + }), ] ); let expected_destination_addrs = [ - IpAddress { - address: Some(Address::Ip("10.0.1.13".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.14/31".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.16/28".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.32/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.40/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.52/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.56/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.64/26".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.128/25".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.0/27".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.32/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.40/30".to_owned())), - }, + IpAddress::Ip("10.0.1.13".to_owned()), + IpAddress::IpSubnet("10.0.1.14/31".to_owned()), + IpAddress::IpSubnet("10.0.1.16/28".to_owned()), + IpAddress::IpSubnet("10.0.1.32/29".to_owned()), + IpAddress::IpSubnet("10.0.1.40/30".to_owned()), + IpAddress::IpSubnet("10.0.1.52/30".to_owned()), + IpAddress::IpSubnet("10.0.1.56/29".to_owned()), + IpAddress::IpSubnet("10.0.1.64/26".to_owned()), + IpAddress::IpSubnet("10.0.1.128/25".to_owned()), + IpAddress::IpSubnet("10.0.2.0/27".to_owned()), + IpAddress::IpSubnet("10.0.2.32/29".to_owned()), + IpAddress::IpSubnet("10.0.2.40/30".to_owned()), ]; assert_eq!(dns_allow_rule.destination_addrs, expected_destination_addrs); // Second ACL - DNS Access DENY let dns_deny_rule = &generated_firewall_rules[3]; - assert_eq!(dns_deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(dns_deny_rule.verdict, FirewallPolicy::Deny); assert!(dns_deny_rule.protocols.is_empty(),); assert!(dns_deny_rule.destination_ports.is_empty(),); assert!(dns_deny_rule.source_addrs.is_empty(),); @@ -904,7 +848,7 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO PortRange::new(80, 80).into(), PortRange::new(443, 443).into(), ], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -950,7 +894,7 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO deny_all_network_devices: false, addresses: Vec::new(), // Will use destination ranges instead ports: vec![PortRange::new(53, 53).into()], - protocols: vec![Protocol::Udp.into(), Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Udp as i32, ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -1006,7 +950,7 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO .unwrap(); assert_eq!( generated_firewall_config.default_policy, - i32::from(FirewallPolicy::Allow) + FirewallPolicy::Allow ); let generated_firewall_rules = generated_firewall_config.rules; @@ -1015,166 +959,95 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO // First ACL - Web Access ALLOW let web_allow_rule = &generated_firewall_rules[0]; - assert_eq!(web_allow_rule.verdict, i32::from(FirewallPolicy::Allow)); - assert_eq!(web_allow_rule.protocols, vec![i32::from(Protocol::Tcp)]); + assert_eq!(web_allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(web_allow_rule.protocols, vec![GwProtocol::Tcp]); assert_eq!( web_allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("fc00::/112".to_owned())), - }] + [IpAddress::IpSubnet("fc00::/112".to_owned())] ); assert_eq!( web_allow_rule.destination_ports, - [ - Port { - port: Some(PortInner::SinglePort(80)) - }, - Port { - port: Some(PortInner::SinglePort(443)) - } - ] + [Port::Single(80), Port::Single(443)] ); // Source addresses should include devices of users 1,2 and network_device_1 assert_eq!( web_allow_rule.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::1:1".to_owned(), - end: "ff00::1:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::2:1".to_owned(), - end: "ff00::2:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::Ip("ff00::100:1".to_owned())), - }, + IpAddress::IpRange(IpRange { + start: "ff00::1:1".to_owned(), + end: "ff00::1:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::2:1".to_owned(), + end: "ff00::2:2".to_owned(), + }), + IpAddress::Ip("ff00::100:1".to_owned()), ] ); // First ACL - Web Access DENY let web_deny_rule = &generated_firewall_rules[2]; - assert_eq!(web_deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(web_deny_rule.verdict, FirewallPolicy::Deny); assert!(web_deny_rule.protocols.is_empty()); assert!(web_deny_rule.destination_ports.is_empty()); assert!(web_deny_rule.source_addrs.is_empty()); assert_eq!( web_deny_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("fc00::/112".to_owned())), - }] + [IpAddress::IpSubnet("fc00::/112".to_owned())] ); // Second ACL - DNS Access ALLOW let dns_allow_rule = &generated_firewall_rules[1]; - assert_eq!(dns_allow_rule.verdict, i32::from(FirewallPolicy::Allow)); - assert_eq!( - dns_allow_rule.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); - assert_eq!( - dns_allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(53)) - }] - ); + assert_eq!(dns_allow_rule.verdict, FirewallPolicy::Allow); + assert_eq!(dns_allow_rule.protocols, [GwProtocol::Tcp, GwProtocol::Udp]); + assert_eq!(dns_allow_rule.destination_ports, [Port::Single(53)]); let expected_destination_addrs = vec![ - IpAddress { - address: Some(Address::Ip("fc00::1:13".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:14/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:18/125".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:20/123".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:40/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:52/127".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:54/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:58/125".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:60/123".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:80/121".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:100/120".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:200/119".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:400/118".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:800/117".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:1000/116".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:2000/115".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:4000/114".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:8000/113".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::2:0/122".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::2:40/126".to_owned())), - }, + IpAddress::Ip("fc00::1:13".to_owned()), + IpAddress::IpSubnet("fc00::1:14/126".to_owned()), + IpAddress::IpSubnet("fc00::1:18/125".to_owned()), + IpAddress::IpSubnet("fc00::1:20/123".to_owned()), + IpAddress::IpSubnet("fc00::1:40/126".to_owned()), + IpAddress::IpSubnet("fc00::1:52/127".to_owned()), + IpAddress::IpSubnet("fc00::1:54/126".to_owned()), + IpAddress::IpSubnet("fc00::1:58/125".to_owned()), + IpAddress::IpSubnet("fc00::1:60/123".to_owned()), + IpAddress::IpSubnet("fc00::1:80/121".to_owned()), + IpAddress::IpSubnet("fc00::1:100/120".to_owned()), + IpAddress::IpSubnet("fc00::1:200/119".to_owned()), + IpAddress::IpSubnet("fc00::1:400/118".to_owned()), + IpAddress::IpSubnet("fc00::1:800/117".to_owned()), + IpAddress::IpSubnet("fc00::1:1000/116".to_owned()), + IpAddress::IpSubnet("fc00::1:2000/115".to_owned()), + IpAddress::IpSubnet("fc00::1:4000/114".to_owned()), + IpAddress::IpSubnet("fc00::1:8000/113".to_owned()), + IpAddress::IpSubnet("fc00::2:0/122".to_owned()), + IpAddress::IpSubnet("fc00::2:40/126".to_owned()), ]; // Source addresses should include network_devices 1,2 assert_eq!( dns_allow_rule.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::1:1".to_owned(), - end: "ff00::1:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::2:1".to_owned(), - end: "ff00::2:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::100:1".to_owned(), - end: "ff00::100:2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "ff00::1:1".to_owned(), + end: "ff00::1:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::2:1".to_owned(), + end: "ff00::2:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::100:1".to_owned(), + end: "ff00::100:2".to_owned(), + }), ] ); assert_eq!(dns_allow_rule.destination_addrs, expected_destination_addrs); // Second ACL - DNS Access DENY let dns_deny_rule = &generated_firewall_rules[3]; - assert_eq!(dns_deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(dns_deny_rule.verdict, FirewallPolicy::Deny); assert!(dns_deny_rule.protocols.is_empty(),); assert!(dns_deny_rule.destination_ports.is_empty(),); assert!(dns_deny_rule.source_addrs.is_empty(),); @@ -1368,7 +1241,7 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P PortRange::new(80, 80).into(), PortRange::new(443, 443).into(), ], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -1414,7 +1287,7 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P deny_all_network_devices: false, addresses: Vec::new(), // Will use destination ranges instead ports: vec![PortRange::new(53, 53).into()], - protocols: vec![Protocol::Udp.into(), Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Udp as i32, ProtoProtocol::Tcp as i32], enabled: true, parent_id: None, state: RuleState::Applied, @@ -1472,7 +1345,7 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P .unwrap(); assert_eq!( generated_firewall_config.default_policy, - i32::from(FirewallPolicy::Allow) + FirewallPolicy::Allow ); let generated_firewall_rules = generated_firewall_config.rules; @@ -1481,201 +1354,120 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P // First ACL - Web Access ALLOW let web_allow_rule_ipv4 = &generated_firewall_rules[0]; - assert_eq!( - web_allow_rule_ipv4.verdict, - i32::from(FirewallPolicy::Allow) - ); - assert_eq!( - web_allow_rule_ipv4.protocols, - vec![i32::from(Protocol::Tcp)] - ); + assert_eq!(web_allow_rule_ipv4.verdict, FirewallPolicy::Allow); + assert_eq!(web_allow_rule_ipv4.protocols, vec![GwProtocol::Tcp]); assert_eq!( web_allow_rule_ipv4.destination_addrs, - vec![IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }] + vec![IpAddress::IpSubnet("192.168.1.0/24".to_owned())] ); assert_eq!( web_allow_rule_ipv4.destination_ports, - vec![ - Port { - port: Some(PortInner::SinglePort(80)) - }, - Port { - port: Some(PortInner::SinglePort(443)) - } - ] + vec![Port::Single(80), Port::Single(443)] ); // Source addresses should include devices of users 1,2 and network_device_1 assert_eq!( web_allow_rule_ipv4.source_addrs, vec![ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::Ip("10.0.100.1".to_owned())), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), + IpAddress::Ip("10.0.100.1".to_owned()), ] ); let web_allow_rule_ipv6 = &generated_firewall_rules[1]; - assert_eq!( - web_allow_rule_ipv6.verdict, - i32::from(FirewallPolicy::Allow) - ); - assert_eq!(web_allow_rule_ipv6.protocols, [i32::from(Protocol::Tcp)]); + assert_eq!(web_allow_rule_ipv6.verdict, FirewallPolicy::Allow); + assert_eq!(web_allow_rule_ipv6.protocols, [GwProtocol::Tcp]); assert_eq!( web_allow_rule_ipv6.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("fc00::/112".to_owned())), - }] + [IpAddress::IpSubnet("fc00::/112".to_owned())] ); assert_eq!( web_allow_rule_ipv6.destination_ports, - [ - Port { - port: Some(PortInner::SinglePort(80)) - }, - Port { - port: Some(PortInner::SinglePort(443)) - } - ] + [Port::Single(80), Port::Single(443)] ); // Source addresses should include devices of users 1,2 and network_device_1 assert_eq!( web_allow_rule_ipv6.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::1:1".to_owned(), - end: "ff00::1:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::2:1".to_owned(), - end: "ff00::2:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::Ip("ff00::100:1".to_owned())), - }, + IpAddress::IpRange(IpRange { + start: "ff00::1:1".to_owned(), + end: "ff00::1:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::2:1".to_owned(), + end: "ff00::2:2".to_owned(), + }), + IpAddress::Ip("ff00::100:1".to_owned()), ] ); // First ACL - Web Access DENY let web_deny_rule_ipv4 = &generated_firewall_rules[4]; - assert_eq!(web_deny_rule_ipv4.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(web_deny_rule_ipv4.verdict, FirewallPolicy::Deny); assert!(web_deny_rule_ipv4.protocols.is_empty()); assert!(web_deny_rule_ipv4.destination_ports.is_empty()); assert!(web_deny_rule_ipv4.source_addrs.is_empty()); assert_eq!( web_deny_rule_ipv4.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }] + [IpAddress::IpSubnet("192.168.1.0/24".to_owned())] ); let web_deny_rule_ipv6 = &generated_firewall_rules[5]; - assert_eq!(web_deny_rule_ipv6.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(web_deny_rule_ipv6.verdict, FirewallPolicy::Deny); assert!(web_deny_rule_ipv6.protocols.is_empty()); assert!(web_deny_rule_ipv6.destination_ports.is_empty()); assert!(web_deny_rule_ipv6.source_addrs.is_empty()); assert_eq!( web_deny_rule_ipv6.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("fc00::/112".to_owned())), - }] + [IpAddress::IpSubnet("fc00::/112".to_owned())] ); // Second ACL - DNS Access ALLOW let dns_allow_rule_ipv4 = &generated_firewall_rules[2]; - assert_eq!( - dns_allow_rule_ipv4.verdict, - i32::from(FirewallPolicy::Allow) - ); + assert_eq!(dns_allow_rule_ipv4.verdict, FirewallPolicy::Allow); assert_eq!( dns_allow_rule_ipv4.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); - assert_eq!( - dns_allow_rule_ipv4.destination_ports, - [Port { - port: Some(PortInner::SinglePort(53)) - }] + [GwProtocol::Tcp, GwProtocol::Udp] ); + assert_eq!(dns_allow_rule_ipv4.destination_ports, [Port::Single(53)]); // Source addresses should include network_devices 1,2 assert_eq!( dns_allow_rule_ipv4.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.100.1".to_owned(), - end: "10.0.100.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.100.1".to_owned(), + end: "10.0.100.2".to_owned(), + }), ] ); let expected_destination_addrs_v4 = vec![ - IpAddress { - address: Some(Address::Ip("10.0.1.13".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.14/31".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.16/28".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.32/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.40/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.52/30".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.56/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.64/26".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.128/25".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.0/27".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.32/29".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.2.40/30".to_owned())), - }, + IpAddress::Ip("10.0.1.13".to_owned()), + IpAddress::IpSubnet("10.0.1.14/31".to_owned()), + IpAddress::IpSubnet("10.0.1.16/28".to_owned()), + IpAddress::IpSubnet("10.0.1.32/29".to_owned()), + IpAddress::IpSubnet("10.0.1.40/30".to_owned()), + IpAddress::IpSubnet("10.0.1.52/30".to_owned()), + IpAddress::IpSubnet("10.0.1.56/29".to_owned()), + IpAddress::IpSubnet("10.0.1.64/26".to_owned()), + IpAddress::IpSubnet("10.0.1.128/25".to_owned()), + IpAddress::IpSubnet("10.0.2.0/27".to_owned()), + IpAddress::IpSubnet("10.0.2.32/29".to_owned()), + IpAddress::IpSubnet("10.0.2.40/30".to_owned()), ]; assert_eq!( @@ -1684,106 +1476,52 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P ); let dns_allow_rule_ipv6 = &generated_firewall_rules[3]; - assert_eq!( - dns_allow_rule_ipv6.verdict, - i32::from(FirewallPolicy::Allow) - ); + assert_eq!(dns_allow_rule_ipv6.verdict, FirewallPolicy::Allow); assert_eq!( dns_allow_rule_ipv6.protocols, - [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] - ); - assert_eq!( - dns_allow_rule_ipv6.destination_ports, - [Port { - port: Some(PortInner::SinglePort(53)) - }] + [GwProtocol::Tcp, GwProtocol::Udp] ); + assert_eq!(dns_allow_rule_ipv6.destination_ports, [Port::Single(53)]); // Source addresses should include network_devices 1,2 assert_eq!( dns_allow_rule_ipv6.source_addrs, [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::1:1".to_owned(), - end: "ff00::1:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::2:1".to_owned(), - end: "ff00::2:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::100:1".to_owned(), - end: "ff00::100:2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "ff00::1:1".to_owned(), + end: "ff00::1:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::2:1".to_owned(), + end: "ff00::2:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::100:1".to_owned(), + end: "ff00::100:2".to_owned(), + }), ] ); let expected_destination_addrs_v6 = vec![ - IpAddress { - address: Some(Address::Ip("fc00::1:13".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:14/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:18/125".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:20/123".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:40/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:52/127".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:54/126".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:58/125".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:60/123".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:80/121".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:100/120".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:200/119".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:400/118".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:800/117".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:1000/116".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:2000/115".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:4000/114".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::1:8000/113".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::2:0/122".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("fc00::2:40/126".to_owned())), - }, + IpAddress::Ip("fc00::1:13".to_owned()), + IpAddress::IpSubnet("fc00::1:14/126".to_owned()), + IpAddress::IpSubnet("fc00::1:18/125".to_owned()), + IpAddress::IpSubnet("fc00::1:20/123".to_owned()), + IpAddress::IpSubnet("fc00::1:40/126".to_owned()), + IpAddress::IpSubnet("fc00::1:52/127".to_owned()), + IpAddress::IpSubnet("fc00::1:54/126".to_owned()), + IpAddress::IpSubnet("fc00::1:58/125".to_owned()), + IpAddress::IpSubnet("fc00::1:60/123".to_owned()), + IpAddress::IpSubnet("fc00::1:80/121".to_owned()), + IpAddress::IpSubnet("fc00::1:100/120".to_owned()), + IpAddress::IpSubnet("fc00::1:200/119".to_owned()), + IpAddress::IpSubnet("fc00::1:400/118".to_owned()), + IpAddress::IpSubnet("fc00::1:800/117".to_owned()), + IpAddress::IpSubnet("fc00::1:1000/116".to_owned()), + IpAddress::IpSubnet("fc00::1:2000/115".to_owned()), + IpAddress::IpSubnet("fc00::1:4000/114".to_owned()), + IpAddress::IpSubnet("fc00::1:8000/113".to_owned()), + IpAddress::IpSubnet("fc00::2:0/122".to_owned()), + IpAddress::IpSubnet("fc00::2:40/126".to_owned()), ]; assert_eq!( @@ -1793,7 +1531,7 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P // Second ACL - DNS Access DENY let dns_deny_rule_ipv4 = &generated_firewall_rules[6]; - assert_eq!(dns_deny_rule_ipv4.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(dns_deny_rule_ipv4.verdict, FirewallPolicy::Deny); assert!(dns_deny_rule_ipv4.protocols.is_empty(),); assert!(dns_deny_rule_ipv4.destination_ports.is_empty(),); assert!(dns_deny_rule_ipv4.source_addrs.is_empty(),); @@ -1803,7 +1541,7 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P ); let dns_deny_rule_ipv6 = &generated_firewall_rules[7]; - assert_eq!(dns_deny_rule_ipv6.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(dns_deny_rule_ipv6.verdict, FirewallPolicy::Deny); assert!(dns_deny_rule_ipv6.protocols.is_empty(),); assert!(dns_deny_rule_ipv6.destination_ports.is_empty(),); assert!(dns_deny_rule_ipv6.source_addrs.is_empty(),); @@ -1889,30 +1627,22 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { // check generated rules assert_eq!(generated_firewall_rules.len(), 4); let expected_source_addrs = [ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let expected_destination_addrs = [ - IpAddress { - address: Some(Address::Ip("10.0.2.3".to_owned())), - }, - IpAddress { - address: Some(Address::IpSubnet("192.168.1.0/24".to_owned())), - }, + IpAddress::Ip("10.0.2.3".to_owned()), + IpAddress::IpSubnet("192.168.1.0/24".to_owned()), ]; let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); assert!(allow_rule.destination_ports.is_empty()); @@ -1923,17 +1653,15 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { ); let alias_allow_rule = &generated_firewall_rules[1]; - assert_eq!(alias_allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(alias_allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(alias_allow_rule.source_addrs, expected_source_addrs); assert!(alias_allow_rule.destination_addrs.is_empty()); assert_eq!( alias_allow_rule.destination_ports, - vec![Port { - port: Some(PortInner::PortRange(PortRangeProto { - start: 100, - end: 200, - })) - }] + vec![Port::Range(GwPortRange { + start: 100, + end: 200, + })] ); assert!(alias_allow_rule.protocols.is_empty()); assert_eq!( @@ -1942,7 +1670,7 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { ); let deny_rule = &generated_firewall_rules[2]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -1953,7 +1681,7 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { ); let alias_deny_rule = &generated_firewall_rules[3]; - assert_eq!(alias_deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(alias_deny_rule.verdict, FirewallPolicy::Deny); assert!(alias_deny_rule.source_addrs.is_empty()); assert!(alias_deny_rule.destination_addrs.is_empty()); assert!(alias_deny_rule.destination_ports.is_empty()); @@ -2040,34 +1768,26 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt // check generated rules assert_eq!(generated_firewall_rules.len(), 4); let expected_source_addrs = vec![ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let alias_allow_rule_1 = &generated_firewall_rules[0]; - assert_eq!(alias_allow_rule_1.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(alias_allow_rule_1.verdict, FirewallPolicy::Allow); assert_eq!(alias_allow_rule_1.source_addrs, expected_source_addrs); assert_eq!( alias_allow_rule_1.destination_addrs, - vec![IpAddress { - address: Some(Address::Ip("10.0.2.3".to_owned())), - },] + vec![IpAddress::Ip("10.0.2.3".to_owned()),] ); assert_eq!( alias_allow_rule_1.destination_ports, - vec![Port { - port: Some(PortInner::SinglePort(5432)) - }] + vec![Port::Single(5432)] ); assert!(alias_allow_rule_1.protocols.is_empty()); assert_eq!( @@ -2076,19 +1796,15 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt ); let alias_allow_rule_2 = &generated_firewall_rules[1]; - assert_eq!(alias_allow_rule_2.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(alias_allow_rule_2.verdict, FirewallPolicy::Allow); assert_eq!(alias_allow_rule_2.source_addrs, expected_source_addrs); assert_eq!( alias_allow_rule_2.destination_addrs, - vec![IpAddress { - address: Some(Address::Ip("10.0.2.4".to_owned())), - },] + vec![IpAddress::Ip("10.0.2.4".to_owned()),] ); assert_eq!( alias_allow_rule_2.destination_ports, - vec![Port { - port: Some(PortInner::SinglePort(6379)) - }] + vec![Port::Single(6379)] ); assert!(alias_allow_rule_2.protocols.is_empty()); assert_eq!( @@ -2097,13 +1813,11 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt ); let alias_deny_rule_1 = &generated_firewall_rules[2]; - assert_eq!(alias_deny_rule_1.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(alias_deny_rule_1.verdict, FirewallPolicy::Deny); assert!(alias_deny_rule_1.source_addrs.is_empty()); assert_eq!( alias_deny_rule_1.destination_addrs, - vec![IpAddress { - address: Some(Address::Ip("10.0.2.3".to_owned())), - },] + vec![IpAddress::Ip("10.0.2.3".to_owned()),] ); assert!(alias_deny_rule_1.destination_ports.is_empty()); assert!(alias_deny_rule_1.protocols.is_empty()); @@ -2113,13 +1827,11 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt ); let alias_deny_rule_2 = &generated_firewall_rules[3]; - assert_eq!(alias_deny_rule_2.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(alias_deny_rule_2.verdict, FirewallPolicy::Deny); assert!(alias_deny_rule_2.source_addrs.is_empty()); assert_eq!( alias_deny_rule_2.destination_addrs, - vec![IpAddress { - address: Some(Address::Ip("10.0.2.4".to_owned())), - },] + vec![IpAddress::Ip("10.0.2.4".to_owned()),] ); assert!(alias_deny_rule_2.destination_ports.is_empty()); assert!(alias_deny_rule_2.protocols.is_empty()); @@ -2183,7 +1895,7 @@ async fn test_no_allowed_users_ipv4(_: PgPoolOptions, options: PgConnectOptions) // only deny rules are generated assert_eq!(generated_firewall_rules.len(), 2); for rule in generated_firewall_rules { - assert_eq!(rule.verdict(), FirewallPolicy::Deny); + assert_eq!(rule.verdict, FirewallPolicy::Deny); } } @@ -2238,7 +1950,7 @@ async fn test_allow_all_groups_expands_all_group_members_into_firewall_sources( allow_all_groups: true, addresses: vec!["192.168.10.0/24".parse().unwrap()], ports: vec![PortRange::new(443, 443).into()], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], any_address: false, any_port: false, any_protocol: false, @@ -2282,21 +1994,14 @@ async fn test_allow_all_groups_expands_all_group_members_into_firewall_sources( .collect(); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!( allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.10.0/24".to_owned())), - }] + [IpAddress::IpSubnet("192.168.10.0/24".to_owned())] ); - assert_eq!( - allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(443)), - }] - ); - assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + assert_eq!(allow_rule.destination_ports, [Port::Single(443)]); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp]); assert!( allow_rule .source_addrs @@ -2311,7 +2016,7 @@ async fn test_allow_all_groups_expands_all_group_members_into_firewall_sources( ); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -2374,7 +2079,7 @@ async fn test_allow_all_groups_deduplicates_shared_group_members_before_source_r allow_all_groups: true, addresses: vec!["192.168.30.0/24".parse().unwrap()], ports: vec![PortRange::new(443, 443).into()], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], any_address: false, any_port: false, any_protocol: false, @@ -2436,21 +2141,14 @@ async fn test_allow_all_groups_deduplicates_shared_group_members_before_source_r .collect(); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule.source_addrs, expected_source_addrs); assert_eq!( allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.30.0/24".to_owned())), - }] + [IpAddress::IpSubnet("192.168.30.0/24".to_owned())] ); - assert_eq!( - allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(443)), - }] - ); - assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + assert_eq!(allow_rule.destination_ports, [Port::Single(443)]); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp]); assert!( allow_rule .source_addrs @@ -2459,7 +2157,7 @@ async fn test_allow_all_groups_deduplicates_shared_group_members_before_source_r ); let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -2518,7 +2216,7 @@ async fn test_deny_all_groups_excludes_members_of_every_group_from_firewall_sour deny_all_groups: true, addresses: vec!["192.168.20.0/24".parse().unwrap()], ports: vec![PortRange::new(8443, 8443).into()], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], any_address: false, any_port: false, any_protocol: false, @@ -2551,7 +2249,7 @@ async fn test_deny_all_groups_excludes_members_of_every_group_from_firewall_sour assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!( allow_rule.source_addrs, [expected_ipv4_source_range_for_user( @@ -2560,17 +2258,10 @@ async fn test_deny_all_groups_excludes_members_of_every_group_from_firewall_sour ); assert_eq!( allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.20.0/24".to_owned())), - }] - ); - assert_eq!( - allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(8443)), - }] + [IpAddress::IpSubnet("192.168.20.0/24".to_owned())] ); - assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + assert_eq!(allow_rule.destination_ports, [Port::Single(8443)]); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp]); for denied_user in [ grouped_denied_user_a.id, grouped_denied_user_b.id, @@ -2586,7 +2277,7 @@ async fn test_deny_all_groups_excludes_members_of_every_group_from_firewall_sour } let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -2653,7 +2344,7 @@ async fn test_deny_all_groups_deduplicates_shared_group_members_before_source_fi deny_all_groups: true, addresses: vec!["192.168.40.0/24".parse().unwrap()], ports: vec![PortRange::new(8443, 8443).into()], - protocols: vec![Protocol::Tcp.into()], + protocols: vec![ProtoProtocol::Tcp as i32], any_address: false, any_port: false, any_protocol: false, @@ -2714,7 +2405,7 @@ async fn test_deny_all_groups_deduplicates_shared_group_members_before_source_fi assert_eq!(generated_firewall_rules.len(), 2); let allow_rule = &generated_firewall_rules[0]; - assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.verdict, FirewallPolicy::Allow); assert_eq!( allow_rule.source_addrs, [expected_ipv4_source_range_for_user( @@ -2723,17 +2414,10 @@ async fn test_deny_all_groups_deduplicates_shared_group_members_before_source_fi ); assert_eq!( allow_rule.destination_addrs, - [IpAddress { - address: Some(Address::IpSubnet("192.168.40.0/24".to_owned())), - }] - ); - assert_eq!( - allow_rule.destination_ports, - [Port { - port: Some(PortInner::SinglePort(8443)), - }] + [IpAddress::IpSubnet("192.168.40.0/24".to_owned())] ); - assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + assert_eq!(allow_rule.destination_ports, [Port::Single(8443)]); + assert_eq!(allow_rule.protocols, [GwProtocol::Tcp]); for denied_user in expected_denied_user_ids { assert!( allow_rule @@ -2744,7 +2428,7 @@ async fn test_deny_all_groups_deduplicates_shared_group_members_before_source_fi } let deny_rule = &generated_firewall_rules[1]; - assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule.verdict, FirewallPolicy::Deny); assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); assert!(deny_rule.destination_ports.is_empty()); @@ -2880,28 +2564,24 @@ async fn test_empty_manual_destination_only_acl(_: PgPoolOptions, options: PgCon assert_eq!(generated_firewall_rules_ipv4.len(), 2); let expected_source_addrs_ipv4 = vec![ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.1.1".to_owned(), - end: "10.0.1.2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "10.0.2.1".to_owned(), - end: "10.0.2.2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "10.0.1.1".to_owned(), + end: "10.0.1.2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "10.0.2.1".to_owned(), + end: "10.0.2.2".to_owned(), + }), ]; let allow_rule_ipv4 = &generated_firewall_rules_ipv4[0]; - assert_eq!(allow_rule_ipv4.ip_version, i32::from(IpVersion::Ipv4)); - assert_eq!(allow_rule_ipv4.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule_ipv4.ip_version, IpVersion::Ipv4); + assert_eq!(allow_rule_ipv4.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule_ipv4.source_addrs, expected_source_addrs_ipv4); assert!(allow_rule_ipv4.destination_addrs.is_empty()); let deny_rule_ipv4 = &generated_firewall_rules_ipv4[1]; - assert_eq!(deny_rule_ipv4.ip_version, i32::from(IpVersion::Ipv4)); - assert_eq!(deny_rule_ipv4.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule_ipv4.ip_version, IpVersion::Ipv4); + assert_eq!(deny_rule_ipv4.verdict, FirewallPolicy::Deny); assert!(deny_rule_ipv4.source_addrs.is_empty()); assert!(deny_rule_ipv4.destination_addrs.is_empty()); @@ -2914,28 +2594,24 @@ async fn test_empty_manual_destination_only_acl(_: PgPoolOptions, options: PgCon assert_eq!(generated_firewall_rules_ipv6.len(), 2); let expected_source_addrs_ipv6 = vec![ - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::1:1".to_owned(), - end: "ff00::1:2".to_owned(), - })), - }, - IpAddress { - address: Some(Address::IpRange(IpRange { - start: "ff00::2:1".to_owned(), - end: "ff00::2:2".to_owned(), - })), - }, + IpAddress::IpRange(IpRange { + start: "ff00::1:1".to_owned(), + end: "ff00::1:2".to_owned(), + }), + IpAddress::IpRange(IpRange { + start: "ff00::2:1".to_owned(), + end: "ff00::2:2".to_owned(), + }), ]; let allow_rule_ipv6 = &generated_firewall_rules_ipv6[0]; - assert_eq!(allow_rule_ipv6.ip_version, i32::from(IpVersion::Ipv6)); - assert_eq!(allow_rule_ipv6.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule_ipv6.ip_version, IpVersion::Ipv6); + assert_eq!(allow_rule_ipv6.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule_ipv6.source_addrs, expected_source_addrs_ipv6); assert!(allow_rule_ipv6.destination_addrs.is_empty()); let deny_rule_ipv6 = &generated_firewall_rules_ipv6[1]; - assert_eq!(deny_rule_ipv6.ip_version, i32::from(IpVersion::Ipv6)); - assert_eq!(deny_rule_ipv6.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule_ipv6.ip_version, IpVersion::Ipv6); + assert_eq!(deny_rule_ipv6.verdict, FirewallPolicy::Deny); assert!(deny_rule_ipv6.source_addrs.is_empty()); assert!(deny_rule_ipv6.destination_addrs.is_empty()); @@ -2949,26 +2625,26 @@ async fn test_empty_manual_destination_only_acl(_: PgPoolOptions, options: PgCon assert_eq!(generated_firewall_rules_ipv4_and_ipv6.len(), 4); let allow_rule_ipv4 = &generated_firewall_rules_ipv4_and_ipv6[0]; - assert_eq!(allow_rule_ipv4.ip_version, i32::from(IpVersion::Ipv4)); - assert_eq!(allow_rule_ipv4.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule_ipv4.ip_version, IpVersion::Ipv4); + assert_eq!(allow_rule_ipv4.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule_ipv4.source_addrs, expected_source_addrs_ipv4); assert!(allow_rule_ipv4.destination_addrs.is_empty()); let allow_rule_ipv6 = &generated_firewall_rules_ipv4_and_ipv6[1]; - assert_eq!(allow_rule_ipv6.ip_version, i32::from(IpVersion::Ipv6)); - assert_eq!(allow_rule_ipv6.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule_ipv6.ip_version, IpVersion::Ipv6); + assert_eq!(allow_rule_ipv6.verdict, FirewallPolicy::Allow); assert_eq!(allow_rule_ipv6.source_addrs, expected_source_addrs_ipv6); assert!(allow_rule_ipv6.destination_addrs.is_empty()); let deny_rule_ipv4 = &generated_firewall_rules_ipv4_and_ipv6[2]; - assert_eq!(deny_rule_ipv4.ip_version, i32::from(IpVersion::Ipv4)); - assert_eq!(deny_rule_ipv4.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule_ipv4.ip_version, IpVersion::Ipv4); + assert_eq!(deny_rule_ipv4.verdict, FirewallPolicy::Deny); assert!(deny_rule_ipv4.source_addrs.is_empty()); assert!(deny_rule_ipv4.destination_addrs.is_empty()); let deny_rule_ipv6 = &generated_firewall_rules_ipv4_and_ipv6[3]; - assert_eq!(deny_rule_ipv6.ip_version, i32::from(IpVersion::Ipv6)); - assert_eq!(deny_rule_ipv6.verdict, i32::from(FirewallPolicy::Deny)); + assert_eq!(deny_rule_ipv6.ip_version, IpVersion::Ipv6); + assert_eq!(deny_rule_ipv6.verdict, FirewallPolicy::Deny); assert!(deny_rule_ipv6.source_addrs.is_empty()); assert!(deny_rule_ipv6.destination_addrs.is_empty()); } diff --git a/crates/defguard_core/src/enterprise/firewall/tests/source.rs b/crates/defguard_core/src/enterprise/firewall/tests/source.rs index 98d4638206..4db9d82d90 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/source.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/source.rs @@ -1,7 +1,6 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - -use defguard_proto::enterprise::firewall::{IpAddress, IpVersion, ip_address::Address}; +use defguard_common::gateway_types::{IpAddress, IpVersion}; use rand::thread_rng; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use crate::enterprise::firewall::{ get_source_addrs, get_source_network_devices, get_source_users, @@ -71,21 +70,11 @@ fn test_process_source_addrs_v4() { assert_eq!( source_addrs, [ - IpAddress { - address: Some(Address::Ip("10.0.1.1".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.2/31".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("10.0.1.4/31".to_owned())) - }, - IpAddress { - address: Some(Address::Ip("172.16.1.1".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("192.168.1.100".to_owned())), - }, + IpAddress::Ip("10.0.1.1".to_owned()), + IpAddress::IpSubnet("10.0.1.2/31".to_owned()), + IpAddress::IpSubnet("10.0.1.4/31".to_owned()), + IpAddress::Ip("172.16.1.1".to_owned()), + IpAddress::Ip("192.168.1.100".to_owned()), ] ); @@ -126,21 +115,11 @@ fn test_process_source_addrs_v6() { assert_eq!( source_addrs, [ - IpAddress { - address: Some(Address::Ip("2001:db8::1".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8::2/127".to_owned())) - }, - IpAddress { - address: Some(Address::IpSubnet("2001:db8::4/127".to_owned())) - }, - IpAddress { - address: Some(Address::Ip("2001:db8:0:1::1".to_owned())), - }, - IpAddress { - address: Some(Address::Ip("2001:db8:0:2::1".to_owned())), - }, + IpAddress::Ip("2001:db8::1".to_owned()), + IpAddress::IpSubnet("2001:db8::2/127".to_owned()), + IpAddress::IpSubnet("2001:db8::4/127".to_owned()), + IpAddress::Ip("2001:db8:0:1::1".to_owned()), + IpAddress::Ip("2001:db8:0:2::1".to_owned()), ] ); diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index e934a8d762..9d6f9e4749 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -16,6 +16,7 @@ use defguard_common::{ wireguard::ServiceLocationMode, }, }, + gateway_types::{FirewallConfig, WireguardPeer}, types::UrlParseError, }; use reqwest::Url; @@ -49,10 +50,7 @@ pub mod proto { } } -use defguard_proto::{ - enterprise::firewall::FirewallConfig, gateway::Peer, - worker::worker_service_server::WorkerServiceServer, -}; +use defguard_proto::worker::worker_service_server::WorkerServiceServer; use tonic::transport::{Identity, Server, ServerTlsConfig, server::Router}; // gRPC header for passing auth token from clients @@ -218,7 +216,12 @@ impl From for defguard_proto::client_types::InstanceInfo { #[derive(Clone, Debug)] pub enum GatewayEvent { NetworkCreated(Id, WireguardNetwork), - NetworkModified(Id, WireguardNetwork, Vec, Option), + NetworkModified( + Id, + WireguardNetwork, + Vec, + Option, + ), NetworkDeleted(Id, String), DeviceCreated(DeviceInfo), DeviceModified(DeviceInfo), diff --git a/crates/defguard_core/src/location_management/allowed_peers.rs b/crates/defguard_core/src/location_management/allowed_peers.rs index 4f0b130aaa..bfd2892a7e 100644 --- a/crates/defguard_core/src/location_management/allowed_peers.rs +++ b/crates/defguard_core/src/location_management/allowed_peers.rs @@ -1,5 +1,7 @@ -use defguard_common::db::{Id, models::WireguardNetwork}; -use defguard_proto::gateway::Peer; +use defguard_common::{ + db::{Id, models::WireguardNetwork}, + gateway_types::WireguardPeer, +}; use sqlx::{PgExecutor, query}; use crate::grpc::should_prevent_service_location_usage; @@ -14,7 +16,7 @@ use crate::grpc::should_prevent_service_location_usage; pub async fn get_location_allowed_peers<'e, E>( location: &WireguardNetwork, executor: E, -) -> sqlx::Result> +) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -49,7 +51,7 @@ where return Ok(rows .into_iter() - .map(|row| Peer { + .map(|row| WireguardPeer { pubkey: row.pubkey, allowed_ips: row.allowed_ips, preshared_key: None, @@ -89,7 +91,7 @@ where Ok(rows .into_iter() - .map(|row| Peer { + .map(|row| WireguardPeer { pubkey: row.pubkey, allowed_ips: row.allowed_ips, preshared_key: Some(row.preshared_key), diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 46e8b9e6c3..71db0b2902 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -21,6 +21,7 @@ use defguard_common::{ wireguard::DEFAULT_WIREGUARD_MTU, }, }, + gateway_types::{FirewallConfig, WireguardPeer}, messages::peer_stats_update::PeerStatsUpdate, }; use defguard_core::{ @@ -31,7 +32,7 @@ use defguard_core::{ }; use defguard_grpc_tls::{certs as tls_certs, connector::HttpsSchemeConnector}; use defguard_proto::{ - enterprise::firewall::FirewallConfig, + enterprise::firewall::FirewallConfig as ProtoFirewallConfig, gateway::{ Configuration, CoreResponse, Peer, PeerStats, Update, UpdateType, core_request, core_response, gateway_client, update, @@ -858,11 +859,13 @@ impl GatewayUpdatesHandler { fn send_network_update( &self, network: &WireguardNetwork, - peers: Vec, + peers: Vec, firewall_config: Option, update_type: i32, ) -> Result<(), Status> { debug!("Sending network update for network {network}"); + let proto_peers: Vec = peers.into_iter().map(Into::into).collect(); + let proto_firewall: Option = firewall_config.map(Into::into); if let Err(err) = self.tx.send(CoreResponse { id: 0, payload: Some(core_response::Payload::Update(Update { @@ -872,8 +875,8 @@ impl GatewayUpdatesHandler { private_key: network.prvkey.clone(), addresses: network.address().iter().map(ToString::to_string).collect(), port: network.port.cast_unsigned(), - peers, - firewall_config, + peers: proto_peers, + firewall_config: proto_firewall, mtu: network.mtu.cast_unsigned(), fwmark: network.fwmark as u32, })), @@ -988,11 +991,12 @@ impl GatewayUpdatesHandler { "Sending firewall config update for network {} with config {firewall_config:?}", self.network ); + let proto_firewall: ProtoFirewallConfig = firewall_config.into(); if let Err(err) = self.tx.send(CoreResponse { id: 0, payload: Some(core_response::Payload::Update(Update { update_type: UpdateType::Modify as i32, - update: Some(update::Update::FirewallConfig(firewall_config)), + update: Some(update::Update::FirewallConfig(proto_firewall)), })), }) { let msg = format!( @@ -1073,15 +1077,17 @@ mod tests { }, setup_pool, }; + use defguard_common::gateway_types::{FirewallConfig, FirewallPolicy, WireguardPeer}; use defguard_core::grpc::GatewayEvent; - use defguard_proto::gateway::{Configuration, Peer, PeerStats, core_response}; + use defguard_proto::{ + enterprise::firewall::FirewallPolicy as ProtoFirewallPolicy, + gateway::{Configuration, PeerStats, core_response}, + }; use prost_types::Timestamp; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::sync::{broadcast, mpsc::unbounded_channel, watch}; - use super::{ - FirewallConfig, GatewayHandler, GatewayUpdatesHandler, try_protos_into_stats_message, - }; + use super::{GatewayHandler, GatewayUpdatesHandler, try_protos_into_stats_message}; fn test_network(location_mfa_mode: LocationMfaMode) -> WireguardNetwork { WireguardNetwork::new( @@ -1203,14 +1209,14 @@ mod tests { fn gen_config_maps_network_fields() { let config = Configuration::new( &build_network(), - vec![Peer { + vec![WireguardPeer { pubkey: "peer-public-key".to_owned(), allowed_ips: vec!["10.10.0.2/32".to_owned()], preshared_key: Some("peer-preshared-key".to_owned()), keepalive_interval: Some(25), }], Some(FirewallConfig { - default_policy: 0, + default_policy: FirewallPolicy::Unspecified, rules: Vec::new(), snat_bindings: Vec::new(), }), @@ -1235,7 +1241,10 @@ mod tests { let firewall_config = config .firewall_config .expect("generated config should include firewall config"); - assert_eq!(firewall_config.default_policy, 0); + assert_eq!( + firewall_config.default_policy, + ProtoFirewallPolicy::Unspecified as i32 + ); assert!(firewall_config.rules.is_empty()); assert!(firewall_config.snat_bindings.is_empty()); } diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs index 4cd3b175b5..fcfc5c8756 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs @@ -9,12 +9,12 @@ use defguard_common::db::{ wireguard::{LocationMfaMode, WireguardNetwork}, }, }; +use defguard_common::gateway_types::{ + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, Protocol, SnatBinding, +}; use defguard_core::grpc::GatewayEvent; use defguard_proto::{ - enterprise::firewall::{ - FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, Protocol, - SnatBinding, ip_address::Address, port::Port as PortInner, - }, + enterprise::firewall::FirewallConfig as ProtoFirewallConfig, gateway::{ CoreResponse, Update, UpdateType, core_response, update::{self}, @@ -366,28 +366,20 @@ pub(crate) fn assert_network_modify_update( pub(crate) fn build_test_firewall_config() -> FirewallConfig { FirewallConfig { - default_policy: i32::from(FirewallPolicy::Allow), + default_policy: FirewallPolicy::Allow, rules: vec![FirewallRule { id: 101, - source_addrs: vec![IpAddress { - address: Some(Address::IpSubnet("10.10.0.0/24".to_owned())), - }], - destination_addrs: vec![IpAddress { - address: Some(Address::Ip("198.51.100.20".to_owned())), - }], - destination_ports: vec![Port { - port: Some(PortInner::SinglePort(443)), - }], - protocols: vec![i32::from(Protocol::Tcp)], - verdict: i32::from(FirewallPolicy::Deny), + source_addrs: vec![IpAddress::IpSubnet("10.10.0.0/24".to_owned())], + destination_addrs: vec![IpAddress::Ip("198.51.100.20".to_owned())], + destination_ports: vec![Port::Single(443)], + protocols: vec![Protocol::Tcp], + verdict: FirewallPolicy::Deny, comment: Some("block test https destination".to_owned()), - ip_version: i32::from(IpVersion::Ipv4), + ip_version: IpVersion::Ipv4, }], snat_bindings: vec![SnatBinding { id: 202, - source_addrs: vec![IpAddress { - address: Some(Address::IpSubnet("10.10.0.0/24".to_owned())), - }], + source_addrs: vec![IpAddress::IpSubnet("10.10.0.0/24".to_owned())], public_ip: "203.0.113.44".to_owned(), comment: Some("test snat binding".to_owned()), }], @@ -398,6 +390,8 @@ pub(crate) fn assert_firewall_modify_update( outbound: CoreResponse, expected_firewall_config: &FirewallConfig, ) { + // Convert expected native config to proto for comparison with the outbound proto payload + let expected_proto: ProtoFirewallConfig = expected_firewall_config.clone().into(); match outbound.payload { Some(core_response::Payload::Update(Update { update_type, @@ -406,22 +400,19 @@ pub(crate) fn assert_firewall_modify_update( assert_eq!(update_type, UpdateType::Modify as i32); assert_eq!( firewall_config.default_policy, - expected_firewall_config.default_policy - ); - assert_eq!( - firewall_config.rules.len(), - expected_firewall_config.rules.len() + expected_proto.default_policy ); + assert_eq!(firewall_config.rules.len(), expected_proto.rules.len()); assert_eq!( firewall_config.snat_bindings.len(), - expected_firewall_config.snat_bindings.len() + expected_proto.snat_bindings.len() ); let firewall_rule = firewall_config .rules .first() .expect("expected firewall rule in update payload"); - let expected_firewall_rule = expected_firewall_config + let expected_firewall_rule = expected_proto .rules .first() .expect("expected firewall rule in test config"); @@ -447,7 +438,7 @@ pub(crate) fn assert_firewall_modify_update( .snat_bindings .first() .expect("expected SNAT binding in update payload"); - let expected_snat_binding = expected_firewall_config + let expected_snat_binding = expected_proto .snat_bindings .first() .expect("expected SNAT binding in test config"); diff --git a/crates/defguard_proto/src/gateway_conversions.rs b/crates/defguard_proto/src/gateway_conversions.rs new file mode 100644 index 0000000000..7c09f84057 --- /dev/null +++ b/crates/defguard_proto/src/gateway_conversions.rs @@ -0,0 +1,128 @@ +use defguard_common::gateway_types::{ + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpRange, IpVersion, Port, PortRange, + Protocol, SnatBinding, WireguardPeer, +}; + +use crate::{ + enterprise::firewall::{ + FirewallConfig as ProtoFirewallConfig, FirewallPolicy as ProtoFirewallPolicy, + FirewallRule as ProtoFirewallRule, IpAddress as ProtoIpAddress, IpRange as ProtoIpRange, + IpVersion as ProtoIpVersion, Port as ProtoPort, PortRange as ProtoPortRange, + Protocol as ProtoProtocol, SnatBinding as ProtoSnatBinding, ip_address::Address, + port::Port as PortInner, + }, + gateway::Peer, +}; + +impl From for Peer { + fn from(p: WireguardPeer) -> Self { + Self { + pubkey: p.pubkey, + allowed_ips: p.allowed_ips, + preshared_key: p.preshared_key, + keepalive_interval: p.keepalive_interval, + } + } +} + +fn policy_to_i32(policy: FirewallPolicy) -> i32 { + match policy { + FirewallPolicy::Unspecified => ProtoFirewallPolicy::Unspecified as i32, + FirewallPolicy::Allow => ProtoFirewallPolicy::Allow as i32, + FirewallPolicy::Deny => ProtoFirewallPolicy::Deny as i32, + } +} + +fn ip_version_to_i32(v: IpVersion) -> i32 { + match v { + IpVersion::Unspecified => ProtoIpVersion::Unspecified as i32, + IpVersion::Ipv4 => ProtoIpVersion::Ipv4 as i32, + IpVersion::Ipv6 => ProtoIpVersion::Ipv6 as i32, + } +} + +fn protocol_to_i32(p: Protocol) -> i32 { + match p { + Protocol::Unspecified => ProtoProtocol::Unspecified as i32, + Protocol::Icmp => ProtoProtocol::Icmp as i32, + Protocol::Tcp => ProtoProtocol::Tcp as i32, + Protocol::Udp => ProtoProtocol::Udp as i32, + } +} + +impl From for ProtoIpRange { + fn from(r: IpRange) -> Self { + Self { + start: r.start, + end: r.end, + } + } +} + +impl From for ProtoIpAddress { + fn from(addr: IpAddress) -> Self { + Self { + address: Some(match addr { + IpAddress::Ip(ip) => Address::Ip(ip), + IpAddress::IpRange(r) => Address::IpRange(r.into()), + IpAddress::IpSubnet(s) => Address::IpSubnet(s), + }), + } + } +} + +impl From for ProtoPortRange { + fn from(r: PortRange) -> Self { + Self { + start: r.start, + end: r.end, + } + } +} + +impl From for ProtoPort { + fn from(p: Port) -> Self { + Self { + port: Some(match p { + Port::Single(n) => PortInner::SinglePort(n), + Port::Range(r) => PortInner::PortRange(r.into()), + }), + } + } +} + +impl From for ProtoFirewallRule { + fn from(r: FirewallRule) -> Self { + Self { + id: r.id, + source_addrs: r.source_addrs.into_iter().map(Into::into).collect(), + destination_addrs: r.destination_addrs.into_iter().map(Into::into).collect(), + destination_ports: r.destination_ports.into_iter().map(Into::into).collect(), + protocols: r.protocols.into_iter().map(protocol_to_i32).collect(), + verdict: policy_to_i32(r.verdict), + comment: r.comment, + ip_version: ip_version_to_i32(r.ip_version), + } + } +} + +impl From for ProtoSnatBinding { + fn from(b: SnatBinding) -> Self { + Self { + id: b.id, + source_addrs: b.source_addrs.into_iter().map(Into::into).collect(), + public_ip: b.public_ip, + comment: b.comment, + } + } +} + +impl From for ProtoFirewallConfig { + fn from(c: FirewallConfig) -> Self { + Self { + default_policy: policy_to_i32(c.default_policy), + rules: c.rules.into_iter().map(Into::into).collect(), + snat_bindings: c.snat_bindings.into_iter().map(Into::into).collect(), + } + } +} diff --git a/crates/defguard_proto/src/lib.rs b/crates/defguard_proto/src/lib.rs index db7d37dde3..8e32212149 100644 --- a/crates/defguard_proto/src/lib.rs +++ b/crates/defguard_proto/src/lib.rs @@ -1,3 +1,5 @@ +pub mod gateway_conversions; + use std::fmt; mod generated { @@ -85,15 +87,13 @@ use defguard_common::{ wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, + gateway_types::{FirewallConfig, WireguardPeer}, }; use proxy::CoreError; use serde::Serialize; use tonic::Status; -use crate::{ - enterprise::firewall::FirewallConfig, - gateway::{Configuration, Peer}, -}; +use crate::gateway::Configuration; // Client MFA methods impl fmt::Display for MfaMethod { @@ -224,7 +224,7 @@ impl From for client_types::ServiceLocationMode { impl Configuration { pub fn new( location: &WireguardNetwork, - peers: Vec, + peers: Vec, maybe_firewall_config: Option, ) -> Self { Self { @@ -232,8 +232,8 @@ impl Configuration { port: location.port.cast_unsigned(), private_key: location.prvkey.clone(), addresses: location.address().iter().map(ToString::to_string).collect(), - peers, - firewall_config: maybe_firewall_config, + peers: peers.into_iter().map(Into::into).collect(), + firewall_config: maybe_firewall_config.map(Into::into), mtu: location.mtu.cast_unsigned(), fwmark: location.fwmark as u32, } From e23e9bcbee97cd9b1dc38f20c2cd420599ab053e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 18 May 2026 20:47:29 +0200 Subject: [PATCH 02/10] migrate GatewayEvent to defguard_common crate --- crates/defguard/src/main.rs | 3 +- crates/defguard_common/src/gateway_event.rs | 52 +++++++++++++++++++ crates/defguard_common/src/lib.rs | 1 + crates/defguard_core/src/grpc/mod.rs | 51 ++---------------- crates/defguard_event_router/Cargo.toml | 2 +- .../src/handlers/bidi.rs | 6 +-- crates/defguard_event_router/src/lib.rs | 6 +-- .../defguard_gateway_manager/src/handler.rs | 4 +- crates/defguard_gateway_manager/src/lib.rs | 2 +- .../src/tests/common/mod.rs | 2 +- .../src/tests/gateway_manager/handler.rs | 2 +- .../tests/gateway_manager/handler/support.rs | 2 +- crates/defguard_proxy_manager/src/lib.rs | 4 +- .../src/tests/common/mod.rs | 3 +- .../tests/proxy_manager/handler/enrollment.rs | 6 +-- .../src/tests/proxy_manager/handler/mfa.rs | 2 +- crates/defguard_session_manager/src/error.rs | 2 +- crates/defguard_session_manager/src/lib.rs | 2 +- .../tests/common/mod.rs | 3 +- .../tests/session_manager/mfa.rs | 2 +- crates/defguard_setup/src/migration.rs | 2 +- 21 files changed, 84 insertions(+), 75 deletions(-) create mode 100644 crates/defguard_common/src/gateway_event.rs diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 92cf9d494d..ed400eb65a 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -17,6 +17,7 @@ use defguard_common::{ settings::{initialize_current_settings, update_current_settings}, }, }, + gateway_event::GatewayEvent, messages::peer_stats_update::PeerStatsUpdate, types::proxy::ProxyControlMessage, }; @@ -30,7 +31,7 @@ use defguard_core::{ }, events::{ApiEvent, BidiStreamEvent}, gateway_config, - grpc::{GatewayEvent, WorkerState, run_grpc_server}, + grpc::{WorkerState, run_grpc_server}, init_dev_env, init_vpn_location, run_web_server, setup_logs::CoreSetupLogLayer, utility_thread::run_utility_thread, diff --git a/crates/defguard_common/src/gateway_event.rs b/crates/defguard_common/src/gateway_event.rs new file mode 100644 index 0000000000..7aa0da89a7 --- /dev/null +++ b/crates/defguard_common/src/gateway_event.rs @@ -0,0 +1,52 @@ +use tokio::sync::broadcast::Sender; +use tracing::{debug, error}; + +use crate::{ + db::{ + Id, + models::{ + Device, WireguardNetwork, + device::{DeviceInfo, DeviceNetworkInfo}, + }, + }, + gateway_types::{FirewallConfig, WireguardPeer}, +}; + +#[derive(Clone, Debug)] +pub enum GatewayEvent { + NetworkCreated(Id, WireguardNetwork), + NetworkModified( + Id, + WireguardNetwork, + Vec, + Option, + ), + NetworkDeleted(Id, String), + DeviceCreated(DeviceInfo), + DeviceModified(DeviceInfo), + DeviceDeleted(DeviceInfo), + FirewallConfigChanged(Id, FirewallConfig), + FirewallDisabled(Id), + MfaSessionAuthorized(Id, Device, DeviceNetworkInfo), + MfaSessionDisconnected(Id, Device), +} + +/// Sends a [`GatewayEvent`] to the gateway manager service. +/// +/// In API handler context prefer `AppState::send_wireguard_event`. +pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { + debug!("Sending the following WireGuard event to Defguard Gateway: {event:?}"); + if let Err(err) = wg_tx.send(event) { + error!("Error sending WireGuard event {err}"); + } +} + +/// Sends multiple [`GatewayEvent`]s to the gateway manager service. +/// +/// In API handler context prefer `AppState::send_multiple_wireguard_events`. +pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender) { + debug!("Sending {} WireGuard events", events.len()); + for event in events { + send_wireguard_event(event, wg_tx); + } +} diff --git a/crates/defguard_common/src/lib.rs b/crates/defguard_common/src/lib.rs index bf1fb913ff..9eff8e1d0a 100644 --- a/crates/defguard_common/src/lib.rs +++ b/crates/defguard_common/src/lib.rs @@ -2,6 +2,7 @@ pub mod auth; pub mod config; pub mod csv; pub mod db; +pub mod gateway_event; pub mod gateway_types; pub mod globals; pub mod hex; diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 9d6f9e4749..6e8370f3b8 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -10,19 +10,14 @@ use defguard_common::{ config::server_config, db::{ Id, - models::{ - Device, Settings, WireguardNetwork, - device::{DeviceInfo, DeviceNetworkInfo}, - wireguard::ServiceLocationMode, - }, + models::{Settings, WireguardNetwork, wireguard::ServiceLocationMode}, }, - gateway_types::{FirewallConfig, WireguardPeer}, types::UrlParseError, }; use reqwest::Url; use serde::Serialize; use sqlx::PgPool; -use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; +use tokio::sync::mpsc::UnboundedSender; use crate::{ db::AppEvent, @@ -212,45 +207,9 @@ impl From for defguard_proto::client_types::InstanceInfo { } } -// TODO: move this to common crate -#[derive(Clone, Debug)] -pub enum GatewayEvent { - NetworkCreated(Id, WireguardNetwork), - NetworkModified( - Id, - WireguardNetwork, - Vec, - Option, - ), - NetworkDeleted(Id, String), - DeviceCreated(DeviceInfo), - DeviceModified(DeviceInfo), - DeviceDeleted(DeviceInfo), - FirewallConfigChanged(Id, FirewallConfig), - FirewallDisabled(Id), - MfaSessionAuthorized(Id, Device, DeviceNetworkInfo), - MfaSessionDisconnected(Id, Device), -} - -/// Sends given `GatewayEvent` to be handled by gateway GRPC server -/// -/// If you want to use it inside the API context, use [`crate::AppState::send_wireguard_event`] instead -pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { - debug!("Sending the following WireGuard event to Defguard Gateway: {event:?}"); - if let Err(err) = wg_tx.send(event) { - error!("Error sending WireGuard event {err}"); - } -} - -/// Sends multiple events to be handled by gateway gRPC server. -/// -/// If you want to use it inside the API context, use [`crate::AppState::send_multiple_wireguard_events`] instead -pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender) { - debug!("Sending {} WireGuard events", events.len()); - for event in events { - send_wireguard_event(event, wg_tx); - } -} +pub use defguard_common::gateway_event::{ + GatewayEvent, send_multiple_wireguard_events, send_wireguard_event, +}; /// If this location is marked as a service location, checks if all requirements are met for it to /// function: diff --git a/crates/defguard_event_router/Cargo.toml b/crates/defguard_event_router/Cargo.toml index 8b64231497..c36146d0a3 100644 --- a/crates/defguard_event_router/Cargo.toml +++ b/crates/defguard_event_router/Cargo.toml @@ -9,6 +9,7 @@ rust-version.workspace = true [dependencies] # internal crates +defguard_common = { workspace = true } defguard_core = { workspace = true } defguard_event_logger = { workspace = true } defguard_session_manager = { workspace = true } @@ -19,4 +20,3 @@ tokio = { workspace = true } tracing = { workspace = true } [dev-dependencies] -defguard_common = { workspace = true } diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index 748f8ff3ed..cc4e74a66b 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -114,10 +114,8 @@ mod tests { wireguard::{LocationMfaMode, ServiceLocationMode}, }, }; - use defguard_core::{ - events::{BidiRequestContext, BidiStreamEventType}, - grpc::GatewayEvent, - }; + use defguard_common::gateway_event::GatewayEvent; + use defguard_core::events::{BidiRequestContext, BidiStreamEventType}; use tokio::sync::{Notify, broadcast, mpsc::unbounded_channel}; use super::*; diff --git a/crates/defguard_event_router/src/lib.rs b/crates/defguard_event_router/src/lib.rs index b46e6d2579..ffbc495253 100644 --- a/crates/defguard_event_router/src/lib.rs +++ b/crates/defguard_event_router/src/lib.rs @@ -19,10 +19,8 @@ use std::sync::Arc; -use defguard_core::{ - events::{ApiEvent, BidiStreamEvent}, - grpc::GatewayEvent, -}; +use defguard_common::gateway_event::GatewayEvent; +use defguard_core::events::{ApiEvent, BidiStreamEvent}; use defguard_event_logger::message::{EventContext, EventLoggerMessage, LoggerEvent}; use defguard_session_manager::events::SessionManagerEvent; use error::EventRouterError; diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 71db0b2902..c0a7232241 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -12,6 +12,7 @@ use std::{ }; use chrono::{DateTime, TimeDelta}; +use defguard_common::gateway_event::GatewayEvent; use defguard_common::{ VERSION, db::{ @@ -26,7 +27,6 @@ use defguard_common::{ }; use defguard_core::{ enterprise::firewall::try_get_location_firewall_config, - grpc::GatewayEvent, handlers::mail::{send_gateway_disconnected_email, send_gateway_reconnected_email}, location_management::allowed_peers::get_location_allowed_peers, }; @@ -1077,8 +1077,8 @@ mod tests { }, setup_pool, }; + use defguard_common::gateway_event::GatewayEvent; use defguard_common::gateway_types::{FirewallConfig, FirewallPolicy, WireguardPeer}; - use defguard_core::grpc::GatewayEvent; use defguard_proto::{ enterprise::firewall::FirewallPolicy as ProtoFirewallPolicy, gateway::{Configuration, PeerStats, core_response}, diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index 8cdac400ef..323431d41b 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -9,11 +9,11 @@ use std::{ sync::atomic::{AtomicBool, Ordering}, }; +use defguard_common::gateway_event::GatewayEvent; use defguard_common::{ db::{ChangeNotification, Id, TriggerOperation, models::gateway::Gateway}, messages::peer_stats_update::PeerStatsUpdate, }; -use defguard_core::grpc::GatewayEvent; use defguard_proto::gateway::gateway_client::GatewayClient; use defguard_version::client::ClientVersionInterceptor; use sqlx::{PgPool, postgres::PgListener}; diff --git a/crates/defguard_gateway_manager/src/tests/common/mod.rs b/crates/defguard_gateway_manager/src/tests/common/mod.rs index eee7e9e471..6290016be0 100644 --- a/crates/defguard_gateway_manager/src/tests/common/mod.rs +++ b/crates/defguard_gateway_manager/src/tests/common/mod.rs @@ -11,6 +11,7 @@ use std::{ time::Duration, }; +use defguard_common::gateway_event::GatewayEvent; use defguard_common::{ db::{ Id, NoId, @@ -21,7 +22,6 @@ use defguard_common::{ }, messages::peer_stats_update::PeerStatsUpdate, }; -use defguard_core::grpc::GatewayEvent; use defguard_proto::gateway::{CoreRequest, CoreResponse, PeerStats, core_request, gateway_server}; use prost_types::Timestamp; use sqlx::{PgPool, postgres::PgConnectOptions}; diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs index 35896035f6..afe1ebd7a8 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs @@ -2,7 +2,7 @@ mod support; use defguard_common::db::models::device::{DeviceInfo, WireguardNetworkDevice}; -use defguard_core::grpc::GatewayEvent; +use defguard_common::gateway_event::GatewayEvent; use defguard_proto::gateway::{UpdateType, core_response}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tonic::Status; diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs index fcfc5c8756..7f1e61b5cc 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs @@ -9,10 +9,10 @@ use defguard_common::db::{ wireguard::{LocationMfaMode, WireguardNetwork}, }, }; +use defguard_common::gateway_event::GatewayEvent; use defguard_common::gateway_types::{ FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, Protocol, SnatBinding, }; -use defguard_core::grpc::GatewayEvent; use defguard_proto::{ enterprise::firewall::FirewallConfig as ProtoFirewallConfig, gateway::{ diff --git a/crates/defguard_proxy_manager/src/lib.rs b/crates/defguard_proxy_manager/src/lib.rs index cb6c546be8..839b722a4f 100644 --- a/crates/defguard_proxy_manager/src/lib.rs +++ b/crates/defguard_proxy_manager/src/lib.rs @@ -7,13 +7,13 @@ use std::{ use std::{path::PathBuf, str::FromStr, sync::Mutex as StdMutex}; use axum_extra::extract::cookie::Key; +use defguard_common::gateway_event::GatewayEvent; use defguard_common::{ db::{Id, models::proxy::Proxy}, types::proxy::ProxyControlMessage, }; use defguard_core::{ - events::BidiStreamEvent, - grpc::{GatewayEvent, proxy::client_mfa::ClientLoginSession}, + events::BidiStreamEvent, grpc::proxy::client_mfa::ClientLoginSession, version::IncompatibleComponents, }; use defguard_proto::proxy::{CoreResponse, HttpsCerts, core_response}; diff --git a/crates/defguard_proxy_manager/src/tests/common/mod.rs b/crates/defguard_proxy_manager/src/tests/common/mod.rs index a2e19d770b..5130bf31fa 100644 --- a/crates/defguard_proxy_manager/src/tests/common/mod.rs +++ b/crates/defguard_proxy_manager/src/tests/common/mod.rs @@ -24,7 +24,8 @@ use defguard_common::db::{ }, setup_pool, }; -use defguard_core::{events::BidiStreamEvent, grpc::GatewayEvent}; +use defguard_common::gateway_event::GatewayEvent; +use defguard_core::events::BidiStreamEvent; use defguard_proto::proxy::{ AcmeChallenge, AcmeIssueEvent, CoreRequest, CoreResponse, InitialInfo, core_response, proxy_server, diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs index 36af0a6526..e602bef255 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs @@ -2,10 +2,8 @@ use defguard_common::db::models::{ Device, Settings, User, biometric_auth::BiometricAuth, polling_token::PollingToken, settings::update_current_settings, }; -use defguard_core::{ - events::{BidiStreamEventType, EnrollmentEvent}, - grpc::GatewayEvent, -}; +use defguard_common::gateway_event::GatewayEvent; +use defguard_core::events::{BidiStreamEventType, EnrollmentEvent}; use defguard_proto::{ client_types::{ExistingDevice, MfaMethod, NewDevice, RegisterMobileAuthRequest}, proxy::{CoreRequest, core_request, core_response}, diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs index e1a117c9e9..330566c45a 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs @@ -1,5 +1,5 @@ use defguard_common::db::Id; -use defguard_core::grpc::GatewayEvent; +use defguard_common::gateway_event::GatewayEvent; use defguard_proto::{ client_types::{ClientMfaFinishRequest, ClientMfaStartRequest, MfaMethod}, proxy::{AwaitRemoteMfaFinishRequest, CoreRequest, core_request, core_response}, diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 4904a26038..2a41c8e699 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -1,5 +1,5 @@ use defguard_common::db::Id; -use defguard_core::grpc::GatewayEvent; +use defguard_common::gateway_event::GatewayEvent; use thiserror::Error; use tokio::sync::{broadcast::error::SendError as BroadcastSendError, mpsc::error::SendError}; diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 88ad9862a0..c1a4ee6632 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -1,4 +1,5 @@ use chrono::Utc; +use defguard_common::gateway_event::GatewayEvent; use defguard_common::{ db::{ Id, @@ -9,7 +10,6 @@ use defguard_common::{ }, messages::peer_stats_update::PeerStatsUpdate, }; -use defguard_core::grpc::GatewayEvent; use sqlx::{PgConnection, PgPool}; use tokio::{ sync::{ diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index 36b3093737..63dc91d0ea 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -16,6 +16,7 @@ use defguard_common::{ wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, + gateway_event::GatewayEvent, messages::peer_stats_update::PeerStatsUpdate, }; use defguard_session_manager::{ @@ -37,7 +38,7 @@ pub(crate) struct SessionManagerHarness { stats_tx: mpsc::UnboundedSender, pub(crate) stats_rx: mpsc::UnboundedReceiver, pub(crate) event_rx: mpsc::UnboundedReceiver, - pub(crate) gateway_rx: broadcast::Receiver, + pub(crate) gateway_rx: broadcast::Receiver, } pub(crate) fn assert_no_session_manager_events(harness: &mut SessionManagerHarness) { diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs index a3a688ccdc..f7c828559a 100644 --- a/crates/defguard_session_manager/tests/session_manager/mfa.rs +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -9,7 +9,7 @@ use defguard_common::db::{ }, setup_pool, }; -use defguard_core::grpc::GatewayEvent; +use defguard_common::gateway_event::GatewayEvent; use defguard_session_manager::events::SessionManagerEventType; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::time::{Duration, timeout}; diff --git a/crates/defguard_setup/src/migration.rs b/crates/defguard_setup/src/migration.rs index ee1980bf4d..fb6fa87d6a 100644 --- a/crates/defguard_setup/src/migration.rs +++ b/crates/defguard_setup/src/migration.rs @@ -10,6 +10,7 @@ use axum::{ serve, }; use axum_extra::extract::cookie::Key; +use defguard_common::gateway_event::GatewayEvent; use defguard_common::{VERSION, db::models::Settings, types::proxy::ProxyControlMessage}; use defguard_core::{ appstate::AppState, @@ -17,7 +18,6 @@ use defguard_core::{ db::AppEvent, enterprise::handlers::openid_login::{auth_callback, get_auth_info}, events::ApiEvent, - grpc::GatewayEvent, handle_404, handlers::{ auth::{ From b06e3cba091f1be6fc8f99b40018d9caa05a93f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 19 May 2026 08:44:26 +0200 Subject: [PATCH 03/10] rename event to command --- crates/defguard/src/main.rs | 4 +- crates/defguard_common/src/gateway_event.rs | 22 +++---- crates/defguard_core/src/appstate.rs | 22 +++---- .../src/enterprise/db/models/acl.rs | 6 +- .../src/enterprise/directory_sync/mod.rs | 14 ++--- .../src/enterprise/directory_sync/tests.rs | 52 ++++++++-------- .../src/enterprise/snat/handlers.rs | 8 +-- crates/defguard_core/src/grpc/mod.rs | 2 +- .../src/grpc/proxy/client_mfa.rs | 20 +++---- .../src/handlers/network_devices.rs | 10 ++-- .../defguard_core/src/handlers/wireguard.rs | 30 +++++----- crates/defguard_core/src/lib.rs | 6 +- .../src/location_management/mod.rs | 42 ++++++------- crates/defguard_core/src/user_management.rs | 20 +++---- crates/defguard_core/src/utility_thread.rs | 18 +++--- .../tests/integration/api/common/mod.rs | 8 +-- .../tests/integration/api/proxy_certs.rs | 4 +- .../tests/integration/api/wireguard.rs | 25 ++++---- .../api/wireguard_network_allowed_groups.rs | 59 +++++++++++++------ .../api/wireguard_network_devices.rs | 10 ++-- .../api/wireguard_network_import.rs | 15 +++-- .../src/handlers/bidi.rs | 4 +- crates/defguard_event_router/src/lib.rs | 8 +-- .../defguard_gateway_manager/src/handler.rs | 36 +++++------ crates/defguard_gateway_manager/src/lib.rs | 6 +- .../src/tests/common/mod.rs | 8 +-- .../src/tests/gateway_manager/handler.rs | 2 +- .../gateway_manager/handler/device_events.rs | 27 +++++---- .../handler/firewall_events.rs | 8 +-- .../src/tests/gateway_manager/handler/mfa.rs | 6 +- .../gateway_manager/handler/network_events.rs | 14 ++--- .../tests/gateway_manager/handler/support.rs | 8 +-- crates/defguard_proxy_manager/src/handler.rs | 4 +- crates/defguard_proxy_manager/src/lib.rs | 6 +- .../src/servers/enrollment.rs | 16 ++--- .../src/tests/common/mod.rs | 4 +- .../tests/proxy_manager/handler/enrollment.rs | 4 +- .../src/tests/proxy_manager/handler/mfa.rs | 10 ++-- crates/defguard_session_manager/src/error.rs | 8 +-- crates/defguard_session_manager/src/lib.rs | 10 ++-- .../tests/common/mod.rs | 4 +- .../tests/session_manager/mfa.rs | 6 +- crates/defguard_setup/src/migration.rs | 6 +- 43 files changed, 315 insertions(+), 287 deletions(-) diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index ed400eb65a..211232a74d 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -17,7 +17,7 @@ use defguard_common::{ settings::{initialize_current_settings, update_current_settings}, }, }, - gateway_event::GatewayEvent, + gateway_event::GatewayCommand, messages::peer_stats_update::PeerStatsUpdate, types::proxy::ProxyControlMessage, }; @@ -211,7 +211,7 @@ async fn main() -> Result<(), anyhow::Error> { // setup communication channels for services let (webhook_tx, webhook_rx) = unbounded_channel::(); // RX is discarded here since it can be derived from TX later on - let (gateway_tx, _gateway_rx) = broadcast::channel::(256); + let (gateway_tx, _gateway_rx) = broadcast::channel::(256); let (event_logger_tx, event_logger_rx) = unbounded_channel::(); let (peer_stats_tx, peer_stats_rx) = unbounded_channel::(); diff --git a/crates/defguard_common/src/gateway_event.rs b/crates/defguard_common/src/gateway_event.rs index 7aa0da89a7..c429fef259 100644 --- a/crates/defguard_common/src/gateway_event.rs +++ b/crates/defguard_common/src/gateway_event.rs @@ -13,7 +13,7 @@ use crate::{ }; #[derive(Clone, Debug)] -pub enum GatewayEvent { +pub enum GatewayCommand { NetworkCreated(Id, WireguardNetwork), NetworkModified( Id, @@ -31,22 +31,22 @@ pub enum GatewayEvent { MfaSessionDisconnected(Id, Device), } -/// Sends a [`GatewayEvent`] to the gateway manager service. +/// Sends a [`GatewayCommand`] to the gateway manager service. /// -/// In API handler context prefer `AppState::send_wireguard_event`. -pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { - debug!("Sending the following WireGuard event to Defguard Gateway: {event:?}"); +/// In API handler context prefer `AppState::send_gateway_command`. +pub fn send_gateway_command(event: GatewayCommand, wg_tx: &Sender) { + debug!("Sending the following command to Gateway Manager: {event:?}"); if let Err(err) = wg_tx.send(event) { - error!("Error sending WireGuard event {err}"); + error!("Error sending Gateway command: {err}"); } } -/// Sends multiple [`GatewayEvent`]s to the gateway manager service. +/// Sends multiple [`GatewayCommand`]s to the gateway manager service. /// -/// In API handler context prefer `AppState::send_multiple_wireguard_events`. -pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender) { - debug!("Sending {} WireGuard events", events.len()); +/// In API handler context prefer `AppState::send_multiple_gateway_commands`. +pub fn send_multiple_gateway_commands(events: Vec, wg_tx: &Sender) { + debug!("Sending {} gateway commands", events.len()); for event in events { - send_wireguard_event(event, wg_tx); + send_gateway_command(event, wg_tx); } } diff --git a/crates/defguard_core/src/appstate.rs b/crates/defguard_core/src/appstate.rs index 499775c14d..7566e0e1c2 100644 --- a/crates/defguard_core/src/appstate.rs +++ b/crates/defguard_core/src/appstate.rs @@ -19,7 +19,7 @@ use crate::{ db::{AppEvent, WebHook}, error::WebError, events::ApiEvent, - grpc::{GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, + grpc::{GatewayCommand, send_gateway_command, send_multiple_gateway_commands}, version::IncompatibleComponents, }; @@ -29,7 +29,7 @@ const X_DEFGUARD_EVENT: &str = "x-defguard-event"; pub struct AppState { pub pool: PgPool, tx: UnboundedSender, - pub wireguard_tx: Sender, + pub wireguard_tx: Sender, pub web_reload_tx: tokio::sync::broadcast::Sender<()>, pub failed_logins: Arc>, key: Key, @@ -86,16 +86,16 @@ impl AppState { } } - /// Sends given `GatewayEvent` to be handled by gateway GRPC server. - /// Convenience wrapper around [`send_wireguard_event`] - pub fn send_wireguard_event(&self, event: GatewayEvent) { - send_wireguard_event(event, &self.wireguard_tx); + /// Sends given `GatewayCommand` to be handled by gateway manager service. + /// Convenience wrapper around [`send_gateway_command`] + pub fn send_gateway_command(&self, event: GatewayCommand) { + send_gateway_command(event, &self.wireguard_tx); } - /// Sends multiple events to be handled by gateway GRPC server. - /// Convenience wrapper around [`send_multiple_wireguard_events`] - pub fn send_multiple_wireguard_events(&self, events: Vec) { - send_multiple_wireguard_events(events, &self.wireguard_tx); + /// Sends multiple commands to be handled by gateway manager service. + /// Convenience wrapper around [`send_multiple_gateway_commands`] + pub fn send_multiple_gateway_commands(&self, events: Vec) { + send_multiple_gateway_commands(events, &self.wireguard_tx); } /// Sends event to the main event router @@ -118,7 +118,7 @@ impl AppState { pool: PgPool, tx: UnboundedSender, rx: UnboundedReceiver, - wireguard_tx: Sender, + wireguard_tx: Sender, web_reload_tx: tokio::sync::broadcast::Sender<()>, key: Key, failed_logins: Arc>, diff --git a/crates/defguard_core/src/enterprise/db/models/acl.rs b/crates/defguard_core/src/enterprise/db/models/acl.rs index b63abd59a9..c5f32de6f7 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl.rs @@ -27,7 +27,7 @@ use crate::{ ApiAclRule, EditAclRule, alias::EditAclAlias, destination::EditAclDestination, }, }, - grpc::GatewayEvent, + grpc::GatewayCommand, }; #[derive(Debug, Error)] @@ -547,7 +547,7 @@ impl AclRule { match try_get_location_firewall_config(&location, &mut transaction).await? { Some(firewall_config) => { debug!("Sending firewall update event for location {location}"); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, )); @@ -1696,7 +1696,7 @@ impl AclAlias { match try_get_location_firewall_config(&location, &mut transaction).await? { Some(firewall_config) => { debug!("Sending firewall update event for location {location}"); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, )); diff --git a/crates/defguard_core/src/enterprise/directory_sync/mod.rs b/crates/defguard_core/src/enterprise/directory_sync/mod.rs index f6da8105eb..c927c25e0b 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/mod.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/mod.rs @@ -31,7 +31,7 @@ use crate::{ license::get_cached_license, limits::{get_counts, update_counts}, }, - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::user::check_username, user_management::{delete_user_and_cleanup_devices, disable_user, sync_allowed_user_devices}, }; @@ -324,7 +324,7 @@ async fn sync_user_groups( directory_sync: &T, user: &User, pool: &PgPool, - wg_tx: &Sender, + wg_tx: &Sender, ) -> Result<(), DirectorySyncError> { info!("Syncing groups of user {} with the directory", user.email); let directory_groups = directory_sync.get_user_groups(&user.email).await?; @@ -418,7 +418,7 @@ pub(crate) async fn test_directory_sync_connection( pub async fn sync_user_groups_if_configured( user: &User, pool: &PgPool, - wg_tx: &Sender, + wg_tx: &Sender, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_business_license_active() { @@ -482,7 +482,7 @@ async fn create_and_add_to_group( async fn sync_all_users_groups( directory_sync: &T, pool: &PgPool, - wg_tx: &Sender, + wg_tx: &Sender, all_users: Option<&[DirectoryUser]>, ) -> Result<(), DirectorySyncError> { info!("Syncing all users' groups with the directory, this may take a while..."); @@ -618,7 +618,7 @@ fn is_directory_sync_enabled(provider: Option<&OpenIdProvider>) -> bool { async fn sync_all_users_state( pool: &PgPool, - wg_tx: &Sender, + wg_tx: &Sender, all_users: &[DirectoryUser], ) -> Result<(), DirectorySyncError> { info!("Syncing all users' state with the directory, this may take a while..."); @@ -904,7 +904,7 @@ async fn sync_inactive_directory_users( transaction: &mut PgConnection, inactive_directory_users: &[&DirectoryUser], modified_users: &mut Vec>, - wg_tx: &Sender, + wg_tx: &Sender, ) -> Result<(), DirectorySyncError> { // find all active Defguard users disabled in directory let disabled_users_emails = inactive_directory_users @@ -1005,7 +1005,7 @@ pub(crate) async fn get_directory_sync_interval(pool: &PgPool) -> u64 { // Performs the directory sync job. This function is called by the utility thread. pub(crate) async fn do_directory_sync( pool: &PgPool, - wireguard_tx: &Sender, + wireguard_tx: &Sender, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_business_license_active() { diff --git a/crates/defguard_core/src/enterprise/directory_sync/tests.rs b/crates/defguard_core/src/enterprise/directory_sync/tests.rs index 3eb6a4fff8..2704837910 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/tests.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/tests.rs @@ -146,7 +146,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -186,7 +186,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -217,7 +217,7 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event { assert_eq!(dev.device.user_id, user2.id); } else { panic!("Expected a DeviceDeleted event"); @@ -229,7 +229,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); User::init_admin_user(&pool, "pass123").await.unwrap(); let _ = make_test_provider( @@ -271,7 +271,7 @@ mod test { // Check that we received a device deleted event for whichever admin was removed let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event { assert!(dev.device.user_id == user1.id || dev.device.user_id == user3.id); } else { panic!("Expected a DeviceDeleted event"); @@ -284,7 +284,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -325,7 +325,7 @@ mod test { // Check for device deletion events let event1 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event1 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user2.id @@ -336,7 +336,7 @@ mod test { } let event2 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event2 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user2.id @@ -353,7 +353,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Disable, @@ -401,14 +401,14 @@ mod test { // Check for device disconnection events let event1 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event1 { assert!(dev.device.user_id == user2.id || dev.device.user_id == testuserdisabled.id); } else { panic!("Expected a DeviceDisconnected event"); } let event2 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event2 { assert!(dev.device.user_id == user2.id || dev.device.user_id == testuserdisabled.id); } else { panic!("Expected a DeviceDisconnected event"); @@ -436,7 +436,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); // Added mut wg_rx + let (wg_tx, mut wg_rx) = broadcast::channel::(16); // Added mut wg_rx make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -474,7 +474,7 @@ mod test { // Check for device disconnection events let event1 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event1 { + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event1 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user3.id @@ -485,7 +485,7 @@ mod test { } let event2 = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event2 { + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event2 { assert!( dev.device.user_id == user1.id || dev.device.user_id == user3.id @@ -514,7 +514,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -571,7 +571,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -600,7 +600,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -625,7 +625,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -659,13 +659,13 @@ mod test { .unwrap(); transaction.commit().await.unwrap(); let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceDeleted(dev)) = event { + if let Ok(GatewayCommand::DeviceDeleted(dev)) = event { assert_eq!(dev.device.user_id, user2_pre_sync.id); } else { panic!("Expected DeviceDeleted event"); } let event = wg_rx.try_recv(); - if let Ok(GatewayEvent::DeviceCreated(dev)) = event { + if let Ok(GatewayCommand::DeviceCreated(dev)) = event { panic!("Unexpected DeviceCreated event: {dev:?}"); } } @@ -676,7 +676,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -704,7 +704,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -751,7 +751,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (wg_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -791,7 +791,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); // disable prefetching users make_test_provider( @@ -825,7 +825,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); // enable prefetching users make_test_provider( @@ -862,7 +862,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); // enable prefetching users make_test_provider( diff --git a/crates/defguard_core/src/enterprise/snat/handlers.rs b/crates/defguard_core/src/enterprise/snat/handlers.rs index a591ac6410..23082fc178 100644 --- a/crates/defguard_core/src/enterprise/snat/handlers.rs +++ b/crates/defguard_core/src/enterprise/snat/handlers.rs @@ -21,7 +21,7 @@ use crate::{ }, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{ApiResponse, ApiResult}, }; @@ -159,7 +159,7 @@ pub async fn create_snat_binding( debug!( "Sending firewall config update for location {location} affected by adding new SNAT binding" ); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, )); @@ -259,7 +259,7 @@ pub async fn modify_snat_binding( debug!( "Sending firewall config update for location {location} affected by adding new SNAT binding" ); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( location_id, firewall_config, )); @@ -344,7 +344,7 @@ pub async fn delete_snat_binding( debug!( "Sending firewall config update for location {location} affected by adding new SNAT binding" ); - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( location_id, firewall_config, )); diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 6e8370f3b8..b5130c1393 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -208,7 +208,7 @@ impl From for defguard_proto::client_types::InstanceInfo { } pub use defguard_common::gateway_event::{ - GatewayEvent, send_multiple_wireguard_events, send_wireguard_event, + GatewayCommand, send_multiple_gateway_commands, send_gateway_command, }; /// If this location is marked as a service location, checks if all requirements are met for it to diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 8cc99dd99d..50dcbb06a8 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -51,7 +51,7 @@ use crate::{ posture::{PostureCheckError, PostureResult, validate_posture}, }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, - grpc::{GatewayEvent, utils::parse_client_ip_agent}, + grpc::{GatewayCommand, utils::parse_client_ip_agent}, }; const CLIENT_SESSION_TIMEOUT: u64 = 60 * 5; // 10 minutes @@ -83,7 +83,7 @@ pub struct ClientLoginSession { pub struct ClientMfaServer { pub(crate) pool: PgPool, - wireguard_tx: Sender, + wireguard_tx: Sender, pub(crate) sessions: Arc>>, remote_mfa_responses: Arc>>>, bidi_event_tx: UnboundedSender, @@ -104,7 +104,7 @@ impl ClientMfaServer { #[must_use] pub fn new( pool: PgPool, - wireguard_tx: Sender, + wireguard_tx: Sender, bidi_event_tx: UnboundedSender, remote_mfa_responses: Arc>>>, sessions: Arc>>, @@ -713,7 +713,7 @@ impl ClientMfaServer { // send gateway event debug!("Sending `peer_create` message to gateway"); let event = - GatewayEvent::MfaSessionAuthorized(location.id, device.clone(), gateway_network_info); + GatewayCommand::MfaSessionAuthorized(location.id, device.clone(), gateway_network_info); self.wireguard_tx.send(event).map_err(|err| { error!("Error sending WireGuard event: {err}"); Status::internal("unexpected error") @@ -951,7 +951,7 @@ impl ClientMfaServer { // gateway update is only needed to remove peer for MFA sessions // this is needed to remove peers for both Connected and New sessions if is_mfa_session { - let gateway_event = GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); + let gateway_event = GatewayCommand::MfaSessionDisconnected(location.id, device.clone()); self.wireguard_tx.send(gateway_event).map_err(|err| { error!("Error sending WireGuard event: {err}"); Status::internal("unexpected error") @@ -1021,7 +1021,7 @@ mod tests { use super::{ClientLoginSession, ClientMfaServer}; use crate::{ events::{BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, - grpc::GatewayEvent, + grpc::GatewayCommand, }; const REPLACEMENT_MFA_PRESHARED_KEY: &str = "replacement-mfa-psk"; @@ -1067,7 +1067,7 @@ mod tests { .try_recv() .expect("expected MFA gateway disconnect event for replaced connected session"); match gateway_event { - GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device) => { + GatewayCommand::MfaSessionDisconnected(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } @@ -1142,7 +1142,7 @@ mod tests { .try_recv() .expect("expected MFA gateway disconnect event for replaced new session"); match gateway_event { - GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device) => { + GatewayCommand::MfaSessionDisconnected(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } @@ -1235,7 +1235,7 @@ mod tests { ) -> ( ClientMfaServer, tokio::sync::mpsc::UnboundedReceiver, - tokio::sync::broadcast::Receiver, + tokio::sync::broadcast::Receiver, ) { let (wireguard_tx, wireguard_rx) = broadcast::channel(8); let (bidi_event_tx, bidi_event_rx) = mpsc::unbounded_channel(); @@ -1359,7 +1359,7 @@ mod tests { ); match gateway_rx.try_recv() { - Ok(GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device)) => { + Ok(GatewayCommand::MfaSessionDisconnected(location_id, disconnected_device)) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } diff --git a/crates/defguard_core/src/handlers/network_devices.rs b/crates/defguard_core/src/handlers/network_devices.rs index 015ec84ebb..6ccb1edfe5 100644 --- a/crates/defguard_core/src/handlers/network_devices.rs +++ b/crates/defguard_core/src/handlers/network_devices.rs @@ -34,7 +34,7 @@ use crate::{ firewall::try_get_location_firewall_config, limits::update_counts, }, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{ device_for_admin_or_self, pagination::{PaginatedApiResponse, PaginatedApiResult, PaginationParams}, @@ -670,7 +670,7 @@ pub(crate) async fn add_network_device( .add_to_network(&network, &ips, &mut transaction) .await?; - appstate.send_wireguard_event(GatewayEvent::DeviceCreated(DeviceInfo { + appstate.send_gateway_command(GatewayCommand::DeviceCreated(DeviceInfo { device: device.clone(), network_info: vec![network_info.clone()], })); @@ -681,7 +681,7 @@ pub(crate) async fn add_network_device( if let Some(firewall_config) = try_get_location_firewall_config(&network, &mut transaction).await? { - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( network.id, firewall_config, )); @@ -778,14 +778,14 @@ pub async fn modify_network_device( wireguard_network_device.wireguard_ips = data.assigned_ips; wireguard_network_device.update(&mut *transaction).await?; let device_info = DeviceInfo::from_device(&mut *transaction, device.clone()).await?; - appstate.send_wireguard_event(GatewayEvent::DeviceModified(device_info)); + appstate.send_gateway_command(GatewayCommand::DeviceModified(device_info)); // send firewall update event if ACLs are enabled if device_network.acl_enabled { if let Some(firewall_config) = try_get_location_firewall_config(&device_network, &mut transaction).await? { - appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + appstate.send_gateway_command(GatewayCommand::FirewallConfigChanged( device_network.id, firewall_config, )); diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index b7409663de..aaaa3b3897 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -38,7 +38,7 @@ use crate::{ limits::{get_counts, update_counts}, }, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{gateway::GatewayInfo, network_devices::DeviceWireGuardConfig}, location_management::{ allowed_peers::get_location_allowed_peers, handle_imported_devices, handle_mapped_devices, @@ -260,7 +260,7 @@ pub(crate) async fn create_network( network.add_all_allowed_devices(&mut transaction).await?; info!("Assigning IPs for existing devices in network {network}"); - appstate.send_wireguard_event(GatewayEvent::NetworkCreated(network.id, network.clone())); + appstate.send_gateway_command(GatewayCommand::NetworkCreated(network.id, network.clone())); transaction.commit().await?; @@ -381,7 +381,7 @@ pub(crate) async fn modify_network( let peers = get_location_allowed_peers(&network, &mut *transaction).await?; let maybe_firewall_config = try_get_location_firewall_config(&network, &mut transaction).await?; - appstate.send_wireguard_event(GatewayEvent::NetworkModified( + appstate.send_gateway_command(GatewayCommand::NetworkModified( network.id, network.clone(), peers, @@ -448,7 +448,7 @@ pub(crate) async fn delete_network( } network.clone().delete(&mut *transaction).await?; transaction.commit().await?; - appstate.send_wireguard_event(GatewayEvent::NetworkDeleted(network_id, network_name)); + appstate.send_gateway_command(GatewayCommand::NetworkDeleted(network_id, network_name)); info!( "User {} deleted WireGuard network {network_id}", session.user.username, @@ -655,7 +655,7 @@ pub(crate) async fn import_network( .await?; info!("New network {network} created"); - appstate.send_wireguard_event(GatewayEvent::NetworkCreated(network.id, network.clone())); + appstate.send_gateway_command(GatewayCommand::NetworkCreated(network.id, network.clone())); let reserved_ips = imported_devices .iter() @@ -663,13 +663,13 @@ pub(crate) async fn import_network( .collect::>(); let (devices, gateway_events) = handle_imported_devices(&network, &mut transaction, imported_devices).await?; - appstate.send_multiple_wireguard_events(gateway_events); + appstate.send_multiple_gateway_commands(gateway_events); // assign IPs for other existing devices debug!("Assigning IPs in imported network for remaining existing devices"); let gateway_events = sync_location_allowed_devices(&network, &mut transaction, Some(&reserved_ips)).await?; - appstate.send_multiple_wireguard_events(gateway_events); + appstate.send_multiple_gateway_commands(gateway_events); debug!("Assigned IPs in imported network for remaining existing devices"); transaction.commit().await?; @@ -716,7 +716,7 @@ pub(crate) async fn add_user_devices( // wrap loop in transaction to abort if a device is invalid let mut transaction = appstate.pool.begin().await?; let events = handle_mapped_devices(&network, &mut transaction, &mapped_devices).await?; - appstate.send_multiple_wireguard_events(events); + appstate.send_multiple_gateway_commands(events); transaction.commit().await?; info!( @@ -892,7 +892,7 @@ pub(crate) async fn add_device( "Sending firewall config update for location {location} affected by adding new \ user {username} devices" ); - events.push(GatewayEvent::FirewallConfigChanged( + events.push(GatewayCommand::FirewallConfigChanged( location_id, firewall_config, )); @@ -901,12 +901,12 @@ pub(crate) async fn add_device( } // add peer on relevant gateways - events.push(GatewayEvent::DeviceCreated(DeviceInfo { + events.push(GatewayCommand::DeviceCreated(DeviceInfo { device: device.clone(), network_info: network_info.clone(), })); - appstate.send_multiple_wireguard_events(events); + appstate.send_multiple_gateway_commands(events); let template_locations = configs .iter() @@ -1061,7 +1061,7 @@ pub(crate) async fn modify_device( network_info.push(device_network_info); } } - appstate.send_wireguard_event(GatewayEvent::DeviceModified(DeviceInfo { + appstate.send_gateway_command(GatewayCommand::DeviceModified(DeviceInfo { device: device.clone(), network_info, })); @@ -1186,7 +1186,7 @@ pub(crate) async fn delete_device( debug!( "Sending firewall config update for location {location} affected by deleting user {username} device" ); - events.push(GatewayEvent::FirewallConfigChanged( + events.push(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, )); @@ -1195,10 +1195,10 @@ pub(crate) async fn delete_device( } let device_id = device_info.device.id; - events.push(GatewayEvent::DeviceDeleted(device_info.clone())); + events.push(GatewayCommand::DeviceDeleted(device_info.clone())); // send generated gateway events - appstate.send_multiple_wireguard_events(events); + appstate.send_multiple_gateway_commands(events); // Emit event specific to the device type. match device.device_type { diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index 78cf0329f0..60c7e9c7cf 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -126,7 +126,7 @@ use crate::{ create_snat_binding, delete_snat_binding, list_snat_bindings, modify_snat_binding, }, }, - grpc::{GatewayEvent, WorkerState}, + grpc::{GatewayCommand, WorkerState}, handlers::{ app_info::get_app_info, auth::{ @@ -256,7 +256,7 @@ async fn openapi() -> Json { pub fn build_webapp( webhook_tx: UnboundedSender, webhook_rx: UnboundedReceiver, - wireguard_tx: Sender, + wireguard_tx: Sender, web_reload_tx: tokio::sync::broadcast::Sender<()>, worker_state: Arc>, pool: PgPool, @@ -804,7 +804,7 @@ pub async fn run_web_server( worker_state: Arc>, webhook_tx: UnboundedSender, webhook_rx: UnboundedReceiver, - wireguard_tx: Sender, + wireguard_tx: Sender, web_reload_tx: tokio::sync::broadcast::Sender<()>, pool: PgPool, failed_logins: Arc>, diff --git a/crates/defguard_core/src/location_management/mod.rs b/crates/defguard_core/src/location_management/mod.rs index 0f58a45291..164048d7d5 100644 --- a/crates/defguard_core/src/location_management/mod.rs +++ b/crates/defguard_core/src/location_management/mod.rs @@ -18,7 +18,7 @@ use tokio::sync::broadcast::Sender; use crate::{ enterprise::firewall::{FirewallError, try_get_location_firewall_config}, - grpc::{GatewayEvent, send_multiple_wireguard_events}, + grpc::{GatewayCommand, send_multiple_gateway_commands}, wg_config::ImportedDevice, }; @@ -41,7 +41,7 @@ pub enum LocationManagementError { // run sync_allowed_devices on all wireguard networks pub(crate) async fn sync_all_networks( conn: &mut PgConnection, - wireguard_tx: &Sender, + wireguard_tx: &Sender, ) -> Result<(), LocationManagementError> { info!("Syncing allowed devices for all WireGuard locations"); let locations = WireguardNetwork::all(&mut *conn).await?; @@ -53,14 +53,14 @@ pub(crate) async fn sync_all_networks( if let Some(firewall_config) = try_get_location_firewall_config(&network, &mut *conn).await? { - gateway_events.push(GatewayEvent::FirewallConfigChanged( + gateway_events.push(GatewayCommand::FirewallConfigChanged( network.id, firewall_config, )); } // check if any gateway events need to be sent if !gateway_events.is_empty() { - send_multiple_wireguard_events(gateway_events, wireguard_tx); + send_multiple_gateway_commands(gateway_events, wireguard_tx); } } Ok(()) @@ -75,7 +75,7 @@ pub(crate) async fn sync_location_allowed_devices( location: &WireguardNetwork, conn: &mut PgConnection, reserved_ips: Option<&[IpAddr]>, -) -> Result, LocationManagementError> { +) -> Result, LocationManagementError> { info!("Synchronizing IPs in network {location} for all allowed devices "); // list all allowed devices let mut allowed_devices = location.get_allowed_devices(&mut *conn).await?; @@ -119,7 +119,7 @@ pub(crate) async fn sync_allowed_devices_for_user( conn: &mut PgConnection, user: &User, reserved_ips: Option<&[IpAddr]>, -) -> Result, WireguardNetworkError> { +) -> Result, WireguardNetworkError> { info!("Synchronizing IPs in network {location} for all allowed devices "); // list all allowed devices let allowed_devices = location @@ -160,12 +160,12 @@ pub async fn process_device_access_changes( mut allowed_devices: HashMap>, currently_configured_devices: Vec, reserved_ips: Option<&[IpAddr]>, -) -> Result, WireguardNetworkError> { +) -> Result, WireguardNetworkError> { // Loop through current device configurations; remove no longer allowed, readdress // when necessary; remove processed entry from all devices list initial list should // now contain only devices to be added. let mut used_ips = location.all_used_ips_for_network(&mut *transaction).await?; - let mut events: Vec = Vec::new(); + let mut events: Vec = Vec::new(); for device_network_config in currently_configured_devices { // Device is allowed and an IP was already assigned if let Some(device) = allowed_devices.remove(&device_network_config.device_id) { @@ -186,7 +186,7 @@ pub async fn process_device_access_changes( let network_info = wireguard_network_device .to_device_network_info_runtime(&mut *transaction, location) .await?; - events.push(GatewayEvent::DeviceModified(DeviceInfo { + events.push(GatewayCommand::DeviceModified(DeviceInfo { device, network_info: vec![network_info], })); @@ -204,7 +204,7 @@ pub async fn process_device_access_changes( Device::find_by_id(&mut *transaction, device_network_config.device_id).await? { let network_info = device_network_config.to_device_network_info(location, None); - events.push(GatewayEvent::DeviceDeleted(DeviceInfo { + events.push(GatewayCommand::DeviceDeleted(DeviceInfo { device, network_info: vec![network_info], })); @@ -224,7 +224,7 @@ pub async fn process_device_access_changes( let network_info = wireguard_network_device .to_device_network_info_runtime(&mut *transaction, location) .await?; - events.push(GatewayEvent::DeviceCreated(DeviceInfo { + events.push(GatewayCommand::DeviceCreated(DeviceInfo { device, network_info: vec![network_info], })); @@ -241,7 +241,7 @@ pub(crate) async fn handle_imported_devices( location: &WireguardNetwork, transaction: &mut PgConnection, imported_devices: Vec, -) -> Result<(Vec, Vec), WireguardNetworkError> { +) -> Result<(Vec, Vec), WireguardNetworkError> { let allowed_devices = location.get_allowed_devices(&mut *transaction).await?; // convert to a map for easier processing let allowed_devices: HashMap> = allowed_devices @@ -276,7 +276,7 @@ pub(crate) async fn handle_imported_devices( .to_device_network_info_runtime(&mut *transaction, location) .await?; // send device to connected gateways - events.push(GatewayEvent::DeviceModified(DeviceInfo { + events.push(GatewayCommand::DeviceModified(DeviceInfo { device: existing_device, network_info: vec![network_info], })); @@ -301,7 +301,7 @@ pub(crate) async fn handle_mapped_devices( location: &WireguardNetwork, conn: &mut PgConnection, mapped_devices: &[MappedDevice], -) -> Result, WireguardNetworkError> { +) -> Result, WireguardNetworkError> { info!("Mapping user devices for network {location}"); // get allowed groups for network let allowed_groups = location.get_allowed_groups(&mut *conn).await?; @@ -366,7 +366,7 @@ pub(crate) async fn handle_mapped_devices( // send device to connected gateways if !network_info.is_empty() { - events.push(GatewayEvent::DeviceCreated(DeviceInfo { + events.push(GatewayCommand::DeviceCreated(DeviceInfo { device, network_info, })); @@ -470,11 +470,11 @@ mod test { assert_eq!(events.len(), 2); assert!(events.iter().any(|e| match e { - GatewayEvent::DeviceCreated(info) => info.device.id == device1.id, + GatewayCommand::DeviceCreated(info) => info.device.id == device1.id, _ => false, })); assert!(events.iter().any(|e| match e { - GatewayEvent::DeviceCreated(info) => info.device.id == device2.id, + GatewayCommand::DeviceCreated(info) => info.device.id == device2.id, _ => false, })); @@ -485,7 +485,7 @@ mod test { assert_eq!(events.len(), 1); match &events[0] { - GatewayEvent::DeviceCreated(info) => { + GatewayCommand::DeviceCreated(info) => { assert_eq!(info.device.id, device3.id); } _ => panic!("Expected DeviceCreated event"), @@ -612,7 +612,7 @@ mod test { .unwrap(); assert_eq!(events.len(), 1); match &events[0] { - GatewayEvent::DeviceCreated(info) => { + GatewayCommand::DeviceCreated(info) => { assert_eq!(info.device.id, device1.id); } _ => panic!("Expected DeviceCreated event"), @@ -623,7 +623,7 @@ mod test { .unwrap(); assert_eq!(events.len(), 1); match &events[0] { - GatewayEvent::DeviceCreated(info) => { + GatewayCommand::DeviceCreated(info) => { assert_eq!(info.device.id, device2.id); } _ => panic!("Expected DeviceCreated event"), @@ -634,7 +634,7 @@ mod test { .unwrap(); assert_eq!(events.len(), 1); match &events[0] { - GatewayEvent::DeviceCreated(info) => { + GatewayCommand::DeviceCreated(info) => { assert_eq!(info.device.id, device3.id); } _ => panic!("Expected DeviceCreated event"), diff --git a/crates/defguard_core/src/user_management.rs b/crates/defguard_core/src/user_management.rs index 1fc5d7cea9..640d6b096a 100644 --- a/crates/defguard_core/src/user_management.rs +++ b/crates/defguard_core/src/user_management.rs @@ -10,7 +10,7 @@ use tokio::sync::broadcast::Sender; use crate::{ enterprise::{firewall::try_get_location_firewall_config, limits::update_counts}, error::WebError, - grpc::{GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, + grpc::{GatewayCommand, send_gateway_command, send_multiple_gateway_commands}, location_management::sync_allowed_devices_for_user, }; @@ -18,7 +18,7 @@ use crate::{ pub async fn delete_user_and_cleanup_devices( user: User, conn: &mut PgConnection, - wg_tx: &Sender, + wg_tx: &Sender, ) -> Result<(), WebError> { let username = user.username.clone(); debug!("Deleting user {username}, removing his devices from gateways and updating ldap...",); @@ -33,7 +33,7 @@ pub async fn delete_user_and_cleanup_devices( for network_info in &device_info.network_info { affected_location_ids.insert(network_info.network_id); } - events.push(GatewayEvent::DeviceDeleted(device_info)); + events.push(GatewayCommand::DeviceDeleted(device_info)); } user.delete(&mut *conn).await?; @@ -49,7 +49,7 @@ pub async fn delete_user_and_cleanup_devices( debug!( "Sending firewall config update for location {location} affected by deleting user {username} devices" ); - events.push(GatewayEvent::FirewallConfigChanged( + events.push(GatewayCommand::FirewallConfigChanged( location_id, firewall_config, )); @@ -57,7 +57,7 @@ pub async fn delete_user_and_cleanup_devices( } } - send_multiple_wireguard_events(events, wg_tx); + send_multiple_gateway_commands(events, wg_tx); info!( "The user {} has been deleted and his devices removed from gateways.", &username @@ -69,7 +69,7 @@ pub async fn delete_user_and_cleanup_devices( pub async fn disable_user( user: &mut User, conn: &mut PgConnection, - wg_tx: &Sender, + wg_tx: &Sender, ) -> Result<(), WebError> { user.is_active = false; user.save(&mut *conn).await?; @@ -82,7 +82,7 @@ pub async fn disable_user( pub async fn sync_allowed_user_devices( user: &User, conn: &mut PgConnection, - wg_tx: &Sender, + wg_tx: &Sender, ) -> Result<(), WebError> { debug!("Syncing allowed devices of user {}", user.username); let locations = WireguardNetwork::all(&mut *conn).await?; @@ -93,15 +93,15 @@ pub async fn sync_allowed_user_devices( // check if any peers were updated if !gateway_events.is_empty() { // send peer update events - send_multiple_wireguard_events(gateway_events, wg_tx); + send_multiple_gateway_commands(gateway_events, wg_tx); } // send firewall config update if ACLs & enterprise features are enabled if let Some(firewall_config) = try_get_location_firewall_config(&location, &mut *conn).await? { - send_wireguard_event( - GatewayEvent::FirewallConfigChanged(location.id, firewall_config), + send_gateway_command( + GatewayCommand::FirewallConfigChanged(location.id, firewall_config), wg_tx, ); } diff --git a/crates/defguard_core/src/utility_thread.rs b/crates/defguard_core/src/utility_thread.rs index 694d8ebedf..3c0088faf2 100644 --- a/crates/defguard_core/src/utility_thread.rs +++ b/crates/defguard_core/src/utility_thread.rs @@ -26,7 +26,7 @@ use crate::{ ldap::{do_ldap_sync, sync::get_ldap_sync_interval}, limits::update_counts, }, - grpc::GatewayEvent, + grpc::GatewayCommand, letsencrypt::do_letsencrypt_refresh, location_management::allowed_peers::get_location_allowed_peers, updates::do_new_version_check, @@ -46,7 +46,7 @@ const ACL_EXPIRY_SYSTEM_ACTOR: &str = "system:acl-expiry"; #[instrument(skip_all)] pub async fn run_utility_thread( pool: &PgPool, - wireguard_tx: broadcast::Sender, + wireguard_tx: broadcast::Sender, proxy_control_tx: mpsc::Sender, web_reload_tx: broadcast::Sender<()>, ) -> Result<(), anyhow::Error> { @@ -199,7 +199,7 @@ pub async fn run_utility_thread( /// Check if enterprise status has changed and perform any necessary actions async fn enterprise_status_check( pool: &PgPool, - wireguard_tx: broadcast::Sender, + wireguard_tx: broadcast::Sender, enable_enterprise: bool, ) -> Result<(), anyhow::Error> { // fetch all ACL-enabled networks @@ -221,13 +221,13 @@ async fn enterprise_status_check( // Handle service location update or just update the firewall if location.service_location_mode == ServiceLocationMode::Disabled { - wireguard_tx.send(GatewayEvent::FirewallConfigChanged( + wireguard_tx.send(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, ))?; } else { let new_peers = get_location_allowed_peers(&location, &mut *transaction).await?; - wireguard_tx.send(GatewayEvent::NetworkModified( + wireguard_tx.send(GatewayCommand::NetworkModified( location.id, location, new_peers, @@ -242,13 +242,13 @@ async fn enterprise_status_check( for location in locations { if location.service_location_mode == ServiceLocationMode::Disabled { debug!("Disabling gateway firewall configuration for location {location:?}"); - wireguard_tx.send(GatewayEvent::FirewallDisabled(location.id))?; + wireguard_tx.send(GatewayCommand::FirewallDisabled(location.id))?; } else { debug!( "Disabling gateway firewall configuration and service location client \ connections for location {location}" ); - wireguard_tx.send(GatewayEvent::NetworkModified( + wireguard_tx.send(GatewayCommand::NetworkModified( location.id, location, // Send empty peer list, we are disabling the service location @@ -265,7 +265,7 @@ async fn enterprise_status_check( /// Find newly expired ACL rules and update their status. async fn expired_acl_rules_check( pool: &PgPool, - wireguard_tx: broadcast::Sender, + wireguard_tx: broadcast::Sender, ) -> Result<(), anyhow::Error> { // mark relevant rules as expired let updated_rules = query_as!( @@ -309,7 +309,7 @@ async fn expired_acl_rules_check( match try_get_location_firewall_config(&location, &mut conn).await? { Some(firewall_config) => { debug!("Sending firewall update event for location {location}"); - wireguard_tx.send(GatewayEvent::FirewallConfigChanged( + wireguard_tx.send(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, ))?; diff --git a/crates/defguard_core/tests/integration/api/common/mod.rs b/crates/defguard_core/tests/integration/api/common/mod.rs index 350f010d92..fc62437088 100644 --- a/crates/defguard_core/tests/integration/api/common/mod.rs +++ b/crates/defguard_core/tests/integration/api/common/mod.rs @@ -30,7 +30,7 @@ use defguard_core::{ db::AppEvent, enterprise::license::{License, LicenseTier, SupportType, set_cached_license}, events::ApiEvent, - grpc::{GatewayEvent, WorkerState}, + grpc::{GatewayCommand, WorkerState}, handlers::{Auth, user::UserDetails}, }; use reqwest::{StatusCode, header::HeaderName}; @@ -60,7 +60,7 @@ pub const X_FORWARDED_URI: HeaderName = HeaderName::from_static("x-forwarded-uri pub(crate) struct ClientState { pub pool: PgPool, pub worker_state: Arc>, - pub wireguard_rx: Receiver, + pub wireguard_rx: Receiver, pub test_user: User, #[allow(dead_code)] pub config: DefGuardConfig, @@ -70,7 +70,7 @@ impl ClientState { pub fn new( pool: PgPool, worker_state: Arc>, - wireguard_rx: Receiver, + wireguard_rx: Receiver, test_user: User, config: DefGuardConfig, ) -> Self { @@ -92,7 +92,7 @@ pub(crate) async fn make_base_client( let (api_event_tx, api_event_rx) = unbounded_channel::(); let (tx, rx) = unbounded_channel::(); let worker_state = Arc::new(Mutex::new(WorkerState::new(tx.clone()))); - let (wg_tx, wg_rx) = broadcast::channel::(16); + let (wg_tx, wg_rx) = broadcast::channel::(16); let failed_logins = FailedLoginMap::new(); let failed_logins = Arc::new(Mutex::new(failed_logins)); diff --git a/crates/defguard_core/tests/integration/api/proxy_certs.rs b/crates/defguard_core/tests/integration/api/proxy_certs.rs index 0473395ad1..aeb461a73c 100644 --- a/crates/defguard_core/tests/integration/api/proxy_certs.rs +++ b/crates/defguard_core/tests/integration/api/proxy_certs.rs @@ -30,7 +30,7 @@ use defguard_core::{ db::AppEvent, enterprise::license::{License, LicenseTier, SupportType, set_cached_license}, events::ApiEvent, - grpc::{GatewayEvent, WorkerState}, + grpc::{GatewayCommand, WorkerState}, handlers::Auth, }; use reqwest::StatusCode; @@ -114,7 +114,7 @@ async fn make_test_client_with_proxy_rx( let (api_event_tx, api_event_rx) = unbounded_channel::(); let (tx, rx) = unbounded_channel::(); let worker_state = Arc::new(Mutex::new(WorkerState::new(tx.clone()))); - let (wg_tx, _wg_rx) = broadcast::channel::(16); + let (wg_tx, _wg_rx) = broadcast::channel::(16); let failed_logins = Arc::new(Mutex::new(FailedLoginMap::new())); diff --git a/crates/defguard_core/tests/integration/api/wireguard.rs b/crates/defguard_core/tests/integration/api/wireguard.rs index 0ca61f6d88..f0db38bbe3 100644 --- a/crates/defguard_core/tests/integration/api/wireguard.rs +++ b/crates/defguard_core/tests/integration/api/wireguard.rs @@ -22,7 +22,7 @@ use defguard_core::{ license::{License, LicenseTier, SupportType, get_cached_license, set_cached_license}, limits::update_counts, }, - grpc::{GatewayEvent, proto::enterprise::license::LicenseLimits}, + grpc::{GatewayCommand, proto::enterprise::license::LicenseLimits}, handlers::{Auth, GroupInfo, wireguard::WireguardNetworkData}, }; use ipnetwork::IpNetwork; @@ -55,7 +55,7 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // check vpn locations for `admin` group let admin_id = Group::find_by_name(&client_state.pool, "admin") @@ -103,7 +103,7 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { ); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkModified(..)); + assert_matches!(event, GatewayCommand::NetworkModified(..)); // check vpn locations for `admin` group let response = client.get(format!("/api/v1/group/{admin_id}")).send().await; @@ -135,7 +135,7 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkDeleted(..)); + assert_matches!(event, GatewayCommand::NetworkDeleted(..)); } #[sqlx::test] @@ -516,7 +516,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { // create network make_network(&client, "network").await; let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network details let response = client.get("/api/v1/network/1").send().await; @@ -535,7 +535,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::CREATED); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceCreated(..)); + assert_matches!(event, GatewayCommand::DeviceCreated(..)); // an IP was assigned for new device let network_devices = WireguardNetworkDevice::find_by_device(&client_state.pool, 1) @@ -549,7 +549,10 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { // add another network make_network(&client, "network").await; - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkCreated(..)); + assert_matches!( + wg_rx.try_recv().unwrap(), + GatewayCommand::NetworkCreated(..) + ); // an IP was assigned for an existing device let network_devices = WireguardNetworkDevice::find_by_device(&client_state.pool, 1) @@ -595,7 +598,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceModified(..)); + assert_matches!(event, GatewayCommand::DeviceModified(..)); // device details let response = client @@ -637,7 +640,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkDeleted(..)); + assert_matches!(event, GatewayCommand::NetworkDeleted(..)); // delete device let response = client @@ -646,7 +649,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceDeleted(..)); + assert_matches!(event, GatewayCommand::DeviceDeleted(..)); let response = client.get("/api/v1/device").json(&device).send().await; assert_eq!(response.status(), StatusCode::OK); @@ -939,7 +942,7 @@ async fn test_device_pubkey(_: PgPoolOptions, options: PgConnectOptions) { // create network make_network(&client, "network").await; let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network details let response = client.get("/api/v1/network/1").send().await; diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs b/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs index 85d9400ea7..d9f4d97b0c 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs @@ -16,7 +16,7 @@ use defguard_common::{ }, }; use defguard_core::{ - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{ Auth, wireguard::{ImportedNetworkData, WireguardNetworkData}, @@ -180,7 +180,7 @@ async fn test_create_new_network(_: PgPoolOptions, options: PgConnectOptions) { let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); assert_err!(wg_rx.try_recv()); // network configuration was created only for admin and allowed user @@ -235,7 +235,10 @@ async fn test_create_new_network_allow_all_groups(_: PgPoolOptions, options: PgC .await .unwrap(); assert_eq!(allowed_groups, vec!["allowed group"]); - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkCreated(..)); + assert_matches!( + wg_rx.try_recv().unwrap(), + GatewayCommand::NetworkCreated(..) + ); let peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -287,7 +290,7 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network configuration was created for admin and the allowed group member let peers = get_location_allowed_peers(&network, &client_state.pool) @@ -321,7 +324,10 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); + assert_matches!( + wg_rx.try_recv().unwrap(), + GatewayCommand::NetworkModified(..) + ); let new_peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -354,7 +360,10 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); + assert_matches!( + wg_rx.try_recv().unwrap(), + GatewayCommand::NetworkModified(..) + ); let new_peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -388,7 +397,10 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); + assert_matches!( + wg_rx.try_recv().unwrap(), + GatewayCommand::NetworkModified(..) + ); let new_peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -437,7 +449,10 @@ async fn test_modify_network_enable_allow_all_groups(_: PgPoolOptions, options: .await; assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkCreated(..)); + assert_matches!( + wg_rx.try_recv().unwrap(), + GatewayCommand::NetworkCreated(..) + ); let peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -474,7 +489,10 @@ async fn test_modify_network_enable_allow_all_groups(_: PgPoolOptions, options: .await .unwrap(); assert_eq!(allowed_groups, vec!["allowed group"]); - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); + assert_matches!( + wg_rx.try_recv().unwrap(), + GatewayCommand::NetworkModified(..) + ); let peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -554,10 +572,10 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne assert_eq!(peers[1].pubkey, devices[1].wireguard_pubkey); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network config was only created for one of the existing devices and the admin device - let GatewayEvent::DeviceModified(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceModified(device_info) = wg_rx.try_recv().unwrap() else { panic!() }; assert_eq!(device_info.device.id, devices[1].id); @@ -568,7 +586,7 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne peers[1].allowed_ips[0] ); - let GatewayEvent::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { panic!() }; assert_eq!(device_info.device.id, devices[0].id); @@ -662,7 +680,7 @@ PersistentKeepalive = 300 assert_eq!(peers[3].pubkey, mapped_devices[1].wireguard_pubkey); // assert events - let GatewayEvent::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { panic!() }; assert_eq!( @@ -676,7 +694,7 @@ PersistentKeepalive = 300 mapped_devices[0].wireguard_ips, ); - let GatewayEvent::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { panic!() }; assert_eq!( @@ -734,7 +752,7 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); assert_err!(wg_rx.try_recv()); // network configuration was created only for admin and allowed user @@ -756,7 +774,7 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::OK); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceDeleted(..)); + assert_matches!(event, GatewayCommand::DeviceDeleted(..)); assert_err!(wg_rx.try_recv()); let peers = get_location_allowed_peers(&network, &client_state.pool) @@ -794,7 +812,7 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::OK); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceCreated(..)); + assert_matches!(event, GatewayCommand::DeviceCreated(..)); assert_err!(wg_rx.try_recv()); let peers = get_location_allowed_peers(&network, &client_state.pool) @@ -845,7 +863,10 @@ async fn test_modify_user_no_effect_when_allow_all_groups( .await; assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkCreated(..)); + assert_matches!( + wg_rx.try_recv().unwrap(), + GatewayCommand::NetworkCreated(..) + ); assert_err!(wg_rx.try_recv()); let peers = get_location_allowed_peers(&network, &client_state.pool) @@ -958,7 +979,7 @@ async fn test_delete_only_allowed_group(_: PgPoolOptions, options: PgConnectOpti let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); let peers = get_location_allowed_peers(&network, &client_state.pool) .await diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs index 071847c0f7..e7c71866ff 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs @@ -5,7 +5,7 @@ use defguard_common::db::{ models::{Device, WireguardNetwork}, }; use defguard_core::{ - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{Auth, network_devices::AddNetworkDevice}, }; use ipnetwork::IpNetwork; @@ -109,12 +109,12 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { let network_1: WireguardNetwork = response.json().await; assert_eq!(network_1.name, "network"); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); let response = make_second_network(&client).await; let network_2: WireguardNetwork = response.json().await; assert_eq!(network_2.name, "network-2"); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // ip suggestions let response = client.get("/api/v1/device/network/ip/1").send().await; @@ -202,7 +202,7 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { let config_text = json["config"]["config"].as_str().unwrap(); assert!(configured); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceCreated(..)); + assert_matches!(event, GatewayCommand::DeviceCreated(..)); // download WG config let response = client.get("/api/v1/device/network/1/config").send().await; @@ -239,7 +239,7 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(device.name, "device-1"); assert_eq!(device.description, Some("new description".to_owned())); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::DeviceModified(..)); + assert_matches!(event, GatewayCommand::DeviceModified(..)); // Make sure the device is only in the selected network let device_networks = diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs index d176b38854..dde6414091 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs @@ -6,7 +6,7 @@ use defguard_common::db::models::{ wireguard::{LocationMfaMode, ServiceLocationMode}, }; use defguard_core::{ - grpc::GatewayEvent, + grpc::GatewayCommand, handlers::{Auth, wireguard::ImportedNetworkData}, }; use matches::assert_matches; @@ -138,11 +138,14 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(network.allowed_ips, vec!["10.0.0.0/24".parse().unwrap()]); assert_eq!(network.connected_at, None); let event = wg_rx.try_recv().unwrap(); - assert_matches!(event, GatewayEvent::NetworkCreated(..)); + assert_matches!(event, GatewayCommand::NetworkCreated(..)); // existing devices assertion // imported config for an existing device - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::DeviceModified(..)); + assert_matches!( + wg_rx.try_recv().unwrap(), + GatewayCommand::DeviceModified(..) + ); let user_device_1 = UserDevice::from_device(&pool, device_1) .await .unwrap() @@ -153,7 +156,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { vec!["10.0.0.12"] ); // generated IP for other existing device - assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::DeviceCreated(..)); + assert_matches!(wg_rx.try_recv().unwrap(), GatewayCommand::DeviceCreated(..)); let user_device_2 = UserDevice::from_device(&pool, device_2) .await .unwrap() @@ -205,7 +208,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { // assert events let event = wg_rx.try_recv().unwrap(); match event { - GatewayEvent::DeviceCreated(device_info) => { + GatewayCommand::DeviceCreated(device_info) => { assert_eq!(device_info.device.name, "device_1"); } _ => unreachable!("Invalid event type received"), @@ -213,7 +216,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { let event = wg_rx.try_recv().unwrap(); match event { - GatewayEvent::DeviceCreated(device_info) => { + GatewayCommand::DeviceCreated(device_info) => { assert_eq!(device_info.device.name, "device_2"); } _ => unreachable!("Invalid event type received"), diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index cc4e74a66b..069f9f07e7 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -114,7 +114,7 @@ mod tests { wireguard::{LocationMfaMode, ServiceLocationMode}, }, }; - use defguard_common::gateway_event::GatewayEvent; + use defguard_common::gateway_event::GatewayCommand; use defguard_core::events::{BidiRequestContext, BidiStreamEventType}; use tokio::sync::{Notify, broadcast, mpsc::unbounded_channel}; @@ -161,7 +161,7 @@ mod tests { let (_bidi_tx, bidi_rx) = unbounded_channel(); let (_session_manager_tx, session_manager_rx) = unbounded_channel(); let (event_logger_tx, event_logger_rx) = unbounded_channel(); - let (wireguard_tx, _wireguard_rx) = broadcast::channel::(1); + let (wireguard_tx, _wireguard_rx) = broadcast::channel::(1); ( EventRouter::new( diff --git a/crates/defguard_event_router/src/lib.rs b/crates/defguard_event_router/src/lib.rs index ffbc495253..d8ad3fb79a 100644 --- a/crates/defguard_event_router/src/lib.rs +++ b/crates/defguard_event_router/src/lib.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_core::events::{ApiEvent, BidiStreamEvent}; use defguard_event_logger::message::{EventContext, EventLoggerMessage, LoggerEvent}; use defguard_session_manager::events::SessionManagerEvent; @@ -61,7 +61,7 @@ impl RouterReceiverSet { struct EventRouter { receivers: RouterReceiverSet, event_logger_tx: UnboundedSender, - wireguard_tx: Sender, + wireguard_tx: Sender, activity_log_stream_reload_notify: Arc, } @@ -87,7 +87,7 @@ impl EventRouter { fn new( receivers: RouterReceiverSet, event_logger_tx: UnboundedSender, - wireguard_tx: Sender, + wireguard_tx: Sender, activity_log_stream_reload_notify: Arc, ) -> Self { Self { @@ -138,7 +138,7 @@ impl EventRouter { pub async fn run_event_router( receivers: RouterReceiverSet, event_logger_tx: UnboundedSender, - wireguard_tx: Sender, + wireguard_tx: Sender, activity_log_stream_reload_notify: Arc, ) -> Result<(), EventRouterError> { info!("Starting main event router service"); diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index c0a7232241..32a6c23030 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -12,7 +12,7 @@ use std::{ }; use chrono::{DateTime, TimeDelta}; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_common::{ VERSION, db::{ @@ -88,7 +88,7 @@ pub(crate) struct GatewayHandler { gateway: Gateway, message_id: AtomicU64, pool: PgPool, - events_tx: Sender, + events_tx: Sender, peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, updates_handler_handle: Option>, @@ -102,7 +102,7 @@ impl GatewayHandler { pub fn new( gateway: Gateway, pool: PgPool, - events_tx: Sender, + events_tx: Sender, peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, ) -> Result { @@ -564,7 +564,7 @@ impl GatewayHandler { pub(crate) fn new_with_test_socket( gateway: Gateway, pool: PgPool, - events_tx: Sender, + events_tx: Sender, peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, socket_path: PathBuf, @@ -620,7 +620,7 @@ struct GatewayUpdatesHandler { network_id: Id, network: WireguardNetwork, gateway_name: String, - events_rx: broadcast::Receiver, + events_rx: broadcast::Receiver, tx: UnboundedSender, } @@ -630,7 +630,7 @@ impl GatewayUpdatesHandler { network_id: Id, network: WireguardNetwork, gateway_name: String, - events_rx: broadcast::Receiver, + events_rx: broadcast::Receiver, tx: UnboundedSender, ) -> Self { Self { @@ -722,7 +722,7 @@ impl GatewayUpdatesHandler { while let Ok(update) = self.events_rx.recv().await { debug!("Received WireGuard update: {update:?}"); let result = match update { - GatewayEvent::NetworkCreated(network_id, network) => { + GatewayCommand::NetworkCreated(network_id, network) => { if network_id == self.network_id { self.send_network_update( &network, @@ -734,7 +734,7 @@ impl GatewayUpdatesHandler { Ok(()) } } - GatewayEvent::NetworkModified( + GatewayCommand::NetworkModified( network_id, network, peers, @@ -754,14 +754,14 @@ impl GatewayUpdatesHandler { Ok(()) } } - GatewayEvent::NetworkDeleted(network_id, network_name) => { + GatewayCommand::NetworkDeleted(network_id, network_name) => { if network_id == self.network_id { self.send_network_delete(&network_name) } else { Ok(()) } } - GatewayEvent::DeviceCreated(device) => { + GatewayCommand::DeviceCreated(device) => { // check if a peer has to be added in the current network match device .network_info @@ -777,7 +777,7 @@ impl GatewayUpdatesHandler { None => Ok(()), } } - GatewayEvent::DeviceModified(device) => { + GatewayCommand::DeviceModified(device) => { // check if a peer has to be updated in the current network match device .network_info @@ -793,7 +793,7 @@ impl GatewayUpdatesHandler { None => Ok(()), } } - GatewayEvent::DeviceDeleted(device) => { + GatewayCommand::DeviceDeleted(device) => { // check if a peer has to be updated in the current network match device .network_info @@ -804,28 +804,28 @@ impl GatewayUpdatesHandler { None => Ok(()), } } - GatewayEvent::FirewallConfigChanged(location_id, firewall_config) => { + GatewayCommand::FirewallConfigChanged(location_id, firewall_config) => { if location_id == self.network_id { self.send_firewall_update(firewall_config) } else { Ok(()) } } - GatewayEvent::FirewallDisabled(location_id) => { + GatewayCommand::FirewallDisabled(location_id) => { if location_id == self.network_id { self.send_firewall_disable() } else { Ok(()) } } - GatewayEvent::MfaSessionDisconnected(location_id, device) => { + GatewayCommand::MfaSessionDisconnected(location_id, device) => { if location_id == self.network_id { self.send_peer_delete(&device.wireguard_pubkey) } else { Ok(()) } } - GatewayEvent::MfaSessionAuthorized(location_id, device, network_info) => { + GatewayCommand::MfaSessionAuthorized(location_id, device, network_info) => { if location_id == self.network_id { if network_info.network_id != location_id { error!( @@ -1077,7 +1077,7 @@ mod tests { }, setup_pool, }; - use defguard_common::gateway_event::GatewayEvent; + use defguard_common::gateway_event::GatewayCommand; use defguard_common::gateway_types::{FirewallConfig, FirewallPolicy, WireguardPeer}; use defguard_proto::{ enterprise::firewall::FirewallPolicy as ProtoFirewallPolicy, @@ -1405,7 +1405,7 @@ mod tests { .save(&pool) .await .unwrap(); - let (events_tx, _events_rx) = broadcast::channel::(1); + let (events_tx, _events_rx) = broadcast::channel::(1); let (peer_stats_tx, _peer_stats_rx) = unbounded_channel(); let (_certs_tx, certs_rx) = watch::channel(Arc::new(HashMap::::new())); let handler = diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index 323431d41b..d590549617 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -9,7 +9,7 @@ use std::{ sync::atomic::{AtomicBool, Ordering}, }; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_common::{ db::{ChangeNotification, Id, TriggerOperation, models::gateway::Gateway}, messages::peer_stats_update::PeerStatsUpdate, @@ -656,14 +656,14 @@ mod unit_tests { /// events, notifications, and side effects to Core components. #[derive(Clone)] pub struct GatewayTxSet { - events: Sender, + events: Sender, peer_stats: UnboundedSender, } impl GatewayTxSet { #[must_use] pub const fn new( - events: Sender, + events: Sender, peer_stats: UnboundedSender, ) -> Self { Self { events, peer_stats } diff --git a/crates/defguard_gateway_manager/src/tests/common/mod.rs b/crates/defguard_gateway_manager/src/tests/common/mod.rs index 6290016be0..6ce372d4aa 100644 --- a/crates/defguard_gateway_manager/src/tests/common/mod.rs +++ b/crates/defguard_gateway_manager/src/tests/common/mod.rs @@ -11,7 +11,7 @@ use std::{ time::Duration, }; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_common::{ db::{ Id, NoId, @@ -476,7 +476,7 @@ pub(crate) struct HandlerTestContext { pub(crate) network: WireguardNetwork, pub(crate) gateway: Gateway, pub(crate) peer_stats_rx: UnboundedReceiver, - events_tx: Option>, + events_tx: Option>, pub(crate) mock_gateway: Option, handler_task: Option>>, } @@ -489,7 +489,7 @@ impl HandlerTestContext { pub(crate) async fn new_with_events_tx( options: PgConnectOptions, - events_tx: broadcast::Sender, + events_tx: broadcast::Sender, ) -> Self { let pool = setup_pool(options).await; initialize_current_settings(&pool) @@ -524,7 +524,7 @@ impl HandlerTestContext { } } - pub(crate) fn events_tx(&self) -> &broadcast::Sender { + pub(crate) fn events_tx(&self) -> &broadcast::Sender { self.events_tx .as_ref() .expect("events sender already taken from context") diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs index afe1ebd7a8..fd66630cff 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs @@ -2,7 +2,7 @@ mod support; use defguard_common::db::models::device::{DeviceInfo, WireguardNetworkDevice}; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_proto::gateway::{UpdateType, core_response}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tonic::Status; diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs index a15b8cbd14..6008540cf8 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs @@ -18,7 +18,7 @@ async fn test_device_created_for_network_produces_peer_create_update( assert_send_ok!( context .events_tx() - .send(GatewayEvent::DeviceCreated(device_info)), + .send(GatewayCommand::DeviceCreated(device_info)), "failed to broadcast created device event" ); @@ -46,7 +46,7 @@ async fn test_device_created_before_config_handshake_is_ignored( "created-before-config-device", "tND8hJQhYnI8naBTo59He43zYldagfjlwmSxWEc01Cc=", "10.10.0.11", - GatewayEvent::DeviceCreated, + GatewayCommand::DeviceCreated, ) .await; } @@ -69,10 +69,11 @@ async fn test_device_modified_for_network_produces_peer_modify_update( ) .await; - let mut network_device = WireguardNetworkDevice::find(&context.pool, device.id, context.network.id) - .await - .expect("failed to load device network info") - .expect("expected device network info for modified device"); + let mut network_device = + WireguardNetworkDevice::find(&context.pool, device.id, context.network.id) + .await + .expect("failed to load device network info") + .expect("expected device network info for modified device"); network_device.wireguard_ips = vec![parse_test_ip("10.10.0.21")]; network_device .update(&context.pool) @@ -86,7 +87,7 @@ async fn test_device_modified_for_network_produces_peer_modify_update( assert_send_ok!( context .events_tx() - .send(GatewayEvent::DeviceModified(device_info)), + .send(GatewayCommand::DeviceModified(device_info)), "failed to broadcast modified device event" ); @@ -114,7 +115,7 @@ async fn test_device_modified_before_config_handshake_is_ignored( "modified-before-config-device", "wyFOHCec/Fi9s+cARikVO71JhyYtYMk0FrQx3fK2PTM=", "10.10.0.22", - GatewayEvent::DeviceModified, + GatewayCommand::DeviceModified, ) .await; } @@ -138,7 +139,7 @@ async fn test_device_deleted_for_network_produces_peer_delete_update( assert_send_ok!( context .events_tx() - .send(GatewayEvent::DeviceDeleted(device_info)), + .send(GatewayCommand::DeviceDeleted(device_info)), "failed to broadcast deleted device event" ); @@ -166,7 +167,7 @@ async fn test_device_deleted_before_config_handshake_is_ignored( "deleted-before-config-device", "m84QJmDMkqdCj8AB2NTE8F55W7M/i3CaaD3eQbQdInY=", "10.10.0.31", - GatewayEvent::DeviceDeleted, + GatewayCommand::DeviceDeleted, ) .await; } @@ -181,7 +182,7 @@ async fn test_device_created_for_different_network_is_ignored( "created-other-network-device", "W6wBmd8wgTwvCyGqDRXk6Hf4OMqDUbUn2XWKnG5wVVQ=", "10.11.0.10", - GatewayEvent::DeviceCreated, + GatewayCommand::DeviceCreated, ) .await; } @@ -196,7 +197,7 @@ async fn test_device_modified_for_different_network_is_ignored( "modified-other-network-device", "yjuzq0cLk3Ww5oQcqK6YkSKwXnqQ1V9OlSMFAEkr0lU=", "10.11.0.20", - GatewayEvent::DeviceModified, + GatewayCommand::DeviceModified, ) .await; } @@ -211,7 +212,7 @@ async fn test_device_deleted_for_different_network_is_ignored( "deleted-other-network-device", "Jtp+K8xnFXuF4cae+tVGZNwoSM2fXjJbRl3sI6rdcAQ=", "10.11.0.30", - GatewayEvent::DeviceDeleted, + GatewayCommand::DeviceDeleted, ) .await; } diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs index 1545aab35a..5086f9dda3 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs @@ -11,7 +11,7 @@ async fn test_matching_location_firewall_config_changed_event_produces_update( assert_send_ok!( context .events_tx() - .send(GatewayEvent::FirewallConfigChanged( + .send(GatewayCommand::FirewallConfigChanged( context.network.id, expected_firewall_config.clone(), )), @@ -37,7 +37,7 @@ async fn test_matching_location_firewall_disabled_event_produces_disable_update( assert_send_ok!( context .events_tx() - .send(GatewayEvent::FirewallDisabled(context.network.id)), + .send(GatewayCommand::FirewallDisabled(context.network.id)), "failed to broadcast firewall disabled event" ); @@ -56,7 +56,7 @@ async fn test_different_location_firewall_config_changed_event_is_ignored( let expected_firewall_config = build_test_firewall_config(); assert_firewall_event_for_different_network_is_ignored(options, move |other_network_id| { - GatewayEvent::FirewallConfigChanged(other_network_id, expected_firewall_config) + GatewayCommand::FirewallConfigChanged(other_network_id, expected_firewall_config) }) .await; } @@ -67,7 +67,7 @@ async fn test_different_location_firewall_disabled_event_is_ignored( options: PgConnectOptions, ) { assert_firewall_event_for_different_network_is_ignored(options, |other_network_id| { - GatewayEvent::FirewallDisabled(other_network_id) + GatewayCommand::FirewallDisabled(other_network_id) }) .await; } diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs index 30b6da8895..78c0ce8fbc 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs @@ -18,7 +18,7 @@ async fn test_matching_location_mfa_session_authorized_produces_peer_create( .await; assert_send_ok!( - context.events_tx().send(GatewayEvent::MfaSessionAuthorized( + context.events_tx().send(GatewayCommand::MfaSessionAuthorized( context.network.id, device, network_device, @@ -64,7 +64,7 @@ async fn test_mfa_session_authorized_with_mismatched_network_id_is_ignored( .await; assert_send_ok!( - context.events_tx().send(GatewayEvent::MfaSessionAuthorized( + context.events_tx().send(GatewayCommand::MfaSessionAuthorized( context.network.id, device, network_device, @@ -98,7 +98,7 @@ async fn test_matching_location_mfa_session_disconnected_produces_peer_delete( assert_send_ok!( context .events_tx() - .send(GatewayEvent::MfaSessionDisconnected( + .send(GatewayCommand::MfaSessionDisconnected( context.network.id, device, )), diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs index e7aeddcdbe..1763191513 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs @@ -8,7 +8,7 @@ async fn test_matching_location_network_deleted_event_produces_delete_update( let _ = context.complete_config_handshake().await; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkDeleted( + context.events_tx().send(GatewayCommand::NetworkDeleted( context.network.id, context.network.name.clone(), )), @@ -45,7 +45,7 @@ async fn test_matching_location_network_modified_event_produces_modify_update( modified_network.fwmark = 42; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkModified( + context.events_tx().send(GatewayCommand::NetworkModified( context.network.id, modified_network, Vec::new(), @@ -92,7 +92,7 @@ async fn test_matching_location_network_created_event_produces_create_update( created_network.fwmark = 17; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkCreated( + context.events_tx().send(GatewayCommand::NetworkCreated( context.network.id, created_network, )), @@ -145,7 +145,7 @@ async fn test_only_matching_handler_receives_network_modified_update( assert_send_ok!( matching_context .events_tx() - .send(GatewayEvent::NetworkModified( + .send(GatewayCommand::NetworkModified( matching_context.network.id, modified_network, Vec::new(), @@ -189,7 +189,7 @@ async fn test_different_location_network_created_event_is_ignored( let _ = context.complete_config_handshake().await; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkCreated( + context.events_tx().send(GatewayCommand::NetworkCreated( other_network.id, other_network, )), @@ -212,7 +212,7 @@ async fn test_different_location_network_deleted_event_is_ignored( let _ = context.complete_config_handshake().await; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkDeleted( + context.events_tx().send(GatewayCommand::NetworkDeleted( other_network.id, other_network.name.clone(), )), @@ -222,7 +222,7 @@ async fn test_different_location_network_deleted_event_is_ignored( context.mock_gateway_mut().expect_no_outbound().await; assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkDeleted( + context.events_tx().send(GatewayCommand::NetworkDeleted( context.network.id, context.network.name.clone(), )), diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs index 7f1e61b5cc..bff29343eb 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs @@ -9,7 +9,7 @@ use defguard_common::db::{ wireguard::{LocationMfaMode, WireguardNetwork}, }, }; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_common::gateway_types::{ FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, Protocol, SnatBinding, }; @@ -190,7 +190,7 @@ pub(crate) async fn assert_device_event_is_ignored_before_config_handshake( device_name: &str, device_pubkey: &str, device_ip: &str, - build_event: fn(DeviceInfo) -> GatewayEvent, + build_event: fn(DeviceInfo) -> GatewayCommand, ) { let mut context = HandlerTestContext::new(options).await; assert_eq!(context.events_tx().receiver_count(), 0); @@ -215,7 +215,7 @@ pub(crate) async fn assert_device_event_for_different_network_is_ignored( device_name: &str, device_pubkey: &str, device_ip: &str, - build_event: fn(DeviceInfo) -> GatewayEvent, + build_event: fn(DeviceInfo) -> GatewayCommand, ) { let mut context = HandlerTestContext::new(options).await; let other_network = context.create_other_network().await; @@ -243,7 +243,7 @@ pub(crate) async fn assert_device_event_for_different_network_is_ignored( pub(crate) async fn assert_firewall_event_for_different_network_is_ignored( options: PgConnectOptions, - build_event: impl FnOnce(Id) -> GatewayEvent, + build_event: impl FnOnce(Id) -> GatewayCommand, ) { let mut context = HandlerTestContext::new(options).await; let other_network = context.create_other_network().await; diff --git a/crates/defguard_proxy_manager/src/handler.rs b/crates/defguard_proxy_manager/src/handler.rs index 4cba396a37..15d8aa9fe5 100644 --- a/crates/defguard_proxy_manager/src/handler.rs +++ b/crates/defguard_proxy_manager/src/handler.rs @@ -31,7 +31,7 @@ use defguard_core::{ ldap::utils::ldap_update_user_state, }, grpc::{ - GatewayEvent, + GatewayCommand, proxy::client_mfa::{ClientLoginSession, ClientMfaServer, PostureCheckOutcome}, }, version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, @@ -471,7 +471,7 @@ impl ProxyHandler { async fn message_loop( &mut self, tx: UnboundedSender, - wireguard_tx: Sender, + wireguard_tx: Sender, resp_stream: &mut Streaming, ) -> Result<(), ProxyError> { let pool = self.pool.clone(); diff --git a/crates/defguard_proxy_manager/src/lib.rs b/crates/defguard_proxy_manager/src/lib.rs index 839b722a4f..394517de68 100644 --- a/crates/defguard_proxy_manager/src/lib.rs +++ b/crates/defguard_proxy_manager/src/lib.rs @@ -7,7 +7,7 @@ use std::{ use std::{path::PathBuf, str::FromStr, sync::Mutex as StdMutex}; use axum_extra::extract::cookie::Key; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_common::{ db::{Id, models::proxy::Proxy}, types::proxy::ProxyControlMessage, @@ -401,14 +401,14 @@ impl ProxyManager { /// events, notifications, and side effects to Core components. #[derive(Clone)] pub struct ProxyTxSet { - wireguard: Sender, + wireguard: Sender, bidi_events: UnboundedSender, } impl ProxyTxSet { #[must_use] pub const fn new( - wireguard: Sender, + wireguard: Sender, bidi_events: UnboundedSender, ) -> Self { Self { diff --git a/crates/defguard_proxy_manager/src/servers/enrollment.rs b/crates/defguard_proxy_manager/src/servers/enrollment.rs index 792ed90d50..0a1cb588ef 100644 --- a/crates/defguard_proxy_manager/src/servers/enrollment.rs +++ b/crates/defguard_proxy_manager/src/servers/enrollment.rs @@ -21,7 +21,7 @@ use defguard_core::{ }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, EnrollmentEvent}, grpc::{ - GatewayEvent, InstanceInfo, + GatewayCommand, InstanceInfo, client_version::ClientFeature, utils::{build_device_config_response, parse_client_ip_agent}, }, @@ -48,7 +48,7 @@ use tonic::Status; pub(crate) struct EnrollmentServer { pool: PgPool, - wireguard_tx: Sender, + wireguard_tx: Sender, bidi_event_tx: UnboundedSender, } @@ -56,7 +56,7 @@ impl EnrollmentServer { #[must_use] pub(crate) fn new( pool: PgPool, - wireguard_tx: Sender, + wireguard_tx: Sender, bidi_event_tx: UnboundedSender, ) -> Self { Self { @@ -96,10 +96,10 @@ impl EnrollmentServer { } } - /// Sends given `GatewayEvent` to be handled by gateway GRPC server - pub(crate) fn send_wireguard_event(&self, event: GatewayEvent) { + /// Sends given `GatewayCommand` to be handled by gateway manager service + pub(crate) fn send_gateway_command(&self, event: GatewayCommand) { if let Err(err) = self.wireguard_tx.send(event) { - error!("Error sending WireGuard event {err}"); + error!("Error sending Gateway command: {err}"); } } @@ -785,7 +785,7 @@ impl EnrollmentServer { adding new device {}, user {}({})", device.wireguard_pubkey, user.username, user.id ); - self.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + self.send_gateway_command(GatewayCommand::FirewallConfigChanged( location_id, firewall_config, )); @@ -797,7 +797,7 @@ impl EnrollmentServer { "Sending DeviceCreated event to gateway for device {}, user {}({:?})", device.wireguard_pubkey, user.username, user.id, ); - self.send_wireguard_event(GatewayEvent::DeviceCreated(DeviceInfo { + self.send_gateway_command(GatewayCommand::DeviceCreated(DeviceInfo { device: device.clone(), network_info, })); diff --git a/crates/defguard_proxy_manager/src/tests/common/mod.rs b/crates/defguard_proxy_manager/src/tests/common/mod.rs index 5130bf31fa..81527c7fc3 100644 --- a/crates/defguard_proxy_manager/src/tests/common/mod.rs +++ b/crates/defguard_proxy_manager/src/tests/common/mod.rs @@ -24,7 +24,7 @@ use defguard_common::db::{ }, setup_pool, }; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_core::events::BidiStreamEvent; use defguard_proto::proxy::{ AcmeChallenge, AcmeIssueEvent, CoreRequest, CoreResponse, InitialInfo, core_response, @@ -382,7 +382,7 @@ impl Drop for MockProxyHarness { pub(crate) struct HandlerTestContext { pub(crate) pool: PgPool, pub(crate) proxy: Proxy, - pub(crate) wireguard_tx: broadcast::Sender, + pub(crate) wireguard_tx: broadcast::Sender, pub(crate) bidi_events_rx: UnboundedReceiver, pub(crate) mock_proxy: Option, handler_task: Option>>, diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs index e602bef255..d7cc39feb5 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs @@ -2,7 +2,7 @@ use defguard_common::db::models::{ Device, Settings, User, biometric_auth::BiometricAuth, polling_token::PollingToken, settings::update_current_settings, }; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_core::events::{BidiStreamEventType, EnrollmentEvent}; use defguard_proto::{ client_types::{ExistingDevice, MfaMethod, NewDevice, RegisterMobileAuthRequest}, @@ -283,7 +283,7 @@ async fn test_new_device_sends_gateway_device_created_event( .expect("gateway event channel closed"); assert!( - matches!(event, GatewayEvent::DeviceCreated(_)), + matches!(event, GatewayCommand::DeviceCreated(_)), "expected DeviceCreated gateway event, got: {event:?}" ); diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs index 330566c45a..e72ba11ae3 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs @@ -1,5 +1,5 @@ use defguard_common::db::Id; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_proto::{ client_types::{ClientMfaFinishRequest, ClientMfaStartRequest, MfaMethod}, proxy::{AwaitRemoteMfaFinishRequest, CoreRequest, core_request, core_response}, @@ -145,7 +145,7 @@ async fn test_mfa_finish_succeeds_with_totp_code(_: PgPoolOptions, options: PgCo .expect("timed out waiting for GatewayEvent::MfaSessionAuthorized") .expect("gateway event channel closed"); let gateway_loc_id = match event { - GatewayEvent::MfaSessionAuthorized(loc_id, _, _) => loc_id, + GatewayCommand::MfaSessionAuthorized(loc_id, _, _) => loc_id, other => panic!("expected MfaSessionAuthorized, got: {other:?}"), }; assert_eq!(gateway_loc_id, network.id); @@ -322,7 +322,7 @@ async fn test_mfa_finish_succeeds_and_creates_session(_: PgPoolOptions, options: .expect("timed out waiting for GatewayEvent::MfaSessionAuthorized") .expect("gateway event channel closed"); let loc_id = match event { - GatewayEvent::MfaSessionAuthorized(loc_id, _, _) => loc_id, + GatewayCommand::MfaSessionAuthorized(loc_id, _, _) => loc_id, other => panic!("expected MfaSessionAuthorized, got: {other:?}"), }; assert_eq!(loc_id, network.id); @@ -587,12 +587,12 @@ async fn test_mfa_finish_replaces_existing_session_disconnects_old( .expect("gateway event channel closed"); match event { - GatewayEvent::MfaSessionDisconnected(loc_id, ref dev) => { + GatewayCommand::MfaSessionDisconnected(loc_id, ref dev) => { assert_eq!(loc_id, network.id, "disconnected session location mismatch"); assert_eq!(dev.id, device.id, "disconnected session device mismatch"); got_disconnected = true; } - GatewayEvent::MfaSessionAuthorized(loc_id, _, _) => { + GatewayCommand::MfaSessionAuthorized(loc_id, _, _) => { assert_eq!(loc_id, network.id, "authorized session location mismatch"); got_authorized = true; } diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 2a41c8e699..76aeb2f70e 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -1,5 +1,5 @@ use defguard_common::db::Id; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use thiserror::Error; use tokio::sync::{broadcast::error::SendError as BroadcastSendError, mpsc::error::SendError}; @@ -36,7 +36,7 @@ pub enum SessionManagerError { #[error("Failed to send session manager event: {0}")] SessionManagerEventError(Box>), #[error("Failed to send gateway manager event: {0}")] - GatewayManagerEventError(Box>), + GatewayManagerEventError(Box>), } impl From> for SessionManagerError { @@ -44,8 +44,8 @@ impl From> for SessionManagerError { Self::SessionManagerEventError(Box::new(error)) } } -impl From> for SessionManagerError { - fn from(error: BroadcastSendError) -> Self { +impl From> for SessionManagerError { + fn from(error: BroadcastSendError) -> Self { Self::GatewayManagerEventError(Box::new(error)) } } diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index c1a4ee6632..a6c8c79522 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -1,5 +1,5 @@ use chrono::Utc; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_common::{ db::{ Id, @@ -42,7 +42,7 @@ pub async fn run_session_manager( pool: PgPool, mut peer_stats_rx: UnboundedReceiver, session_manager_event_tx: UnboundedSender, - gateway_tx: Sender, + gateway_tx: Sender, ) -> Result<(), SessionManagerError> { info!("Starting VPN client session manager service"); let mut session_update_timer = interval(SESSION_UPDATE_INTERVAL); @@ -103,7 +103,7 @@ pub async fn run_session_manager_iteration( pub struct SessionManager { pool: PgPool, session_manager_event_tx: UnboundedSender, - gateway_tx: Sender, + gateway_tx: Sender, } impl SessionManager { @@ -111,7 +111,7 @@ impl SessionManager { pub fn new( pool: PgPool, session_manager_event_tx: UnboundedSender, - gateway_tx: Sender, + gateway_tx: Sender, ) -> Self { Self { pool, @@ -323,7 +323,7 @@ impl SessionManager { debug!( "Sending MFA session disconnect event for device {device} in location {location} to gateway manager" ); - let event = GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); + let event = GatewayCommand::MfaSessionDisconnected(location.id, device.clone()); self.gateway_tx.send(event)?; Ok(()) } diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index 63dc91d0ea..db3fb1871e 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -16,7 +16,7 @@ use defguard_common::{ wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, - gateway_event::GatewayEvent, + gateway_event::GatewayCommand, messages::peer_stats_update::PeerStatsUpdate, }; use defguard_session_manager::{ @@ -38,7 +38,7 @@ pub(crate) struct SessionManagerHarness { stats_tx: mpsc::UnboundedSender, pub(crate) stats_rx: mpsc::UnboundedReceiver, pub(crate) event_rx: mpsc::UnboundedReceiver, - pub(crate) gateway_rx: broadcast::Receiver, + pub(crate) gateway_rx: broadcast::Receiver, } pub(crate) fn assert_no_session_manager_events(harness: &mut SessionManagerHarness) { diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs index f7c828559a..942588dba0 100644 --- a/crates/defguard_session_manager/tests/session_manager/mfa.rs +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -9,7 +9,7 @@ use defguard_common::db::{ }, setup_pool, }; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_session_manager::events::SessionManagerEventType; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::time::{Duration, timeout}; @@ -503,7 +503,7 @@ async fn test_inactive_mfa_connected_sessions_disconnect_and_clear_authorization .expect("timed out waiting for MFA disconnect gateway event") .expect("gateway event channel closed"); match gateway_event { - GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device) => { + GatewayCommand::MfaSessionDisconnected(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } @@ -570,7 +570,7 @@ async fn test_never_connected_mfa_new_sessions_disconnect_after_threshold( .expect("timed out waiting for MFA disconnect gateway event for new session") .expect("gateway event channel closed"); match gateway_event { - GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device) => { + GatewayCommand::MfaSessionDisconnected(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } diff --git a/crates/defguard_setup/src/migration.rs b/crates/defguard_setup/src/migration.rs index fb6fa87d6a..9d2d54d012 100644 --- a/crates/defguard_setup/src/migration.rs +++ b/crates/defguard_setup/src/migration.rs @@ -10,7 +10,7 @@ use axum::{ serve, }; use axum_extra::extract::cookie::Key; -use defguard_common::gateway_event::GatewayEvent; +use defguard_common::gateway_event::GatewayCommand; use defguard_common::{VERSION, db::models::Settings, types::proxy::ProxyControlMessage}; use defguard_core::{ appstate::AppState, @@ -60,7 +60,7 @@ use crate::handlers::{ pub struct MigrationWebapp { pub router: Router, _event_rx: mpsc::UnboundedReceiver, - _wireguard_rx: broadcast::Receiver, + _wireguard_rx: broadcast::Receiver, _proxy_control_rx: mpsc::Receiver, } @@ -72,7 +72,7 @@ pub fn build_migration_webapp( let failed_logins = Arc::new(Mutex::new(FailedLoginMap::new())); let (webhook_tx, webhook_rx) = mpsc::unbounded_channel::(); let (event_tx, event_rx) = mpsc::unbounded_channel::(); - let (wireguard_tx, wireguard_rx) = broadcast::channel::(64); + let (wireguard_tx, wireguard_rx) = broadcast::channel::(64); let (web_reload_tx, _web_reload_rx) = broadcast::channel::<()>(8); let (proxy_control_tx, proxy_control_rx) = mpsc::channel(32); let incompatible_components = Arc::new(RwLock::new(IncompatibleComponents::default())); From 89c33eea283ca861da3e97ac6eb31fbac661772c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 19 May 2026 09:13:07 +0200 Subject: [PATCH 04/10] rename outstanding references --- crates/defguard_common/src/gateway_event.rs | 8 +- crates/defguard_core/src/appstate.rs | 10 +- .../src/enterprise/db/models/acl.rs | 4 +- .../src/enterprise/directory_sync/mod.rs | 34 +++--- .../src/enterprise/directory_sync/tests.rs | 100 +++++++++--------- .../src/enterprise/handlers/openid_login.rs | 2 +- .../src/grpc/proxy/client_mfa.rs | 32 +++--- crates/defguard_core/src/handlers/group.rs | 12 +-- crates/defguard_core/src/handlers/user.rs | 4 +- .../defguard_core/src/handlers/wireguard.rs | 4 +- crates/defguard_core/src/lib.rs | 8 +- .../src/location_management/mod.rs | 8 +- crates/defguard_core/src/user_management.rs | 14 +-- crates/defguard_core/src/utility_thread.rs | 24 ++--- .../tests/integration/api/common/mod.rs | 12 +-- .../tests/integration/api/proxy_certs.rs | 4 +- .../tests/integration/api/wireguard.rs | 26 ++--- .../api/wireguard_network_allowed_groups.rs | 76 ++++++------- .../api/wireguard_network_devices.rs | 10 +- .../api/wireguard_network_import.rs | 14 +-- .../src/handlers/bidi.rs | 4 +- crates/defguard_event_router/src/lib.rs | 12 +-- .../defguard_gateway_manager/src/handler.rs | 4 +- crates/defguard_proxy_manager/src/handler.rs | 4 +- .../src/servers/enrollment.rs | 12 +-- .../src/tests/common/mod.rs | 12 +-- .../tests/proxy_manager/handler/enrollment.rs | 10 +- .../src/tests/proxy_manager/handler/mfa.rs | 40 +++---- .../src/tests/proxy_manager/handler/oidc.rs | 2 +- crates/defguard_setup/src/migration.rs | 8 +- 30 files changed, 257 insertions(+), 257 deletions(-) diff --git a/crates/defguard_common/src/gateway_event.rs b/crates/defguard_common/src/gateway_event.rs index c429fef259..55d7c277de 100644 --- a/crates/defguard_common/src/gateway_event.rs +++ b/crates/defguard_common/src/gateway_event.rs @@ -34,9 +34,9 @@ pub enum GatewayCommand { /// Sends a [`GatewayCommand`] to the gateway manager service. /// /// In API handler context prefer `AppState::send_gateway_command`. -pub fn send_gateway_command(event: GatewayCommand, wg_tx: &Sender) { +pub fn send_gateway_command(event: GatewayCommand, gateway_tx: &Sender) { debug!("Sending the following command to Gateway Manager: {event:?}"); - if let Err(err) = wg_tx.send(event) { + if let Err(err) = gateway_tx.send(event) { error!("Error sending Gateway command: {err}"); } } @@ -44,9 +44,9 @@ pub fn send_gateway_command(event: GatewayCommand, wg_tx: &Sender, wg_tx: &Sender) { +pub fn send_multiple_gateway_commands(events: Vec, gateway_tx: &Sender) { debug!("Sending {} gateway commands", events.len()); for event in events { - send_gateway_command(event, wg_tx); + send_gateway_command(event, gateway_tx); } } diff --git a/crates/defguard_core/src/appstate.rs b/crates/defguard_core/src/appstate.rs index 7566e0e1c2..201bd14a9d 100644 --- a/crates/defguard_core/src/appstate.rs +++ b/crates/defguard_core/src/appstate.rs @@ -29,7 +29,7 @@ const X_DEFGUARD_EVENT: &str = "x-defguard-event"; pub struct AppState { pub pool: PgPool, tx: UnboundedSender, - pub wireguard_tx: Sender, + pub gateway_tx: Sender, pub web_reload_tx: tokio::sync::broadcast::Sender<()>, pub failed_logins: Arc>, key: Key, @@ -89,13 +89,13 @@ impl AppState { /// Sends given `GatewayCommand` to be handled by gateway manager service. /// Convenience wrapper around [`send_gateway_command`] pub fn send_gateway_command(&self, event: GatewayCommand) { - send_gateway_command(event, &self.wireguard_tx); + send_gateway_command(event, &self.gateway_tx); } /// Sends multiple commands to be handled by gateway manager service. /// Convenience wrapper around [`send_multiple_gateway_commands`] pub fn send_multiple_gateway_commands(&self, events: Vec) { - send_multiple_gateway_commands(events, &self.wireguard_tx); + send_multiple_gateway_commands(events, &self.gateway_tx); } /// Sends event to the main event router @@ -118,7 +118,7 @@ impl AppState { pool: PgPool, tx: UnboundedSender, rx: UnboundedReceiver, - wireguard_tx: Sender, + gateway_tx: Sender, web_reload_tx: tokio::sync::broadcast::Sender<()>, key: Key, failed_logins: Arc>, @@ -132,7 +132,7 @@ impl AppState { Self { pool, tx, - wireguard_tx, + gateway_tx, web_reload_tx, failed_logins, key, diff --git a/crates/defguard_core/src/enterprise/db/models/acl.rs b/crates/defguard_core/src/enterprise/db/models/acl.rs index c5f32de6f7..11ab825a1c 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl.rs @@ -555,7 +555,7 @@ impl AclRule { None => { debug!( "No firewall config generated for location {location}. Not sending a \ - gateway event" + gateway command" ); } } @@ -1704,7 +1704,7 @@ impl AclAlias { None => { debug!( "No firewall config generated for location {location}. Not sending a \ - gateway event" + gateway command" ); } } diff --git a/crates/defguard_core/src/enterprise/directory_sync/mod.rs b/crates/defguard_core/src/enterprise/directory_sync/mod.rs index c927c25e0b..22d358b8e9 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/mod.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/mod.rs @@ -324,7 +324,7 @@ async fn sync_user_groups( directory_sync: &T, user: &User, pool: &PgPool, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), DirectorySyncError> { info!("Syncing groups of user {} with the directory", user.email); let directory_groups = directory_sync.get_user_groups(&user.email).await?; @@ -371,7 +371,7 @@ async fn sync_user_groups( } } - sync_allowed_user_devices(user, &mut transaction, wg_tx) + sync_allowed_user_devices(user, &mut transaction, gateway_tx) .await .map_err(|err| { DirectorySyncError::NetworkUpdateError(format!( @@ -418,7 +418,7 @@ pub(crate) async fn test_directory_sync_connection( pub async fn sync_user_groups_if_configured( user: &User, pool: &PgPool, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_business_license_active() { @@ -435,7 +435,7 @@ pub async fn sync_user_groups_if_configured( match DirectorySyncClient::build(pool).await { Ok(mut dir_sync) => { dir_sync.prepare().await?; - sync_user_groups(&dir_sync, user, pool, wg_tx).await?; + sync_user_groups(&dir_sync, user, pool, gateway_tx).await?; } Err(err) => { error!("Failed to build directory sync client: {err}"); @@ -482,7 +482,7 @@ async fn create_and_add_to_group( async fn sync_all_users_groups( directory_sync: &T, pool: &PgPool, - wg_tx: &Sender, + gateway_tx: &Sender, all_users: Option<&[DirectoryUser]>, ) -> Result<(), DirectorySyncError> { info!("Syncing all users' groups with the directory, this may take a while..."); @@ -579,7 +579,7 @@ async fn sync_all_users_groups( create_and_add_to_group(&user, group, pool).await?; } - sync_allowed_user_devices(&user, &mut transaction, wg_tx).await.map_err(|err| { + sync_allowed_user_devices(&user, &mut transaction, gateway_tx).await.map_err(|err| { DirectorySyncError::NetworkUpdateError(format!( "Failed to sync allowed devices for user {} during directory synchronization: {err}", user.email @@ -618,7 +618,7 @@ fn is_directory_sync_enabled(provider: Option<&OpenIdProvider>) -> bool { async fn sync_all_users_state( pool: &PgPool, - wg_tx: &Sender, + gateway_tx: &Sender, all_users: &[DirectoryUser], ) -> Result<(), DirectorySyncError> { info!("Syncing all users' state with the directory, this may take a while..."); @@ -656,7 +656,7 @@ async fn sync_all_users_state( &mut transaction, &inactive_directory_users, &mut modified_users, - wg_tx, + gateway_tx, ) .await?; @@ -789,7 +789,7 @@ async fn sync_all_users_state( the admin behavior setting is set to disable", user.email ); - disable_user(&mut user, &mut transaction, wg_tx).await.map_err(|err| { + disable_user(&mut user, &mut transaction, gateway_tx).await.map_err(|err| { DirectorySyncError::UserUpdateError(format!( "Failed to disable admin {} during directory synchronization: {err}", user.email @@ -819,7 +819,7 @@ async fn sync_all_users_state( if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { deleted_users.push(user.clone().as_noid()); } - delete_user_and_cleanup_devices(user, &mut transaction, wg_tx) + delete_user_and_cleanup_devices(user, &mut transaction, gateway_tx) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -844,7 +844,7 @@ async fn sync_all_users_state( the user behavior setting is set to disable", user.email ); - disable_user(&mut user, &mut transaction, wg_tx).await.map_err(|err| { + disable_user(&mut user, &mut transaction, gateway_tx).await.map_err(|err| { DirectorySyncError::UserUpdateError(format!( "Failed to disable user {} during directory synchronization: {err}", user.email @@ -866,7 +866,7 @@ async fn sync_all_users_state( if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { deleted_users.push(user.clone().as_noid()); } - delete_user_and_cleanup_devices(user, &mut transaction, wg_tx) + delete_user_and_cleanup_devices(user, &mut transaction, gateway_tx) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -904,7 +904,7 @@ async fn sync_inactive_directory_users( transaction: &mut PgConnection, inactive_directory_users: &[&DirectoryUser], modified_users: &mut Vec>, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), DirectorySyncError> { // find all active Defguard users disabled in directory let disabled_users_emails = inactive_directory_users @@ -929,7 +929,7 @@ async fn sync_inactive_directory_users( "Disabling user {} because they are disabled in the directory", user.email ); - disable_user(&mut user, transaction, wg_tx) + disable_user(&mut user, transaction, gateway_tx) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -1005,7 +1005,7 @@ pub(crate) async fn get_directory_sync_interval(pool: &PgPool) -> u64 { // Performs the directory sync job. This function is called by the utility thread. pub(crate) async fn do_directory_sync( pool: &PgPool, - wireguard_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] if !is_business_license_active() { @@ -1042,7 +1042,7 @@ pub(crate) async fn do_directory_sync( DirectorySyncTarget::All | DirectorySyncTarget::Users ) { let users = dir_sync.get_all_users().await?; - sync_all_users_state(pool, wireguard_tx, &users).await?; + sync_all_users_state(pool, gateway_tx, &users).await?; all_users = Some(users); } if matches!( @@ -1063,7 +1063,7 @@ pub(crate) async fn do_directory_sync( } _ => None, // No need to pass all users for other providers, for the time being. }; - sync_all_users_groups(&dir_sync, pool, wireguard_tx, users_to_pass.as_deref()) + sync_all_users_groups(&dir_sync, pool, gateway_tx, users_to_pass.as_deref()) .await?; } } diff --git a/crates/defguard_core/src/enterprise/directory_sync/tests.rs b/crates/defguard_core/src/enterprise/directory_sync/tests.rs index 2704837910..0487838e7e 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/tests.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/tests.rs @@ -146,7 +146,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -167,7 +167,7 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); @@ -176,7 +176,7 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); // No events - assert!(wg_rx.try_recv().is_err()); + assert!(gateway_rx.try_recv().is_err()); } // Delete users, keep admins @@ -186,7 +186,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -208,7 +208,7 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); @@ -216,7 +216,7 @@ mod test { assert!(get_test_user(&pool, "user2").await.is_none()); assert!(get_test_user(&pool, "testuser").await.is_some()); - let event = wg_rx.try_recv(); + let event = gateway_rx.try_recv(); if let Ok(GatewayCommand::DeviceDeleted(dev)) = event { assert_eq!(dev.device.user_id, user2.id); } else { @@ -229,7 +229,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); User::init_admin_user(&pool, "pass123").await.unwrap(); let _ = make_test_provider( @@ -254,7 +254,7 @@ mod test { assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); @@ -270,7 +270,7 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); // Check that we received a device deleted event for whichever admin was removed - let event = wg_rx.try_recv(); + let event = gateway_rx.try_recv(); if let Ok(GatewayCommand::DeviceDeleted(dev)) = event { assert!(dev.device.user_id == user1.id || dev.device.user_id == user3.id); } else { @@ -284,7 +284,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -308,7 +308,7 @@ mod test { assert!(get_test_user(&pool, "user2").await.is_some()); assert!(get_test_user(&pool, "testuser").await.is_some()); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); @@ -324,7 +324,7 @@ mod test { assert!(get_test_user(&pool, "testuser").await.is_some()); // Check for device deletion events - let event1 = wg_rx.try_recv(); + let event1 = gateway_rx.try_recv(); if let Ok(GatewayCommand::DeviceDeleted(dev)) = event1 { assert!( dev.device.user_id == user1.id @@ -335,7 +335,7 @@ mod test { panic!("Expected a DeviceDeleted event"); } - let event2 = wg_rx.try_recv(); + let event2 = gateway_rx.try_recv(); if let Ok(GatewayCommand::DeviceDeleted(dev)) = event2 { assert!( dev.device.user_id == user1.id @@ -353,7 +353,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Disable, @@ -395,19 +395,19 @@ mod test { assert!(testuserdisabled.is_active); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); // Check for device disconnection events - let event1 = wg_rx.try_recv(); + let event1 = gateway_rx.try_recv(); if let Ok(GatewayCommand::DeviceDeleted(dev)) = event1 { assert!(dev.device.user_id == user2.id || dev.device.user_id == testuserdisabled.id); } else { panic!("Expected a DeviceDisconnected event"); } - let event2 = wg_rx.try_recv(); + let event2 = gateway_rx.try_recv(); if let Ok(GatewayCommand::DeviceDeleted(dev)) = event2 { assert!(dev.device.user_id == user2.id || dev.device.user_id == testuserdisabled.id); } else { @@ -436,7 +436,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); // Added mut wg_rx + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); // Added mut gateway_rx make_test_provider( &pool, DirectorySyncUserBehavior::Keep, @@ -468,12 +468,12 @@ mod test { assert!(testuserdisabled.is_active); let all_users = client.get_all_users().await.unwrap(); - sync_all_users_state(&pool, &wg_tx, &all_users) + sync_all_users_state(&pool, &gateway_tx, &all_users) .await .unwrap(); // Check for device disconnection events - let event1 = wg_rx.try_recv(); + let event1 = gateway_rx.try_recv(); if let Ok(GatewayCommand::DeviceDeleted(dev)) = event1 { assert!( dev.device.user_id == user1.id @@ -484,7 +484,7 @@ mod test { panic!("Expected a DeviceDisconnected event"); } - let event2 = wg_rx.try_recv(); + let event2 = gateway_rx.try_recv(); if let Ok(GatewayCommand::DeviceDeleted(dev)) = event2 { assert!( dev.device.user_id == user1.id @@ -514,7 +514,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -530,7 +530,7 @@ mod test { make_test_user_and_device("testuser2", &pool).await; make_test_user_and_device("testuserdisabled", &pool).await; let all_users = client.get_all_users().await.unwrap(); - sync_all_users_groups(&client, &pool, &wg_tx, Some(&all_users)) + sync_all_users_groups(&client, &pool, &gateway_tx, Some(&all_users)) .await .unwrap(); @@ -571,7 +571,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -585,7 +585,7 @@ mod test { let user = make_test_user_and_device("testuser", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - sync_user_groups_if_configured(&user, &pool, &wg_tx) + sync_user_groups_if_configured(&user, &pool, &gateway_tx) .await .unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); @@ -600,7 +600,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -614,7 +614,7 @@ mod test { let user = make_test_user_and_device("testuser", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); } @@ -625,7 +625,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -648,23 +648,23 @@ mod test { let user2_pre_sync = make_test_user_and_device("user2", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 3); let user2 = get_test_user(&pool, "user2").await; assert!(user2.is_none()); let mut transaction = pool.begin().await.unwrap(); - sync_allowed_user_devices(&user, &mut transaction, &wg_tx) + sync_allowed_user_devices(&user, &mut transaction, &gateway_tx) .await .unwrap(); transaction.commit().await.unwrap(); - let event = wg_rx.try_recv(); + let event = gateway_rx.try_recv(); if let Ok(GatewayCommand::DeviceDeleted(dev)) = event { assert_eq!(dev.device.user_id, user2_pre_sync.id); } else { panic!("Expected DeviceDeleted event"); } - let event = wg_rx.try_recv(); + let event = gateway_rx.try_recv(); if let Ok(GatewayCommand::DeviceCreated(dev)) = event { panic!("Unexpected DeviceCreated event: {dev:?}"); } @@ -676,7 +676,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -691,7 +691,7 @@ mod test { make_test_user_and_device("user2", &pool).await; let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 0); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); let user_groups = user.member_of(&pool).await.unwrap(); assert_eq!(user_groups.len(), 3); let user2 = get_test_user(&pool, "user2").await; @@ -704,7 +704,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -724,7 +724,7 @@ mod test { assert_eq!(user_groups.len(), 1); assert!(user.is_admin(&pool).await.unwrap()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); // He should still be an admin as it's the last one assert!(user.is_admin(&pool).await.unwrap()); @@ -733,7 +733,7 @@ mod test { let user2 = make_test_user_and_device("testuser2", &pool).await; user2.add_to_group(&pool, &admin_grp).await.unwrap(); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); let admins = User::find_admins(&pool).await.unwrap(); // There should be only one admin left @@ -742,7 +742,7 @@ mod test { let defguard_user = make_test_user_and_device("defguard", &pool).await; make_admin(&pool, &defguard_user).await; - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); } #[sqlx::test] @@ -751,7 +751,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, _) = broadcast::channel::(16); + let (gateway_tx, _) = broadcast::channel::(16); make_test_provider( &pool, DirectorySyncUserBehavior::Delete, @@ -768,7 +768,7 @@ mod test { make_admin(&pool, &defguard_user).await; assert!(defguard_user.is_admin(&pool).await.unwrap()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); // The user should still be an admin assert!(defguard_user.is_admin(&pool).await.unwrap()); @@ -780,7 +780,7 @@ mod test { .await .unwrap(); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); let user = User::find_by_username(&pool, "defguard").await.unwrap(); assert!(user.is_none()); } @@ -791,7 +791,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); // disable prefetching users make_test_provider( @@ -809,14 +809,14 @@ mod test { let defguard_users = User::all(&pool).await.unwrap(); assert!(defguard_users.is_empty()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); // no users in Defguard after sync let defguard_users = User::all(&pool).await.unwrap(); assert!(defguard_users.is_empty()); // No events - assert!(wg_rx.try_recv().is_err()); + assert!(gateway_rx.try_recv().is_err()); } #[sqlx::test] @@ -825,7 +825,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config.clone()); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); // enable prefetching users make_test_provider( @@ -843,14 +843,14 @@ mod test { let defguard_users = User::all(&pool).await.unwrap(); assert!(defguard_users.is_empty()); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); // all active directory users were synced let defguard_users = User::all(&pool).await.unwrap(); assert_eq!(defguard_users.len(), 3); // No events - assert!(wg_rx.try_recv().is_err()); + assert!(gateway_rx.try_recv().is_err()); } #[sqlx::test] @@ -862,7 +862,7 @@ mod test { let config = DefGuardConfig::new_test_config(); let _ = SERVER_CONFIG.set(config); - let (wg_tx, mut wg_rx) = broadcast::channel::(16); + let (gateway_tx, mut gateway_rx) = broadcast::channel::(16); // enable prefetching users make_test_provider( @@ -892,7 +892,7 @@ mod test { set_cached_license(Some(license)); update_counts(&pool).await.unwrap(); - do_directory_sync(&pool, &wg_tx).await.unwrap(); + do_directory_sync(&pool, &gateway_tx).await.unwrap(); update_counts(&pool).await.unwrap(); let user_count = get_counts().user(); @@ -902,6 +902,6 @@ mod test { assert_eq!(defguard_users.len(), user_limit as usize); // No events - assert!(wg_rx.try_recv().is_err()); + assert!(gateway_rx.try_recv().is_err()); } } diff --git a/crates/defguard_core/src/enterprise/handlers/openid_login.rs b/crates/defguard_core/src/enterprise/handlers/openid_login.rs index 47df9efd84..1cdc92d161 100644 --- a/crates/defguard_core/src/enterprise/handlers/openid_login.rs +++ b/crates/defguard_core/src/enterprise/handlers/openid_login.rs @@ -636,7 +636,7 @@ pub async fn auth_callback( // sync the groups for the MFA enabled user logging in through the provider without firing it on // every login attempt, even for standard, non-provider users. if let Err(err) = - sync_user_groups_if_configured(&user, &appstate.pool, &appstate.wireguard_tx).await + sync_user_groups_if_configured(&user, &appstate.pool, &appstate.gateway_tx).await { error!( "Failed to sync user groups for user {} with the directory while the user was trying \ diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 50dcbb06a8..7d9e373ec3 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -83,7 +83,7 @@ pub struct ClientLoginSession { pub struct ClientMfaServer { pub(crate) pool: PgPool, - wireguard_tx: Sender, + gateway_tx: Sender, pub(crate) sessions: Arc>>, remote_mfa_responses: Arc>>>, bidi_event_tx: UnboundedSender, @@ -104,14 +104,14 @@ impl ClientMfaServer { #[must_use] pub fn new( pool: PgPool, - wireguard_tx: Sender, + gateway_tx: Sender, bidi_event_tx: UnboundedSender, remote_mfa_responses: Arc>>>, sessions: Arc>>, ) -> Self { Self { pool, - wireguard_tx, + gateway_tx, sessions, remote_mfa_responses, bidi_event_tx, @@ -710,12 +710,12 @@ impl ClientMfaServer { let gateway_network_info = Self::build_mfa_authorized_gateway_network_info(network_device, key.public.clone()); - // send gateway event + // send gateway command debug!("Sending `peer_create` message to gateway"); let event = GatewayCommand::MfaSessionAuthorized(location.id, device.clone(), gateway_network_info); - self.wireguard_tx.send(event).map_err(|err| { - error!("Error sending WireGuard event: {err}"); + self.gateway_tx.send(event).map_err(|err| { + error!("Error sending gateway command: {err}"); Status::internal("unexpected error") })?; @@ -951,9 +951,9 @@ impl ClientMfaServer { // gateway update is only needed to remove peer for MFA sessions // this is needed to remove peers for both Connected and New sessions if is_mfa_session { - let gateway_event = GatewayCommand::MfaSessionDisconnected(location.id, device.clone()); - self.wireguard_tx.send(gateway_event).map_err(|err| { - error!("Error sending WireGuard event: {err}"); + let gateway_command = GatewayCommand::MfaSessionDisconnected(location.id, device.clone()); + self.gateway_tx.send(gateway_command).map_err(|err| { + error!("Error sending gateway command: {err}"); Status::internal("unexpected error") })?; } @@ -1063,10 +1063,10 @@ mod tests { .await .expect("should replace connected MFA session"); - let gateway_event = gateway_rx + let gateway_command = gateway_rx .try_recv() .expect("expected MFA gateway disconnect event for replaced connected session"); - match gateway_event { + match gateway_command { GatewayCommand::MfaSessionDisconnected(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); @@ -1138,10 +1138,10 @@ mod tests { .await .expect("should replace new MFA session"); - let gateway_event = gateway_rx + let gateway_command = gateway_rx .try_recv() .expect("expected MFA gateway disconnect event for replaced new session"); - match gateway_event { + match gateway_command { GatewayCommand::MfaSessionDisconnected(location_id, disconnected_device) => { assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); @@ -1237,7 +1237,7 @@ mod tests { tokio::sync::mpsc::UnboundedReceiver, tokio::sync::broadcast::Receiver, ) { - let (wireguard_tx, wireguard_rx) = broadcast::channel(8); + let (gateway_tx, gateway_rx) = broadcast::channel(8); let (bidi_event_tx, bidi_event_rx) = mpsc::unbounded_channel(); let remote_mfa_responses: Arc>>> = Arc::default(); @@ -1246,13 +1246,13 @@ mod tests { ( ClientMfaServer::new( pool, - wireguard_tx, + gateway_tx, bidi_event_tx, remote_mfa_responses, sessions, ), bidi_event_rx, - wireguard_rx, + gateway_rx, ) } diff --git a/crates/defguard_core/src/handlers/group.rs b/crates/defguard_core/src/handlers/group.rs index 09006a46b4..6023da1045 100644 --- a/crates/defguard_core/src/handlers/group.rs +++ b/crates/defguard_core/src/handlers/group.rs @@ -110,7 +110,7 @@ pub(crate) async fn bulk_assign_to_groups( } } - sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; + sync_all_networks(&mut transaction, &appstate.gateway_tx).await?; transaction.commit().await?; @@ -364,7 +364,7 @@ pub(crate) async fn create_group( .insert(&group_info.name); } - sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; + sync_all_networks(&mut transaction, &appstate.gateway_tx).await?; transaction.commit().await?; @@ -498,7 +498,7 @@ pub(crate) async fn modify_group( .insert(group.name.as_str()); } - sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; + sync_all_networks(&mut transaction, &appstate.gateway_tx).await?; let users_after = group.members(&mut *transaction).await?.clone(); transaction.commit().await?; @@ -596,7 +596,7 @@ pub(crate) async fn delete_group( // sync allowed devices for all locations let mut conn = appstate.pool.acquire().await?; - sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; + sync_all_networks(&mut conn, &appstate.gateway_tx).await?; info!( "User {} deleted group {}", @@ -653,7 +653,7 @@ pub(crate) async fn add_group_member( ldap_add_user_to_groups(&user, hashset![group.name.as_str()], &appstate.pool).await; ldap_update_user_state(&mut user, &appstate.pool).await; let mut conn = appstate.pool.acquire().await?; - sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; + sync_all_networks(&mut conn, &appstate.gateway_tx).await?; info!("Added user: {} to group: {}", user.username, group.name); appstate.emit_event(ApiEvent { context, @@ -716,7 +716,7 @@ pub(crate) async fn remove_group_member( .await; let mut conn = appstate.pool.acquire().await?; - sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; + sync_all_networks(&mut conn, &appstate.gateway_tx).await?; info!("Removed user: {} from group: {}", user.username, group.name); appstate.emit_event(ApiEvent { context, diff --git a/crates/defguard_core/src/handlers/user.rs b/crates/defguard_core/src/handlers/user.rs index 5b4fb2529a..aa5d6662c2 100644 --- a/crates/defguard_core/src/handlers/user.rs +++ b/crates/defguard_core/src/handlers/user.rs @@ -775,7 +775,7 @@ pub(crate) async fn modify_user( "User {} changed {username} groups or status, syncing allowed network devices.", session.user.username ); - sync_allowed_user_devices(&user, &mut transaction, &appstate.wireguard_tx).await?; + sync_allowed_user_devices(&user, &mut transaction, &appstate.gateway_tx).await?; } // remove API tokens when deactivating a user @@ -908,7 +908,7 @@ pub(crate) async fn delete_user( } else { None }; - delete_user_and_cleanup_devices(user.clone(), &mut transaction, &appstate.wireguard_tx) + delete_user_and_cleanup_devices(user.clone(), &mut transaction, &appstate.gateway_tx) .await?; appstate.trigger_action(AppEvent::UserDeleted(username.clone())); diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index aaaa3b3897..8b9cdaf53a 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -871,7 +871,7 @@ pub(crate) async fn add_device( let (network_info, configs) = device.add_to_all_networks(&mut transaction).await?; - // prepare a list of gateway events to be sent + // prepare a list of gateway commands to be sent let mut events = Vec::new(); // get all locations affected by device being added @@ -1197,7 +1197,7 @@ pub(crate) async fn delete_device( let device_id = device_info.device.id; events.push(GatewayCommand::DeviceDeleted(device_info.clone())); - // send generated gateway events + // send generated gateway commands appstate.send_multiple_gateway_commands(events); // Emit event specific to the device type. diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index 60c7e9c7cf..9bb7ae4f6c 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -256,7 +256,7 @@ async fn openapi() -> Json { pub fn build_webapp( webhook_tx: UnboundedSender, webhook_rx: UnboundedReceiver, - wireguard_tx: Sender, + gateway_tx: Sender, web_reload_tx: tokio::sync::broadcast::Sender<()>, worker_state: Arc>, pool: PgPool, @@ -734,7 +734,7 @@ pub fn build_webapp( pool.clone(), webhook_tx, webhook_rx, - wireguard_tx, + gateway_tx, web_reload_tx, key, failed_logins, @@ -804,7 +804,7 @@ pub async fn run_web_server( worker_state: Arc>, webhook_tx: UnboundedSender, webhook_rx: UnboundedReceiver, - wireguard_tx: Sender, + gateway_tx: Sender, web_reload_tx: tokio::sync::broadcast::Sender<()>, pool: PgPool, failed_logins: Arc>, @@ -822,7 +822,7 @@ pub async fn run_web_server( let webapp = build_webapp( webhook_tx, webhook_rx, - wireguard_tx, + gateway_tx, web_reload_tx.clone(), worker_state, pool.clone(), diff --git a/crates/defguard_core/src/location_management/mod.rs b/crates/defguard_core/src/location_management/mod.rs index 164048d7d5..4d6328bd54 100644 --- a/crates/defguard_core/src/location_management/mod.rs +++ b/crates/defguard_core/src/location_management/mod.rs @@ -41,7 +41,7 @@ pub enum LocationManagementError { // run sync_allowed_devices on all wireguard networks pub(crate) async fn sync_all_networks( conn: &mut PgConnection, - wireguard_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), LocationManagementError> { info!("Syncing allowed devices for all WireGuard locations"); let locations = WireguardNetwork::all(&mut *conn).await?; @@ -58,9 +58,9 @@ pub(crate) async fn sync_all_networks( firewall_config, )); } - // check if any gateway events need to be sent + // check if any gateway commands need to be sent if !gateway_events.is_empty() { - send_multiple_gateway_commands(gateway_events, wireguard_tx); + send_multiple_gateway_commands(gateway_events, gateway_tx); } } Ok(()) @@ -236,7 +236,7 @@ pub async fn process_device_access_changes( /// Check if devices found in an imported config file exist already, /// if they do assign a specified IP. /// Return a list of imported devices which need to be manually mapped to a user -/// and a list of WireGuard events to be sent out. +/// and a list of gateway commands to be sent out. pub(crate) async fn handle_imported_devices( location: &WireguardNetwork, transaction: &mut PgConnection, diff --git a/crates/defguard_core/src/user_management.rs b/crates/defguard_core/src/user_management.rs index 640d6b096a..e059e189b3 100644 --- a/crates/defguard_core/src/user_management.rs +++ b/crates/defguard_core/src/user_management.rs @@ -18,7 +18,7 @@ use crate::{ pub async fn delete_user_and_cleanup_devices( user: User, conn: &mut PgConnection, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), WebError> { let username = user.username.clone(); debug!("Deleting user {username}, removing his devices from gateways and updating ldap...",); @@ -57,7 +57,7 @@ pub async fn delete_user_and_cleanup_devices( } } - send_multiple_gateway_commands(events, wg_tx); + send_multiple_gateway_commands(events, gateway_tx); info!( "The user {} has been deleted and his devices removed from gateways.", &username @@ -69,12 +69,12 @@ pub async fn delete_user_and_cleanup_devices( pub async fn disable_user( user: &mut User, conn: &mut PgConnection, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), WebError> { user.is_active = false; user.save(&mut *conn).await?; user.logout_all_sessions(&mut *conn).await?; - sync_allowed_user_devices(user, conn, wg_tx).await?; + sync_allowed_user_devices(user, conn, gateway_tx).await?; Ok(()) } @@ -82,7 +82,7 @@ pub async fn disable_user( pub async fn sync_allowed_user_devices( user: &User, conn: &mut PgConnection, - wg_tx: &Sender, + gateway_tx: &Sender, ) -> Result<(), WebError> { debug!("Syncing allowed devices of user {}", user.username); let locations = WireguardNetwork::all(&mut *conn).await?; @@ -93,7 +93,7 @@ pub async fn sync_allowed_user_devices( // check if any peers were updated if !gateway_events.is_empty() { // send peer update events - send_multiple_gateway_commands(gateway_events, wg_tx); + send_multiple_gateway_commands(gateway_events, gateway_tx); } // send firewall config update if ACLs & enterprise features are enabled @@ -102,7 +102,7 @@ pub async fn sync_allowed_user_devices( { send_gateway_command( GatewayCommand::FirewallConfigChanged(location.id, firewall_config), - wg_tx, + gateway_tx, ); } } diff --git a/crates/defguard_core/src/utility_thread.rs b/crates/defguard_core/src/utility_thread.rs index 3c0088faf2..2ad2f41f65 100644 --- a/crates/defguard_core/src/utility_thread.rs +++ b/crates/defguard_core/src/utility_thread.rs @@ -46,7 +46,7 @@ const ACL_EXPIRY_SYSTEM_ACTOR: &str = "system:acl-expiry"; #[instrument(skip_all)] pub async fn run_utility_thread( pool: &PgPool, - wireguard_tx: broadcast::Sender, + gateway_tx: broadcast::Sender, proxy_control_tx: mpsc::Sender, web_reload_tx: broadcast::Sender<()>, ) -> Result<(), anyhow::Error> { @@ -64,7 +64,7 @@ pub async fn run_utility_thread( let directory_sync_task = || async { if let Err(e) = Box::pin( - do_directory_sync(pool, &wireguard_tx).instrument(info_span!("directory_sync_task")), + do_directory_sync(pool, &gateway_tx).instrument(info_span!("directory_sync_task")), ) .await { @@ -100,7 +100,7 @@ pub async fn run_utility_thread( }; let expired_acl_rules_task = || async { - if let Err(err) = expired_acl_rules_check(pool, wireguard_tx.clone()) + if let Err(err) = expired_acl_rules_check(pool, gateway_tx.clone()) .instrument(info_span!("expired_acl_rules_task")) .await { @@ -174,7 +174,7 @@ pub async fn run_utility_thread( {new_enterprise_enabled}" ); if let Err(err) = - enterprise_status_check(pool, wireguard_tx.clone(), new_enterprise_enabled) + enterprise_status_check(pool, gateway_tx.clone(), new_enterprise_enabled) .instrument(info_span!("enterprise_status_check")) .await { @@ -199,7 +199,7 @@ pub async fn run_utility_thread( /// Check if enterprise status has changed and perform any necessary actions async fn enterprise_status_check( pool: &PgPool, - wireguard_tx: broadcast::Sender, + gateway_tx: broadcast::Sender, enable_enterprise: bool, ) -> Result<(), anyhow::Error> { // fetch all ACL-enabled networks @@ -221,13 +221,13 @@ async fn enterprise_status_check( // Handle service location update or just update the firewall if location.service_location_mode == ServiceLocationMode::Disabled { - wireguard_tx.send(GatewayCommand::FirewallConfigChanged( + gateway_tx.send(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, ))?; } else { let new_peers = get_location_allowed_peers(&location, &mut *transaction).await?; - wireguard_tx.send(GatewayCommand::NetworkModified( + gateway_tx.send(GatewayCommand::NetworkModified( location.id, location, new_peers, @@ -242,13 +242,13 @@ async fn enterprise_status_check( for location in locations { if location.service_location_mode == ServiceLocationMode::Disabled { debug!("Disabling gateway firewall configuration for location {location:?}"); - wireguard_tx.send(GatewayCommand::FirewallDisabled(location.id))?; + gateway_tx.send(GatewayCommand::FirewallDisabled(location.id))?; } else { debug!( "Disabling gateway firewall configuration and service location client \ connections for location {location}" ); - wireguard_tx.send(GatewayCommand::NetworkModified( + gateway_tx.send(GatewayCommand::NetworkModified( location.id, location, // Send empty peer list, we are disabling the service location @@ -265,7 +265,7 @@ async fn enterprise_status_check( /// Find newly expired ACL rules and update their status. async fn expired_acl_rules_check( pool: &PgPool, - wireguard_tx: broadcast::Sender, + gateway_tx: broadcast::Sender, ) -> Result<(), anyhow::Error> { // mark relevant rules as expired let updated_rules = query_as!( @@ -309,7 +309,7 @@ async fn expired_acl_rules_check( match try_get_location_firewall_config(&location, &mut conn).await? { Some(firewall_config) => { debug!("Sending firewall update event for location {location}"); - wireguard_tx.send(GatewayCommand::FirewallConfigChanged( + gateway_tx.send(GatewayCommand::FirewallConfigChanged( location.id, firewall_config, ))?; @@ -317,7 +317,7 @@ async fn expired_acl_rules_check( None => { debug!( "No firewall config generated for location {location}. Not sending a \ - gateway event" + gateway command" ); } } diff --git a/crates/defguard_core/tests/integration/api/common/mod.rs b/crates/defguard_core/tests/integration/api/common/mod.rs index fc62437088..725d13e8b5 100644 --- a/crates/defguard_core/tests/integration/api/common/mod.rs +++ b/crates/defguard_core/tests/integration/api/common/mod.rs @@ -60,7 +60,7 @@ pub const X_FORWARDED_URI: HeaderName = HeaderName::from_static("x-forwarded-uri pub(crate) struct ClientState { pub pool: PgPool, pub worker_state: Arc>, - pub wireguard_rx: Receiver, + pub gateway_rx: Receiver, pub test_user: User, #[allow(dead_code)] pub config: DefGuardConfig, @@ -70,14 +70,14 @@ impl ClientState { pub fn new( pool: PgPool, worker_state: Arc>, - wireguard_rx: Receiver, + gateway_rx: Receiver, test_user: User, config: DefGuardConfig, ) -> Self { Self { pool, worker_state, - wireguard_rx, + gateway_rx, test_user, config, } @@ -92,7 +92,7 @@ pub(crate) async fn make_base_client( let (api_event_tx, api_event_rx) = unbounded_channel::(); let (tx, rx) = unbounded_channel::(); let worker_state = Arc::new(Mutex::new(WorkerState::new(tx.clone()))); - let (wg_tx, wg_rx) = broadcast::channel::(16); + let (gateway_tx, gateway_rx) = broadcast::channel::(16); let failed_logins = FailedLoginMap::new(); let failed_logins = Arc::new(Mutex::new(failed_logins)); @@ -113,7 +113,7 @@ pub(crate) async fn make_base_client( let client_state = ClientState::new( pool.clone(), worker_state.clone(), - wg_rx, + gateway_rx, User::find_by_username(&pool, "hpotter") .await .unwrap() @@ -146,7 +146,7 @@ pub(crate) async fn make_base_client( let webapp = build_webapp( tx, rx, - wg_tx, + gateway_tx, web_reload_tx, worker_state, pool, diff --git a/crates/defguard_core/tests/integration/api/proxy_certs.rs b/crates/defguard_core/tests/integration/api/proxy_certs.rs index aeb461a73c..86c5ccc84e 100644 --- a/crates/defguard_core/tests/integration/api/proxy_certs.rs +++ b/crates/defguard_core/tests/integration/api/proxy_certs.rs @@ -114,7 +114,7 @@ async fn make_test_client_with_proxy_rx( let (api_event_tx, api_event_rx) = unbounded_channel::(); let (tx, rx) = unbounded_channel::(); let worker_state = Arc::new(Mutex::new(WorkerState::new(tx.clone()))); - let (wg_tx, _wg_rx) = broadcast::channel::(16); + let (gateway_tx, _wg_rx) = broadcast::channel::(16); let failed_logins = Arc::new(Mutex::new(FailedLoginMap::new())); @@ -140,7 +140,7 @@ async fn make_test_client_with_proxy_rx( let webapp = build_webapp( tx, rx, - wg_tx, + gateway_tx, web_reload_tx, worker_state, pool.clone(), diff --git a/crates/defguard_core/tests/integration/api/wireguard.rs b/crates/defguard_core/tests/integration/api/wireguard.rs index f0db38bbe3..150ad98da1 100644 --- a/crates/defguard_core/tests/integration/api/wireguard.rs +++ b/crates/defguard_core/tests/integration/api/wireguard.rs @@ -44,7 +44,7 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -54,7 +54,7 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { let response = make_network(&client, "network").await; let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); // check vpn locations for `admin` group @@ -102,7 +102,7 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { ] ); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkModified(..)); // check vpn locations for `admin` group @@ -134,7 +134,7 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkDeleted(..)); } @@ -507,7 +507,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -515,7 +515,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { // create network make_network(&client, "network").await; - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network details @@ -534,7 +534,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::CREATED); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::DeviceCreated(..)); // an IP was assigned for new device @@ -550,7 +550,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { // add another network make_network(&client, "network").await; assert_matches!( - wg_rx.try_recv().unwrap(), + gateway_rx.try_recv().unwrap(), GatewayCommand::NetworkCreated(..) ); @@ -597,7 +597,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::DeviceModified(..)); // device details @@ -639,7 +639,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkDeleted(..)); // delete device @@ -648,7 +648,7 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::DeviceDeleted(..)); let response = client.get("/api/v1/device").json(&device).send().await; @@ -933,7 +933,7 @@ async fn test_device_pubkey(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -941,7 +941,7 @@ async fn test_device_pubkey(_: PgPoolOptions, options: PgConnectOptions) { // create network make_network(&client, "network").await; - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network details diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs b/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs index d9f4d97b0c..3873de4713 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs @@ -147,7 +147,7 @@ async fn test_create_new_network(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; let (_users, devices) = setup_test_users(&client_state.pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -179,9 +179,9 @@ async fn test_create_new_network(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); // network configuration was created only for admin and allowed user let peers = get_location_allowed_peers(&network, &client_state.pool) @@ -199,7 +199,7 @@ async fn test_create_new_network_allow_all_groups(_: PgPoolOptions, options: PgC let (client, client_state) = make_test_client(pool).await; let (_users, devices) = setup_test_users(&client_state.pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -236,7 +236,7 @@ async fn test_create_new_network_allow_all_groups(_: PgPoolOptions, options: PgC .unwrap(); assert_eq!(allowed_groups, vec!["allowed group"]); assert_matches!( - wg_rx.try_recv().unwrap(), + gateway_rx.try_recv().unwrap(), GatewayCommand::NetworkCreated(..) ); @@ -257,7 +257,7 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; let (_users, devices) = setup_test_users(&client_state.pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -289,7 +289,7 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network configuration was created for admin and the allowed group member @@ -325,7 +325,7 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); assert_matches!( - wg_rx.try_recv().unwrap(), + gateway_rx.try_recv().unwrap(), GatewayCommand::NetworkModified(..) ); @@ -361,7 +361,7 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); assert_matches!( - wg_rx.try_recv().unwrap(), + gateway_rx.try_recv().unwrap(), GatewayCommand::NetworkModified(..) ); @@ -398,7 +398,7 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); assert_matches!( - wg_rx.try_recv().unwrap(), + gateway_rx.try_recv().unwrap(), GatewayCommand::NetworkModified(..) ); @@ -409,7 +409,7 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(new_peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(new_peers[1].pubkey, devices[2].wireguard_pubkey); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); } #[sqlx::test] @@ -419,7 +419,7 @@ async fn test_modify_network_enable_allow_all_groups(_: PgPoolOptions, options: let (client, client_state) = make_test_client(pool).await; let (_users, devices) = setup_test_users(&client_state.pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -450,7 +450,7 @@ async fn test_modify_network_enable_allow_all_groups(_: PgPoolOptions, options: assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; assert_matches!( - wg_rx.try_recv().unwrap(), + gateway_rx.try_recv().unwrap(), GatewayCommand::NetworkCreated(..) ); @@ -490,7 +490,7 @@ async fn test_modify_network_enable_allow_all_groups(_: PgPoolOptions, options: .unwrap(); assert_eq!(allowed_groups, vec!["allowed group"]); assert_matches!( - wg_rx.try_recv().unwrap(), + gateway_rx.try_recv().unwrap(), GatewayCommand::NetworkModified(..) ); @@ -512,7 +512,7 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne let (client, client_state) = make_test_client(pool).await; let (_users, devices) = setup_test_users(&client_state.pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -571,11 +571,11 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(peers[1].pubkey, devices[1].wireguard_pubkey); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); // network config was only created for one of the existing devices and the admin device - let GatewayCommand::DeviceModified(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceModified(device_info) = gateway_rx.try_recv().unwrap() else { panic!() }; assert_eq!(device_info.device.id, devices[1].id); @@ -586,7 +586,7 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne peers[1].allowed_ips[0] ); - let GatewayCommand::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceCreated(device_info) = gateway_rx.try_recv().unwrap() else { panic!() }; assert_eq!(device_info.device.id, devices[0].id); @@ -597,7 +597,7 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne peers[0].allowed_ips[0] ); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); } #[sqlx::test] @@ -607,7 +607,7 @@ async fn test_import_mapping_devices(_: PgPoolOptions, options: PgConnectOptions let (client, client_state) = make_test_client(pool).await; let (users, devices) = setup_test_users(&client_state.pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -654,7 +654,7 @@ PersistentKeepalive = 300 let mut mapped_devices = response.devices; assert_eq!(mapped_devices.len(), 4); for _ in 0..3 { - wg_rx.try_recv().unwrap(); + gateway_rx.try_recv().unwrap(); } // assign devices to users @@ -680,7 +680,7 @@ PersistentKeepalive = 300 assert_eq!(peers[3].pubkey, mapped_devices[1].wireguard_pubkey); // assert events - let GatewayCommand::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceCreated(device_info) = gateway_rx.try_recv().unwrap() else { panic!() }; assert_eq!( @@ -694,7 +694,7 @@ PersistentKeepalive = 300 mapped_devices[0].wireguard_ips, ); - let GatewayCommand::DeviceCreated(device_info) = wg_rx.try_recv().unwrap() else { + let GatewayCommand::DeviceCreated(device_info) = gateway_rx.try_recv().unwrap() else { panic!() }; assert_eq!( @@ -708,7 +708,7 @@ PersistentKeepalive = 300 mapped_devices[1].wireguard_ips, ); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); } /// Test that changing groups for a particular user generates correct update events @@ -719,7 +719,7 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; let (_users, devices) = setup_test_users(&client_state.pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -751,9 +751,9 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); // network configuration was created only for admin and allowed user let peers = get_location_allowed_peers(&network, &client_state.pool) @@ -773,9 +773,9 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::DeviceDeleted(..)); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); let peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -793,7 +793,7 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); let peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -811,9 +811,9 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::DeviceCreated(..)); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); let peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -833,7 +833,7 @@ async fn test_modify_user_no_effect_when_allow_all_groups( let (client, client_state) = make_test_client(pool).await; let (_users, devices) = setup_test_users(&client_state.pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -864,10 +864,10 @@ async fn test_modify_user_no_effect_when_allow_all_groups( assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; assert_matches!( - wg_rx.try_recv().unwrap(), + gateway_rx.try_recv().unwrap(), GatewayCommand::NetworkCreated(..) ); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); let peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -883,7 +883,7 @@ async fn test_modify_user_no_effect_when_allow_all_groups( .await; assert_eq!(response.status(), StatusCode::OK); - assert_err!(wg_rx.try_recv()); + assert_err!(gateway_rx.try_recv()); let peers = get_location_allowed_peers(&network, &client_state.pool) .await @@ -946,7 +946,7 @@ async fn test_delete_only_allowed_group(_: PgPoolOptions, options: PgConnectOpti let (client, client_state) = make_test_client(pool).await; let (_users, devices) = setup_test_users(&client_state.pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -978,7 +978,7 @@ async fn test_delete_only_allowed_group(_: PgPoolOptions, options: PgConnectOpti assert_eq!(response.status(), StatusCode::CREATED); let network: WireguardNetwork = response.json().await; assert_eq!(network.name, "network"); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); let peers = get_location_allowed_peers(&network, &client_state.pool) diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs index e7c71866ff..13fc2422a5 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs @@ -98,7 +98,7 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { let (client, client_state) = make_test_client(pool).await; - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -108,12 +108,12 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { let response = make_first_network(&client).await; let network_1: WireguardNetwork = response.json().await; assert_eq!(network_1.name, "network"); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); let response = make_second_network(&client).await; let network_2: WireguardNetwork = response.json().await; assert_eq!(network_2.name, "network-2"); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); // ip suggestions @@ -201,7 +201,7 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { let configured = json["device"]["configured"].as_bool().unwrap(); let config_text = json["config"]["config"].as_str().unwrap(); assert!(configured); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::DeviceCreated(..)); // download WG config @@ -238,7 +238,7 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { .unwrap(); assert_eq!(device.name, "device-1"); assert_eq!(device.description, Some("new description".to_owned())); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::DeviceModified(..)); // Make sure the device is only in the selected network diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs index dde6414091..a9dad7d566 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs @@ -102,7 +102,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { transaction.commit().await.unwrap(); - let mut wg_rx = client_state.wireguard_rx; + let mut gateway_rx = client_state.gateway_rx; let auth = Auth::new("admin", "pass123"); let response = &client.post("/api/v1/auth").json(&auth).send().await; @@ -137,13 +137,13 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(network.dns, Some("10.0.0.2".to_owned())); assert_eq!(network.allowed_ips, vec!["10.0.0.0/24".parse().unwrap()]); assert_eq!(network.connected_at, None); - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); assert_matches!(event, GatewayCommand::NetworkCreated(..)); // existing devices assertion // imported config for an existing device assert_matches!( - wg_rx.try_recv().unwrap(), + gateway_rx.try_recv().unwrap(), GatewayCommand::DeviceModified(..) ); let user_device_1 = UserDevice::from_device(&pool, device_1) @@ -156,7 +156,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { vec!["10.0.0.12"] ); // generated IP for other existing device - assert_matches!(wg_rx.try_recv().unwrap(), GatewayCommand::DeviceCreated(..)); + assert_matches!(gateway_rx.try_recv().unwrap(), GatewayCommand::DeviceCreated(..)); let user_device_2 = UserDevice::from_device(&pool, device_2) .await .unwrap() @@ -206,7 +206,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::CREATED); // assert events - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); match event { GatewayCommand::DeviceCreated(device_info) => { assert_eq!(device_info.device.name, "device_1"); @@ -214,7 +214,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { _ => unreachable!("Invalid event type received"), } - let event = wg_rx.try_recv().unwrap(); + let event = gateway_rx.try_recv().unwrap(); match event { GatewayCommand::DeviceCreated(device_info) => { assert_eq!(device_info.device.name, "device_2"); @@ -222,7 +222,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { _ => unreachable!("Invalid event type received"), } - let event = wg_rx.try_recv(); + let event = gateway_rx.try_recv(); assert_matches!(event, Err(TryRecvError::Empty)); // assert user devices diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index 069f9f07e7..d5252bab30 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -161,13 +161,13 @@ mod tests { let (_bidi_tx, bidi_rx) = unbounded_channel(); let (_session_manager_tx, session_manager_rx) = unbounded_channel(); let (event_logger_tx, event_logger_rx) = unbounded_channel(); - let (wireguard_tx, _wireguard_rx) = broadcast::channel::(1); + let (gateway_tx, _gateway_rx) = broadcast::channel::(1); ( EventRouter::new( RouterReceiverSet::new(api_rx, bidi_rx, session_manager_rx), event_logger_tx, - wireguard_tx, + gateway_tx, Arc::new(Notify::new()), ), event_logger_rx, diff --git a/crates/defguard_event_router/src/lib.rs b/crates/defguard_event_router/src/lib.rs index d8ad3fb79a..193744288c 100644 --- a/crates/defguard_event_router/src/lib.rs +++ b/crates/defguard_event_router/src/lib.rs @@ -13,7 +13,7 @@ //! MPSC channel. //! 2. The router processes these events and forwards them to the appropriate services: //! - Activity log events go to the event logger service -//! - WireGuard events go to the gateway service +//! - gateway commands go to the gateway service //! - Mail events go to the mail service //! - etc. @@ -61,7 +61,7 @@ impl RouterReceiverSet { struct EventRouter { receivers: RouterReceiverSet, event_logger_tx: UnboundedSender, - wireguard_tx: Sender, + gateway_tx: Sender, activity_log_stream_reload_notify: Arc, } @@ -87,13 +87,13 @@ impl EventRouter { fn new( receivers: RouterReceiverSet, event_logger_tx: UnboundedSender, - wireguard_tx: Sender, + gateway_tx: Sender, activity_log_stream_reload_notify: Arc, ) -> Self { Self { receivers, event_logger_tx, - wireguard_tx, + gateway_tx, activity_log_stream_reload_notify, } } @@ -138,7 +138,7 @@ impl EventRouter { pub async fn run_event_router( receivers: RouterReceiverSet, event_logger_tx: UnboundedSender, - wireguard_tx: Sender, + gateway_tx: Sender, activity_log_stream_reload_notify: Arc, ) -> Result<(), EventRouterError> { info!("Starting main event router service"); @@ -146,7 +146,7 @@ pub async fn run_event_router( let mut event_router = EventRouter::new( receivers, event_logger_tx, - wireguard_tx, + gateway_tx, activity_log_stream_reload_notify, ); diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 32a6c23030..8bed65d10c 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -615,7 +615,7 @@ impl GatewayHandler { } } -/// Helper struct for handling gateway events. +/// Helper struct for handling gateway commands. struct GatewayUpdatesHandler { network_id: Id, network: WireguardNetwork, @@ -712,7 +712,7 @@ impl GatewayUpdatesHandler { /// Process incoming Gateway events /// - /// Main gRPC server uses a shared channel for broadcasting all gateway events + /// Main gRPC server uses a shared channel for broadcasting all gateway commands /// so the handler must determine if an event is relevant for the network being serviced async fn run(&mut self) { info!( diff --git a/crates/defguard_proxy_manager/src/handler.rs b/crates/defguard_proxy_manager/src/handler.rs index 15d8aa9fe5..44b668f589 100644 --- a/crates/defguard_proxy_manager/src/handler.rs +++ b/crates/defguard_proxy_manager/src/handler.rs @@ -471,7 +471,7 @@ impl ProxyHandler { async fn message_loop( &mut self, tx: UnboundedSender, - wireguard_tx: Sender, + gateway_tx: Sender, resp_stream: &mut Streaming, ) -> Result<(), ProxyError> { let pool = self.pool.clone(); @@ -863,7 +863,7 @@ impl ProxyHandler { if let Err(err) = sync_user_groups_if_configured( &user, &pool, - &wireguard_tx, + &gateway_tx, ) .await { diff --git a/crates/defguard_proxy_manager/src/servers/enrollment.rs b/crates/defguard_proxy_manager/src/servers/enrollment.rs index 0a1cb588ef..c23760acdd 100644 --- a/crates/defguard_proxy_manager/src/servers/enrollment.rs +++ b/crates/defguard_proxy_manager/src/servers/enrollment.rs @@ -48,7 +48,7 @@ use tonic::Status; pub(crate) struct EnrollmentServer { pool: PgPool, - wireguard_tx: Sender, + gateway_tx: Sender, bidi_event_tx: UnboundedSender, } @@ -56,12 +56,12 @@ impl EnrollmentServer { #[must_use] pub(crate) fn new( pool: PgPool, - wireguard_tx: Sender, + gateway_tx: Sender, bidi_event_tx: UnboundedSender, ) -> Self { Self { pool, - wireguard_tx, + gateway_tx, bidi_event_tx, } } @@ -98,7 +98,7 @@ impl EnrollmentServer { /// Sends given `GatewayCommand` to be handled by gateway manager service pub(crate) fn send_gateway_command(&self, event: GatewayCommand) { - if let Err(err) = self.wireguard_tx.send(event) { + if let Err(err) = self.gateway_tx.send(event) { error!("Error sending Gateway command: {err}"); } } @@ -1201,9 +1201,9 @@ mod test { settings.enrollment_send_welcome_email = false; update_current_settings(&pool, settings).await.unwrap(); - let (wireguard_tx, _) = broadcast::channel(1); + let (gateway_tx, _) = broadcast::channel(1); let (bidi_event_tx, _) = unbounded_channel(); - let server = EnrollmentServer::new(pool.clone(), wireguard_tx, bidi_event_tx); + let server = EnrollmentServer::new(pool.clone(), gateway_tx, bidi_event_tx); let mut transaction = pool.begin().await.unwrap(); let result = server diff --git a/crates/defguard_proxy_manager/src/tests/common/mod.rs b/crates/defguard_proxy_manager/src/tests/common/mod.rs index 81527c7fc3..2bef79c9a0 100644 --- a/crates/defguard_proxy_manager/src/tests/common/mod.rs +++ b/crates/defguard_proxy_manager/src/tests/common/mod.rs @@ -382,7 +382,7 @@ impl Drop for MockProxyHarness { pub(crate) struct HandlerTestContext { pub(crate) pool: PgPool, pub(crate) proxy: Proxy, - pub(crate) wireguard_tx: broadcast::Sender, + pub(crate) gateway_tx: broadcast::Sender, pub(crate) bidi_events_rx: UnboundedReceiver, pub(crate) mock_proxy: Option, handler_task: Option>>, @@ -407,9 +407,9 @@ impl HandlerTestContext { let proxy = create_proxy(&pool).await; - let (wireguard_tx, _) = broadcast::channel(16); + let (gateway_tx, _) = broadcast::channel(16); let (bidi_events_tx, bidi_events_rx) = mpsc::unbounded_channel::(); - let tx_set = ProxyTxSet::new(wireguard_tx.clone(), bidi_events_tx); + let tx_set = ProxyTxSet::new(gateway_tx.clone(), bidi_events_tx); let (_, certs_rx) = watch::channel(Arc::new(HashMap::new())); let incompatible_components = Arc::new(std::sync::RwLock::new( @@ -451,7 +451,7 @@ impl HandlerTestContext { Self { pool, proxy, - wireguard_tx, + gateway_tx, bidi_events_rx, mock_proxy: Some(mock_proxy), handler_task: Some(handler_task), @@ -587,9 +587,9 @@ impl ManagerTestContext { pub(crate) async fn start(&mut self) { assert!(self.manager_task.is_none(), "proxy manager already started"); - let (wireguard_tx, _) = broadcast::channel(16); + let (gateway_tx, _) = broadcast::channel(16); let (bidi_events_tx, _bidi_events_rx) = mpsc::unbounded_channel::(); - let tx_set = ProxyTxSet::new(wireguard_tx, bidi_events_tx); + let tx_set = ProxyTxSet::new(gateway_tx, bidi_events_tx); let incompatible_components = Arc::new(std::sync::RwLock::new( defguard_core::version::IncompatibleComponents::default(), diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs index d7cc39feb5..17e8cbbb55 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs @@ -259,8 +259,8 @@ async fn test_new_device_sends_gateway_device_created_event( // Start the enrollment session so Token::used_at is set. start_enrollment_session(&mut context, &token.id).await; - // Subscribe to gateway events BEFORE sending the request. - let mut gateway_rx = context.wireguard_tx.subscribe(); + // Subscribe to gateway commands BEFORE sending the request. + let mut gateway_rx = context.gateway_tx.subscribe(); let pubkey = "DhsoNUJPXGl2g5CdqrfE0d7r+AUSHyw5RlNgbXqHlKE="; context.mock_proxy().send_request(CoreRequest { @@ -279,12 +279,12 @@ async fn test_new_device_sends_gateway_device_created_event( // Check that a DeviceCreated event was broadcast. let event = timeout(TEST_TIMEOUT, gateway_rx.recv()) .await - .expect("timed out waiting for GatewayEvent::DeviceCreated") - .expect("gateway event channel closed"); + .expect("timed out waiting for GatewayCommand::DeviceCreated") + .expect("gateway command channel closed"); assert!( matches!(event, GatewayCommand::DeviceCreated(_)), - "expected DeviceCreated gateway event, got: {event:?}" + "expected DeviceCreated gateway command, got: {event:?}" ); context.finish().await.expect_server_finished().await; diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs index e72ba11ae3..f90aa9b1a2 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs @@ -123,9 +123,9 @@ async fn test_mfa_finish_succeeds_with_totp_code(_: PgPoolOptions, options: PgCo ) .await; - // Subscribe before finish so the handler's wireguard_tx.send() has a receiver, + // Subscribe before finish so the handler's gateway_tx.send() has a receiver, // and keep the receiver alive so we can assert on the event. - let mut gateway_rx = context.wireguard_tx.subscribe(); + let mut gateway_rx = context.gateway_tx.subscribe(); let code = generate_totp_code(&user); let (_, psk) = send_mfa_finish(&mut context, &token, Some(&code)).await; @@ -138,12 +138,12 @@ async fn test_mfa_finish_succeeds_with_totp_code(_: PgPoolOptions, options: PgCo let session = assert_vpn_session_exists(&context.pool, network.id, device.id).await; assert!(session.preshared_key.is_some()); - // Verify GatewayEvent::MfaSessionAuthorized was broadcast. + // Verify GatewayCommand::MfaSessionAuthorized was broadcast. // Use the already-subscribed receiver - subscribing after send_mfa_finish would miss the event. let event = timeout(RECEIVE_TIMEOUT, gateway_rx.recv()) .await - .expect("timed out waiting for GatewayEvent::MfaSessionAuthorized") - .expect("gateway event channel closed"); + .expect("timed out waiting for GatewayCommand::MfaSessionAuthorized") + .expect("gateway command channel closed"); let gateway_loc_id = match event { GatewayCommand::MfaSessionAuthorized(loc_id, _, _) => loc_id, other => panic!("expected MfaSessionAuthorized, got: {other:?}"), @@ -298,9 +298,9 @@ async fn test_mfa_finish_succeeds_and_creates_session(_: PgPoolOptions, options: .await; // Subscribe to the gateway broadcast BEFORE calling finish, so that the - // handler's wireguard_tx.send() has at least one active receiver (without + // handler's gateway_tx.send() has at least one active receiver (without // one the send would fail with SendError and return Internal). - let mut gateway_rx = context.wireguard_tx.subscribe(); + let mut gateway_rx = context.gateway_tx.subscribe(); // The start handler has already called generate_email_mfa_code internally // and the in-memory secret is still the same, so regenerating here gives @@ -316,11 +316,11 @@ async fn test_mfa_finish_succeeds_and_creates_session(_: PgPoolOptions, options: let session = assert_vpn_session_exists(&context.pool, network.id, device.id).await; assert!(session.preshared_key.is_some()); - // Verify GatewayEvent::MfaSessionAuthorized was broadcast + // Verify GatewayCommand::MfaSessionAuthorized was broadcast let event = timeout(RECEIVE_TIMEOUT, gateway_rx.recv()) .await - .expect("timed out waiting for GatewayEvent::MfaSessionAuthorized") - .expect("gateway event channel closed"); + .expect("timed out waiting for GatewayCommand::MfaSessionAuthorized") + .expect("gateway command channel closed"); let loc_id = match event { GatewayCommand::MfaSessionAuthorized(loc_id, _, _) => loc_id, other => panic!("expected MfaSessionAuthorized, got: {other:?}"), @@ -358,8 +358,8 @@ async fn test_mfa_token_valid_before_finish_invalid_after( let valid = send_token_validation(&mut context, &token).await; assert!(valid, "token must be valid after start"); - // Subscribe before finish so the handler's wireguard_tx.send() has a receiver - let _gateway_rx = context.wireguard_tx.subscribe(); + // Subscribe before finish so the handler's gateway_tx.send() has a receiver + let _gateway_rx = context.gateway_tx.subscribe(); let code = user.generate_email_mfa_code().expect("generate email code"); send_mfa_finish(&mut context, &token, Some(&code)).await; @@ -468,8 +468,8 @@ async fn test_mfa_await_remote_receives_psk_after_finish( // we proceed with the finish call. task::yield_now().await; - // Subscribe before finish so the handler's wireguard_tx.send() has a receiver - let _gateway_rx = context.wireguard_tx.subscribe(); + // Subscribe before finish so the handler's gateway_tx.send() has a receiver + let _gateway_rx = context.gateway_tx.subscribe(); // Now finish the MFA login with the correct code. Use the no-recv variant // because two responses will arrive (ClientMfaFinish + AwaitRemoteMfaFinish) @@ -509,7 +509,7 @@ async fn test_mfa_await_remote_receives_psk_after_finish( /// When a second MFA cycle completes for the same device+location the handler /// must: /// - disconnect the first `VpnClientSession` (state → Disconnected), -/// - emit `GatewayEvent::MfaSessionDisconnected` for the first session, and +/// - emit `GatewayCommand::MfaSessionDisconnected` for the first session, and /// - create a new active `VpnClientSession`. #[sqlx::test] async fn test_mfa_finish_replaces_existing_session_disconnects_old( @@ -525,7 +525,7 @@ async fn test_mfa_finish_replaces_existing_session_disconnects_old( // ---- First MFA cycle ---- // Must subscribe before finish so the send has a receiver. - let _gw_rx1 = context.wireguard_tx.subscribe(); + let _gw_rx1 = context.gateway_tx.subscribe(); let (_, token1) = send_mfa_start( &mut context, @@ -566,7 +566,7 @@ async fn test_mfa_finish_replaces_existing_session_disconnects_old( // Subscribe before finish so both MfaSessionDisconnected and // MfaSessionAuthorized have an active receiver. - let mut gw_rx2 = context.wireguard_tx.subscribe(); + let mut gw_rx2 = context.gateway_tx.subscribe(); let code2 = generate_totp_code(&user); let (_, psk2) = send_mfa_finish(&mut context, &token2, Some(&code2)).await; @@ -583,8 +583,8 @@ async fn test_mfa_finish_replaces_existing_session_disconnects_old( for _ in 0..2 { let event = timeout(RECEIVE_TIMEOUT, gw_rx2.recv()) .await - .expect("timed out waiting for gateway event after second MFA finish") - .expect("gateway event channel closed"); + .expect("timed out waiting for gateway command after second MFA finish") + .expect("gateway command channel closed"); match event { GatewayCommand::MfaSessionDisconnected(loc_id, ref dev) => { @@ -596,7 +596,7 @@ async fn test_mfa_finish_replaces_existing_session_disconnects_old( assert_eq!(loc_id, network.id, "authorized session location mismatch"); got_authorized = true; } - other => panic!("unexpected gateway event: {other:?}"), + other => panic!("unexpected gateway command: {other:?}"), } } assert!(got_disconnected, "MfaSessionDisconnected must be emitted"); diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/oidc.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/oidc.rs index 6b3e97b26c..176d52e0f7 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/oidc.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/oidc.rs @@ -282,7 +282,7 @@ async fn test_mfa_oidc_full_flow(_: PgPoolOptions, options: PgConnectOptions) { set_public_proxy_url(&context.pool, &mock.base_url).await; // Subscribe to gateway events before sending MFA finish. - let _gateway_rx = context.wireguard_tx.subscribe(); + let _gateway_rx = context.gateway_tx.subscribe(); // ---- Step 1: ClientMfaStart with Oidc method ---- let (_, mfa_token) = send_mfa_start( diff --git a/crates/defguard_setup/src/migration.rs b/crates/defguard_setup/src/migration.rs index 9d2d54d012..10cc85d0c2 100644 --- a/crates/defguard_setup/src/migration.rs +++ b/crates/defguard_setup/src/migration.rs @@ -60,7 +60,7 @@ use crate::handlers::{ pub struct MigrationWebapp { pub router: Router, _event_rx: mpsc::UnboundedReceiver, - _wireguard_rx: broadcast::Receiver, + _gateway_rx: broadcast::Receiver, _proxy_control_rx: mpsc::Receiver, } @@ -72,7 +72,7 @@ pub fn build_migration_webapp( let failed_logins = Arc::new(Mutex::new(FailedLoginMap::new())); let (webhook_tx, webhook_rx) = mpsc::unbounded_channel::(); let (event_tx, event_rx) = mpsc::unbounded_channel::(); - let (wireguard_tx, wireguard_rx) = broadcast::channel::(64); + let (gateway_tx, gateway_rx) = broadcast::channel::(64); let (web_reload_tx, _web_reload_rx) = broadcast::channel::<()>(8); let (proxy_control_tx, proxy_control_rx) = mpsc::channel(32); let incompatible_components = Arc::new(RwLock::new(IncompatibleComponents::default())); @@ -86,7 +86,7 @@ pub fn build_migration_webapp( pool.clone(), webhook_tx, webhook_rx, - wireguard_tx.clone(), + gateway_tx.clone(), web_reload_tx, key, failed_logins.clone(), @@ -175,7 +175,7 @@ pub fn build_migration_webapp( MigrationWebapp { router, _event_rx: event_rx, - _wireguard_rx: wireguard_rx, + _gateway_rx: gateway_rx, _proxy_control_rx: proxy_control_rx, } } From 09ccbf558a1a09a9da4132c1f838e25acd402c0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 19 May 2026 11:07:37 +0200 Subject: [PATCH 05/10] cleanup --- crates/defguard_common/src/gateway_event.rs | 30 ++++++++++++++----- crates/defguard_common/src/gateway_types.rs | 14 +++++++++ crates/defguard_core/src/appstate.rs | 8 ++--- .../src/enterprise/firewall/mod.rs | 2 +- .../defguard_proto/src/gateway_conversions.rs | 4 +++ 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/crates/defguard_common/src/gateway_event.rs b/crates/defguard_common/src/gateway_event.rs index 55d7c277de..36d8fdedb9 100644 --- a/crates/defguard_common/src/gateway_event.rs +++ b/crates/defguard_common/src/gateway_event.rs @@ -1,3 +1,9 @@ +//! Gateway command types and helpers for communicating with the gateway manager service. +//! +//! [`GatewayCommand`] is the primary type sent from core to the gateway manager over +//! an in-process broadcast channel. The gateway manager converts each command to the +//! appropriate protobuf wire message before forwarding it to the gateway daemon. + use tokio::sync::broadcast::Sender; use tracing::{debug, error}; @@ -12,6 +18,11 @@ use crate::{ gateway_types::{FirewallConfig, WireguardPeer}, }; +/// A command sent from core to the gateway manager service. +/// +/// Each variant instructs the gateway daemon to update its WireGuard state or +/// firewall configuration. Native Rust types are used throughout; conversion to +/// protobuf wire types happens at the serialization boundary in the gateway manager. #[derive(Clone, Debug)] pub enum GatewayCommand { NetworkCreated(Id, WireguardNetwork), @@ -34,19 +45,22 @@ pub enum GatewayCommand { /// Sends a [`GatewayCommand`] to the gateway manager service. /// /// In API handler context prefer `AppState::send_gateway_command`. -pub fn send_gateway_command(event: GatewayCommand, gateway_tx: &Sender) { - debug!("Sending the following command to Gateway Manager: {event:?}"); - if let Err(err) = gateway_tx.send(event) { - error!("Error sending Gateway command: {err}"); +pub fn send_gateway_command(command: GatewayCommand, gateway_tx: &Sender) { + debug!("Sending the following command to Gateway Manager: {command:?}"); + if let Err(err) = gateway_tx.send(command) { + error!("Error sending gateway command: {err}"); } } /// Sends multiple [`GatewayCommand`]s to the gateway manager service. /// /// In API handler context prefer `AppState::send_multiple_gateway_commands`. -pub fn send_multiple_gateway_commands(events: Vec, gateway_tx: &Sender) { - debug!("Sending {} gateway commands", events.len()); - for event in events { - send_gateway_command(event, gateway_tx); +pub fn send_multiple_gateway_commands( + commands: Vec, + gateway_tx: &Sender, +) { + debug!("Sending {} gateway commands", commands.len()); + for command in commands { + send_gateway_command(command, gateway_tx); } } diff --git a/crates/defguard_common/src/gateway_types.rs b/crates/defguard_common/src/gateway_types.rs index 4a3199bbe5..25737def93 100644 --- a/crates/defguard_common/src/gateway_types.rs +++ b/crates/defguard_common/src/gateway_types.rs @@ -1,3 +1,8 @@ +//! Native Rust types for data carried in [`crate::gateway_event::GatewayCommand`] variants. +//! +//! These are domain types; conversion to protobuf wire types happens at the +//! serialization boundary (gateway manager) via `From` impls in `defguard_proto`. + /// A WireGuard peer entry to be configured on a gateway. #[derive(Clone, Debug, PartialEq)] pub struct WireguardPeer { @@ -7,6 +12,7 @@ pub struct WireguardPeer { pub keepalive_interval: Option, } +/// Default firewall action applied to traffic that does not match any rule. #[derive(Clone, Debug, Default, PartialEq)] pub enum FirewallPolicy { #[default] @@ -15,6 +21,7 @@ pub enum FirewallPolicy { Deny, } +/// IP protocol version a firewall rule applies to. #[derive(Clone, Debug, Default, PartialEq)] pub enum IpVersion { #[default] @@ -23,6 +30,7 @@ pub enum IpVersion { Ipv6, } +/// Network protocol matched by a firewall rule. #[derive(Clone, Debug, PartialEq)] pub enum Protocol { Unspecified, @@ -42,6 +50,7 @@ impl From for Protocol { } } +/// An inclusive range of IP addresses. #[derive(Clone, Debug, PartialEq)] pub struct IpRange { pub start: String, @@ -59,18 +68,21 @@ pub enum IpAddress { IpSubnet(String), } +/// An inclusive range of port numbers. #[derive(Clone, Debug, PartialEq)] pub struct PortRange { pub start: u32, pub end: u32, } +/// A single port or an inclusive port range matched by a firewall rule. #[derive(Clone, Debug, PartialEq)] pub enum Port { Single(u32), Range(PortRange), } +/// A single ACL-derived firewall rule to be enforced on a gateway. #[derive(Clone, Debug, PartialEq)] pub struct FirewallRule { pub id: i64, @@ -83,6 +95,7 @@ pub struct FirewallRule { pub ip_version: IpVersion, } +/// Source NAT binding that rewrites the source IP of matching VPN traffic. #[derive(Clone, Debug, PartialEq)] pub struct SnatBinding { pub id: i64, @@ -91,6 +104,7 @@ pub struct SnatBinding { pub comment: Option, } +/// Full firewall configuration to be applied to a gateway location. #[derive(Clone, Debug, Default, PartialEq)] pub struct FirewallConfig { pub default_policy: FirewallPolicy, diff --git a/crates/defguard_core/src/appstate.rs b/crates/defguard_core/src/appstate.rs index 201bd14a9d..a69237d3cd 100644 --- a/crates/defguard_core/src/appstate.rs +++ b/crates/defguard_core/src/appstate.rs @@ -88,14 +88,14 @@ impl AppState { /// Sends given `GatewayCommand` to be handled by gateway manager service. /// Convenience wrapper around [`send_gateway_command`] - pub fn send_gateway_command(&self, event: GatewayCommand) { - send_gateway_command(event, &self.gateway_tx); + pub fn send_gateway_command(&self, command: GatewayCommand) { + send_gateway_command(command, &self.gateway_tx); } /// Sends multiple commands to be handled by gateway manager service. /// Convenience wrapper around [`send_multiple_gateway_commands`] - pub fn send_multiple_gateway_commands(&self, events: Vec) { - send_multiple_gateway_commands(events, &self.gateway_tx); + pub fn send_multiple_gateway_commands(&self, commands: Vec) { + send_multiple_gateway_commands(commands, &self.gateway_tx); } /// Sends event to the main event router diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs index af18056f79..1e57617280 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -939,7 +939,7 @@ async fn generate_user_snat_bindings_for_location( continue; } - // create the SNAT binding proto + // create the SNAT binding let snat_binding = SnatBinding { id: user_binding.id, source_addrs, diff --git a/crates/defguard_proto/src/gateway_conversions.rs b/crates/defguard_proto/src/gateway_conversions.rs index 7c09f84057..a2e92215c2 100644 --- a/crates/defguard_proto/src/gateway_conversions.rs +++ b/crates/defguard_proto/src/gateway_conversions.rs @@ -1,3 +1,7 @@ +//! Conversions from [`defguard_common::gateway_types`] domain types to their +//! protobuf counterparts. These are the sole serialization boundary between +//! the native gateway command representation and the wire protocol. + use defguard_common::gateway_types::{ FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpRange, IpVersion, Port, PortRange, Protocol, SnatBinding, WireguardPeer, From 1e201d2c826c011a33d00b36904c4c360d90f54d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 21 May 2026 21:45:23 +0200 Subject: [PATCH 06/10] post-merge cleanup --- .../src/enterprise/firewall/mod.rs | 2 +- .../defguard_gateway_manager/src/handler.rs | 3 +- .../src/tests/gateway_manager/handler.rs | 1 - .../src/tests/gateway_manager/handler/mfa.rs | 87 +++++++++++++++++-- .../tests/gateway_manager/handler/support.rs | 4 +- 5 files changed, 85 insertions(+), 12 deletions(-) diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs index c87a6dc5a4..aba2ea4b78 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -843,7 +843,7 @@ pub async fn try_get_location_firewall_config( generate_firewall_rules_from_acls(location.id, location_acls, &mut *conn).await?; let snat_bindings = generate_user_snat_bindings_for_location(location.id, &mut *conn).await?; let firewall_config = FirewallConfig { - default_policy: default_policy, + default_policy, rules: firewall_rules, snat_bindings, }; diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index f872acb30a..78d9fe6a24 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -1118,8 +1118,7 @@ mod tests { use tokio::sync::{broadcast, mpsc::unbounded_channel, watch}; use super::{ - FirewallConfig, GatewayHandler, GatewayUpdatesHandler, WireguardPeer, - try_protos_into_stats_message, + GatewayHandler, GatewayUpdatesHandler, WireguardPeer, try_protos_into_stats_message, }; fn test_network(location_mfa_mode: LocationMfaMode) -> WireguardNetwork { diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs index ea1d9fc3d2..6680f6ba07 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs @@ -2,7 +2,6 @@ mod support; use defguard_common::db::models::device::{DeviceInfo, WireguardNetworkDevice}; -use defguard_common::gateway_types::WireguardPeer; use defguard_core::grpc::GatewayCommand; use defguard_proto::gateway::{UpdateType, core_response}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs index 96341e529b..9b17a67c63 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs @@ -1,5 +1,80 @@ #[sqlx::test] -async fn test_matching_location_mfa_session_authorized_produces_peer_create( +async fn test_matching_location_posture_vpn_session_authorized_produces_peer_create( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_keepalive_interval = expected_keepalive_interval(&context); + enable_linux_posture_for_network(&context.pool, &context.network).await; + + let _ = context.complete_config_handshake().await; + let (device, network_device) = create_authorized_posture_device_for_current_network( + &context, + "posture-authorized-device", + "qk9cOvJ5pvyR0pL7B6V8DKtYlD1BHRY9QwIkEOROHjA=", + "10.10.0.42", + "posture-authorized-preshared-key", + ) + .await; + + assert_send_ok!( + context.events_tx().send(GatewayCommand::VpnSessionAuthorized( + context.network.id, + device, + network_device, + )), + "failed to broadcast posture VPN session authorized event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Create, + "qk9cOvJ5pvyR0pL7B6V8DKtYlD1BHRY9QwIkEOROHjA=", + &["10.10.0.42"], + Some("posture-authorized-preshared-key"), + Some(expected_keepalive_interval), + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_posture_vpn_session_authorized_without_psk_is_skipped( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + enable_linux_posture_for_network(&context.pool, &context.network).await; + + let _ = context.complete_config_handshake().await; + let (device, mut network_device) = create_authorized_posture_device_for_current_network( + &context, + "posture-authorized-without-psk-device", + "qk9cOvJ5pvyR0pL7B6V8DKtYlD1BHRY9QwIkEOROHjB=", + "10.10.0.43", + "posture-authorized-preshared-key", + ) + .await; + network_device.preshared_key = None; + + assert_send_ok!( + context.events_tx().send(GatewayCommand::VpnSessionAuthorized( + context.network.id, + device, + network_device, + )), + "failed to broadcast posture VPN session authorized event without PSK" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_vpn_session_authorized_produces_peer_create( _: PgPoolOptions, options: PgConnectOptions, ) { @@ -23,7 +98,7 @@ async fn test_matching_location_mfa_session_authorized_produces_peer_create( device, network_device, )), - "failed to broadcast MFA session authorized event" + "failed to broadcast VPN session authorized event" ); let outbound = context.mock_gateway_mut().recv_outbound().await; @@ -41,7 +116,7 @@ async fn test_matching_location_mfa_session_authorized_produces_peer_create( } #[sqlx::test] -async fn test_mfa_session_authorized_with_mismatched_network_id_is_ignored( +async fn test_vpn_session_authorized_with_mismatched_network_id_is_ignored( _: PgPoolOptions, options: PgConnectOptions, ) { @@ -69,7 +144,7 @@ async fn test_mfa_session_authorized_with_mismatched_network_id_is_ignored( device, network_device, )), - "failed to broadcast mismatched MFA session authorized event" + "failed to broadcast mismatched VPN session authorized event" ); context.mock_gateway_mut().expect_no_outbound().await; @@ -78,7 +153,7 @@ async fn test_mfa_session_authorized_with_mismatched_network_id_is_ignored( } #[sqlx::test] -async fn test_matching_location_mfa_session_disconnected_produces_peer_delete( +async fn test_matching_location_vpn_session_deauthorized_produces_peer_delete( _: PgPoolOptions, options: PgConnectOptions, ) { @@ -102,7 +177,7 @@ async fn test_matching_location_mfa_session_disconnected_produces_peer_delete( context.network.id, device, )), - "failed to broadcast MFA session disconnected event" + "failed to broadcast VPN session deauthorized event" ); let outbound = context.mock_gateway_mut().recv_outbound().await; diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs index eab99935f0..6b61405d65 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs @@ -10,8 +10,8 @@ use defguard_common::db::{ }, }; use defguard_common::gateway_types::{ - FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpRange, IpVersion, Port, - PortRange as GwPortRange, Protocol as GwProtocol, SnatBinding, + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, + Protocol as GwProtocol, SnatBinding, }; use defguard_core::{ enterprise::db::models::device_posture::{ From e0c57f406dab61abb8ef9d50712deb284cb240a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 22 May 2026 09:31:54 +0200 Subject: [PATCH 07/10] update query data --- ...72929d6a612846acf4a5489c36a04cb57de59.json | 22 +++++++++++++++++++ ...ab52c14293fa64d260db812c89ed43a5a7cc8.json | 15 +++++++++++++ ...68ebe8d0b64fe8bf6dafb9d92768397fbcecd.json | 15 +++++++++++++ ...4f06bc8e8fce23ee336d64f8608f0e7eefc8c.json | 14 ++++++++++++ ...1a2f9a2abcb8d044ca90a536c6e567e320d69.json | 15 +++++++++++++ 5 files changed, 81 insertions(+) create mode 100644 .sqlx/query-0ad599ef120ecc02b030bd1276172929d6a612846acf4a5489c36a04cb57de59.json create mode 100644 .sqlx/query-1bb3c8ecbd6500717d639678e08ab52c14293fa64d260db812c89ed43a5a7cc8.json create mode 100644 .sqlx/query-8662e89e31b58def986e1a155cd68ebe8d0b64fe8bf6dafb9d92768397fbcecd.json create mode 100644 .sqlx/query-8e70e158b41d640eba1333c0aea4f06bc8e8fce23ee336d64f8608f0e7eefc8c.json create mode 100644 .sqlx/query-e0a5e5060afa2628112ae29b8a51a2f9a2abcb8d044ca90a536c6e567e320d69.json diff --git a/.sqlx/query-0ad599ef120ecc02b030bd1276172929d6a612846acf4a5489c36a04cb57de59.json b/.sqlx/query-0ad599ef120ecc02b030bd1276172929d6a612846acf4a5489c36a04cb57de59.json new file mode 100644 index 0000000000..ce67e42930 --- /dev/null +++ b/.sqlx/query-0ad599ef120ecc02b030bd1276172929d6a612846acf4a5489c36a04cb57de59.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT recovery_codes FROM \"user\" WHERE id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "recovery_codes", + "type_info": "TextArray" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false + ] + }, + "hash": "0ad599ef120ecc02b030bd1276172929d6a612846acf4a5489c36a04cb57de59" +} diff --git a/.sqlx/query-1bb3c8ecbd6500717d639678e08ab52c14293fa64d260db812c89ed43a5a7cc8.json b/.sqlx/query-1bb3c8ecbd6500717d639678e08ab52c14293fa64d260db812c89ed43a5a7cc8.json new file mode 100644 index 0000000000..8d314a5faa --- /dev/null +++ b/.sqlx/query-1bb3c8ecbd6500717d639678e08ab52c14293fa64d260db812c89ed43a5a7cc8.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE session SET expires = $1 WHERE id = $2", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Timestamp", + "Text" + ] + }, + "nullable": [] + }, + "hash": "1bb3c8ecbd6500717d639678e08ab52c14293fa64d260db812c89ed43a5a7cc8" +} diff --git a/.sqlx/query-8662e89e31b58def986e1a155cd68ebe8d0b64fe8bf6dafb9d92768397fbcecd.json b/.sqlx/query-8662e89e31b58def986e1a155cd68ebe8d0b64fe8bf6dafb9d92768397fbcecd.json new file mode 100644 index 0000000000..82b7e023c7 --- /dev/null +++ b/.sqlx/query-8662e89e31b58def986e1a155cd68ebe8d0b64fe8bf6dafb9d92768397fbcecd.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "INSERT INTO group_user (user_id, group_id) VALUES ($1, $2)", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "8662e89e31b58def986e1a155cd68ebe8d0b64fe8bf6dafb9d92768397fbcecd" +} diff --git a/.sqlx/query-8e70e158b41d640eba1333c0aea4f06bc8e8fce23ee336d64f8608f0e7eefc8c.json b/.sqlx/query-8e70e158b41d640eba1333c0aea4f06bc8e8fce23ee336d64f8608f0e7eefc8c.json new file mode 100644 index 0000000000..d62bfc64a5 --- /dev/null +++ b/.sqlx/query-8e70e158b41d640eba1333c0aea4f06bc8e8fce23ee336d64f8608f0e7eefc8c.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "INSERT INTO group_user (group_id, user_id) VALUES (1, $1)", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [] + }, + "hash": "8e70e158b41d640eba1333c0aea4f06bc8e8fce23ee336d64f8608f0e7eefc8c" +} diff --git a/.sqlx/query-e0a5e5060afa2628112ae29b8a51a2f9a2abcb8d044ca90a536c6e567e320d69.json b/.sqlx/query-e0a5e5060afa2628112ae29b8a51a2f9a2abcb8d044ca90a536c6e567e320d69.json new file mode 100644 index 0000000000..6fff00d5e4 --- /dev/null +++ b/.sqlx/query-e0a5e5060afa2628112ae29b8a51a2f9a2abcb8d044ca90a536c6e567e320d69.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "INSERT INTO group_user (group_id, user_id) VALUES (1, $1), (1, $2)", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "e0a5e5060afa2628112ae29b8a51a2f9a2abcb8d044ca90a536c6e567e320d69" +} From 4b9b256224de7365ef73eaffff7b356388669209 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 22 May 2026 09:43:40 +0200 Subject: [PATCH 08/10] revert unnecessary submodule change --- web/src/shared/defguard-ui | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/shared/defguard-ui b/web/src/shared/defguard-ui index 034614b3dc..fae350a2ba 160000 --- a/web/src/shared/defguard-ui +++ b/web/src/shared/defguard-ui @@ -1 +1 @@ -Subproject commit 034614b3dc8c43fae41c2c79e9c3993d2f732eb8 +Subproject commit fae350a2baadddb1edffce0ff30408283cdd8699 From 6428bd2ccc70c7de5eff553b055000160652fe27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 22 May 2026 11:53:22 +0200 Subject: [PATCH 09/10] fix tests --- crates/defguard_session_manager/src/lib.rs | 30 ++++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index f012dc51c3..a7969ca59c 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -219,6 +219,9 @@ impl SessionManager { "[{index}/{locations_count}] Disconnecting inactive sessions in location {location}" ); + let session_authorization_required = + location.mfa_enabled() || location.has_postures(&mut *transaction).await?; + // get all connected sessions which have become inactive let inactive_sessions = VpnClientSession::get_all_inactive_for_location(&mut *transaction, &location) @@ -234,13 +237,18 @@ impl SessionManager { "Disconnecting inactive session for user {}, device {} in location {location}", session.user_id, session.device_id ); - self.disconnect_session(&mut transaction, session, &location) - .await?; + self.disconnect_session( + &mut transaction, + session, + &location, + session_authorization_required, + ) + .await?; } // get all sessions which were created but have never connected - // this is only relevant for MFA locations - if location.mfa_enabled() { + // this is only relevant for locations that authorize peers at runtime + if session_authorization_required { let unused_sessions = VpnClientSession::get_never_connected(&mut *transaction, &location).await?; @@ -254,8 +262,13 @@ impl SessionManager { "Disconnecting never connected session for user {}, device {} in location {location}", session.user_id, session.device_id ); - self.disconnect_session(&mut transaction, session, &location) - .await?; + self.disconnect_session( + &mut transaction, + session, + &location, + session_authorization_required, + ) + .await?; } } } @@ -274,6 +287,7 @@ impl SessionManager { transaction: &mut PgConnection, mut session: VpnClientSession, location: &WireguardNetwork, + session_authorization_required: bool, ) -> Result<(), SessionManagerError> { let disconnect_timestamp = Utc::now().naive_utc(); let is_connected = session.connected_at.is_some(); @@ -294,8 +308,8 @@ impl SessionManager { session.device_id, ))?; - // remove peers from GW for MFA locations - if location.mfa_enabled() { + // remove peers from GW for locations that authorize peers at runtime + if session_authorization_required { self.send_peer_disconnect_message(location, &device)?; } From aa31144ad61208cb42b770e983e1f75580065f2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 22 May 2026 11:58:41 +0200 Subject: [PATCH 10/10] merge use statements --- .../enterprise/firewall/tests/destination.rs | 13 ++++---- .../src/enterprise/firewall/tests/gh1868.rs | 14 ++++---- .../src/enterprise/firewall/tests/mod.rs | 22 +++++++------ .../src/location_management/allowed_peers.rs | 6 ++-- .../defguard_gateway_manager/src/handler.rs | 2 +- crates/defguard_gateway_manager/src/lib.rs | 2 +- .../src/tests/common/mod.rs | 2 +- .../src/tests/gateway_manager/handler.rs | 2 +- .../tests/gateway_manager/handler/support.rs | 32 +++++++++---------- crates/defguard_proxy_manager/src/lib.rs | 2 +- .../src/tests/common/mod.rs | 16 ++++++---- .../tests/proxy_manager/handler/enrollment.rs | 10 +++--- .../src/tests/proxy_manager/handler/mfa.rs | 3 +- crates/defguard_session_manager/src/error.rs | 3 +- crates/defguard_session_manager/src/lib.rs | 2 +- .../tests/session_manager/sessions.rs | 2 +- crates/defguard_setup/src/migration.rs | 5 +-- 17 files changed, 74 insertions(+), 64 deletions(-) diff --git a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs index 478fdbf986..a53c91ec87 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs @@ -1,16 +1,17 @@ -use defguard_common::gateway_types::{ - FirewallPolicy, IpAddress, IpRange, Port, PortRange as GwPortRange, Protocol as GwProtocol, +use defguard_common::{ + db::{NoId, models::WireguardNetwork, setup_pool}, + gateway_types::{ + FirewallPolicy, IpAddress, IpRange, Port, PortRange as GwPortRange, Protocol as GwProtocol, + }, }; use defguard_proto::enterprise::firewall::Protocol as ProtoProtocol; +use rand::thread_rng; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, ops::RangeInclusive, }; -use defguard_common::db::{NoId, models::WireguardNetwork, setup_pool}; -use rand::thread_rng; -use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; - use super::{create_acl_rule, create_test_users_and_devices, set_test_license_business}; use crate::enterprise::{ db::models::acl::{ diff --git a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs b/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs index 4774cd4da9..96b1b96e4d 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs @@ -1,12 +1,14 @@ -use defguard_common::gateway_types::{FirewallPolicy, IpVersion}; +use defguard_common::{ + db::{ + Id, NoId, + models::{Device, DeviceType, User, WireguardNetwork, device::WireguardNetworkDevice}, + setup_pool, + }, + gateway_types::{FirewallPolicy, IpVersion}, +}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use chrono::NaiveDateTime; -use defguard_common::db::{ - Id, NoId, - models::{Device, DeviceType, User, WireguardNetwork, device::WireguardNetworkDevice}, - setup_pool, -}; use ipnetwork::IpNetwork; use rand::{Rng, rngs::ThreadRng, thread_rng}; use sqlx::{ diff --git a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs index 1801930b1c..e25397c8ff 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs @@ -1,17 +1,19 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use chrono::NaiveDateTime; -use defguard_common::db::{ - Id, NoId, - models::{ - Device, DeviceType, WireguardNetwork, device::WireguardNetworkDevice, group::Group, - user::User, +use defguard_common::{ + db::{ + Id, NoId, + models::{ + Device, DeviceType, WireguardNetwork, device::WireguardNetworkDevice, group::Group, + user::User, + }, + setup_pool, + }, + gateway_types::{ + FirewallPolicy, IpAddress, IpRange, IpVersion, Port, PortRange as GwPortRange, + Protocol as GwProtocol, }, - setup_pool, -}; -use defguard_common::gateway_types::{ - FirewallPolicy, IpAddress, IpRange, IpVersion, Port, PortRange as GwPortRange, - Protocol as GwProtocol, }; use defguard_proto::enterprise::firewall::Protocol as ProtoProtocol; use ipnetwork::IpNetwork; diff --git a/crates/defguard_core/src/location_management/allowed_peers.rs b/crates/defguard_core/src/location_management/allowed_peers.rs index 9dc5283502..24d2629d2f 100644 --- a/crates/defguard_core/src/location_management/allowed_peers.rs +++ b/crates/defguard_core/src/location_management/allowed_peers.rs @@ -1,5 +1,7 @@ -use defguard_common::db::{Id, models::WireguardNetwork}; -use defguard_common::gateway_types::WireguardPeer; +use defguard_common::{ + db::{Id, models::WireguardNetwork}, + gateway_types::WireguardPeer, +}; use sqlx::{PgConnection, query}; use crate::grpc::should_prevent_service_location_usage; diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 78d9fe6a24..5f9ca22579 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -1111,7 +1111,7 @@ mod tests { }, setup_pool, }; - use defguard_core::grpc::GatewayCommand; + use defguard_common::gateway_event::GatewayCommand; use defguard_proto::gateway::{Configuration, PeerStats, core_response}; use prost_types::Timestamp; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index d590549617..f5691da936 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -9,9 +9,9 @@ use std::{ sync::atomic::{AtomicBool, Ordering}, }; -use defguard_common::gateway_event::GatewayCommand; use defguard_common::{ db::{ChangeNotification, Id, TriggerOperation, models::gateway::Gateway}, + gateway_event::GatewayCommand, messages::peer_stats_update::PeerStatsUpdate, }; use defguard_proto::gateway::gateway_client::GatewayClient; diff --git a/crates/defguard_gateway_manager/src/tests/common/mod.rs b/crates/defguard_gateway_manager/src/tests/common/mod.rs index 6ce372d4aa..e59f72c571 100644 --- a/crates/defguard_gateway_manager/src/tests/common/mod.rs +++ b/crates/defguard_gateway_manager/src/tests/common/mod.rs @@ -11,7 +11,6 @@ use std::{ time::Duration, }; -use defguard_common::gateway_event::GatewayCommand; use defguard_common::{ db::{ Id, NoId, @@ -20,6 +19,7 @@ use defguard_common::{ }, setup_pool, }, + gateway_event::GatewayCommand, messages::peer_stats_update::PeerStatsUpdate, }; use defguard_proto::gateway::{CoreRequest, CoreResponse, PeerStats, core_request, gateway_server}; diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs index 6680f6ba07..97986d1e39 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs @@ -2,7 +2,7 @@ mod support; use defguard_common::db::models::device::{DeviceInfo, WireguardNetworkDevice}; -use defguard_core::grpc::GatewayCommand; +use defguard_common::gateway_event::GatewayCommand; use defguard_proto::gateway::{UpdateType, core_response}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tonic::Status; diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs index 6b61405d65..cc28d7f361 100644 --- a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs @@ -1,23 +1,23 @@ use std::net::IpAddr; -use defguard_common::db::{ - Id, NoId, - models::{ - device::{Device, DeviceInfo, DeviceNetworkInfo, DeviceType, WireguardNetworkDevice}, - user::User, - vpn_client_session::VpnClientSession, - wireguard::{LocationMfaMode, WireguardNetwork}, +use defguard_common::gateway_event::GatewayCommand; +use defguard_common::{ + db::{ + Id, NoId, + models::{ + device::{Device, DeviceInfo, DeviceNetworkInfo, DeviceType, WireguardNetworkDevice}, + user::User, + vpn_client_session::VpnClientSession, + wireguard::{LocationMfaMode, WireguardNetwork}, + }, }, -}; -use defguard_common::gateway_types::{ - FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, - Protocol as GwProtocol, SnatBinding, -}; -use defguard_core::{ - enterprise::db::models::device_posture::{ - DevicePosture, DevicePostureLocation, DevicePostureOsRule, OsType, + gateway_types::{ + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, + Protocol as GwProtocol, SnatBinding, }, - grpc::GatewayCommand, +}; +use defguard_core::enterprise::db::models::device_posture::{ + DevicePosture, DevicePostureLocation, DevicePostureOsRule, OsType, }; use defguard_proto::{ enterprise::firewall::FirewallConfig as ProtoFirewallConfig, diff --git a/crates/defguard_proxy_manager/src/lib.rs b/crates/defguard_proxy_manager/src/lib.rs index 394517de68..baded624b5 100644 --- a/crates/defguard_proxy_manager/src/lib.rs +++ b/crates/defguard_proxy_manager/src/lib.rs @@ -7,9 +7,9 @@ use std::{ use std::{path::PathBuf, str::FromStr, sync::Mutex as StdMutex}; use axum_extra::extract::cookie::Key; -use defguard_common::gateway_event::GatewayCommand; use defguard_common::{ db::{Id, models::proxy::Proxy}, + gateway_event::GatewayCommand, types::proxy::ProxyControlMessage, }; use defguard_core::{ diff --git a/crates/defguard_proxy_manager/src/tests/common/mod.rs b/crates/defguard_proxy_manager/src/tests/common/mod.rs index 2bef79c9a0..a3f4601f86 100644 --- a/crates/defguard_proxy_manager/src/tests/common/mod.rs +++ b/crates/defguard_proxy_manager/src/tests/common/mod.rs @@ -16,15 +16,17 @@ use axum::{ response::Json, }; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; -use defguard_common::db::{ - Id, NoId, - models::{ - proxy::Proxy, - settings::{Settings, initialize_current_settings}, +use defguard_common::{ + db::{ + Id, NoId, + models::{ + proxy::Proxy, + settings::{Settings, initialize_current_settings}, + }, + setup_pool, }, - setup_pool, + gateway_event::GatewayCommand, }; -use defguard_common::gateway_event::GatewayCommand; use defguard_core::events::BidiStreamEvent; use defguard_proto::proxy::{ AcmeChallenge, AcmeIssueEvent, CoreRequest, CoreResponse, InitialInfo, core_response, diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs index 17e8cbbb55..80477829e8 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/enrollment.rs @@ -1,8 +1,10 @@ -use defguard_common::db::models::{ - Device, Settings, User, biometric_auth::BiometricAuth, polling_token::PollingToken, - settings::update_current_settings, +use defguard_common::{ + db::models::{ + Device, Settings, User, biometric_auth::BiometricAuth, polling_token::PollingToken, + settings::update_current_settings, + }, + gateway_event::GatewayCommand, }; -use defguard_common::gateway_event::GatewayCommand; use defguard_core::events::{BidiStreamEventType, EnrollmentEvent}; use defguard_proto::{ client_types::{ExistingDevice, MfaMethod, NewDevice, RegisterMobileAuthRequest}, diff --git a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs index 6cc6d831cb..ad1c0f186c 100644 --- a/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs +++ b/crates/defguard_proxy_manager/src/tests/proxy_manager/handler/mfa.rs @@ -1,5 +1,4 @@ -use defguard_common::db::Id; -use defguard_common::gateway_event::GatewayCommand; +use defguard_common::{db::Id, gateway_event::GatewayCommand}; use defguard_proto::{ client_types::{ClientMfaFinishRequest, ClientMfaStartRequest, MfaMethod}, proxy::{AwaitRemoteMfaFinishRequest, CoreRequest, core_request, core_response}, diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 76aeb2f70e..44327f69b5 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -1,5 +1,4 @@ -use defguard_common::db::Id; -use defguard_common::gateway_event::GatewayCommand; +use defguard_common::{db::Id, gateway_event::GatewayCommand}; use thiserror::Error; use tokio::sync::{broadcast::error::SendError as BroadcastSendError, mpsc::error::SendError}; diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index a7969ca59c..9aa1800895 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -1,5 +1,4 @@ use chrono::Utc; -use defguard_common::gateway_event::GatewayCommand; use defguard_common::{ db::{ Id, @@ -8,6 +7,7 @@ use defguard_common::{ vpn_client_session::{VpnClientSession, VpnClientSessionState}, }, }, + gateway_event::GatewayCommand, messages::peer_stats_update::PeerStatsUpdate, }; use sqlx::{PgConnection, PgPool}; diff --git a/crates/defguard_session_manager/tests/session_manager/sessions.rs b/crates/defguard_session_manager/tests/session_manager/sessions.rs index 7abe0ee04e..bbbd1e9e4e 100644 --- a/crates/defguard_session_manager/tests/session_manager/sessions.rs +++ b/crates/defguard_session_manager/tests/session_manager/sessions.rs @@ -8,7 +8,7 @@ use defguard_common::db::{ }, setup_pool, }; -use defguard_core::grpc::GatewayCommand; +use defguard_common::gateway_event::GatewayCommand; use defguard_session_manager::events::SessionManagerEventType; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::time::{Duration, timeout}; diff --git a/crates/defguard_setup/src/migration.rs b/crates/defguard_setup/src/migration.rs index 10cc85d0c2..8c7c76b328 100644 --- a/crates/defguard_setup/src/migration.rs +++ b/crates/defguard_setup/src/migration.rs @@ -10,8 +10,9 @@ use axum::{ serve, }; use axum_extra::extract::cookie::Key; -use defguard_common::gateway_event::GatewayCommand; -use defguard_common::{VERSION, db::models::Settings, types::proxy::ProxyControlMessage}; +use defguard_common::{ + VERSION, db::models::Settings, gateway_event::GatewayCommand, types::proxy::ProxyControlMessage, +}; use defguard_core::{ appstate::AppState, auth::failed_login::FailedLoginMap,