Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pdl-compiler/src/backends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ pub mod json;
pub mod rust;
pub mod rust_legacy;

mod common;
pub mod common;
2 changes: 1 addition & 1 deletion pdl-compiler/src/backends/common/alignment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub enum Chunk<S: Symbol> {
}

/// Packs symbols of various sizes, which may not be byte-aligned, into a sequence of byte-aligned chunks.
#[derive(Debug, Clone)]
#[derive(Default, Debug, Clone)]
pub struct ByteAligner<S: Symbol> {
/// Staged fields, waiting to be packed into a chunk.
staged_fields: Vec<Field<S>>,
Expand Down
92 changes: 87 additions & 5 deletions pdl-compiler/src/backends/rust/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,23 @@ fn constraint_value_str(fields: &[&'_ ast::Field], constraint: &ast::Constraint)
}
}

/// Return the default value for a field.
/// Only concrete data fields are considered for inclusion,
/// other kinds will yield an unreachable! error.
fn data_field_default(field: &ast::Field) -> proc_macro2::TokenStream {
match &field.desc {
_ if field.cond.is_some() => quote! { None },
ast::FieldDesc::Scalar { .. } => quote! { 0 },
ast::FieldDesc::Typedef { .. } => quote! { Default::default() },
ast::FieldDesc::Array { width: Some(_), size: Some(size), .. } => quote! { [0; #size] },
ast::FieldDesc::Array { size: Some(_), .. } => {
quote! { std::array::from_fn(|_| Default::default()) }
}
ast::FieldDesc::Array { .. } => quote! { vec![] },
_ => unreachable!(),
}
}

fn implements_copy(scope: &analyzer::Scope<'_>, field: &ast::Field) -> bool {
match &field.desc {
ast::FieldDesc::Scalar { .. } => true,
Expand Down Expand Up @@ -454,7 +471,10 @@ fn generate_root_packet_decl(
}
})
.collect::<Vec<_>>();
let data_field_defaults =
data_fields.iter().copied().map(data_field_default).collect::<Vec<_>>();
let payload_field = decl.payload().map(|_| quote! { pub payload: Vec<u8>, });
let payload_default = decl.payload().map(|_| quote! { payload: vec![], });
let payload_accessor =
decl.payload().map(|_| quote! { pub fn payload(&self) -> &[u8] { &self.payload } });

Expand Down Expand Up @@ -519,14 +539,29 @@ fn generate_root_packet_decl(
}
});

// Provide the implementation of the default trait.
// #[derive(Default)] unfortunately does not support fixed sized arrays of length larger
// than 32, which forces us to write a manual implementation.
// https://github.com/rust-lang/rust/issues/61415
let default_impl = quote! {
impl Default for #name {
fn default() -> #name {
#name {
#( #data_field_ids: #data_field_defaults, )*
#payload_default
}
}
}
};

// Provide the implementation of the specialization function.
// The specialization function is only provided for declarations that have
// child packets.
let specialize = (!children_decl.is_empty())
.then(|| generate_specialize_impl(scope, schema, decl, id, &data_fields).unwrap());

quote! {
#[derive(Default, Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct #name {
#( pub #data_field_ids: #data_field_types, )*
Expand All @@ -546,6 +581,8 @@ fn generate_root_packet_decl(
)*
}

#default_impl

impl Packet for #name {
#encoded_len
#encode
Expand Down Expand Up @@ -594,7 +631,10 @@ fn generate_derived_packet_decl(
}
})
.collect::<Vec<_>>();
let data_field_defaults =
data_fields.iter().copied().map(data_field_default).collect::<Vec<_>>();
let payload_field = decl.payload().map(|_| quote! { pub payload: Vec<u8>, });
let payload_default = decl.payload().map(|_| quote! { payload: vec![], });
let payload_accessor =
decl.payload().map(|_| quote! { pub fn payload(&self) -> &[u8] { &self.payload } });

Expand Down Expand Up @@ -866,8 +906,23 @@ fn generate_derived_packet_decl(
let specialize = (!children_decl.is_empty())
.then(|| generate_specialize_impl(scope, schema, decl, id, &data_fields).unwrap());

// Provide the implementation of the default trait.
// #[derive(Default)] unfortunately does not support fixed sized arrays of length larger
// than 32, which forces us to write a manual implementation.
// https://github.com/rust-lang/rust/issues/61415
let default_impl = quote! {
impl Default for #name {
fn default() -> #name {
#name {
#( #data_field_ids: #data_field_defaults, )*
#payload_default
}
}
}
};

