Skip to content

Commit 6398f10

Browse files
committed
refactor: pattern match on discrimintant consts
1 parent 63b02fd commit 6398f10

File tree

1 file changed

+182
-24
lines changed

1 file changed

+182
-24
lines changed

soavec_derive/src/soa_enum.rs

Lines changed: 182 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,67 @@
44

55
use proc_macro2::TokenStream;
66
use quote::quote;
7-
use syn::{Data, DeriveInput, Fields, Expr, Lit, spanned::Spanned};
7+
use syn::{Data, DeriveInput, Expr, Fields, spanned::Spanned};
8+
9+
/// Parse the repr type from #[repr(...)] attribute
10+
fn parse_repr_type(input: &DeriveInput) -> TokenStream {
11+
for attr in &input.attrs {
12+
if attr.path().is_ident("repr") {
13+
// Parse as a single type first (e.g., #[repr(u8)])
14+
if let Ok(ty) = attr.parse_args::<syn::Type>() {
15+
if is_valid_discriminant_type(&ty) {
16+
return quote! { #ty };
17+
}
18+
} else {
19+
// Parse as list of types (for example #[repr(C, u8)])
20+
if let Ok(meta_list) = attr.parse_args_with(
21+
syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
22+
) {
23+
for meta in meta_list {
24+
if let syn::Meta::Path(path) = meta {
25+
let ty: syn::Type = syn::Type::Path(syn::TypePath {
26+
qself: None,
27+
path: path.clone(),
28+
});
29+
if is_valid_discriminant_type(&ty) {
30+
return quote! { #ty };
31+
}
32+
}
33+
}
34+
}
35+
}
36+
}
37+
}
38+
39+
// Fallback if no valid repr type found
40+
quote! { isize }
41+
}
42+
43+
/// Check if a type is a valid discriminant type
44+
fn is_valid_discriminant_type(ty: &syn::Type) -> bool {
45+
if let syn::Type::Path(type_path) = ty
46+
&& let Some(ident) = type_path.path.get_ident()
47+
{
48+
let ident_str = ident.to_string();
49+
return matches!(ident_str.as_str(), |"u8"| "i8"
50+
| "u16"
51+
| "i16"
52+
| "u32"
53+
| "i32"
54+
| "u64"
55+
| "i64"
56+
| "u128"
57+
| "i128"
58+
| "usize"
59+
| "isize");
60+
}
61+
false
62+
}
863

964
pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
1065
let enum_name = &input.ident;
1166
let generics = &input.generics;
67+
let discriminant_type = parse_repr_type(&input);
1268

1369
let variants = match &input.data {
1470
Data::Enum(data_enum) => &data_enum.variants,
@@ -34,11 +90,12 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
3490

3591
// Information about the variants
3692
let mut variant_data = Vec::new();
37-
let mut current_discriminant: Expr = syn::parse_quote!(0);
93+
let mut current_discriminant: Expr = syn::parse_quote!(0);
3894
for variant in variants.iter() {
3995
let variant_name = &variant.ident;
4096

41-
let discriminant_value = variant.discriminant
97+
let discriminant_value = variant
98+
.discriminant
4299
.as_ref()
43100
.map(|(_, expr)| expr.clone())
44101
.unwrap_or_else(|| current_discriminant.clone());
@@ -220,14 +277,28 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
220277
.collect();
221278

222279
let tuple_repr = if union_names.is_empty() {
223-
quote! { (u8,) }
280+
quote! { (#discriminant_type,) }
224281
} else {
225-
quote! { (u8, #(#union_names),*) }
282+
quote! { (#discriminant_type, #(#union_names),*) }
226283
};
227284

285+
// Generate const items for discriminant values
286+
// This is so that we can later pattern match on the discriminant values
287+
let discriminant_consts: Vec<_> = variant_data
288+
.iter()
289+
.map(|(variant_name, disc_value, _, _)| {
290+
let const_name =
291+
quote::format_ident!("__DISC_{}", variant_name.to_string().to_uppercase());
292+
quote! {
293+
const #const_name: #discriminant_type = #disc_value;
294+
}
295+
})
296+
.collect();
297+
228298
// Generate into_tuple match arms
229-
let into_tuple_arms: Vec<_> = variant_data.iter().map(|(variant_name, disc_value, field_names, _field_types)| {
299+
let into_tuple_arms: Vec<_> = variant_data.iter().map(|(variant_name, _disc_value, field_names, _field_types)| {
230300
let variant_name_lower = variant_name.to_string().to_lowercase();
301+
let const_name = quote::format_ident!("__DISC_{}", variant_name.to_string().to_uppercase());
231302

232303
if field_names.is_empty() {
233304
// Unit variant
@@ -239,16 +310,15 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
239310

240311
if union_constructions.is_empty() {
241312
quote! {
242-
Self::#variant_name => (#disc_value,)
313+
Self::#variant_name => (#const_name,)
243314
}
244315
} else {
245316
quote! {
246-
Self::#variant_name => (#disc_value, #(#union_constructions),*)
317+
Self::#variant_name => (#const_name, #(#union_constructions),*)
247318
}
248319
}
249-
}
250-
else {
251-
// Variant with fields
320+
} else {
321+
// Variant with fields
252322
let union_constructions: Vec<_> = (0..max_fields).map(|idx| {
253323
let union_name = quote::format_ident!("{}Field{}", enum_name, idx);
254324
let union_field = quote::format_ident!("{}", variant_name_lower);
@@ -261,21 +331,23 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
261331
}).collect();
262332

263333
quote! {
264-
Self::#variant_name(#(#field_names),*) => (#disc_value, #(#union_constructions),*)
334+
Self::#variant_name(#(#field_names),*) => (#const_name, #(#union_constructions),*)
265335
}
266336
}
267337
}).collect();
268338

269339
// Generate from_tuple match arms
270340
let from_tuple_arms: Vec<_> = variant_data
271341
.iter()
272-
.map(|(variant_name, disc_value, field_names, _field_types)| {
342+
.map(|(variant_name, _disc_value, field_names, _field_types)| {
273343
let variant_name_lower = variant_name.to_string().to_lowercase();
344+
let const_name =
345+
quote::format_ident!("__DISC_{}", variant_name.to_string().to_uppercase());
274346

275347
if field_names.is_empty() {
276348
// Unit variant
277349
quote! {
278-
#disc_value => Self::#variant_name
350+
#const_name => Self::#variant_name
279351
}
280352
} else {
281353
// Variant with fields - extract from unions
@@ -291,7 +363,7 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
291363
.collect();
292364

293365
quote! {
294-
#disc_value => {
366+
#const_name => {
295367
#(#field_extractions)*
296368
Self::#variant_name(#(#field_names),*)
297369
}
@@ -310,6 +382,8 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
310382
.collect();
311383

312384
let expanded = quote! {
385+
#(#discriminant_consts)*
386+
313387
#(#union_types)*
314388

315389
#[allow(dead_code)]
@@ -378,7 +452,7 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
378452
let (__disc_ptr, #(#union_field_names),*) = value;
379453
unsafe {
380454
#ref_struct_name {
381-
discriminant: core::mem::transmute(__disc_ptr.as_ptr()),
455+
discriminant: &*(__disc_ptr.as_ptr() as *const _),
382456
#(#ref_union_field_names: #union_field_names.as_ref()),*
383457
}
384458
}
@@ -391,7 +465,7 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
391465
let (__disc_ptr, #(mut #union_field_names),*) = value;
392466
unsafe {
393467
#mut_struct_name {
394-
discriminant: core::mem::transmute(__disc_ptr.as_ptr()),
468+
discriminant: &*(__disc_ptr.as_ptr() as *const _),
395469
#(#ref_union_field_names: #union_field_names.as_mut()),*
396470
}
397471
}
@@ -406,7 +480,7 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
406480
let (__disc_ptr, #(#union_field_names),*) = value;
407481
unsafe {
408482
#slice_struct_name {
409-
discriminant: core::slice::from_raw_parts(core::mem::transmute(__disc_ptr.as_ptr()), __soa_len),
483+
discriminant: core::slice::from_raw_parts(__disc_ptr.as_ptr() as *const _, __soa_len),
410484
#(#ref_union_field_names: core::slice::from_raw_parts(#union_field_names.as_ptr(), __soa_len)),*
411485
}
412486
}
@@ -421,7 +495,7 @@ pub fn expand_data_enum(input: DeriveInput) -> syn::Result<TokenStream> {
421495
let (__disc_ptr, #(#union_field_names),*) = value;
422496
unsafe {
423497
#slice_mut_struct_name {
424-
discriminant: core::slice::from_raw_parts_mut(core::mem::transmute(__disc_ptr.as_ptr()), __soa_len),
498+
discriminant: core::slice::from_raw_parts_mut(__disc_ptr.as_ptr() as *mut _, __soa_len),
425499
#(#ref_union_field_names: core::slice::from_raw_parts_mut(#union_field_names.as_ptr(), __soa_len)),*
426500
}
427501
}
@@ -439,6 +513,7 @@ mod tests {
439513
#[test]
440514
fn test_expand_derive_enum() {
441515
let input: DeriveInput = syn::parse_quote! {
516+
#[repr(u8)]
442517
enum TestEnum {
443518
Undefined,
444519
Null,
@@ -449,7 +524,10 @@ mod tests {
449524

450525
let result = expand_data_enum(input).unwrap().to_string();
451526

452-
// std::fs::write("test_expand_derive_enum_output.rs", &result).unwrap();
527+
assert!(result.contains("const __DISC_UNDEFINED : u8 = 0 ;"));
528+
assert!(result.contains("const __DISC_NULL : u8 = 0 + 1 ;"));
529+
assert!(result.contains("const __DISC_BOOLEAN : u8 = 0 + 1 + 1 ;"));
530+
assert!(result.contains("const __DISC_NUMBER : u8 = 0 + 1 + 1 + 1 ;"));
453531

454532
assert!(result.contains("union TestEnumField0"));
455533
assert!(result.contains("undefined : ()"));
@@ -480,8 +558,6 @@ mod tests {
480558

481559
let result = expand_data_enum(input).unwrap().to_string();
482560

483-
// std::fs::write("test_expand_derive_enum_explicit_discriminant.rs", &result).unwrap();
484-
485561
assert!(result.contains("union ValueField0"));
486562
assert!(result.contains("undefined : ()"));
487563
assert!(result.contains("string : & 'static str"));
@@ -541,8 +617,6 @@ mod tests {
541617

542618
let result = expand_data_enum(input).unwrap().to_string();
543619

544-
// std::fs::write("test_multiple_field_variants_output.rs", &result).unwrap();
545-
546620
assert!(result.contains("union MultiFieldField0"));
547621
assert!(result.contains("union MultiFieldField1"));
548622
assert!(result.contains("union MultiFieldField2"));
@@ -553,4 +627,88 @@ mod tests {
553627
assert!(result.contains("String"));
554628
assert!(result.contains("i32"));
555629
}
630+
631+
#[test]
632+
fn test_respects_repr() {
633+
let input: DeriveInput = syn::parse_quote! {
634+
#[repr(u16)]
635+
enum CustomRepr {
636+
A,
637+
B,
638+
C,
639+
}
640+
};
641+
642+
let result = expand_data_enum(input).unwrap().to_string();
643+
644+
assert!(result.contains("const __DISC_A : u16 = 0 ;"));
645+
assert!(result.contains("const __DISC_B : u16 = 0 + 1 ;"));
646+
assert!(result.contains("const __DISC_C : u16 = 0 + 1 + 1 ;"));
647+
}
648+
649+
#[test]
650+
fn test_respects_custom_discriminants() {
651+
let input: DeriveInput = syn::parse_quote! {
652+
enum CustomDiscriminants {
653+
A = 10,
654+
B = 20,
655+
C, // 20 + 1
656+
}
657+
};
658+
659+
let result = expand_data_enum(input).unwrap().to_string();
660+
661+
assert!(result.contains("const __DISC_A : isize = 10 ;"));
662+
assert!(result.contains("const __DISC_B : isize = 20 ;"));
663+
assert!(result.contains("const __DISC_C : isize = 20 + 1 ;"));
664+
}
665+
666+
#[test]
667+
fn test_function_const_discriminants() {
668+
let input: DeriveInput = syn::parse_quote! {
669+
enum FuncConstDiscriminants {
670+
A = MY_CONST,
671+
B = my_function(),
672+
C, // my_function() + 1
673+
}
674+
};
675+
676+
let result = expand_data_enum(input).unwrap().to_string();
677+
678+
assert!(result.contains("const __DISC_A : isize = MY_CONST ;"));
679+
assert!(result.contains("const __DISC_B : isize = my_function () ;"));
680+
assert!(result.contains("const __DISC_C : isize = my_function () + 1 ;"));
681+
}
682+
683+
#[test]
684+
fn test_does_not_use_c_as_type() {
685+
let input: DeriveInput = syn::parse_quote! {
686+
#[repr(C)] // should default to isize
687+
enum NotC {
688+
A,
689+
B,
690+
}
691+
};
692+
693+
let result = expand_data_enum(input).unwrap().to_string();
694+
695+
assert!(result.contains("const __DISC_A : isize = 0 ;"));
696+
assert!(result.contains("const __DISC_B : isize = 0 + 1 ;"));
697+
}
698+
699+
#[test]
700+
fn test_repr_c_u8_uses_u8() {
701+
let input: DeriveInput = syn::parse_quote! {
702+
#[repr(C, u8)]
703+
enum CAndU8 {
704+
A,
705+
B,
706+
}
707+
};
708+
709+
let result = expand_data_enum(input).unwrap().to_string();
710+
711+
assert!(result.contains("const __DISC_A : u8 = 0 ;"));
712+
assert!(result.contains("const __DISC_B : u8 = 0 + 1 ;"));
713+
}
556714
}

0 commit comments

Comments
 (0)