diff --git a/src/smtml/dune b/src/smtml/dune index 586800ea..13a0bec7 100644 --- a/src/smtml/dune +++ b/src/smtml/dune @@ -30,7 +30,6 @@ interpret interpret_intf lexer - loc log logic mappings diff --git a/src/smtml/expr.ml b/src/smtml/expr.ml index 57cc6d87..5848172d 100644 --- a/src/smtml/expr.ml +++ b/src/smtml/expr.ml @@ -2,15 +2,17 @@ (* Copyright (C) 2023-2024 formalsec *) (* Written by the Smtml programmers *) -type t = expr Hc.hash_consed +module Hashcons = Hc + +type t = + | Imm of Value.t + | Sym of expr Hc.hash_consed and expr = - | Val of Value.t | Ptr of { base : Bitvector.t ; offset : t } - | Loc of Loc.t | Symbol of Symbol.t | List of t list | App of Symbol.t * t list @@ -24,45 +26,50 @@ and expr = | Concat of t * t | Binder of Binder.t * t list * t +let equal_of_t a b = + match (a, b) with + | Imm a, Imm b -> Value.equal a b + | Sym a, Sym b -> phys_equal a b + | (Imm _ | Sym _), _ -> false + +let[@inline] hash_of_t e = match e with Imm v -> Value.hash v | Sym e -> e.tag + module Expr = struct type t = expr let list_eq (l1 : 'a list) (l2 : 'a list) : bool = - if List.compare_lengths l1 l2 = 0 then List.for_all2 phys_equal l1 l2 + if List.compare_lengths l1 l2 = 0 then List.for_all2 equal_of_t l1 l2 else false let equal (e1 : expr) (e2 : expr) : bool = match (e1, e2) with - | Val v1, Val v2 -> Value.equal v1 v2 - | Loc a, Loc b -> Loc.compare a b = 0 | Ptr { base = b1; offset = o1 }, Ptr { base = b2; offset = o2 } -> - Bitvector.equal b1 b2 && phys_equal o1 o2 + Bitvector.equal b1 b2 && equal_of_t o1 o2 | Symbol s1, Symbol s2 -> Symbol.equal s1 s2 | List l1, List l2 -> list_eq l1 l2 | App (s1, l1), App (s2, l2) -> Symbol.equal s1 s2 && list_eq l1 l2 | Unop (t1, op1, e1), Unop (t2, op2, e2) -> - Ty.equal t1 t2 && Ty.Unop.equal op1 op2 && phys_equal e1 e2 + Ty.equal t1 t2 && Ty.Unop.equal op1 op2 && equal_of_t e1 e2 | Binop (t1, op1, e1, e3), Binop (t2, op2, e2, e4) -> - Ty.equal t1 t2 && Ty.Binop.equal op1 op2 && phys_equal e1 e2 - && phys_equal e3 e4 + Ty.equal t1 t2 && Ty.Binop.equal op1 op2 && equal_of_t e1 e2 + && equal_of_t e3 e4 | Relop (t1, op1, e1, e3), Relop (t2, op2, e2, e4) -> - Ty.equal t1 t2 && Ty.Relop.equal op1 op2 && phys_equal e1 e2 - && phys_equal e3 e4 + Ty.equal t1 t2 && Ty.Relop.equal op1 op2 && equal_of_t e1 e2 + && equal_of_t e3 e4 | Triop (t1, op1, e1, e3, e5), Triop (t2, op2, e2, e4, e6) -> - Ty.equal t1 t2 && Ty.Triop.equal op1 op2 && phys_equal e1 e2 - && phys_equal e3 e4 && phys_equal e5 e6 + Ty.equal t1 t2 && Ty.Triop.equal op1 op2 && equal_of_t e1 e2 + && equal_of_t e3 e4 && equal_of_t e5 e6 | Cvtop (t1, op1, e1), Cvtop (t2, op2, e2) -> - Ty.equal t1 t2 && Ty.Cvtop.equal op1 op2 && phys_equal e1 e2 + Ty.equal t1 t2 && Ty.Cvtop.equal op1 op2 && equal_of_t e1 e2 | Naryop (t1, op1, l1), Naryop (t2, op2, l2) -> Ty.equal t1 t2 && Ty.Naryop.equal op1 op2 && list_eq l1 l2 | Extract (e1, h1, l1), Extract (e2, h2, l2) -> - phys_equal e1 e2 && h1 = h2 && l1 = l2 - | Concat (e1, e3), Concat (e2, e4) -> phys_equal e1 e2 && phys_equal e3 e4 + equal_of_t e1 e2 && h1 = h2 && l1 = l2 + | Concat (e1, e3), Concat (e2, e4) -> equal_of_t e1 e2 && equal_of_t e3 e4 | Binder (binder1, vars1, e1), Binder (binder2, vars2, e2) -> - Binder.equal binder1 binder2 && list_eq vars1 vars2 && phys_equal e1 e2 - | ( ( Val _ | Ptr _ | Loc _ | Symbol _ | List _ | App _ | Unop _ | Binop _ - | Triop _ | Relop _ | Cvtop _ | Naryop _ | Extract _ | Concat _ - | Binder _ ) + Binder.equal binder1 binder2 && list_eq vars1 vars2 && equal_of_t e1 e2 + | ( ( Ptr _ | Symbol _ | List _ | App _ | Unop _ | Binop _ | Triop _ + | Relop _ | Cvtop _ | Naryop _ | Extract _ | Concat _ | Binder _ ) , _ ) -> false @@ -71,182 +78,193 @@ module Expr = struct let hash (e : expr) : int = match e with - | Val v -> Value.hash v - | Ptr { base; offset } -> combine (Bitvector.hash base) offset.tag - | Loc l -> Loc.hash l + | Ptr { base; offset } -> combine (Bitvector.hash base) (hash_of_t offset) | Symbol s -> Symbol.hash s - | List l -> List.fold_left (fun acc x -> combine acc x.Hc.tag) 0 l + | List l -> List.fold_left (fun acc x -> combine acc (hash_of_t x)) 0 l | App (s, es) -> let h_s = Symbol.hash s in - List.fold_left (fun acc x -> combine acc x.Hc.tag) h_s es + List.fold_left (fun acc x -> combine acc (hash_of_t x)) h_s es | Unop (ty, op, e) -> let h1 = Ty.hash ty in let h2 = combine h1 (Ty.Unop.hash op) in - combine h2 e.tag + combine h2 (hash_of_t e) | Binop (ty, op, e1, e2) -> let h = Ty.hash ty in let h = combine h (Ty.Binop.hash op) in - let h = combine h e1.tag in - combine h e2.tag + let h = combine h (hash_of_t e1) in + combine h (hash_of_t e2) | Triop (ty, op, e1, e2, e3) -> let h = Ty.hash ty in let h = combine h (Ty.Triop.hash op) in - let h = combine h e1.tag in - let h = combine h e2.tag in - combine h e3.tag + let h = combine h (hash_of_t e1) in + let h = combine h (hash_of_t e2) in + combine h (hash_of_t e3) | Relop (ty, op, e1, e2) -> let h = Ty.hash ty in let h = combine h (Ty.Relop.hash op) in - let h = combine h e1.tag in - combine h e2.tag + let h = combine h (hash_of_t e1) in + combine h (hash_of_t e2) | Cvtop (ty, op, e) -> let h = Ty.hash ty in let h = combine h (Ty.Cvtop.hash op) in - combine h e.tag + combine h (hash_of_t e) | Naryop (ty, op, es) -> let h = Ty.hash ty in let h = combine h (Ty.Naryop.hash op) in - List.fold_left (fun acc x -> combine acc x.Hc.tag) h es + List.fold_left (fun acc x -> combine acc (hash_of_t x)) h es | Extract (e, hi, lo) -> - let h = e.tag in + let h = hash_of_t e in let h = combine h hi in combine h lo - | Concat (e1, e2) -> combine e1.tag e2.tag + | Concat (e1, e2) -> combine (hash_of_t e1) (hash_of_t e2) | Binder (b, vars, e) -> let h = Hashtbl.hash b in - let h_vars = List.fold_left (fun acc x -> combine acc x.Hc.tag) h vars in - combine h_vars e.tag + let h_vars = + List.fold_left (fun acc x -> combine acc (hash_of_t x)) h vars + in + combine h_vars (hash_of_t e) end -module Hc = Hc.Make [@inlined hint] (Expr) - -let equal (hte1 : t) (hte2 : t) = phys_equal hte1 hte2 [@@inline] - -let hash (hte : t) = hte.tag [@@inline] - -module Key = struct - type nonrec t = t - - let to_int hte = hash hte +module Hc = Hashcons.Make [@inlined hint] (Expr) - let compare x y = compare (to_int x) (to_int y) -end +let[@inline] equal (a : t) (b : t) = equal_of_t a b -let[@inline] make e = Hc.hashcons e +let[@inline] hash (hte : t) = hash_of_t hte -let[@inline] view (hte : t) = hte.node +let[@inline] make e = Sym (Hc.hashcons e) -let[@inline] compare (hte1 : t) (hte2 : t) = compare hte1.tag hte2.tag +let[@inline] view hte = hte.Hashcons.node -let symbol s = make (Symbol s) +let[@inline] compare (a : t) (b : t) = + match (a, b) with + | Imm a, Imm b -> Value.compare a b + | Sym a, Sym b -> compare a.tag b.tag + | Imm _, _ -> -1 + | Sym _, _ -> 1 (** The return type of an expression *) let rec ty (hte : t) : Ty.t = - match view hte with - | Val x -> Value.type_of x - | Ptr _ -> Ty_bitv 32 - | Loc _ -> Ty_app - | Symbol x -> Symbol.type_of x - | List _ -> Ty_list - | App (sym, _) -> begin match sym.ty with Ty_none -> Ty_app | ty -> ty end - | Triop (_, Ite, _, hte1, hte2) -> - let ty1 = ty hte1 in - assert ( - let ty2 = ty hte2 in - Ty.equal ty1 ty2 ); - ty1 - | Cvtop (_, (Zero_extend m | Sign_extend m), hte) -> ( - match ty hte with Ty_bitv n -> Ty_bitv (n + m) | _ -> assert false ) - | Unop (ty, _, _) - | Binop (ty, _, _, _) - | Triop (ty, _, _, _, _) - | Relop (ty, _, _, _) - | Cvtop (ty, _, _) - | Naryop (ty, _, _) -> - ty - | Extract (_, h, l) -> Ty_bitv ((h - l) * 8) - | Concat (e1, e2) -> ( - match (ty e1, ty e2) with - | Ty_bitv n1, Ty_bitv n2 -> Ty_bitv (n1 + n2) - | t1, t2 -> - Fmt.failwith "Invalid concat of (%a) with (%a)" Ty.pp t1 Ty.pp t2 ) - | Binder (_, _, e) -> ty e + match hte with + | Imm x -> Value.type_of x + | Sym x -> begin + match view x with + | Ptr _ -> Ty_bitv 32 + | Symbol x -> Symbol.type_of x + | List _ -> Ty_list + | App (sym, _) -> begin match sym.ty with Ty_none -> Ty_app | ty -> ty end + | Triop (_, Ite, _, hte1, hte2) -> + let ty1 = ty hte1 in + assert ( + let ty2 = ty hte2 in + Ty.equal ty1 ty2 ); + ty1 + | Cvtop (_, (Zero_extend m | Sign_extend m), hte) -> ( + match ty hte with Ty_bitv n -> Ty_bitv (n + m) | _ -> assert false ) + | Unop (ty, _, _) + | Binop (ty, _, _, _) + | Triop (ty, _, _, _, _) + | Relop (ty, _, _, _) + | Cvtop (ty, _, _) + | Naryop (ty, _, _) -> + ty + | Extract (_, h, l) -> Ty_bitv ((h - l) * 8) + | Concat (e1, e2) -> ( + match (ty e1, ty e2) with + | Ty_bitv n1, Ty_bitv n2 -> Ty_bitv (n1 + n2) + | t1, t2 -> + Fmt.failwith "Invalid concat of (%a) with (%a)" Ty.pp t1 Ty.pp t2 ) + | Binder (_, _, e) -> ty e + end let rec is_symbolic (v : t) : bool = - match view v with - | Val _ | Loc _ -> false - | Symbol _ -> true - | Ptr { offset; _ } -> is_symbolic offset - | Unop (_, _, v) | Cvtop (_, _, v) | Extract (v, _, _) | Binder (_, _, v) -> - is_symbolic v - | Binop (_, _, v1, v2) | Relop (_, _, v1, v2) | Concat (v1, v2) -> - is_symbolic v1 || is_symbolic v2 - | Triop (_, _, v1, v2, v3) -> - is_symbolic v1 || is_symbolic v2 || is_symbolic v3 - | List vs | App (_, vs) | Naryop (_, _, vs) -> List.exists is_symbolic vs + match v with + | Imm _ -> false + | Sym x -> begin + match view x with + | Symbol _ -> true + | Ptr { offset; _ } -> is_symbolic offset + | Unop (_, _, v) | Cvtop (_, _, v) | Extract (v, _, _) | Binder (_, _, v) -> + is_symbolic v + | Binop (_, _, v1, v2) | Relop (_, _, v1, v2) | Concat (v1, v2) -> + is_symbolic v1 || is_symbolic v2 + | Triop (_, _, v1, v2, v3) -> + is_symbolic v1 || is_symbolic v2 || is_symbolic v3 + | List vs | App (_, vs) | Naryop (_, _, vs) -> List.exists is_symbolic vs + end let get_symbols (hte : t list) = let tbl = Hashtbl.create 64 in let rec symbols (hte : t) = - match view hte with - | Val _ | Loc _ -> () - | Ptr { offset; _ } -> symbols offset - | Symbol s -> Hashtbl.replace tbl s () - | List es -> List.iter symbols es - | App (_, es) -> List.iter symbols es - | Unop (_, _, e1) -> symbols e1 - | Binop (_, _, e1, e2) -> - symbols e1; - symbols e2 - | Triop (_, _, e1, e2, e3) -> - symbols e1; - symbols e2; - symbols e3 - | Relop (_, _, e1, e2) -> - symbols e1; - symbols e2 - | Cvtop (_, _, e) -> symbols e - | Naryop (_, _, es) -> List.iter symbols es - | Extract (e, _, _) -> symbols e - | Concat (e1, e2) -> - symbols e1; - symbols e2 - | Binder (_, vars, e) -> - List.iter symbols vars; - symbols e + match hte with + | Imm _ -> () + | Sym x -> begin + match view x with + | Ptr { offset; _ } -> symbols offset + | Symbol s -> Hashtbl.replace tbl s () + | List es -> List.iter symbols es + | App (_, es) -> List.iter symbols es + | Unop (_, _, e1) -> symbols e1 + | Binop (_, _, e1, e2) -> + symbols e1; + symbols e2 + | Triop (_, _, e1, e2, e3) -> + symbols e1; + symbols e2; + symbols e3 + | Relop (_, _, e1, e2) -> + symbols e1; + symbols e2 + | Cvtop (_, _, e) -> symbols e + | Naryop (_, _, es) -> List.iter symbols es + | Extract (e, _, _) -> symbols e + | Concat (e1, e2) -> + symbols e1; + symbols e2 + | Binder (_, vars, e) -> + List.iter symbols vars; + symbols e + end in List.iter symbols hte; Hashtbl.fold (fun k () acc -> k :: acc) tbl [] let rec pp fmt (hte : t) = - match view hte with - | Val v -> Value.pp fmt v - | Ptr { base; offset } -> Fmt.pf fmt "(Ptr %a %a)" Bitvector.pp base pp offset - | Loc l -> Fmt.pf fmt "(loc %a)" Loc.pp l - | Symbol s -> Fmt.pf fmt "@[%a@]" Symbol.pp s - | List v -> Fmt.pf fmt "@[[%a]@]" (Fmt.list ~sep:Fmt.comma pp) v - | App (s, v) -> - Fmt.pf fmt "@[(%a@ %a)@]" Symbol.pp s (Fmt.list ~sep:Fmt.comma pp) v - | Unop (ty, op, e) -> - Fmt.pf fmt "@[(%a.%a@ %a)@]" Ty.pp ty Ty.Unop.pp op pp e - | Binop (ty, op, e1, e2) -> - Fmt.pf fmt "@[(%a.%a@ %a@ %a)@]" Ty.pp ty Ty.Binop.pp op pp e1 pp e2 - | Triop (ty, op, e1, e2, e3) -> - Fmt.pf fmt "@[(%a.%a@ %a@ %a@ %a)@]" Ty.pp ty Ty.Triop.pp op pp e1 pp - e2 pp e3 - | Relop (ty, op, e1, e2) -> - Fmt.pf fmt "@[(%a.%a@ %a@ %a)@]" Ty.pp ty Ty.Relop.pp op pp e1 pp e2 - | Cvtop (ty, op, e) -> - Fmt.pf fmt "@[(%a.%a@ %a)@]" Ty.pp ty Ty.Cvtop.pp op pp e - | Naryop (ty, op, es) -> - Fmt.pf fmt "@[(%a.%a@ (%a))@]" Ty.pp ty Ty.Naryop.pp op - (Fmt.list ~sep:Fmt.comma pp) - es - | Extract (e, h, l) -> Fmt.pf fmt "@[(extract@ %a@ %d@ %d)@]" pp e l h - | Concat (e1, e2) -> Fmt.pf fmt "@[(++@ %a@ %a)@]" pp e1 pp e2 - | Binder (b, vars, e) -> - Fmt.pf fmt "@[(%a@ (%a)@ %a)@]" Binder.pp b (Fmt.list ~sep:Fmt.sp pp) - vars pp e + match hte with + | Imm v -> Value.pp fmt v + | Sym hte -> begin + match view hte with + | Ptr { base; offset } -> + Fmt.pf fmt "(Ptr %a %a)" Bitvector.pp base pp offset + | Symbol s -> Fmt.pf fmt "@[%a@]" Symbol.pp s + | List v -> Fmt.pf fmt "@[[%a]@]" (Fmt.list ~sep:Fmt.comma pp) v + | App (s, v) -> + Fmt.pf fmt "@[(%a@ %a)@]" Symbol.pp s + (Fmt.list ~sep:Fmt.comma pp) + v + | Unop (ty, op, e) -> + Fmt.pf fmt "@[(%a.%a@ %a)@]" Ty.pp ty Ty.Unop.pp op pp e + | Binop (ty, op, e1, e2) -> + Fmt.pf fmt "@[(%a.%a@ %a@ %a)@]" Ty.pp ty Ty.Binop.pp op pp e1 pp + e2 + | Triop (ty, op, e1, e2, e3) -> + Fmt.pf fmt "@[(%a.%a@ %a@ %a@ %a)@]" Ty.pp ty Ty.Triop.pp op pp e1 + pp e2 pp e3 + | Relop (ty, op, e1, e2) -> + Fmt.pf fmt "@[(%a.%a@ %a@ %a)@]" Ty.pp ty Ty.Relop.pp op pp e1 pp + e2 + | Cvtop (ty, op, e) -> + Fmt.pf fmt "@[(%a.%a@ %a)@]" Ty.pp ty Ty.Cvtop.pp op pp e + | Naryop (ty, op, es) -> + Fmt.pf fmt "@[(%a.%a@ (%a))@]" Ty.pp ty Ty.Naryop.pp op + (Fmt.list ~sep:Fmt.comma pp) + es + | Extract (e, h, l) -> + Fmt.pf fmt "@[(extract@ %a@ %d@ %d)@]" pp e l h + | Concat (e1, e2) -> Fmt.pf fmt "@[(++@ %a@ %a)@]" pp e1 pp e2 + | Binder (b, vars, e) -> + Fmt.pf fmt "@[(%a@ (%a)@ %a)@]" Binder.pp b + (Fmt.list ~sep:Fmt.sp pp) vars pp e + end let pp_list fmt (es : t list) = Fmt.hovbox (Fmt.list ~sep:Fmt.comma pp) fmt es @@ -275,11 +293,11 @@ let pp_smtml fmt (es : t list) : unit = let to_string e = Fmt.str "%a" pp e -let value (v : Value.t) : t = make (Val v) [@@inline] +let[@inline] value (v : Value.t) : t = Imm v -let ptr base offset = make (Ptr { base = Bitvector.of_int32 base; offset }) +let symbol s = make (Symbol s) -let loc l = make (Loc l) +let ptr base offset = make (Ptr { base = Bitvector.of_int32 base; offset }) let list l = make (List l) @@ -293,7 +311,7 @@ let forall vars body = binder Forall vars body let exists vars body = binder Exists vars body -let raw_unop ty op hte = make (Unop (ty, op, hte)) [@@inline] +let[@inline] raw_unop ty op hte = make (Unop (ty, op, hte)) let normalize_eq_or_ne op (ty', e1, e2) = let make_relop lhs rhs = Relop (ty', op, lhs, rhs) in @@ -303,333 +321,438 @@ let normalize_eq_or_ne op (ty', e1, e2) = match ty1 with | Ty_bitv m -> let binop = make (Binop (ty1, Sub, e1, e2)) in - let zero = make (Val (Bitv (Bitvector.make Z.zero m))) in + let zero = value (Bitv (Bitvector.make Z.zero m)) in make_relop binop zero | Ty_int -> let binop = make (Binop (ty1, Sub, e1, e2)) in - let zero = make (Val (Int Int.zero)) in + let zero = value (Int Int.zero) in make_relop binop zero | Ty_real -> let binop = make (Binop (ty1, Sub, e1, e2)) in - let zero = make (Val (Real 0.)) in + let zero = value (Real 0.) in make_relop binop zero | _ -> make_relop e1 e2 end let negate_relop (hte : t) : t = - let e = - match view hte with - | Relop (ty, Eq, e1, e2) -> normalize_eq_or_ne Ne (ty, e1, e2) - | Relop (ty, Ne, e1, e2) -> normalize_eq_or_ne Eq (ty, e1, e2) - | Relop (ty, Lt, e1, e2) -> Relop (ty, Le, e2, e1) - | Relop (ty, LtU, e1, e2) -> Relop (ty, LeU, e2, e1) - | Relop (ty, Le, e1, e2) -> Relop (ty, Lt, e2, e1) - | Relop (ty, LeU, e1, e2) -> Relop (ty, LtU, e2, e1) - | Relop (ty, Gt, e1, e2) -> Relop (ty, Le, e1, e2) - | Relop (ty, GtU, e1, e2) -> Relop (ty, LeU, e1, e2) - | Relop (ty, Ge, e1, e2) -> Relop (ty, Lt, e1, e2) - | Relop (ty, GeU, e1, e2) -> Relop (ty, LtU, e1, e2) - | _ -> Fmt.failwith "negate_relop: not a relop." - in - make e - -let unop ty op hte = - match (op, view hte) with - | Ty.Unop.(Regexp_loop _ | Regexp_star), _ -> raw_unop ty op hte - | _, Val v -> value (Eval.unop ty op v) - | Not, Unop (_, Not, hte') -> hte' - | Not, Relop (Ty_fp _, _, _, _) -> raw_unop ty op hte - | Not, Relop (_, _, _, _) -> negate_relop hte - | Neg, Unop (_, Neg, hte') -> hte' - | Trim, Cvtop (Ty_real, ToString, _) -> hte - | Head, List (hd :: _) -> hd - | Tail, List (_ :: tl) -> make (List tl) - | Reverse, List es -> make (List (List.rev es)) - | Length, List es -> value (Int (List.length es)) - | _ -> raw_unop ty op hte - -let raw_binop ty op hte1 hte2 = make (Binop (ty, op, hte1, hte2)) [@@inline] - -let rec binop ty op hte1 hte2 = - match (op, view hte1, view hte2) with - | Ty.Binop.(String_in_re | Regexp_range), _, _ -> raw_binop ty op hte1 hte2 - | op, Val v1, Val v2 -> value (Eval.binop ty op v1 v2) - | Sub, Ptr { base = b1; offset = os1 }, Ptr { base = b2; offset = os2 } -> - if Bitvector.equal b1 b2 then binop ty Sub os1 os2 - else raw_binop ty op hte1 hte2 - | Add, Ptr { base; offset }, _ -> - let m = Bitvector.numbits base in - make (Ptr { base; offset = binop (Ty_bitv m) Add offset hte2 }) - | Sub, Ptr { base; offset }, _ -> - let m = Bitvector.numbits base in - make (Ptr { base; offset = binop (Ty_bitv m) Sub offset hte2 }) - | Rem, Ptr { base; offset }, _ -> - let m = Bitvector.numbits base in - let rhs = value (Bitv base) in - let addr = binop (Ty_bitv m) Add rhs offset in - binop ty Rem addr hte2 - | Add, _, Ptr { base; offset } -> - let m = Bitvector.numbits base in - make (Ptr { base; offset = binop (Ty_bitv m) Add offset hte1 }) - | Sub, _, Ptr { base; offset } -> - let m = Bitvector.numbits base in - let base = value (Bitv base) in - binop ty Sub hte1 (binop (Ty_bitv m) Add base offset) - | (Add | Or), Val (Bitv bv), _ when Bitvector.eqz bv -> hte2 - | (And | Div | DivU | Mul | Rem | RemU), Val (Bitv bv), _ - when Bitvector.eqz bv -> - hte1 - | (Add | Or), _, Val (Bitv bv) when Bitvector.eqz bv -> hte1 - | (And | Mul), _, Val (Bitv bv) when Bitvector.eqz bv -> hte2 - | Add, Binop (ty, Add, x, { node = Val v1; _ }), Val v2 -> - let v = value (Eval.binop ty Add v1 v2) in - raw_binop ty Add x v - | Sub, Binop (ty, Sub, x, { node = Val v1; _ }), Val v2 -> - let v = value (Eval.binop ty Add v1 v2) in - raw_binop ty Sub x v - | Mul, Val (Bitv bv), _ when Bitvector.eq_one bv -> hte2 - | Mul, _, Val (Bitv bv) when Bitvector.eq_one bv -> hte1 - | Mul, Binop (ty, Mul, x, { node = Val v1; _ }), Val v2 -> - let v = value (Eval.binop ty Mul v1 v2) in - raw_binop ty Mul x v - | Add, Val v1, Binop (ty, Add, x, { node = Val v2; _ }) -> - let v = value (Eval.binop ty Add v1 v2) in - raw_binop ty Add v x - | Mul, Val v1, Binop (ty, Mul, x, { node = Val v2; _ }) -> - let v = value (Eval.binop ty Mul v1 v2) in - raw_binop ty Mul v x - | At, List es, Val (Int n) -> - (* TODO: use another datastructure? *) - begin match List.nth_opt es n with None -> assert false | Some v -> v + match hte with + | Sym { node = Relop (ty, Eq, e1, e2); _ } -> + make (normalize_eq_or_ne Ne (ty, e1, e2)) + | Sym { node = Relop (ty, Ne, e1, e2); _ } -> + make (normalize_eq_or_ne Eq (ty, e1, e2)) + | Sym { node = Relop (ty, Lt, e1, e2); _ } -> make (Relop (ty, Le, e2, e1)) + | Sym { node = Relop (ty, LtU, e1, e2); _ } -> make (Relop (ty, LeU, e2, e1)) + | Sym { node = Relop (ty, Le, e1, e2); _ } -> make (Relop (ty, Lt, e2, e1)) + | Sym { node = Relop (ty, LeU, e1, e2); _ } -> make (Relop (ty, LtU, e2, e1)) + | Sym { node = Relop (ty, Gt, e1, e2); _ } -> make (Relop (ty, Le, e1, e2)) + | Sym { node = Relop (ty, GtU, e1, e2); _ } -> make (Relop (ty, LeU, e1, e2)) + | Sym { node = Relop (ty, Ge, e1, e2); _ } -> make (Relop (ty, Lt, e1, e2)) + | Sym { node = Relop (ty, GeU, e1, e2); _ } -> make (Relop (ty, LtU, e1, e2)) + | _ -> Fmt.failwith "negate_relop: not a relop." + +let unop ty op e = + match e with + | Imm v -> Imm (Eval.unop ty op v) + | Sym { node; _ } -> begin + match (op, node) with + | Not, Unop (_, Not, e') -> e' + | Neg, Unop (_, Neg, e') -> e' + | Not, Relop (Ty_fp _, _, _, _) -> raw_unop ty op e + | Not, Relop _ -> negate_relop e + | Trim, Cvtop (Ty_real, ToString, _) -> e + | Head, List (hd :: _) -> hd + | Tail, List (_ :: tl) -> make (List tl) + | Reverse, List es -> make (List (List.rev es)) + | Length, List es -> value (Int (List.length es)) + | _ -> raw_unop ty op e + end + +let[@inline] raw_binop ty op hte1 hte2 = make (Binop (ty, op, hte1, hte2)) + +let rec binop ty op e1 e2 = + match (e1, e2) with + | Imm v1, Imm v2 -> value (Eval.binop ty op v1 v2) + | _ -> begin + match op with + | Add -> begin + match (e1, e2) with + | Imm (Bitv bv), _ when Bitvector.eqz bv -> e2 + | _, Imm (Bitv bv) when Bitvector.eqz bv -> e1 + (* Ptr Logic: Ptr + x *) + | Sym { node = Ptr { base; offset }; _ }, e + | e, Sym { node = Ptr { base; offset }; _ } -> + let m = Bitvector.numbits base in + let offset = binop (Ty_bitv m) Add offset e in + make (Ptr { base; offset }) + (* Normalization: c + x -> x + c *) + | (Imm _ as value), (Sym { node = Symbol _; _ } as sym) + | (Sym { node = Symbol _; _ } as sym), (Imm _ as value) -> + raw_binop ty Add sym value + (* Associativity: (x + c1) + c2 -> x + (c1 + c2) *) + | Sym { node = Binop (_, Add, x, Imm v1); _ }, Imm v2 + | Imm v2, Sym { node = Binop (_, Add, x, Imm v1); _ } -> + let v_sum = Eval.binop ty Add v1 v2 in + binop ty Add x (Imm v_sum) + | _ -> raw_binop ty Add e1 e2 + end + | Sub -> begin + match (e1, e2) with + | _, Imm (Bitv bv) when Bitvector.eqz bv -> e1 + (* Ptr - Ptr *) + | ( Sym { node = Ptr { base = b1; offset = o1 }; _ } + , Sym { node = Ptr { base = b2; offset = o2 }; _ } ) -> + if Bitvector.equal b1 b2 then binop ty Sub o1 o2 + else raw_binop ty Sub e1 e2 + (* Ptr - x *) + | Sym { node = Ptr { base; offset }; _ }, x -> + let m = Bitvector.numbits base in + let offset = binop (Ty_bitv m) Sub offset x in + make (Ptr { base; offset }) + (* x - Ptr *) + | x, Sym { node = Ptr { base; offset }; _ } -> + let m = Bitvector.numbits base in + let base = value (Bitv base) in + binop ty Sub x (binop (Ty_bitv m) Add base offset) + (* Associativity: (x - c1) - c2 -> x - (c1 + c2) *) + | Sym { node = Binop (_, Sub, x, Imm v1); _ }, Imm v2 -> + let v_sum = Eval.binop ty Add v1 v2 in + binop ty Sub x (Imm v_sum) + | _ -> raw_binop ty Sub e1 e2 + end + | Mul -> begin + match (e1, e2) with + | Imm (Bitv bv), _ when Bitvector.eq_one bv -> e2 + | _, Imm (Bitv bv) when Bitvector.eq_one bv -> e1 + | Imm (Bitv bv), _ when Bitvector.eqz bv -> e1 + | _, Imm (Bitv bv) when Bitvector.eqz bv -> e2 + (* Associativity *) + | Sym { node = Binop (_, Mul, x, Imm v1); _ }, Imm v2 + | Imm v2, Sym { node = Binop (_, Mul, x, Imm v1); _ } -> + let v_prod = Eval.binop ty Mul v1 v2 in + binop ty Mul x (Imm v_prod) + | _ -> raw_binop ty Mul e1 e2 + end + | Rem -> begin + match (e1, e2) with + | Sym { node = Ptr { base; offset }; _ }, _ -> + let m = Bitvector.numbits base in + let rhs = value (Bitv base) in + let addr = binop (Ty_bitv m) Add rhs offset in + binop ty Rem addr e2 + | Imm (Bitv bv), _ when Bitvector.eqz bv -> e1 (* 0 % x = 0 *) + | _ -> raw_binop ty Rem e1 e2 + end + | Or -> begin + match (e1, e2) with + | Imm (Bitv bv), _ when Bitvector.eqz bv -> e2 + | _, Imm (Bitv bv) when Bitvector.eqz bv -> e1 + | _ -> raw_binop ty Or e1 e2 + end + | And -> begin + match (e1, e2) with + | Imm (Bitv bv), _ | _, Imm (Bitv bv) -> + if Bitvector.eqz bv then begin + (* Soundness guarantee *) + assert ( + match ty with Ty_bitv m -> m = Bitvector.numbits bv | _ -> false ); + Imm (Bitv bv) + end + else raw_binop ty op e1 e2 + | _ -> raw_binop ty op e1 e2 + end + | Div | DivU | RemU -> begin + match (e1, e2) with + | Imm (Bitv bv), _ when Bitvector.eqz bv -> e1 + | _ -> raw_binop ty op e1 e2 + end + (* Lists *) + | List_append -> begin + match (e1, e2) with + (* Identity *) + | ( (Sym { node = List []; _ } | Imm (List [])) + , (Sym { node = List _; _ } as sym_list) ) + | ( (Sym { node = List _; _ } as sym_list) + , (Sym { node = List []; _ } | Imm (List [])) ) -> + sym_list + | Sym { node = List symbolic_list; _ }, Imm (List concrete_list) -> + make (List (symbolic_list @ List.map value concrete_list)) + | Imm (List concrete_list), Sym { node = List symbolic_list; _ } -> + make (List (List.map value concrete_list @ symbolic_list)) + | Sym { node = List l1; _ }, Sym { node = List l2; _ } -> + make (List (l1 @ l2)) + | _ -> raw_binop ty op e1 e2 + end + | At -> begin + match (e1, e2) with + | Sym { node = List es; _ }, Imm (Int n) -> ( + (* TODO: use another datastructure? *) + match List.nth_opt es n with + | Some v -> v + | None -> assert false ) + | _ -> raw_binop ty At e1 e2 end - | List_cons, _, List es -> make (List (hte1 :: es)) - | List_append, List _, (List [] | Val (List [])) -> hte1 - | List_append, (List [] | Val (List [])), List _ -> hte2 - | List_append, List l0, Val (List l1) -> make (List (l0 @ List.map value l1)) - | List_append, Val (List l0), List l1 -> make (List (List.map value l0 @ l1)) - | List_append, List l0, List l1 -> make (List (l0 @ l1)) - | _ -> raw_binop ty op hte1 hte2 + | List_cons -> begin + match (e1, e2) with + | _, Sym { node = List es; _ } -> make (List (e1 :: es)) + | _ -> raw_binop ty List_cons e1 e2 + end + | _ -> raw_binop ty op e1 e2 + end let raw_triop ty op e1 e2 e3 = make (Triop (ty, op, e1, e2, e3)) [@@inline] let triop ty op e1 e2 e3 = - match (op, view e1, view e2, view e3) with - | Ty.Triop.Ite, Val True, _, _ -> e2 - | Ite, Val False, _, _ -> e3 - | op, Val v1, Val v2, Val v3 -> value (Eval.triop ty op v1 v2 v3) - | Ite, _, Triop (_, Ite, c2, r1, r2), Triop (_, Ite, _, _, _) -> - let else_ = raw_triop ty Ite e1 r2 e3 in + match (op, e1, e2, e3) with + | Ty.Triop.Ite, Imm True, _, _ -> e2 + | Ite, Imm False, _, _ -> e3 + | op, Imm v1, Imm v2, Imm v3 -> + value (Eval.triop ty op v1 v2 v3) + (* The Complex Rewrite: Lifting nested ITEs + Original: if e1 then (if c2 then r1 else r2) else (if ... ) + Target: if (e1 && c2) then r1 else (if e1 then r2 else e3) + *) + | ( Ite + , _ + , Sym { node = Triop (_, Ite, c2, r1, r2); _ } + , Sym { node = Triop (_, Ite, _, _, _); _ } ) -> + let else_part = raw_triop ty Ite e1 r2 e3 in let cond = binop Ty_bool And e1 c2 in - raw_triop ty Ite cond r1 else_ + raw_triop ty Ite cond r1 else_part | _ -> raw_triop ty op e1 e2 e3 let raw_relop ty op hte1 hte2 = make (Relop (ty, op, hte1, hte2)) [@@inline] -let rec relop ty (op : Ty.Relop.t) hte1 hte2 = - let both_phys_eq = phys_equal hte1 hte2 in - let can_be_shortcuted = +let rec relop ty (op : Ty.Relop.t) e1 e2 = + (* "zero cost" check upfront *) + if equal_of_t e1 e2 then match ty with - | Ty.Ty_bool | Ty_bitv _ | Ty_int | Ty_unit -> both_phys_eq - | Ty_fp _ | Ty_app | Ty_list | Ty_real | Ty_regexp | Ty_roundingMode - | Ty_none | Ty_str -> - false - in - match (op, view hte1, view hte2) with - | (Eq | Le | Ge | LeU | GeU), _, _ when can_be_shortcuted -> value True - | (Ne | Lt | Gt | LtU | GtU), _, _ when can_be_shortcuted -> value False - | op, Val v1, Val v2 -> value (if Eval.relop ty op v1 v2 then True else False) - | Ne, Val (Real v), _ | Ne, _, Val (Real v) -> - if Float.is_nan v || Float.is_infinite v then value True - else if both_phys_eq then value False - else raw_relop ty op hte1 hte2 - | _, Val (Real v), _ | _, _, Val (Real v) -> - if Float.is_nan v || Float.is_infinite v then value False - else - (* TODO: it is possible to add a shortcut when `both_phys_eq` *) - raw_relop ty op hte1 hte2 - | Eq, _, Val Nothing | Eq, Val Nothing, _ -> value False - | Ne, _, Val Nothing | Ne, Val Nothing, _ -> value True - | Eq, _, Val (App (`Op "symbol", [ Str _ ])) - | Eq, Val (App (`Op "symbol", [ Str _ ])), _ -> - value False - | Ne, _, Val (App (`Op "symbol", [ Str _ ])) - | Ne, Val (App (`Op "symbol", [ Str _ ])), _ -> - value True - | ( Eq - , Symbol ({ ty = Ty_fp prec1; _ } as s1) - , Symbol ({ ty = Ty_fp prec2; _ } as s2) ) - when both_phys_eq || (prec1 = prec2 && Symbol.equal s1 s2) -> - raw_unop Ty_bool Not (raw_unop (Ty_fp prec1) Is_nan hte1) - | Eq, Ptr { base = b1; offset = os1 }, Ptr { base = b2; offset = os2 } -> - if both_phys_eq then value True - else if Bitvector.equal b1 b2 then relop Ty_bool Eq os1 os2 - else value False - | Ne, Ptr { base = b1; offset = os1 }, Ptr { base = b2; offset = os2 } -> - if both_phys_eq then value False - else if Bitvector.equal b1 b2 then relop Ty_bool Ne os1 os2 - else value True - | ( (LtU | LeU) - , Ptr { base = b1; offset = os1 } - , Ptr { base = b2; offset = os2 } ) -> - if both_phys_eq then value True - else if Bitvector.equal b1 b2 then relop ty op os1 os2 - else - let b1 = Value.Bitv b1 in - let b2 = Value.Bitv b2 in - value (if Eval.relop ty op b1 b2 then True else False) - | ( op - , Val (Bitv _ as n) - , Ptr { base; offset = { node = Val (Bitv _ as o); _ } } ) -> - let base = Eval.binop (Ty_bitv 32) Add (Bitv base) o in - value (if Eval.relop ty op n base then True else False) - | op, Ptr { base; offset = { node = Val (Bitv _ as o); _ } }, Val (Bitv _ as n) - -> - let base = Eval.binop (Ty_bitv 32) Add (Bitv base) o in - value (if Eval.relop ty op base n then True else False) - | op, List l1, List l2 -> relop_list op l1 l2 - | Gt, _, _ -> relop ty Lt hte2 hte1 - | GtU, _, _ -> relop ty LtU hte2 hte1 - | Ge, _, _ -> relop ty Le hte2 hte1 - | GeU, _, _ -> relop ty LeU hte2 hte1 - | _, _, _ -> raw_relop ty op hte1 hte2 + | Ty.Ty_bool | Ty_bitv _ | Ty_int | Ty_unit -> begin + match op with + | Eq | Le | Ge | LeU | GeU -> Imm True + | Ne | Lt | Gt | LtU | GtU -> Imm False + end + | _ -> dispatch_relop ty op e1 e2 + else dispatch_relop ty op e1 e2 + +and dispatch_relop ty op e1 e2 = + match (e1, e2) with + | Imm v1, Imm v2 -> Imm (if Eval.relop ty op v1 v2 then True else False) + | _ -> begin + match (op, e1, e2) with + | Gt, _, _ -> relop ty Lt e2 e1 + | GtU, _, _ -> relop ty LtU e2 e1 + | Ge, _, _ -> relop ty Le e2 e1 + | GeU, _, _ -> relop ty LeU e2 e1 + | Ne, Imm (Real v), _ | Ne, _, Imm (Real v) -> + if Float.is_nan v || Float.is_infinite v then value True + else raw_relop ty op e1 e2 + | _, Imm (Real v), _ | _, _, Imm (Real v) -> + if Float.is_nan v || Float.is_infinite v then value False + else raw_relop ty op e1 e2 + | Eq, _, Imm Nothing | Eq, Imm Nothing, _ -> value False + | Ne, _, Imm Nothing | Ne, Imm Nothing, _ -> value True + | Eq, _, Imm (App (`Op "symbol", [ Str _ ])) + | Eq, Imm (App (`Op "symbol", [ Str _ ])), _ -> + value False + | Ne, _, Imm (App (`Op "symbol", [ Str _ ])) + | Ne, Imm (App (`Op "symbol", [ Str _ ])), _ -> + value True + | ( Eq + , Sym { node = Symbol ({ ty = Ty_fp prec1; _ } as s1); _ } + , Sym { node = Symbol ({ ty = Ty_fp prec2; _ } as s2); _ } ) + when prec1 = prec2 && Symbol.equal s1 s2 -> + raw_unop Ty_bool Not (raw_unop (Ty_fp prec1) Is_nan e1) + | Eq, Sym { node = Ptr p1; _ }, Sym { node = Ptr p2; _ } -> + if Bitvector.equal p1.base p2.base then + relop Ty_bool Eq p1.offset p2.offset + else value False + | Ne, Sym { node = Ptr p1; _ }, Sym { node = Ptr p2; _ } -> + if Bitvector.equal p1.base p2.base then + relop Ty_bool Ne p1.offset p2.offset + else value True + | (LtU | LeU), Sym { node = Ptr p1; _ }, Sym { node = Ptr p2; _ } -> + if Bitvector.equal p1.base p2.base then relop ty op p1.offset p2.offset + else + let b1 = Value.Bitv p1.base in + let b2 = Value.Bitv p2.base in + value (if Eval.relop ty op b1 b2 then True else False) + | ( op + , Imm (Bitv _ as n) + , Sym { node = Ptr { base; offset = Imm (Bitv _ as o) }; _ } ) -> + let base = Eval.binop (Ty_bitv 32) Add (Bitv base) o in + value (if Eval.relop ty op n base then True else False) + | ( op + , Sym { node = Ptr { base; offset = Imm (Bitv _ as o) }; _ } + , Imm (Bitv _ as n) ) -> + let base = Eval.binop (Ty_bitv 32) Add (Bitv base) o in + value (if Eval.relop ty op base n then True else False) + | op, Sym { node = List l1; _ }, Sym { node = List l2; _ } -> + relop_list op l1 l2 + | _, _, _ -> raw_relop ty op e1 e2 + end and relop_list op l1 l2 = - match (op, l1, l2) with - | Eq, [], [] -> value True - | Eq, _, [] | Eq, [], _ -> value False - | Eq, l1, l2 -> - if not (List.compare_lengths l1 l2 = 0) then value False - else - List.fold_left2 - (fun acc a b -> - binop Ty_bool And acc - @@ - match (ty a, ty b) with - | Ty_real, Ty_real -> relop Ty_real Eq a b - | _ -> relop Ty_bool Eq a b ) - (value True) l1 l2 - | Ne, _, _ -> unop Ty_bool Not @@ relop_list Eq l1 l2 - | (Lt | LtU | Gt | GtU | Le | LeU | Ge | GeU), _, _ -> assert false + match op with + | Ne -> + (* Implement Ne as !(Eq) *) + unop Ty_bool Not (relop_list Eq l1 l2) + | Eq -> + let rec loop l1 l2 acc = + match (l1, l2) with + | [], [] -> acc + | [], _ | _, [] -> Imm False + | x :: xs, y :: ys -> ( + let use_real = + match (ty x, ty y) with Ty_real, Ty_real -> true | _ -> false + in + let res = relop (if use_real then Ty_real else Ty_bool) Eq x y in + match res with + | Imm False -> Imm False + | Imm True -> loop xs ys acc + (* If Symbolic, accumulate it and continue. *) + | _ -> loop xs ys (binop Ty_bool And acc res) ) + in + loop l1 l2 (Imm True) + | _ -> assert false let raw_cvtop ty op hte = make (Cvtop (ty, op, hte)) [@@inline] -let rec cvtop theory op hte = - match (op, view hte) with - | Ty.Cvtop.String_to_re, _ -> raw_cvtop theory op hte - | _, Val v -> value (Eval.cvtop theory op v) - | String_to_float, Cvtop (Ty_real, ToString, hte) -> hte - | ( Reinterpret_float - , Cvtop (Ty_real, Reinterpret_int, { node = Symbol { ty = Ty_int; _ }; _ }) - ) -> - hte - | Zero_extend n, Ptr { base; offset } -> - let offset = cvtop theory op offset in - make (Ptr { base = Bitvector.zero_extend n base; offset }) - | WrapI64, Ptr { base; offset } -> - let offset = cvtop theory op offset in - make (Ptr { base = Bitvector.extract base ~high:31 ~low:0; offset }) - | WrapI64, Cvtop (Ty_bitv 64, Zero_extend 32, hte) -> - assert (Ty.equal theory (ty hte) && Ty.equal theory (Ty_bitv 32)); - hte - | _ -> raw_cvtop theory op hte +let rec cvtop theory op e = + match e with + | Imm v -> value (Eval.cvtop theory op v) + | Sym hte -> begin + match (op, view hte) with + | String_to_float, Cvtop (Ty_real, ToString, hte) -> hte + | ( Reinterpret_float + , Cvtop + (Ty_real, Reinterpret_int, Sym { node = Symbol { ty = Ty_int; _ }; _ }) + ) -> + e + | Zero_extend n, Ptr ptr -> + let offset = cvtop theory op ptr.offset in + make (Ptr { base = Bitvector.zero_extend n ptr.base; offset }) + | WrapI64, Ptr ptr -> + let offset = cvtop theory op ptr.offset in + make (Ptr { base = Bitvector.extract ptr.base ~high:31 ~low:0; offset }) + | WrapI64, Cvtop (Ty_bitv 64, Zero_extend 32, hte) -> + assert (Ty.equal theory (ty hte) && Ty.equal theory (Ty_bitv 32)); + hte + | _ -> raw_cvtop theory op e + end let raw_naryop ty op es = make (Naryop (ty, op, es)) [@@inline] let naryop ty op es = - if List.for_all (fun e -> match view e with Val _ -> true | _ -> false) es - then - let vs = - List.map (fun e -> match view e with Val v -> v | _ -> assert false) es - in - value (Eval.naryop ty op vs) - else - match (ty, op, List.map view es) with + let rec get_values acc v = + match v with + | [] -> Some (List.rev acc) + | Imm v :: tl -> get_values (v :: acc) tl + | _ -> None + in + match get_values [] es with + | Some vs -> value (Eval.naryop ty op vs) + | None -> ( + match (ty, op, es) with | ( Ty_str , Concat - , [ Naryop (Ty_str, Concat, l1); Naryop (Ty_str, Concat, l2) ] ) -> + , [ Sym { node = Naryop (Ty_str, Concat, l1); _ } + ; Sym { node = Naryop (Ty_str, Concat, l2); _ } + ] ) -> raw_naryop Ty_str Concat (l1 @ l2) - | Ty_str, Concat, [ Naryop (Ty_str, Concat, htes); hte ] -> - raw_naryop Ty_str Concat (htes @ [ make hte ]) - | Ty_str, Concat, [ hte; Naryop (Ty_str, Concat, htes) ] -> - raw_naryop Ty_str Concat (make hte :: htes) - | _ -> raw_naryop ty op es + | Ty_str, Concat, [ Sym { node = Naryop (Ty_str, Concat, htes); _ }; hte ] + -> + raw_naryop Ty_str Concat (htes @ [ hte ]) + | Ty_str, Concat, [ hte; Sym { node = Naryop (Ty_str, Concat, htes); _ } ] + -> + raw_naryop Ty_str Concat (hte :: htes) + | _ -> raw_naryop ty op es ) let[@inline] raw_extract (hte : t) ~(high : int) ~(low : int) : t = make (Extract (hte, high, low)) -let extract (hte : t) ~(high : int) ~(low : int) : t = - match (view hte, high, low) with - | Val (Bitv bv), high, low -> +let extract (e : t) ~(high : int) ~(low : int) : t = + match e with + | Imm (Bitv bv) -> let high = (high * 8) - 1 in let low = low * 8 in value (Bitv (Bitvector.extract bv ~high ~low)) - | ( Cvtop - ( _ - , (Zero_extend 24 | Sign_extend 24) - , ({ node = Symbol { ty = Ty_bitv 8; _ }; _ } as sym) ) - , 1 - , 0 ) -> - sym - | Concat (_, e), h, l when Ty.size (ty e) = h - l -> e - | Concat (e, _), 8, 4 when Ty.size (ty e) = 4 -> e - | _ -> - if high - low = Ty.size (ty hte) then hte else raw_extract hte ~high ~low + | Imm _ -> assert false + | Sym hte -> begin + match (view hte, high, low) with + | ( Cvtop + ( _ + , (Zero_extend 24 | Sign_extend 24) + , (Sym { node = Symbol { ty = Ty_bitv 8; _ }; _ } as sym) ) + , 1 + , 0 ) -> + sym + | Concat (_, e), h, l when Ty.size (ty e) = h - l -> e + | Concat (e, _), 8, 4 when Ty.size (ty e) = 4 -> e + | _ -> if high - low = Ty.size (ty e) then e else raw_extract e ~high ~low + end let raw_concat (msb : t) (lsb : t) : t = make (Concat (msb, lsb)) [@@inline] (* TODO: don't rebuild so many values it generates unecessary hc lookups *) let rec concat (msb : t) (lsb : t) : t = - match (view msb, view lsb) with - | Val (Bitv a), Val (Bitv b) -> value (Bitv (Bitvector.concat a b)) - | Val (Bitv _), Concat (({ node = Val (Bitv _); _ } as b), se) -> + match (msb, lsb) with + | Imm (Bitv a), Imm (Bitv b) -> value (Bitv (Bitvector.concat a b)) + | Imm (Bitv _), Sym { node = Concat ((Imm (Bitv _) as b), se); _ } -> raw_concat (concat msb b) se - | Extract (s1, h, m1), Extract (s2, m2, l) when equal s1 s2 && m1 = m2 -> + | Sym { node = Extract (s1, h, m1); _ }, Sym { node = Extract (s2, m2, l); _ } + when equal s1 s2 && m1 = m2 -> if h - l = Ty.size (ty s1) then s1 else raw_extract s1 ~high:h ~low:l - | Extract (_, _, _), Concat (({ node = Extract (_, _, _); _ } as e2), e3) -> + | ( Sym { node = Extract (_, _, _); _ } + , Sym { node = Concat ((Sym { node = Extract (_, _, _); _ } as e2), e3); _ } + ) -> raw_concat (concat msb e2) e3 | _ -> raw_concat msb lsb -let rec simplify_expr ?(in_relop = false) (hte : t) : t = - match view hte with - | Val _ | Symbol _ | Loc _ -> hte - | Ptr { base; offset } -> - let offset = simplify_expr ~in_relop offset in - if not in_relop then make (Ptr { base; offset }) - else binop (Ty_bitv 32) Add (value (Bitv base)) offset - | List es -> make @@ List (List.map (simplify_expr ~in_relop) es) - | App (x, es) -> make @@ App (x, List.map (simplify_expr ~in_relop) es) - | Unop (ty, op, e) -> - let e = simplify_expr ~in_relop e in - unop ty op e - | Binop (ty, op, e1, e2) -> - let e1 = simplify_expr ~in_relop e1 in - let e2 = simplify_expr ~in_relop e2 in - binop ty op e1 e2 - | Relop (ty, op, e1, e2) -> - let e1 = simplify_expr ~in_relop:true e1 in - let e2 = simplify_expr ~in_relop:true e2 in - relop ty op e1 e2 - | Triop (ty, op, c, e1, e2) -> - let c = simplify_expr ~in_relop c in - let e1 = simplify_expr ~in_relop e1 in - let e2 = simplify_expr ~in_relop e2 in - triop ty op c e1 e2 - | Cvtop (ty, op, e) -> - let e = simplify_expr ~in_relop e in - cvtop ty op e - | Naryop (ty, op, es) -> - let es = List.map (simplify_expr ~in_relop) es in - naryop ty op es - | Extract (s, high, low) -> - let s = simplify_expr ~in_relop s in - extract s ~high ~low - | Concat (e1, e2) -> - let msb = simplify_expr ~in_relop e1 in - let lsb = simplify_expr ~in_relop e2 in - concat msb lsb - | Binder _ -> - (* Not simplifying anything atm *) - hte +let rec simplify_expr ?(in_relop = false) (e : t) : t = + match e with + | Imm _ -> e + | Sym hte -> begin + match view hte with + | Symbol _ -> e + | Ptr { base; offset } -> + let offset = simplify_expr ~in_relop offset in + if not in_relop then make (Ptr { base; offset }) + else binop (Ty_bitv 32) Add (value (Bitv base)) offset + | List es -> make @@ List (List.map (simplify_expr ~in_relop) es) + | App (x, es) -> make @@ App (x, List.map (simplify_expr ~in_relop) es) + | Unop (ty, op, e) -> + let e = simplify_expr ~in_relop e in + unop ty op e + | Binop (ty, op, e1, e2) -> + let e1 = simplify_expr ~in_relop e1 in + let e2 = simplify_expr ~in_relop e2 in + binop ty op e1 e2 + | Relop (ty, op, e1, e2) -> + let e1 = simplify_expr ~in_relop:true e1 in + let e2 = simplify_expr ~in_relop:true e2 in + relop ty op e1 e2 + | Triop (ty, op, c, e1, e2) -> + let c = simplify_expr ~in_relop c in + let e1 = simplify_expr ~in_relop e1 in + let e2 = simplify_expr ~in_relop e2 in + triop ty op c e1 e2 + | Cvtop (ty, op, e) -> + let e = simplify_expr ~in_relop e in + cvtop ty op e + | Naryop (ty, op, es) -> + let es = List.map (simplify_expr ~in_relop) es in + naryop ty op es + | Extract (s, high, low) -> + let s = simplify_expr ~in_relop s in + extract s ~high ~low + | Concat (e1, e2) -> + let msb = simplify_expr ~in_relop e1 in + let lsb = simplify_expr ~in_relop e2 in + concat msb lsb + | Binder _ -> + (* Not simplifying anything atm *) + e + end module Cache = Hashtbl.Make (struct type nonrec t = t @@ -660,8 +783,8 @@ module Bool = struct open Ty let of_val = function - | Val True -> Some true - | Val False -> Some false + | Imm True -> Some true + | Imm False -> Some false | _ -> None let true_ = value True @@ -673,33 +796,32 @@ module Bool = struct let v b = to_val b [@@inline] let not b = - let bexpr = view b in - match of_val bexpr with + match of_val b with | Some b -> to_val (not b) | None -> ( - match bexpr with - | Unop (Ty_bool, Not, cond) -> cond + match b with + | Sym { node = Unop (Ty_bool, Not, cond); _ } -> cond | _ -> unop Ty_bool Not b ) let equal b1 b2 = - match (view b1, view b2) with - | Val True, Val True | Val False, Val False -> true_ + match (b1, b2) with + | Imm True, Imm True | Imm False, Imm False -> true_ | _ -> relop Ty_bool Eq b1 b2 let distinct b1 b2 = - match (view b1, view b2) with - | Val True, Val False | Val False, Val True -> true_ + match (b1, b2) with + | Imm True, Imm False | Imm False, Imm True -> true_ | _ -> relop Ty_bool Ne b1 b2 let and_ b1 b2 = - match (of_val (view b1), of_val (view b2)) with + match (of_val b1, of_val b2) with | Some true, _ -> b2 | _, Some true -> b1 | Some false, _ | _, Some false -> false_ | _ -> binop Ty_bool And b1 b2 let or_ b1 b2 = - match (of_val (view b1), of_val (view b2)) with + match (of_val b1, of_val b2) with | Some false, _ -> b2 | _, Some false -> b1 | Some true, _ | _, Some true -> true_ @@ -801,78 +923,92 @@ end module Smtlib = struct let rec pp fmt (hte : t) = - match view hte with - | Val v -> Value.Smtlib.pp fmt v - | Ptr _ -> assert false - | Loc _ -> assert false - | Symbol s -> Fmt.pf fmt "@[%a@]" Symbol.pp s - | List _ -> assert false - | App _ -> assert false - | Unop (ty, op, e) -> - Fmt.pf fmt "@[(%a@ %a)@]" Ty.Smtlib.pp_unop (ty, op) pp e - | Binop (ty, op, e1, e2) -> - Fmt.pf fmt "@[(%a@ %a@ %a)@]" Ty.Smtlib.pp_binop (ty, op) pp e1 pp - e2 - | Triop _ -> assert false - | Relop (ty, op, e1, e2) -> - Fmt.pf fmt "@[(%a@ %a@ %a)@]" Ty.Smtlib.pp_relop (ty, op) pp e1 pp - e2 - | Cvtop _ -> assert false - | Naryop _ -> assert false - | Extract _ -> assert false - | Concat _ -> assert false - | Binder _ -> assert false + match hte with + | Imm v -> Value.Smtlib.pp fmt v + | Sym hte -> begin + match view hte with + | Ptr _ -> assert false + | Symbol s -> Fmt.pf fmt "@[%a@]" Symbol.pp s + | List _ -> assert false + | App _ -> assert false + | Unop (ty, op, e) -> + Fmt.pf fmt "@[(%a@ %a)@]" Ty.Smtlib.pp_unop (ty, op) pp e + | Binop (ty, op, e1, e2) -> + Fmt.pf fmt "@[(%a@ %a@ %a)@]" Ty.Smtlib.pp_binop (ty, op) pp e1 + pp e2 + | Triop _ -> assert false + | Relop (ty, op, e1, e2) -> + Fmt.pf fmt "@[(%a@ %a@ %a)@]" Ty.Smtlib.pp_relop (ty, op) pp e1 + pp e2 + | Cvtop _ -> assert false + | Naryop _ -> assert false + | Extract _ -> assert false + | Concat _ -> assert false + | Binder _ -> assert false + end end let inline_symbol_values map e = let rec aux e = - match view e with - | Val _ | Loc _ -> e - | Symbol symbol -> Option.value ~default:e (Symbol.Map.find_opt symbol map) - | Ptr e -> - let offset = aux e.offset in - make @@ Ptr { e with offset } - | List vs -> - let vs = List.map aux vs in - list vs - | App (x, vs) -> - let vs = List.map aux vs in - app x vs - | Unop (ty, op, v) -> - let v = aux v in - unop ty op v - | Binop (ty, op, v1, v2) -> - let v1 = aux v1 in - let v2 = aux v2 in - binop ty op v1 v2 - | Triop (ty, op, v1, v2, v3) -> - let v1 = aux v1 in - let v2 = aux v2 in - let v3 = aux v3 in - triop ty op v1 v2 v3 - | Cvtop (ty, op, v) -> - let v = aux v in - cvtop ty op v - | Relop (ty, op, v1, v2) -> - let v1 = aux v1 in - let v2 = aux v2 in - relop ty op v1 v2 - | Naryop (ty, op, vs) -> - let vs = List.map aux vs in - naryop ty op vs - | Extract (e, high, low) -> - let e = aux e in - extract e ~high ~low - | Concat (e1, e2) -> - let e1 = aux e1 in - let e2 = aux e2 in - concat e1 e2 - | Binder (b, vars, e) -> - let e = aux e in - binder b vars e + match e with + | Imm _ -> e + | Sym hte -> begin + match view hte with + | Symbol symbol -> + Option.value ~default:e (Symbol.Map.find_opt symbol map) + | Ptr e -> + let offset = aux e.offset in + make @@ Ptr { e with offset } + | List vs -> + let vs = List.map aux vs in + list vs + | App (x, vs) -> + let vs = List.map aux vs in + app x vs + | Unop (ty, op, v) -> + let v = aux v in + unop ty op v + | Binop (ty, op, v1, v2) -> + let v1 = aux v1 in + let v2 = aux v2 in + binop ty op v1 v2 + | Triop (ty, op, v1, v2, v3) -> + let v1 = aux v1 in + let v2 = aux v2 in + let v3 = aux v3 in + triop ty op v1 v2 v3 + | Cvtop (ty, op, v) -> + let v = aux v in + cvtop ty op v + | Relop (ty, op, v1, v2) -> + let v1 = aux v1 in + let v2 = aux v2 in + relop ty op v1 v2 + | Naryop (ty, op, vs) -> + let vs = List.map aux vs in + naryop ty op vs + | Extract (e, high, low) -> + let e = aux e in + extract e ~high ~low + | Concat (e1, e2) -> + let e1 = aux e1 in + let e2 = aux e2 in + concat e1 e2 + | Binder (b, vars, e) -> + let e = aux e in + binder b vars e + end in aux e +module Key = struct + type nonrec t = t + + let to_int hte = hash hte + + let compare x y = Prelude.compare (to_int x) (to_int y) +end + module Set = struct include Set.Make (Key) @@ -890,32 +1026,35 @@ module Set = struct let get_symbols (set : t) = let tbl = Hashtbl.create 64 in let rec symbols hte = - match view hte with - | Val _ | Loc _ -> () - | Ptr { offset; _ } -> symbols offset - | Symbol s -> Hashtbl.replace tbl s () - | List es -> List.iter symbols es - | App (_, es) -> List.iter symbols es - | Unop (_, _, e1) -> symbols e1 - | Binop (_, _, e1, e2) -> - symbols e1; - symbols e2 - | Triop (_, _, e1, e2, e3) -> - symbols e1; - symbols e2; - symbols e3 - | Relop (_, _, e1, e2) -> - symbols e1; - symbols e2 - | Cvtop (_, _, e) -> symbols e - | Naryop (_, _, es) -> List.iter symbols es - | Extract (e, _, _) -> symbols e - | Concat (e1, e2) -> - symbols e1; - symbols e2 - | Binder (_, vars, e) -> - List.iter symbols vars; - symbols e + match hte with + | Imm _ -> () + | Sym hte -> begin + match view hte with + | Ptr { offset; _ } -> symbols offset + | Symbol s -> Hashtbl.replace tbl s () + | List es -> List.iter symbols es + | App (_, es) -> List.iter symbols es + | Unop (_, _, e1) -> symbols e1 + | Binop (_, _, e1, e2) -> + symbols e1; + symbols e2 + | Triop (_, _, e1, e2, e3) -> + symbols e1; + symbols e2; + symbols e3 + | Relop (_, _, e1, e2) -> + symbols e1; + symbols e2 + | Cvtop (_, _, e) -> symbols e + | Naryop (_, _, es) -> List.iter symbols es + | Extract (e, _, _) -> symbols e + | Concat (e1, e2) -> + symbols e1; + symbols e2 + | Binder (_, vars, e) -> + List.iter symbols vars; + symbols e + end in iter symbols set; Hashtbl.fold (fun k () acc -> k :: acc) tbl [] @@ -932,16 +1071,16 @@ module Set = struct end let rec split_conjunctions (e : t) : Set.t = - match view e with - | Binop (Ty_bool, And, e1, e2) -> + match e with + | Sym { node = Binop (Ty_bool, And, e1, e2); _ } -> let s1 = split_conjunctions e1 in let s2 = split_conjunctions e2 in Set.union s1 s2 | _ -> Set.singleton e let rec split_disjunctions (e : t) : Set.t = - match view e with - | Binop (Ty_bool, Or, e1, e2) -> + match e with + | Sym { node = Binop (Ty_bool, Or, e1, e2); _ } -> let s1 = split_disjunctions e1 in let s2 = split_disjunctions e2 in Set.union s1 s2 diff --git a/src/smtml/expr_intf.ml b/src/smtml/expr_intf.ml index a024eefd..237a7129 100644 --- a/src/smtml/expr_intf.ml +++ b/src/smtml/expr_intf.ml @@ -12,16 +12,16 @@ module type S = sig (** {1 Expression Types} *) (** A term in the abstract syntax tree. *) - type t = expr Hc.hash_consed + type t = + | Imm of Value.t (** A constant value. *) + | Sym of expr Hc.hash_consed (** The different types of expressions. *) and expr = private - | Val of Value.t (** A constant value. *) | Ptr of { base : Bitvector.t (** Base address. *) ; offset : t (** Offset from base. *) } - | Loc of Loc.t (** Abstract location *) | Symbol of Symbol.t (** A symbolic variable. *) | List of t list (** A list of expressions. *) | App of Symbol.t * t list (** Function application. *) @@ -38,7 +38,7 @@ module type S = sig (** {1 Constructors and Accessors} *) (** [view term] extracts the underlying expression from a term. *) - val view : t -> expr + val view : expr Hc.hash_consed -> expr (** [hash term] computes the hash of a term. *) val hash : t -> int @@ -96,9 +96,6 @@ module type S = sig address and offset. *) val ptr : int32 -> t -> t - (** [loc l] constructs an abstract location *) - val loc : Loc.t -> t - (** [list l] constructs a list expression with the given list of expressions *) val list : t list -> t diff --git a/src/smtml/expr_raw.mli b/src/smtml/expr_raw.mli index 9a5af4fe..a8ad4403 100644 --- a/src/smtml/expr_raw.mli +++ b/src/smtml/expr_raw.mli @@ -2,5 +2,4 @@ (* Copyright (C) 2023-2025 formalsec *) (* Written by the Smtml programmers *) -include - Expr_intf.S with type expr = Expr.expr and type t = Expr.expr Hc.hash_consed +include Expr_intf.S diff --git a/src/smtml/feature_extraction.ml b/src/smtml/feature_extraction.ml index 21270c7e..659668cc 100644 --- a/src/smtml/feature_extraction.ml +++ b/src/smtml/feature_extraction.ml @@ -152,23 +152,25 @@ let string_of_naryop (naryop : Ty.Naryop.t) : string = | Concat -> "Concat" | Regexp_union -> "Regexp_union" -let string_of_expr_kind (e : Expr.expr) _ty : string = +let string_of_expr_kind (e : Expr.t) _ty : string = match e with - | Val _ -> "Val" - | Ptr _ -> "Ptr" - | Loc _ -> "Loc" - | Symbol _ -> "Symbol" - | List _ -> "List" - | App _ -> "App" - | Unop _ -> "Unop" - | Binop _ -> "Binop" - | Triop _ -> "Triop" - | Relop _ -> "Relop" - | Cvtop _ -> "Cvtop" - | Naryop _ -> "Naryop" - | Extract _ -> "Extract" - | Concat _ -> "Concat" - | Binder _ -> "Binder" + | Imm _ -> "Val" + | Sym hte -> begin + match Expr.view hte with + | Ptr _ -> "Ptr" + | Symbol _ -> "Symbol" + | List _ -> "List" + | App _ -> "App" + | Unop _ -> "Unop" + | Binop _ -> "Binop" + | Triop _ -> "Triop" + | Relop _ -> "Relop" + | Cvtop _ -> "Cvtop" + | Naryop _ -> "Naryop" + | Extract _ -> "Extract" + | Concat _ -> "Concat" + | Binder _ -> "Binder" + end (* Define all constructors you want to track *) let ctor_names = @@ -344,73 +346,76 @@ let extract_feats : Expr.t -> int StringMap.t = feats in let rec visit depth feats (e : Expr.t) = - let feats = incr_feat feats (string_of_expr_kind e.node (Expr.ty e)) in - match e.node with - | Val _ | Symbol _ -> (depth, feats) - | Ptr { offset; _ } -> visit (depth + 1) feats offset - | List lst -> - List.fold_left - (fun (depth, feats) e -> - let depth', feats = visit depth feats e in - (Int.max depth depth', feats) ) - (depth + 1, feats) - lst - | Naryop (ty, naryop, lst) -> - let feats = incr_feat feats (string_of_ty ty) in - let feats = incr_feat feats (string_of_naryop naryop) in - List.fold_left - (fun (depth, feats) e -> - let depth', feats = visit depth feats e in - (Int.max depth depth', feats) ) - (depth + 1, feats) - lst - | App (_, lst) -> - List.fold_left - (fun (depth, feats) e -> - let depth', feats = visit depth feats e in - (Int.max depth depth', feats) ) - (depth + 1, feats) - lst - | Unop (ty, unop, t) -> - let feats = incr_feat feats (string_of_ty ty) in - let feats = incr_feat feats (string_of_unop unop) in - visit (depth + 1) feats t - | Cvtop (ty, cvtop, t) -> - let feats = incr_feat feats (string_of_ty ty) in - let feats = incr_feat feats (string_of_cvtop cvtop) in - visit (depth + 1) feats t - | Extract (t, _, _) -> visit (depth + 1) feats t - | Binop (ty, binop, e1, e2) -> - let feats = incr_feat feats (string_of_ty ty) in - let feats = incr_feat feats (string_of_binop binop) in - let depth1, feats = visit (depth + 1) feats e1 in - let depth2, feats = visit (depth + 1) feats e2 in - (Int.max depth1 depth2, feats) - | Relop (ty, relop, e1, e2) -> - let feats = incr_feat feats (string_of_ty ty) in - let feats = incr_feat feats (string_of_relop relop) in - let depth1, feats = visit (depth + 1) feats e1 in - let depth2, feats = visit (depth + 1) feats e2 in - (Int.max depth1 depth2, feats) - | Concat (e1, e2) -> - let depth1, feats = visit (depth + 1) feats e1 in - let depth2, feats = visit (depth + 1) feats e2 in - (Int.max depth1 depth2, feats) - | Triop (ty, triop, e1, e2, e3) -> - let feats = incr_feat feats (string_of_ty ty) in - let feats = incr_feat feats (string_of_triop triop) in - let depth1, feats = visit (depth + 1) feats e1 in - let depth2, feats = visit (depth + 1) feats e2 in - let depth3, feats = visit (depth + 1) feats e3 in - (Int.max (Int.max depth1 depth2) depth3, feats) - | Binder (_, lst, t) -> - List.fold_left - (fun (depth, feats) e -> - let depth', feats = visit depth feats e in - (Int.max depth depth', feats) ) - (depth + 1, feats) - (t :: lst) - | Loc _ -> assert false + let feats = incr_feat feats (string_of_expr_kind e (Expr.ty e)) in + match e with + | Imm _ -> (depth, feats) + | Sym hte -> begin + match Expr.view hte with + | Symbol _ -> (depth, feats) + | Ptr { offset; _ } -> visit (depth + 1) feats offset + | List lst -> + List.fold_left + (fun (depth, feats) e -> + let depth', feats = visit depth feats e in + (Int.max depth depth', feats) ) + (depth + 1, feats) + lst + | Naryop (ty, naryop, lst) -> + let feats = incr_feat feats (string_of_ty ty) in + let feats = incr_feat feats (string_of_naryop naryop) in + List.fold_left + (fun (depth, feats) e -> + let depth', feats = visit depth feats e in + (Int.max depth depth', feats) ) + (depth + 1, feats) + lst + | App (_, lst) -> + List.fold_left + (fun (depth, feats) e -> + let depth', feats = visit depth feats e in + (Int.max depth depth', feats) ) + (depth + 1, feats) + lst + | Unop (ty, unop, t) -> + let feats = incr_feat feats (string_of_ty ty) in + let feats = incr_feat feats (string_of_unop unop) in + visit (depth + 1) feats t + | Cvtop (ty, cvtop, t) -> + let feats = incr_feat feats (string_of_ty ty) in + let feats = incr_feat feats (string_of_cvtop cvtop) in + visit (depth + 1) feats t + | Extract (t, _, _) -> visit (depth + 1) feats t + | Binop (ty, binop, e1, e2) -> + let feats = incr_feat feats (string_of_ty ty) in + let feats = incr_feat feats (string_of_binop binop) in + let depth1, feats = visit (depth + 1) feats e1 in + let depth2, feats = visit (depth + 1) feats e2 in + (Int.max depth1 depth2, feats) + | Relop (ty, relop, e1, e2) -> + let feats = incr_feat feats (string_of_ty ty) in + let feats = incr_feat feats (string_of_relop relop) in + let depth1, feats = visit (depth + 1) feats e1 in + let depth2, feats = visit (depth + 1) feats e2 in + (Int.max depth1 depth2, feats) + | Concat (e1, e2) -> + let depth1, feats = visit (depth + 1) feats e1 in + let depth2, feats = visit (depth + 1) feats e2 in + (Int.max depth1 depth2, feats) + | Triop (ty, triop, e1, e2, e3) -> + let feats = incr_feat feats (string_of_ty ty) in + let feats = incr_feat feats (string_of_triop triop) in + let depth1, feats = visit (depth + 1) feats e1 in + let depth2, feats = visit (depth + 1) feats e2 in + let depth3, feats = visit (depth + 1) feats e3 in + (Int.max (Int.max depth1 depth2) depth3, feats) + | Binder (_, lst, t) -> + List.fold_left + (fun (depth, feats) e -> + let depth', feats = visit depth feats e in + (Int.max depth depth', feats) ) + (depth + 1, feats) + (t :: lst) + end in fun expr -> let depth, feats = visit 1 StringMap.empty expr in diff --git a/src/smtml/interpret.ml b/src/smtml/interpret.ml index b6e4b365..000b2d7e 100644 --- a/src/smtml/interpret.ml +++ b/src/smtml/interpret.ml @@ -84,12 +84,12 @@ module Make (Solver : Solver_intf.S) = struct loop (eval stmt { state with stmts } ~no_strict_status) ~no_strict_status let parse_status (t : Expr.t) : [ `Sat | `Unsat | `Unknown ] option = - match Expr.view t with - | App ({ name = Simple ":status"; _ }, [ st ]) -> ( - match Expr.view st with - | Symbol { name = Simple "sat"; _ } -> Some `Sat - | Symbol { name = Simple "unsat"; _ } -> Some `Unsat - | Symbol { name = Simple "unknown"; _ } -> Some `Unknown + match t with + | Expr.Sym { node = App ({ name = Simple ":status"; _ }, [ st ]); _ } -> ( + match st with + | Sym { node = Symbol { name = Simple "sat"; _ }; _ } -> Some `Sat + | Sym { node = Symbol { name = Simple "unsat"; _ }; _ } -> Some `Unsat + | Sym { node = Symbol { name = Simple "unknown"; _ }; _ } -> Some `Unknown | _ -> Log.debug (fun k -> k "Unrecognised status value: %a" Expr.pp st); None ) diff --git a/src/smtml/loc.ml b/src/smtml/loc.ml deleted file mode 100644 index 6f6d61cb..00000000 --- a/src/smtml/loc.ml +++ /dev/null @@ -1,20 +0,0 @@ -(* SPDX-License-Identifier: MIT *) -(* Copyright (C) 2023-2024 formalsec *) -(* Written by the Smtml programmers *) - -type t = int - -let fresh = - let next = ref 0 in - fun () -> - let id = !next in - incr next; - id - -let compare = Int.compare - -let equal a b = compare a b = 0 - -let hash a = a - -let pp = Fmt.int diff --git a/src/smtml/loc.mli b/src/smtml/loc.mli deleted file mode 100644 index f20052a4..00000000 --- a/src/smtml/loc.mli +++ /dev/null @@ -1,11 +0,0 @@ -type t - -val fresh : unit -> t - -val compare : t -> t -> int - -val equal : t -> t -> bool - -val hash : t -> int - -val pp : t Fmt.t diff --git a/src/smtml/mappings.ml b/src/smtml/mappings.ml index 69d9bc7a..3c7ae712 100644 --- a/src/smtml/mappings.ml +++ b/src/smtml/mappings.ml @@ -687,126 +687,145 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct Ty.pp ty ty let get_rounding_mode ctx rm = - match Expr.view rm with - | Symbol { name = Simple ("roundNearestTiesToEven" | "RNE"); _ } -> + match rm with + | Expr.Sym + { node = + Symbol { name = Simple ("roundNearestTiesToEven" | "RNE"); _ } + ; _ + } -> (ctx, M.Float.Rounding_mode.rne) - | Symbol { name = Simple ("roundNearestTiesToAway" | "RNA"); _ } -> + | Sym + { node = + Symbol { name = Simple ("roundNearestTiesToAway" | "RNA"); _ } + ; _ + } -> (ctx, M.Float.Rounding_mode.rna) - | Symbol { name = Simple ("roundTowardPositive" | "RTP"); _ } -> + | Sym + { node = Symbol { name = Simple ("roundTowardPositive" | "RTP"); _ } + ; _ + } -> (ctx, M.Float.Rounding_mode.rtp) - | Symbol { name = Simple ("roundTowardNegative" | "RTN"); _ } -> + | Sym + { node = Symbol { name = Simple ("roundTowardNegative" | "RTN"); _ } + ; _ + } -> (ctx, M.Float.Rounding_mode.rtn) - | Symbol { name = Simple ("roundTowardZero" | "RTZ"); _ } -> + | Sym + { node = Symbol { name = Simple ("roundTowardZero" | "RTZ"); _ }; _ } + -> (ctx, M.Float.Rounding_mode.rtz) - | Symbol rm -> make_symbol ctx rm + | Sym { node = Symbol rm; _ } -> make_symbol ctx rm | _ -> Fmt.failwith "unknown rouding mode: %a" Expr.pp rm - let rec encode_expr ctx (hte : Expr.t) : symbol_ctx * M.term = - match Expr.view hte with - | Val value -> (ctx, v value) - | Ptr { base; offset } -> - let base = v (Bitv base) in - let ctx, offset = encode_expr ctx offset in - (ctx, I32.binop Add base offset) - | Symbol { name = Simple "re.all"; _ } -> (ctx, M.Re.all ()) - | Symbol { name = Simple "re.none"; _ } -> (ctx, M.Re.none ()) - | Symbol { name = Simple "re.allchar"; _ } -> (ctx, M.Re.allchar ()) - | Symbol sym -> make_symbol ctx sym - (* FIXME: add a way to support building these expressions without apps *) - | App ({ name = Simple "fp.add"; _ }, [ rm; a; b ]) -> - let ctx, a = encode_expr ctx a in - let ctx, b = encode_expr ctx b in - let ctx, rm = get_rounding_mode ctx rm in - (ctx, M.Float.add ~rm a b) - | App ({ name = Simple "fp.sub"; _ }, [ rm; a; b ]) -> - let ctx, a = encode_expr ctx a in - let ctx, b = encode_expr ctx b in - let ctx, rm = get_rounding_mode ctx rm in - (ctx, M.Float.sub ~rm a b) - | App ({ name = Simple "fp.mul"; _ }, [ rm; a; b ]) -> - let ctx, a = encode_expr ctx a in - let ctx, b = encode_expr ctx b in - let ctx, rm = get_rounding_mode ctx rm in - (ctx, M.Float.mul ~rm a b) - | App ({ name = Simple "fp.div"; _ }, [ rm; a; b ]) -> - let ctx, a = encode_expr ctx a in - let ctx, b = encode_expr ctx b in - let ctx, rm = get_rounding_mode ctx rm in - (ctx, M.Float.div ~rm a b) - | App ({ name = Simple "fp.fma"; _ }, [ rm; a; b; c ]) -> - let ctx, a = encode_expr ctx a in - let ctx, b = encode_expr ctx b in - let ctx, c = encode_expr ctx c in - let ctx, rm = get_rounding_mode ctx rm in - (ctx, M.Float.fma ~rm a b c) - | App ({ name = Simple "fp.sqrt"; _ }, [ rm; a ]) -> - let ctx, a = encode_expr ctx a in - let ctx, rm = get_rounding_mode ctx rm in - (ctx, M.Float.sqrt ~rm a) - | App ({ name = Simple "fp.roundToIntegral"; _ }, [ rm; a ]) -> - let ctx, a = encode_expr ctx a in - let ctx, rm = get_rounding_mode ctx rm in - (ctx, M.Float.round_to_integral ~rm a) - | App (sym, args) -> - let name = - match Symbol.name sym with - | Simple name -> name - | Indexed _ -> - Fmt.failwith "Unsupported uninterpreted application of: %a" - Symbol.pp sym - in - let ty = get_type @@ Symbol.type_of sym in - let tys = List.map (fun e -> get_type @@ Expr.ty e) args in - let ctx, arguments = encode_exprs ctx args in - let sym = M.Func.make name tys ty in - (ctx, M.Func.apply sym arguments) - | Unop (ty, op, e) -> - let ctx, e = encode_expr ctx e in - (ctx, unop ty op e) - | Binop (ty, op, e1, e2) -> - let ctx, e1 = encode_expr ctx e1 in - let ctx, e2 = encode_expr ctx e2 in - (ctx, binop ty op e1 e2) - | Triop (ty, op, e1, e2, e3) -> - let ctx, e1 = encode_expr ctx e1 in - let ctx, e2 = encode_expr ctx e2 in - let ctx, e3 = encode_expr ctx e3 in - (ctx, triop ty op e1 e2 e3) - | Relop (ty, op, e1, e2) -> - let ctx, e1 = encode_expr ctx e1 in - let ctx, e2 = encode_expr ctx e2 in - (ctx, relop ty op e1 e2) - | Cvtop (ty, op, e) -> - let ctx, e = encode_expr ctx e in - (ctx, cvtop ty op e) - | Naryop (ty, op, es) -> - let ctx, es = - List.fold_left - (fun (ctx, es) e -> - let ctx, e = encode_expr ctx e in - (ctx, e :: es) ) - (ctx, []) es - in - (* This is needed so arguments don't end up out of order in the operator *) - let es = List.rev es in - (ctx, naryop ty op es) - | Extract (e, h, l) -> - let ctx, e = encode_expr ctx e in - (ctx, M.Bitv.extract e ~high:((h * 8) - 1) ~low:(l * 8)) - | Concat (e1, e2) -> - let ctx, e1 = encode_expr ctx e1 in - let ctx, e2 = encode_expr ctx e2 in - (ctx, M.Bitv.concat e1 e2) - | Binder (Forall, vars, body) -> - let ctx, vars = encode_exprs ctx vars in - let ctx, body = encode_expr ctx body in - (ctx, M.forall vars body) - | Binder (Exists, vars, body) -> - let ctx, vars = encode_exprs ctx vars in - let ctx, body = encode_expr ctx body in - (ctx, M.exists vars body) - | List _ | Binder _ | Loc _ -> - Fmt.failwith "Cannot encode expression: %a" Expr.pp hte + let rec encode_expr ctx (e : Expr.t) : symbol_ctx * M.term = + match e with + | Imm value -> (ctx, v value) + | Sym hte -> begin + match Expr.view hte with + | Ptr { base; offset } -> + let base = v (Bitv base) in + let ctx, offset = encode_expr ctx offset in + (ctx, I32.binop Add base offset) + | Symbol { name = Simple "re.all"; _ } -> (ctx, M.Re.all ()) + | Symbol { name = Simple "re.none"; _ } -> (ctx, M.Re.none ()) + | Symbol { name = Simple "re.allchar"; _ } -> (ctx, M.Re.allchar ()) + | Symbol sym -> make_symbol ctx sym + (* FIXME: add a way to support building these expressions without apps *) + | App ({ name = Simple "fp.add"; _ }, [ rm; a; b ]) -> + let ctx, a = encode_expr ctx a in + let ctx, b = encode_expr ctx b in + let ctx, rm = get_rounding_mode ctx rm in + (ctx, M.Float.add ~rm a b) + | App ({ name = Simple "fp.sub"; _ }, [ rm; a; b ]) -> + let ctx, a = encode_expr ctx a in + let ctx, b = encode_expr ctx b in + let ctx, rm = get_rounding_mode ctx rm in + (ctx, M.Float.sub ~rm a b) + | App ({ name = Simple "fp.mul"; _ }, [ rm; a; b ]) -> + let ctx, a = encode_expr ctx a in + let ctx, b = encode_expr ctx b in + let ctx, rm = get_rounding_mode ctx rm in + (ctx, M.Float.mul ~rm a b) + | App ({ name = Simple "fp.div"; _ }, [ rm; a; b ]) -> + let ctx, a = encode_expr ctx a in + let ctx, b = encode_expr ctx b in + let ctx, rm = get_rounding_mode ctx rm in + (ctx, M.Float.div ~rm a b) + | App ({ name = Simple "fp.fma"; _ }, [ rm; a; b; c ]) -> + let ctx, a = encode_expr ctx a in + let ctx, b = encode_expr ctx b in + let ctx, c = encode_expr ctx c in + let ctx, rm = get_rounding_mode ctx rm in + (ctx, M.Float.fma ~rm a b c) + | App ({ name = Simple "fp.sqrt"; _ }, [ rm; a ]) -> + let ctx, a = encode_expr ctx a in + let ctx, rm = get_rounding_mode ctx rm in + (ctx, M.Float.sqrt ~rm a) + | App ({ name = Simple "fp.roundToIntegral"; _ }, [ rm; a ]) -> + let ctx, a = encode_expr ctx a in + let ctx, rm = get_rounding_mode ctx rm in + (ctx, M.Float.round_to_integral ~rm a) + | App (sym, args) -> + let name = + match Symbol.name sym with + | Simple name -> name + | Indexed _ -> + Fmt.failwith "Unsupported uninterpreted application of: %a" + Symbol.pp sym + in + let ty = get_type @@ Symbol.type_of sym in + let tys = List.map (fun e -> get_type @@ Expr.ty e) args in + let ctx, arguments = encode_exprs ctx args in + let sym = M.Func.make name tys ty in + (ctx, M.Func.apply sym arguments) + | Unop (ty, op, e) -> + let ctx, e = encode_expr ctx e in + (ctx, unop ty op e) + | Binop (ty, op, e1, e2) -> + let ctx, e1 = encode_expr ctx e1 in + let ctx, e2 = encode_expr ctx e2 in + (ctx, binop ty op e1 e2) + | Triop (ty, op, e1, e2, e3) -> + let ctx, e1 = encode_expr ctx e1 in + let ctx, e2 = encode_expr ctx e2 in + let ctx, e3 = encode_expr ctx e3 in + (ctx, triop ty op e1 e2 e3) + | Relop (ty, op, e1, e2) -> + let ctx, e1 = encode_expr ctx e1 in + let ctx, e2 = encode_expr ctx e2 in + (ctx, relop ty op e1 e2) + | Cvtop (ty, op, e) -> + let ctx, e = encode_expr ctx e in + (ctx, cvtop ty op e) + | Naryop (ty, op, es) -> + let ctx, es = + List.fold_left + (fun (ctx, es) e -> + let ctx, e = encode_expr ctx e in + (ctx, e :: es) ) + (ctx, []) es + in + (* This is needed so arguments don't end up out of order in the operator *) + let es = List.rev es in + (ctx, naryop ty op es) + | Extract (e, h, l) -> + let ctx, e = encode_expr ctx e in + (ctx, M.Bitv.extract e ~high:((h * 8) - 1) ~low:(l * 8)) + | Concat (e1, e2) -> + let ctx, e1 = encode_expr ctx e1 in + let ctx, e2 = encode_expr ctx e2 in + (ctx, M.Bitv.concat e1 e2) + | Binder (Forall, vars, body) -> + let ctx, vars = encode_exprs ctx vars in + let ctx, body = encode_expr ctx body in + (ctx, M.forall vars body) + | Binder (Exists, vars, body) -> + let ctx, vars = encode_exprs ctx vars in + let ctx, body = encode_expr ctx body in + (ctx, M.exists vars body) + | List _ | Binder _ -> + Fmt.failwith "Cannot encode expression: %a" Expr.pp e + end and encode_exprs ctx (es : Expr.t list) : symbol_ctx * M.term list = let ctx, exprs = diff --git a/src/smtml/rewrite.ml b/src/smtml/rewrite.ml index 4734114a..ec662af4 100644 --- a/src/smtml/rewrite.ml +++ b/src/smtml/rewrite.ml @@ -55,118 +55,123 @@ let rewrite_ty unknown_ty tys = | ty -> ty (** Propagates types in [type_map] and inlines [Let_in] binders *) -let rec rewrite_expr (type_map, expr_map) hte = - debug "rewrite_expr: %a@." (fun k -> k Expr.pp hte); - match Expr.view hte with - | Val _ | Loc _ -> hte - | Ptr { base; offset } -> - let base = Bitvector.to_int32 base in - Expr.ptr base (rewrite_expr (type_map, expr_map) offset) - | Symbol sym -> begin - (* Avoid rewriting well-typed symbols already *) - if not (Ty.equal Ty_none (Symbol.type_of sym)) then hte - else - match Symb_map.find_opt sym type_map with - | None -> ( - match Symb_map.find_opt sym expr_map with +let rec rewrite_expr (type_map, expr_map) e = + debug "rewrite_expr: %a@." (fun k -> k Expr.pp e); + match e with + | Imm _ -> e + | Sym hte -> begin + match Expr.view hte with + | Ptr { base; offset } -> + let base = Bitvector.to_int32 base in + Expr.ptr base (rewrite_expr (type_map, expr_map) offset) + | Symbol sym -> begin + (* Avoid rewriting well-typed symbols already *) + if not (Ty.equal Ty_none (Symbol.type_of sym)) then e + else + match Symb_map.find_opt sym type_map with + | None -> ( + match Symb_map.find_opt sym expr_map with + | None -> Fmt.failwith "Undefined symbol: %a" Symbol.pp sym + | Some expr -> expr ) + | Some ty -> Expr.symbol { sym with ty } + end + | List htes -> Expr.list (List.map (rewrite_expr (type_map, expr_map)) htes) + | App + ( ( { name = Simple ("fp.add" | "fp.sub" | "fp.mul" | "fp.div"); _ } as + sym ) + , [ rm; a; b ] ) -> + let rm = rewrite_expr (type_map, expr_map) rm in + let a = rewrite_expr (type_map, expr_map) a in + let b = rewrite_expr (type_map, expr_map) b in + let ty = rewrite_ty Ty_none [ Expr.ty a; Expr.ty b ] in + Expr.app { sym with ty } [ rm; a; b ] + | App (({ name = Simple "fp.fma"; _ } as sym), [ rm; a; b; c ]) -> + let rm = rewrite_expr (type_map, expr_map) rm in + let a = rewrite_expr (type_map, expr_map) a in + let b = rewrite_expr (type_map, expr_map) b in + let c = rewrite_expr (type_map, expr_map) c in + let ty = rewrite_ty Ty_none [ Expr.ty a; Expr.ty b; Expr.ty c ] in + Expr.app { sym with ty } [ rm; a; b; c ] + | App + ( ({ name = Simple ("fp.sqrt" | "fp.roundToIntegral"); _ } as sym) + , [ rm; a ] ) -> + let rm = rewrite_expr (type_map, expr_map) rm in + let a = rewrite_expr (type_map, expr_map) a in + let ty = rewrite_ty Ty_none [ Expr.ty a ] in + Expr.app { sym with ty } [ rm; a ] + | App (sym, htes) -> + let sym = + match Symb_map.find_opt sym type_map with | None -> Fmt.failwith "Undefined symbol: %a" Symbol.pp sym - | Some expr -> expr ) - | Some ty -> Expr.symbol { sym with ty } - end - | List htes -> Expr.list (List.map (rewrite_expr (type_map, expr_map)) htes) - | App - ( ({ name = Simple ("fp.add" | "fp.sub" | "fp.mul" | "fp.div"); _ } as sym) - , [ rm; a; b ] ) -> - let rm = rewrite_expr (type_map, expr_map) rm in - let a = rewrite_expr (type_map, expr_map) a in - let b = rewrite_expr (type_map, expr_map) b in - let ty = rewrite_ty Ty_none [ Expr.ty a; Expr.ty b ] in - Expr.app { sym with ty } [ rm; a; b ] - | App (({ name = Simple "fp.fma"; _ } as sym), [ rm; a; b; c ]) -> - let rm = rewrite_expr (type_map, expr_map) rm in - let a = rewrite_expr (type_map, expr_map) a in - let b = rewrite_expr (type_map, expr_map) b in - let c = rewrite_expr (type_map, expr_map) c in - let ty = rewrite_ty Ty_none [ Expr.ty a; Expr.ty b; Expr.ty c ] in - Expr.app { sym with ty } [ rm; a; b; c ] - | App - ( ({ name = Simple ("fp.sqrt" | "fp.roundToIntegral"); _ } as sym) - , [ rm; a ] ) -> - let rm = rewrite_expr (type_map, expr_map) rm in - let a = rewrite_expr (type_map, expr_map) a in - let ty = rewrite_ty Ty_none [ Expr.ty a ] in - Expr.app { sym with ty } [ rm; a ] - | App (sym, htes) -> - let sym = - match Symb_map.find_opt sym type_map with - | None -> Fmt.failwith "Undefined symbol: %a" Symbol.pp sym - | Some ty -> { sym with ty } - in - Expr.app sym (List.map (rewrite_expr (type_map, expr_map)) htes) - | Unop (ty, op, hte) -> - let hte = rewrite_expr (type_map, expr_map) hte in - let ty = rewrite_ty ty [ Expr.ty hte ] in - Expr.unop ty op hte - | Binop (ty, op, hte1, hte2) -> - let hte1 = rewrite_expr (type_map, expr_map) hte1 in - let hte2 = rewrite_expr (type_map, expr_map) hte2 in - let ty = rewrite_ty ty [ Expr.ty hte1; Expr.ty hte2 ] in - Expr.binop ty op hte1 hte2 - | Triop (ty, op, hte1, hte2, hte3) -> - let hte1 = rewrite_expr (type_map, expr_map) hte1 in - let hte2 = rewrite_expr (type_map, expr_map) hte2 in - let hte3 = rewrite_expr (type_map, expr_map) hte3 in - Expr.triop ty op hte1 hte2 hte3 - | Relop (ty, ((Eq | Ne) as op), hte1, hte2) when not (Ty.equal Ty_none ty) -> - let hte1 = rewrite_expr (type_map, expr_map) hte1 in - let hte2 = rewrite_expr (type_map, expr_map) hte2 in - Expr.relop ty op hte1 hte2 - | Relop (ty, op, hte1, hte2) -> - let hte1 = rewrite_expr (type_map, expr_map) hte1 in - let hte2 = rewrite_expr (type_map, expr_map) hte2 in - let ty = rewrite_ty ty [ Expr.ty hte1; Expr.ty hte2 ] in - Expr.relop ty op hte1 hte2 - | Cvtop (ty, op, hte) -> - let hte = rewrite_expr (type_map, expr_map) hte in - let ty = rewrite_ty ty [ Expr.ty hte ] in - Expr.cvtop ty op hte - | Naryop (ty, op, htes) -> - let htes = List.map (rewrite_expr (type_map, expr_map)) htes in - Expr.naryop ty op htes - | Extract (hte, h, l) -> - let hte = rewrite_expr (type_map, expr_map) hte in - Expr.extract hte ~high:h ~low:l - | Concat (hte1, hte2) -> - let hte1 = rewrite_expr (type_map, expr_map) hte1 in - let hte2 = rewrite_expr (type_map, expr_map) hte2 in - Expr.concat hte1 hte2 - | Binder (Let_in, vars, e) -> - (* Then, we rewrite the types of the expr *) - let expr_map = - List.fold_left - (fun map e -> - match Expr.view e with - | App (sym, [ e ]) -> - (* Searches the outer expr_map. Because I don't think the list of + | Some ty -> { sym with ty } + in + Expr.app sym (List.map (rewrite_expr (type_map, expr_map)) htes) + | Unop (ty, op, hte) -> + let hte = rewrite_expr (type_map, expr_map) hte in + let ty = rewrite_ty ty [ Expr.ty hte ] in + Expr.unop ty op hte + | Binop (ty, op, hte1, hte2) -> + let hte1 = rewrite_expr (type_map, expr_map) hte1 in + let hte2 = rewrite_expr (type_map, expr_map) hte2 in + let ty = rewrite_ty ty [ Expr.ty hte1; Expr.ty hte2 ] in + Expr.binop ty op hte1 hte2 + | Triop (ty, op, hte1, hte2, hte3) -> + let hte1 = rewrite_expr (type_map, expr_map) hte1 in + let hte2 = rewrite_expr (type_map, expr_map) hte2 in + let hte3 = rewrite_expr (type_map, expr_map) hte3 in + Expr.triop ty op hte1 hte2 hte3 + | Relop (ty, ((Eq | Ne) as op), hte1, hte2) when not (Ty.equal Ty_none ty) + -> + let hte1 = rewrite_expr (type_map, expr_map) hte1 in + let hte2 = rewrite_expr (type_map, expr_map) hte2 in + Expr.relop ty op hte1 hte2 + | Relop (ty, op, hte1, hte2) -> + let hte1 = rewrite_expr (type_map, expr_map) hte1 in + let hte2 = rewrite_expr (type_map, expr_map) hte2 in + let ty = rewrite_ty ty [ Expr.ty hte1; Expr.ty hte2 ] in + Expr.relop ty op hte1 hte2 + | Cvtop (ty, op, hte) -> + let hte = rewrite_expr (type_map, expr_map) hte in + let ty = rewrite_ty ty [ Expr.ty hte ] in + Expr.cvtop ty op hte + | Naryop (ty, op, htes) -> + let htes = List.map (rewrite_expr (type_map, expr_map)) htes in + Expr.naryop ty op htes + | Extract (hte, h, l) -> + let hte = rewrite_expr (type_map, expr_map) hte in + Expr.extract hte ~high:h ~low:l + | Concat (hte1, hte2) -> + let hte1 = rewrite_expr (type_map, expr_map) hte1 in + let hte2 = rewrite_expr (type_map, expr_map) hte2 in + Expr.concat hte1 hte2 + | Binder (Let_in, vars, e) -> + (* Then, we rewrite the types of the expr *) + let expr_map = + List.fold_left + (fun map e -> + match e with + | Expr.Sym { node = App (sym, [ e ]); _ } -> + (* Searches the outer expr_map. Because I don't think the list of var bindings are in scope for themselves? *) - let e = rewrite_expr (type_map, expr_map) e in - Symb_map.add sym e map - | _ -> assert false ) - expr_map vars - in - rewrite_expr (type_map, expr_map) e - | Binder (((Forall | Exists) as quantifier), vars, e) -> - let type_map, vars = - List.fold_left - (fun (map, vars) e -> - match Expr.view e with - | App (sym, [ e ]) -> - let ty = Expr.ty e in - (Symb_map.add sym ty map, Expr.symbol { sym with ty } :: vars) - | _ -> assert false ) - (type_map, []) vars - in - Expr.binder quantifier vars (rewrite_expr (type_map, expr_map) e) + let e = rewrite_expr (type_map, expr_map) e in + Symb_map.add sym e map + | _ -> assert false ) + expr_map vars + in + rewrite_expr (type_map, expr_map) e + | Binder (((Forall | Exists) as quantifier), vars, e) -> + let type_map, vars = + List.fold_left + (fun (map, vars) e -> + match e with + | Expr.Sym { node = App (sym, [ e ]); _ } -> + let ty = Expr.ty e in + (Symb_map.add sym ty map, Expr.symbol { sym with ty } :: vars) + | _ -> assert false ) + (type_map, []) vars + in + Expr.binder quantifier vars (rewrite_expr (type_map, expr_map) e) + end (** Acccumulates types of symbols in [type_map] and calls rewrite_expr *) let rewrite_cmd type_map cmd = diff --git a/src/smtml/smtlib.ml b/src/smtml/smtlib.ml index a0db7ecf..1a717bca 100644 --- a/src/smtml/smtlib.ml +++ b/src/smtml/smtlib.ml @@ -129,8 +129,8 @@ module Term = struct Expr.value (Bitv (Bitvector.make int (len - 2))) let colon ?loc (symbol : t) (term : t) : t = - match Expr.view symbol with - | Symbol s -> + match symbol with + | Expr.Sym { node = Symbol s; _ } -> (* Hack: var bindings are 1 argument lambdas *) Log.debug (fun k -> k "colon: unknown '%a' making app" Expr.pp symbol); Expr.app s [ term ] @@ -139,14 +139,18 @@ module Term = struct Expr.pp term let make_fp_binop symbol (op : Ty.Binop.t) rm a b = - match Expr.view rm with - | Symbol { name = Simple "roundNearestTiesToEven"; _ } -> + match rm with + | Expr.Sym + { node = Symbol { name = Simple "roundNearestTiesToEven"; _ }; _ } -> Expr.raw_binop Ty_none op a b | _ -> Expr.app symbol [ rm; a; b ] let apply ?loc (id : t) (args : t list) : t = - match Expr.view id with - | Symbol ({ namespace = Term; name = Simple name; _ } as symbol) -> begin + match id with + | Expr.Sym + { node = Symbol ({ namespace = Term; name = Simple name; _ } as symbol) + ; _ + } -> begin match (name, args) with | "-", [ a ] -> Expr.raw_unop Ty_none Neg a | "not", [ a ] -> Expr.raw_unop Ty_bool Not a @@ -229,11 +233,7 @@ module Term = struct | "bvsge", [ a; b ] -> Expr.raw_relop Ty_none Ge a b | "bvuge", [ a; b ] -> Expr.raw_relop Ty_none GeU a b | "concat", [ a; b ] -> Expr.raw_concat a b - | ( "fp" - , [ { node = Val (Bitv sign); _ } - ; { node = Val (Bitv eb); _ } - ; { node = Val (Bitv i); _ } - ] ) -> + | "fp", [ Imm (Bitv sign); Imm (Bitv eb); Imm (Bitv i) ] -> let fp = Bitvector.(concat sign (concat eb i)) in let fp_sz = Bitvector.numbits fp in if fp_sz = 32 then Expr.value (Num (F32 (Bitvector.to_int32 fp))) @@ -253,12 +253,13 @@ module Term = struct | "fp.mul", [ rm; a; b ] -> make_fp_binop symbol Mul rm a b | "fp.div", [ rm; a; b ] -> make_fp_binop symbol Div rm a b | ( "fp.sqrt" - , [ { node = Symbol { name = Simple "roundNearestTiesToEven"; _ }; _ } + , [ Sym + { node = Symbol { name = Simple "roundNearestTiesToEven"; _ }; _ } ; a ] ) -> Expr.raw_unop Ty_none Sqrt a | "fp.rem", [ a; b ] -> Expr.raw_binop Ty_none Rem a b - | "fp.roundToIntegral", [ rm; a ] -> begin + | "fp.roundToIntegral", [ Sym rm; a ] -> begin match Expr.view rm with | Symbol { name = Simple "roundNearestTiesToEven"; _ } -> Expr.raw_unop Ty_none Nearest a @@ -281,10 +282,13 @@ module Term = struct Log.debug (fun k -> k "apply: unknown %a making app" Symbol.pp symbol); Expr.app symbol args end - | Symbol ({ name = Simple _; namespace = Attr; _ } as attr) -> + | Sym + { node = Symbol ({ name = Simple _; namespace = Attr; _ } as attr); _ } + -> Log.debug (fun k -> k "apply: unknown %a making app" Symbol.pp attr); Expr.app attr args - | Symbol { name = Indexed { basename; indices }; _ } -> begin + | Sym { node = Symbol { name = Indexed { basename; indices }; _ }; _ } -> + begin match (basename, indices, args) with | "extract", [ h; l ], [ a ] -> let high = @@ -315,17 +319,18 @@ module Term = struct Expr.raw_unop Ty_regexp (Regexp_loop (i1, i2)) a | ( "to_fp" , [ "11"; "53" ] - , [ { node = - Symbol { name = Simple ("roundNearestTiesToEven" | "RNE"); _ } - ; _ - } + , [ Sym + { node = + Symbol { name = Simple ("roundNearestTiesToEven" | "RNE"); _ } + ; _ + } ; a ] ) -> Expr.raw_cvtop (Ty_fp 64) PromoteF32 a | _ -> Fmt.failwith "%acould not parse indexed app: %a" pp_loc loc Expr.pp id end - | Symbol id -> + | Sym { node = Symbol id; _ } -> Log.debug (fun k -> k "apply: unknown %a making app" Symbol.pp id); Expr.app id args | _ -> @@ -372,12 +377,16 @@ module Statement = struct let datatypes ?loc:_ = assert false let fun_decl ?loc id ts1 ts2 return_sort = - match (id, ts1, ts2, Expr.view return_sort) with - | id, [], [], Symbol sort -> Declare_const { id; sort } - | id, [], args, Symbol sort -> + match (id, ts1, ts2, return_sort) with + | id, [], [], Expr.Sym { node = Symbol sort; _ } -> + Declare_const { id; sort } + | id, [], args, Sym { node = Symbol sort; _ } -> let args = List.map - (fun e -> match Expr.view e with Symbol s -> s | _ -> assert false) + (fun e -> + match e with + | Expr.Sym { node = Symbol s; _ } -> s + | _ -> assert false ) args in Declare_fun { id; args; sort } diff --git a/test/unit/test_expr.ml b/test/unit/test_expr.ml index 2f4f62f5..4b76e2df 100644 --- a/test/unit/test_expr.ml +++ b/test/unit/test_expr.ml @@ -21,15 +21,21 @@ let test_hc _ = let open Infix in let length0 = Expr.Hc.length () in let ty = Ty.Ty_bitv 32 in - assert (symbol "x" ty == symbol "x" ty); - assert (symbol "x" ty != symbol "y" ty); + assert ( + match (symbol "x" ty, symbol "x" ty) with + | Sym a, Sym b -> a == b + | _ -> false ); + assert ( + match (symbol "x" ty, symbol "y" ty) with + | Sym a, Sym b -> a != b + | _ -> false ); let left_a = symbol "x" ty in let right_a = symbol "y" ty in let left_b = symbol "x" ty in let right_b = symbol "y" ty in let a = Expr.binop ty Add left_a right_a in let b = Expr.binop ty Add left_b right_b in - assert (a == b); + assert (match (a, b) with Sym a, Sym b -> a == b | _ -> false); (* There should be only 3 elements added in the hashcons table: *) (* 1. x *) (* 2. y *) @@ -257,7 +263,7 @@ let test_binop_simplifications _ = (binop32 Mul x (int32 4l)); check (binop32 Mul (int32 2l) (binop32 Mul x (int32 2l))) - (binop32 Mul (int32 4l) x) + (binop32 Mul x (int32 4l)) let test_binop = [ "test_binop_int" >:: test_binop_int @@ -533,7 +539,7 @@ let test_simplify_assoc _ = check (Expr.simplify sym) (Expr.raw_binop Ty_int Add x (int 13)); let binary = Expr.raw_binop Ty_int Add x (int 10) in let sym = Expr.raw_binop Ty_int Add (int 3) binary in - check (Expr.simplify sym) (Expr.raw_binop Ty_int Add (int 13) x) + check (Expr.simplify sym) (Expr.raw_binop Ty_int Add x (int 13)) let test_simplify_extract_i8 _ = let open Infix in