Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 8 additions & 80 deletions crates/cuda_std_macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,34 @@
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{ToTokens, quote_spanned};
use syn::{
Error, FnArg, Ident, ItemFn, ReturnType, Stmt, Token, parse::Parse, parse_macro_input,
parse_quote, punctuated::Punctuated, spanned::Spanned,
FnArg, Ident, ItemFn, ReturnType, Stmt, parse_macro_input, parse_quote, spanned::Spanned,
};

/// Registers a function as a gpu kernel.
///
/// This attribute must always be placed on gpu kernel functions.
///
/// This attribute does a couple of things:
/// - Tells `rustc_codegen_nvvm` to mark this as a gpu kernel and to not remove it from the ptx file.
/// - Tells `rustc_codegen_nvvm` to mark this as a gpu kernel and to not remove it from the ptx
/// file.
/// - Marks the function as `no_mangle`.
/// - Errors if the function is not unsafe.
/// - Makes sure function parameters are all [`Copy`].
/// - Makes sure the function doesn't return anything.
///
/// Note that this does not cfg the function for nvptx(64), that is explicit so that rust analyzer is able to
/// offer intellisense by default.
/// Note that this does not cfg the function for nvptx(64), that is explicit so that rust analyzer
/// is able to offer intellisense by default.
#[proc_macro_attribute]
pub fn kernel(input: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream {
let cloned = input.clone();
let _ = parse_macro_input!(input as KernelHints);
let input = parse_macro_input!(cloned as proc_macro2::TokenStream);
let mut item = parse_macro_input!(item as ItemFn);
let no_mangle = parse_quote!(#[unsafe(no_mangle)]);
item.attrs.push(no_mangle);
let internal = parse_quote!(#[cfg_attr(target_arch="nvptx64", nvvm_internal::kernel(#input))]);
item.attrs.push(internal);

// used to guarantee some things about how params are passed in the codegen.
// Used to guarantee some things about how params are passed in the codegen.
item.sig.abi = Some(parse_quote!(extern "C"));

let check_fn = parse_quote! {
Expand Down Expand Up @@ -71,80 +69,10 @@ pub fn kernel(input: proc_macro::TokenStream, item: proc_macro::TokenStream) ->
item.to_token_stream().into()
}

#[derive(Debug, Clone, Copy, PartialEq)]
enum Dimension {
Dim1,
Dim2,
Dim3,
}

impl Parse for Dimension {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let val = Ident::parse(input)?;
let val = val.to_string();
match val.as_str() {
"1d" | "1D" => Ok(Self::Dim1),
"2d" | "2D" => Ok(Self::Dim2),
"3d" | "3D" => Ok(Self::Dim3),
_ => Err(syn::Error::new(Span::call_site(), "Invalid dimension")),
}
}
}

enum KernelHint {
GridDim(Dimension),
BlockDim(Dimension),
}

impl Parse for KernelHint {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let name = Ident::parse(input)?;
let key = name.to_string();
<Token![=]>::parse(input)?;
match key.as_str() {
"grid_dim" => {
let dim = Dimension::parse(input)?;
Ok(Self::GridDim(dim))
}
"block_dim" => {
let dim = Dimension::parse(input)?;
Ok(Self::BlockDim(dim))
}
_ => Err(Error::new(Span::call_site(), "Unrecognized option")),
}
}
}

#[derive(Debug, Default, Clone, PartialEq)]
struct KernelHints {
grid_dim: Option<Dimension>,
block_dim: Option<Dimension>,
}

impl Parse for KernelHints {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let iter = Punctuated::<KernelHint, Token![,]>::parse_terminated(input)?;
let hints = iter
.into_pairs()
.map(|x| x.into_value())
.collect::<Vec<_>>();

let mut out = KernelHints::default();

for hint in hints {
match hint {
KernelHint::GridDim(dim) => out.grid_dim = Some(dim),
KernelHint::BlockDim(dim) => out.block_dim = Some(dim),
}
}

Ok(out)
}
}

// derived from rust-gpu's gpu_only

/// Creates a cpu version of the function which panics and cfg-gates the function for only nvptx/nvptx64.
/// Creates a cpu version of the function which panics and cfg-gates the function for only
/// nvptx/nvptx64.
#[proc_macro_attribute]
pub fn gpu_only(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream {
let syn::ItemFn {
Expand Down
Loading