From f71704b564a0c9330dac899b8ab0710e9dc1adff Mon Sep 17 00:00:00 2001 From: Lorenzo Rota Date: Mon, 17 Feb 2025 19:23:24 +0100 Subject: [PATCH 01/10] Added floats to the IR textual format --- src/ir/term/text/lex.rs | 20 +++++++++++++++++-- src/ir/term/text/mod.rs | 43 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/src/ir/term/text/lex.rs b/src/ir/term/text/lex.rs index 7905cc839..5056cfb9e 100644 --- a/src/ir/term/text/lex.rs +++ b/src/ir/term/text/lex.rs @@ -19,7 +19,8 @@ pub enum Token { Bin, #[regex(br"#f-?[0-9]+(m[0-9]+)?")] Field, - // TODO: Float + #[regex(br"#fp(?i:nan|-?inf|-?\d+(?:\.\d+)?)?(?:[eE][+-]?\d+)?(f32|f64)?")] + Float, // Identifiers #[regex(br"#t|#a|#l|#m|[^()0-9#; \t\n\f][^(); \t\n\f#]*")] @@ -43,7 +44,7 @@ mod test { #[test] fn all_tokens() { - let l = Token::lexer(b"(let ((a true)(b true)) (add (sub #b01 #xFf) (div 15 -16)))"); + let l = Token::lexer(b"(let ((a true)(b true)) (add (sub #b01 #xFf) (div 15 -16)) (add #fpnan #fp0) (add #fp-inf #fpinf) (add #fp31.415926e-1f32 #fp0.31415926e+1f64))"); let tokens: Vec<_> = l.into_iter().collect(); assert_eq!( &tokens, @@ -73,6 +74,21 @@ mod test { Token::Int, Token::Close, Token::Close, + Token::Open, + Token::Ident, + Token::Float, + Token::Float, + Token::Close, + Token::Open, + Token::Ident, + Token::Float, + Token::Float, + Token::Close, + Token::Open, + Token::Ident, + Token::Float, + Token::Float, + Token::Close, Token::Close, ] ) diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index 58c81030a..3ae5195d5 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -50,6 +50,8 @@ //! * bit-vector: `#xFFFF...`, `#bBBBB...` //! * field literal: `#fDD` or `#fDDmDD`. //! * In the former case, an ambient modulus must be set. +//! * floating-point literals: `#fpNan`, `#fpInf`, `#fpDD.DD` +//! * In all cases, a precision can be specified as a suffix `f32` or `f64`. //! * tuple: `(#t V1 ... Vn)` //! * array: `(#a Sk V N ((Vk1 Vv1) ... (Vkn Vvn)))` //! * list: `(#l Sk (V1 ... Vn))` @@ -413,6 +415,38 @@ impl<'src> IrInterp<'src> { _ => panic!("Expected integer, got {}", tt), } } + + /// Parse this text as a double-precision floating-point number and return the value + /// with a boolean indicating whether it was marked as 'f32'. + fn parse_fp_literal(&mut self, lit: &str) -> (f64, bool) { + let _lit = lit.to_lowercase(); + if _lit == "inf" { + return (f64::INFINITY, false); + } + if _lit == "-inf" { + return (-f64::INFINITY, false); + } + if _lit == "nan" { + return (f64::NAN, false); + } + // Parse as f32 or (by default) as f64 + if let Some(end) = _lit.find("f32") { + let num_part = &lit[..end]; + let val = num_part.parse::() + .unwrap_or_else(|_| panic!("Invalid F32 literal '{}'", lit)); + return (val, true); + } else if let Some(end) = _lit.find("f64") { + let num_part = &lit[..end]; + let val = num_part.parse::() + .unwrap_or_else(|_| panic!("Invalid F64 literal '{}'", lit)); + return (val, false); + } else { + let val = lit.parse::() + .unwrap_or_else(|_| panic!("Invalid F64 literal '{}'", lit)); + (val, false) + } + } + fn usize(&self, tt: &TokTree) -> usize { self.maybe_usize(tt).unwrap() } @@ -519,6 +553,15 @@ impl<'src> IrInterp<'src> { }; pf_lit(FieldV::new::(v, m)) } + Leaf(Token::Float, bytes) => { + let lit = std::str::from_utf8(&bytes[3..]).unwrap(); + let (f64_val, is_f32) = self.parse_fp_literal(lit); + if is_f32 { + const_(Value::F32(f64_val as f32)) + } else { + const_(Value::F64(f64_val)) + } + } Leaf(Ident, b"false") => bool_lit(false), Leaf(Ident, b"true") => bool_lit(true), Leaf(Ident, n) => self.get_binding(n).clone(), From 748e1234f983b4b4797c26bc0e884e79bf15333d Mon Sep 17 00:00:00 2001 From: Lorenzo Rota Date: Tue, 18 Feb 2025 08:45:03 +0100 Subject: [PATCH 02/10] Introduced Op::PfToFp, added placeholders for FP ops in IR eval --- src/ir/term/eval.rs | 11 +++++++++++ src/ir/term/fmt.rs | 1 + src/ir/term/mod.rs | 4 ++++ src/ir/term/text/mod.rs | 3 ++- src/ir/term/ty.rs | 5 +++++ 5 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/ir/term/eval.rs b/src/ir/term/eval.rs index 3e0a010a8..2f91a112f 100644 --- a/src/ir/term/eval.rs +++ b/src/ir/term/eval.rs @@ -165,6 +165,7 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> } }), Op::BoolToBv => Value::BitVector(BitVector::new(Integer::from(args[0].as_bool()), 1)), + Op::PfUnOp(o) => Value::Field({ let a = args[0].as_pf().clone(); match o { @@ -216,6 +217,16 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> ) }), + Op::FpBinOp(o) => unimplemented!("Op::FpBinOp({o}) not implemented"), + Op::FpBinPred(o) => unimplemented!("Op::FpBinPred({o}) not implemented"), + Op::FpUnPred(o) => unimplemented!("Op::FpUnPred({o}) not implemented"), + Op::FpUnOp(o) => unimplemented!("Op::FpUnOp({o}) not implemented"), + Op::BvToFp => unimplemented!("Op::BvToFp not implemented"), + Op::UbvToFp(w) => unimplemented!("Op::UbvToFp({w}) not implemented"), + Op::SbvToFp(w) => unimplemented!("Op::SbvToFp({w}) not implemented"), + Op::FpToFp(w) => unimplemented!("Op::FpToFp({w}) not implemented"), + Op::PfToFp(w) => unimplemented!("Op::PfToFp({w}) not implemented"), + Op::IntBinOp(o) => Value::Int({ let a = args[0].as_int().clone(); let b = args[1].as_int().clone(); diff --git a/src/ir/term/fmt.rs b/src/ir/term/fmt.rs index 7575a37b6..436b59da4 100644 --- a/src/ir/term/fmt.rs +++ b/src/ir/term/fmt.rs @@ -364,6 +364,7 @@ impl DisplayIr for Op { Op::UbvToFp(a) => write!(f, "(ubv2fp {a})"), Op::SbvToFp(a) => write!(f, "(sbv2fp {a})"), Op::FpToFp(a) => write!(f, "(fp2fp {a})"), + Op::PfToFp(a) => write!(f, "(pf2fp {a})"), Op::PfUnOp(a) => write!(f, "{a}"), Op::PfNaryOp(a) => write!(f, "{a}"), Op::PfDiv => write!(f, "/"), diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index 09d90dad6..a108a6e37 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -119,6 +119,9 @@ pub enum Op { // dest width /// translate the number represented by the argument to a floating-point value of this width. FpToFp(usize), + /// translate the prime-field number represented by the argument to a floating-point value + /// of this width. + PfToFp(usize), /// Prime-field unary operator PfUnOp(PfUnOp), @@ -365,6 +368,7 @@ impl Op { Op::UbvToFp(_) => Some(1), Op::SbvToFp(_) => Some(1), Op::FpToFp(_) => Some(1), + Op::PfToFp(_) => Some(1), Op::PfUnOp(_) => Some(1), Op::PfDiv => Some(2), Op::PfNaryOp(_) => None, diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index 3ae5195d5..d7db5ddc9 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -51,7 +51,7 @@ //! * field literal: `#fDD` or `#fDDmDD`. //! * In the former case, an ambient modulus must be set. //! * floating-point literals: `#fpNan`, `#fpInf`, `#fpDD.DD` -//! * In all cases, a precision can be specified as a suffix `f32` or `f64`. +//! * In all cases, a precision can be specified by appending the suffix `f32` or `f64`. //! * tuple: `(#t V1 ... Vn)` //! * array: `(#a Sk V N ((Vk1 Vv1) ... (Vkn Vvn)))` //! * list: `(#l Sk (V1 ... Vn))` @@ -318,6 +318,7 @@ impl<'src> IrInterp<'src> { [Leaf(Ident, b"ubv2fp"), a] => Ok(Op::UbvToFp(self.usize(a))), [Leaf(Ident, b"sbv2fp"), a] => Ok(Op::SbvToFp(self.usize(a))), [Leaf(Ident, b"fp2fp"), a] => Ok(Op::FpToFp(self.usize(a))), + [Leaf(Ident, b"pf2fp"), a] => Ok(Op::PfToFp(self.usize(a))), [Leaf(Ident, b"challenge"), name, field] => Ok(Op::new_chall( self.ident_string(name), FieldT::from(self.int(field)), diff --git a/src/ir/term/ty.rs b/src/ir/term/ty.rs index 005ce0504..a33135c38 100644 --- a/src/ir/term/ty.rs +++ b/src/ir/term/ty.rs @@ -55,6 +55,7 @@ fn check_dependencies(t: &Term) -> Vec { Op::UbvToFp(_) => Vec::new(), Op::SbvToFp(_) => Vec::new(), Op::FpToFp(_) => Vec::new(), + Op::PfToFp(_) => Vec::new(), Op::PfUnOp(_) => vec![t.cs()[0].clone()], Op::PfDiv => vec![t.cs()[0].clone()], Op::PfNaryOp(_) => vec![t.cs()[0].clone()], @@ -139,6 +140,8 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result { Op::SbvToFp(32) => Ok(Sort::F32), Op::FpToFp(64) => Ok(Sort::F64), Op::FpToFp(32) => Ok(Sort::F32), + Op::PfToFp(64) => Ok(Sort::F64), + Op::PfToFp(32) => Ok(Sort::F32), Op::PfUnOp(_) => Ok(get_ty(&t.cs()[0]).clone()), Op::PfDiv => Ok(get_ty(&t.cs()[0]).clone()), Op::PfNaryOp(_) => Ok(get_ty(&t.cs()[0]).clone()), @@ -361,6 +364,8 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result bv_or(a, "sbv-to-fp").map(|_| Sort::F32), (Op::FpToFp(64), &[a]) => fp_or(a, "fp-to-fp").map(|_| Sort::F64), (Op::FpToFp(32), &[a]) => fp_or(a, "fp-to-fp").map(|_| Sort::F32), + (Op::PfToFp(64), &[a]) => pf_or(a, "pf-to-fp").map(|_| Sort::F64), + (Op::PfToFp(32), &[a]) => pf_or(a, "pf-to-fp").map(|_| Sort::F32), (Op::PfNaryOp(_), a) => { let ctx = "pf nary op"; all_eq_or(a.iter().cloned(), ctx) From f87ca5cbcebf9de3f614a0508e7866aca4cb08b5 Mon Sep 17 00:00:00 2001 From: Lorenzo Rota Date: Fri, 21 Feb 2025 11:11:20 +0100 Subject: [PATCH 03/10] Implemented FP Ops and added tests --- Cargo.lock | 3 +- Cargo.toml | 1 + src/ir/term/eval.rs | 126 ++++++- src/ir/term/mod.rs | 36 ++ src/ir/term/test.rs | 731 ++++++++++++++++++++++++++++++++++++++++ src/ir/term/text/mod.rs | 40 ++- 6 files changed, 903 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ca04c24ad..e94ccd287 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addchain" @@ -344,6 +344,7 @@ dependencies = [ "logos", "lp-solvers 0.0.4", "merlin", + "num-traits", "once_cell", "pairing", "paste", diff --git a/Cargo.toml b/Cargo.toml index 6d5113a4c..c2f4f85de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,7 @@ curve25519-dalek = {version = "3.2.0", features = ["serde"], optional = true} paste = "1.0" im = "15" once_cell = "1" +num-traits = "0.2" [dev-dependencies] quickcheck = "1" diff --git a/src/ir/term/eval.rs b/src/ir/term/eval.rs index 2f91a112f..f35319b5c 100644 --- a/src/ir/term/eval.rs +++ b/src/ir/term/eval.rs @@ -3,7 +3,7 @@ use super::{ check, const_, extras, term, Array, BitVector, BoolNaryOp, BvBinOp, BvBinPred, BvNaryOp, BvUnOp, FieldToBv, FxHashMap, IntBinOp, IntBinPred, IntNaryOp, IntUnOp, Integer, Node, Op, - PfNaryOp, PfUnOp, Sort, Term, TermMap, Value, + PfNaryOp, PfUnOp, Sort, Term, TermMap, Value, Float, FpBinOp, FpBinPred, FpUnPred, FpUnOp, }; use crate::cfg::cfg_or_default; @@ -216,17 +216,6 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> }, ) }), - - Op::FpBinOp(o) => unimplemented!("Op::FpBinOp({o}) not implemented"), - Op::FpBinPred(o) => unimplemented!("Op::FpBinPred({o}) not implemented"), - Op::FpUnPred(o) => unimplemented!("Op::FpUnPred({o}) not implemented"), - Op::FpUnOp(o) => unimplemented!("Op::FpUnOp({o}) not implemented"), - Op::BvToFp => unimplemented!("Op::BvToFp not implemented"), - Op::UbvToFp(w) => unimplemented!("Op::UbvToFp({w}) not implemented"), - Op::SbvToFp(w) => unimplemented!("Op::SbvToFp({w}) not implemented"), - Op::FpToFp(w) => unimplemented!("Op::FpToFp({w}) not implemented"), - Op::PfToFp(w) => unimplemented!("Op::PfToFp({w}) not implemented"), - Op::IntBinOp(o) => Value::Int({ let a = args[0].as_int().clone(); let b = args[1].as_int().clone(); @@ -250,6 +239,119 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> IntUnOp::Neg => -a, } }), + Op::FpBinOp(o) => { + // Promote to f64 if either operand is f64 + let promote_to_f64 = matches!(args[0], Value::F64(_)) || matches!(args[1], Value::F64(_)); + fn comp(a: T, b: T, op: FpBinOp) -> T { + match op { + FpBinOp::Add => a + b, + FpBinOp::Sub => a - b, + FpBinOp::Mul => a * b, + FpBinOp::Div => a / b, + FpBinOp::Rem => a % b, + FpBinOp::Max => a.max(b), + FpBinOp::Min => a.min(b), + } + } + if promote_to_f64 { + Value::F64(comp(args[0].as_f64(), args[1].as_f64(), *o)) + } else { + Value::F32(comp(args[0].as_f32(), args[1].as_f32(), *o)) + } + } + Op::FpBinPred(o) => Value::Bool({ + // Promote to f64 if either operand is f64 + let promote_to_f64 = matches!(args[0], Value::F64(_)) || matches!(args[1], Value::F64(_)); + fn comp(a: T, b: T, op: FpBinPred) -> bool { + match op { + FpBinPred::Le => a <= b, + FpBinPred::Lt => a < b, + FpBinPred::Eq => a == b, + FpBinPred::Ge => a >= b, + FpBinPred::Gt => a > b, + } + } + if promote_to_f64 { + comp(args[0].as_f64(), args[1].as_f64(), *o) + } else { + comp(args[0].as_f32(), args[1].as_f32(), *o) + } + }), + Op::FpUnPred(o) => Value::Bool({ + fn comp(a: T, op: FpUnPred) -> bool { + match op { + FpUnPred::Normal => a.is_normal(), + FpUnPred::Subnormal => a.is_subnormal(), + FpUnPred::Zero => a == T::zero(), + FpUnPred::Infinite => a.is_infinite(), + FpUnPred::Nan => a.is_nan(), + FpUnPred::Negative => a.is_sign_negative(), + FpUnPred::Positive => a.is_sign_positive(), + } + } + match args[0] { + Value::F32(a) => comp(*a, *o), + Value::F64(a) => comp(*a, *o), + _ => panic!("Expected F32 or F64, got {}", args[0]), + } + }), + Op::FpUnOp(o) => { + fn comp(a: T, op: FpUnOp) -> T { + match op { + FpUnOp::Neg => -a, + FpUnOp::Abs => a.abs(), + FpUnOp::Sqrt => a.sqrt(), + FpUnOp::Round => a.round(), + } + } + match args[0] { + Value::F32(a) => Value::F32(comp(*a, *o)), + Value::F64(a) => Value::F64(comp(*a, *o)), + _ => panic!("Expected F32 or F64, got {}", args[0]), + } + } + Op::BvToFp => { + let bv = args[0].as_bv(); + let val = bv.uint(); + let w = bv.width(); + match w { + 32 => Value::F32(f32::from_bits(val.to_u32().unwrap())), + 64 => Value::F64(f64::from_bits(val.to_u64().unwrap())), + _ => panic!("{} out of bounds for {} on {:?}", w, op, args), + } + } + Op::UbvToFp(w) => { + let val = args[0].as_bv().uint(); + match w { + 0..=32 => Value::F32(val.to_f32()), + 33..=64 => Value::F64(val.to_f64()), + _ => panic!("{} out of bounds for {} on {:?}", w, op, args), + } + } + Op::SbvToFp(w) => { + let val = args[0].as_bv().as_sint(); + match w { + 0..=32 => Value::F32(val.to_f32()), + 33..=64 => Value::F64(val.to_f64()), + _ => panic!("{} out of bounds for {} on {:?}", w, op, args), + } + } + Op::FpToFp(w) => { + match (args[0], w) { + (Value::F32(v), 64) => Value::F64(*v as f64), // Promote F32 to F64 + (Value::F64(v), 32) => Value::F32(*v as f32), // Truncate F64 to F32 + (Value::F32(_), 32) | (Value::F64(_), 64) => args[0].clone(), + _ => panic!("Invalid conversion width {} (expected 32 or 64)", w), + } + } + Op::PfToFp(w) => { + let val = args[0].as_pf().i(); + match w { + 32 => Value::F32(val.to_f32()), + 64 => Value::F64(val.to_f64()), + _ => panic!("{} out of bounds for {} on {:?} (expected 32 or 64)", w, op, args), + } + } Op::UbvToPf(fty) => Value::Field(fty.new_v(args[0].as_bv().uint())), Op::PfChallenge(c) => Value::Field(eval_pf_challenge(&c.name, &c.field)), Op::Witness(_) => args[0].clone(), diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index a108a6e37..90f23ff2a 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -27,6 +27,7 @@ use circ_opt::FieldToBv; use fxhash::{FxHashMap, FxHashSet}; use log::debug; use rug::Integer; +use num_traits::Float; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::borrow::Borrow; use std::cell::Cell; @@ -336,6 +337,7 @@ pub const INT_LE: Op = Op::IntBinPred(IntBinPred::Le); pub const INT_GT: Op = Op::IntBinPred(IntBinPred::Gt); /// integer greater than or equal pub const INT_GE: Op = Op::IntBinPred(IntBinPred::Ge); +// TODO: add floating-point operator abbreviations impl Op { /// Number of arguments for this operator. `None` if n-ary. @@ -1267,6 +1269,22 @@ impl Term { None } } + /// Get the underlying 32-bit floating-point constant, if possible. + pub fn as_f32_opt(&self) -> Option { + if let Some(Value::F32(v)) = self.as_value_opt() { + Some(*v) + } else { + None + } + } + /// Get the underlying 64-bit floating-point constant, if possible. + pub fn as_f64_opt(&self) -> Option { + match self.as_value_opt()? { + Value::F64(v) => Some(*v), + Value::F32(v) => Some(*v as f64), // Floating-point promotion + _ => None, + } + } /// Get the underlying prime field constant, if possible. pub fn as_pf_opt(&self) -> Option<&FieldV> { if let Some(Value::Field(b)) = self.as_value_opt() { @@ -1385,6 +1403,24 @@ impl Value { } } #[track_caller] + /// Get the underlying 32-bit floating-point constant, or panic! + pub fn as_f32(&self) -> f32 { + if let Value::F32(v) = self { + *v + } else { + panic!("Not a f32: {}", self) + } + } + #[track_caller] + /// Get the underlying 64-bit floating-point constant, or panic! + pub fn as_f64(&self) -> f64 { + match self { + Value::F32(v) => *v as f64, // Floating-point promotion + Value::F64(v) => *v, + _ => panic!("Not a f64 or f32: {}", self), + } + } + #[track_caller] /// Get the underlying prime field constant, if possible. pub fn as_pf(&self) -> &FieldV { if let Value::Field(b) = self { diff --git a/src/ir/term/test.rs b/src/ir/term/test.rs index ecfb8c246..ea26dedbb 100644 --- a/src/ir/term/test.rs +++ b/src/ir/term/test.rs @@ -634,3 +634,734 @@ fn pf2bool_eval() { let expected_output = text::parse_value_map(b"(let ((output false)) false)"); assert_eq!(&actual_output, expected_output.get("output").unwrap()); } + +#[test] +fn fpsub_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsub #fp3 #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsub #fp3.333333f32 #fp0.191741f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3.141592f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsub #fp3.333333333333333 #fp0.19174067974354)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3.141592653589793") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsub #fp-3.402823466e+38f32 #fp2e31f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inff32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsub #fp-1.7976931348623158e+308 #fp1e292)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inf") + ); +} + +#[test] +fn fpadd_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp1 #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp1.1111 #fp-1.1111)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp1e-324 #fp-1e-324)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp3.402823466e+38f32 #fp2e31f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInff32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp1.7976931348623158e+308 #fp1e+292)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInf") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fpInf #fp1)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInf") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp-Inf #fp1)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inf") + ); + assert!( + const_(eval( + &text::parse_term(b"(fpadd #fpInff32 #fp-Inff32)"), + &FxHashMap::default() + )).as_f32_opt().unwrap().is_nan() + ); + assert!( + const_(eval( + &text::parse_term(b"(fpadd #fpNaN #fp2)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); + assert!( + const_(eval( + &text::parse_term(b"(fpadd #fpNaN #fpNaN)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpadd #fp1.0000001f32 #fp1e-7f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0000002f32") + ); +} + +#[test] +fn fpmul_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmul #fp3 #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0") + ); + assert!( + const_(eval( + &text::parse_term(b"(fpmul #fp-Inf #fp0)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmul #fpInf #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInf") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmul #fpInf #fp-1)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inf") + ); + assert!( + const_(eval( + &text::parse_term(b"(fpmul #fpNaN #fp2)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); + assert!( + const_(eval( + &text::parse_term(b"(fpmul #fpNaN #fpNaN)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); +} + +#[test] +fn fpdiv_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpdiv #fp1 #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInf") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpdiv #fp-1 #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inf") + ); + assert!( + const_(eval( + &text::parse_term(b"(fpdiv #fpNaN #fp2)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); +} + +#[test] +fn fprem_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fprem #fp3.1415 #fp3)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0.14150000000000018") + ); + assert!( + const_(eval( + &text::parse_term(b"(fprem #fp3 #fp0)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); + assert!( + const_(eval( + &text::parse_term(b"(fprem #fpInf #fp2)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); +} + +#[test] +fn fpmax_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmax #fp2 #fp5)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp5") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmax #fpNaN #fp3)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmax #fp-0 #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0") + ); +} + +#[test] +fn fpmin_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpmin #fp-3 #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-3") + ); +} + +#[test] +fn fpneg_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpneg #fp3)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-3") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpneg #fp-4.5)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp4.5") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpneg #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpneg #fpInf)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-Inf") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpneg #fp-Inf)"), + &FxHashMap::default() + )), + text::parse_term(b"#fpInf") + ); + assert!( + const_(eval( + &text::parse_term(b"(fpneg #fpNan)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); +} + +#[test] +fn fpabs_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpabs #fp-7)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp7") + ); +} + +#[test] +fn fpsqrt_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsqrt #fp12.25)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3.5") + ); + assert!( + const_(eval( + &text::parse_term(b"(fpsqrt #fp-4)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); +} + +#[test] +fn fpround_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpround #fp2.5)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp3") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpround #fp2.3)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp2") + ); + assert!( + const_(eval( + &text::parse_term(b"(fpround #fpNan)"), + &FxHashMap::default() + )).as_f64_opt().unwrap().is_nan() + ); +} + +#[test] +fn fpge_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpge #fp5 #fp3)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpge #fp3 #fp5)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpge #fp4 #fp4)"), + &FxHashMap::default())), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpge #fpNaN #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fpgt_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpgt #fp4 #fp4)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fplt #fpInf #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn fple_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fple #fpNaN #fp2)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fplt_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fplt #fp-Inf #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn fpeq_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpeq #fp1.0 #fp1)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpeq (fpadd #fp0.1 #fp0.2) #fp0.3)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpeq #fpNaN #fpNaN)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fpnormal_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnormal #fp0.01)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnormal #fp1.17549435e-39f32)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fpsubnormal_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsubnormal #fp0.01)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + + assert_eq!( + const_(eval( + &text::parse_term(b"(fpsubnormal #fp1.17549435e-39f32)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn fpzero_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpzero #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpzero #fp-0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpzero #fp1e-45f32)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fpinfinite_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpinfinite #fpInf)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpinfinite #fp-Inf)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpinfinite #fp1.7976931348623158e+308)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpinfinite #fp1.7976931348623159e+308)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn fpnan_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnan #fpNaN)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnan #fpInf)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnan #fp3.14)"), + &FxHashMap::default() + )), + text::parse_term(b"false") + ); +} + +#[test] +fn fpnegative_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnegative #fp-2.5)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fpnegative #fp-0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn fppositive_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(fppositive #fp3.7)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(fppositive #fp0)"), + &FxHashMap::default() + )), + text::parse_term(b"true") + ); +} + +#[test] +fn bv2fp_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"(bv2fp #b00111111100000000000000000000000)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(bv2fp #b00000000000000000000000000000000)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(bv2fp #b0011111111110000000000000000000000000000000000000000000000000000)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(bv2fp #b0000000000000000000000000000000000000000000000000000000000000000)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"(bv2fp #b1000000000000000000000000000000000000000000000000000000000000000)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-0") + ); +} + +#[test] +fn ubv2fp_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"((ubv2fp 32) #x0001)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"((ubv2fp 64) #x0001)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"((ubv2fp 64) #xffffffff)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp4294967295.0") + ); +} + +#[test] +fn sbv2fp_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"((sbv2fp 64) #xffffffffffffffff)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp-1.0") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"((sbv2fp 64) #x7fffffff)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp2147483647.0") + ); +} + +#[test] +fn fp2fp_tests() { + assert_eq!( + const_(eval( + &text::parse_term(b"((fp2fp 64) #fp1.0f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp1.0f64") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"((fp2fp 32) #fp16777217.0f64)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp16777216.0f32") + ); + assert_eq!( + const_(eval( + &text::parse_term(b"((fp2fp 32) #fp2.5f32)"), + &FxHashMap::default() + )), + text::parse_term(b"#fp2.5f32") + ); +} + +#[test] +fn pf2fp_tests() { + assert_eq!( + const_(eval(&text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f1))"), &FxHashMap::default())), + text::parse_term(b"#fp1.0f32") + ); + assert_eq!( + const_(eval(&text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f20))"), &FxHashMap::default())), + text::parse_term(b"#fp3.0f32") + ); + assert_eq!( + const_(eval(&text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f17))"), &FxHashMap::default())), + text::parse_term(b"#fp17f32") // shouldn't this be 0? + ); + let p: Integer = Integer::from(1) << 512 - 1; // Mersenne prime + let overflow_f32: Integer = p.clone() - 1; + assert_eq!( + const_(eval( + &text::parse_term(format!("(set_default_modulus {p} ((pf2fp 32) #f{overflow_f32}))").as_bytes()), + &FxHashMap::default() + )), + text::parse_term(b"#fpInff32") + ); + assert_eq!( + const_(eval( + &text::parse_term(format!("(set_default_modulus {p} ((pf2fp 64) #f{overflow_f32}))").as_bytes()), + &FxHashMap::default() + )), + text::parse_term(b"#fp6.703903964971298e+153") + ); +} diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index d7db5ddc9..d5086074c 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -417,35 +417,38 @@ impl<'src> IrInterp<'src> { } } - /// Parse this text as a double-precision floating-point number and return the value - /// with a boolean indicating whether it was marked as 'f32'. - fn parse_fp_literal(&mut self, lit: &str) -> (f64, bool) { + /// Parse this text as a [Value::F32] or [Value::F64] literal. + fn parse_fp_literal(&mut self, lit: &str) -> Value { let _lit = lit.to_lowercase(); + if _lit == "inf" { - return (f64::INFINITY, false); + return Value::F64(f64::INFINITY); } if _lit == "-inf" { - return (-f64::INFINITY, false); + return Value::F64(f64::NEG_INFINITY); } if _lit == "nan" { - return (f64::NAN, false); + return Value::F64(f64::NAN); } - // Parse as f32 or (by default) as f64 + + // Parse as F32 when "f32" is found if let Some(end) = _lit.find("f32") { let num_part = &lit[..end]; - let val = num_part.parse::() + let val = num_part.parse::() .unwrap_or_else(|_| panic!("Invalid F32 literal '{}'", lit)); - return (val, true); - } else if let Some(end) = _lit.find("f64") { + return Value::F32(val); + } + // Parse as F64 when "f64" is found + else if let Some(end) = _lit.find("f64") { let num_part = &lit[..end]; let val = num_part.parse::() .unwrap_or_else(|_| panic!("Invalid F64 literal '{}'", lit)); - return (val, false); - } else { - let val = lit.parse::() - .unwrap_or_else(|_| panic!("Invalid F64 literal '{}'", lit)); - (val, false) + return Value::F64(val); } + // Default: parse as F64 + let val = lit.parse::() + .unwrap_or_else(|_| panic!("Invalid F64 literal '{}'", lit)); + Value::F64(val) } fn usize(&self, tt: &TokTree) -> usize { @@ -556,12 +559,7 @@ impl<'src> IrInterp<'src> { } Leaf(Token::Float, bytes) => { let lit = std::str::from_utf8(&bytes[3..]).unwrap(); - let (f64_val, is_f32) = self.parse_fp_literal(lit); - if is_f32 { - const_(Value::F32(f64_val as f32)) - } else { - const_(Value::F64(f64_val)) - } + const_(self.parse_fp_literal(lit)) } Leaf(Ident, b"false") => bool_lit(false), Leaf(Ident, b"true") => bool_lit(true), From 632a38bebe1afede53facf9246da9c8328123f32 Mon Sep 17 00:00:00 2001 From: Lorenzo Rota Date: Fri, 21 Feb 2025 11:33:44 +0100 Subject: [PATCH 04/10] Apply fmt and fix linting errors --- src/ir/term/eval.rs | 16 +++- src/ir/term/mod.rs | 2 +- src/ir/term/test.rs | 196 +++++++++++++++++++++++----------------- src/ir/term/text/mod.rs | 15 +-- 4 files changed, 135 insertions(+), 94 deletions(-) diff --git a/src/ir/term/eval.rs b/src/ir/term/eval.rs index f35319b5c..a265087cc 100644 --- a/src/ir/term/eval.rs +++ b/src/ir/term/eval.rs @@ -2,8 +2,9 @@ use super::{ check, const_, extras, term, Array, BitVector, BoolNaryOp, BvBinOp, BvBinPred, BvNaryOp, - BvUnOp, FieldToBv, FxHashMap, IntBinOp, IntBinPred, IntNaryOp, IntUnOp, Integer, Node, Op, - PfNaryOp, PfUnOp, Sort, Term, TermMap, Value, Float, FpBinOp, FpBinPred, FpUnPred, FpUnOp, + BvUnOp, FieldToBv, Float, FpBinOp, FpBinPred, FpUnOp, FpUnPred, FxHashMap, IntBinOp, + IntBinPred, IntNaryOp, IntUnOp, Integer, Node, Op, PfNaryOp, PfUnOp, Sort, Term, TermMap, + Value, }; use crate::cfg::cfg_or_default; @@ -241,7 +242,8 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> }), Op::FpBinOp(o) => { // Promote to f64 if either operand is f64 - let promote_to_f64 = matches!(args[0], Value::F64(_)) || matches!(args[1], Value::F64(_)); + let promote_to_f64 = + matches!(args[0], Value::F64(_)) || matches!(args[1], Value::F64(_)); fn comp(a: T, b: T, op: FpBinOp) -> T { match op { FpBinOp::Add => a + b, @@ -261,7 +263,8 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> } Op::FpBinPred(o) => Value::Bool({ // Promote to f64 if either operand is f64 - let promote_to_f64 = matches!(args[0], Value::F64(_)) || matches!(args[1], Value::F64(_)); + let promote_to_f64 = + matches!(args[0], Value::F64(_)) || matches!(args[1], Value::F64(_)); fn comp(a: T, b: T, op: FpBinPred) -> bool { match op { FpBinPred::Le => a <= b, @@ -349,7 +352,10 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> match w { 32 => Value::F32(val.to_f32()), 64 => Value::F64(val.to_f64()), - _ => panic!("{} out of bounds for {} on {:?} (expected 32 or 64)", w, op, args), + _ => panic!( + "{} out of bounds for {} on {:?} (expected 32 or 64)", + w, op, args + ), } } Op::UbvToPf(fty) => Value::Field(fty.new_v(args[0].as_bv().uint())), diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index 90f23ff2a..bb5e76414 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -26,8 +26,8 @@ pub use circ_hc::{Node, Table, Weak}; use circ_opt::FieldToBv; use fxhash::{FxHashMap, FxHashSet}; use log::debug; -use rug::Integer; use num_traits::Float; +use rug::Integer; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::borrow::Borrow; use std::cell::Cell; diff --git a/src/ir/term/test.rs b/src/ir/term/test.rs index ea26dedbb..47d97f97f 100644 --- a/src/ir/term/test.rs +++ b/src/ir/term/test.rs @@ -725,24 +725,27 @@ fn fpadd_tests() { )), text::parse_term(b"#fp-Inf") ); - assert!( - const_(eval( - &text::parse_term(b"(fpadd #fpInff32 #fp-Inff32)"), - &FxHashMap::default() - )).as_f32_opt().unwrap().is_nan() - ); - assert!( - const_(eval( - &text::parse_term(b"(fpadd #fpNaN #fp2)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); - assert!( - const_(eval( - &text::parse_term(b"(fpadd #fpNaN #fpNaN)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); + assert!(const_(eval( + &text::parse_term(b"(fpadd #fpInff32 #fp-Inff32)"), + &FxHashMap::default() + )) + .as_f32_opt() + .unwrap() + .is_nan()); + assert!(const_(eval( + &text::parse_term(b"(fpadd #fpNaN #fp2)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); + assert!(const_(eval( + &text::parse_term(b"(fpadd #fpNaN #fpNaN)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); assert_eq!( const_(eval( &text::parse_term(b"(fpadd #fp1.0000001f32 #fp1e-7f32)"), @@ -761,12 +764,13 @@ fn fpmul_tests() { )), text::parse_term(b"#fp0") ); - assert!( - const_(eval( - &text::parse_term(b"(fpmul #fp-Inf #fp0)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); + assert!(const_(eval( + &text::parse_term(b"(fpmul #fp-Inf #fp0)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); assert_eq!( const_(eval( &text::parse_term(b"(fpmul #fpInf #fp2)"), @@ -781,18 +785,20 @@ fn fpmul_tests() { )), text::parse_term(b"#fp-Inf") ); - assert!( - const_(eval( - &text::parse_term(b"(fpmul #fpNaN #fp2)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); - assert!( - const_(eval( - &text::parse_term(b"(fpmul #fpNaN #fpNaN)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); + assert!(const_(eval( + &text::parse_term(b"(fpmul #fpNaN #fp2)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); + assert!(const_(eval( + &text::parse_term(b"(fpmul #fpNaN #fpNaN)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); } #[test] @@ -811,12 +817,13 @@ fn fpdiv_tests() { )), text::parse_term(b"#fp-Inf") ); - assert!( - const_(eval( - &text::parse_term(b"(fpdiv #fpNaN #fp2)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); + assert!(const_(eval( + &text::parse_term(b"(fpdiv #fpNaN #fp2)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); } #[test] @@ -828,18 +835,20 @@ fn fprem_tests() { )), text::parse_term(b"#fp0.14150000000000018") ); - assert!( - const_(eval( - &text::parse_term(b"(fprem #fp3 #fp0)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); - assert!( - const_(eval( - &text::parse_term(b"(fprem #fpInf #fp2)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); + assert!(const_(eval( + &text::parse_term(b"(fprem #fp3 #fp0)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); + assert!(const_(eval( + &text::parse_term(b"(fprem #fpInf #fp2)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); } #[test] @@ -915,12 +924,13 @@ fn fpneg_tests() { )), text::parse_term(b"#fpInf") ); - assert!( - const_(eval( - &text::parse_term(b"(fpneg #fpNan)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); + assert!(const_(eval( + &text::parse_term(b"(fpneg #fpNan)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); } #[test] @@ -943,12 +953,13 @@ fn fpsqrt_tests() { )), text::parse_term(b"#fp3.5") ); - assert!( - const_(eval( - &text::parse_term(b"(fpsqrt #fp-4)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); + assert!(const_(eval( + &text::parse_term(b"(fpsqrt #fp-4)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); } #[test] @@ -967,12 +978,13 @@ fn fpround_tests() { )), text::parse_term(b"#fp2") ); - assert!( - const_(eval( - &text::parse_term(b"(fpround #fpNan)"), - &FxHashMap::default() - )).as_f64_opt().unwrap().is_nan() - ); + assert!(const_(eval( + &text::parse_term(b"(fpround #fpNan)"), + &FxHashMap::default() + )) + .as_f64_opt() + .unwrap() + .is_nan()); } #[test] @@ -994,7 +1006,8 @@ fn fpge_tests() { assert_eq!( const_(eval( &text::parse_term(b"(fpge #fp4 #fp4)"), - &FxHashMap::default())), + &FxHashMap::default() + )), text::parse_term(b"true") ); assert_eq!( @@ -1245,21 +1258,27 @@ fn bv2fp_tests() { ); assert_eq!( const_(eval( - &text::parse_term(b"(bv2fp #b0011111111110000000000000000000000000000000000000000000000000000)"), + &text::parse_term( + b"(bv2fp #b0011111111110000000000000000000000000000000000000000000000000000)" + ), &FxHashMap::default() )), text::parse_term(b"#fp1.0") ); assert_eq!( const_(eval( - &text::parse_term(b"(bv2fp #b0000000000000000000000000000000000000000000000000000000000000000)"), + &text::parse_term( + b"(bv2fp #b0000000000000000000000000000000000000000000000000000000000000000)" + ), &FxHashMap::default() )), text::parse_term(b"#fp0") ); assert_eq!( const_(eval( - &text::parse_term(b"(bv2fp #b1000000000000000000000000000000000000000000000000000000000000000)"), + &text::parse_term( + b"(bv2fp #b1000000000000000000000000000000000000000000000000000000000000000)" + ), &FxHashMap::default() )), text::parse_term(b"#fp-0") @@ -1337,29 +1356,42 @@ fn fp2fp_tests() { #[test] fn pf2fp_tests() { assert_eq!( - const_(eval(&text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f1))"), &FxHashMap::default())), + const_(eval( + &text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f1))"), + &FxHashMap::default() + )), text::parse_term(b"#fp1.0f32") ); assert_eq!( - const_(eval(&text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f20))"), &FxHashMap::default())), + const_(eval( + &text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f20))"), + &FxHashMap::default() + )), text::parse_term(b"#fp3.0f32") ); assert_eq!( - const_(eval(&text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f17))"), &FxHashMap::default())), + const_(eval( + &text::parse_term(b"(set_default_modulus 17 ((pf2fp 32) #f17))"), + &FxHashMap::default() + )), text::parse_term(b"#fp17f32") // shouldn't this be 0? ); - let p: Integer = Integer::from(1) << 512 - 1; // Mersenne prime + let p: Integer = Integer::from(1) << (512 - 1); // Mersenne prime let overflow_f32: Integer = p.clone() - 1; assert_eq!( const_(eval( - &text::parse_term(format!("(set_default_modulus {p} ((pf2fp 32) #f{overflow_f32}))").as_bytes()), + &text::parse_term( + format!("(set_default_modulus {p} ((pf2fp 32) #f{overflow_f32}))").as_bytes() + ), &FxHashMap::default() )), text::parse_term(b"#fpInff32") ); assert_eq!( const_(eval( - &text::parse_term(format!("(set_default_modulus {p} ((pf2fp 64) #f{overflow_f32}))").as_bytes()), + &text::parse_term( + format!("(set_default_modulus {p} ((pf2fp 64) #f{overflow_f32}))").as_bytes() + ), &FxHashMap::default() )), text::parse_term(b"#fp6.703903964971298e+153") diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index d5086074c..48f2fb91b 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -420,7 +420,7 @@ impl<'src> IrInterp<'src> { /// Parse this text as a [Value::F32] or [Value::F64] literal. fn parse_fp_literal(&mut self, lit: &str) -> Value { let _lit = lit.to_lowercase(); - + if _lit == "inf" { return Value::F64(f64::INFINITY); } @@ -430,23 +430,26 @@ impl<'src> IrInterp<'src> { if _lit == "nan" { return Value::F64(f64::NAN); } - + // Parse as F32 when "f32" is found if let Some(end) = _lit.find("f32") { let num_part = &lit[..end]; - let val = num_part.parse::() + let val = num_part + .parse::() .unwrap_or_else(|_| panic!("Invalid F32 literal '{}'", lit)); return Value::F32(val); - } + } // Parse as F64 when "f64" is found else if let Some(end) = _lit.find("f64") { let num_part = &lit[..end]; - let val = num_part.parse::() + let val = num_part + .parse::() .unwrap_or_else(|_| panic!("Invalid F64 literal '{}'", lit)); return Value::F64(val); } // Default: parse as F64 - let val = lit.parse::() + let val = lit + .parse::() .unwrap_or_else(|_| panic!("Invalid F64 literal '{}'", lit)); Value::F64(val) } From 68a9c7a440d4782f2d10b449e161cdc302dfd2a7 Mon Sep 17 00:00:00 2001 From: Lorenzo Rota Date: Fri, 21 Feb 2025 12:06:09 +0100 Subject: [PATCH 05/10] Fixed linting error in target/r1cs/mod.rs --- src/target/r1cs/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/r1cs/mod.rs b/src/target/r1cs/mod.rs index ba253f2cb..4c0f464b6 100644 --- a/src/target/r1cs/mod.rs +++ b/src/target/r1cs/mod.rs @@ -168,7 +168,7 @@ impl Var { VarType::Chall => 0b011, VarType::FinalWit => 0b100, }; - Var(ty_repr << Self::NUMBER_BITS | number) + Var((ty_repr << Self::NUMBER_BITS) | number) } fn ty(&self) -> VarType { match self.0 >> Self::NUMBER_BITS { From 1be551478c8d93960e5c41a2466f8d4aaba61049 Mon Sep 17 00:00:00 2001 From: Lorenzo Rota Date: Fri, 21 Feb 2025 12:27:56 +0100 Subject: [PATCH 06/10] Fixed 2 tests with minor mistakes --- src/ir/term/test.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ir/term/test.rs b/src/ir/term/test.rs index 47d97f97f..dc282ce9e 100644 --- a/src/ir/term/test.rs +++ b/src/ir/term/test.rs @@ -872,7 +872,7 @@ fn fpmax_tests() { &text::parse_term(b"(fpmax #fp-0 #fp0)"), &FxHashMap::default() )), - text::parse_term(b"#fp0") + text::parse_term(b"#fp-0") // interesting ); } @@ -1030,7 +1030,7 @@ fn fpgt_tests() { ); assert_eq!( const_(eval( - &text::parse_term(b"(fplt #fpInf #fp0)"), + &text::parse_term(b"(fpgt #fpInf #fp0)"), &FxHashMap::default() )), text::parse_term(b"true") From 7052691bedc0f312948f65d26319bf2b852631d6 Mon Sep 17 00:00:00 2001 From: Lorenzo Rota Date: Tue, 25 Feb 2025 11:10:31 +0100 Subject: [PATCH 07/10] Removed implicit float promotion in FpBinOp and FpBinPred --- src/ir/term/eval.rs | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/ir/term/eval.rs b/src/ir/term/eval.rs index a265087cc..b4bdf61d7 100644 --- a/src/ir/term/eval.rs +++ b/src/ir/term/eval.rs @@ -241,9 +241,6 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> } }), Op::FpBinOp(o) => { - // Promote to f64 if either operand is f64 - let promote_to_f64 = - matches!(args[0], Value::F64(_)) || matches!(args[1], Value::F64(_)); fn comp(a: T, b: T, op: FpBinOp) -> T { match op { FpBinOp::Add => a + b, @@ -255,16 +252,20 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> FpBinOp::Min => a.min(b), } } - if promote_to_f64 { - Value::F64(comp(args[0].as_f64(), args[1].as_f64(), *o)) - } else { - Value::F32(comp(args[0].as_f32(), args[1].as_f32(), *o)) + match (args[0], args[1]) { + (Value::F32(_), Value::F32(_)) => { + Value::F32(comp(args[0].as_f32(), args[1].as_f32(), *o)) + } + (Value::F64(_), Value::F64(_)) => { + Value::F64(comp(args[0].as_f64(), args[1].as_f64(), *o)) + } + _ => panic!( + "Expected two F32 or F64, got LHS {} and RHS {}", + args[0], args[1] + ), } } Op::FpBinPred(o) => Value::Bool({ - // Promote to f64 if either operand is f64 - let promote_to_f64 = - matches!(args[0], Value::F64(_)) || matches!(args[1], Value::F64(_)); fn comp(a: T, b: T, op: FpBinPred) -> bool { match op { FpBinPred::Le => a <= b, @@ -274,10 +275,13 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> FpBinPred::Gt => a > b, } } - if promote_to_f64 { - comp(args[0].as_f64(), args[1].as_f64(), *o) - } else { - comp(args[0].as_f32(), args[1].as_f32(), *o) + match (args[0], args[1]) { + (Value::F32(_), Value::F32(_)) => comp(args[0].as_f32(), args[1].as_f32(), *o), + (Value::F64(_), Value::F64(_)) => comp(args[0].as_f64(), args[1].as_f64(), *o), + _ => panic!( + "Expected two F32 or F64, got LHS {} and RHS {}", + args[0], args[1] + ), } }), Op::FpUnPred(o) => Value::Bool({ From 247dea298dae564ef3af9d5db6a3467f1e3a7b7e Mon Sep 17 00:00:00 2001 From: Lorenzo Rota Date: Tue, 25 Feb 2025 11:26:15 +0100 Subject: [PATCH 08/10] Removed implicit conversion of F32 into F64 from as_f64_opt and as_f64 --- src/ir/term/mod.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index bb5e76414..1fab9bad1 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -1281,7 +1281,6 @@ impl Term { pub fn as_f64_opt(&self) -> Option { match self.as_value_opt()? { Value::F64(v) => Some(*v), - Value::F32(v) => Some(*v as f64), // Floating-point promotion _ => None, } } @@ -1415,7 +1414,6 @@ impl Value { /// Get the underlying 64-bit floating-point constant, or panic! pub fn as_f64(&self) -> f64 { match self { - Value::F32(v) => *v as f64, // Floating-point promotion Value::F64(v) => *v, _ => panic!("Not a f64 or f32: {}", self), } From 254228384b692d91515b619dbd92062c2fc64d0c Mon Sep 17 00:00:00 2001 From: Lorenzo Rota Date: Wed, 12 Mar 2025 15:22:48 +0100 Subject: [PATCH 09/10] Initial progress on compiler pass for IEEE754 * Introduced new optimization pass called Float * Extended IR with Op::FpToPf * Added new FP operations constants * visit function fully handles Op::Const * Op::Var is being worked on. Todo: * Implement all required hints and gadgets * Extend FLoatRewriter to handle lookups and table updates * Extend FloatRewriter to transform lookups into interactive proofs * Complete Op::Var case in visit fn before rest --- src/ir/opt/fp/extras.rs | 130 +++++++++++++++ src/ir/opt/fp/mod.rs | 361 ++++++++++++++++++++++++++++++++++++++++ src/ir/opt/mod.rs | 6 + src/ir/term/eval.rs | 11 ++ src/ir/term/fmt.rs | 1 + src/ir/term/mod.rs | 50 +++++- src/ir/term/ty.rs | 3 + 7 files changed, 561 insertions(+), 1 deletion(-) create mode 100644 src/ir/opt/fp/extras.rs create mode 100644 src/ir/opt/fp/mod.rs diff --git a/src/ir/opt/fp/extras.rs b/src/ir/opt/fp/extras.rs new file mode 100644 index 000000000..ac4042a41 --- /dev/null +++ b/src/ir/opt/fp/extras.rs @@ -0,0 +1,130 @@ +//! Extra algorithms for hint (precompute) and gadget (sub-circuit) constructions + +use crate::cfg::cfg_or_default as cfg; +use crate::ir::term::*; + +use super::FloatRewriter; + +/// Wrapper for pf_lit +pub fn new_pf_lit(v: I) -> Term +where + rug::Integer: From, +{ + pf_lit(cfg().field().new_v(v)) +} + +/// Computes a hint given a (mutable) computation, variable name, and hint function +/// +/// Syntax: +/// +/// * without children: `compute_hint![COMP, NAME, HINT]` +/// * with children: `compute_hint![COMP, NAME, HINT; ARG0, ARG1, ... ]` +/// +/// Returns a variable term associated to a precompute term that can be used in the computation. +#[macro_export] +macro_rules! compute_hint { + ($comp:expr, $name:expr, $hint_fn:path) => {{ + let hint = $hint_fn(); + let sort = check(&hint); + $comp.extend_precomputation($name.clone(), hint); + var($name, sort) + }}; + + ($comp:expr, $name:expr, $hint_fn:path; $($args:expr),+) => {{ + let hint = $hint_fn($($args),+); + let sort = check(&hint); + $comp.extend_precomputation($name.clone(), hint); + var($name, sort) + }}; +} + +/// Hints aka nondeterministic advice +/// +/// These are unconstrained precompute terms, which _should_ be enforced by other +/// mechanisms in the computation. +pub struct Hint; +/// Frequently used gadgets +pub struct Gadget<'a> { + pub rewriter: &'a mut FloatRewriter, +} + +impl Hint { + /// Precompute the sign and exponent of a float, together with its pf encoding. + pub fn decode_float(var: &Term, e_size: usize, m_size: usize) -> Term { + let var_pf = match var.op() { + Op::Var(_) => term![Op::FpToPf(cfg().field().clone()); var.clone()], + _ => panic!("Expected Op::Var but found something else"), + }; + let var_bv = term![Op::PfToBv(64); var_pf.clone()]; + + let s = term![BV_LSHR; var_bv.clone(), bv_lit(e_size + m_size, 64)]; + let shifted_v = term![BV_LSHR; var_bv.clone(), bv_lit(m_size, 64)]; + let shifted_s = term![BV_SHL; s.clone(), bv_lit(e_size, 64)]; + let e = term![BV_SUB; shifted_v, shifted_s]; + + term![ + Op::Tuple; + term![Op::UbvToPf(Box::new(cfg().field().clone())); s], + term![Op::UbvToPf(Box::new(cfg().field().clone())); e], + var_pf + ] + } + /// Precompute the number of leading zeros of mantissa within m_size bits + pub fn normalize(var: &Term, m_size: usize) -> Term { + let mantissa = term![Op::PfToBv(m_size); var.clone()]; + + let mut d = bv_lit(0, m_size); + // shift length is number of zeros + // d := ite (bit == 0) (d + 1) (d) + for i in (0..m_size).rev() { + let bit = term![Op::BvBit(i); mantissa.clone()]; + let cond = term![EQ; bit, bool_lit(false)]; + d = term![ITE; cond, term![BV_ADD; d.clone(), bv_lit(1, m_size)], d.clone()] + } + + // if mantissa is zero, return zero, else return computed shift + let is_zero = term![EQ; mantissa.clone(), bv_lit(0, m_size)]; + term![ITE; is_zero, new_pf_lit(0), term![Op::UbvToPf(Box::new(cfg().field().clone())); d.clone()]] + } + // /// Precompute the power of two of a given pf term. + // pub fn power_of_two(var: &Term) -> Term { + // // maybe set the size to m_size. come back to this later. + // let d = term![Op::PfToBv(64); var.clone()]; + // term![Op::UbvToPf(Box::new(cfg().field().clone())); term![BV_SHL; bv_lit(1, 64), d]] + // } +} + +impl Gadget<'_> { + pub fn assert(&mut self, cond: Term) { + self.rewriter.assertions.push(cond); + } + + // /// TODO: implement lookup-related machinery + // pub fn query_power_of_two(&self, comp: &mut Computation, exp: &Term) -> Term { + // // Precompute the power of two of the exponent + // let res = compute_hint!( + // comp, + // Hint::power_of_two(exp), + // format!("{}.pow2", exp.as_var_name()) + // ); + + // // continue here... + // res + // } + + /// Creates a bool term checking if pf term is 1 or 0 + pub fn is_bool(&self, a: &Term) -> Term { + // compute a*(1-a) == 0 + self.is_zero(&term![PF_MUL; a.clone(), self.pf_sub(&new_pf_lit(1), a)]) + } + + /// Creates a bool term checking if pf term a is zero + pub fn is_zero(&self, a: &Term) -> Term { + term![EQ; a.clone(), new_pf_lit(0)] + } + + /// Creates a pf term subtracting pf term b from pf term a + pub fn pf_sub(&self, a: &Term, b: &Term) -> Term { + term![PF_ADD; a.clone(), term![PF_NEG; b.clone()]] + } +} diff --git a/src/ir/opt/fp/mod.rs b/src/ir/opt/fp/mod.rs new file mode 100644 index 000000000..30c81fde2 --- /dev/null +++ b/src/ir/opt/fp/mod.rs @@ -0,0 +1,361 @@ +//! Floating-point IEEE-754 constructions +//! +//! Replace IR floats with their respective constructions using lookup arguments. +//! The floats themselves are stored as a tuple of components. +//! +//! A tuple elimination pass should be run afterwards. +//! Based on the implementation of https://github.com/tumberger/zk-Location + +mod extras; + +use super::visit::RewritePass; +use crate::cfg::cfg_or_default as cfg; +use crate::compute_hint; +use crate::ir::term::*; +use extras::{new_pf_lit, Gadget, Hint}; +// use lookup::{lookup_exponent, lookup_mantissa}; +use rug::Integer; + +/// Represents an IEEE-754 floating-point number +#[allow(dead_code)] +struct FloatVar { + sign: Term, + exponent: Term, + mantissa: Term, + is_abnormal: Term, +} + +// Context +#[derive(Debug, Clone)] +#[allow(dead_code)] +struct Context { + e: usize, // Exponent bit width + m: usize, // Mantissa bit width + e_max: Integer, + e_normal_min: Integer, + e_min: Integer, +} + +#[allow(dead_code)] +impl Context { + pub fn new(e: usize, m: usize) -> Self { + let e_max = Integer::from(1) << (e - 1); + let e_normal_min = Integer::from(2) - &e_max; + let e_min = e_normal_min.clone() - (m + 1); + + Self { + e, + m, + e_max, + e_normal_min, + e_min, + } + } + + fn components_of(&self, v: u64) -> (Integer, Integer, Integer, Integer) { + let s = v >> (self.e + self.m); + let e = (v >> self.m) - (s << self.e); + let m = v - (s << (self.e + self.m)) - (e << self.m); + + let sign = Integer::from(s as i64); + + let exponent_max = Integer::from(1) << (self.e - 1); + let exponent_min = Integer::from(1) - &exponent_max; + + let mut exponent = Integer::from(e as i64) + &exponent_min; + let mut mantissa = Integer::from(m as i64); + let mut shift = 0_usize; + + for i in (0..self.m).rev() { + if ((m >> (self.m - 1 - i)) & 1) == 1 { + break; + } + shift += 1; + } + + let mantissa_is_not_zero = m != 0; + let exponent_is_min = exponent == exponent_min; + let exponent_is_max = exponent == exponent_max; + + if exponent_is_min { + exponent -= Integer::from(shift as i64); + mantissa <<= shift + 1; + } else if exponent_is_max && mantissa_is_not_zero { + mantissa = Integer::from(0); + } else { + mantissa += Integer::from(1) << self.m; + } + + let is_abnormal = if exponent_is_max { + Integer::from(1) + } else { + Integer::from(0) + }; + + (sign, exponent, mantissa, is_abnormal) + } + + pub fn new_f32_constant(&self, v: f32) -> (Integer, Integer, Integer, Integer) { + self.components_of(v.to_bits() as u64) + } + + pub fn new_f64_constant(&self, v: f64) -> (Integer, Integer, Integer, Integer) { + self.components_of(v.to_bits()) + } +} + +#[derive(Default)] +struct FloatRewriter { + assertions: Vec, +} + +impl RewritePass for FloatRewriter { + fn traverse(&mut self, computation: &mut Computation) { + self.traverse_full(computation, false, true); + + // Apply all stored assertions at the end + for assertion in self.assertions.drain(..) { + computation.assert(assertion); + } + } + + fn visit Vec>( + &mut self, + computation: &mut Computation, + orig: &Term, + _rewritten_children: F, + ) -> Option { + let mut gadget = Gadget { rewriter: self }; + + match &orig.op() { + // Destructures float variables into constrained tuples of (s: bool, e: field, m: field, a: bool). + Op::Var(v) => { + let ctx = match v.sort { + Sort::F32 => Context::new(8, 23), + Sort::F64 => Context::new(11, 52), + _ => todo!(), + }; + + // Precompute the float components + let fp_tup = compute_hint![ + computation, + format!("{}.fp_tup", v.name), + Hint::decode_float; + orig, + ctx.e, + ctx.m + ]; + let s = term(Op::Field(0), vec![fp_tup.clone()]); + let e = term(Op::Field(1), vec![fp_tup.clone()]); + let v_pf = term(Op::Field(2), vec![fp_tup.clone()]); + + let s_mul = term![PF_MUL; s.clone(), new_pf_lit(1u64 << (ctx.e + ctx.m))]; + let e_mul = term![PF_MUL; e.clone(), new_pf_lit(1u64 << ctx.m)]; + let m = term![PF_ADD; v_pf, term![PF_NEG; term![PF_ADD; s_mul, e_mul]]]; + + // Assert well-formedness of s, e and m through range checks + gadget.assert(gadget.is_bool(&s)); + gadget.assert(term![Op::PfFitsInBits(ctx.e); e.clone()]); + gadget.assert(term![Op::PfFitsInBits(ctx.m); m.clone()]); + + let exponent_min = gadget.pf_sub(&new_pf_lit(ctx.e_normal_min), &new_pf_lit(1)); + let exponent_max = new_pf_lit(ctx.e_max); + + let exponent = term![PF_ADD; e.clone(), exponent_min.clone()]; + + let mantissa_is_zero = gadget.is_zero(&m); + let _mantissa_is_not_zero = term![NOT; mantissa_is_zero.clone()]; + let _exponent_is_min = term![EQ; exponent.clone(), exponent_min.clone()]; + let _exponent_is_max = term![EQ; exponent.clone(), exponent_max]; + + // Precompute the shift length `d` for float normalization + let d = compute_hint![ + computation, + format!("{}.d", v.name), + Hint::normalize; + &m, + ctx.m + ]; + + // Assert soundness of `d` with a range check + // It should be in the range [0, ctx.m]. + let bitlen = (usize::BITS - ctx.m.leading_zeros()) as usize; + gadget.assert(term![Op::PfFitsInBits(bitlen); d.clone()]); + // continue here with the querypowerof2 gadget operation on shift + + // return the term that we rewrite orig with + Some(term![ + Op::Tuple; + term![Op::PfToBoolTrusted; s.clone()], + e, + m, + term![Op::PfToBoolTrusted; s] + ]) + } + // Rewrites float constants into tuples of (s: bool, e: field, m: field, a: bool). + Op::Const(v) => { + let comps = match v.as_ref() { + Value::F32(fp) => Some(Context::new(8, 23).new_f32_constant(*fp)), + Value::F64(fp) => Some(Context::new(11, 52).new_f64_constant(*fp)), + _ => None, + }; + comps.map(|float| { + term( + Op::Tuple, + vec![ + const_(Value::Bool(float.0 == 1)), + const_(Value::Field(cfg().field().new_v(float.1))), + const_(Value::Field(cfg().field().new_v(float.2))), + const_(Value::Bool(float.3 == 1)), + ], + ) + }) + } + Op::Eq => todo!(), + Op::FpBinOp(_fp_bin_op) => todo!(), + Op::FpBinPred(_fp_bin_pred) => todo!(), + Op::FpUnPred(_fp_un_pred) => todo!(), + Op::FpUnOp(_fp_un_op) => todo!(), + Op::BvToFp => todo!(), + Op::UbvToFp(_) => todo!(), + Op::SbvToFp(_) => todo!(), + Op::FpToFp(_) => todo!(), + Op::PfToFp(_) => todo!(), + _ => None, + } + } +} + +/// Replace floats with IEEE 754 constructions +pub fn construct_floats(c: &mut Computation) { + let mut pass = FloatRewriter::default(); + pass.traverse(c); +} + +#[cfg(test)] +mod test { + use super::*; + use crate::cfg::cfg_or_default as cfg; + + #[test] + fn const_fp32() { + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs) (commitments)) + (precompute () () (#t )) + (let + ( + (val #fp3.1415926535f32) + ) + val + ) + ) + ", + ); + construct_floats(&mut c); + + let expected = term( + Op::Tuple, + vec![ + const_(Value::Bool(false)), + const_(Value::Field(cfg().field().new_v(1))), + const_(Value::Field(cfg().field().new_v(13176795))), + const_(Value::Bool(false)), + ], + ); + + assert_eq!(c.outputs[0], expected); + } + + #[test] + fn const_fp64() { + let mut c = text::parse_computation( + b" + (computation + (metadata (parties ) (inputs) (commitments)) + (precompute () () (#t )) + (let + ( + (val #fp-3.141592653589793) + ) + val + ) + ) + ", + ); + construct_floats(&mut c); + + let expected = term( + Op::Tuple, + vec![ + const_(Value::Bool(true)), + const_(Value::Field(cfg().field().new_v(1))), + const_(Value::Field(cfg().field().new_v( + Integer::from_str_radix("7074237752028440", 10).unwrap(), + ))), + const_(Value::Bool(false)), + ], + ); + + assert_eq!(c.outputs[0], expected); + } + + // #[test] + // fn init_fp32() { + // let mut c = text::parse_computation( + // b" + // (computation + // (metadata (parties ) (inputs (a f32)) (commitments)) + // (precompute () () (#t )) + // (let + // ( + // (val a) + // ) + // val + // ) + // ) + // ", + // ); + // let values = text::parse_value_map( + // b" + // (set_default_modulus 11 + // (let ( + // (a #fp3.1415926535f32) + // ) false ; ignored + // )) + // ", + // ); + // // println!("{}", text::serialize_computation(&c)); + // construct_floats(&mut c); + // println!("{}", text::serialize_computation(&c)); + // // println!("{:#?}", c.eval_all(&values)); + // assert_eq!(vec![Value::F32(1.234)], c.eval_all(&values)); + // } + + // #[test] + // fn add_fp32() { + // let mut c = text::parse_computation( + // b" + // (computation + // (metadata (parties ) (inputs (a f32) (b f32) (return f32)) (commitments)) + // (precompute () () (#t )) + // (= return (fpadd a b)) + // ) + // ", + // ); + // let values = text::parse_value_map( + // b" + // (set_default_modulus 11 + // (let ( + // (a #fp1f32) + // (b #fp2f32) + // (return #fp3f32) + // ) false ; ignored + // )) + // ", + // ); + // construct_floats(&mut c); + // assert_eq!(vec![Value::F32(3.0)], c.eval_all(&values)); + // } +} diff --git a/src/ir/opt/mod.rs b/src/ir/opt/mod.rs index fa358ad0e..e4d1c5947 100644 --- a/src/ir/opt/mod.rs +++ b/src/ir/opt/mod.rs @@ -5,6 +5,7 @@ pub mod chall; pub mod cstore; pub mod fits_in_bits_ip; pub mod flat; +pub mod fp; pub mod inline; pub mod link; pub mod mem; @@ -59,6 +60,8 @@ pub enum Opt { DeskolemizeWitnesses, /// Check bit-constaints with challenges. FitsInBitsIp, + /// Replace floats with IEEE 754 constructions using lookup arguments + Float, } /// Run optimizations on `cs`, in this order, returning the new constraint system. @@ -169,6 +172,9 @@ pub fn opt>(mut cs: Computations, optimizations: I) Opt::FitsInBitsIp => { fits_in_bits_ip::fits_in_bits_ip(c); } + Opt::Float => { + fp::construct_floats(c); + } } info!( "After {:?}: {} terms", diff --git a/src/ir/term/eval.rs b/src/ir/term/eval.rs index b4bdf61d7..6be6355a2 100644 --- a/src/ir/term/eval.rs +++ b/src/ir/term/eval.rs @@ -362,6 +362,17 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> ), } } + Op::FpToPf(fty) => { + let val = match args[0] { + Value::F32(f) => rug::Integer::from(f.to_bits()), + Value::F64(f) => rug::Integer::from(f.to_bits()), + _ => panic!( + "Expected floating-point value for {} but got {:?}", + op, args[0] + ), + }; + Value::Field(fty.new_v(val)) + } Op::UbvToPf(fty) => Value::Field(fty.new_v(args[0].as_bv().uint())), Op::PfChallenge(c) => Value::Field(eval_pf_challenge(&c.name, &c.field)), Op::Witness(_) => args[0].clone(), diff --git a/src/ir/term/fmt.rs b/src/ir/term/fmt.rs index 436b59da4..76c1babce 100644 --- a/src/ir/term/fmt.rs +++ b/src/ir/term/fmt.rs @@ -365,6 +365,7 @@ impl DisplayIr for Op { Op::SbvToFp(a) => write!(f, "(sbv2fp {a})"), Op::FpToFp(a) => write!(f, "(fp2fp {a})"), Op::PfToFp(a) => write!(f, "(pf2fp {a})"), + Op::FpToPf(a) => write!(f, "(fp2pf {a})"), Op::PfUnOp(a) => write!(f, "{a}"), Op::PfNaryOp(a) => write!(f, "{a}"), Op::PfDiv => write!(f, "/"), diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index 1fab9bad1..3d901bcd1 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -123,6 +123,8 @@ pub enum Op { /// translate the prime-field number represented by the argument to a floating-point value /// of this width. PfToFp(usize), + /// Floating-point value to Field + FpToPf(FieldT), /// Prime-field unary operator PfUnOp(PfUnOp), @@ -337,7 +339,52 @@ pub const INT_LE: Op = Op::IntBinPred(IntBinPred::Le); pub const INT_GT: Op = Op::IntBinPred(IntBinPred::Gt); /// integer greater than or equal pub const INT_GE: Op = Op::IntBinPred(IntBinPred::Ge); -// TODO: add floating-point operator abbreviations +/// floating-point addition +pub const FP_ADD: Op = Op::FpBinOp(FpBinOp::Add); +/// floating-point multiplication +pub const FP_MUL: Op = Op::FpBinOp(FpBinOp::Mul); +/// floating-point subtraction +pub const FP_SUB: Op = Op::FpBinOp(FpBinOp::Sub); +/// floating-point division +pub const FP_DIV: Op = Op::FpBinOp(FpBinOp::Div); +/// floating-point remainder +pub const FP_REM: Op = Op::FpBinOp(FpBinOp::Rem); +/// floating-point maximum +pub const FP_MAX: Op = Op::FpBinOp(FpBinOp::Max); +/// floating-point minimum +pub const FP_MIN: Op = Op::FpBinOp(FpBinOp::Min); +/// floating-point less than or equal +pub const FP_LE: Op = Op::FpBinPred(FpBinPred::Le); +/// floating-point less than +pub const FP_LT: Op = Op::FpBinPred(FpBinPred::Lt); +/// floating-point equal to +pub const FP_EQ: Op = Op::FpBinPred(FpBinPred::Eq); +/// floating-point greater than or equal +pub const FP_GE: Op = Op::FpBinPred(FpBinPred::Ge); +/// floating-point greater than +pub const FP_GT: Op = Op::FpBinPred(FpBinPred::Gt); +/// floating-point is normal +pub const FP_IS_NORM: Op = Op::FpUnPred(FpUnPred::Normal); +/// floating-point is subnormal +pub const FP_IS_SUBNORM: Op = Op::FpUnPred(FpUnPred::Subnormal); +/// floating-point is zero +pub const FP_IS_ZERO: Op = Op::FpUnPred(FpUnPred::Zero); +/// floating-point is infinite +pub const FP_IS_INF: Op = Op::FpUnPred(FpUnPred::Infinite); +/// floating-point is not-a-number +pub const FP_IS_NAN: Op = Op::FpUnPred(FpUnPred::Nan); +/// floating-point is negative +pub const FP_IS_NEG: Op = Op::FpUnPred(FpUnPred::Negative); +/// floating-point is positive +pub const FP_IS_POS: Op = Op::FpUnPred(FpUnPred::Positive); +/// floating-point unary negation +pub const FP_NEG: Op = Op::FpUnOp(FpUnOp::Neg); +/// floating-point absolute value +pub const FP_ABS: Op = Op::FpUnOp(FpUnOp::Abs); +/// floating-point square root +pub const FP_SQRT: Op = Op::FpUnOp(FpUnOp::Sqrt); +/// floating-point round +pub const FP_ROUND: Op = Op::FpUnOp(FpUnOp::Round); impl Op { /// Number of arguments for this operator. `None` if n-ary. @@ -371,6 +418,7 @@ impl Op { Op::SbvToFp(_) => Some(1), Op::FpToFp(_) => Some(1), Op::PfToFp(_) => Some(1), + Op::FpToPf(_) => Some(1), Op::PfUnOp(_) => Some(1), Op::PfDiv => Some(2), Op::PfNaryOp(_) => None, diff --git a/src/ir/term/ty.rs b/src/ir/term/ty.rs index a33135c38..4768e427b 100644 --- a/src/ir/term/ty.rs +++ b/src/ir/term/ty.rs @@ -56,6 +56,7 @@ fn check_dependencies(t: &Term) -> Vec { Op::SbvToFp(_) => Vec::new(), Op::FpToFp(_) => Vec::new(), Op::PfToFp(_) => Vec::new(), + Op::FpToPf(_) => Vec::new(), Op::PfUnOp(_) => vec![t.cs()[0].clone()], Op::PfDiv => vec![t.cs()[0].clone()], Op::PfNaryOp(_) => vec![t.cs()[0].clone()], @@ -142,6 +143,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result { Op::FpToFp(32) => Ok(Sort::F32), Op::PfToFp(64) => Ok(Sort::F64), Op::PfToFp(32) => Ok(Sort::F32), + Op::FpToPf(m) => Ok(Sort::Field(m.clone())), Op::PfUnOp(_) => Ok(get_ty(&t.cs()[0]).clone()), Op::PfDiv => Ok(get_ty(&t.cs()[0]).clone()), Op::PfNaryOp(_) => Ok(get_ty(&t.cs()[0]).clone()), @@ -366,6 +368,7 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result fp_or(a, "fp-to-fp").map(|_| Sort::F32), (Op::PfToFp(64), &[a]) => pf_or(a, "pf-to-fp").map(|_| Sort::F64), (Op::PfToFp(32), &[a]) => pf_or(a, "pf-to-fp").map(|_| Sort::F32), + (Op::FpToPf(m), &[a]) => fp_or(a, "fp-to-pf").map(|_| Sort::Field(m.clone())), (Op::PfNaryOp(_), a) => { let ctx = "pf nary op"; all_eq_or(a.iter().cloned(), ctx) From e620a06387b8e1679fe3fe8f0588b9961bcaf0e1 Mon Sep 17 00:00:00 2001 From: Lorenzo Rota Date: Wed, 12 Mar 2025 15:26:20 +0100 Subject: [PATCH 10/10] Fixed doc error --- src/ir/opt/fp/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ir/opt/fp/mod.rs b/src/ir/opt/fp/mod.rs index 30c81fde2..4e4ce6f21 100644 --- a/src/ir/opt/fp/mod.rs +++ b/src/ir/opt/fp/mod.rs @@ -4,7 +4,7 @@ //! The floats themselves are stored as a tuple of components. //! //! A tuple elimination pass should be run afterwards. -//! Based on the implementation of https://github.com/tumberger/zk-Location +//! Based on the implementation of mod extras;