diff --git a/crates/cuda_std_macros/src/lib.rs b/crates/cuda_std_macros/src/lib.rs index 467a6319..ed97e7ce 100644 --- a/crates/cuda_std_macros/src/lib.rs +++ b/crates/cuda_std_macros/src/lib.rs @@ -1,9 +1,7 @@ use proc_macro::TokenStream; -use proc_macro2::Span; use quote::{ToTokens, quote_spanned}; use syn::{ - Error, FnArg, Ident, ItemFn, ReturnType, Stmt, Token, parse::Parse, parse_macro_input, - parse_quote, punctuated::Punctuated, spanned::Spanned, + FnArg, Ident, ItemFn, ReturnType, Stmt, parse_macro_input, parse_quote, spanned::Spanned, }; /// Registers a function as a gpu kernel. @@ -11,18 +9,18 @@ use syn::{ /// This attribute must always be placed on gpu kernel functions. /// /// This attribute does a couple of things: -/// - Tells `rustc_codegen_nvvm` to mark this as a gpu kernel and to not remove it from the ptx file. +/// - Tells `rustc_codegen_nvvm` to mark this as a gpu kernel and to not remove it from the ptx +/// file. /// - Marks the function as `no_mangle`. /// - Errors if the function is not unsafe. /// - Makes sure function parameters are all [`Copy`]. /// - Makes sure the function doesn't return anything. /// -/// Note that this does not cfg the function for nvptx(64), that is explicit so that rust analyzer is able to -/// offer intellisense by default. +/// Note that this does not cfg the function for nvptx(64), that is explicit so that rust analyzer +/// is able to offer intellisense by default. #[proc_macro_attribute] pub fn kernel(input: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream { let cloned = input.clone(); - let _ = parse_macro_input!(input as KernelHints); let input = parse_macro_input!(cloned as proc_macro2::TokenStream); let mut item = parse_macro_input!(item as ItemFn); let no_mangle = parse_quote!(#[unsafe(no_mangle)]); @@ -30,7 +28,7 @@ pub fn kernel(input: proc_macro::TokenStream, item: proc_macro::TokenStream) -> let internal = parse_quote!(#[cfg_attr(target_arch="nvptx64", nvvm_internal::kernel(#input))]); item.attrs.push(internal); - // used to guarantee some things about how params are passed in the codegen. + // Used to guarantee some things about how params are passed in the codegen. item.sig.abi = Some(parse_quote!(extern "C")); let check_fn = parse_quote! { @@ -71,80 +69,10 @@ pub fn kernel(input: proc_macro::TokenStream, item: proc_macro::TokenStream) -> item.to_token_stream().into() } -#[derive(Debug, Clone, Copy, PartialEq)] -enum Dimension { - Dim1, - Dim2, - Dim3, -} - -impl Parse for Dimension { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let val = Ident::parse(input)?; - let val = val.to_string(); - match val.as_str() { - "1d" | "1D" => Ok(Self::Dim1), - "2d" | "2D" => Ok(Self::Dim2), - "3d" | "3D" => Ok(Self::Dim3), - _ => Err(syn::Error::new(Span::call_site(), "Invalid dimension")), - } - } -} - -enum KernelHint { - GridDim(Dimension), - BlockDim(Dimension), -} - -impl Parse for KernelHint { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let name = Ident::parse(input)?; - let key = name.to_string(); - ::parse(input)?; - match key.as_str() { - "grid_dim" => { - let dim = Dimension::parse(input)?; - Ok(Self::GridDim(dim)) - } - "block_dim" => { - let dim = Dimension::parse(input)?; - Ok(Self::BlockDim(dim)) - } - _ => Err(Error::new(Span::call_site(), "Unrecognized option")), - } - } -} - -#[derive(Debug, Default, Clone, PartialEq)] -struct KernelHints { - grid_dim: Option, - block_dim: Option, -} - -impl Parse for KernelHints { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - let iter = Punctuated::::parse_terminated(input)?; - let hints = iter - .into_pairs() - .map(|x| x.into_value()) - .collect::>(); - - let mut out = KernelHints::default(); - - for hint in hints { - match hint { - KernelHint::GridDim(dim) => out.grid_dim = Some(dim), - KernelHint::BlockDim(dim) => out.block_dim = Some(dim), - } - } - - Ok(out) - } -} - // derived from rust-gpu's gpu_only -/// Creates a cpu version of the function which panics and cfg-gates the function for only nvptx/nvptx64. +/// Creates a cpu version of the function which panics and cfg-gates the function for only +/// nvptx/nvptx64. #[proc_macro_attribute] pub fn gpu_only(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream { let syn::ItemFn {