diff --git a/packages/fortifier-macros-tests/tests/validations.rs b/packages/fortifier-macros-tests/tests/validations.rs index cd55e47..935a952 100644 --- a/packages/fortifier-macros-tests/tests/validations.rs +++ b/packages/fortifier-macros-tests/tests/validations.rs @@ -3,7 +3,6 @@ use trybuild::TestCases; #[test] fn validations() { let t = TestCases::new(); - t.pass("tests/validations/*/root_generics_pass.rs"); - // t.pass("tests/validations/*/*_pass.rs"); - // t.compile_fail("tests/validations/*/*_fail.rs"); + t.pass("tests/validations/*/*_pass.rs"); + t.compile_fail("tests/validations/*/*_fail.rs"); } diff --git a/packages/fortifier-macros-tests/tests/validations/custom/options_pass.rs b/packages/fortifier-macros-tests/tests/validations/custom/options_pass.rs new file mode 100644 index 0000000..af1abf7 --- /dev/null +++ b/packages/fortifier-macros-tests/tests/validations/custom/options_pass.rs @@ -0,0 +1,114 @@ +use fortifier::{Validate, error_code}; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Serialize, Validate)] +struct CustomData<'a> { + #[validate(custom(function = custom, error = CustomError))] + zero_options: &'a str, + + #[validate(custom(function = custom, error = CustomError))] + strip_one_option: Option<&'a str>, + #[validate(custom(function = custom, error = CustomError))] + strip_two_options: Option>, + #[validate(custom(function = custom, error = CustomError))] + strip_three_options: Option>>, + + #[validate(custom(function = custom_one_option, error = CustomError, options))] + strip_no_options_from_one: Option<&'a str>, + #[validate(custom(function = custom_two_options, error = CustomError, options))] + strip_no_options_from_two: Option>, + #[validate(custom(function = custom_three_options, error = CustomError, options))] + strip_no_options_from_three: Option>>, + + #[validate(custom(function = custom_one_option, error = CustomError, options = 1))] + strip_to_one_option_from_one: Option<&'a str>, + #[validate(custom(function = custom_one_option, error = CustomError, options = 1))] + strip_to_one_option_from_two: Option>, + #[validate(custom(function = custom_one_option, error = CustomError, options = 1))] + strip_to_one_option_from_three: Option>>, + + #[validate(custom(function = custom_one_option, error = CustomError, options = 2))] + strip_to_two_options_from_one: Option<&'a str>, + #[validate(custom(function = custom_two_options, error = CustomError, options = 2))] + strip_to_two_options_from_two: Option>, + #[validate(custom(function = custom_two_options, error = CustomError, options = 2))] + strip_to_two_options_from_three: Option>>, +} + +error_code!(CustomErrorCode, "custom"); + +#[derive(Debug, Deserialize, PartialEq, Serialize)] +#[serde(rename_all = "camelCase")] +struct CustomError { + code: CustomErrorCode, +} + +fn custom(value: &str) -> Result<(), CustomError> { + if value == "" { + Ok(()) + } else { + Err(CustomError { + code: CustomErrorCode, + }) + } +} + +fn custom_one_option(value: &Option<&str>) -> Result<(), CustomError> { + if let Some(value) = value + && *value == "" + { + Ok(()) + } else { + Err(CustomError { + code: CustomErrorCode, + }) + } +} + +fn custom_two_options(value: &Option>) -> Result<(), CustomError> { + if let Some(Some(value)) = value + && *value == "" + { + Ok(()) + } else { + Err(CustomError { + code: CustomErrorCode, + }) + } +} + +fn custom_three_options(value: &Option>>) -> Result<(), CustomError> { + if let Some(Some(Some(value))) = value + && *value == "" + { + Ok(()) + } else { + Err(CustomError { + code: CustomErrorCode, + }) + } +} + +fn main() { + let data = CustomData { + zero_options: "", + + strip_one_option: Some(""), + strip_two_options: Some(Some("")), + strip_three_options: Some(Some(Some(""))), + + strip_no_options_from_one: Some(""), + strip_no_options_from_two: Some(Some("")), + strip_no_options_from_three: Some(Some(Some(""))), + + strip_to_one_option_from_one: Some(""), + strip_to_one_option_from_two: Some(Some("")), + strip_to_one_option_from_three: Some(Some(Some(""))), + + strip_to_two_options_from_one: Some(""), + strip_to_two_options_from_two: Some(Some("")), + strip_to_two_options_from_three: Some(Some(Some(""))), + }; + + assert_eq!(data.validate_sync(), Ok(())); +} diff --git a/packages/fortifier-macros/src/util.rs b/packages/fortifier-macros/src/util.rs index 01d4d79..0416526 100644 --- a/packages/fortifier-macros/src/util.rs +++ b/packages/fortifier-macros/src/util.rs @@ -1,6 +1,6 @@ use convert_case::{Case, Casing}; use quote::format_ident; -use syn::Ident; +use syn::{GenericArgument, Ident, Path, PathArguments, Type}; pub fn upper_camel_ident(ident: &Ident) -> Ident { let s = ident.to_string(); @@ -11,3 +11,30 @@ pub fn upper_camel_ident(ident: &Ident) -> Ident { format_ident!("{}", s.to_case(Case::UpperCamel)) } } + +pub fn path_to_string(path: &Path) -> String { + path.segments + .iter() + .map(|segment| segment.ident.to_string()) + .collect::>() + .join("::") +} + +pub fn is_option_path(path: &Path) -> bool { + let path_string = path_to_string(path); + path_string == "Option" || path_string == "std::option::Option" +} + +pub fn count_options(r#type: &Type) -> usize { + if let Type::Path(r#type) = r#type + && let Some(segment) = r#type.path.segments.last() + && let PathArguments::AngleBracketed(arguments) = &segment.arguments + && arguments.args.len() == 1 + && is_option_path(&r#type.path) + && let Some(GenericArgument::Type(argument_type)) = arguments.args.first() + { + 1 + count_options(argument_type) + } else { + 0 + } +} diff --git a/packages/fortifier-macros/src/validate/type.rs b/packages/fortifier-macros/src/validate/type.rs index 7fe01ed..343da11 100644 --- a/packages/fortifier-macros/src/validate/type.rs +++ b/packages/fortifier-macros/src/validate/type.rs @@ -5,7 +5,9 @@ use syn::{ TypeParamBound, WherePredicate, punctuated::Punctuated, token::PathSep, }; -use crate::{integrations::where_predicate, validate::error::format_error_ident}; +use crate::{ + integrations::where_predicate, util::path_to_string, validate::error::format_error_ident, +}; /// Primitive and built-in types. /// @@ -244,14 +246,6 @@ impl ValidateResult { } } -fn path_to_string(path: &Path) -> String { - path.segments - .iter() - .map(|segment| segment.ident.to_string()) - .collect::>() - .join("::") -} - fn is_validate_path(path: &Path) -> bool { let path_string = path_to_string(path); path_string == "Validate" diff --git a/packages/fortifier-macros/src/validations/custom.rs b/packages/fortifier-macros/src/validations/custom.rs index 2e7d1e1..64e0d2e 100644 --- a/packages/fortifier-macros/src/validations/custom.rs +++ b/packages/fortifier-macros/src/validations/custom.rs @@ -1,28 +1,31 @@ use proc_macro2::TokenStream; use quote::{ToTokens, format_ident, quote}; -use syn::{Ident, LitBool, Path, Result, Type, TypePath, meta::ParseNestedMeta}; +use syn::{Ident, LitBool, LitInt, Path, Result, Type, TypePath, meta::ParseNestedMeta}; use crate::{ generics::{Generic, generic_arguments}, - util::upper_camel_ident, + util::{count_options, upper_camel_ident}, validation::{Execution, Validation}, }; pub struct Custom { + r#type: Type, name: Ident, - execution: Execution, error_type: TypePath, function_path: Path, + execution: Execution, context: bool, + max_options: usize, } impl Validation for Custom { - fn parse(_type: &Type, meta: &ParseNestedMeta<'_>) -> Result { + fn parse(r#type: &Type, meta: &ParseNestedMeta<'_>) -> Result { let mut name = None; - let mut execution = Execution::Sync; let mut error_type: Option = None; let mut function_path: Option = None; + let mut execution = Execution::Sync; let mut context = false; + let mut max_options = 0; meta.parse_nested_meta(|meta| { if meta.path.is_ident("async") { @@ -54,6 +57,15 @@ impl Validation for Custom { } else if meta.path.is_ident("function") { function_path = Some(meta.value()?.parse()?); + Ok(()) + } else if meta.path.is_ident("options") { + if let Ok(value) = meta.value() { + let lit: LitInt = value.parse()?; + max_options = lit.base10_parse::()?; + } else { + max_options = usize::MAX; + } + Ok(()) } else if meta.path.is_ident("name") { let ident = meta.value()?.parse()?; @@ -78,11 +90,13 @@ impl Validation for Custom { }); Ok(Custom { + r#type: r#type.clone(), name, - execution, error_type, function_path, + execution, context, + max_options, }) } @@ -103,24 +117,53 @@ impl Validation for Custom { } fn expr(&self, execution: Execution, expr: &TokenStream) -> Option { - let context_expr = self.context.then(|| quote!(, &context)); - match (execution, self.execution) { - (Execution::Sync, Execution::Sync) => { - let function_path = &self.function_path; + (Execution::Sync, Execution::Sync) => Some(wrapper( + &self.r#type, + &self.function_path, + expr, + self.context, + self.max_options, + None, + )), + (Execution::Async, Execution::Async) => Some(wrapper( + &self.r#type, + &self.function_path, + expr, + self.context, + self.max_options, + Some(quote!(.await)), + )), + _ => None, + } + } +} - Some(quote! { - #function_path(&#expr #context_expr) - }) - } - (Execution::Async, Execution::Async) => { - let function_path = &self.function_path; +fn wrapper( + r#type: &Type, + function_path: &Path, + expr: &TokenStream, + context: bool, + max_options: usize, + suffix: Option, +) -> TokenStream { + let context_expr = context.then(|| quote!(, &context)); + + let count = count_options(r#type); + let remove_count = count.saturating_sub(max_options); + + if remove_count > 0 { + let mut wrapper = quote!(value); + for _ in 0..remove_count { + wrapper = quote!(Some(#wrapper)); + } - Some(quote! { - #function_path(&#expr #context_expr).await - }) - } - _ => None, + quote! { + { if let #wrapper = &#expr { #function_path(value #context_expr) #suffix} else { Ok(())} } + } + } else { + quote! { + #function_path(&#expr #context_expr) #suffix } } } diff --git a/packages/fortifier/src/error_code.rs b/packages/fortifier/src/error_code.rs index ffc88b7..01a37dc 100644 --- a/packages/fortifier/src/error_code.rs +++ b/packages/fortifier/src/error_code.rs @@ -1,6 +1,46 @@ +#[cfg(all(feature = "serde", feature = "utoipa"))] /// Implement an error code. #[macro_export] macro_rules! error_code { + ($name:ident, $code:literal) => { + $crate::error_code_base!($name, $code); + $crate::error_code_serde!($name, $code); + $crate::error_code_utoipa!($name, $code); + }; +} + +#[cfg(all(feature = "serde", not(feature = "utoipa")))] +/// Implement an error code. +#[macro_export] +macro_rules! error_code { + ($name:ident, $code:literal) => { + $crate::error_code_base!($name, $code); + $crate::error_code_serde!($name, $code); + }; +} + +#[cfg(all(not(feature = "serde"), feature = "utoipa"))] +/// Implement an error code. +#[macro_export] +macro_rules! error_code { + ($name:ident, $code:literal) => { + $crate::error_code_base!($name, $code); + $crate::error_code_utoipa!($name, $code); + }; +} + +#[cfg(all(not(feature = "serde"), not(feature = "utoipa")))] +/// Implement an error code. +#[macro_export] +macro_rules! error_code { + ($name:ident, $code:literal) => { + $crate::error_code_base!($name, $code); + }; +} + +/// Implement an error code. +#[macro_export] +macro_rules! error_code_base { ($name:ident, $code:literal) => { const CODE: &str = $code; @@ -27,20 +67,25 @@ macro_rules! error_code { ::std::fmt::Debug::fmt(&**self, f) } } + }; +} - #[cfg(feature = "serde")] +/// Implement [`serde`] traits for an error code. +#[cfg(feature = "serde")] +#[macro_export] +macro_rules! error_code_serde { + ($name:ident, $code:literal) => { impl<'de> ::serde::Deserialize<'de> for $name { fn deserialize(deserializer: D) -> Result where D: ::serde::Deserializer<'de>, { deserializer - .deserialize_any($crate::integrations::serde::MustBeStrVisitor(CODE)) + .deserialize_any($crate::serde::MustBeStrVisitor(CODE)) .map(|()| Self) } } - #[cfg(feature = "serde")] impl ::serde::Serialize for $name { fn serialize(&self, serializer: S) -> Result where @@ -49,8 +94,14 @@ macro_rules! error_code { serializer.serialize_str(CODE) } } + }; +} - #[cfg(feature = "utoipa")] +/// Implement [`utoipa`] traits for an error code. +#[cfg(feature = "utoipa")] +#[macro_export] +macro_rules! error_code_utoipa { + ($name:ident, $code:literal) => { impl ::utoipa::PartialSchema for $name { fn schema() -> ::utoipa::openapi::RefOr<::utoipa::openapi::schema::Schema> { ::utoipa::openapi::schema::ObjectBuilder::new() @@ -61,7 +112,6 @@ macro_rules! error_code { } } - #[cfg(feature = "utoipa")] impl ::utoipa::ToSchema for $name {} }; } diff --git a/packages/fortifier/src/validate.rs b/packages/fortifier/src/validate.rs index 79028cb..c09490c 100644 --- a/packages/fortifier/src/validate.rs +++ b/packages/fortifier/src/validate.rs @@ -70,35 +70,6 @@ pub trait Validate: ValidateWithContext { } } -/// Generate an infallible validate implementation for a type. -#[macro_export] -macro_rules! validate_ok { - ($type:ty) => { - impl $crate::ValidateWithContext for $type { - type Context = (); - type Error = ::std::convert::Infallible; - - fn validate_sync_with_context( - &self, - _context: &Self::Context, - ) -> Result<(), $crate::ValidationErrors> { - Ok(()) - } - - fn validate_async_with_context( - &self, - _context: &Self::Context, - ) -> ::std::pin::Pin< - Box>> + Send>, - > { - Box::pin(async { Ok(()) }) - } - } - - impl $crate::Validate for $type {} - }; -} - /// Generate a dereference validate implementation for a type. #[macro_export] macro_rules! validate_with_deref {