44
55use proc_macro2:: TokenStream ;
66use quote:: quote;
7- use syn:: { Data , DeriveInput , Fields , Expr , Lit , spanned:: Spanned } ;
7+ use syn:: { Data , DeriveInput , Expr , Fields , spanned:: Spanned } ;
8+
9+ /// Parse the repr type from #[repr(...)] attribute
10+ fn parse_repr_type ( input : & DeriveInput ) -> TokenStream {
11+ for attr in & input. attrs {
12+ if attr. path ( ) . is_ident ( "repr" ) {
13+ // Parse as a single type first (e.g., #[repr(u8)])
14+ if let Ok ( ty) = attr. parse_args :: < syn:: Type > ( ) {
15+ if is_valid_discriminant_type ( & ty) {
16+ return quote ! { #ty } ;
17+ }
18+ } else {
19+ // Parse as list of types (for example #[repr(C, u8)])
20+ if let Ok ( meta_list) = attr. parse_args_with (
21+ syn:: punctuated:: Punctuated :: < syn:: Meta , syn:: Token ![ , ] > :: parse_terminated,
22+ ) {
23+ for meta in meta_list {
24+ if let syn:: Meta :: Path ( path) = meta {
25+ let ty: syn:: Type = syn:: Type :: Path ( syn:: TypePath {
26+ qself : None ,
27+ path : path. clone ( ) ,
28+ } ) ;
29+ if is_valid_discriminant_type ( & ty) {
30+ return quote ! { #ty } ;
31+ }
32+ }
33+ }
34+ }
35+ }
36+ }
37+ }
38+
39+ // Fallback if no valid repr type found
40+ quote ! { isize }
41+ }
42+
43+ /// Check if a type is a valid discriminant type
44+ fn is_valid_discriminant_type ( ty : & syn:: Type ) -> bool {
45+ if let syn:: Type :: Path ( type_path) = ty
46+ && let Some ( ident) = type_path. path . get_ident ( )
47+ {
48+ let ident_str = ident. to_string ( ) ;
49+ return matches ! ( ident_str. as_str( ) , |"u8" | "i8"
50+ | "u16"
51+ | "i16"
52+ | "u32"
53+ | "i32"
54+ | "u64"
55+ | "i64"
56+ | "u128"
57+ | "i128"
58+ | "usize"
59+ | "isize" ) ;
60+ }
61+ false
62+ }
863
964pub fn expand_data_enum ( input : DeriveInput ) -> syn:: Result < TokenStream > {
1065 let enum_name = & input. ident ;
1166 let generics = & input. generics ;
67+ let discriminant_type = parse_repr_type ( & input) ;
1268
1369 let variants = match & input. data {
1470 Data :: Enum ( data_enum) => & data_enum. variants ,
@@ -34,11 +90,12 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
3490
3591 // Information about the variants
3692 let mut variant_data = Vec :: new ( ) ;
37- let mut current_discriminant: Expr = syn:: parse_quote!( 0 ) ;
93+ let mut current_discriminant: Expr = syn:: parse_quote!( 0 ) ;
3894 for variant in variants. iter ( ) {
3995 let variant_name = & variant. ident ;
4096
41- let discriminant_value = variant. discriminant
97+ let discriminant_value = variant
98+ . discriminant
4299 . as_ref ( )
43100 . map ( |( _, expr) | expr. clone ( ) )
44101 . unwrap_or_else ( || current_discriminant. clone ( ) ) ;
@@ -220,14 +277,28 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
220277 . collect ( ) ;
221278
222279 let tuple_repr = if union_names. is_empty ( ) {
223- quote ! { ( u8 , ) }
280+ quote ! { ( #discriminant_type , ) }
224281 } else {
225- quote ! { ( u8 , #( #union_names) , * ) }
282+ quote ! { ( #discriminant_type , #( #union_names) , * ) }
226283 } ;
227284
285+ // Generate const items for discriminant values
286+ // This is so that we can later pattern match on the discriminant values
287+ let discriminant_consts: Vec < _ > = variant_data
288+ . iter ( )
289+ . map ( |( variant_name, disc_value, _, _) | {
290+ let const_name =
291+ quote:: format_ident!( "__DISC_{}" , variant_name. to_string( ) . to_uppercase( ) ) ;
292+ quote ! {
293+ const #const_name: #discriminant_type = #disc_value;
294+ }
295+ } )
296+ . collect ( ) ;
297+
228298 // Generate into_tuple match arms
229- let into_tuple_arms: Vec < _ > = variant_data. iter ( ) . map ( |( variant_name, disc_value , field_names, _field_types) | {
299+ let into_tuple_arms: Vec < _ > = variant_data. iter ( ) . map ( |( variant_name, _disc_value , field_names, _field_types) | {
230300 let variant_name_lower = variant_name. to_string ( ) . to_lowercase ( ) ;
301+ let const_name = quote:: format_ident!( "__DISC_{}" , variant_name. to_string( ) . to_uppercase( ) ) ;
231302
232303 if field_names. is_empty ( ) {
233304 // Unit variant
@@ -239,16 +310,15 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
239310
240311 if union_constructions. is_empty ( ) {
241312 quote ! {
242- Self :: #variant_name => ( #disc_value , )
313+ Self :: #variant_name => ( #const_name , )
243314 }
244315 } else {
245316 quote ! {
246- Self :: #variant_name => ( #disc_value , #( #union_constructions) , * )
317+ Self :: #variant_name => ( #const_name , #( #union_constructions) , * )
247318 }
248319 }
249- }
250- else {
251- // Variant with fields
320+ } else {
321+ // Variant with fields
252322 let union_constructions: Vec < _ > = ( 0 ..max_fields) . map ( |idx| {
253323 let union_name = quote:: format_ident!( "{}Field{}" , enum_name, idx) ;
254324 let union_field = quote:: format_ident!( "{}" , variant_name_lower) ;
@@ -261,21 +331,23 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
261331 } ) . collect ( ) ;
262332
263333 quote ! {
264- Self :: #variant_name( #( #field_names) , * ) => ( #disc_value , #( #union_constructions) , * )
334+ Self :: #variant_name( #( #field_names) , * ) => ( #const_name , #( #union_constructions) , * )
265335 }
266336 }
267337 } ) . collect ( ) ;
268338
269339 // Generate from_tuple match arms
270340 let from_tuple_arms: Vec < _ > = variant_data
271341 . iter ( )
272- . map ( |( variant_name, disc_value , field_names, _field_types) | {
342+ . map ( |( variant_name, _disc_value , field_names, _field_types) | {
273343 let variant_name_lower = variant_name. to_string ( ) . to_lowercase ( ) ;
344+ let const_name =
345+ quote:: format_ident!( "__DISC_{}" , variant_name. to_string( ) . to_uppercase( ) ) ;
274346
275347 if field_names. is_empty ( ) {
276348 // Unit variant
277349 quote ! {
278- #disc_value => Self :: #variant_name
350+ #const_name => Self :: #variant_name
279351 }
280352 } else {
281353 // Variant with fields - extract from unions
@@ -291,7 +363,7 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
291363 . collect ( ) ;
292364
293365 quote ! {
294- #disc_value => {
366+ #const_name => {
295367 #( #field_extractions) *
296368 Self :: #variant_name( #( #field_names) , * )
297369 }
@@ -310,6 +382,8 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
310382 . collect ( ) ;
311383
312384 let expanded = quote ! {
385+ #( #discriminant_consts) *
386+
313387 #( #union_types) *
314388
315389 #[ allow( dead_code) ]
@@ -378,7 +452,7 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
378452 let ( __disc_ptr, #( #union_field_names) , * ) = value;
379453 unsafe {
380454 #ref_struct_name {
381- discriminant: core :: mem :: transmute ( __disc_ptr. as_ptr( ) ) ,
455+ discriminant: & * ( __disc_ptr. as_ptr( ) as * const _ ) ,
382456 #( #ref_union_field_names: #union_field_names. as_ref( ) ) , *
383457 }
384458 }
@@ -391,7 +465,7 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
391465 let ( __disc_ptr, #( mut #union_field_names) , * ) = value;
392466 unsafe {
393467 #mut_struct_name {
394- discriminant: core :: mem :: transmute ( __disc_ptr. as_ptr( ) ) ,
468+ discriminant: & * ( __disc_ptr. as_ptr( ) as * const _ ) ,
395469 #( #ref_union_field_names: #union_field_names. as_mut( ) ) , *
396470 }
397471 }
@@ -406,7 +480,7 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
406480 let ( __disc_ptr, #( #union_field_names) , * ) = value;
407481 unsafe {
408482 #slice_struct_name {
409- discriminant: core:: slice:: from_raw_parts( core :: mem :: transmute ( __disc_ptr. as_ptr( ) ) , __soa_len) ,
483+ discriminant: core:: slice:: from_raw_parts( __disc_ptr. as_ptr( ) as * const _ , __soa_len) ,
410484 #( #ref_union_field_names: core:: slice:: from_raw_parts( #union_field_names. as_ptr( ) , __soa_len) ) , *
411485 }
412486 }
@@ -421,7 +495,7 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
421495 let ( __disc_ptr, #( #union_field_names) , * ) = value;
422496 unsafe {
423497 #slice_mut_struct_name {
424- discriminant: core:: slice:: from_raw_parts_mut( core :: mem :: transmute ( __disc_ptr. as_ptr( ) ) , __soa_len) ,
498+ discriminant: core:: slice:: from_raw_parts_mut( __disc_ptr. as_ptr( ) as * mut _ , __soa_len) ,
425499 #( #ref_union_field_names: core:: slice:: from_raw_parts_mut( #union_field_names. as_ptr( ) , __soa_len) ) , *
426500 }
427501 }
@@ -439,6 +513,7 @@ mod tests {
439513 #[ test]
440514 fn test_expand_derive_enum ( ) {
441515 let input: DeriveInput = syn:: parse_quote! {
516+ #[ repr( u8 ) ]
442517 enum TestEnum {
443518 Undefined ,
444519 Null ,
@@ -449,7 +524,10 @@ mod tests {
449524
450525 let result = expand_data_enum ( input) . unwrap ( ) . to_string ( ) ;
451526
452- // std::fs::write("test_expand_derive_enum_output.rs", &result).unwrap();
527+ assert ! ( result. contains( "const __DISC_UNDEFINED : u8 = 0 ;" ) ) ;
528+ assert ! ( result. contains( "const __DISC_NULL : u8 = 0 + 1 ;" ) ) ;
529+ assert ! ( result. contains( "const __DISC_BOOLEAN : u8 = 0 + 1 + 1 ;" ) ) ;
530+ assert ! ( result. contains( "const __DISC_NUMBER : u8 = 0 + 1 + 1 + 1 ;" ) ) ;
453531
454532 assert ! ( result. contains( "union TestEnumField0" ) ) ;
455533 assert ! ( result. contains( "undefined : ()" ) ) ;
@@ -480,8 +558,6 @@ mod tests {
480558
481559 let result = expand_data_enum ( input) . unwrap ( ) . to_string ( ) ;
482560
483- // std::fs::write("test_expand_derive_enum_explicit_discriminant.rs", &result).unwrap();
484-
485561 assert ! ( result. contains( "union ValueField0" ) ) ;
486562 assert ! ( result. contains( "undefined : ()" ) ) ;
487563 assert ! ( result. contains( "string : & 'static str" ) ) ;
@@ -541,8 +617,6 @@ mod tests {
541617
542618 let result = expand_data_enum ( input) . unwrap ( ) . to_string ( ) ;
543619
544- // std::fs::write("test_multiple_field_variants_output.rs", &result).unwrap();
545-
546620 assert ! ( result. contains( "union MultiFieldField0" ) ) ;
547621 assert ! ( result. contains( "union MultiFieldField1" ) ) ;
548622 assert ! ( result. contains( "union MultiFieldField2" ) ) ;
@@ -553,4 +627,88 @@ mod tests {
553627 assert ! ( result. contains( "String" ) ) ;
554628 assert ! ( result. contains( "i32" ) ) ;
555629 }
630+
631+ #[ test]
632+ fn test_respects_repr ( ) {
633+ let input: DeriveInput = syn:: parse_quote! {
634+ #[ repr( u16 ) ]
635+ enum CustomRepr {
636+ A ,
637+ B ,
638+ C ,
639+ }
640+ } ;
641+
642+ let result = expand_data_enum ( input) . unwrap ( ) . to_string ( ) ;
643+
644+ assert ! ( result. contains( "const __DISC_A : u16 = 0 ;" ) ) ;
645+ assert ! ( result. contains( "const __DISC_B : u16 = 0 + 1 ;" ) ) ;
646+ assert ! ( result. contains( "const __DISC_C : u16 = 0 + 1 + 1 ;" ) ) ;
647+ }
648+
649+ #[ test]
650+ fn test_respects_custom_discriminants ( ) {
651+ let input: DeriveInput = syn:: parse_quote! {
652+ enum CustomDiscriminants {
653+ A = 10 ,
654+ B = 20 ,
655+ C , // 20 + 1
656+ }
657+ } ;
658+
659+ let result = expand_data_enum ( input) . unwrap ( ) . to_string ( ) ;
660+
661+ assert ! ( result. contains( "const __DISC_A : isize = 10 ;" ) ) ;
662+ assert ! ( result. contains( "const __DISC_B : isize = 20 ;" ) ) ;
663+ assert ! ( result. contains( "const __DISC_C : isize = 20 + 1 ;" ) ) ;
664+ }
665+
666+ #[ test]
667+ fn test_function_const_discriminants ( ) {
668+ let input: DeriveInput = syn:: parse_quote! {
669+ enum FuncConstDiscriminants {
670+ A = MY_CONST ,
671+ B = my_function( ) ,
672+ C , // my_function() + 1
673+ }
674+ } ;
675+
676+ let result = expand_data_enum ( input) . unwrap ( ) . to_string ( ) ;
677+
678+ assert ! ( result. contains( "const __DISC_A : isize = MY_CONST ;" ) ) ;
679+ assert ! ( result. contains( "const __DISC_B : isize = my_function () ;" ) ) ;
680+ assert ! ( result. contains( "const __DISC_C : isize = my_function () + 1 ;" ) ) ;
681+ }
682+
683+ #[ test]
684+ fn test_does_not_use_c_as_type ( ) {
685+ let input: DeriveInput = syn:: parse_quote! {
686+ #[ repr( C ) ] // should default to isize
687+ enum NotC {
688+ A ,
689+ B ,
690+ }
691+ } ;
692+
693+ let result = expand_data_enum ( input) . unwrap ( ) . to_string ( ) ;
694+
695+ assert ! ( result. contains( "const __DISC_A : isize = 0 ;" ) ) ;
696+ assert ! ( result. contains( "const __DISC_B : isize = 0 + 1 ;" ) ) ;
697+ }
698+
699+ #[ test]
700+ fn test_repr_c_u8_uses_u8 ( ) {
701+ let input: DeriveInput = syn:: parse_quote! {
702+ #[ repr( C , u8 ) ]
703+ enum CAndU8 {
704+ A ,
705+ B ,
706+ }
707+ } ;
708+
709+ let result = expand_data_enum ( input) . unwrap ( ) . to_string ( ) ;
710+
711+ assert ! ( result. contains( "const __DISC_A : u8 = 0 ;" ) ) ;
712+ assert ! ( result. contains( "const __DISC_B : u8 = 0 + 1 ;" ) ) ;
713+ }
556714}
0 commit comments