From 30caee12a9fbcd181509430a8bcb316e646e9185 Mon Sep 17 00:00:00 2001 From: Rubens Brandao Date: Wed, 26 Jun 2024 17:24:33 -0300 Subject: [PATCH 1/2] implement rust WebsocketProvider --- rust/src/lib.rs | 1 + rust/src/websocketprovider.rs | 434 ++++++++++++++++++++++++++++++++++ 2 files changed, 435 insertions(+) create mode 100644 rust/src/websocketprovider.rs diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 0cb0484e65..a72d16d1d0 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -84,6 +84,7 @@ pub mod type_printer; pub mod types; pub mod update; pub mod variable; +pub mod websocketprovider; pub mod worker_thread; pub mod workflow; diff --git a/rust/src/websocketprovider.rs b/rust/src/websocketprovider.rs new file mode 100644 index 0000000000..7d6d70442d --- /dev/null +++ b/rust/src/websocketprovider.rs @@ -0,0 +1,434 @@ +use core::{ffi, mem, ptr}; + +use binaryninjacore_sys::*; + +use crate::rc::{Array, CoreArrayProvider, CoreArrayProviderInner}; +use crate::string::{BnStrCompatible, BnString}; + +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct WebsocketProvider { + handle: ptr::NonNull, +} + +impl WebsocketProvider { + pub(crate) unsafe fn from_raw(handle: ptr::NonNull) -> Self { + Self { handle } + } + + pub(crate) unsafe fn ref_from_raw(handle: &*mut BNWebsocketProvider) -> &Self { + assert!(!handle.is_null()); + mem::transmute(handle) + } + + #[allow(clippy::mut_from_ref)] + pub(crate) unsafe fn as_raw(&self) -> &mut BNWebsocketProvider { + &mut *self.handle.as_ptr() + } + + pub fn all() -> Array { + let mut count = 0; + let result = unsafe { BNGetWebsocketProviderList(&mut count) }; + assert!(!result.is_null()); + unsafe { Array::new(result, count, ()) } + } + + pub fn by_name(name: S) -> Option { + let name = name.into_bytes_with_nul(); + let result = + unsafe { BNGetWebsocketProviderByName(name.as_ref().as_ptr() as *const ffi::c_char) }; + ptr::NonNull::new(result).map(|h| unsafe { Self::from_raw(h) }) + } + + pub fn name(&self) -> BnString { + let result = unsafe { BNGetWebsocketProviderName(self.as_raw()) }; + assert!(!result.is_null()); + unsafe { BnString::from_raw(result) } + } + + /// Connect to a given url, asynchronously. The connection will be run in a + /// separate thread managed by the websocket provider. + /// + /// * `host` - Full url with scheme, domain, optionally port, and path + /// * `headers` - HTTP header keys and values + pub fn connect(self, url: U, headers: I) -> Option + where + U: BnStrCompatible, + I: IntoIterator, + K: BnStrCompatible, + V: BnStrCompatible, + { + let result = unsafe { BNCreateWebsocketProviderClient(self.as_raw()) }; + let client = unsafe { WebsocketClient::from_raw(ptr::NonNull::new(result).unwrap()) }; + let url = url.into_bytes_with_nul(); + let (header_keys, header_values): (Vec, Vec) = headers + .into_iter() + .map(|(k, v)| (k.into_bytes_with_nul(), v.into_bytes_with_nul())) + .unzip(); + let header_keys: Vec<*const ffi::c_char> = header_keys + .iter() + .map(|k| k.as_ref().as_ptr() as *const ffi::c_char) + .collect(); + let header_values: Vec<*const ffi::c_char> = header_values + .iter() + .map(|v| v.as_ref().as_ptr() as *const ffi::c_char) + .collect(); + let mut cb_callback = BNWebsocketClientOutputCallbacks { + context: ptr::null_mut(), + connectedCallback: Some(cb_connected_nop), + disconnectedCallback: Some(cb_disconnected_nop), + errorCallback: Some(cb_error_nop), + readCallback: Some(cb_read_nop), + }; + let success = unsafe { + BNConnectWebsocketClient( + client.as_raw(), + url.as_ref().as_ptr() as *const ffi::c_char, + header_keys.len().try_into().unwrap(), + header_keys.as_ptr(), + header_values.as_ptr(), + &mut cb_callback, + ) + }; + success.then_some(client) + } + + /// Connect to a given url, asynchronously. The connection will be run in a + /// separate thread managed by the websocket provider. + /// + /// Callbacks will be called **on the thread of the connection**, so be sure + /// to ExecuteOnMainThread any long-running or gui operations in the callbacks. + /// + /// If the connection succeeds, [WebsocketClientCallback::connected] will be called. On normal termination, [WebsocketClientCallback::disconnected] will be called. + /// + /// If the connection succeeds, but later fails, [WebsocketClientCallback::disconnected] will not be called, and [WebsocketClientCallback::error] will be called instead. + /// + /// If the connection fails, neither [WebsocketClientCallback::connected] nor [WebsocketClientCallback::disconnected] will be called, and [WebsocketClientCallback::error] will be called instead. + /// + /// If [WebsocketClientCallback::connected] or [WebsocketClientCallback::read] return false, the connection will be aborted. + /// + /// * `host` - Full url with scheme, domain, optionally port, and path + /// * `headers` - HTTP header keys and values + /// * `callback` - Callbacks for various websocket events + pub fn connect_with_callback( + self, + url: U, + headers: I, + callback: W, + ) -> Option> + where + U: BnStrCompatible, + I: IntoIterator, + K: BnStrCompatible, + V: BnStrCompatible, + W: WebsocketClientCallback, + { + let result = unsafe { BNCreateWebsocketProviderClient(self.as_raw()) }; + let client = unsafe { WebsocketClient::from_raw(ptr::NonNull::new(result).unwrap()) }; + // SAFETY: freed by WebsocketClientConnectedWithCallback::drop + let callback = Box::leak(Box::new(callback)); + let url = url.into_bytes_with_nul(); + let (header_keys, header_values): (Vec, Vec) = headers + .into_iter() + .map(|(k, v)| (k.into_bytes_with_nul(), v.into_bytes_with_nul())) + .unzip(); + let header_keys: Vec<*const ffi::c_char> = header_keys + .iter() + .map(|k| k.as_ref().as_ptr() as *const ffi::c_char) + .collect(); + let header_values: Vec<*const ffi::c_char> = header_values + .iter() + .map(|v| v.as_ref().as_ptr() as *const ffi::c_char) + .collect(); + let mut cb_callback = BNWebsocketClientOutputCallbacks { + context: callback as *mut W as *mut _, + connectedCallback: Some(cb_connected::), + disconnectedCallback: Some(cb_disconnected::), + errorCallback: Some(cb_error::), + readCallback: Some(cb_read::), + }; + let success = unsafe { + BNConnectWebsocketClient( + client.as_raw(), + url.as_ref().as_ptr() as *const ffi::c_char, + header_keys.len().try_into().unwrap(), + header_keys.as_ptr(), + header_values.as_ptr(), + &mut cb_callback, + ) + }; + success.then(|| WebsocketClientHandleWithCallback { client, callback }) + } +} + +impl CoreArrayProvider for WebsocketProvider { + type Raw = *mut BNWebsocketProvider; + type Context = (); + type Wrapped<'a> = &'a Self; +} + +unsafe impl CoreArrayProviderInner for WebsocketProvider { + unsafe fn free(raw: *mut Self::Raw, _count: usize, _context: &Self::Context) { + BNFreeWebsocketProviderList(raw) + } + + unsafe fn wrap_raw<'a>(raw: &'a Self::Raw, _context: &'a Self::Context) -> Self::Wrapped<'a> { + Self::ref_from_raw(raw) + } +} + +/// Implements a websocket client. See [WebsocketProvider::connect] and [WebsocketProvider::connect_with_callback] for more details. +#[repr(transparent)] +pub struct WebsocketClient { + handle: ptr::NonNull, +} + +impl Clone for WebsocketClient { + fn clone(&self) -> Self { + let result = unsafe { BNNewWebsocketClientReference(self.as_raw()) }; + unsafe { Self::from_raw(ptr::NonNull::new(result).unwrap()) } + } +} + +impl Drop for WebsocketClient { + fn drop(&mut self) { + unsafe { BNFreeWebsocketClient(self.as_raw()) } + } +} + +impl WebsocketClient { + pub(crate) unsafe fn from_raw(handle: ptr::NonNull) -> Self { + Self { handle } + } + + pub(crate) unsafe fn into_raw(self) -> *mut BNWebsocketClient { + mem::ManuallyDrop::new(self).handle.as_ptr() + } + + #[allow(clippy::mut_from_ref)] + pub(crate) unsafe fn as_raw(&self) -> &mut BNWebsocketClient { + &mut *self.handle.as_ptr() + } + + pub fn new_custom(provider: WebsocketProvider) -> WebsocketClient + where + W: WebsocketCustomClient, + { + // SAFETY: Websocket client is freed by cb_destroy_client + let custom_uinit = Box::leak(Box::new(mem::MaybeUninit::zeroed())); + let mut callbacks = BNWebsocketClientCallbacks { + context: custom_uinit as *mut _ as *mut ffi::c_void, + connect: Some(cb_connect::), + destroyClient: Some(cb_destroy_client::), + disconnect: Some(cb_disconnect::), + write: Some(cb_write::), + }; + let result = unsafe { BNInitWebsocketClient(provider.as_raw(), &mut callbacks) }; + let client = unsafe { WebsocketClient::from_raw(ptr::NonNull::new(result).unwrap()) }; + custom_uinit.write(W::new(provider, &client)); + client + } + + /// Call the connect callback function, forward the callback returned value + pub fn notify_connect(&self) -> bool { + unsafe { BNNotifyWebsocketClientConnect(self.as_raw()) } + } + + /// Notify the callback function of a disconnect, but don't disconnect, + /// use the [Self::disconnect] function for that + pub fn notify_disconnect(&self) { + unsafe { BNNotifyWebsocketClientDisconnect(self.as_raw()) } + } + + /// Call the error callback function + pub fn notify_error(&self, error: S) { + let error = error.into_bytes_with_nul(); + unsafe { + BNNotifyWebsocketClientError( + self.as_raw(), + error.as_ref().as_ptr() as *const ffi::c_char, + ) + } + } + + /// Call the read callback function, forward the callback returned value + pub fn notify_read(&self, data: &mut [u8]) -> bool { + unsafe { + BNNotifyWebsocketClientReadData( + self.as_raw(), + data.as_mut_ptr(), + data.len().try_into().unwrap(), + ) + } + } + + /// Write some data to the websocket + pub fn write(&self, data: &[u8]) -> usize { + let len = u64::try_from(data.len()).unwrap(); + let result = unsafe { BNWriteWebsocketClientData(self.as_raw(), data.as_ptr(), len) }; + usize::try_from(result).unwrap() + } + + /// Disconnect the websocket + pub fn disconnect(&self) -> bool { + unsafe { BNDisconnectWebsocketClient(self.as_raw()) } + } +} + +pub struct WebsocketClientHandleWithCallback { + client: WebsocketClient, + callback: *mut W, +} + +impl Drop for WebsocketClientHandleWithCallback { + fn drop(&mut self) { + let callback: Box = unsafe { Box::from_raw(self.callback) }; + drop(callback); + } +} + +impl AsRef for WebsocketClientHandleWithCallback { + fn as_ref(&self) -> &WebsocketClient { + &self.client + } +} + +impl core::ops::Deref for WebsocketClientHandleWithCallback { + type Target = WebsocketClient; + + fn deref(&self) -> &Self::Target { + &self.client + } +} + +pub trait WebsocketCustomProvider: Sync + Send { + fn new(core: WebsocketProvider) -> Self; + fn create_client(&self) -> WebsocketClient; +} + +pub trait WebsocketClientCallback: Sync + Send { + fn connected(&self) -> bool; + fn disconnected(&self); + fn error(&self, msg: &str); + fn read(&self, data: &mut [u8]) -> bool; +} + +pub trait WebsocketCustomClient: Sync + Send { + fn new(provider: WebsocketProvider, client: &WebsocketClient) -> Self; + fn connect(&self, host: &str, header_keys: &[BnString], header_values: &[BnString]) -> bool; + fn write(&self, data: &[u8]) -> bool; + fn disconnect(&self) -> bool; +} + +pub fn register_websocket_provider(name: S) -> (&'static W, WebsocketProvider) +where + S: BnStrCompatible, + W: WebsocketCustomProvider + 'static, +{ + let name = name.into_bytes_with_nul(); + // SAFETY: Websocket provider is never freed + let provider_uinit = Box::leak(Box::new(mem::MaybeUninit::zeroed())); + let result = unsafe { + BNRegisterWebsocketProvider( + name.as_ref().as_ptr() as *const ffi::c_char, + &mut BNWebsocketProviderCallbacks { + context: provider_uinit as *mut _ as *mut ffi::c_void, + createClient: Some(cb_create_client::), + }, + ) + }; + let provider_core = unsafe { WebsocketProvider::from_raw(ptr::NonNull::new(result).unwrap()) }; + provider_uinit.write(W::new(provider_core)); + (unsafe { provider_uinit.assume_init_ref() }, provider_core) +} + +unsafe extern "C" fn cb_create_client( + ctxt: *mut ::std::os::raw::c_void, +) -> *mut BNWebsocketClient { + let ctxt: &mut W = &mut *(ctxt as *mut W); + let result = ctxt.create_client(); + result.into_raw() +} + +unsafe extern "C" fn cb_destroy_client(ctxt: *mut ffi::c_void) { + let ctxt: Box = Box::from_raw(&mut *(ctxt as *mut W)); + drop(ctxt) +} + +unsafe extern "C" fn cb_connect( + ctxt: *mut ffi::c_void, + host: *const ffi::c_char, + header_count: u64, + header_keys: *const *const ffi::c_char, + header_values: *const *const ffi::c_char, +) -> bool { + let ctxt: &mut W = &mut *(ctxt as *mut W); + let host = ffi::CStr::from_ptr(host); + // SAFETY BnString and *mut ffi::c_char are transparnet + let header_count = usize::try_from(header_count).unwrap(); + let header_keys = core::slice::from_raw_parts(header_keys as *const BnString, header_count); + let header_values = core::slice::from_raw_parts(header_values as *const BnString, header_count); + ctxt.connect(&host.to_string_lossy(), header_keys, header_values) +} + +unsafe extern "C" fn cb_write( + data: *const u8, + len: u64, + ctxt: *mut ffi::c_void, +) -> bool { + let ctxt: &mut W = &mut *(ctxt as *mut W); + let len = usize::try_from(len).unwrap(); + let data = core::slice::from_raw_parts(data, len); + ctxt.write(data) +} + +unsafe extern "C" fn cb_disconnect(ctxt: *mut ffi::c_void) -> bool { + let ctxt: &mut W = &mut *(ctxt as *mut W); + ctxt.disconnect() +} + +unsafe extern "C" fn cb_connected(ctxt: *mut ffi::c_void) -> bool { + let ctxt: &mut W = &mut *(ctxt as *mut W); + ctxt.connected() +} + +unsafe extern "C" fn cb_disconnected(ctxt: *mut ffi::c_void) { + let ctxt: &mut W = &mut *(ctxt as *mut W); + ctxt.disconnected() +} + +unsafe extern "C" fn cb_error( + msg: *const ffi::c_char, + ctxt: *mut ffi::c_void, +) { + let ctxt: &mut W = &mut *(ctxt as *mut W); + let msg = ffi::CStr::from_ptr(msg); + ctxt.error(&msg.to_string_lossy()) +} + +unsafe extern "C" fn cb_read( + data: *mut u8, + len: u64, + ctxt: *mut ::std::os::raw::c_void, +) -> bool { + let ctxt: &mut W = &mut *(ctxt as *mut W); + let len = usize::try_from(len).unwrap(); + let data = core::slice::from_raw_parts_mut(data, len); + ctxt.read(data) +} + +unsafe extern "C" fn cb_connected_nop(_ctxt: *mut ffi::c_void) -> bool { + true +} + +unsafe extern "C" fn cb_disconnected_nop(_ctxt: *mut ffi::c_void) {} + +unsafe extern "C" fn cb_error_nop(_msg: *const ffi::c_char, _ctxt: *mut ffi::c_void) {} + +unsafe extern "C" fn cb_read_nop( + _data: *mut u8, + _len: u64, + _ctxt: *mut ::std::os::raw::c_void, +) -> bool { + true +} From 22d45c893bc31d5c5a6d08b1ec95d83fe9365d4e Mon Sep 17 00:00:00 2001 From: rbran Date: Wed, 5 Feb 2025 13:06:34 +0000 Subject: [PATCH 2/2] fix websocketprovider inconsistencies --- rust/src/websocketprovider.rs | 350 +++++++++++++++++++------------- rust/tests/websocketprovider.rs | 92 +++++++++ 2 files changed, 298 insertions(+), 144 deletions(-) create mode 100644 rust/tests/websocketprovider.rs diff --git a/rust/src/websocketprovider.rs b/rust/src/websocketprovider.rs index 7d6d70442d..6619b79b69 100644 --- a/rust/src/websocketprovider.rs +++ b/rust/src/websocketprovider.rs @@ -2,14 +2,16 @@ use core::{ffi, mem, ptr}; use binaryninjacore_sys::*; -use crate::rc::{Array, CoreArrayProvider, CoreArrayProviderInner}; +use crate::rc::{Array, CoreArrayProvider, CoreArrayProviderInner, Ref, RefCountable}; use crate::string::{BnStrCompatible, BnString}; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Hash, PartialEq, Eq)] #[repr(transparent)] pub struct WebsocketProvider { handle: ptr::NonNull, } +unsafe impl Sync for WebsocketProvider {} +unsafe impl Send for WebsocketProvider {} impl WebsocketProvider { pub(crate) unsafe fn from_raw(handle: ptr::NonNull) -> Self { @@ -51,46 +53,29 @@ impl WebsocketProvider { /// /// * `host` - Full url with scheme, domain, optionally port, and path /// * `headers` - HTTP header keys and values - pub fn connect(self, url: U, headers: I) -> Option + /// * `read_handle` - Handles the received data from the Socket + pub fn connect<'a, U, I, K, V, F>( + &'a self, + url: U, + headers: I, + read_handle: &'a mut F, + ) -> Option>> where U: BnStrCompatible, I: IntoIterator, K: BnStrCompatible, V: BnStrCompatible, + F: FnMut(&[u8]) -> bool, { - let result = unsafe { BNCreateWebsocketProviderClient(self.as_raw()) }; - let client = unsafe { WebsocketClient::from_raw(ptr::NonNull::new(result).unwrap()) }; - let url = url.into_bytes_with_nul(); - let (header_keys, header_values): (Vec, Vec) = headers - .into_iter() - .map(|(k, v)| (k.into_bytes_with_nul(), v.into_bytes_with_nul())) - .unzip(); - let header_keys: Vec<*const ffi::c_char> = header_keys - .iter() - .map(|k| k.as_ref().as_ptr() as *const ffi::c_char) - .collect(); - let header_values: Vec<*const ffi::c_char> = header_values - .iter() - .map(|v| v.as_ref().as_ptr() as *const ffi::c_char) - .collect(); - let mut cb_callback = BNWebsocketClientOutputCallbacks { - context: ptr::null_mut(), + let cb_callback = BNWebsocketClientOutputCallbacks { + context: read_handle as *mut _ as *mut ffi::c_void, connectedCallback: Some(cb_connected_nop), disconnectedCallback: Some(cb_disconnected_nop), errorCallback: Some(cb_error_nop), - readCallback: Some(cb_read_nop), - }; - let success = unsafe { - BNConnectWebsocketClient( - client.as_raw(), - url.as_ref().as_ptr() as *const ffi::c_char, - header_keys.len().try_into().unwrap(), - header_keys.as_ptr(), - header_values.as_ptr(), - &mut cb_callback, - ) + readCallback: Some(cb_read_closure::), }; - success.then_some(client) + let client_ptr = unsafe { BNCreateWebsocketProviderClient(self.as_raw()) }; + connect_client(client_ptr, url, headers, cb_callback) } /// Connect to a given url, asynchronously. The connection will be run in a @@ -110,12 +95,12 @@ impl WebsocketProvider { /// * `host` - Full url with scheme, domain, optionally port, and path /// * `headers` - HTTP header keys and values /// * `callback` - Callbacks for various websocket events - pub fn connect_with_callback( - self, + pub fn connect_with_callback<'a, U, I, K, V, W>( + &self, url: U, headers: I, - callback: W, - ) -> Option> + callback: &'a mut W, + ) -> Option>> where U: BnStrCompatible, I: IntoIterator, @@ -123,41 +108,15 @@ impl WebsocketProvider { V: BnStrCompatible, W: WebsocketClientCallback, { - let result = unsafe { BNCreateWebsocketProviderClient(self.as_raw()) }; - let client = unsafe { WebsocketClient::from_raw(ptr::NonNull::new(result).unwrap()) }; - // SAFETY: freed by WebsocketClientConnectedWithCallback::drop - let callback = Box::leak(Box::new(callback)); - let url = url.into_bytes_with_nul(); - let (header_keys, header_values): (Vec, Vec) = headers - .into_iter() - .map(|(k, v)| (k.into_bytes_with_nul(), v.into_bytes_with_nul())) - .unzip(); - let header_keys: Vec<*const ffi::c_char> = header_keys - .iter() - .map(|k| k.as_ref().as_ptr() as *const ffi::c_char) - .collect(); - let header_values: Vec<*const ffi::c_char> = header_values - .iter() - .map(|v| v.as_ref().as_ptr() as *const ffi::c_char) - .collect(); - let mut cb_callback = BNWebsocketClientOutputCallbacks { + let cb_callback = BNWebsocketClientOutputCallbacks { context: callback as *mut W as *mut _, connectedCallback: Some(cb_connected::), disconnectedCallback: Some(cb_disconnected::), errorCallback: Some(cb_error::), readCallback: Some(cb_read::), }; - let success = unsafe { - BNConnectWebsocketClient( - client.as_raw(), - url.as_ref().as_ptr() as *const ffi::c_char, - header_keys.len().try_into().unwrap(), - header_keys.as_ptr(), - header_values.as_ptr(), - &mut cb_callback, - ) - }; - success.then(|| WebsocketClientHandleWithCallback { client, callback }) + let client_ptr = unsafe { BNCreateWebsocketProviderClient(self.as_raw()) }; + connect_client(client_ptr, url, headers, cb_callback) } } @@ -177,32 +136,41 @@ unsafe impl CoreArrayProviderInner for WebsocketProvider { } } -/// Implements a websocket client. See [WebsocketProvider::connect] and [WebsocketProvider::connect_with_callback] for more details. +/// Implements a websocket client. #[repr(transparent)] -pub struct WebsocketClient { +pub struct WebsocketClient<'a> { handle: ptr::NonNull, + // lifetime of callbacks, AKA don't drop callbacks while the client still running + _callback: std::marker::PhantomData<&'a ()>, } +unsafe impl Sync for WebsocketClient<'_> {} +unsafe impl Send for WebsocketClient<'_> {} + +impl ToOwned for WebsocketClient<'_> { + type Owned = Ref; -impl Clone for WebsocketClient { - fn clone(&self) -> Self { - let result = unsafe { BNNewWebsocketClientReference(self.as_raw()) }; - unsafe { Self::from_raw(ptr::NonNull::new(result).unwrap()) } + fn to_owned(&self) -> Self::Owned { + unsafe { ::inc_ref(self) } } } -impl Drop for WebsocketClient { - fn drop(&mut self) { - unsafe { BNFreeWebsocketClient(self.as_raw()) } +unsafe impl RefCountable for WebsocketClient<'_> { + unsafe fn inc_ref(handle: &Self) -> Ref { + let result = BNNewWebsocketClientReference(handle.as_raw()); + unsafe { Self::ref_from_raw(ptr::NonNull::new(result).unwrap()) } } -} -impl WebsocketClient { - pub(crate) unsafe fn from_raw(handle: ptr::NonNull) -> Self { - Self { handle } + unsafe fn dec_ref(handle: &Self) { + BNFreeWebsocketClient(handle.as_raw()) } +} - pub(crate) unsafe fn into_raw(self) -> *mut BNWebsocketClient { - mem::ManuallyDrop::new(self).handle.as_ptr() +impl WebsocketClient<'_> { + pub(crate) unsafe fn ref_from_raw(handle: ptr::NonNull) -> Ref { + Ref::new(Self { + handle, + _callback: std::marker::PhantomData, + }) } #[allow(clippy::mut_from_ref)] @@ -210,34 +178,30 @@ impl WebsocketClient { &mut *self.handle.as_ptr() } - pub fn new_custom(provider: WebsocketProvider) -> WebsocketClient - where - W: WebsocketCustomClient, - { - // SAFETY: Websocket client is freed by cb_destroy_client - let custom_uinit = Box::leak(Box::new(mem::MaybeUninit::zeroed())); - let mut callbacks = BNWebsocketClientCallbacks { - context: custom_uinit as *mut _ as *mut ffi::c_void, - connect: Some(cb_connect::), - destroyClient: Some(cb_destroy_client::), - disconnect: Some(cb_disconnect::), - write: Some(cb_write::), - }; - let result = unsafe { BNInitWebsocketClient(provider.as_raw(), &mut callbacks) }; - let client = unsafe { WebsocketClient::from_raw(ptr::NonNull::new(result).unwrap()) }; - custom_uinit.write(W::new(provider, &client)); - client + /// Write some data to the websocket + pub fn write(&self, data: &[u8]) -> bool { + let len = u64::try_from(data.len()).unwrap(); + unsafe { BNWriteWebsocketClientData(self.as_raw(), data.as_ptr(), len) != 0 } + } + + /// Disconnect the websocket + pub fn disconnect(&self) -> bool { + unsafe { BNDisconnectWebsocketClient(self.as_raw()) } } +} + +pub struct CoreWebSocketClient<'a>(WebsocketClient<'a>); +impl CoreWebSocketClient<'_> { /// Call the connect callback function, forward the callback returned value pub fn notify_connect(&self) -> bool { - unsafe { BNNotifyWebsocketClientConnect(self.as_raw()) } + unsafe { BNNotifyWebsocketClientConnect(self.0.as_raw()) } } /// Notify the callback function of a disconnect, but don't disconnect, /// use the [Self::disconnect] function for that pub fn notify_disconnect(&self) { - unsafe { BNNotifyWebsocketClientDisconnect(self.as_raw()) } + unsafe { BNNotifyWebsocketClientDisconnect(self.0.as_raw()) } } /// Call the error callback function @@ -245,65 +209,161 @@ impl WebsocketClient { let error = error.into_bytes_with_nul(); unsafe { BNNotifyWebsocketClientError( - self.as_raw(), + self.0.as_raw(), error.as_ref().as_ptr() as *const ffi::c_char, ) } } /// Call the read callback function, forward the callback returned value - pub fn notify_read(&self, data: &mut [u8]) -> bool { + pub fn notify_read(&self, data: &[u8]) -> bool { unsafe { BNNotifyWebsocketClientReadData( - self.as_raw(), - data.as_mut_ptr(), + self.0.as_raw(), + data.as_ptr() as *mut _, data.len().try_into().unwrap(), ) } } - - /// Write some data to the websocket - pub fn write(&self, data: &[u8]) -> usize { - let len = u64::try_from(data.len()).unwrap(); - let result = unsafe { BNWriteWebsocketClientData(self.as_raw(), data.as_ptr(), len) }; - usize::try_from(result).unwrap() - } - - /// Disconnect the websocket - pub fn disconnect(&self) -> bool { - unsafe { BNDisconnectWebsocketClient(self.as_raw()) } - } } -pub struct WebsocketClientHandleWithCallback { - client: WebsocketClient, - callback: *mut W, -} +pub trait WebsocketCustomProvider: Sync + Send { + type Client<'a>: WebsocketCustomClient; + + fn new(core: WebsocketProvider) -> Self; + fn get_core(&self) -> &WebsocketProvider; + fn init_client<'a>(&self, core: CoreWebSocketClient<'a>) -> Self::Client<'a>; -impl Drop for WebsocketClientHandleWithCallback { - fn drop(&mut self) { - let callback: Box = unsafe { Box::from_raw(self.callback) }; - drop(callback); + /// Connect to a given url, asynchronously. The connection will be run in a + /// separate thread managed by the websocket provider. + /// + /// * `host` - Full url with scheme, domain, optionally port, and path + /// * `headers` - HTTP header keys and values + /// * `read_handle` - Handles the received data from the Socket + fn connect<'a, U, I, K, V, F>( + &self, + url: U, + headers: I, + read_handle: &'a mut F, + ) -> Option>> + where + Self: Sized, + U: BnStrCompatible, + I: IntoIterator, + K: BnStrCompatible, + V: BnStrCompatible, + F: FnMut(&[u8]) -> bool, + { + let cb_callback = BNWebsocketClientOutputCallbacks { + context: read_handle as *mut _ as *mut ffi::c_void, + connectedCallback: Some(cb_connected_nop), + disconnectedCallback: Some(cb_disconnected_nop), + errorCallback: Some(cb_error_nop), + readCallback: Some(cb_read_closure::), + }; + let client_ptr = new_client(self); + connect_client(client_ptr, url, headers, cb_callback) } -} -impl AsRef for WebsocketClientHandleWithCallback { - fn as_ref(&self) -> &WebsocketClient { - &self.client + /// Connect to a given url, asynchronously. The connection will be run in a + /// separate thread managed by the websocket provider. + /// + /// Callbacks will be called **on the thread of the connection**, so be sure + /// to ExecuteOnMainThread any long-running or gui operations in the callbacks. + /// + /// If the connection succeeds, [WebsocketClientCallback::connected] will be called. On normal termination, [WebsocketClientCallback::disconnected] will be called. + /// + /// If the connection succeeds, but later fails, [WebsocketClientCallback::disconnected] will not be called, and [WebsocketClientCallback::error] will be called instead. + /// + /// If the connection fails, neither [WebsocketClientCallback::connected] nor [WebsocketClientCallback::disconnected] will be called, and [WebsocketClientCallback::error] will be called instead. + /// + /// If [WebsocketClientCallback::connected] or [WebsocketClientCallback::read] return false, the connection will be aborted. + /// + /// * `host` - Full url with scheme, domain, optionally port, and path + /// * `headers` - HTTP header keys and values + /// * `callback` - Callbacks for various websocket events + fn connect_with_callback<'a, U, I, K, V, W>( + &self, + url: U, + headers: I, + callback: &'a mut W, + ) -> Option>> + where + Self: Sized, + U: BnStrCompatible, + I: IntoIterator, + K: BnStrCompatible, + V: BnStrCompatible, + W: WebsocketClientCallback, + { + let cb_callback = BNWebsocketClientOutputCallbacks { + context: callback as *mut W as *mut _, + connectedCallback: Some(cb_connected::), + disconnectedCallback: Some(cb_disconnected::), + errorCallback: Some(cb_error::), + readCallback: Some(cb_read::), + }; + let client_ptr = new_client(self); + connect_client(client_ptr, url, headers, cb_callback) } } -impl core::ops::Deref for WebsocketClientHandleWithCallback { - type Target = WebsocketClient; - - fn deref(&self) -> &Self::Target { - &self.client - } +fn new_client(provider: &W) -> *mut BNWebsocketClient { + // SAFETY: Websocket client is freed by cb_destroy_client + let custom_uinit = Box::leak(Box::new(mem::MaybeUninit::zeroed())); + let mut callbacks = BNWebsocketClientCallbacks { + context: custom_uinit as *mut _ as *mut ffi::c_void, + connect: Some(cb_connect::>), + destroyClient: Some(cb_destroy_client::>), + disconnect: Some(cb_disconnect::>), + write: Some(cb_write::>), + }; + let handle = unsafe { BNInitWebsocketClient(provider.get_core().as_raw(), &mut callbacks) }; + custom_uinit.write(provider.init_client(CoreWebSocketClient(WebsocketClient { + handle: ptr::NonNull::new(handle).unwrap(), + _callback: std::marker::PhantomData, + }))); + handle } -pub trait WebsocketCustomProvider: Sync + Send { - fn new(core: WebsocketProvider) -> Self; - fn create_client(&self) -> WebsocketClient; +fn connect_client<'a, U, I, K, V>( + client_ptr: *mut BNWebsocketClient, + url: U, + headers: I, + mut cb_callback: BNWebsocketClientOutputCallbacks, +) -> Option>> +where + U: BnStrCompatible, + I: IntoIterator, + K: BnStrCompatible, + V: BnStrCompatible, +{ + let client = unsafe { WebsocketClient::ref_from_raw(ptr::NonNull::new(client_ptr).unwrap()) }; + // SAFETY: freed by WebsocketClientConnectedWithCallback::drop + let url = url.into_bytes_with_nul(); + let (header_keys, header_values): (Vec, Vec) = headers + .into_iter() + .map(|(k, v)| (k.into_bytes_with_nul(), v.into_bytes_with_nul())) + .unzip(); + let header_keys: Vec<*const ffi::c_char> = header_keys + .iter() + .map(|k| k.as_ref().as_ptr() as *const ffi::c_char) + .collect(); + let header_values: Vec<*const ffi::c_char> = header_values + .iter() + .map(|v| v.as_ref().as_ptr() as *const ffi::c_char) + .collect(); + let success = unsafe { + BNConnectWebsocketClient( + client.as_raw(), + url.as_ref().as_ptr() as *const ffi::c_char, + header_keys.len().try_into().unwrap(), + header_keys.as_ptr(), + header_values.as_ptr(), + &mut cb_callback, + ) + }; + success.then_some(client) } pub trait WebsocketClientCallback: Sync + Send { @@ -314,7 +374,6 @@ pub trait WebsocketClientCallback: Sync + Send { } pub trait WebsocketCustomClient: Sync + Send { - fn new(provider: WebsocketProvider, client: &WebsocketClient) -> Self; fn connect(&self, host: &str, header_keys: &[BnString], header_values: &[BnString]) -> bool; fn write(&self, data: &[u8]) -> bool; fn disconnect(&self) -> bool; @@ -346,8 +405,7 @@ unsafe extern "C" fn cb_create_client( ctxt: *mut ::std::os::raw::c_void, ) -> *mut BNWebsocketClient { let ctxt: &mut W = &mut *(ctxt as *mut W); - let result = ctxt.create_client(); - result.into_raw() + new_client(ctxt) } unsafe extern "C" fn cb_destroy_client(ctxt: *mut ffi::c_void) { @@ -425,10 +483,14 @@ unsafe extern "C" fn cb_disconnected_nop(_ctxt: *mut ffi::c_void) {} unsafe extern "C" fn cb_error_nop(_msg: *const ffi::c_char, _ctxt: *mut ffi::c_void) {} -unsafe extern "C" fn cb_read_nop( - _data: *mut u8, - _len: u64, - _ctxt: *mut ::std::os::raw::c_void, +unsafe extern "C" fn cb_read_closure bool>( + data: *mut u8, + len: u64, + ctxt: *mut ::std::os::raw::c_void, ) -> bool { - true + let ctxt: &mut F = &mut *(ctxt as *mut F); + let len = usize::try_from(len).unwrap(); + let data = core::slice::from_raw_parts_mut(data, len); + let ctxt: &mut F = &mut *(ctxt as *mut F); + ctxt(data) } diff --git a/rust/tests/websocketprovider.rs b/rust/tests/websocketprovider.rs new file mode 100644 index 0000000000..1b50518ef0 --- /dev/null +++ b/rust/tests/websocketprovider.rs @@ -0,0 +1,92 @@ +use binaryninja::headless::Session; +use binaryninja::string::BnString; +use binaryninja::websocketprovider::{ + register_websocket_provider, CoreWebSocketClient, WebsocketCustomClient, + WebsocketCustomProvider, WebsocketProvider, +}; +use rstest::*; + +#[fixture] +#[once] +fn session() -> Session { + Session::new().expect("Failed to initialize session") +} + +struct MyWebsocketProvider { + core: WebsocketProvider, +} + +impl WebsocketCustomProvider for MyWebsocketProvider { + type Client<'a> = MyWebsocketClient<'a>; + + fn new(core: WebsocketProvider) -> Self { + Self { core } + } + + fn get_core(&self) -> &WebsocketProvider { + &self.core + } + + fn init_client<'a>(&self, core: CoreWebSocketClient<'a>) -> Self::Client<'a> { + MyWebsocketClient { core } + } +} + +struct MyWebsocketClient<'a> { + core: CoreWebSocketClient<'a>, +} + +impl WebsocketCustomClient for MyWebsocketClient<'_> { + fn connect(&self, _host: &str, _header_keys: &[BnString], _header_values: &[BnString]) -> bool { + true + } + + fn write(&self, data: &[u8]) -> bool { + if !self.core.notify_read("sent: ".as_bytes()) { + return false; + } + if !self.core.notify_read(data) { + return false; + } + self.core.notify_read("\n".as_bytes()) + } + + fn disconnect(&self) -> bool { + true + } +} + +#[rstest] +fn reg_websocket_provider(_session: &Session) { + let (rust_provider, _core_provider) = + register_websocket_provider::<_, MyWebsocketProvider>("RustWebsocketProvider"); + let mut handle = |_: &[u8]| true; + let _client = rust_provider + .connect("url", [("header", "value")], &mut handle) + .unwrap(); +} + +#[rstest] +fn listen_websocket_provider(_session: &Session) { + let (rust_provider, _core_provider) = + register_websocket_provider::<_, MyWebsocketProvider>("RustWebsocketProvider2"); + + let mut data_read = vec![]; + let mut read_handle = |data: &[u8]| { + data_read.extend_from_slice(data); + true + }; + let client = rust_provider + .connect("url", [("header", "value")], &mut read_handle) + .unwrap(); + // NOTE important to enforce that this line will result compilation errors + //let _ = read_handle(&[]); + + assert!(client.write("test1".as_bytes())); + assert!(client.write("test2".as_bytes())); + + client.disconnect(); + drop(client); + + assert_eq!(&data_read[..], "sent: test1\nsent: test2\n".as_bytes()); +}