Original file line numberDiff line numberDiff line change
Expand Up@@ -6,25 +6,30 @@ use proc_macro::TokenStream;

use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, quote_spanned, TokenStreamExt};
use syn::parse::{Parse, ParseStream};
use syn::spanned::Spanned;
use syn::{
parse_macro_input, parse_quote, Error, Expr, ExprLit, ExprPath, FnArg, Ident, ItemFn,
ItemStruct, Lit, LitStr, Pat, Visibility,
parse_macro_input, parse_quote, token, Expr, ExprLit, ExprPath, FnArg, Ident, ItemFn,
ItemStruct, Lit, LitStr, Pat, Type, TypePath, Visibility,
};

macro_rules! err {
($span:expr, $message:expr $(,)?) => {
Error::new($span.span(), $message).to_compile_error()
syn::Error::new($span.span(), $message).to_compile_error()
};
($span:expr, $message:expr, $($args:expr),*) => {
Error::new($span.span(), format!($message, $($args),*)).to_compile_error()
syn::Error::new($span.span(), format!($message, $($args),*)).to_compile_error()
};
}

/// Attribute macro for marking structs as UEFI protocols.
///
/// The macro takes one argument, either a GUID string or the path to a `Guid`
/// constant.
/// The macro takes one, two, or 3 arguments. The first two are GUIDs
/// or the path to a `Guid` constant. The first argument is always the
/// GUID of the protocol, while the optional second argument is the
/// GUID of the service binding protocol, when applicable.
///
/// The third argument is a struct
///
/// The macro can only be applied to a struct. It implements the
/// [`Protocol`] trait and the `unsafe` [`Identify`] trait for the
Expand All@@ -49,38 +54,47 @@ macro_rules! err {
/// #[unsafe_protocol(PROTO_GUID)]
/// struct ExampleProtocol2 {}
///
/// const SERVICE_GUID: Guid = guid!("12345678-9abc-def0-1234-56789abcdef1");
/// #[unsafe_protocol(PROTO_GUID, SERVICE_GUID)]
/// struct ExampleProtocol3 {}
///
/// assert_eq!(ExampleProtocol1::GUID, PROTO_GUID);
/// assert_eq!(ExampleProtocol2::GUID, PROTO_GUID);
/// assert_eq!(ExampleProtocol3::GUID, PROTO_GUID);
///
/// assert_eq!(ExampleProtocol1::SERVICE_BINDING, None);
/// assert_eq!(ExampleProtocol2::SERVICE_BINDING, None);
/// assert_eq!(ExampleProtocol3::SERVICE_BINDING, Some(SERVICE_GUID));
/// ```
///
/// [`Identify`]: https://docs.rs/uefi/latest/uefi/trait.Identify.html
/// [`Protocol`]: https://docs.rs/uefi/latest/uefi/proto/trait.Protocol.html
/// [send-and-sync]: https://doc.rust-lang.org/nomicon/send-and-sync.html
#[proc_macro_attribute]
pub fn unsafe_protocol(args: TokenStream, input: TokenStream) -> TokenStream {
let expr = parse_macro_input!(args as Expr);
let args = parse_macro_input!(args as ProtocolArgs);
let item_struct = parse_macro_input!(input as ItemStruct);
let ident = &item_struct.ident;
let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl();

let guid_val = match expr {
Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) => {
quote!(::uefi::guid!(#lit))
let proto_guid = guid_from_expr(args.protocol_guid);
let service_binding_guid = match args.service_binding_guid {
None => quote!(None),
Some(expr) => {
let guid = guid_from_expr(expr);
quote!(Some(#guid))
}
Expr::Path(ExprPath { path, .. }) => quote!(#path),
_ => {
return err!(
expr,
"macro input must be either a string literal or path to a constant"
)
.into()
};
let wrapper_type = match args.wrapper_type {
None => quote!(::uefi::proto::NoWrapper),
Some(Type::Path(TypePath { path, .. })) => {
let wrapper_ident = &item_struct.ident;
let wrapper_generics = &item_struct.generics;
quote!(::uefi::proto::StructWrapper<#path, #wrapper_ident #wrapper_generics>)
}
Some(typ) => return err!(typ, "wrapper type must be a path").into(),
};

let item_struct = parse_macro_input!(input as ItemStruct);

let ident = &item_struct.ident;
let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl();

quote! {
// Disable this lint for now. It doesn't account for the fact that
// currently it doesn't work to `derive(Debug)` on structs that have
Expand All@@ -91,14 +105,93 @@ pub fn unsafe_protocol(args: TokenStream, input: TokenStream) -> TokenStream {
#item_struct

unsafe impl #impl_generics ::uefi::Identify for #ident #ty_generics #where_clause {
const GUID: ::uefi::Guid = #guid_val;
const GUID: ::uefi::Guid = #proto_guid;
}

impl #impl_generics ::uefi::proto::Protocol for #ident #ty_generics #where_clause {}
impl #impl_generics ::uefi::proto::Protocol for #ident #ty_generics #where_clause {
const SERVICE_BINDING: Option<::uefi::Guid> = #service_binding_guid;
type Raw = #wrapper_type;
}
}
.into()
}

struct ProtocolArgs {
protocol_guid: Expr,
service_binding_guid: Option<Expr>,
wrapper_type: Option<Type>,
}

impl Parse for ProtocolArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
// Always parse a GUID
let protocol_guid = input.parse::<Expr>()?;
let mut args = Self {
protocol_guid,
service_binding_guid: None,
wrapper_type: None,
};

// Next is always a comma
if !input.is_empty() {
let _ = input.parse::<token::Comma>()?;
}

// Next can be a GUID or a comma
let lookahead = input.lookahead1();

// ... so parse a GUID if not a comma
if !input.is_empty() && !lookahead.peek(token::Comma) {
let service_binding_guid = input.parse::<Expr>()?;
args.service_binding_guid = Some(service_binding_guid);
}

// ... and then parse a comma unless at the end
if !input.is_empty() {
let _ = input.parse::<token::Comma>()?;
}

// Next can be a type or a (trailing) comma
let lookahead = input.lookahead1();

// ... so parse a Type if not a comma
if !input.is_empty() && !lookahead.peek(token::Comma) {
let wrapper_type = input.parse::<Type>()?;
args.wrapper_type = Some(wrapper_type);
}

// ... and then parse a (trailing) comma unless at the end
if !input.is_empty() {
let _ = input.parse::<token::Comma>()?;
}

// Error if this is not the end
if !input.is_empty() {
return Err(input.error("up to 3 comma-separated args are supported"));
}

Ok(args)
}
}

fn guid_from_expr(expr: Expr) -> TokenStream2 {
match expr {
Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) => {
quote!(::uefi::guid!(#lit))
}
Expr::Path(ExprPath { path, .. }) => quote!(#path),
_ => {
return err!(
expr,
"macro input must be either a string literal or path to a constant"
)
.into()
}
}
}

/// Get the name of a function's argument at `arg_index`.
fn get_function_arg_name(f: &ItemFn, arg_index: usize, errors: &mut TokenStream2) -> Option<Ident> {
if let Some(FnArg::Typed(arg)) = f.sig.inputs.iter().nth(arg_index) {
Expand Down
Loading