diff --git a/Cargo.toml b/Cargo.toml index 3aaf02d..1bf8a55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,16 @@ +[workspace] +members = [".", "enum-display-macro"] + [package] name = "enum-display" description = "A macro to derive Display for enums" keywords = ["enum", "display", "derive", "macro"] -version = "0.1.4" +version = "0.1.5" edition = "2021" license = "MIT" documentation = "https://docs.rs/enum-display" homepage = "https://github.com/SeedyROM/enum-display" repository = "https://github.com/SeedyROM/enum-display" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] -enum-display-macro = { version = "0.1.4" } +enum-display-macro = { version = "0.1.5", path = "./enum-display-macro" } diff --git a/enum-display-macro/Cargo.toml b/enum-display-macro/Cargo.toml index a2bf804..b780b58 100644 --- a/enum-display-macro/Cargo.toml +++ b/enum-display-macro/Cargo.toml @@ -2,7 +2,7 @@ name = "enum-display-macro" description = "A macro to derive Display for enums" keywords = ["enum", "display", "derive", "macro"] -version = "0.1.4" +version = "0.1.5" edition = "2021" license = "MIT" documentation = "https://docs.rs/enum-display/tree/main/enum-display-derive" @@ -12,9 +12,11 @@ repository = "https://github.com/SeedyROM/enum-display/tree/main/enum-display-de # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +proc-macro2 = "1.0.95" convert_case = "0.6.0" quote = "1.0.21" syn = { version = "1.0.101", features = ["full"] } +regex = "1.11.1" [lib] proc-macro = true diff --git a/enum-display-macro/src/lib.rs b/enum-display-macro/src/lib.rs index 5dc0c7b..6483879 100644 --- a/enum-display-macro/src/lib.rs +++ b/enum-display-macro/src/lib.rs @@ -1,86 +1,317 @@ use convert_case::{Case, Casing}; use proc_macro::{self, TokenStream}; +use proc_macro2::Span; use quote::quote; -use syn::{parse_macro_input, DeriveInput}; - -fn parse_case_name(case_name: &str) -> Case { - match case_name { - "Upper" => Case::Upper, - "Lower" => Case::Lower, - "Title" => Case::Title, - "Toggle" => Case::Toggle, - "Camel" => Case::Camel, - "Pascal" => Case::Pascal, - "UpperCamel" => Case::UpperCamel, - "Snake" => Case::Snake, - "UpperSnake" => Case::UpperSnake, - "ScreamingSnake" => Case::ScreamingSnake, - "Kebab" => Case::Kebab, - "Cobol" => Case::Cobol, - "UpperKebab" => Case::UpperKebab, - "Train" => Case::Train, - "Flat" => Case::Flat, - "UpperFlat" => Case::UpperFlat, - "Alternating" => Case::Alternating, - _ => panic!("Unrecognized case name: {}", case_name), +use regex::Regex; +use syn::{parse_macro_input, Attribute, DeriveInput, FieldsNamed, FieldsUnnamed, Ident, Variant}; + +// Enum attributes +struct EnumAttrs { + case_transform: Option, +} + +impl EnumAttrs { + fn from_attrs(attrs: Vec) -> Self { + let mut case_transform: Option = None; + + for attr in attrs.into_iter() { + if attr.path.is_ident("enum_display") { + let meta = attr.parse_meta().unwrap(); + if let syn::Meta::List(list) = meta { + for nested in list.nested { + if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested { + if name_value.path.is_ident("case") { + if let syn::Lit::Str(lit_str) = name_value.lit { + case_transform = + Some(Self::parse_case_name(lit_str.value().as_str())); + } + } + } + } + } + } + } + + Self { case_transform } + } + + fn parse_case_name(case_name: &str) -> Case { + match case_name { + "Upper" => Case::Upper, + "Lower" => Case::Lower, + "Title" => Case::Title, + "Toggle" => Case::Toggle, + "Camel" => Case::Camel, + "Pascal" => Case::Pascal, + "UpperCamel" => Case::UpperCamel, + "Snake" => Case::Snake, + "UpperSnake" => Case::UpperSnake, + "ScreamingSnake" => Case::ScreamingSnake, + "Kebab" => Case::Kebab, + "Cobol" => Case::Cobol, + "UpperKebab" => Case::UpperKebab, + "Train" => Case::Train, + "Flat" => Case::Flat, + "UpperFlat" => Case::UpperFlat, + "Alternating" => Case::Alternating, + _ => panic!("Unrecognized case name: {case_name}"), + } + } + + fn transform_case(&self, ident: String) -> String { + if let Some(case) = self.case_transform { + ident.to_case(case) + } else { + ident + } } } -#[proc_macro_derive(EnumDisplay, attributes(enum_display))] -pub fn derive(input: TokenStream) -> TokenStream { - // Parse the input tokens into a syntax tree - let DeriveInput { - ident, data, attrs, .. - } = parse_macro_input!(input); +// Variant attributes +struct VariantAttrs { + format: Option, +} - // Should we transform the case of the enum variants? - let mut case_transform: Option = None; - - // Find the enum_display attribute - for attr in attrs.into_iter() { - if attr.path.is_ident("enum_display") { - let meta = attr.parse_meta().unwrap(); - if let syn::Meta::List(list) = meta { - for nested in list.nested { - if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested { - if name_value.path.is_ident("case") { - if let syn::Lit::Str(lit_str) = name_value.lit { - // Set the case transform - case_transform = Some(parse_case_name(lit_str.value().as_str())); +impl VariantAttrs { + fn from_attrs(attrs: Vec) -> Self { + let mut format = None; + + // Find the display attribute + for attr in attrs.into_iter() { + if attr.path.is_ident("display") { + let meta = attr.parse_meta().unwrap(); + if let syn::Meta::List(list) = meta { + if let Some(first_nested) = list.nested.first() { + match first_nested { + // Handle literal string: #[display("format string")] + syn::NestedMeta::Lit(syn::Lit::Str(lit_str)) => { + format = + Some(Self::translate_numeric_placeholders(&lit_str.value())); + } + // Handle named value: #[display(format = "format string")] + syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) => { + if let syn::Lit::Str(lit_str) = &name_value.lit { + format = Some(Self::translate_numeric_placeholders( + &lit_str.value(), + )); + } } + _ => {} } } } } } + + Self { format } } - // Build the match arms - let variants = match data { + // Translates {123:?} to {_unnamed_123:?} for safer format arg usage + fn translate_numeric_placeholders(fmt: &str) -> String { + let re = Regex::new(r"\{\s*(\d+)\s*([^}]*)\}").unwrap(); + re.replace_all(fmt, |caps: ®ex::Captures| { + let idx = &caps[1]; + let fmt_spec = &caps[2]; + format!("{{_unnamed_{idx}{fmt_spec}}}") + }) + .to_string() + } +} + +// Shared intermediate variant info +struct VariantInfo { + ident: Ident, + ident_transformed: String, + attrs: VariantAttrs, +} + +// Intermediate Named variant info +struct NamedVariantIR { + info: VariantInfo, + fields: Vec, +} + +impl NamedVariantIR { + fn from_fields_named(fields_named: FieldsNamed, info: VariantInfo) -> Self { + let fields = fields_named + .named + .into_iter() + .filter_map(|field| field.ident) + .collect(); + Self { info, fields } + } + + fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream { + let VariantInfo { + ident, + ident_transformed, + attrs, + } = self.info; + let fields = self.fields; + match (any_has_format, attrs.format) { + (true, Some(fmt)) => { + quote! { #ident { #(#fields),* } => { let variant = #ident_transformed; format!(#fmt) } } + } + (true, None) => quote! { #ident { .. } => String::from(#ident_transformed), }, + (false, None) => quote! { #ident { .. } => #ident_transformed, }, + _ => unreachable!( + "`any_has_format` should never be false when a variant has format string" + ), + } + } +} + +// Intermediate Unnamed variant info +struct UnnamedVariantIR { + info: VariantInfo, + fields: Vec, +} + +impl UnnamedVariantIR { + fn from_fields_unnamed(fields_unnamed: FieldsUnnamed, info: VariantInfo) -> Self { + let fields: Vec = fields_unnamed + .unnamed + .into_iter() + .enumerate() + .map(|(i, _)| Ident::new(format!("_unnamed_{i}").as_str(), Span::call_site())) + .collect(); + Self { info, fields } + } + + fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream { + let VariantInfo { + ident, + ident_transformed, + attrs, + } = self.info; + let fields = self.fields; + match (any_has_format, attrs.format) { + (true, Some(fmt)) => { + quote! { #ident(#(#fields),*) => { let variant = #ident_transformed; format!(#fmt) } } + } + (true, None) => quote! { #ident(..) => String::from(#ident_transformed), }, + (false, None) => quote! { #ident(..) => #ident_transformed, }, + _ => unreachable!( + "`any_has_format` should never be false when a variant has format string" + ), + } + } +} + +// Intermediate Unit variant info +struct UnitVariantIR { + info: VariantInfo, +} + +impl UnitVariantIR { + fn new(info: VariantInfo) -> Self { + Self { info } + } + + fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream { + let VariantInfo { + ident, + ident_transformed, + attrs, + } = self.info; + match (any_has_format, attrs.format) { + (true, Some(fmt)) => { + quote! { #ident => { let variant = #ident_transformed; format!(#fmt) } } + } + (true, None) => quote! { #ident => String::from(#ident_transformed), }, + (false, None) => quote! { #ident => #ident_transformed, }, + _ => unreachable!( + "`any_has_format` should never be false when a variant has format string" + ), + } + } +} + +// Intermediate version of Variant +enum VariantIR { + Named(NamedVariantIR), + Unnamed(UnnamedVariantIR), + Unit(UnitVariantIR), +} + +impl VariantIR { + fn from_variant(variant: Variant, enum_attrs: &EnumAttrs) -> Self { + let ident_str = variant.ident.to_string(); + let info = VariantInfo { + ident: variant.ident, + ident_transformed: enum_attrs.transform_case(ident_str), + attrs: VariantAttrs::from_attrs(variant.attrs), + }; + match variant.fields { + syn::Fields::Named(fields_named) => { + Self::Named(NamedVariantIR::from_fields_named(fields_named, info)) + } + syn::Fields::Unnamed(fields_unnamed) => { + Self::Unnamed(UnnamedVariantIR::from_fields_unnamed(fields_unnamed, info)) + } + syn::Fields::Unit => Self::Unit(UnitVariantIR::new(info)), + } + } + + fn generate(self, any_has_format: bool) -> proc_macro2::TokenStream { + match self { + VariantIR::Named(named_variant) => named_variant.generate(any_has_format), + VariantIR::Unnamed(unnamed_variant) => unnamed_variant.generate(any_has_format), + VariantIR::Unit(unit_variant) => unit_variant.generate(any_has_format), + } + } + + fn has_format(&self) -> bool { + match self { + VariantIR::Named(named_variant) => &named_variant.info, + VariantIR::Unnamed(unnamed_variant) => &unnamed_variant.info, + VariantIR::Unit(unit_variant) => &unit_variant.info, + } + .attrs + .format + .is_some() + } +} + +#[proc_macro_derive(EnumDisplay, attributes(enum_display, display))] +pub fn derive(input: TokenStream) -> TokenStream { + // Parse the input tokens into a syntax tree + let DeriveInput { + ident, + data, + attrs, + generics, + .. + } = parse_macro_input!(input); + + // Copy generics and bounds + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + // Read enum attrs + let enum_attrs = EnumAttrs::from_attrs(attrs); + + // Read variants and variant attrs into an intermediate format + let intermediate_variants: Vec = match data { syn::Data::Enum(syn::DataEnum { variants, .. }) => variants, _ => panic!("EnumDisplay can only be derived for enums"), } .into_iter() - .map(|variant| { - let ident = variant.ident; - let ident_str = if case_transform.is_some() { - ident.to_string().to_case(case_transform.unwrap()) - } else { - ident.to_string() - }; + .map(|variant| VariantIR::from_variant(variant, &enum_attrs)) + .collect(); - match variant.fields { - syn::Fields::Named(_) => quote! { - #ident { .. } => #ident_str, - }, - syn::Fields::Unnamed(_) => quote! { - #ident(..) => #ident_str, - }, - syn::Fields::Unit => quote! { - #ident => #ident_str, - }, - } - }); + // If any variants have a format string, the output of all match arms must be String instead of &str + // This is because we can't return a reference to the temporary output of format!() + let any_has_format = intermediate_variants.iter().any(|v| v.has_format()); + let post_fix = if any_has_format { + quote! { .as_str() } + } else { + quote! {} + }; + + // Build the match arms + let variants = intermediate_variants + .into_iter() + .map(|v| v.generate(any_has_format)); // #[allow(unused_qualifications)] is needed // due to https://github.com/SeedyROM/enum-display/issues/1 @@ -88,13 +319,13 @@ pub fn derive(input: TokenStream) -> TokenStream { let output = quote! { #[automatically_derived] #[allow(unused_qualifications)] - impl ::core::fmt::Display for #ident { + impl #impl_generics ::core::fmt::Display for #ident #ty_generics #where_clause { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { ::core::fmt::Formatter::write_str( f, match self { - #(#ident::#variants)* - }, + #(Self::#variants)* + }#post_fix ) } } diff --git a/src/lib.rs b/src/lib.rs index 2014389..62ea829 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,13 +44,43 @@ mod tests { #[derive(EnumDisplay)] enum TestEnum { Name, + + #[display("Overridden Name")] + OverriddenName, + + #[display("Unit: {variant}")] + NameFullFormat, + Address { street: String, city: String, state: String, zip: String, }, + + #[display("Named: {variant} {{{street}, {zip}}}")] + AddressPartialFormat { + street: String, + city: String, + state: String, + zip: String, + }, + + #[display("Named: {variant} {{{street}, {city}, {state}, {zip}}}")] + AddressFullFormat { + street: String, + city: String, + state: String, + zip: String, + }, + DateOfBirth(u32, u32, u32), + + #[display("Unnamed: {variant}({2})")] + DateOfBirthPartialFormat(u32, u32, u32), + + #[display("Unnamed: {variant}({0}, {1}, {2})")] + DateOfBirthFullFormat(u32, u32, u32), } #[allow(dead_code)] @@ -67,9 +97,27 @@ mod tests { DateOfBirth(u32, u32, u32), } + #[allow(dead_code)] + #[derive(EnumDisplay)] + enum TestEnumWithLifetimeAndGenerics<'a, T: Clone> + where + T: std::fmt::Display, + { + Name, + Address { + street: &'a T, + city: &'a T, + state: &'a T, + zip: &'a T, + }, + DateOfBirth(u32, u32, u32), + } + #[test] fn test_unit_field_variant() { assert_eq!(TestEnum::Name.to_string(), "Name"); + assert_eq!(TestEnum::OverriddenName.to_string(), "Overridden Name"); + assert_eq!(TestEnum::NameFullFormat.to_string(), "Unit: NameFullFormat"); } #[test] @@ -84,11 +132,39 @@ mod tests { .to_string(), "Address" ); + assert_eq!( + TestEnum::AddressPartialFormat { + street: "123 Main St".to_string(), + city: "Any Town".to_string(), + state: "CA".to_string(), + zip: "12345".to_string() + } + .to_string(), + "Named: AddressPartialFormat {123 Main St, 12345}" + ); + assert_eq!( + TestEnum::AddressFullFormat { + street: "123 Main St".to_string(), + city: "Any Town".to_string(), + state: "CA".to_string(), + zip: "12345".to_string() + } + .to_string(), + "Named: AddressFullFormat {123 Main St, Any Town, CA, 12345}" + ); } #[test] fn test_unnamed_fields_variant() { - assert_eq!(TestEnum::DateOfBirth(1, 1, 2000).to_string(), "DateOfBirth"); + assert_eq!(TestEnum::DateOfBirth(1, 2, 1999).to_string(), "DateOfBirth"); + assert_eq!( + TestEnum::DateOfBirthPartialFormat(1, 2, 1999).to_string(), + "Unnamed: DateOfBirthPartialFormat(1999)" + ); + assert_eq!( + TestEnum::DateOfBirthFullFormat(1, 2, 1999).to_string(), + "Unnamed: DateOfBirthFullFormat(1, 2, 1999)" + ); } #[test] @@ -117,4 +193,34 @@ mod tests { "date-of-birth" ); } + + #[test] + fn test_unit_field_variant_with_lifetime_and_generics() { + assert_eq!( + TestEnumWithLifetimeAndGenerics::<'_, String>::Name.to_string(), + "Name" + ); + } + + #[test] + fn test_named_fields_variant_with_lifetime_and_generics() { + assert_eq!( + TestEnumWithLifetimeAndGenerics::Address { + street: &"123 Main St".to_string(), + city: &"Any Town".to_string(), + state: &"CA".to_string(), + zip: &"12345".to_string() + } + .to_string(), + "Address" + ); + } + + #[test] + fn test_unnamed_fields_variant_with_lifetime_and_generics() { + assert_eq!( + TestEnumWithLifetimeAndGenerics::<'_, String>::DateOfBirth(1, 1, 2000).to_string(), + "DateOfBirth" + ); + } }