diff --git a/.gitignore b/.gitignore index bab5191acce..283b789aa87 100644 --- a/.gitignore +++ b/.gitignore @@ -103,3 +103,4 @@ shards/tests/.edits_backup local/ .claude/ +shards/rust_macro/Cargo.lock diff --git a/deps/CMakeLists.txt b/deps/CMakeLists.txt index 16922689e93..e65c93ffe12 100644 --- a/deps/CMakeLists.txt +++ b/deps/CMakeLists.txt @@ -312,5 +312,3 @@ add_subdirectory(opus) add_library(crdt_lite INTERFACE) target_include_directories(crdt_lite INTERFACE crdt-lite) - -add_subdirectory(FTXUI) \ No newline at end of file diff --git a/shards/modules/core/src/uuid.rs b/shards/modules/core/src/uuid.rs index 5c631a46e2b..db7139f4843 100644 --- a/shards/modules/core/src/uuid.rs +++ b/shards/modules/core/src/uuid.rs @@ -5,26 +5,20 @@ use shards::core::register_legacy_shard; use shards::core::register_shard; use shards::shard::LegacyShard; use shards::shard::Shard; +use shards::simple_shard; +use shards::types::BytesOut; use shards::types::ClonedVar; use shards::types::Context; use shards::types::ExposedTypes; use shards::types::InstanceData; use shards::types::OptionalString; -use shards::types::BOOL_TYPES_SLICE; +use shards::types::StringOut; use shards::types::BYTES_OR_STRING_TYPES; -use shards::types::BYTES_TYPES; -use shards::types::INT16_TYPES; use shards::types::INT_TYPES; -use shards::types::INT_TYPES_SLICE; -use shards::types::NONE_TYPES; - -use shards::types::Parameters; use shards::types::Type; -use shards::types::STRING_TYPES; - use shards::types::common_type; use shards::types::Types; use shards::types::Var; @@ -33,40 +27,45 @@ use core::convert::TryInto; use std::str::FromStr; use std::sync::RwLock; -#[derive(Default)] -struct UUIDCreate {} - -impl LegacyShard for UUIDCreate { - fn registerName() -> &'static str { - cstr!("UUID") - } - - fn hash() -> u32 { - compile_time_crc32::crc32!("UUID-rust-0x20250822") - } +// Simple shards using the new macro - fn name(&mut self) -> &str { - "UUID" - } - - fn help(&mut self) -> OptionalString { - OptionalString(shccstr!("Outputs a UUID (Universally Unique Identifier).")) - } +#[simple_shard("UUID", "Outputs a UUID (Universally Unique Identifier).")] +fn uuid_create(_: ()) -> [u8; 16] { + let uuid = uuid::Uuid::new_v4(); + *uuid.as_bytes() +} - fn inputTypes(&mut self) -> &std::vec::Vec { - &NONE_TYPES +#[simple_shard("UUID.ToString", "Reads a UUID and formats it into a readable string.")] +fn uuid_to_string( + input: [u8; 16], + #[param("Hyphenated", "Whether to use hyphens in the output.", default = false)] + hyphenated: bool, +) -> StringOut { + let uuid = uuid::Uuid::from_bytes(input); + if hyphenated { + uuid.hyphenated().to_string().as_str().into() + } else { + uuid.simple().to_string().as_str().into() } +} - fn outputTypes(&mut self) -> &std::vec::Vec { - &INT16_TYPES - } +#[simple_shard("UUID.ToBytes", "Reads a UUID and formats it into bytes.")] +fn uuid_to_bytes(input: [u8; 16]) -> BytesOut { + input.as_slice().into() +} - fn activate(&mut self, _: &Context, _: &Var) -> Result, &str> { - let uuid = uuid::Uuid::new_v4(); - Ok(Some(uuid.as_bytes().into())) - } +#[simple_shard("NanoID", "Creates a random NanoID.")] +fn nanoid_create( + _: (), + #[param("Size", "The output string length of the created NanoID.", default = 21i64)] + size: i64, +) -> StringOut { + let size = size as usize; + nanoid::nanoid!(size).as_str().into() } +// Legacy shard for UUID.Convert (handles multiple input types) + #[derive(Default)] struct UUIDConvert {} @@ -94,7 +93,7 @@ impl LegacyShard for UUIDConvert { } fn outputTypes(&mut self) -> &std::vec::Vec { - &INT16_TYPES + &shards::types::INT16_TYPES } fn activate(&mut self, _: &Context, input: &Var) -> Result, &str> { @@ -113,188 +112,7 @@ impl LegacyShard for UUIDConvert { } } -lazy_static! { - static ref PARAMETERS: Parameters = vec![( - cstr!("Hyphenated"), - shccstr!("Whether to use hyphens in the output."), - BOOL_TYPES_SLICE - ) - .into(),]; -} - -#[derive(Default)] -struct UUIDToString { - output: ClonedVar, - hyphenated: bool, -} - -impl LegacyShard for UUIDToString { - fn registerName() -> &'static str { - cstr!("UUID.ToString") - } - - fn hash() -> u32 { - compile_time_crc32::crc32!("UUID.ToString-rust-0x20250822") - } - - fn name(&mut self) -> &str { - "UUID.ToString" - } - - fn help(&mut self) -> OptionalString { - OptionalString(shccstr!( - "Reads an UUID and formats it into a readable string." - )) - } - - fn inputTypes(&mut self) -> &std::vec::Vec { - &INT16_TYPES - } - - fn outputTypes(&mut self) -> &std::vec::Vec { - &STRING_TYPES - } - - fn parameters(&mut self) -> Option<&Parameters> { - Some(&PARAMETERS) - } - - fn setParam(&mut self, index: i32, value: &Var) -> Result<(), &str> { - match index { - 0 => Ok(self.hyphenated = value.try_into()?), - _ => unreachable!(), - } - } - - fn getParam(&mut self, index: i32) -> Var { - match index { - 0 => self.hyphenated.into(), - _ => unreachable!(), - } - } - - fn activate(&mut self, _: &Context, input: &Var) -> Result, &str> { - let bytes: [u8; 16] = input.try_into()?; - let uuid = uuid::Uuid::from_bytes(bytes); - self.output = if self.hyphenated { - uuid.hyphenated().to_string().into() - } else { - uuid.simple().to_string().into() - }; - Ok(Some(self.output.0)) - } -} - -#[derive(Default)] -struct UUIDToBytes { - output: ClonedVar, -} - -impl LegacyShard for UUIDToBytes { - fn registerName() -> &'static str { - cstr!("UUID.ToBytes") - } - - fn hash() -> u32 { - compile_time_crc32::crc32!("UUID.ToBytes-rust-0x20250822") - } - - fn name(&mut self) -> &str { - "UUID.ToBytes" - } - - fn help(&mut self) -> OptionalString { - OptionalString(shccstr!("Reads an UUID and formats it into bytes.")) - } - - fn inputTypes(&mut self) -> &std::vec::Vec { - &INT16_TYPES - } - - fn outputTypes(&mut self) -> &std::vec::Vec { - &BYTES_TYPES - } - - fn activate(&mut self, _: &Context, input: &Var) -> Result, &str> { - let bytes: [u8; 16] = input.try_into()?; - self.output = bytes.as_ref().into(); - Ok(Some(self.output.0)) - } -} - -lazy_static! { - static ref NANO_PARAMETERS: Parameters = vec![( - cstr!("Size"), - shccstr!("The output string length of the created NanoID."), - INT_TYPES_SLICE - ) - .into(),]; -} - -struct NanoIDCreate { - size: i64, - output: ClonedVar, -} - -impl Default for NanoIDCreate { - fn default() -> Self { - Self { - size: 21, - output: Default::default(), - } - } -} - -impl LegacyShard for NanoIDCreate { - fn registerName() -> &'static str { - cstr!("NanoID") - } - - fn hash() -> u32 { - compile_time_crc32::crc32!("NanoID-rust-0x20250822") - } - - fn name(&mut self) -> &str { - "NanoID" - } - - fn help(&mut self) -> OptionalString { - OptionalString(shccstr!("Creates a random NanoID.")) - } - - fn inputTypes(&mut self) -> &std::vec::Vec { - &NONE_TYPES - } - - fn outputTypes(&mut self) -> &std::vec::Vec { - &STRING_TYPES - } - - fn parameters(&mut self) -> Option<&Parameters> { - Some(&NANO_PARAMETERS) - } - - fn setParam(&mut self, index: i32, value: &Var) -> Result<(), &str> { - match index { - 0 => Ok(self.size = value.try_into()?), - _ => unreachable!(), - } - } - - fn getParam(&mut self, index: i32) -> Var { - match index { - 0 => self.size.into(), - _ => unreachable!(), - } - } - - fn activate(&mut self, _: &Context, _: &Var) -> Result, &str> { - let size = self.size as usize; - let id = nanoid::nanoid!(size); - self.output = id.into(); - Ok(Some(self.output.0)) - } -} +// Snowflake shard (uses global state and custom warmup validation) lazy_static! { static ref SNOWFLAKE_GENERATOR: RwLock = @@ -325,7 +143,7 @@ impl Default for SnowflakeShard { #[shards::shard_impl] impl Shard for SnowflakeShard { fn input_types(&mut self) -> &Types { - &NONE_TYPES + &shards::types::NONE_TYPES } fn output_types(&mut self) -> &Types { @@ -370,10 +188,10 @@ impl Shard for SnowflakeShard { } pub fn register_shards() { - register_legacy_shard::(); - register_legacy_shard::(); - register_legacy_shard::(); - register_legacy_shard::(); + register_shard::(); + register_shard::(); + register_shard::(); + register_shard::(); register_legacy_shard::(); register_shard::(); } diff --git a/shards/modules/core/src/yaml.rs b/shards/modules/core/src/yaml.rs index 8c108131d76..c8df1d07bd7 100644 --- a/shards/modules/core/src/yaml.rs +++ b/shards/modules/core/src/yaml.rs @@ -1,128 +1,29 @@ use shards::core::register_shard; -use shards::shard::Shard; -use shards::types::{ClonedVar, STRING_TYPES}; -use shards::types::{Context, ExposedTypes, InstanceData, Type, Types, Var}; +use shards::simple_shard; -#[derive(shards::shard)] -#[shard_info("Yaml.ToJson", "A shard that converts YAML to JSON.")] -struct YamlToJsonShard { - #[shard_required] - required: ExposedTypes, +#[simple_shard("Yaml.ToJson", "A shard that converts YAML to JSON.")] +fn yaml_to_json(yaml: &str) -> Result { + // Deserialize YAML into a serde_json::Value + let data: serde_json::Value = serde_yml::from_str(yaml).map_err(|_| "Failed to parse YAML")?; - output: ClonedVar, -} - -impl Default for YamlToJsonShard { - fn default() -> Self { - Self { - required: ExposedTypes::new(), - output: ClonedVar::default(), - } - } -} - -#[shards::shard_impl] -impl Shard for YamlToJsonShard { - fn input_types(&mut self) -> &Types { - &STRING_TYPES - } - - fn output_types(&mut self) -> &Types { - &STRING_TYPES - } - - fn warmup(&mut self, ctx: &Context) -> Result<(), &str> { - self.warmup_helper(ctx)?; - - Ok(()) - } - - fn cleanup(&mut self, ctx: Option<&Context>) -> Result<(), &str> { - self.cleanup_helper(ctx)?; - - Ok(()) - } - - fn compose(&mut self, data: &InstanceData) -> Result { - self.compose_helper(data)?; - Ok(self.output_types()[0]) - } - - fn activate(&mut self, _context: &Context, input: &Var) -> Result, &str> { - let yaml: &str = input.try_into()?; - // simply convert to json using serde - - // Deserialize YAML into a serde_json::Value - let data: serde_json::Value = serde_yml::from_str(yaml).map_err(|_| "Failed to parse YAML")?; - - // Serialize the data to a JSON string - let json_string = serde_json::to_string(&data).map_err(|_| "Failed to serialize to JSON")?; + // Serialize the data to a JSON string + let json_string = serde_json::to_string(&data).map_err(|_| "Failed to serialize to JSON")?; - self.output = json_string.into(); - - Ok(Some(self.output.0)) - } -} - -#[derive(shards::shard)] -#[shard_info("Yaml.FromJson", "A shard that converts JSON to YAML.")] -struct JsonToYamlShard { - #[shard_required] - required: ExposedTypes, - - output: ClonedVar, + Ok(json_string) } -impl Default for JsonToYamlShard { - fn default() -> Self { - Self { - required: ExposedTypes::new(), - output: ClonedVar::default(), - } - } -} - -#[shards::shard_impl] -impl Shard for JsonToYamlShard { - fn input_types(&mut self) -> &Types { - &STRING_TYPES - } - - fn output_types(&mut self) -> &Types { - &STRING_TYPES - } - - fn warmup(&mut self, ctx: &Context) -> Result<(), &str> { - self.warmup_helper(ctx)?; - Ok(()) - } - - fn cleanup(&mut self, ctx: Option<&Context>) -> Result<(), &str> { - self.cleanup_helper(ctx)?; - Ok(()) - } - - fn compose(&mut self, data: &InstanceData) -> Result { - self.compose_helper(data)?; - Ok(self.output_types()[0]) - } - - fn activate(&mut self, _context: &Context, input: &Var) -> Result, &str> { - let json: &str = input.try_into()?; - - // Deserialize JSON into a serde_json::Value - let data: serde_json::Value = serde_json::from_str(json).map_err(|_| "Failed to parse JSON")?; - - // Serialize the data to a YAML string - let yaml_string = serde_yml::to_string(&data).map_err(|_| "Failed to serialize to YAML")?; +#[simple_shard("Yaml.FromJson", "A shard that converts JSON to YAML.")] +fn json_to_yaml(json: &str) -> Result { + // Deserialize JSON into a serde_json::Value + let data: serde_json::Value = serde_json::from_str(json).map_err(|_| "Failed to parse JSON")?; - self.output = yaml_string.into(); + // Serialize the data to a YAML string + let yaml_string = serde_yml::to_string(&data).map_err(|_| "Failed to serialize to YAML")?; - Ok(Some(self.output.0)) - } + Ok(yaml_string) } pub fn register_shards() { register_shard::(); - register_shard::(); + register_shard::(); } diff --git a/shards/modules/crypto/src/argon.rs b/shards/modules/crypto/src/argon.rs index 6658c1e6d25..ca76c458788 100644 --- a/shards/modules/crypto/src/argon.rs +++ b/shards/modules/crypto/src/argon.rs @@ -2,169 +2,64 @@ /* Copyright © 2023 Fragcolor Pte. Ltd. */ use shards::core::register_shard; -use shards::shard::Shard; -use shards::types::{ - common_type, ClonedVar, Context, ExposedTypes, InstanceData, ParamVar, Type, Types, Var, - BOOL_TYPES, STRING_TYPES, -}; +use shards::simple_shard; use argon2::{ password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Argon2, }; use rand_core::OsRng; -use std::convert::TryInto; - -#[derive(shards::shard)] -#[shard_info("Argon2id.Hash", "Hashes a password using the Argon2id algorithm.")] -struct Argon2idHashShard { - #[shard_required] - required: ExposedTypes, - - #[shard_param("MemoryCost", "The amount of memory to use in KiB. Default is 8192 (8 MB).", [common_type::int])] - memory_cost: ParamVar, - - #[shard_param("TimeCost", "The number of iterations to perform. Default is 4.", [common_type::int])] - time_cost: ParamVar, - - #[shard_param("Parallelism", "The degree of parallelism to use. Default is 1.", [common_type::int])] - parallelism: ParamVar, - - output: ClonedVar, -} - -impl Default for Argon2idHashShard { - fn default() -> Self { - Self { - required: ExposedTypes::new(), - memory_cost: ParamVar::new(8192.into()), // 8 MB - time_cost: ParamVar::new(4.into()), - parallelism: ParamVar::new(1.into()), - output: ClonedVar::default(), - } - } -} - -#[shards::shard_impl] -impl Shard for Argon2idHashShard { - fn input_types(&mut self) -> &Types { - &STRING_TYPES - } - - fn output_types(&mut self) -> &Types { - &STRING_TYPES - } - - fn warmup(&mut self, ctx: &Context) -> Result<(), &str> { - self.warmup_helper(ctx)?; - Ok(()) - } - - fn cleanup(&mut self, ctx: Option<&Context>) -> Result<(), &str> { - self.cleanup_helper(ctx)?; - Ok(()) - } - - fn compose(&mut self, data: &InstanceData) -> Result { - self.compose_helper(data)?; - // Remove the checks for None, as we now have default values - Ok(self.output_types()[0]) - } - - fn activate(&mut self, _context: &Context, input: &Var) -> Result, &str> { - let password: &str = input.try_into()?; - - let memory_cost: i64 = self.memory_cost.get().try_into()?; - let time_cost: i64 = self.time_cost.get().try_into()?; - let parallelism: i64 = self.parallelism.get().try_into()?; - // Convert parameters to u32 - let memory_cost = u32::try_from(memory_cost).map_err(|_| "Invalid memory cost")?; - let time_cost = u32::try_from(time_cost).map_err(|_| "Invalid time cost")?; - let parallelism = u32::try_from(parallelism).map_err(|_| "Invalid parallelism")?; - - // Create an Argon2 instance - let argon2 = Argon2::new( - argon2::Algorithm::Argon2id, - argon2::Version::V0x13, - argon2::Params::new(memory_cost, time_cost, parallelism, None).unwrap(), - ); - - // Generate a random salt - let salt = SaltString::generate(&mut OsRng); - - // Hash the password - let password_hash = argon2 - .hash_password(password.as_bytes(), &salt) - .map_err(|_| "Failed to hash password")?; - - // Convert the PasswordHash to a string - let hash_string = password_hash.serialize(); - - self.output = Var::ephemeral_string(hash_string.as_str()).into(); - Ok(Some(self.output.0)) - } -} - -#[derive(shards::shard)] -#[shard_info("Argon2id.Verify", "Verifies a password against an Argon2id hash.")] -struct Argon2idVerifyShard { - #[shard_required] - required: ExposedTypes, - - #[shard_param("Hash", "The Argon2id hash to verify against.", [common_type::string, common_type::string_var])] - hash: ParamVar, +#[simple_shard("Argon2id.Hash", "Hashes a password using the Argon2id algorithm.")] +fn argon2id_hash( + password: &str, + #[param("MemoryCost", "The amount of memory to use in KiB. Default is 8192 (8 MB).", default = 8192i64)] + memory_cost: i64, + #[param("TimeCost", "The number of iterations to perform. Default is 4.", default = 4i64)] + time_cost: i64, + #[param("Parallelism", "The degree of parallelism to use. Default is 1.", default = 1i64)] + parallelism: i64, +) -> Result { + // Convert parameters to u32 + let memory_cost = u32::try_from(memory_cost).map_err(|_| "Invalid memory cost")?; + let time_cost = u32::try_from(time_cost).map_err(|_| "Invalid time cost")?; + let parallelism = u32::try_from(parallelism).map_err(|_| "Invalid parallelism")?; + + // Create an Argon2 instance + let argon2 = Argon2::new( + argon2::Algorithm::Argon2id, + argon2::Version::V0x13, + argon2::Params::new(memory_cost, time_cost, parallelism, None) + .map_err(|_| "Invalid Argon2 parameters")?, + ); + + // Generate a random salt + let salt = SaltString::generate(&mut OsRng); + + // Hash the password + let password_hash = argon2 + .hash_password(password.as_bytes(), &salt) + .map_err(|_| "Failed to hash password")?; + + // Convert the PasswordHash to a string + let hash_string = password_hash.serialize(); + + Ok(hash_string.to_string()) } -impl Default for Argon2idVerifyShard { - fn default() -> Self { - Self { - required: ExposedTypes::new(), - hash: ParamVar::default(), - } - } -} - -#[shards::shard_impl] -impl Shard for Argon2idVerifyShard { - fn input_types(&mut self) -> &Types { - &STRING_TYPES - } - - fn output_types(&mut self) -> &Types { - &BOOL_TYPES - } - - fn warmup(&mut self, ctx: &Context) -> Result<(), &str> { - self.warmup_helper(ctx)?; - Ok(()) - } - - fn cleanup(&mut self, ctx: Option<&Context>) -> Result<(), &str> { - self.cleanup_helper(ctx)?; - Ok(()) - } - - fn compose(&mut self, data: &InstanceData) -> Result { - self.compose_helper(data)?; - if self.hash.is_none() { - return Err("Hash parameter is required"); - } - Ok(self.output_types()[0]) - } - - fn activate(&mut self, _context: &Context, input: &Var) -> Result, &str> { - let password: &str = input.try_into()?; - let hash: &str = self.hash.get().try_into()?; - - let parsed_hash = PasswordHash::new(hash).map_err(|_| "Failed to parse the provided hash")?; +#[simple_shard("Argon2id.Verify", "Verifies a password against an Argon2id hash.")] +fn argon2id_verify( + password: &str, + #[param_var("Hash", "The Argon2id hash to verify against.")] + hash: &str, +) -> Result { + let parsed_hash = PasswordHash::new(hash).map_err(|_| "Failed to parse the provided hash")?; - let result = Argon2::default() - .verify_password(password.as_bytes(), &parsed_hash) - .is_ok(); + let result = Argon2::default() + .verify_password(password.as_bytes(), &parsed_hash) + .is_ok(); - Ok(Some(result.into())) - } + Ok(result) } pub fn register_shards() { diff --git a/shards/rust/src/lib.rs b/shards/rust/src/lib.rs index d8f657b7853..91e18fcf3d9 100644 --- a/shards/rust/src/lib.rs +++ b/shards/rust/src/lib.rs @@ -43,6 +43,7 @@ pub use shards_macro::param_set; pub use shards_macro::shard; pub use shards_macro::shard_impl; pub use shards_macro::shards_enum; +pub use shards_macro::simple_shard; use crate::core::Core; pub use crate::shardsc::*; diff --git a/shards/rust/src/types/common.rs b/shards/rust/src/types/common.rs index 9db06ffd2eb..e1128d53dfb 100644 --- a/shards/rust/src/types/common.rs +++ b/shards/rust/src/types/common.rs @@ -212,6 +212,54 @@ pub type RawString = SHString; #[derive(Default, Serialize, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct ClonedVar(pub Var); +/// Typed wrapper for bytes output - avoids extra allocations +#[repr(transparent)] +#[derive(Default)] +pub struct BytesOut(pub ClonedVar); + +/// Typed wrapper for string output - avoids extra allocations +#[repr(transparent)] +#[derive(Default)] +pub struct StringOut(pub ClonedVar); + +impl BytesOut { + pub fn new(data: &[u8]) -> Self { + BytesOut(Var::from(data).into()) + } +} + +impl From<&[u8]> for BytesOut { + fn from(v: &[u8]) -> Self { + BytesOut::new(v) + } +} + +impl std::ops::Deref for BytesOut { + type Target = ClonedVar; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl StringOut { + pub fn new(s: &str) -> Self { + StringOut(Var::ephemeral_string(s).into()) + } +} + +impl From<&str> for StringOut { + fn from(v: &str) -> Self { + StringOut::new(v) + } +} + +impl std::ops::Deref for StringOut { + type Target = ClonedVar; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl Ord for Var { fn cmp(&self, other: &Self) -> std::cmp::Ordering { unsafe { diff --git a/shards/rust/src/types/metadata.rs b/shards/rust/src/types/metadata.rs index 220ec923f47..17f01abdd2f 100644 --- a/shards/rust/src/types/metadata.rs +++ b/shards/rust/src/types/metadata.rs @@ -1233,15 +1233,19 @@ impl From> for ClonedVar { impl From for ClonedVar { fn from(v: std::string::String) -> Self { - let cstr = CString::new(v).unwrap(); - let tmp = Var::from(&cstr); - let res = ClonedVar(Var::default()); - unsafe { - let rv = &res.0 as *const SHVar as *mut SHVar; - let sv = &tmp as *const SHVar; - (*Core).cloneVar.unwrap_unchecked()(rv, sv); - } - res + Var::ephemeral_string(&v).into() + } +} + +impl From> for ClonedVar { + fn from(v: Vec) -> Self { + Var::from(&v[..]).into() + } +} + +impl From<[u8; 16]> for ClonedVar { + fn from(v: [u8; 16]) -> Self { + Var::from(&v).into() } } @@ -1622,6 +1626,7 @@ impl From<&[u8; 16]> for Var { } } + impl TryFrom<&Var> for SHAudio { type Error = &'static str; @@ -2016,3 +2021,4 @@ impl From<&[u8]> for Var { } } } + diff --git a/shards/rust/src/types/mod.rs b/shards/rust/src/types/mod.rs index 88604682bde..d050810752f 100644 --- a/shards/rust/src/types/mod.rs +++ b/shards/rust/src/types/mod.rs @@ -25,6 +25,7 @@ pub mod param; pub mod strings; pub mod seq; pub mod table; +pub mod shard_type; // Re-export common types that are used everywhere pub use common::*; @@ -66,3 +67,6 @@ pub use table::{ STRING_OR_NONE_SLICE, STRINGS_OR_NONE_SLICE, STRING_VAR_OR_NONE_SLICE, ANY_TABLE_VAR_NONE_SLICE, }; + +// Re-export ShardType trait for simple_shard macro +pub use shard_type::ShardType; diff --git a/shards/rust/src/types/shard_type.rs b/shards/rust/src/types/shard_type.rs new file mode 100644 index 00000000000..ccfd52351f6 --- /dev/null +++ b/shards/rust/src/types/shard_type.rs @@ -0,0 +1,176 @@ +/* SPDX-License-Identifier: BSD-3-Clause */ +/* Copyright © 2025 Fragcolor Pte. Ltd. */ + +//! ShardType trait for mapping Rust types to Shards types. +//! +//! This trait enables the simple_shard macro to automatically determine +//! the correct Shards type information from Rust types. + +use super::*; +use crate::shardsc::*; + +/// Trait for mapping Rust types to Shards types. +/// +/// Implement this trait for custom types to enable them to be used +/// with the simple_shard macro for automatic type inference. +pub trait ShardType { + /// Returns the Shards Type for this Rust type + fn shards_type() -> Type; + /// Returns a static reference to the Types vec (for input_types/output_types) + fn shards_types() -> &'static Types; + /// Returns the context variable version of this type (for parameters) + fn shards_var_type() -> Type; +} + +// Implement for common primitive types + +impl ShardType for i64 { + fn shards_type() -> Type { common_type::int } + fn shards_types() -> &'static Types { &INT_TYPES } + fn shards_var_type() -> Type { common_type::int_var } +} + +impl ShardType for i32 { + fn shards_type() -> Type { common_type::int } + fn shards_types() -> &'static Types { &INT_TYPES } + fn shards_var_type() -> Type { common_type::int_var } +} + +impl ShardType for f64 { + fn shards_type() -> Type { common_type::float } + fn shards_types() -> &'static Types { &FLOAT_TYPES } + fn shards_var_type() -> Type { common_type::float_var } +} + +impl ShardType for f32 { + fn shards_type() -> Type { common_type::float } + fn shards_types() -> &'static Types { &FLOAT_TYPES } + fn shards_var_type() -> Type { common_type::float_var } +} + +impl ShardType for bool { + fn shards_type() -> Type { common_type::bool } + fn shards_types() -> &'static Types { &BOOL_TYPES } + fn shards_var_type() -> Type { common_type::bool_var } +} + +impl ShardType for std::string::String { + fn shards_type() -> Type { common_type::string } + fn shards_types() -> &'static Types { &STRING_TYPES } + fn shards_var_type() -> Type { common_type::string_var } +} + +impl<'a> ShardType for &'a str { + fn shards_type() -> Type { common_type::string } + fn shards_types() -> &'static Types { &STRING_TYPES } + fn shards_var_type() -> Type { common_type::string_var } +} + +// Vector types - int +impl ShardType for (i64, i64) { + fn shards_type() -> Type { common_type::int2 } + fn shards_types() -> &'static Types { &INT2_TYPES } + fn shards_var_type() -> Type { common_type::int2_var } +} + +impl ShardType for (i32, i32, i32) { + fn shards_type() -> Type { common_type::int3 } + fn shards_types() -> &'static Types { &INT3_TYPES } + fn shards_var_type() -> Type { common_type::int3_var } +} + +impl ShardType for (i64, i64, i64) { + fn shards_type() -> Type { common_type::int3 } + fn shards_types() -> &'static Types { &INT3_TYPES } + fn shards_var_type() -> Type { common_type::int3_var } +} + +impl ShardType for (i32, i32, i32, i32) { + fn shards_type() -> Type { common_type::int4 } + fn shards_types() -> &'static Types { &INT4_TYPES } + fn shards_var_type() -> Type { common_type::int4_var } +} + +impl ShardType for (i64, i64, i64, i64) { + fn shards_type() -> Type { common_type::int4 } + fn shards_types() -> &'static Types { &INT4_TYPES } + fn shards_var_type() -> Type { common_type::int4_var } +} + +// Vector types - float +impl ShardType for (f64, f64) { + fn shards_type() -> Type { common_type::float2 } + fn shards_types() -> &'static Types { &FLOAT2_TYPES } + fn shards_var_type() -> Type { common_type::float2_var } +} + +impl ShardType for (f32, f32, f32) { + fn shards_type() -> Type { common_type::float3 } + fn shards_types() -> &'static Types { &FLOAT3_TYPES } + fn shards_var_type() -> Type { common_type::float3_var } +} + +impl ShardType for (f64, f64, f64) { + fn shards_type() -> Type { common_type::float3 } + fn shards_types() -> &'static Types { &FLOAT3_TYPES } + fn shards_var_type() -> Type { common_type::float3_var } +} + +impl ShardType for (f32, f32, f32, f32) { + fn shards_type() -> Type { common_type::float4 } + fn shards_types() -> &'static Types { &FLOAT4_TYPES } + fn shards_var_type() -> Type { common_type::float4_var } +} + +impl ShardType for (f64, f64, f64, f64) { + fn shards_type() -> Type { common_type::float4 } + fn shards_types() -> &'static Types { &FLOAT4_TYPES } + fn shards_var_type() -> Type { common_type::float4_var } +} + +// Bytes +impl<'a> ShardType for &'a [u8] { + fn shards_type() -> Type { common_type::bytes } + fn shards_types() -> &'static Types { &BYTES_TYPES } + fn shards_var_type() -> Type { common_type::bytes_var } +} + +impl ShardType for Vec { + fn shards_type() -> Type { common_type::bytes } + fn shards_types() -> &'static Types { &BYTES_TYPES } + fn shards_var_type() -> Type { common_type::bytes_var } +} + +// Fixed-size byte arrays (Int16 = 16 bytes) +impl ShardType for [u8; 16] { + fn shards_type() -> Type { common_type::int16 } + fn shards_types() -> &'static Types { &INT16_TYPES } + fn shards_var_type() -> Type { common_type::int16_var } +} + +// None/Unit type +impl ShardType for () { + fn shards_type() -> Type { common_type::none } + fn shards_types() -> &'static Types { &NONE_TYPES } + fn shards_var_type() -> Type { common_type::none } // None doesn't have a var type +} + +// Color +impl ShardType for SHColor { + fn shards_type() -> Type { common_type::color } + fn shards_types() -> &'static Types { &COLOR_TYPES } + fn shards_var_type() -> Type { common_type::color_var } +} + +// Typed output wrappers +impl ShardType for super::BytesOut { + fn shards_type() -> Type { common_type::bytes } + fn shards_types() -> &'static Types { &BYTES_TYPES } + fn shards_var_type() -> Type { common_type::bytes_var } +} + +impl ShardType for super::StringOut { + fn shards_type() -> Type { common_type::string } + fn shards_types() -> &'static Types { &STRING_TYPES } + fn shards_var_type() -> Type { common_type::string_var } +} diff --git a/shards/rust_macro/src/lib.rs b/shards/rust_macro/src/lib.rs index d095357ed0a..c67d9f076b8 100644 --- a/shards/rust_macro/src/lib.rs +++ b/shards/rust_macro/src/lib.rs @@ -990,3 +990,583 @@ pub fn shard_impl(_attr: TokenStream, item: TokenStream) -> TokenStream { Err(err) => err.to_compile_error(), } } + +// ============================================================================ +// Simple Shard Macro - Simplified shard definition via function attributes +// ============================================================================ + +struct SimpleShardAttrArgs { + name: LitStr, + help: LitStr, +} + +impl syn::parse::Parse for SimpleShardAttrArgs { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let name: LitStr = input.parse()?; + input.parse::()?; + let help: LitStr = input.parse()?; + Ok(Self { name, help }) + } +} + +struct SimpleParamInfo { + name: String, + rust_name: syn::Ident, + rust_type: syn::Type, + description: String, + default: Option, + is_var: bool, // true for ParamVar (context variables) +} + +// Helper struct to parse #[param("Name", "Desc")] or #[param("Name", "Desc", default = value)] +struct SimpleParamAttr { + name: String, + description: String, + default: Option, +} + +impl syn::parse::Parse for SimpleParamAttr { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let name: LitStr = input.parse()?; + input.parse::()?; + let description: LitStr = input.parse()?; + + let default = if input.peek(syn::Token![,]) { + input.parse::()?; + let ident: syn::Ident = input.parse()?; + if ident != "default" { + return Err(syn::Error::new(ident.span(), "Expected 'default'")); + } + input.parse::()?; + Some(input.parse()?) + } else { + None + }; + + Ok(Self { + name: name.value(), + description: description.value(), + default, + }) + } +} + +fn generate_simple_shard(args: SimpleShardAttrArgs, func: syn::ItemFn) -> Result { + let shard_name = args.name.value(); + let shard_help = args.help.value(); + + // Generate struct name from shard name (e.g., "Math.Scale" -> "MathScaleShard") + let struct_name = shard_name.replace(".", ""); + let struct_id = Ident::new(&format!("{}Shard", struct_name), Span::call_site()); + let params_static_id = Ident::new( + &format!("{}_PARAMETERS", struct_name.to_uppercase()), + Span::call_site(), + ); + + // Parse function signature + let mut input_type: Option = None; + let mut input_name: Option = None; + let mut is_unit_input = false; + let mut params: Vec = Vec::new(); + + for (i, arg) in func.sig.inputs.iter().enumerate() { + let syn::FnArg::Typed(pat_type) = arg else { + return Err("Expected typed argument".into()); + }; + + let arg_type = pat_type.ty.as_ref().clone(); + + // First arg is input + if i == 0 { + // Check if input type is unit () + if let syn::Type::Tuple(tuple) = &arg_type { + if tuple.elems.is_empty() { + is_unit_input = true; + input_type = Some(arg_type); + // Use a dummy name for unit input + input_name = Some(Ident::new("_input", Span::call_site())); + continue; + } + } + + // Handle both ident patterns and wildcard patterns + let arg_name = match pat_type.pat.as_ref() { + syn::Pat::Ident(pat_ident) => pat_ident.ident.clone(), + syn::Pat::Wild(_) => Ident::new("_input", Span::call_site()), + _ => return Err("Expected identifier or wildcard pattern".into()), + }; + + input_type = Some(arg_type); + input_name = Some(arg_name); + continue; + } + + // Rest are parameters - need ident pattern + let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() else { + return Err("Expected identifier pattern for parameter".into()); + }; + + let arg_name = pat_ident.ident.clone(); + + // Rest are parameters - parse #[param(...)] or #[param_var(...)] attribute + let mut param_name = arg_name.to_string(); + let mut param_desc = String::new(); + let mut param_default: Option = None; + let mut is_var = false; + + for attr in &pat_type.attrs { + if attr.path().is_ident("param") { + let parsed: SimpleParamAttr = attr.parse_args()?; + param_name = parsed.name; + param_desc = parsed.description; + param_default = parsed.default; + } else if attr.path().is_ident("param_var") { + // For context variable parameters + let parsed: SimpleParamAttr = attr.parse_args()?; + param_name = parsed.name; + param_desc = parsed.description; + param_default = parsed.default; + is_var = true; + } + } + + params.push(SimpleParamInfo { + name: param_name, + rust_name: arg_name, + rust_type: arg_type, + description: param_desc, + default: param_default, + is_var, + }); + } + + let input_type = input_type.ok_or("Function must have at least one argument (input)")?; + let input_name = input_name.ok_or("Function must have at least one argument (input)")?; + + // Get output type from return type + let mut returns_result = false; + let mut returns_typed_out = false; // BytesOut, StringOut, etc. + let output_type = match &func.sig.output { + syn::ReturnType::Type(_, ty) => { + // Handle Result wrapper + if let syn::Type::Path(path) = ty.as_ref() { + let type_name = path.path.segments.last().map(|s| s.ident.to_string()); + + if type_name.as_deref() == Some("Result") { + returns_result = true; + // Extract T from Result + if let syn::PathArguments::AngleBracketed(args) = + &path.path.segments.last().unwrap().arguments + { + if let Some(syn::GenericArgument::Type(t)) = args.args.first() { + // Check if inner type is BytesOut/StringOut + if let syn::Type::Path(inner_path) = t { + let inner_name = inner_path.path.segments.last().map(|s| s.ident.to_string()); + if matches!(inner_name.as_deref(), Some("BytesOut") | Some("StringOut")) { + returns_typed_out = true; + } + } + t.clone() + } else { + return Err("Invalid Result type".into()); + } + } else { + return Err("Invalid Result type".into()); + } + } else if matches!(type_name.as_deref(), Some("BytesOut") | Some("StringOut")) { + returns_typed_out = true; + ty.as_ref().clone() + } else { + ty.as_ref().clone() + } + } else { + ty.as_ref().clone() + } + } + syn::ReturnType::Default => return Err("Function must have return type".into()), + }; + + // Generate struct fields + let param_fields: Vec<_> = params + .iter() + .map(|p| { + let name = &p.rust_name; + if p.is_var { + quote! { #name: shards::types::ParamVar } + } else { + quote! { #name: shards::types::ClonedVar } + } + }) + .collect(); + + // Generate default values + let param_defaults: Vec<_> = params + .iter() + .map(|p| { + let name = &p.rust_name; + let default_val = p + .default + .as_ref() + .map(|d| quote! { (#d).into() }) + .unwrap_or_else(|| quote! { Default::default() }); + + if p.is_var { + quote! { #name: shards::types::ParamVar::new(#default_val) } + } else { + quote! { #name: #default_val } + } + }) + .collect(); + + // Generate parameter info + let param_names: Vec<_> = params + .iter() + .map(|p| LitStr::new(&p.name, Span::call_site())) + .collect(); + let param_descs: Vec<_> = params + .iter() + .map(|p| LitStr::new(&p.description, Span::call_site())) + .collect(); + let param_rust_names: Vec<_> = params.iter().map(|p| &p.rust_name).collect(); + let param_types: Vec<_> = params.iter().map(|p| &p.rust_type).collect(); + let param_indices: Vec<_> = (0..params.len()) + .map(|i| LitInt::new(&format!("{}", i), Span::call_site())) + .collect(); + + // Generate parameter extraction in activate + let param_extractions: Vec<_> = params + .iter() + .map(|p| { + let name = &p.rust_name; + let ty = &p.rust_type; + if p.is_var { + quote! { + let #name: #ty = self.#name.get().as_ref().try_into()?; + } + } else { + quote! { + let #name: #ty = self.#name.0.as_ref().try_into()?; + } + } + }) + .collect(); + + // Generate warmup/cleanup calls for ParamVar + let param_warmups: Vec<_> = params + .iter() + .filter(|p| p.is_var) + .map(|p| { + let name = &p.rust_name; + quote! { self.#name.warmup(context); } + }) + .collect(); + + let param_cleanups: Vec<_> = params + .iter() + .filter(|p| p.is_var) + .map(|p| { + let name = &p.rust_name; + quote! { self.#name.cleanup(context); } + }) + .collect(); + + // The original function body + let func_body = &func.block; + let func_name = &func.sig.ident; + + // CRC for shard hash + let crc = crc32(format!("{}-rust-0x20250822", shard_name)); + + // Determine if we have params + let has_params = !params.is_empty(); + + let parameters_impl = if has_params { + quote! { + fn parameters(&mut self) -> Option<&shards::types::Parameters> { + Some(&#params_static_id) + } + } + } else { + quote! { + fn parameters(&mut self) -> Option<&shards::types::Parameters> { + None + } + } + }; + + let set_get_param_impl = if has_params { + quote! { + fn set_param(&mut self, index: i32, value: &shards::types::Var) -> std::result::Result<(), &'static str> { + match index { + #( + #param_indices => self.#param_rust_names.set_param(value), + )* + _ => Err("Invalid parameter index"), + } + } + + fn get_param(&mut self, index: i32) -> shards::types::Var { + match index { + #( + #param_indices => (&self.#param_rust_names).into(), + )* + _ => shards::types::Var::default(), + } + } + } + } else { + quote! { + fn set_param(&mut self, _index: i32, _value: &shards::types::Var) -> std::result::Result<(), &'static str> { + Err("No parameters") + } + + fn get_param(&mut self, _index: i32) -> shards::types::Var { + shards::types::Var::default() + } + } + }; + + // Generate parameter type arrays that include both base type and var type + let param_type_array_ids: Vec<_> = params + .iter() + .enumerate() + .map(|(i, _)| { + Ident::new( + &format!("{}_PARAM_{}_TYPES", struct_name.to_uppercase(), i), + Span::call_site(), + ) + }) + .collect(); + + let params_static_def = if has_params { + let param_type_arrays: Vec<_> = params + .iter() + .zip(param_type_array_ids.iter()) + .map(|(p, id)| { + let ty = &p.rust_type; + quote! { + static ref #id: shards::types::Types = vec![ + <#ty as shards::types::ShardType>::shards_type(), + <#ty as shards::types::ShardType>::shards_var_type() + ]; + } + }) + .collect(); + + quote! { + lazy_static::lazy_static! { + #(#param_type_arrays)* + static ref #params_static_id: shards::types::Parameters = vec![ + #( + ( + shards::cstr!(#param_names), + shards::shccstr!(#param_descs), + #param_type_array_ids.as_slice() + ).into() + ),* + ]; + } + } + } else { + quote! {} + }; + + // Generate activate call - differs for unit input vs normal input, and Result vs plain return + let activate_call = match (is_unit_input, returns_result) { + (true, true) => quote! { #func_name((), #(#param_rust_names),*)? }, + (true, false) => quote! { #func_name((), #(#param_rust_names),*) }, + (false, true) => quote! { + { + let #input_name: #input_type = input.try_into()?; + #func_name(#input_name, #(#param_rust_names),*)? + } + }, + (false, false) => quote! { + { + let #input_name: #input_type = input.try_into()?; + #func_name(#input_name, #(#param_rust_names),*) + } + }, + }; + + // Generate output assignment - typed outputs (BytesOut, StringOut) are already ClonedVar wrappers + let output_assignment = if returns_typed_out { + quote! { self.output = result.0; } + } else { + quote! { self.output = result.into(); } + }; + + // Generate compose calls for param_var parameters + let has_var_params = params.iter().any(|p| p.is_var); + let param_var_composes: Vec<_> = params + .iter() + .filter(|p| p.is_var) + .map(|p| { + let name = &p.rust_name; + let param_name = &p.name; + let ty = &p.rust_type; + quote! { + { + let param_types: shards::types::Types = vec![ + <#ty as shards::types::ShardType>::shards_type(), + <#ty as shards::types::ShardType>::shards_var_type() + ]; + shards::util::collect_required_variables_typed( + data, + &mut self.required, + (&self.#name).into(), + ¶m_types[..], + #param_name + )?; + } + } + }) + .collect(); + + // Generate the inner function signature based on whether it returns Result or plain type + let inner_function = if returns_result { + quote! { + #[inline] + fn #func_name(#input_name: #input_type, #(#param_rust_names: #param_types),*) -> std::result::Result<#output_type, &'static str> { + #func_body + } + } + } else { + quote! { + #[inline] + fn #func_name(#input_name: #input_type, #(#param_rust_names: #param_types),*) -> #output_type { + #func_body + } + } + }; + + // Generate the full shard implementation + let output = quote! { + // The inner function with the actual logic + #inner_function + + pub struct #struct_id { + required: shards::types::ExposedTypes, + #(#param_fields,)* + output: shards::types::ClonedVar, + } + + impl Default for #struct_id { + fn default() -> Self { + Self { + required: shards::types::ExposedTypes::new(), + #(#param_defaults,)* + output: shards::types::ClonedVar::default(), + } + } + } + + #params_static_def + + impl shards::shard::ShardGenerated for #struct_id { + fn register_name() -> &'static str { + shards::cstr!(#shard_name) + } + + fn name(&mut self) -> &str { + #shard_name + } + + fn hash() -> u32 { + #crc + } + + fn help(&mut self) -> shards::types::OptionalString { + shards::types::OptionalString(shards::shccstr!(#shard_help)) + } + + #parameters_impl + + #set_get_param_impl + + fn required_variables(&mut self) -> Option<&shards::types::ExposedTypes> { + Some(&self.required) + } + } + + impl shards::shard::ShardGeneratedOverloads for #struct_id { + fn has_compose() -> bool { #has_var_params } + fn has_warmup() -> bool { true } + fn has_mutate() -> bool { false } + fn has_crossover() -> bool { false } + fn has_get_state() -> bool { false } + fn has_set_state() -> bool { false } + fn has_reset_state() -> bool { false } + } + + impl shards::shard::Shard for #struct_id { + fn input_types(&mut self) -> &shards::types::Types { + <#input_type as shards::types::ShardType>::shards_types() + } + + fn output_types(&mut self) -> &shards::types::Types { + <#output_type as shards::types::ShardType>::shards_types() + } + + fn warmup(&mut self, context: &shards::types::Context) -> std::result::Result<(), &str> { + #(#param_warmups)* + Ok(()) + } + + fn cleanup(&mut self, context: std::option::Option<&shards::types::Context>) -> std::result::Result<(), &str> { + #(#param_cleanups)* + self.output = shards::types::ClonedVar::default(); + Ok(()) + } + + fn compose(&mut self, data: &shards::types::InstanceData) -> std::result::Result { + self.required.clear(); + #(#param_var_composes)* + Ok(<#output_type as shards::types::ShardType>::shards_type()) + } + + fn activate(&mut self, _context: &shards::types::Context, input: &shards::types::Var) -> std::result::Result, &str> { + #(#param_extractions)* + + let result = #activate_call; + #output_assignment + Ok(Some(self.output.0)) + } + } + }; + + Ok(output.into()) +} + +/// Simple shard definition via function attribute. +/// +/// This macro allows defining shards with a simple function syntax, +/// automatically generating all the boilerplate code. +/// +/// # Example +/// ```rust,ignore +/// #[simple_shard("Math.Scale", "Scales input by factor")] +/// fn scale( +/// input: i64, +/// #[param("Factor", "Scale factor", default = 2)] +/// factor: i64, +/// ) -> Result { +/// Ok(input * factor) +/// } +/// ``` +/// +/// This will generate a `MathScaleShard` struct with all necessary +/// trait implementations. +#[proc_macro_attribute] +pub fn simple_shard(attr: TokenStream, item: TokenStream) -> TokenStream { + let args = syn::parse_macro_input!(attr as SimpleShardAttrArgs); + let func = syn::parse_macro_input!(item as syn::ItemFn); + + match generate_simple_shard(args, func) { + Ok(result) => { + // eprintln!("simple_shard:\n{}", result); + result + } + Err(err) => err.to_compile_error(), + } +}