quote! {
#[derive(Default, Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct #name {
#( pub #data_field_ids: #data_field_types, )*
Expand Down Expand Up @@ -900,6 +955,8 @@ fn generate_derived_packet_decl(
)*
}

#default_impl

impl Packet for #name {
#encoded_len
#encode
Expand Down Expand Up @@ -1005,6 +1062,27 @@ fn generate_enum_decl(id: &str, tags: &[ast::Tag], width: usize) -> proc_macro2:
}
}

// Generate the default enum value.
// The default value is the first tag of the enum.
// If the first tag identifies a range, then the first value
// of the range is used.
let default_value = match &tags[0] {
ast::Tag::Value(ast::TagValue { id, .. }) => {
let id = format_tag_ident(id);
quote! { #name::#id }
}
ast::Tag::Range(ast::TagRange { tags, .. }) if !tags.is_empty() => {
let id = format_tag_ident(&tags[0].id);
quote! { #name::#id }
}
ast::Tag::Range(ast::TagRange { id, range, .. }) => {
let id = format_tag_ident(id);
let value = format_value(*range.start());
quote! { #name::#id(Private(#value)) }
}
ast::Tag::Other(_) => todo!(),
};

// Generate the cases for parsing the enum value from an integer.
let mut from_cases = vec![];
for tag in tags.iter() {
Expand Down Expand Up @@ -1080,15 +1158,19 @@ fn generate_enum_decl(id: &str, tags: &[ast::Tag], width: usize) -> proc_macro2:

quote! {
#repr_u64
#[derive(Default, Debug, Clone, Copy, Hash, Eq, PartialEq)]
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(try_from = #backing_type_str, into = #backing_type_str))]
pub enum #name {
// No specific value => use the first one as default
#[default]
#(#variants,)*
}

impl Default for #name {
fn default() -> #name {
#default_value
}
}

impl TryFrom<#backing_type> for #name {
type Error = #backing_type;
fn try_from(value: #backing_type) -> Result<Self, Self::Error> {
Expand Down
11 changes: 6 additions & 5 deletions pdl-compiler/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

//! PDL parser and analyzer.

use std::path::Path;

use argh::FromArgs;
use codespan_reporting::term::{self, termcolor};

Expand Down Expand Up @@ -80,11 +78,13 @@ struct Opt {
/// For the rust backend this is a path e.g. "module::CustomField" or "super::CustomField".
custom_field: Vec<String>,

#[cfg(feature = "java")]
#[argh(option)]
/// directory where generated files should go. This only works when 'output_format' is 'java'.
/// If omitted, the generated code will be printed to stdout.
output_dir: Option<String>,

#[cfg(feature = "java")]
#[argh(option)]
/// java package to contain the generated classes.
java_package: Option<String>,
Expand Down Expand Up @@ -149,7 +149,7 @@ fn generate_backend(opt: &Opt) -> Result<(), String> {
&sources,
&analyzed_file,
&opt.custom_field,
Path::new(output_dir),
std::path::Path::new(output_dir),
package,
)
}
Expand All @@ -167,7 +167,8 @@ fn generate_backend(opt: &Opt) -> Result<(), String> {
Err(err) => {
let writer = termcolor::StandardStream::stderr(termcolor::ColorChoice::Always);
let config = term::Config::default();
term::emit_to_write_style(&mut writer.lock(), &config, &sources, &err).expect("Could not print error");
term::emit_to_write_style(&mut writer.lock(), &config, &sources, &err)
.expect("Could not print error");
Err(String::from("Error while parsing input"))
}
}
Expand Down Expand Up @@ -196,7 +197,7 @@ fn generate_tests(opt: &Opt, test_file: &str) -> Result<(), String> {

backends::java::test::generate_tests(
test_file,
Path::new(output_dir),
std::path::Path::new(output_dir),
package.clone(),
&opt.input_file,
&opt.exclude_declaration,
Expand Down
3 changes: 1 addition & 2 deletions pdl-compiler/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,7 @@ pub fn parse_inline(
) -> Result<ast::File, Diagnostic<ast::FileId>> {
let root = PDLParser::parse(Rule::file, &source)
.map_err(|e| {
Diagnostic::error()
.with_message(format!("failed to parse input file '{name}': {e}"))
Diagnostic::error().with_message(format!("failed to parse input file '{name}': {e}"))
})?
.next()
.unwrap();
Expand Down
Loading
Loading