diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index aa6cae3..98d165a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -32,7 +32,7 @@ jobs: run: cargo clippy -- -D warnings - name: Run tests with debugs if: matrix.os == 'ubuntu-20.04' || matrix.os == 'macos-latest' - run: RUST_LOG=debug cargo test --verbose + run: RUST_LOG=debug cargo test --verbose --features plugins - name: Run tests on Windows if: matrix.os == 'windows-latest' - run: cargo test + run: cargo test --features plugins diff --git a/Cargo.toml b/Cargo.toml index b10befa..c67b07f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ description = "mDNS Service Discovery library with no async runtime dependency" async = ["flume/async"] logging = ["log"] default = ["async", "logging"] +plugins = [] [dependencies] flume = { version = "0.11", default-features = false } # channel between threads diff --git a/src/lib.rs b/src/lib.rs index 9f1af4d..e229b33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -148,6 +148,8 @@ mod log { mod dns_cache; mod dns_parser; mod error; +#[cfg(feature = "plugins")] +mod plugin; mod service_daemon; mod service_info; @@ -158,5 +160,8 @@ pub use service_daemon::{ }; pub use service_info::{AsIpAddrs, IntoTxtProperties, ServiceInfo, TxtProperties, TxtProperty}; +#[cfg(feature = "plugins")] +pub use plugin::PluginCommand; + /// A handler to receive messages from [ServiceDaemon]. Re-export from `flume` crate. pub use flume::Receiver; diff --git a/src/plugin.rs b/src/plugin.rs new file mode 100644 index 0000000..71d0d55 --- /dev/null +++ b/src/plugin.rs @@ -0,0 +1,15 @@ +use crate::ServiceInfo; +use flume::Sender; +use std::collections::HashMap; +use std::sync::Arc; + +/// Commands to be implemented by plugins +#[derive(Debug)] +pub enum PluginCommand { + Registered, + + /// Command to fetch services that are currently provided by the plugin + ListServices(Sender>>), + + Exit(Sender<()>), +} diff --git a/src/service_daemon.rs b/src/service_daemon.rs index dc6fae8..b0a3178 100644 --- a/src/service_daemon.rs +++ b/src/service_daemon.rs @@ -30,6 +30,8 @@ // corresponds to a set of DNS Resource Records. #[cfg(feature = "logging")] use crate::log::{debug, error, warn}; +#[cfg(feature = "plugins")] +use crate::plugin::PluginCommand; use crate::{ dns_cache::DnsCache, dns_parser::{ @@ -46,6 +48,7 @@ use flume::{bounded, Sender, TrySendError}; use if_addrs::{IfAddr, Interface}; use polling::Poller; use socket2::{SockAddr, Socket}; +use std::sync::Arc; use std::{ cmp::{self, Reverse}, collections::{BinaryHeap, HashMap, HashSet}, @@ -278,6 +281,21 @@ impl ServiceDaemon { self.send_cmd(Command::StopResolveHostname(hostname.to_string())) } + /// Registers a plugin provided by the library consumer, to support dynamic mDNS resolution. + /// + /// Please be aware that this resolution should be relatively consistent, e.g. configured + /// externally. + /// + /// If feature `plugins` is enabled, the daemon will send requests to the plugins + /// using the flume channel sender for which needs to be provided as `pc_send`. + /// + /// Please note that enabling the feature enables fetching the plugin-provided services + /// on *every* request, so this is disabled by default due to extra overhead. + #[cfg(feature = "plugins")] + pub fn register_plugin(&self, name: String, pc_send: Sender) -> Result<()> { + self.send_cmd(Command::RegisterPlugin(name, pc_send)) + } + /// Registers a service provided by this host. /// /// If `service_info` has no addresses yet and its `addr_auto` is enabled, @@ -401,9 +419,29 @@ impl ServiceDaemon { fn daemon_thread(signal_sock: UdpSocket, poller: Poller, receiver: Receiver) { let zc = Zeroconf::new(signal_sock, poller); + #[cfg(feature = "plugins")] + let plugin_senders = zc.plugin_senders.clone(); + if let Some(cmd) = Self::run(zc, receiver) { match cmd { Command::Exit(resp_s) => { + #[cfg(feature = "plugins")] + for (plugin, sender) in plugin_senders.clone() { + let (p_send, p_recv) = bounded(1); + + match sender.send(PluginCommand::Exit(p_send)) { + Ok(()) => {} + Err(e) => { + error!("failed to send plugin exit command: {}, {}", plugin, e) + } + }; + + match p_recv.recv() { + Ok(()) => debug!("plugin {} exited successfully", plugin), + Err(e) => error!("plugin {} failed to exit: {}", plugin, e), + } + } + // It is guaranteed that the receiver already dropped, // i.e. the daemon command channel closed. if let Err(e) = resp_s.send(DaemonStatus::Shutdown) { @@ -658,6 +696,11 @@ impl ServiceDaemon { zc.monitors.push(resp_s); } + #[cfg(feature = "plugins")] + Command::RegisterPlugin(name, papi_send) => { + zc.register_plugin(name, papi_send); + } + Command::SetOption(daemon_opt) => { zc.process_set_option(daemon_opt); } @@ -930,6 +973,9 @@ struct Zeroconf { /// Service instances that are pending for resolving SRV and TXT. pending_resolves: HashSet, + + #[cfg(feature = "plugins")] + plugin_senders: HashMap>, } impl Zeroconf { @@ -979,6 +1025,8 @@ impl Zeroconf { timers, status, pending_resolves: HashSet::new(), + #[cfg(feature = "plugins")] + plugin_senders: HashMap::new(), } } @@ -1875,12 +1923,26 @@ impl Zeroconf { // See https://datatracker.ietf.org/doc/html/rfc6763#section-9 const META_QUERY: &str = "_services._dns-sd._udp.local."; + let services_by_plugins = self.list_plugin_services(); + + let mut all_services: HashMap<&String, &ServiceInfo> = HashMap::new(); + + for (k, v) in &self.my_services { + all_services.insert(k, v); + } + + for (_plugin, services) in &services_by_plugins { + for (k, v) in services { + all_services.insert(k, v); + } + } + for question in msg.questions.iter() { debug!("query question: {:?}", &question); let qtype = question.entry.ty; if qtype == TYPE_PTR { - for service in self.my_services.values() { + for service in all_services.values() { if question.entry.name == service.get_type() || service .get_subtype() @@ -1906,7 +1968,7 @@ impl Zeroconf { } } else { if qtype == TYPE_A || qtype == TYPE_AAAA || qtype == TYPE_ANY { - for service in self.my_services.values() { + for service in all_services.values() { if service.get_hostname().to_lowercase() == question.entry.name.to_lowercase() { @@ -1940,7 +2002,7 @@ impl Zeroconf { } let name_to_find = question.entry.name.to_lowercase(); - let service = match self.my_services.get(&name_to_find) { + let service = match all_services.get(&name_to_find) { Some(s) => s, None => continue, }; @@ -2315,6 +2377,79 @@ impl Zeroconf { self.increase_counter(Counter::CacheRefreshSRV, query_srv_count); self.increase_counter(Counter::CacheRefreshAddr, query_addr_count); } + + // Returns (Plugin, Map) + #[cfg(feature = "plugins")] + fn list_plugin_services(&self) -> Vec<(String, HashMap>)> { + let mut output = vec![]; + + for key in self.plugin_senders.keys() { + output.push((key.clone(), self.list_plugin_services_for(key))); + } + + output + } + + #[cfg(not(feature = "plugins"))] + fn list_plugin_services(&self) -> Vec<(String, HashMap>)> { + vec![] + } + + #[cfg(feature = "plugins")] + fn list_plugin_services_for(&self, plugin: &str) -> HashMap> { + let (r_send, r_recv) = bounded(1); + + let p_send = match self.plugin_senders.get(plugin) { + None => { + warn!("Could not find plugin {}", plugin); + + return HashMap::new(); + } + Some(p_send) => p_send, + }; + + match p_send.send(PluginCommand::ListServices(r_send)) { + Ok(()) => {} + Err(e) => warn!("Failed to send ListServices command: {}", e), + } + + r_recv.recv().unwrap_or_else(|e| { + warn!("Could not receive service list: {}", e); + + HashMap::new() + }) + } + + #[cfg(feature = "plugins")] + fn register_plugin(&mut self, name: String, papi_send: Sender) { + if self.plugin_senders.contains_key(&name) { + let old_send = self.plugin_senders.get(&name).unwrap(); + + let (exit_send, exit_recv) = bounded(1); + + match old_send.send(PluginCommand::Exit(exit_send)) { + Ok(()) => debug!("Requested old plugin exit"), + Err(e) => warn!("Failed to send exit command to a plugin: {}", e), + } + + match exit_recv.recv_timeout(Duration::from_secs(1)) { + Ok(()) => debug!("The old plugin exited"), + Err(e) => warn!("Old plugin's exit timed out: {}", e), + } + } + + debug!("Registered a new plugin: {}", name); + + self.plugin_senders.insert(name, papi_send.clone()); + + match papi_send.send(PluginCommand::Registered) { + Ok(()) => {} + Err(e) => warn!( + "Failed to send a registration notification to a plugin: {}", + e + ), + }; + } } /// All possible events sent to the client from the daemon @@ -2410,6 +2545,9 @@ enum Command { SetOption(DaemonOption), + #[cfg(feature = "plugins")] + RegisterPlugin(String, Sender), + Exit(Sender), } @@ -2429,6 +2567,8 @@ impl fmt::Display for Command { Self::StopResolveHostname(_) => write!(f, "Command StopResolveHostname"), Self::Unregister(_, _) => write!(f, "Command Unregister"), Self::UnregisterResend(_, _) => write!(f, "Command UnregisterResend"), + #[cfg(feature = "plugins")] + Self::RegisterPlugin(name, _) => write!(f, "Command RegisterPlugin: {}", name), Self::Resolve(_, _) => write!(f, "Command Resolve"), } } diff --git a/tests/mdns_test.rs b/tests/mdns_test.rs index 620f53b..d683073 100644 --- a/tests/mdns_test.rs +++ b/tests/mdns_test.rs @@ -1,11 +1,19 @@ +#[cfg(feature = "plugins")] +use flume::bounded; use if_addrs::{IfAddr, Interface}; +#[cfg(feature = "plugins")] +use mdns_sd::PluginCommand; use mdns_sd::{ DaemonEvent, DaemonStatus, HostnameResolutionEvent, IfKind, IntoTxtProperties, ServiceDaemon, ServiceEvent, ServiceInfo, UnregisterStatus, }; use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +#[cfg(feature = "plugins")] +use std::sync::Arc; use std::thread::sleep; +#[cfg(feature = "plugins")] +use std::thread::spawn; use std::time::{Duration, SystemTime}; // use test_log::test; // commented out for debugging a flaky test in CI. @@ -1456,6 +1464,98 @@ fn test_domain_suffix_in_browse() { mdns_client.shutdown().unwrap(); } +#[test] +#[cfg(feature = "plugins")] +fn plugin_support_test() { + let mdns_server = ServiceDaemon::new().expect("failed to create mdns server"); + let mdns_client = ServiceDaemon::new().expect("Failed to create mdns client"); + + let mut ips = vec![]; + + for i in my_ip_interfaces() { + mdns_server.enable_interface(&i.name).unwrap(); + mdns_client.enable_interface(&i.name).unwrap(); + ips.push(i.ip().to_string()); + } + + let (papi_send, papi_recv) = bounded(100); + + mdns_server + .register_plugin("test".to_string(), papi_send) + .expect("failed to register plugin"); + + let cmd_registered = papi_recv.recv().expect("failed to receive command"); + + match cmd_registered { + PluginCommand::Registered => {} + _ => panic!("Wrong plugin command received"), + }; + + spawn(move || { + let service_info_arc = Arc::new({ + let service_type = "somehost._tcp.local."; + let instance_name = "somehost"; + let ip = ips.join(","); + let host_name = "somehost.local."; + let port = 5200; + let properties = [("property_1", "test"), ("property_2", "1234")]; + + ServiceInfo::new( + service_type, + instance_name, + host_name, + ip, + port, + &properties[..], + ) + .unwrap() + }); + + loop { + let cmd = papi_recv.recv(); + + match cmd { + Ok(PluginCommand::Registered) => {} + Ok(PluginCommand::Exit(sender)) => { + sender.send(()).unwrap(); + return; + } + Ok(PluginCommand::ListServices(sender)) => { + let mut map = HashMap::new(); + + map.insert("somehost.local.".to_string(), service_info_arc.clone()); + + sender.send(map).unwrap(); + } + Err(_) => return, + } + } + }); + + let browse_chan = mdns_client + .resolve_hostname("somehost.local.", None) + .unwrap(); + + let mut resolved = false; + + while let Ok(event) = browse_chan.recv() { + match event { + HostnameResolutionEvent::AddressesFound(host, _addresses) => { + resolved = true; + println!("Resolved a service of {}", &host); + break; + } + other => { + println!("Received event {:?}", other); + } + } + } + assert!(resolved); + + mdns_server.shutdown().unwrap(); + mdns_client.shutdown().unwrap(); +} + /// A helper function to include a timestamp for println. fn timed_println(msg: String) { let now = SystemTime::now();