diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 0cb0484e65..a72d16d1d0 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -84,6 +84,7 @@ pub mod type_printer; pub mod types; pub mod update; pub mod variable; +pub mod websocketprovider; pub mod worker_thread; pub mod workflow; diff --git a/rust/src/websocketprovider.rs b/rust/src/websocketprovider.rs new file mode 100644 index 0000000000..6619b79b69 --- /dev/null +++ b/rust/src/websocketprovider.rs @@ -0,0 +1,496 @@ +use core::{ffi, mem, ptr}; + +use binaryninjacore_sys::*; + +use crate::rc::{Array, CoreArrayProvider, CoreArrayProviderInner, Ref, RefCountable}; +use crate::string::{BnStrCompatible, BnString}; + +#[derive(Clone, Copy, Hash, PartialEq, Eq)] +#[repr(transparent)] +pub struct WebsocketProvider { + handle: ptr::NonNull, +} +unsafe impl Sync for WebsocketProvider {} +unsafe impl Send for WebsocketProvider {} + +impl WebsocketProvider { + pub(crate) unsafe fn from_raw(handle: ptr::NonNull) -> Self { + Self { handle } + } + + pub(crate) unsafe fn ref_from_raw(handle: &*mut BNWebsocketProvider) -> &Self { + assert!(!handle.is_null()); + mem::transmute(handle) + } + + #[allow(clippy::mut_from_ref)] + pub(crate) unsafe fn as_raw(&self) -> &mut BNWebsocketProvider { + &mut *self.handle.as_ptr() + } + + pub fn all() -> Array { + let mut count = 0; + let result = unsafe { BNGetWebsocketProviderList(&mut count) }; + assert!(!result.is_null()); + unsafe { Array::new(result, count, ()) } + } + + pub fn by_name(name: S) -> Option { + let name = name.into_bytes_with_nul(); + let result = + unsafe { BNGetWebsocketProviderByName(name.as_ref().as_ptr() as *const ffi::c_char) }; + ptr::NonNull::new(result).map(|h| unsafe { Self::from_raw(h) }) + } + + pub fn name(&self) -> BnString { + let result = unsafe { BNGetWebsocketProviderName(self.as_raw()) }; + assert!(!result.is_null()); + unsafe { BnString::from_raw(result) } + } + + /// Connect to a given url, asynchronously. The connection will be run in a + /// separate thread managed by the websocket provider. + /// + /// * `host` - Full url with scheme, domain, optionally port, and path + /// * `headers` - HTTP header keys and values + /// * `read_handle` - Handles the received data from the Socket + pub fn connect<'a, U, I, K, V, F>( + &'a self, + url: U, + headers: I, + read_handle: &'a mut F, + ) -> Option>> + where + U: BnStrCompatible, + I: IntoIterator, + K: BnStrCompatible, + V: BnStrCompatible, + F: FnMut(&[u8]) -> bool, + { + let cb_callback = BNWebsocketClientOutputCallbacks { + context: read_handle as *mut _ as *mut ffi::c_void, + connectedCallback: Some(cb_connected_nop), + disconnectedCallback: Some(cb_disconnected_nop), + errorCallback: Some(cb_error_nop), + readCallback: Some(cb_read_closure::), + }; + let client_ptr = unsafe { BNCreateWebsocketProviderClient(self.as_raw()) }; + connect_client(client_ptr, url, headers, cb_callback) + } + + /// Connect to a given url, asynchronously. The connection will be run in a + /// separate thread managed by the websocket provider. + /// + /// Callbacks will be called **on the thread of the connection**, so be sure + /// to ExecuteOnMainThread any long-running or gui operations in the callbacks. + /// + /// If the connection succeeds, [WebsocketClientCallback::connected] will be called. On normal termination, [WebsocketClientCallback::disconnected] will be called. + /// + /// If the connection succeeds, but later fails, [WebsocketClientCallback::disconnected] will not be called, and [WebsocketClientCallback::error] will be called instead. + /// + /// If the connection fails, neither [WebsocketClientCallback::connected] nor [WebsocketClientCallback::disconnected] will be called, and [WebsocketClientCallback::error] will be called instead. + /// + /// If [WebsocketClientCallback::connected] or [WebsocketClientCallback::read] return false, the connection will be aborted. + /// + /// * `host` - Full url with scheme, domain, optionally port, and path + /// * `headers` - HTTP header keys and values + /// * `callback` - Callbacks for various websocket events + pub fn connect_with_callback<'a, U, I, K, V, W>( + &self, + url: U, + headers: I, + callback: &'a mut W, + ) -> Option>> + where + U: BnStrCompatible, + I: IntoIterator, + K: BnStrCompatible, + V: BnStrCompatible, + W: WebsocketClientCallback, + { + let cb_callback = BNWebsocketClientOutputCallbacks { + context: callback as *mut W as *mut _, + connectedCallback: Some(cb_connected::), + disconnectedCallback: Some(cb_disconnected::), + errorCallback: Some(cb_error::), + readCallback: Some(cb_read::), + }; + let client_ptr = unsafe { BNCreateWebsocketProviderClient(self.as_raw()) }; + connect_client(client_ptr, url, headers, cb_callback) + } +} + +impl CoreArrayProvider for WebsocketProvider { + type Raw = *mut BNWebsocketProvider; + type Context = (); + type Wrapped<'a> = &'a Self; +} + +unsafe impl CoreArrayProviderInner for WebsocketProvider { + unsafe fn free(raw: *mut Self::Raw, _count: usize, _context: &Self::Context) { + BNFreeWebsocketProviderList(raw) + } + + unsafe fn wrap_raw<'a>(raw: &'a Self::Raw, _context: &'a Self::Context) -> Self::Wrapped<'a> { + Self::ref_from_raw(raw) + } +} + +/// Implements a websocket client. +#[repr(transparent)] +pub struct WebsocketClient<'a> { + handle: ptr::NonNull, + // lifetime of callbacks, AKA don't drop callbacks while the client still running + _callback: std::marker::PhantomData<&'a ()>, +} +unsafe impl Sync for WebsocketClient<'_> {} +unsafe impl Send for WebsocketClient<'_> {} + +impl ToOwned for WebsocketClient<'_> { + type Owned = Ref; + + fn to_owned(&self) -> Self::Owned { + unsafe { ::inc_ref(self) } + } +} + +unsafe impl RefCountable for WebsocketClient<'_> { + unsafe fn inc_ref(handle: &Self) -> Ref { + let result = BNNewWebsocketClientReference(handle.as_raw()); + unsafe { Self::ref_from_raw(ptr::NonNull::new(result).unwrap()) } + } + + unsafe fn dec_ref(handle: &Self) { + BNFreeWebsocketClient(handle.as_raw()) + } +} + +impl WebsocketClient<'_> { + pub(crate) unsafe fn ref_from_raw(handle: ptr::NonNull) -> Ref { + Ref::new(Self { + handle, + _callback: std::marker::PhantomData, + }) + } + + #[allow(clippy::mut_from_ref)] + pub(crate) unsafe fn as_raw(&self) -> &mut BNWebsocketClient { + &mut *self.handle.as_ptr() + } + + /// Write some data to the websocket + pub fn write(&self, data: &[u8]) -> bool { + let len = u64::try_from(data.len()).unwrap(); + unsafe { BNWriteWebsocketClientData(self.as_raw(), data.as_ptr(), len) != 0 } + } + + /// Disconnect the websocket + pub fn disconnect(&self) -> bool { + unsafe { BNDisconnectWebsocketClient(self.as_raw()) } + } +} + +pub struct CoreWebSocketClient<'a>(WebsocketClient<'a>); + +impl CoreWebSocketClient<'_> { + /// Call the connect callback function, forward the callback returned value + pub fn notify_connect(&self) -> bool { + unsafe { BNNotifyWebsocketClientConnect(self.0.as_raw()) } + } + + /// Notify the callback function of a disconnect, but don't disconnect, + /// use the [Self::disconnect] function for that + pub fn notify_disconnect(&self) { + unsafe { BNNotifyWebsocketClientDisconnect(self.0.as_raw()) } + } + + /// Call the error callback function + pub fn notify_error(&self, error: S) { + let error = error.into_bytes_with_nul(); + unsafe { + BNNotifyWebsocketClientError( + self.0.as_raw(), + error.as_ref().as_ptr() as *const ffi::c_char, + ) + } + } + + /// Call the read callback function, forward the callback returned value + pub fn notify_read(&self, data: &[u8]) -> bool { + unsafe { + BNNotifyWebsocketClientReadData( + self.0.as_raw(), + data.as_ptr() as *mut _, + data.len().try_into().unwrap(), + ) + } + } +} + +pub trait WebsocketCustomProvider: Sync + Send { + type Client<'a>: WebsocketCustomClient; + + fn new(core: WebsocketProvider) -> Self; + fn get_core(&self) -> &WebsocketProvider; + fn init_client<'a>(&self, core: CoreWebSocketClient<'a>) -> Self::Client<'a>; + + /// Connect to a given url, asynchronously. The connection will be run in a + /// separate thread managed by the websocket provider. + /// + /// * `host` - Full url with scheme, domain, optionally port, and path + /// * `headers` - HTTP header keys and values + /// * `read_handle` - Handles the received data from the Socket + fn connect<'a, U, I, K, V, F>( + &self, + url: U, + headers: I, + read_handle: &'a mut F, + ) -> Option>> + where + Self: Sized, + U: BnStrCompatible, + I: IntoIterator, + K: BnStrCompatible, + V: BnStrCompatible, + F: FnMut(&[u8]) -> bool, + { + let cb_callback = BNWebsocketClientOutputCallbacks { + context: read_handle as *mut _ as *mut ffi::c_void, + connectedCallback: Some(cb_connected_nop), + disconnectedCallback: Some(cb_disconnected_nop), + errorCallback: Some(cb_error_nop), + readCallback: Some(cb_read_closure::), + }; + let client_ptr = new_client(self); + connect_client(client_ptr, url, headers, cb_callback) + } + + /// Connect to a given url, asynchronously. The connection will be run in a + /// separate thread managed by the websocket provider. + /// + /// Callbacks will be called **on the thread of the connection**, so be sure + /// to ExecuteOnMainThread any long-running or gui operations in the callbacks. + /// + /// If the connection succeeds, [WebsocketClientCallback::connected] will be called. On normal termination, [WebsocketClientCallback::disconnected] will be called. + /// + /// If the connection succeeds, but later fails, [WebsocketClientCallback::disconnected] will not be called, and [WebsocketClientCallback::error] will be called instead. + /// + /// If the connection fails, neither [WebsocketClientCallback::connected] nor [WebsocketClientCallback::disconnected] will be called, and [WebsocketClientCallback::error] will be called instead. + /// + /// If [WebsocketClientCallback::connected] or [WebsocketClientCallback::read] return false, the connection will be aborted. + /// + /// * `host` - Full url with scheme, domain, optionally port, and path + /// * `headers` - HTTP header keys and values + /// * `callback` - Callbacks for various websocket events + fn connect_with_callback<'a, U, I, K, V, W>( + &self, + url: U, + headers: I, + callback: &'a mut W, + ) -> Option>> + where + Self: Sized, + U: BnStrCompatible, + I: IntoIterator, + K: BnStrCompatible, + V: BnStrCompatible, + W: WebsocketClientCallback, + { + let cb_callback = BNWebsocketClientOutputCallbacks { + context: callback as *mut W as *mut _, + connectedCallback: Some(cb_connected::), + disconnectedCallback: Some(cb_disconnected::), + errorCallback: Some(cb_error::), + readCallback: Some(cb_read::), + }; + let client_ptr = new_client(self); + connect_client(client_ptr, url, headers, cb_callback) + } +} + +fn new_client(provider: &W) -> *mut BNWebsocketClient { + // SAFETY: Websocket client is freed by cb_destroy_client + let custom_uinit = Box::leak(Box::new(mem::MaybeUninit::zeroed())); + let mut callbacks = BNWebsocketClientCallbacks { + context: custom_uinit as *mut _ as *mut ffi::c_void, + connect: Some(cb_connect::>), + destroyClient: Some(cb_destroy_client::>), + disconnect: Some(cb_disconnect::>), + write: Some(cb_write::>), + }; + let handle = unsafe { BNInitWebsocketClient(provider.get_core().as_raw(), &mut callbacks) }; + custom_uinit.write(provider.init_client(CoreWebSocketClient(WebsocketClient { + handle: ptr::NonNull::new(handle).unwrap(), + _callback: std::marker::PhantomData, + }))); + handle +} + +fn connect_client<'a, U, I, K, V>( + client_ptr: *mut BNWebsocketClient, + url: U, + headers: I, + mut cb_callback: BNWebsocketClientOutputCallbacks, +) -> Option>> +where + U: BnStrCompatible, + I: IntoIterator, + K: BnStrCompatible, + V: BnStrCompatible, +{ + let client = unsafe { WebsocketClient::ref_from_raw(ptr::NonNull::new(client_ptr).unwrap()) }; + // SAFETY: freed by WebsocketClientConnectedWithCallback::drop + let url = url.into_bytes_with_nul(); + let (header_keys, header_values): (Vec, Vec) = headers + .into_iter() + .map(|(k, v)| (k.into_bytes_with_nul(), v.into_bytes_with_nul())) + .unzip(); + let header_keys: Vec<*const ffi::c_char> = header_keys + .iter() + .map(|k| k.as_ref().as_ptr() as *const ffi::c_char) + .collect(); + let header_values: Vec<*const ffi::c_char> = header_values + .iter() + .map(|v| v.as_ref().as_ptr() as *const ffi::c_char) + .collect(); + let success = unsafe { + BNConnectWebsocketClient( + client.as_raw(), + url.as_ref().as_ptr() as *const ffi::c_char, + header_keys.len().try_into().unwrap(), + header_keys.as_ptr(), + header_values.as_ptr(), + &mut cb_callback, + ) + }; + success.then_some(client) +} + +pub trait WebsocketClientCallback: Sync + Send { + fn connected(&self) -> bool; + fn disconnected(&self); + fn error(&self, msg: &str); + fn read(&self, data: &mut [u8]) -> bool; +} + +pub trait WebsocketCustomClient: Sync + Send { + fn connect(&self, host: &str, header_keys: &[BnString], header_values: &[BnString]) -> bool; + fn write(&self, data: &[u8]) -> bool; + fn disconnect(&self) -> bool; +} + +pub fn register_websocket_provider(name: S) -> (&'static W, WebsocketProvider) +where + S: BnStrCompatible, + W: WebsocketCustomProvider + 'static, +{ + let name = name.into_bytes_with_nul(); + // SAFETY: Websocket provider is never freed + let provider_uinit = Box::leak(Box::new(mem::MaybeUninit::zeroed())); + let result = unsafe { + BNRegisterWebsocketProvider( + name.as_ref().as_ptr() as *const ffi::c_char, + &mut BNWebsocketProviderCallbacks { + context: provider_uinit as *mut _ as *mut ffi::c_void, + createClient: Some(cb_create_client::), + }, + ) + }; + let provider_core = unsafe { WebsocketProvider::from_raw(ptr::NonNull::new(result).unwrap()) }; + provider_uinit.write(W::new(provider_core)); + (unsafe { provider_uinit.assume_init_ref() }, provider_core) +} + +unsafe extern "C" fn cb_create_client( + ctxt: *mut ::std::os::raw::c_void, +) -> *mut BNWebsocketClient { + let ctxt: &mut W = &mut *(ctxt as *mut W); + new_client(ctxt) +} + +unsafe extern "C" fn cb_destroy_client(ctxt: *mut ffi::c_void) { + let ctxt: Box = Box::from_raw(&mut *(ctxt as *mut W)); + drop(ctxt) +} + +unsafe extern "C" fn cb_connect( + ctxt: *mut ffi::c_void, + host: *const ffi::c_char, + header_count: u64, + header_keys: *const *const ffi::c_char, + header_values: *const *const ffi::c_char, +) -> bool { + let ctxt: &mut W = &mut *(ctxt as *mut W); + let host = ffi::CStr::from_ptr(host); + // SAFETY BnString and *mut ffi::c_char are transparnet + let header_count = usize::try_from(header_count).unwrap(); + let header_keys = core::slice::from_raw_parts(header_keys as *const BnString, header_count); + let header_values = core::slice::from_raw_parts(header_values as *const BnString, header_count); + ctxt.connect(&host.to_string_lossy(), header_keys, header_values) +} + +unsafe extern "C" fn cb_write( + data: *const u8, + len: u64, + ctxt: *mut ffi::c_void, +) -> bool { + let ctxt: &mut W = &mut *(ctxt as *mut W); + let len = usize::try_from(len).unwrap(); + let data = core::slice::from_raw_parts(data, len); + ctxt.write(data) +} + +unsafe extern "C" fn cb_disconnect(ctxt: *mut ffi::c_void) -> bool { + let ctxt: &mut W = &mut *(ctxt as *mut W); + ctxt.disconnect() +} + +unsafe extern "C" fn cb_connected(ctxt: *mut ffi::c_void) -> bool { + let ctxt: &mut W = &mut *(ctxt as *mut W); + ctxt.connected() +} + +unsafe extern "C" fn cb_disconnected(ctxt: *mut ffi::c_void) { + let ctxt: &mut W = &mut *(ctxt as *mut W); + ctxt.disconnected() +} + +unsafe extern "C" fn cb_error( + msg: *const ffi::c_char, + ctxt: *mut ffi::c_void, +) { + let ctxt: &mut W = &mut *(ctxt as *mut W); + let msg = ffi::CStr::from_ptr(msg); + ctxt.error(&msg.to_string_lossy()) +} + +unsafe extern "C" fn cb_read( + data: *mut u8, + len: u64, + ctxt: *mut ::std::os::raw::c_void, +) -> bool { + let ctxt: &mut W = &mut *(ctxt as *mut W); + let len = usize::try_from(len).unwrap(); + let data = core::slice::from_raw_parts_mut(data, len); + ctxt.read(data) +} + +unsafe extern "C" fn cb_connected_nop(_ctxt: *mut ffi::c_void) -> bool { + true +} + +unsafe extern "C" fn cb_disconnected_nop(_ctxt: *mut ffi::c_void) {} + +unsafe extern "C" fn cb_error_nop(_msg: *const ffi::c_char, _ctxt: *mut ffi::c_void) {} + +unsafe extern "C" fn cb_read_closure bool>( + data: *mut u8, + len: u64, + ctxt: *mut ::std::os::raw::c_void, +) -> bool { + let ctxt: &mut F = &mut *(ctxt as *mut F); + let len = usize::try_from(len).unwrap(); + let data = core::slice::from_raw_parts_mut(data, len); + let ctxt: &mut F = &mut *(ctxt as *mut F); + ctxt(data) +} diff --git a/rust/tests/websocketprovider.rs b/rust/tests/websocketprovider.rs new file mode 100644 index 0000000000..1b50518ef0 --- /dev/null +++ b/rust/tests/websocketprovider.rs @@ -0,0 +1,92 @@ +use binaryninja::headless::Session; +use binaryninja::string::BnString; +use binaryninja::websocketprovider::{ + register_websocket_provider, CoreWebSocketClient, WebsocketCustomClient, + WebsocketCustomProvider, WebsocketProvider, +}; +use rstest::*; + +#[fixture] +#[once] +fn session() -> Session { + Session::new().expect("Failed to initialize session") +} + +struct MyWebsocketProvider { + core: WebsocketProvider, +} + +impl WebsocketCustomProvider for MyWebsocketProvider { + type Client<'a> = MyWebsocketClient<'a>; + + fn new(core: WebsocketProvider) -> Self { + Self { core } + } + + fn get_core(&self) -> &WebsocketProvider { + &self.core + } + + fn init_client<'a>(&self, core: CoreWebSocketClient<'a>) -> Self::Client<'a> { + MyWebsocketClient { core } + } +} + +struct MyWebsocketClient<'a> { + core: CoreWebSocketClient<'a>, +} + +impl WebsocketCustomClient for MyWebsocketClient<'_> { + fn connect(&self, _host: &str, _header_keys: &[BnString], _header_values: &[BnString]) -> bool { + true + } + + fn write(&self, data: &[u8]) -> bool { + if !self.core.notify_read("sent: ".as_bytes()) { + return false; + } + if !self.core.notify_read(data) { + return false; + } + self.core.notify_read("\n".as_bytes()) + } + + fn disconnect(&self) -> bool { + true + } +} + +#[rstest] +fn reg_websocket_provider(_session: &Session) { + let (rust_provider, _core_provider) = + register_websocket_provider::<_, MyWebsocketProvider>("RustWebsocketProvider"); + let mut handle = |_: &[u8]| true; + let _client = rust_provider + .connect("url", [("header", "value")], &mut handle) + .unwrap(); +} + +#[rstest] +fn listen_websocket_provider(_session: &Session) { + let (rust_provider, _core_provider) = + register_websocket_provider::<_, MyWebsocketProvider>("RustWebsocketProvider2"); + + let mut data_read = vec![]; + let mut read_handle = |data: &[u8]| { + data_read.extend_from_slice(data); + true + }; + let client = rust_provider + .connect("url", [("header", "value")], &mut read_handle) + .unwrap(); + // NOTE important to enforce that this line will result compilation errors + //let _ = read_handle(&[]); + + assert!(client.write("test1".as_bytes())); + assert!(client.write("test2".as_bytes())); + + client.disconnect(); + drop(client); + + assert_eq!(&data_read[..], "sent: test1\nsent: test2\n".as_bytes()); +}