diff --git a/Cargo.lock b/Cargo.lock index ca04c24ad..e94ccd287 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addchain" @@ -344,6 +344,7 @@ dependencies = [ "logos", "lp-solvers 0.0.4", "merlin", + "num-traits", "once_cell", "pairing", "paste", diff --git a/Cargo.toml b/Cargo.toml index 6d5113a4c..c2f4f85de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,7 @@ curve25519-dalek = {version = "3.2.0", features = ["serde"], optional = true} paste = "1.0" im = "15" once_cell = "1" +num-traits = "0.2" [dev-dependencies] quickcheck = "1" diff --git a/src/ir/opt/fp/extras.rs b/src/ir/opt/fp/extras.rs new file mode 100644 index 000000000..ac4042a41 --- /dev/null +++ b/src/ir/opt/fp/extras.rs @@ -0,0 +1,130 @@ +//! Extra algorithms for hint (precompute) and gadget (sub-circuit) constructions + +use crate::cfg::cfg_or_default as cfg; +use crate::ir::term::*; + +use super::FloatRewriter; + +/// Wrapper for pf_lit +pub fn new_pf_lit(v: I) -> Term +where + rug::Integer: From, +{ + pf_lit(cfg().field().new_v(v)) +} + +/// Computes a hint given a (mutable) computation, variable name, and hint function +/// +/// Syntax: +/// +/// * without children: `compute_hint![COMP, NAME, HINT]` +/// * with children: `compute_hint![COMP, NAME, HINT; ARG0, ARG1, ... ]` +/// +/// Returns a variable term associated to a precompute term that can be used in the computation. +#[macro_export] +macro_rules! compute_hint { + ($comp:expr, $name:expr, $hint_fn:path) => {{ + let hint = $hint_fn(); + let sort = check(&hint); + $comp.extend_precomputation($name.clone(), hint); + var($name, sort) + }}; + + ($comp:expr, $name:expr, $hint_fn:path; $($args:expr),+) => {{ + let hint = $hint_fn($($args),+); + let sort = check(&hint); + $comp.extend_precomputation($name.clone(), hint); + var($name, sort) + }}; +} + +/// Hints aka nondeterministic advice +/// +/// These are unconstrained precompute terms, which _should_ be enforced by other +/// mechanisms in the computation. +pub struct Hint; +/// Frequently used gadgets +pub struct Gadget<'a> { + pub rewriter: &'a mut FloatRewriter, +} + +impl Hint { + /// Precompute the sign and exponent of a float, together with its pf encoding. + pub fn decode_float(var: &Term, e_size: usize, m_size: usize) -> Term { + let var_pf = match var.op() { + Op::Var(_) => term![Op::FpToPf(cfg().field().clone()); var.clone()], + _ => panic!("Expected Op::Var but found something else"), + }; + let var_bv = term![Op::PfToBv(64); var_pf.clone()]; + + let s = term![BV_LSHR; var_bv.clone(), bv_lit(e_size + m_size, 64)]; + let shifted_v = term![BV_LSHR; var_bv.clone(), bv_lit(m_size, 64)]; + let shifted_s = term![BV_SHL; s.clone(), bv_lit(e_size, 64)]; + let e = term![BV_SUB; shifted_v, shifted_s]; + + term![ + Op::Tuple; + term![Op::UbvToPf(Box::new(cfg().field().clone())); s], + term![Op::UbvToPf(Box::new(cfg().field().clone())); e], + var_pf + ] + } + /// Precompute the number of leading zeros of mantissa within m_size bits + pub fn normalize(var: &Term, m_size: usize) -> Term { + let mantissa = term![Op::PfToBv(m_size); var.clone()]; + + let mut d = bv_lit(0, m_size); + // shift length is number of zeros + // d := ite (bit == 0) (d + 1) (d) + for i in (0..m_size).rev() { + let bit = term![Op::BvBit(i); mantissa.clone()]; + let cond = term![EQ; bit, bool_lit(false)]; + d = term![ITE; cond, term![BV_ADD; d.clone(), bv_lit(1, m_size)], d.clone()] + } + + // if mantissa is zero, return zero, else return computed shift + let is_zero = term![EQ; mantissa.clone(), bv_lit(0, m_size)]; + term![ITE; is_zero, new_pf_lit(0), term![Op::UbvToPf(Box::new(cfg().field().clone())); d.clone()]] + } + // /// Precompute the power of two of a given pf term. + // pub fn power_of_two(var: &Term) -> Term { + // // maybe set the size to m_size. come back to this later. + // let d = term![Op::PfToBv(64); var.clone()]; + // term![Op::UbvToPf(Box::new(cfg().field().clone())); term![BV_SHL; bv_lit(1, 64), d]] + // } +} + +impl Gadget<'_> { + pub fn assert(&mut self, cond: Term) { + self.rewriter.assertions.push(cond); + } + + // /// TODO: implement lookup-related machinery + // pub fn query_power_of_two(&self, comp: &mut Computation, exp: &Term) -> Term { + // // Precompute the power of two of the exponent + // let res = compute_hint!( + // comp, + // Hint::power_of_two(exp), + // format!("{}.pow2", exp.as_var_name()) + // ); + + // // continue here... + // res + // } + + /// Creates a bool term checking if pf term is 1 or 0 + pub fn is_bool(&self, a: &Term) -> Term { + // compute a*(1-a) == 0 + self.is_zero(&term![PF_MUL; a.clone(), self.pf_sub(&new_pf_lit(1), a)]) + } + + /// Creates a bool term checking if pf term a is zero + pub fn is_zero(&self, a: &Term) -> Term { + term![EQ; a.clone(), new_pf_lit(0)] + } + + /// Creates a pf term subtracting pf term b from pf term a + pub fn pf_sub(&self, a: &Term, b: &Term) -> Term { + term![PF_ADD; a.clone(), term![PF_NEG; b.clone()]] + } +} diff --git a/src/ir/opt/fp/mod.rs b/src/ir/opt/fp/mod.rs new file mode 100644 index 000000000..4e4ce6f21 --- /dev/null +++ b/src/ir/opt/fp/mod.rs @@ -0,0 +1,361 @@ +//! Floating-point IEEE-754 constructions +//! +//! Replace IR floats with their respective constructions using lookup arguments. +//! The floats themselves are stored as a tuple of components. +//! +//! A tuple elimination pass should be run afterwards. +//! Based on the implementation of + +mod extras; + +use super::visit::RewritePass; +use crate::cfg::cfg_or_default as cfg; +use crate::compute_hint; +use crate::ir::term::*; +use extras::{new_pf_lit, Gadget, Hint}; +// use lookup::{lookup_exponent, lookup_mantissa}; +use rug::Integer; + +/// Represents an IEEE-754 floating-point number +#[allow(dead_code)] +struct FloatVar { + sign: Term, + exponent: Term, + mantissa: Term, + is_abnormal: Term, +} + +// Context +#[derive(Debug, Clone)] +#[allow(dead_code)] +struct Context { + e: usize, // Exponent bit width + m: usize, // Mantissa bit width + e_max: Integer, + e_normal_min: Integer, + e_min: Integer, +} + +#[allow(dead_code)] +impl Context { + pub fn new(e: usize, m: usize) -> Self { + let e_max = Integer::from(1) << (e - 1); + let e_normal_min = Integer::from(2) - &e_max; + let e_min = e_normal_min.clone() - (m + 1); + + Self { + e, + m, + e_max, + e_normal_min, + e_min, + } + } + + fn components_of(&self, v: u64) -> (Integer, Integer, Integer, Integer) { + let s = v >> (self.e + self.m); + let e = (v >> self.m) - (s << self.e); + let m = v - (s << (self.e + self.m)) - (e << self.m); + + let sign = Integer::from(s as i64); + + let exponent_max = Integer::from(1) << (self.e - 1); + let exponent_min = Integer::from(1) - &exponent_max; + + let mut exponent = Integer::from(e as i64) + &exponent_min; + let mut mantissa = Integer::from(m as i64); + let mut shift = 0_usize; + + for i in (0..self.m).rev() { + if ((m >> (self.m - 1 - i)) & 1) == 1 { + break; + } + shift += 1; + } + + let mantissa_is_not_zero = m != 0; + let exponent_is_min = exponent == exponent_min; + let exponent_is_max = exponent == exponent_max; + + if exponent_is_min { + exponent -= Integer::from(shift as i64); + mantissa <<= shift + 1; + } else if exponent_is_max && mantissa_is_not_zero { + mantissa = Integer::from(0); + } else { + mantissa += Integer::from(1) << self.m; + } + + let is_abnormal = if exponent_is_max { + Integer::from(1) + } else { + Integer::from(0) + }; + + (sign, exponent, mantissa, is_abnormal) + } + + pub fn new_f32_constant(&self, v: f32) -> (Integer, Integer, Integer, Integer) { + self.components_of(v.to_bits() as u64) + } + + pub fn new_f64_constant(&self, v: f64) -> (Integer, Integer, Integer, Integer) { + self.components_of(v.to_bits()) + } +} + +#[derive(Default)] +struct FloatRewriter { + assertions: Vec, +} + +impl RewritePass for FloatRewriter { + fn traverse(&mut self, computation: &mut Computation) { + self.traverse_full(computation, false, true); + + // Apply all stored assertions at the end + for assertion in self.assertions.drain(..) { + computation.assert(assertion); + } + } + + fn visit Vec>( + &mut self, + computation: &mut Computation, + orig: &Term, + _rewritten_children: F, + ) -> Option { + let mut gadget = Gadget { rewriter: self }; + + match &orig.op() { + // Destructures float variables into constrained tuples of (s: bool, e: field, m: field, a: bool). + Op::Var(v) => { + let ctx = match v.sort { + Sort::F32 => Context::new(8, 23), + Sort::F64 => Context::new(11, 52), + _ => todo!(), + }; + + // Precompute the float components + let fp_tup = compute_hint![ + computation, + format!("{}.fp_tup", v.name), + Hint::decode_float; + orig, + ctx.e, + ctx.m + ]; + let s = term(Op::Field(0), vec![fp_tup.clone()]); + let e = term(Op::Field(1), vec![fp_tup.clone()]); + let v_pf = term(Op::Field(2), vec![fp_tup.clone()]); + + let s_mul = term![PF_MUL; s.clone(), new_pf_lit(1u64 << (ctx.e + ctx.m))]; + let e_mul = term![PF_MUL; e.clone(), new_pf_lit(1u64 << ctx.m)]; + let m = term![PF_ADD; v_pf, term![PF_NEG; term![PF_ADD; s_mul, e_mul]]]; + + // Assert well-formedness of s, e and m through range checks + gadget.assert(gadget.is_bool(&s)); + gadget.assert(term![Op::PfFitsInBits(ctx.e); e.clone()]); + gadget.assert(term![Op::PfFitsInBits(ctx.m); m.clone()]); + + let exponent_min = gadget.pf_sub(&new_pf_lit(ctx.e_normal_min), &new_pf_lit(1)); + let exponent_max = new_pf_lit(ctx.e_max); + + let exponent = term![PF_ADD; e.clone(), exponent_min.clone()]; + + let mantissa_is_zero = gadget.is_zero(&m); + let _mantissa_is_not_zero = term![NOT; mantissa_is_zero.clone()]; + let _exponent_is_min = term![EQ; exponent.clone(), exponent_min.clone()]; + let _exponent_is_max = term![EQ; exponent.clone(), exponent_max]; + + // Precompute the shift length `d` for float normalization + let d = compute_hint![ + computation, + format!("{}.d", v.name), + Hint::normalize; + &m, + ctx.m + ]; + + // Assert soundness of `d` with a range check + // It should be in the range [0, ctx.m]. + let bitlen = (usize::BITS - ctx.m.leading_zeros()) as usize; + gadget.assert(term![Op::PfFitsInBits(bitlen); d.clone()]); + // continue here with the querypowerof2 gadget operation on shift + + // return the term that we rewrite orig with + Some(term![ + Op::Tuple; + term![Op::PfToBoolTrusted; s.clone()], + e, + m, + term![Op::PfToBoolTrusted; s] + ]) + } + // Rewrites float constants into tuples of (s: bool, e: field, m: field, a: bool). + Op::Const(v) => { + let comps = match v.as_ref() { + Value::F32(fp) => Some(Context::new(8, 23).new_f32_constant(*fp)), + Value::F64(fp) => Some(Context::new(11, 52).new_f64_constant(*fp)), + _ => None, + }; + comps.map(|float| { + term( + Op::Tuple, + vec![ + const_(Value::Bool(float.0 == 1)), + const_(Value::Field(cfg().field().new_v(float.1))), + const_(Value::Field(cfg().field().new_v(float.2))), + const_(Value::Bool(float.3 == 1)), + ], + ) + }) + } + Op::Eq => todo!(), + Op::FpBinOp(_fp_bin_op) => todo!(), + Op::FpBinPred(_fp_bin_pred) => todo!(), + Op::FpUnPred(_fp_un_pred) => todo!(), + Op::FpUnOp(_fp_un_op) => todo!(), + Op::BvToFp => todo!(), + Op::UbvToFp(_) => todo!(), + Op::SbvToFp(_) => todo!(), + Op::FpToFp(_) => todo!(), + Op::PfToFp(_) => todo!(), + _ => None, + } + } +} + +/// Replace floats with IEEE 754 constructions +pub fn construct_floats(c: &mut Computation) { + let mut pass = FloatRewriter::default(); + pass.traverse(c); +} + +#[cfg(test)] +mod test { + use super::*; + use crate::cfg::cfg_or_default as cfg; + + #[test] + fn const_fp32() { + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs) (commitments)) + (precompute () () (#t )) + (let + ( + (val #fp3.1415926535f32) + ) + val + ) + ) + ", + ); + construct_floats(&mut c); + + let expected = term( + Op::Tuple, + vec![ + const_(Value::Bool(false)), + const_(Value::Field(cfg().field().new_v(1))), + const_(Value::Field(cfg().field().new_v(13176795))), + const_(Value::Bool(false)), + ], + ); + + assert_eq!(c.outputs[0], expected); + } + + #[test] + fn const_fp64() { + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs) (commitments)) + (precompute () () (#t )) + (let + ( + (val #fp-3.141592653589793) + ) + val + ) + ) + ", + ); + construct_floats(&mut c); + + let expected = term( + Op::Tuple, + vec![ + const_(Value::Bool(true)), + const_(Value::Field(cfg().field().new_v(1))), + const_(Value::Field(cfg().field().new_v( + Integer::from_str_radix("7074237752028440", 10).unwrap(), + ))), + const_(Value::Bool(false)), + ], + ); + + assert_eq!(c.outputs[0], expected); + } + + // #[test] + // fn init_fp32() { + // let mut c = text::parse_computation( + // b" + // (computation + // (metadata (parties ) (inputs (a f32)) (commitments)) + // (precompute () () (#t )) + // (let + // ( + // (val a) + // ) + // val + // ) + // ) + // ", + // ); + // let values = text::parse_value_map( + // b" + // (set_default_modulus 11 + // (let ( + // (a #fp3.1415926535f32) + // ) false ; ignored + // )) + // ", + // ); + // // println!("{}", text::serialize_computation(&c)); + // construct_floats(&mut c); + // println!("{}", text::serialize_computation(&c)); + // // println!("{:#?}", c.eval_all(&values)); + // assert_eq!(vec![Value::F32(1.234)], c.eval_all(&values)); + // } + + // #[test] + // fn add_fp32() { + // let mut c = text::parse_computation( + // b" + // (computation + // (metadata (parties ) (inputs (a f32) (b f32) (return f32)) (commitments)) + // (precompute () () (#t )) + // (= return (fpadd a b)) + // ) + // ", + // ); + // let values = text::parse_value_map( + // b" + // (set_default_modulus 11 + // (let ( + // (a #fp1f32) + // (b #fp2f32) + // (return #fp3f32) + // ) false ; ignored + // )) + // ", + // ); + // construct_floats(&mut c); + // assert_eq!(vec![Value::F32(3.0)], c.eval_all(&values)); + // } +} diff --git a/src/ir/opt/mod.rs b/src/ir/opt/mod.rs index fa358ad0e..e4d1c5947 100644 --- a/src/ir/opt/mod.rs +++ b/src/ir/opt/mod.rs @@ -5,6 +5,7 @@ pub mod chall; pub mod cstore; pub mod fits_in_bits_ip; pub mod flat; +pub mod fp; pub mod inline; pub mod link; pub mod mem; @@ -59,6 +60,8 @@ pub enum Opt { DeskolemizeWitnesses, /// Check bit-constaints with challenges. FitsInBitsIp, + /// Replace floats with IEEE 754 constructions using lookup arguments + Float, } /// Run optimizations on `cs`, in this order, returning the new constraint system. @@ -169,6 +172,9 @@ pub fn opt>(mut cs: Computations, optimizations: I) Opt::FitsInBitsIp => { fits_in_bits_ip::fits_in_bits_ip(c); } + Opt::Float => { + fp::construct_floats(c); + } } info!( "After {:?}: {} terms", diff --git a/src/ir/term/eval.rs b/src/ir/term/eval.rs index 3e0a010a8..6be6355a2 100644 --- a/src/ir/term/eval.rs +++ b/src/ir/term/eval.rs @@ -2,8 +2,9 @@ use super::{ check, const_, extras, term, Array, BitVector, BoolNaryOp, BvBinOp, BvBinPred, BvNaryOp, - BvUnOp, FieldToBv, FxHashMap, IntBinOp, IntBinPred, IntNaryOp, IntUnOp, Integer, Node, Op, - PfNaryOp, PfUnOp, Sort, Term, TermMap, Value, + BvUnOp, FieldToBv, Float, FpBinOp, FpBinPred, FpUnOp, FpUnPred, FxHashMap, IntBinOp, + IntBinPred, IntNaryOp, IntUnOp, Integer, Node, Op, PfNaryOp, PfUnOp, Sort, Term, TermMap, + Value, }; use crate::cfg::cfg_or_default; @@ -165,6 +166,7 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> } }), Op::BoolToBv => Value::BitVector(BitVector::new(Integer::from(args[0].as_bool()), 1)), + Op::PfUnOp(o) => Value::Field({ let a = args[0].as_pf().clone(); match o { @@ -215,7 +217,6 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> }, ) }), - Op::IntBinOp(o) => Value::Int({ let a = args[0].as_int().clone(); let b = args[1].as_int().clone(); @@ -239,6 +240,139 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> IntUnOp::Neg => -a, } }), + Op::FpBinOp(o) => { + fn comp(a: T, b: T, op: FpBinOp) -> T { + match op { + FpBinOp::Add => a + b, + FpBinOp::Sub => a - b, + FpBinOp::Mul => a * b, + FpBinOp::Div => a / b, + FpBinOp::Rem => a % b, + FpBinOp::Max => a.max(b), + FpBinOp::Min => a.min(b), + } + } + match (args[0], args[1]) { + (Value::F32(_), Value::F32(_)) => { + Value::F32(comp(args[0].as_f32(), args[1].as_f32(), *o)) + } + (Value::F64(_), Value::F64(_)) => { + Value::F64(comp(args[0].as_f64(), args[1].as_f64(), *o)) + } + _ => panic!( + "Expected two F32 or F64, got LHS {} and RHS {}", + args[0], args[1] + ), + } + } + Op::FpBinPred(o) => Value::Bool({ + fn comp(a: T, b: T, op: FpBinPred) -> bool { + match op { + FpBinPred::Le => a <= b, + FpBinPred::Lt => a < b, + FpBinPred::Eq => a == b, + FpBinPred::Ge => a >= b, + FpBinPred::Gt => a > b, + } + } + match (args[0], args[1]) { + (Value::F32(_), Value::F32(_)) => comp(args[0].as_f32(), args[1].as_f32(), *o), + (Value::F64(_), Value::F64(_)) => comp(args[0].as_f64(), args[1].as_f64(), *o), + _ => panic!( + "Expected two F32 or F64, got LHS {} and RHS {}", + args[0], args[1] + ), + } + }), + Op::FpUnPred(o) => Value::Bool({ + fn comp(a: T, op: FpUnPred) -> bool { + match op { + FpUnPred::Normal => a.is_normal(), + FpUnPred::Subnormal => a.is_subnormal(), + FpUnPred::Zero => a == T::zero(), + FpUnPred::Infinite => a.is_infinite(), + FpUnPred::Nan => a.is_nan(), + FpUnPred::Negative => a.is_sign_negative(), + FpUnPred::Positive => a.is_sign_positive(), + } + } + match args[0] { + Value::F32(a) => comp(*a, *o), + Value::F64(a) => comp(*a, *o), + _ => panic!("Expected F32 or F64, got {}", args[0]), + } + }), + Op::FpUnOp(o) => { + fn comp(a: T, op: FpUnOp) -> T { + match op { + FpUnOp::Neg => -a, + FpUnOp::Abs => a.abs(), + FpUnOp::Sqrt => a.sqrt(), + FpUnOp::Round => a.round(), + } + } + match args[0] { + Value::F32(a) => Value::F32(comp(*a, *o)), + Value::F64(a) => Value::F64(comp(*a, *o)), + _ => panic!("Expected F32 or F64, got {}", args[0]), + } + } + Op::BvToFp => { + let bv = args[0].as_bv(); + let val = bv.uint(); + let w = bv.width(); + match w { + 32 => Value::F32(f32::from_bits(val.to_u32().unwrap())), + 64 => Value::F64(f64::from_bits(val.to_u64().unwrap())), + _ => panic!("{} out of bounds for {} on {:?}", w, op, args), + } + } + Op::UbvToFp(w) => { + let val = args[0].as_bv().uint(); + match w { + 0..=32 => Value::F32(val.to_f32()), + 33..=64 => Value::F64(val.to_f64()), + _ => panic!("{} out of bounds for {} on {:?}", w, op, args), + } + } + Op::SbvToFp(w) => { + let val = args[0].as_bv().as_sint(); + match w { + 0..=32 => Value::F32(val.to_f32()), + 33..=64 => Value::F64(val.to_f64()), + _ => panic!("{} out of bounds for {} on {:?}", w, op, args), + } + } + Op::FpToFp(w) => { + match (args[0], w) { + (Value::F32(v), 64) => Value::F64(*v as f64), // Promote F32 to F64 + (Value::F64(v), 32) => Value::F32(*v as f32), // Truncate F64 to F32 + (Value::F32(_), 32) | (Value::F64(_), 64) => args[0].clone(), + _ => panic!("Invalid conversion width {} (expected 32 or 64)", w), + } + } + Op::PfToFp(w) => { + let val = args[0].as_pf().i(); + match w { + 32 => Value::F32(val.to_f32()), + 64 => Value::F64(val.to_f64()), + _ => panic!( + "{} out of bounds for {} on {:?} (expected 32 or 64)", + w, op, args + ), + } + } + Op::FpToPf(fty) => { + let val = match args[0] { + Value::F32(f) => rug::Integer::from(f.to_bits()), + Value::F64(f) => rug::Integer::from(f.to_bits()), + _ => panic!( + "Expected floating-point value for {} but got {:?}", + op, args[0] + ), + }; + Value::Field(fty.new_v(val)) + } Op::UbvToPf(fty) => Value::Field(fty.new_v(args[0].as_bv().uint())), Op::PfChallenge(c) => Value::Field(eval_pf_challenge(&c.name, &c.field)), Op::Witness(_) => args[0].clone(), diff --git a/src/ir/term/fmt.rs b/src/ir/term/fmt.rs index 7575a37b6..76c1babce 100644 --- a/src/ir/term/fmt.rs +++ b/src/ir/term/fmt.rs @@ -364,6 +364,8 @@ impl DisplayIr for Op { Op::UbvToFp(a) => write!(f, "(ubv2fp {a})"), Op::SbvToFp(a) => write!(f, "(sbv2fp {a})"), Op::FpToFp(a) => write!(f, "(fp2fp {a})"), + Op::PfToFp(a) => write!(f, "(pf2fp {a})"), + Op::FpToPf(a) => write!(f, "(fp2pf {a})"), Op::PfUnOp(a) => write!(f, "{a}"), Op::PfNaryOp(a) => write!(f, "{a}"), Op::PfDiv => write!(f, "/"), diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index 09d90dad6..3d901bcd1 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -26,6 +26,7 @@ pub use circ_hc::{Node, Table, Weak}; use circ_opt::FieldToBv; use fxhash::{FxHashMap, FxHashSet}; use log::debug; +use num_traits::Float; use rug::Integer; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::borrow::Borrow; @@ -119,6 +120,11 @@ pub enum Op { // dest width /// translate the number represented by the argument to a floating-point value of this width. FpToFp(usize), + /// translate the prime-field number represented by the argument to a floating-point value + /// of this width. + PfToFp(usize), + /// Floating-point value to Field + FpToPf(FieldT), /// Prime-field unary operator PfUnOp(PfUnOp), @@ -333,6 +339,52 @@ pub const INT_LE: Op = Op::IntBinPred(IntBinPred::Le); pub const INT_GT: Op = Op::IntBinPred(IntBinPred::Gt); /// integer greater than or equal pub const INT_GE: Op = Op::IntBinPred(IntBinPred::Ge); +/// floating-point addition +pub const FP_ADD: Op = Op::FpBinOp(FpBinOp::Add); +/// floating-point multiplication +pub const FP_MUL: Op = Op::FpBinOp(FpBinOp::Mul); +/// floating-point subtraction +pub const FP_SUB: Op = Op::FpBinOp(FpBinOp::Sub); +/// floating-point division +pub const FP_DIV: Op = Op::FpBinOp(FpBinOp::Div); +/// floating-point remainder +pub const FP_REM: Op = Op::FpBinOp(FpBinOp::Rem); +/// floating-point maximum +pub const FP_MAX: Op = Op::FpBinOp(FpBinOp::Max); +/// floating-point minimum +pub const FP_MIN: Op = Op::FpBinOp(FpBinOp::Min); +/// floating-point less than or equal +pub const FP_LE: Op = Op::FpBinPred(FpBinPred::Le); +/// floating-point less than +pub const FP_LT: Op = Op::FpBinPred(FpBinPred::Lt); +/// floating-point equal to +pub const FP_EQ: Op = Op::FpBinPred(FpBinPred::Eq); +/// floating-point greater than or equal +pub const FP_GE: Op = Op::FpBinPred(FpBinPred::Ge); +/// floating-point greater than +pub const FP_GT: Op = Op::FpBinPred(FpBinPred::Gt); +/// floating-point is normal +pub const FP_IS_NORM: Op = Op::FpUnPred(FpUnPred::Normal); +/// floating-point is subnormal +pub const FP_IS_SUBNORM: Op = Op::FpUnPred(FpUnPred::Subnormal); +/// floating-point is zero +pub const FP_IS_ZERO: Op = Op::FpUnPred(FpUnPred::Zero); +/// floating-point is infinite +pub const FP_IS_INF: Op = Op::FpUnPred(FpUnPred::Infinite); +/// floating-point is not-a-number +pub const FP_IS_NAN: Op = Op::FpUnPred(FpUnPred::Nan); +/// floating-point is negative +pub const FP_IS_NEG: Op = Op::FpUnPred(FpUnPred::Negative); +/// floating-point is positive +pub const FP_IS_POS: Op = Op::FpUnPred(FpUnPred::Positive); +/// floating-point unary negation +pub const FP_NEG: Op = Op::FpUnOp(FpUnOp::Neg); +/// floating-point absolute value +pub const FP_ABS: Op = Op::FpUnOp(FpUnOp::Abs); +/// floating-point square root +pub const FP_SQRT: Op = Op::FpUnOp(FpUnOp::Sqrt); +/// floating-point round +pub const FP_ROUND: Op = Op::FpUnOp(FpUnOp::Round); impl Op { /// Number of arguments for this operator. `None` if n-ary. @@ -365,6 +417,8 @@ impl Op { Op::UbvToFp(_) => Some(1), Op::SbvToFp(_) => Some(1), Op::FpToFp(_) => Some(1), + Op::PfToFp(_) => Some(1), + Op::FpToPf(_) => Some(1), Op::PfUnOp(_) => Some(1), Op::PfDiv => Some(2), Op::PfNaryOp(_) => None, @@ -1263,6 +1317,21 @@ impl Term { None } } + /// Get the underlying 32-bit floating-point constant, if possible. + pub fn as_f32_opt(&self) -> Option { + if let Some(Value::F32(v)) = self.as_value_opt() { + Some(*v) + } else { + None + } + } + /// Get the underlying 64-bit floating-point constant, if possible. + pub fn as_f64_opt(&self) -> Option { + match self.as_value_opt()? { + Value::F64(v) => Some(*v), + _ => None, + } + } /// Get the underlying prime field constant, if possible. pub fn as_pf_opt(&self) -> Option<&FieldV> { if let Some(Value::Field(b)) = self.as_value_opt() { @@ -1381,6 +1450,23 @@ impl Value { } } #[track_caller] + /// Get the underlying 32-bit floating-point constant, or panic! + pub fn as_f32(&self) -> f32 { + if let Value::F32(v) = self { + *v + } else { + panic!("Not a f32: {}", self) + } + } + #[track_caller] + /// Get the underlying 64-bit floating-point constant, or panic! + pub fn as_f64(&self) -> f64 { + match self { + Value::F64(v) => *v, + _ => panic!("Not a f64 or f32: {}", self), + } + } + #[track_caller] /// Get the underlying prime field constant, if possible. pub fn as_pf(&self) -> &FieldV { if let Value::Field(b) = self { diff --git a/src/ir/term/test.rs b/src/ir/term/test.rs index ecfb8c246..dc282ce9e 100644 --- a/src/ir/term/test.rs +++ b/src/ir/term/test.rs @@ -634,3 +634,766 @@ fn pf2bool_eval() { let expected_output = text::parse_value_map(b"(let ((output false)) false)"); assert_eq!(&actual_output, expected_output.get("output").unwrap()); } + +#[test] +fn fpsub_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsub #fp3 #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsub #fp3.333333f32 #fp0.191741f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3.141592f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsub #fp3.333333333333333 #fp0.19174067974354)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3.141592653589793") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsub #fp-3.402823466e+38f32 #fp2e31f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inff32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsub #fp-1.7976931348623158e+308 #fp1e292)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inf") + ); +} + +#[test] +fn fpadd_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp1 #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp1.1111 #fp-1.1111)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp1e-324 #fp-1e-324)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp3.402823466e+38f32 #fp2e31f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInff32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp1.7976931348623158e+308 #fp1e+292)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInf") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fpInf #fp1)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInf") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp-Inf #fp1)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inf") + ); + assert!(const_(eval( + &text::parse_term(b"(fpadd #fpInff32 #fp-Inff32)"), + &FxHashMap::default() + )) + .as_f32_opt() + .unwrap() + .is_nan()); + assert!(const_(eval( + &text::parse_term(b"(fpadd #fpNaN #fp2)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); + assert!(const_(eval( + &text::parse_term(b"(fpadd #fpNaN #fpNaN)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp1.0000001f32 #fp1e-7f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0000002f32") + ); +} + +#[test] +fn fpmul_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmul #fp3 #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0") + ); + assert!(const_(eval( + &text::parse_term(b"(fpmul #fp-Inf #fp0)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmul #fpInf #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInf") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmul #fpInf #fp-1)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inf") + ); + assert!(const_(eval( + &text::parse_term(b"(fpmul #fpNaN #fp2)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); + assert!(const_(eval( + &text::parse_term(b"(fpmul #fpNaN #fpNaN)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); +} + +#[test] +fn fpdiv_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpdiv #fp1 #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInf") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpdiv #fp-1 #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inf") + ); + assert!(const_(eval( + &text::parse_term(b"(fpdiv #fpNaN #fp2)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); +} + +#[test] +fn fprem_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fprem #fp3.1415 #fp3)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0.14150000000000018") + ); + assert!(const_(eval( + &text::parse_term(b"(fprem #fp3 #fp0)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); + assert!(const_(eval( + &text::parse_term(b"(fprem #fpInf #fp2)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); +} + +#[test] +fn fpmax_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmax #fp2 #fp5)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp5") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmax #fpNaN #fp3)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmax #fp-0 #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-0") // interesting + ); +} + +#[test] +fn fpmin_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmin #fp-3 #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-3") + ); +} + +#[test] +fn fpneg_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpneg #fp3)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-3") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpneg #fp-4.5)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp4.5") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpneg #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpneg #fpInf)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inf") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpneg #fp-Inf)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInf") + ); + assert!(const_(eval( + &text::parse_term(b"(fpneg #fpNan)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); +} + +#[test] +fn fpabs_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpabs #fp-7)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp7") + ); +} + +#[test] +fn fpsqrt_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsqrt #fp12.25)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3.5") + ); + assert!(const_(eval( + &text::parse_term(b"(fpsqrt #fp-4)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); +} + +#[test] +fn fpround_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpround #fp2.5)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpround #fp2.3)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp2") + ); + assert!(const_(eval( + &text::parse_term(b"(fpround #fpNan)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); +} + +#[test] +fn fpge_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpge #fp5 #fp3)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpge #fp3 #fp5)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpge #fp4 #fp4)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpge #fpNaN #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fpgt_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpgt #fp4 #fp4)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpgt #fpInf #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn fple_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fple #fpNaN #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fplt_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fplt #fp-Inf #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn fpeq_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpeq #fp1.0 #fp1)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpeq (fpadd #fp0.1 #fp0.2) #fp0.3)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpeq #fpNaN #fpNaN)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fpnormal_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnormal #fp0.01)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnormal #fp1.17549435e-39f32)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fpsubnormal_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsubnormal #fp0.01)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsubnormal #fp1.17549435e-39f32)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn fpzero_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpzero #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpzero #fp-0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpzero #fp1e-45f32)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fpinfinite_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpinfinite #fpInf)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpinfinite #fp-Inf)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpinfinite #fp1.7976931348623158e+308)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpinfinite #fp1.7976931348623159e+308)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn fpnan_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnan #fpNaN)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnan #fpInf)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnan #fp3.14)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fpnegative_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnegative #fp-2.5)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnegative #fp-0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn fppositive_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fppositive #fp3.7)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fppositive #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn bv2fp_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(bv2fp #b00111111100000000000000000000000)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(bv2fp #b00000000000000000000000000000000)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0f32") + ); + assert_eq!( + const_(eval( + &text::parse_term( + b"(bv2fp #b0011111111110000000000000000000000000000000000000000000000000000)" + ), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0") + ); + assert_eq!( + const_(eval( + &text::parse_term( + b"(bv2fp #b0000000000000000000000000000000000000000000000000000000000000000)" + ), + &FxHashMap::default() + )), + text::parse_term(b"#fp0") + ); + assert_eq!( + const_(eval( + &text::parse_term( + b"(bv2fp #b1000000000000000000000000000000000000000000000000000000000000000)" + ), + &FxHashMap::default() + )), + text::parse_term(b"#fp-0") + ); +} + +#[test] +fn ubv2fp_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"((ubv2fp 32) #x0001)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"((ubv2fp 64) #x0001)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"((ubv2fp 64) #xffffffff)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp4294967295.0") + ); +} + +#[test] +fn sbv2fp_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"((sbv2fp 64) #xffffffffffffffff)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-1.0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"((sbv2fp 64) #x7fffffff)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp2147483647.0") + ); +} + +#[test] +fn fp2fp_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"((fp2fp 64) #fp1.0f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0f64") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"((fp2fp 32) #fp16777217.0f64)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp16777216.0f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"((fp2fp 32) #fp2.5f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp2.5f32") + ); +} + +#[test] +fn pf2fp_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f1))"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f20))"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3.0f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f17))"), + &FxHashMap::default() + )), + text::parse_term(b"#fp17f32") // shouldn't this be 0? + ); + let p: Integer = Integer::from(1) << (512 - 1); // Mersenne prime + let overflow_f32: Integer = p.clone() - 1; + assert_eq!( + const_(eval( + &text::parse_term( + format!("(set_default_modulus {p} ((pf2fp 32) #f{overflow_f32}))").as_bytes() + ), + &FxHashMap::default() + )), + text::parse_term(b"#fpInff32") + ); + assert_eq!( + const_(eval( + &text::parse_term( + format!("(set_default_modulus {p} ((pf2fp 64) #f{overflow_f32}))").as_bytes() + ), + &FxHashMap::default() + )), + text::parse_term(b"#fp6.703903964971298e+153") + ); +} diff --git a/src/ir/term/text/lex.rs b/src/ir/term/text/lex.rs index 7905cc839..5056cfb9e 100644 --- a/src/ir/term/text/lex.rs +++ b/src/ir/term/text/lex.rs @@ -19,7 +19,8 @@ pub enum Token { Bin, #[regex(br"#f-?[0-9]+(m[0-9]+)?")] Field, - // TODO: Float + #[regex(br"#fp(?i:nan|-?inf|-?\d+(?:\.\d+)?)?(?:[eE][+-]?\d+)?(f32|f64)?")] + Float, // Identifiers #[regex(br"#t|#a|#l|#m|[^()0-9#; \t\n\f][^(); \t\n\f#]*")] @@ -43,7 +44,7 @@ mod test { #[test] fn all_tokens() { - let l = Token::lexer(b"(let ((a true)(b true)) (add (sub #b01 #xFf) (div 15 -16)))"); + let l = Token::lexer(b"(let ((a true)(b true)) (add (sub #b01 #xFf) (div 15 -16)) (add #fpnan #fp0) (add #fp-inf #fpinf) (add #fp31.415926e-1f32 #fp0.31415926e+1f64))"); let tokens: Vec<_> = l.into_iter().collect(); assert_eq!( &tokens, @@ -73,6 +74,21 @@ mod test { Token::Int, Token::Close, Token::Close, + Token::Open, + Token::Ident, + Token::Float, + Token::Float, + Token::Close, + Token::Open, + Token::Ident, + Token::Float, + Token::Float, + Token::Close, + Token::Open, + Token::Ident, + Token::Float, + Token::Float, + Token::Close, Token::Close, ] ) diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index 58c81030a..48f2fb91b 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -50,6 +50,8 @@ //! * bit-vector: `#xFFFF...`, `#bBBBB...` //! * field literal: `#fDD` or `#fDDmDD`. //! * In the former case, an ambient modulus must be set. +//! * floating-point literals: `#fpNan`, `#fpInf`, `#fpDD.DD` +//! * In all cases, a precision can be specified by appending the suffix `f32` or `f64`. //! * tuple: `(#t V1 ... Vn)` //! * array: `(#a Sk V N ((Vk1 Vv1) ... (Vkn Vvn)))` //! * list: `(#l Sk (V1 ... Vn))` @@ -316,6 +318,7 @@ impl<'src> IrInterp<'src> { [Leaf(Ident, b"ubv2fp"), a] => Ok(Op::UbvToFp(self.usize(a))), [Leaf(Ident, b"sbv2fp"), a] => Ok(Op::SbvToFp(self.usize(a))), [Leaf(Ident, b"fp2fp"), a] => Ok(Op::FpToFp(self.usize(a))), + [Leaf(Ident, b"pf2fp"), a] => Ok(Op::PfToFp(self.usize(a))), [Leaf(Ident, b"challenge"), name, field] => Ok(Op::new_chall( self.ident_string(name), FieldT::from(self.int(field)), @@ -413,6 +416,44 @@ impl<'src> IrInterp<'src> { _ => panic!("Expected integer, got {}", tt), } } + + /// Parse this text as a [Value::F32] or [Value::F64] literal. + fn parse_fp_literal(&mut self, lit: &str) -> Value { + let _lit = lit.to_lowercase(); + + if _lit == "inf" { + return Value::F64(f64::INFINITY); + } + if _lit == "-inf" { + return Value::F64(f64::NEG_INFINITY); + } + if _lit == "nan" { + return Value::F64(f64::NAN); + } + + // Parse as F32 when "f32" is found + if let Some(end) = _lit.find("f32") { + let num_part = &lit[..end]; + let val = num_part + .parse::() + .unwrap_or_else(|_| panic!("Invalid F32 literal '{}'", lit)); + return Value::F32(val); + } + // Parse as F64 when "f64" is found + else if let Some(end) = _lit.find("f64") { + let num_part = &lit[..end]; + let val = num_part + .parse::() + .unwrap_or_else(|_| panic!("Invalid F64 literal '{}'", lit)); + return Value::F64(val); + } + // Default: parse as F64 + let val = lit + .parse::() + .unwrap_or_else(|_| panic!("Invalid F64 literal '{}'", lit)); + Value::F64(val) + } + fn usize(&self, tt: &TokTree) -> usize { self.maybe_usize(tt).unwrap() } @@ -519,6 +560,10 @@ impl<'src> IrInterp<'src> { }; pf_lit(FieldV::new::(v, m)) } + Leaf(Token::Float, bytes) => { + let lit = std::str::from_utf8(&bytes[3..]).unwrap(); + const_(self.parse_fp_literal(lit)) + } Leaf(Ident, b"false") => bool_lit(false), Leaf(Ident, b"true") => bool_lit(true), Leaf(Ident, n) => self.get_binding(n).clone(), diff --git a/src/ir/term/ty.rs b/src/ir/term/ty.rs index 005ce0504..4768e427b 100644 --- a/src/ir/term/ty.rs +++ b/src/ir/term/ty.rs @@ -55,6 +55,8 @@ fn check_dependencies(t: &Term) -> Vec { Op::UbvToFp(_) => Vec::new(), Op::SbvToFp(_) => Vec::new(), Op::FpToFp(_) => Vec::new(), + Op::PfToFp(_) => Vec::new(), + Op::FpToPf(_) => Vec::new(), Op::PfUnOp(_) => vec![t.cs()[0].clone()], Op::PfDiv => vec![t.cs()[0].clone()], Op::PfNaryOp(_) => vec![t.cs()[0].clone()], @@ -139,6 +141,9 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result { Op::SbvToFp(32) => Ok(Sort::F32), Op::FpToFp(64) => Ok(Sort::F64), Op::FpToFp(32) => Ok(Sort::F32), + Op::PfToFp(64) => Ok(Sort::F64), + Op::PfToFp(32) => Ok(Sort::F32), + Op::FpToPf(m) => Ok(Sort::Field(m.clone())), Op::PfUnOp(_) => Ok(get_ty(&t.cs()[0]).clone()), Op::PfDiv => Ok(get_ty(&t.cs()[0]).clone()), Op::PfNaryOp(_) => Ok(get_ty(&t.cs()[0]).clone()), @@ -361,6 +366,9 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result bv_or(a, "sbv-to-fp").map(|_| Sort::F32), (Op::FpToFp(64), &[a]) => fp_or(a, "fp-to-fp").map(|_| Sort::F64), (Op::FpToFp(32), &[a]) => fp_or(a, "fp-to-fp").map(|_| Sort::F32), + (Op::PfToFp(64), &[a]) => pf_or(a, "pf-to-fp").map(|_| Sort::F64), + (Op::PfToFp(32), &[a]) => pf_or(a, "pf-to-fp").map(|_| Sort::F32), + (Op::FpToPf(m), &[a]) => fp_or(a, "fp-to-pf").map(|_| Sort::Field(m.clone())), (Op::PfNaryOp(_), a) => { let ctx = "pf nary op"; all_eq_or(a.iter().cloned(), ctx) diff --git a/src/target/r1cs/mod.rs b/src/target/r1cs/mod.rs index ba253f2cb..4c0f464b6 100644 --- a/src/target/r1cs/mod.rs +++ b/src/target/r1cs/mod.rs @@ -168,7 +168,7 @@ impl Var { VarType::Chall => 0b011, VarType::FinalWit => 0b100, }; - Var(ty_repr << Self::NUMBER_BITS | number) + Var((ty_repr << Self::NUMBER_BITS) | number) } fn ty(&self) -> VarType { match self.0 >> Self::NUMBER_BITS {