From 9170b6e3beeca6529955d4a1fa66452f2fddc45b Mon Sep 17 00:00:00 2001 From: Ariel Kellison Date: Sat, 7 Mar 2026 07:52:10 -0500 Subject: [PATCH 1/2] Add documentation, refactor proofs, fix sum_is_finite --- .gitignore | 2 + C/matrix_model.v | 3 - C/spec_densemat.v | 2 +- Makefile.coq.local | 2 +- _CoqProject | 2 + accuracy_proofs/common.v | 936 +++++++++------ accuracy_proofs/dot_acc.v | 279 +++-- accuracy_proofs/dot_acc_lemmas.v | 1753 +++++++++++++++------------- accuracy_proofs/dotprod_model.v | 1088 +++++++++++------- accuracy_proofs/float_acc_lems.v | 1267 +++++++++++--------- accuracy_proofs/fma_dot_acc.v | 261 +++-- accuracy_proofs/fma_is_finite.v | 717 ++++++------ accuracy_proofs/gemm_acc.v | 558 ++++++--- accuracy_proofs/gemv_acc.v | 393 ++++--- accuracy_proofs/libvalidsdp.v | 26 +- accuracy_proofs/mv_mathcomp.v | 1854 +++++++++++++++++------------- accuracy_proofs/real_lemmas.v | 57 + accuracy_proofs/sum_acc.v | 532 +++++---- accuracy_proofs/sum_is_finite.v | 467 +++++--- accuracy_proofs/sum_model.v | 1031 ++++++++++------- accuracy_proofs/vec_op_acc.v | 521 +++++---- accuracy_proofs/vecnorm_acc.v | 158 ++- header.html | 23 + html/index.html | 2 +- 24 files changed, 7105 insertions(+), 4829 deletions(-) create mode 100644 accuracy_proofs/real_lemmas.v create mode 100644 header.html diff --git a/.gitignore b/.gitignore index 7f79b27..7014311 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ Makefile.coq Makefile.coq.conf CoqMakefile CoqMakefile.conf +.history/* .*.aux .*.d *.a @@ -48,3 +49,4 @@ html/LAProof.* html/genindex.html html/*.css html/toc.html +.vscode/settings.json diff --git a/C/matrix_model.v b/C/matrix_model.v index 72ccf00..b1b52c1 100644 --- a/C/matrix_model.v +++ b/C/matrix_model.v @@ -32,9 +32,6 @@ Unset Printing Implicit Defensive. Set Bullet Behavior "Strict Subproofs". (* end show *) -Definition neg_zero {t}: ftype t := Binary.B754_zero (fprec t) (femax t) true. - - Lemma map_inj: forall [T1 T2] (f: T1 -> T2) (H: injective f) (al bl: list T1), map f al = map f bl -> al=bl. Proof. induction al; destruct bl; simpl; intros; inversion H0; clear H0; subst; auto. diff --git a/C/spec_densemat.v b/C/spec_densemat.v index ebf1637..2f279f4 100644 --- a/C/spec_densemat.v +++ b/C/spec_densemat.v @@ -353,7 +353,7 @@ Definition densemat_clear_spec := program correct, there's no need for dynamic bounds checking. The precondition of the function enforces that [0 <= i < m] and [0 <= j < n]. It does so by construction - of the dependently typed value [X], where the last component is a pair [(i: 'I_[m], j: 'I[n]). + of the dependently typed value [X], where the last component is a pair [(i: 'I_[m], j: 'I[n])]. *) Definition densematn_get_spec := DECLARE _densematn_get diff --git a/Makefile.coq.local b/Makefile.coq.local index 95be9e3..79ac1ea 100644 --- a/Makefile.coq.local +++ b/Makefile.coq.local @@ -1,5 +1,5 @@ # COQDOC publishing -COQDOCEXTRAFLAGS= -g --no-lib-name --index genindex --lib-subtitles --interpolate -utf8 +COQDOCEXTRAFLAGS= -g --no-lib-name --with-header header.html --index genindex --lib-subtitles --interpolate -utf8 accuracy: accuracy_proofs/export.vo C: C/verif_alloc.vo C/verif_sparse.vo C/verif_sparse_byrows.vo C/VSU_densemat.vo C/verif_build_csr.vo diff --git a/_CoqProject b/_CoqProject index 1c46633..db64212 100644 --- a/_CoqProject +++ b/_CoqProject @@ -2,6 +2,7 @@ COQEXTRAFLAGS = "-w -notation-overridden,-ambiguous-paths,-overwriting-delimiting-key,-notation-incompatible-prefix" accuracy_proofs/preamble.v +accuracy_proofs/real_lemmas.v accuracy_proofs/common.v accuracy_proofs/float_acc_lems.v accuracy_proofs/dotprod_model.v @@ -11,6 +12,7 @@ accuracy_proofs/dot_acc_lemmas.v accuracy_proofs/dot_acc.v accuracy_proofs/vecnorm_acc.v accuracy_proofs/fma_dot_acc.v +accuracy_proofs/sum_is_finite.v accuracy_proofs/fma_is_finite.v accuracy_proofs/mv_mathcomp.v accuracy_proofs/gemv_acc.v diff --git a/accuracy_proofs/common.v b/accuracy_proofs/common.v index 526376f..45c30e7 100644 --- a/accuracy_proofs/common.v +++ b/accuracy_proofs/common.v @@ -1,205 +1,391 @@ -(* This file contains basic definitions and lemmas common to all other files in - the repository. *) - -From LAProof.accuracy_proofs Require Import preamble. - -Definition rounded t r:= -(Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) r). - -Definition neg_zero {t: type} := Binary.B754_zero (fprec t) (femax t) true. -Definition pos_zero {t: type} := Binary.B754_zero (fprec t) (femax t) false. -Definition Beq_dec_t {t: type} - (x y : ftype t) : {x = y} + {x <> y}. - apply (Beq_dec (fprec t) (femax t) x y). - Defined. +(** * Common Definitions and Lemmas for Floating-Point Accuracy Proofs + + This file provides foundational definitions and lemmas used throughout + the LAProof library. It establishes the core vocabulary + for reasoning about floating-point rounding, error bounds, and + accumulation of rounding errors in numerical computations. + + The main concepts defined here are: + + - _Floating-point rounding_: The function [rounded] captures + round-to-nearest-even (RNE) rounding in radix-2 floating point, + parameterized by a floating-point type << t >> that fixes the precision + [fprec t] and exponent range [femax t]. + + - _Special floating-point values_: << neg_zero >> and << pos_zero >> + represent the IEEE 754 signed zero values for a given type << t >>, + useful when reasoning about sign-sensitive floating-point operations. + + - _Zero testing_: << iszero >> is a boolean predicate on floating-point + values that returns << true >> exactly when the value is an IEEE 754 + zero (of either sign). The lemma [iszeroR_iszeroF] connects this + structural test to the real-number interpretation [FT2R x = 0]. + + - _Counting nonzeros_: << nnzF >> (resp. << nnzR >>) counts the number of + nonzero elements in a list of floating-point (resp. real) values. + Several lemmas formalize the behavior in the case when the nonzero + count is zero. + + - _Default relative error bound_ [default_rel]: + + << default_rel = (1/2) * 2^(-(fprec t) + 1) >> + + This is the unit roundoff (machine epsilon) for type << t >>. + It satisfies [default_rel > 0] and numerous inequalities involving + [1 + default_rel] and its powers that are needed in error analysis. + + - _Default absolute error bound_ [default_abs]: + + << default_abs = (1/2) * 2^(3 - femax t - fprec t) >> + + This bounds the absolute error introduced when rounding a subnormal + result. It satisfies << default_abs <= default_rel >>. + + - _Relative error accumulation factor_ [g n]: + + << g n = (1 + default_rel)^n - 1 >> + + This bounds the accumulated relative rounding error after << n >> + floating-point operations. Key properties include [g_pos], + [le_g_Sn] (monotonicity), and the recurrence + [one_plus_d_mul_g], which expresses how one additional rounding + step advances the bound. + + - _Mixed absolute/relative error accumulation factor_ [g1 n1 n2]: + + << g1 n1 n2 = INR n1 * default_abs * (1 + g n2) >> + + This bounds the accumulated absolute rounding error when + << n1 >> subnormal rounding errors, each of size [default_abs], are + each amplified by up to << (1 + default_rel)^n2 >> subsequent + multiplications. Numerous lemmas establish how [g1] grows as + its arguments increase, supporting inductive error analyses. + + _Hint database_: All positivity, monotonicity, and ordering lemmas + for [default_rel], [default_abs], [g], and [g1] are registered in + the [commonDB] hint database for use with [auto] and [eauto]. +*) + +From LAProof.accuracy_proofs Require Import preamble real_lemmas. + +Import Zorder. + +(** ** Global Definitions and Setup + + The definitions below are independent of any particular floating-point + type and are available without opening [Section WithType]. *) + +Definition rounded t r := + Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) r. + +Definition neg_zero {t : type} := Binary.B754_zero (fprec t) (femax t) true. +Definition pos_zero {t : type} := Binary.B754_zero (fprec t) (femax t) false. + +Definition Beq_dec_t {t : type} (x y : ftype t) : {x = y} + {x <> y} := + Beq_dec (fprec t) (femax t) x y. Create HintDb commonDB discriminated. -Global Hint Resolve - bpow_gt_0 bpow_ge_0 pos_INR lt_0_INR pow_le: commonDB. +Global Hint Resolve + bpow_gt_0 bpow_ge_0 pos_INR lt_0_INR pow_le : commonDB. Delimit Scope R_scope with Re. Open Scope R_scope. -Lemma rev_list_rev: @rev = @List.rev. +Lemma rev_list_rev : @rev = @List.rev. Proof. -apply FunctionalExtensionality.functional_extensionality_dep; intro T. -apply FunctionalExtensionality.functional_extensionality; intro al. -unfold rev. -change @catrev with rev_append. -rewrite rev_append_rev app_nil_r //. + apply FunctionalExtensionality.functional_extensionality_dep; intro T. + apply FunctionalExtensionality.functional_extensionality; intro al. + unfold rev. + change @catrev with rev_append. + rewrite rev_append_rev app_nil_r //. Qed. -Lemma size_not_empty_nat {A} (l: seq A) : l <> [] -> Nat.le 1 (size l). +Lemma size_not_empty_nat {A} (l : seq A) : l <> [] -> Nat.le 1 (size l). Proof. -intros. -destruct l; try congruence; compute; lia. + intros. + destruct l; try congruence; compute; lia. Qed. - Section WithType. -Context {NAN: FPCore.Nans} {t : type}. - -Definition iszero {t} (x: ftype t) : bool := - match x with Binary.B754_zero _ _ _ => true | _ => false end. +Context {NAN : FPCore.Nans} {t : type}. -Lemma iszeroR_iszeroF: forall x: ftype t, Binary.is_finite x -> FT2R x = R0 -> iszero x. -Proof. -destruct x; intros; auto. -exfalso; clear - H0. -rewrite /FT2R /= /Defs.F2R /= in H0. -destruct s; simpl in H0. -- -assert (IZR (Z.neg m) * bpow Zaux.radix2 e < 0)%Re; [clear | lra]; rewrite /IZR. -apply Rmult_neg_pos. -move :(IPR_gt_0 m); lra. -move :(bpow_gt_0 Zaux.radix2 e); lra. -- -assert (0 < IZR (Z.pos m) * bpow Zaux.radix2 e)%Re; [clear | lra]; rewrite /IZR. -apply Rmult_pos_pos. -apply IPR_gt_0. -apply bpow_gt_0. -Qed. - -(** Number of nonzeros *) +(** ** Zero Predicates *) -Definition nnzF: seq (ftype t) -> nat := - count (fun x => negb (iszero x)). +Definition iszero {t} (x : ftype t) : bool := + match x with + | Binary.B754_zero _ _ _ => true + | _ => false + end. -Definition nnzR: seq R -> nat := - count (fun x => negb (0 == x)). - -Lemma nnzF_zero l: (nnzF l == 0%nat) = (size l == count iszero l). +Lemma iszeroR_iszeroF : forall x : ftype t, + Binary.is_finite x -> FT2R x = R0 -> iszero x. +Proof. + destruct x; intros; auto. + exfalso; clear - H0. + rewrite /FT2R /= /Defs.F2R /= in H0. + destruct s; simpl in H0. + - assert (IZR (Z.neg m) * bpow Zaux.radix2 e < 0)%Re; [clear | lra]. + rewrite /IZR. + apply Rmult_neg_pos. + + move: (IPR_gt_0 m); lra. + + move: (bpow_gt_0 Zaux.radix2 e); lra. + - assert (0 < IZR (Z.pos m) * bpow Zaux.radix2 e)%Re; [clear | lra]. + rewrite /IZR. + apply Rmult_pos_pos. + + apply IPR_gt_0. + + apply bpow_gt_0. +Qed. + +(** *** Number of Nonzeros *) + +Definition nnzF : seq (ftype t) -> nat := + count (fun x => negb (iszero x)). + +Definition nnzR : seq R -> nat := + count (fun x => negb (0 == x)). + +Lemma nnzF_zero l : (nnzF l == 0%nat) = (size l == count iszero l). Proof. -rewrite /nnzF (eq_sym (size l)) -all_count. -elim l => // a l' IH. -case :a => //. + rewrite /nnzF (eq_sym (size l)) -all_count. + elim l => // a l' IH. + case: a => //. Qed. -Lemma nnzR_zero l: (nnzR l == 0%nat) = (size l == count (eq_op 0) l). +Lemma nnzR_zero l : (nnzR l == 0%nat) = (size l == count (eq_op 0) l). Proof. -rewrite /nnzR (eq_sym (size l)) -all_count. -elim l => // a l' IH /=. -rewrite -{}IH /=; case (0 == a); lia. + rewrite /nnzR (eq_sym (size l)) -all_count. + elim l => // a l' IH /=. + rewrite -{}IH /=; case (0 == a); lia. Qed. -Lemma nnzF_lemma l: (nnzF l == 0%nat) = all iszero l. +Lemma nnzF_lemma l : (nnzF l == 0%nat) = all iszero l. Proof. -rewrite !nnzF_zero all_count (eq_sym (size _)) //. + rewrite !nnzF_zero all_count (eq_sym (size _)) //. Qed. -Lemma nnzR_lemma l: (nnzR l == 0%nat) = (all (eq_op R0) l). +Lemma nnzR_lemma l : (nnzR l == 0%nat) = (all (eq_op R0) l). Proof. -rewrite !nnzR_zero all_count (eq_sym (size _)) //. + rewrite !nnzR_zero all_count (eq_sym (size _)) //. Qed. -Lemma nnzF_is_zero_cons a l: nnzF (a::l) == 0%nat -> nnzF l == 0%nat. +Lemma nnzF_is_zero_cons a l : nnzF (a :: l) == 0%nat -> nnzF l == 0%nat. Proof. -rewrite !nnzF_lemma (all_cat _ [:: a] l) => /andP [H H'] //. + rewrite !nnzF_lemma (all_cat _ [:: a] l) => /andP [H H'] //. Qed. -Lemma nnzR_is_zero_cons a l: nnzR (a::l) == 0%nat -> nnzR l == 0%nat. +Lemma nnzR_is_zero_cons a l : nnzR (a :: l) == 0%nat -> nnzR l == 0%nat. Proof. -rewrite !nnzR_lemma (all_cat _ [:: a] l) => /andP [H H'] //. + rewrite !nnzR_lemma (all_cat _ [:: a] l) => /andP [H H'] //. Qed. -Lemma nnzR_cons l : - nnzR (0%Re :: l) == nnzR l. +Lemma nnzR_cons l : nnzR (0%Re :: l) == nnzR l. Proof. -rewrite /= eq_refl //. + rewrite /= eq_refl //. Qed. +(** ** Error Bound Constants + + Fundamental error parameters for a floating-point type << t >>. + All ordering and positivity lemmas are collected in [commonDB]. *) + Definition default_rel : R := / 2 * Raux.bpow Zaux.radix2 (- fprec t + 1). Definition default_abs : R := / 2 * Raux.bpow Zaux.radix2 (3 - femax t - fprec t). -Lemma default_rel_sep_0 : - default_rel <> R0. + +Lemma default_rel_sep_0 : default_rel <> R0. Proof. -apply Rabs_lt_pos; -rewrite Rabs_pos_eq; [apply Rmult_lt_0_compat; try Lra.nra | - apply Rmult_le_pos; try Lra.nra]; auto with commonDB. + apply Rabs_lt_pos. + rewrite Rabs_pos_eq; + [ apply Rmult_lt_0_compat; try Lra.nra + | apply Rmult_le_pos; try Lra.nra ]; + auto with commonDB. Qed. Hint Resolve default_rel_sep_0 : commonDB. -Lemma default_rel_gt_0 : - 0 < default_rel. -Proof. apply Rmult_lt_0_compat; try nra; -auto with commonDB. +Lemma default_rel_gt_0 : 0 < default_rel. +Proof. + apply Rmult_lt_0_compat; try nra; auto with commonDB. Qed. Hint Resolve default_rel_gt_0 : commonDB. - -Lemma default_rel_ge_0 : - 0 <= default_rel. + +Lemma default_rel_ge_0 : 0 <= default_rel. Proof. apply Rlt_le; auto with commonDB. Qed. Hint Resolve default_rel_ge_0 : commonDB. -Lemma default_rel_plus_1_ge_1: - 1 <= 1 + default_rel. -Proof. -rewrite Rplus_comm. -apply Rcomplements.Rle_minus_l; field_simplify. -auto with commonDB. +Lemma default_rel_plus_1_ge_1 : 1 <= 1 + default_rel. +Proof. + rewrite Rplus_comm. + apply Rcomplements.Rle_minus_l; field_simplify. + auto with commonDB. Qed. Hint Resolve default_rel_plus_1_ge_1 : commonDB. -Lemma default_rel_plus_0_ge_1: - 0 <= 1 + default_rel. -Proof. suff: 1 <= 1 + default_rel; try nra; auto with commonDB. Qed. +Lemma default_rel_plus_0_ge_1 : 0 <= 1 + default_rel. +Proof. + suff: 1 <= 1 + default_rel; try nra; auto with commonDB. +Qed. Hint Resolve default_rel_plus_0_ge_1 : commonDB. -Lemma default_rel_plus_1_gt_1: - 1 < 1 + default_rel. +Lemma default_rel_plus_1_gt_1 : 1 < 1 + default_rel. Proof. -rewrite Rplus_comm; apply Rcomplements.Rlt_minus_l; - field_simplify; auto with commonDB. + rewrite Rplus_comm; apply Rcomplements.Rlt_minus_l; + field_simplify; auto with commonDB. Qed. Hint Resolve default_rel_plus_1_gt_1 : commonDB. -Lemma default_rel_plus_1_gt_0 : - 0 < 1 + default_rel. -Proof. -eapply Rlt_trans with 1; [nra | ]. -auto with commonDB. +Lemma default_rel_plus_1_gt_0 : 0 < 1 + default_rel. +Proof. + eapply Rlt_trans with 1; [nra |]; auto with commonDB. Qed. Hint Resolve default_rel_plus_1_gt_0 : commonDB. -Lemma default_rel_plus_1_ge_1' n: - 1 <= (1 + default_rel) ^ n. -Proof. -induction n; simpl; auto; try nra. -eapply Rle_trans with (1 * 1); try nra. -apply Rmult_le_compat; try nra. -auto with commonDB. +Lemma default_rel_plus_1_ge_1' n : 1 <= (1 + default_rel) ^ n. +Proof. + induction n; simpl; auto; try nra. + eapply Rle_trans with (1 * 1); try nra. + apply Rmult_le_compat; try nra. + auto with commonDB. Qed. -Hint Resolve default_rel_plus_1_ge_1': commonDB. +Hint Resolve default_rel_plus_1_ge_1' : commonDB. -Lemma default_abs_gt_0 : - 0 < default_abs . -Proof. -unfold default_abs. -apply Rmult_lt_0_compat; auto with commonDB; nra. +Lemma default_abs_gt_0 : 0 < default_abs. +Proof. + unfold default_abs. + apply Rmult_lt_0_compat; auto with commonDB; nra. Qed. -Hint Resolve default_abs_gt_0: commonDB. +Hint Resolve default_abs_gt_0 : commonDB. -Lemma default_abs_ge_0 : - 0 <= default_abs . +Lemma default_abs_ge_0 : 0 <= default_abs. Proof. apply Rlt_le; auto with commonDB. Qed. -Hint Resolve default_abs_ge_0: commonDB. +Hint Resolve default_abs_ge_0 : commonDB. + +Lemma abs_le_rel : default_abs <= default_rel. +Proof. + apply: Rmult_le_compat; try nra; auto with commonDB. + apply: bpow_le => //. + pose proof fprec_gt_one t; pose proof fprec_lt_femax t; lia. +Qed. + +(** [fmax] is the largest finite value representable in type << t >>. *) + +Definition fmax := bpow Zaux.radix2 (femax t). -Lemma abs_le_rel : - default_abs <= default_rel. +Lemma bpow_femax_lb : (1 < femax t)%Z. Proof. -apply: Rmult_le_compat; try nra; auto with commonDB. -apply: bpow_le => //; pose proof fprec_gt_one t; pose proof fprec_lt_femax t; lia. + pose proof fprec_gt_one t as Hfprec. + pose proof fprec_lt_femax t as Hlt. + eapply Z.lt_trans with (fprec t); auto. +Qed. + +Lemma bpow_fmax_lb_4 : + 4 <= fmax. +Proof. + pose proof fprec_gt_one t as Hfprec. + pose proof fprec_lt_femax t as Hlt. + pose proof bpow_femax_lb. + eapply Rle_trans with (bpow Zaux.radix2 2). + - unfold bpow; simpl; nra. + - apply bpow_le; lia. +Qed. + +Lemma bpow_fprec_lb_2 : + 2 <= bpow Zaux.radix2 (fprec t). +Proof. + pose proof fprec_gt_one t as Hfprec. + eapply Rle_trans with (bpow Zaux.radix2 1). + - unfold bpow; simpl; nra. + - apply bpow_le; lia. +Qed. + +(** ** Upper bounds on [default_abs] and [default_rel] *) + +(** The default absolute rounding error [default_abs] is at most [fmax]. *) + +Lemma default_abs_le_fmax : + default_abs <= fmax. +Proof. + replace fmax with (1 * fmax) by nra. + unfold default_abs, fmax. + apply Rmult_le_compat; try nra. + - apply bpow_ge_0. + - apply bpow_le. + apply Z.le_sub_le_add_r. + apply Z.le_sub_le_add_r. + eapply Z.le_trans with (fprec t + fprec t + femax t)%Z; + [ | repeat apply Zplus_le_compat_r; + apply Z.lt_le_incl; + apply (fprec_lt_femax t) + ]. + eapply Z.le_trans with (fprec t + fprec t + fprec t)%Z; + [ | repeat apply Zplus_le_compat_l; + apply Z.lt_le_incl; + apply fprec_lt_femax ]. + eapply Z.le_trans with (1 + fprec t + fprec t)%Z; + [ | repeat apply Zplus_le_compat_r; + apply Z.lt_le_incl; + apply fprec_gt_one ]. + eapply Z.le_trans with (1 + 1 + fprec t)%Z; + [ | repeat apply Zplus_le_compat_r; + repeat apply Zplus_le_compat_l; + apply Z.lt_le_incl; + apply fprec_gt_one ]. + eapply Z.le_trans with (1 + 1 + 1)%Z; + [ lia + | repeat apply Zplus_le_compat_r; + repeat apply Zplus_le_compat_l; + apply Z.lt_le_incl; + apply fprec_gt_one ]. +Qed. + +(** [default_abs t] is at most 1. *) + +Lemma default_abs_ub : + default_abs <= 1. +Proof. + pose proof (abs_le_rel) as H. + eapply Rle_trans; [apply H |]. + unfold default_rel. rewrite bpow_plus bpow_opp. + replace (bpow _ 1) with 2 by (simpl; nra). + refine (Rle_trans _ (1 / bpow Zaux.radix2 (fprec t)) _ _ _). + - nra. + - apply Rdiv_le_left. + + apply bpow_gt_0. + + refine (Rle_trans _ 2 _ _ _); try nra. + rewrite Rmult_1_l; apply bpow_fprec_lb_2. +Qed. + +(** [default_rel t] is at most 1. *) + +Lemma default_rel_ub : + default_rel <= 1. +Proof. + unfold default_rel. + pose proof bpow_gt_0 Zaux.radix2 (fprec t) as Hpos. + rewrite !bpow_plus. + rewrite <- !Rmult_assoc. + rewrite Rmult_comm. + rewrite <- !Rmult_assoc. + replace (bpow Zaux.radix2 1 * / 2) with 1 by (simpl; nra). + rewrite !bpow_opp. + rewrite !Rcomplements.Rle_div_r. + - field_simplify; try nra. + replace 1 with (bpow Zaux.radix2 0) by (simpl; auto). + apply bpow_le. + pose proof fprec_gt_one t; lia. + - apply Rlt_gt. + replace (/ bpow Zaux.radix2 (fprec t)) + with (1 / bpow Zaux.radix2 (fprec t)) by nra. + apply Rdiv_lt_0_compat; try nra. Qed. End WithType. -Global Hint Resolve +Global Hint Resolve default_rel_sep_0 default_rel_gt_0 default_rel_ge_0 @@ -213,318 +399,312 @@ Global Hint Resolve default_rel_plus_0_ge_1 : commonDB. +(** ** Error Accumulation Factors + + [g n] and [g1 n1 n2] bound accumulated rounding errors over sequences + of floating-point operations. The [commonDB] hint database is populated + with the lemmas below to support automated error bound proofs. *) + Section WithType. -Context {NAN: FPCore.Nans} {t: type}. +Context {NAN : FPCore.Nans} {t : type}. Notation D := (@default_rel t). Notation E := (@default_abs t). -Definition g (n: nat) : R := ((1 + D) ^ n - 1). +(** *** Relative Error Factor [g] *) -Lemma g_pos n: - 0 <= g n. -Proof. -unfold g. induction n. -simpl; nra. eapply Rle_trans; [apply IHn| apply Rplus_le_compat; try nra]. -simpl. eapply Rle_trans with (1 * (1+D )^n); try nra. -apply Rmult_le_compat; try nra. rewrite Rplus_comm. apply Rcomplements.Rle_minus_l. -field_simplify. -auto with commonDB. +Definition g (n : nat) : R := (1 + D) ^ n - 1. + +Lemma g_pos n : 0 <= g n. +Proof. + unfold g; induction n; simpl; try nra. + eapply Rle_trans; [apply IHn | apply Rplus_le_compat; try nra]. + eapply Rle_trans with (1 * (1 + D) ^ n); try nra. + apply Rmult_le_compat; try nra. + rewrite Rplus_comm; apply Rcomplements.Rle_minus_l. + field_simplify; auto with commonDB. Qed. Hint Resolve g_pos : commonDB. -Lemma le_g_Sn n : - g n <= g (S n). -Proof. -induction n; unfold g; simpl. - { field_simplify; auto with commonDB. } - unfold g in IHn. eapply Rplus_le_compat; try nra. - eapply Rmult_le_compat_l. - apply Rplus_le_le_0_compat; try nra; try apply default_rel_ge_0. - rewrite tech_pow_Rmult. apply Rle_pow; try lia. - rewrite Rplus_comm. apply Rcomplements.Rle_minus_l. - field_simplify; auto with commonDB. +Lemma le_g_Sn n : g n <= g (S n). +Proof. + induction n; unfold g; simpl. + - field_simplify; auto with commonDB. + - unfold g in IHn; eapply Rplus_le_compat; try nra. + eapply Rmult_le_compat_l. + + apply Rplus_le_le_0_compat; try nra; try apply default_rel_ge_0. + + rewrite tech_pow_Rmult; apply Rle_pow; try lia. + rewrite Rplus_comm; apply Rcomplements.Rle_minus_l. + field_simplify; auto with commonDB. Qed. Hint Resolve le_g_Sn : commonDB. -Lemma d_le_g n: -D <= g (n + 1). -Proof. unfold g. induction n; simpl; field_simplify; try nra. -eapply Rle_trans; [apply IHn|]. -apply Rplus_le_compat_r. -replace (D * (1 + D) ^ (n + 1) + (1 + D) ^ (n + 1)) - with ((1+D)^(n+1)*(D+1)) by nra. -eapply Rle_trans with ((1 + D ) ^ (n + 1) * 1); try nra. -eapply Rmult_le_compat; try nra. -{ apply pow_le. apply Fourier_util.Rle_zero_pos_plus1 ; auto with commonDB. } -apply Rcomplements.Rle_minus_l. field_simplify; auto with commonDB. +Lemma d_le_g n : D <= g (n + 1). +Proof. + unfold g. + apply Rcomplements.Rle_minus_r. + rewrite Rplus_comm. + replace (1 + D) with ((1 + D) ^ 1) at 1 by nra. + apply Rle_pow; [| lia]. + auto with commonDB. Qed. Hint Resolve d_le_g : commonDB. - -Lemma d_le_g_1 n: -(1<= n)%nat -> D <= g n. -Proof. -intros; unfold g. -eapply Rle_trans with ((1 + D )^1 - 1). -field_simplify; nra. -apply Rplus_le_compat; try nra. -apply Rle_pow; try lia. -auto with commonDB. Qed. +Lemma d_le_g_1 n : (1 <= n)%nat -> D <= g n. +Proof. + intros; unfold g. + eapply Rle_trans with ((1 + D) ^ 1 - 1). + - field_simplify; nra. + - apply Rplus_le_compat; try nra. + apply Rle_pow; try lia. + auto with commonDB. +Qed. Hint Resolve d_le_g_1 : commonDB. -Lemma one_plus_d_mul_g a n: - (1 + D ) * g n * a + D * a = g (n + 1) * a. -Proof. unfold g. rewrite Rmult_minus_distr_l. rewrite tech_pow_Rmult. -field_simplify. f_equal. rewrite Rmult_comm; repeat f_equal; lia. +Lemma one_plus_d_mul_g a n : + (1 + D) * g n * a + D * a = g (n + 1) * a. +Proof. + unfold g; rewrite Rmult_minus_distr_l; rewrite tech_pow_Rmult. + field_simplify; f_equal; rewrite Rmult_comm; repeat f_equal; lia. Qed. Hint Resolve one_plus_d_mul_g : commonDB. -Definition g1 (n1: nat) (n2: nat) : R := - INR n1 * E* (1 + g n2 ). +(** *** Mixed Error Factor [g1] *) -Lemma g1_pos n m : 0 <= g1 n m. -Proof. unfold g1. -apply Rmult_le_pos; try apply pos_INR. -apply Rmult_le_pos; try apply pos_INR. -apply default_abs_ge_0. unfold g; field_simplify. -apply pow_le. -apply Fourier_util.Rle_zero_pos_plus1. -auto with commonDB. +Definition g1 (n1 n2 : nat) : R := + INR n1 * E * (1 + g n2). + +Lemma g1_pos n m : 0 <= g1 n m. +Proof. + unfold g1. + apply Rmult_le_pos; try apply pos_INR. + apply Rmult_le_pos; try apply pos_INR. + apply default_abs_ge_0. + unfold g; field_simplify. + apply pow_le. + apply Fourier_util.Rle_zero_pos_plus1. + auto with commonDB. Qed. Hint Resolve g1_pos : commonDB. -Lemma one_plus_d_mul_g1 n: -(1 <= n )%nat -> -g1 n (n - 1) * (1 + D ) = g1 n n. +Lemma one_plus_d_mul_g1 n : + (1 <= n)%nat -> + g1 n (n - 1) * (1 + D) = g1 n n. Proof. -intros. -unfold g1, g; field_simplify. -symmetry. replace n with (S (n-1)) at 2. -rewrite <- tech_pow_Rmult. -field_simplify; nra. -rewrite <- Nat.sub_succ_l; auto; lia. + intros. + unfold g1, g; field_simplify. + symmetry; replace n with (S (n - 1)) at 2. + - rewrite <- tech_pow_Rmult; field_simplify; nra. + - rewrite <- Nat.sub_succ_l; auto; lia. Qed. Hint Resolve g1_pos : commonDB. -Lemma one_plus_d_mul_g1' n m: -g1 n m * (1 + D) = g1 n (S m). +Lemma one_plus_d_mul_g1' n m : + g1 n m * (1 + D) = g1 n (S m). Proof. -intros. -unfold g1, g; field_simplify. -symmetry. -rewrite <- tech_pow_Rmult. -field_simplify; nra. + intros. + unfold g1, g; field_simplify. + symmetry; rewrite <- tech_pow_Rmult; field_simplify; nra. Qed. Hint Resolve g1_pos : commonDB. -Hint Resolve fprec_lt_femax :commonDB. -Lemma e_le_g1 n: -(1 <= n )%nat -> -E <= g1 n n. +Hint Resolve fprec_lt_femax : commonDB. + +Lemma e_le_g1 n : (1 <= n)%nat -> E <= g1 n n. Proof. -intros; unfold g1. eapply Rle_trans with (1 * E * 1); try nra. -apply: Rmult_le_compat; first (field_simplify; auto with commonDB); try nra. -apply: Rmult_le_compat => //; auto with commonDB; try nra. -replace 1 with (INR 1) by (simpl; nra). -apply le_INR; auto with commonDB; lia. -rewrite Rplus_comm -Rcomplements.Rle_minus_l; field_simplify; -auto with commonDB. + intros; unfold g1. + eapply Rle_trans with (1 * E * 1); try nra. + apply: Rmult_le_compat; first (field_simplify; auto with commonDB); try nra. + apply: Rmult_le_compat => //; auto with commonDB; try nra. + - replace 1 with (INR 1) by (simpl; nra). + apply le_INR; auto with commonDB; lia. + - rewrite Rplus_comm -Rcomplements.Rle_minus_l; field_simplify; + auto with commonDB. Qed. Hint Resolve e_le_g1 : commonDB. - -Lemma plus_d_e_g1_le' n m: -(1 <= n )%nat -> (1 <= m)%nat -> -g1 n m + (1 + D) * E <= g1 (S n) m. -Proof. -intros; replace (S n) with (n + 1)%nat by lia. -rewrite /g1; field_simplify. -replace (INR (n + 1)) with (INR n + 1). -rewrite !Rmult_plus_distr_l !Rmult_1_r --Rplus_assoc -!Rmult_plus_distr_l Rmult_comm. -apply: Rplus_le_compat_r. -rewrite Rplus_comm -Rplus_assoc. -apply: Rplus_le_compat; try nra. -rewrite Rplus_comm. -apply: Rplus_le_compat; try nra. -apply: Rmult_le_compat_l; auto with commonDB. -field_simplify. -apply: Rminus_plus_le_minus. -rewrite Rplus_comm. -suff H1: (1 + D)^1 <= (1 + D) ^ m; try nra. -apply: Rle_pow; auto with commonDB. -lia. -rewrite plus_INR; simpl; nra. +Lemma plus_d_e_g1_le' n m : + (1 <= n)%nat -> (1 <= m)%nat -> + g1 n m + (1 + D) * E <= g1 (S n) m. +Proof. + intros; replace (S n) with (n + 1)%nat by lia. + rewrite /g1; field_simplify. + replace (INR (n + 1)) with (INR n + 1). + - rewrite !Rmult_plus_distr_l !Rmult_1_r + -Rplus_assoc -!Rmult_plus_distr_l Rmult_comm. + apply: Rplus_le_compat_r. + rewrite Rplus_comm -Rplus_assoc. + apply: Rplus_le_compat; try nra. + rewrite Rplus_comm. + apply: Rplus_le_compat; try nra. + apply: Rmult_le_compat_l; auto with commonDB. + field_simplify. + apply: Rminus_plus_le_minus. + rewrite Rplus_comm. + suff H1 : (1 + D) ^ 1 <= (1 + D) ^ m; try nra. + apply: Rle_pow; auto with commonDB; lia. + - rewrite plus_INR; simpl; nra. Qed. Hint Resolve plus_d_e_g1_le' : commonDB. - -Lemma mult_d_e_g1_le' n m: -(1 <= n )%nat -> (1 <= m)%nat -> -g1 n m * (1 + D) + E <= g1 (S n) (S m). -Proof. -intros; replace (S n) with (n + 1)%nat by lia. -replace (S m) with (m + 1)%nat by lia. -unfold g1, g; field_simplify. -replace (INR (n + 1)) with (INR n + 1) by - (rewrite plus_INR; simpl; nra). -replace (INR (m + 1)) with (INR m + 1) by - (rewrite plus_INR; simpl; nra). -rewrite !Rmult_plus_distr_l !Rmult_1_r. replace -(INR n * E * (1 + D) ^ m * D + -INR n * E * (1 + D) ^ m) with -(INR n * E * (1 + D) ^ m * (1 + D)) by nra. -rewrite !Rmult_plus_distr_r. -apply: Rplus_le_compat. -rewrite !Rmult_assoc Rmult_comm !Rmult_assoc. -apply: Rmult_le_compat_l; try nra. -apply: Rmult_le_compat_l; auto with commonDB. -rewrite -Rmult_assoc Rmult_comm. -apply: Rmult_le_compat_l; auto with commonDB. -rewrite Rmult_comm tech_pow_Rmult. -replace (S m) with (m + 1)%nat by lia; nra. -replace (E) with (E * 1) at 1 by nra. -apply Rmult_le_compat_l; [apply default_abs_ge_0 | ]; -auto with commonDB. +Lemma mult_d_e_g1_le' n m : + (1 <= n)%nat -> (1 <= m)%nat -> + g1 n m * (1 + D) + E <= g1 (S n) (S m). +Proof. + intros. + replace (S n) with (n + 1)%nat by lia. + replace (S m) with (m + 1)%nat by lia. + unfold g1, g; field_simplify. + replace (INR (n + 1)) with (INR n + 1) by + (rewrite plus_INR; simpl; nra). + replace (INR (m + 1)) with (INR m + 1) by + (rewrite plus_INR; simpl; nra). + rewrite !Rmult_plus_distr_l !Rmult_1_r. + replace + (INR n * E * (1 + D) ^ m * D + INR n * E * (1 + D) ^ m) + with + (INR n * E * (1 + D) ^ m * (1 + D)) by nra. + rewrite !Rmult_plus_distr_r. + apply: Rplus_le_compat. + - rewrite !Rmult_assoc Rmult_comm !Rmult_assoc. + apply: Rmult_le_compat_l; try nra. + apply: Rmult_le_compat_l; auto with commonDB. + rewrite -Rmult_assoc Rmult_comm. + apply: Rmult_le_compat_l; auto with commonDB. + rewrite Rmult_comm tech_pow_Rmult. + replace (S m) with (m + 1)%nat by lia; nra. + - replace E with (E * 1) at 1 by nra. + apply Rmult_le_compat_l; [apply default_abs_ge_0 |]; + auto with commonDB. Qed. Hint Resolve mult_d_e_g1_le' : commonDB. -Lemma plus_d_e_g1_le n: -(1 <= n )%nat -> -g1 n n + (1 + D) * E <= g1 (S n) n. -Proof. auto with commonDB. Qed. +Lemma plus_d_e_g1_le n : + (1 <= n)%nat -> + g1 n n + (1 + D) * E <= g1 (S n) n. +Proof. auto with commonDB. Qed. Hint Resolve plus_d_e_g1_le : commonDB. -Lemma plus_e_g1_le n: -g1 n n + E <= g1 (S n) n. -Proof. -rewrite /g1. -replace (S n) with (n + 1)%nat by lia. -replace (INR (n + 1)) with (INR n + 1). -rewrite Rmult_assoc Rmult_assoc. -apply: Rle_trans; - [ apply: Rle_refl| rewrite Rmult_plus_distr_r]. -apply: Rplus_le_compat_l. -rewrite Rmult_plus_distr_l Rmult_1_l Rmult_1_r. -suff : E + 0 * 0 <= E + E * g n; first by nra. -apply: Rplus_le_compat_l. -apply: Rmult_le_compat; try nra; -auto with commonDB. -rewrite plus_INR; simpl; nra. +Lemma plus_e_g1_le n : g1 n n + E <= g1 (S n) n. +Proof. + rewrite /g1. + replace (S n) with (n + 1)%nat by lia. + replace (INR (n + 1)) with (INR n + 1). + - rewrite Rmult_assoc Rmult_assoc. + apply: Rle_trans; + [apply: Rle_refl | rewrite Rmult_plus_distr_r]. + apply: Rplus_le_compat_l. + rewrite Rmult_plus_distr_l Rmult_1_l Rmult_1_r. + suff : E + 0 * 0 <= E + E * g n; first by nra. + apply: Rplus_le_compat_l. + apply: Rmult_le_compat; try nra; auto with commonDB. + - rewrite plus_INR; simpl; nra. Qed. Hint Resolve plus_e_g1_le : commonDB. -Lemma g1n_le_g1Sn n: -(1 <= n )%nat -> -g1 n (n - 1) <= g1 (S n) (S (n - 1)). -Proof. -rewrite /g1 => Hn. -replace (S n) with (n + 1)%nat by lia. -replace (INR (n + 1)) with (INR n + 1). -apply: Rmult_le_compat. -apply: Rmult_le_pos; auto with commonDB. -rewrite /g; field_simplify; apply pow_le; -auto with commonDB. -apply: Rmult_le_compat; try nra; auto with commonDB. -apply: Rplus_le_compat_l; auto with commonDB. -rewrite plus_INR; simpl; nra. +Lemma g1n_le_g1Sn n : + (1 <= n)%nat -> + g1 n (n - 1) <= g1 (S n) (S (n - 1)). +Proof. + rewrite /g1 => Hn. + replace (S n) with (n + 1)%nat by lia. + replace (INR (n + 1)) with (INR n + 1). + - apply: Rmult_le_compat. + + apply: Rmult_le_pos; auto with commonDB. + + rewrite /g; field_simplify; apply pow_le; auto with commonDB. + + apply: Rmult_le_compat; try nra; auto with commonDB. + + apply: Rplus_le_compat_l; auto with commonDB. + - rewrite plus_INR; simpl; nra. Qed. Hint Resolve g1n_le_g1Sn : commonDB. -Lemma g1n_le_g1Sn' n: -g1 n n <= g1 (S n) (S n). -Proof. -rewrite /g1. -replace (S n) with (n + 1)%nat by lia. -replace (INR (n + 1)) with (INR n + 1). -apply: Rmult_le_compat. -apply: Rmult_le_pos; auto with commonDB. -rewrite /g; field_simplify; apply pow_le; -auto with commonDB. -apply: Rmult_le_compat; try nra; auto with commonDB. -apply: Rplus_le_compat_l; auto with commonDB. -rewrite addnC. -auto with commonDB. -rewrite plus_INR; simpl; nra. +Lemma g1n_le_g1Sn' n : g1 n n <= g1 (S n) (S n). +Proof. + rewrite /g1. + replace (S n) with (n + 1)%nat by lia. + replace (INR (n + 1)) with (INR n + 1). + - apply: Rmult_le_compat. + + apply: Rmult_le_pos; auto with commonDB. + + rewrite /g; field_simplify; apply pow_le; auto with commonDB. + + apply: Rmult_le_compat; try nra; auto with commonDB. + + apply: Rplus_le_compat_l; auto with commonDB. + rewrite addnC; auto with commonDB. + - rewrite plus_INR; simpl; nra. Qed. Hint Resolve g1n_le_g1Sn' : commonDB. -Lemma Rplus_le_lt_compat a1 a2 b1 b2 : - a1 <= a2 -> b1 < b2 -> a1 + b1 < a2 + b2. -Proof. nra. Qed. - -Lemma Rmult_le_lt_compat a1 a2 b1 b2 : - 0 < a1 -> 0 < b1 -> a1 < a2 -> b1 <= b2 -> a1 * b1 < a2 * b2. -Proof. nra. Qed. - -Lemma g1n_lt_g1Sn n: -(1 <= n )%nat -> -g1 n (n - 1) < g1 (S n) (S (n - 1)). -Proof. -rewrite /g1 => Hn. -replace (S n) with (n + 1)%nat by lia. -apply: Rmult_lt_compat. -apply: Rmult_le_pos; auto with commonDB. -rewrite /g; field_simplify; apply pow_le; -auto with commonDB. -apply: Rmult_le_lt_compat; try nra; auto with commonDB. -apply lt_0_INR; lia. -suff : INR n < INR n + 1 ; simpl; try nra. -move => H. -rewrite plus_INR; simpl; nra. -rewrite /g; field_simplify. -apply: Rlt_pow; auto with commonDB. -suff : 0 < D; try nra; auto with commonDB. -Qed. +Lemma g1n_lt_g1Sn n : + (1 <= n)%nat -> + g1 n (n - 1) < g1 (S n) (S (n - 1)). +Proof. + rewrite /g1 => Hn. + replace (S n) with (n + 1)%nat by lia. + apply: Rmult_lt_compat. + - apply: Rmult_le_pos; auto with commonDB. + - rewrite /g; field_simplify; apply pow_le; auto with commonDB. + - apply: Rmult_le_lt_compat; try nra; auto with commonDB. + + apply lt_0_INR; lia. + + suff : INR n < INR n + 1; simpl; try nra. + move => H; rewrite plus_INR; simpl; nra. + - rewrite /g; field_simplify. + apply: Rlt_pow; auto with commonDB. + suff : 0 < D; try nra; auto with commonDB. +Qed. + +(** ** Floating-Point Operation Lemmas + + Structural identities: rounding is idempotent on exact FP values + ([round_FT2R]); signed-zero behavior of [BPLUS] and [BMINUS]; + commutativity of [BPLUS]; and the equivalence [BMINUS x y = BPLUS x (BOPP y)]. *) Lemma round_FT2R a : - (Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) (FT2R a)) = @FT2R t a. -Proof. -rewrite Generic_fmt.round_generic //. -apply Binary.generic_format_B2R. + Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) (FT2R a) = @FT2R t a. +Proof. + rewrite Generic_fmt.round_generic //. + apply Binary.generic_format_B2R. Qed. - -Lemma BMINUS_neg_zero: forall (c: ftype t), feq (BMINUS neg_zero (BOPP c)) c. +Lemma BMINUS_neg_zero : forall c : ftype t, feq (BMINUS neg_zero (BOPP c)) c. Proof. destruct c; try destruct s; reflexivity. Qed. -Lemma foldl_congr: forall (op: ftype t -> ftype t -> ftype t) - (Hop: forall x y, feq x y -> forall x' y', feq x' y' -> - feq (op x x') (op y y')) - (u v: ftype t) al bl, - feq u v -> Forall2 feq al bl -> feq (foldl op u al) (foldl op v bl). +Lemma foldl_congr : + forall (op : ftype t -> ftype t -> ftype t) + (Hop : forall x y, feq x y -> forall x' y', feq x' y' -> + feq (op x x') (op y y')) + (u v : ftype t) al bl, + feq u v -> Forall2 feq al bl -> feq (foldl op u al) (foldl op v bl). Proof. -intros. -revert u v H bl H0; induction al; destruct bl; simpl; intros; inversion H0; clear H0; subst; auto. + intros. + revert u v H bl H0; induction al; destruct bl; simpl; intros; + inversion H0; clear H0; subst; auto. Qed. -Lemma BPLUS_neg_zero: forall (c: ftype t), feq (BPLUS c neg_zero) c. +Lemma BPLUS_neg_zero : forall c : ftype t, feq (BPLUS c neg_zero) c. Proof. destruct c; try destruct s; reflexivity. Qed. -Lemma BPLUS_comm: forall (x y: ftype t), feq (BPLUS x y) (BPLUS y x). +Lemma BPLUS_comm : forall x y : ftype t, feq (BPLUS x y) (BPLUS y x). Proof. -destruct x, y; try destruct s; try destruct s0; try reflexivity; -unfold BPLUS, BINOP, feq, Binary.Bplus, Binary.BSN2B, BinarySingleNaN.SF2B; simpl; -rewrite (Z.min_comm e1 e); -rewrite ?(Pos.add_comm (fst (SpecFloat.shl_align m0 e1 (Z.min e e1)))). -1,4: destruct (BinarySingleNaN.SF2B _ _); simpl; auto. -1,2: destruct (BinarySingleNaN.binary_normalize _ _ _ _ _ _ _ _); simpl; auto. + destruct x, y; try destruct s; try destruct s0; try reflexivity; + unfold BPLUS, BINOP, feq, Binary.Bplus, Binary.BSN2B, BinarySingleNaN.SF2B; simpl; + rewrite (Z.min_comm e1 e); + rewrite ?(Pos.add_comm (fst (SpecFloat.shl_align m0 e1 (Z.min e e1)))). + 1,4: destruct (BinarySingleNaN.SF2B _ _); simpl; auto. + 1,2: destruct (BinarySingleNaN.binary_normalize _ _ _ _ _ _ _ _); simpl; auto. Qed. -Lemma MINUS_PLUS_BOPP: forall x y: ftype t, feq (BMINUS x y) (BPLUS x (BOPP y)). +Lemma MINUS_PLUS_BOPP : forall x y : ftype t, feq (BMINUS x y) (BPLUS x (BOPP y)). Proof. -destruct x, y; try destruct s; try destruct s0; try reflexivity; -unfold BMINUS, BPLUS, BINOP, BOPP, UNOP, feq, Binary.Bplus, Binary.Bminus, - Binary.BSN2B, BinarySingleNaN.SF2B, Binary.build_nan; simpl. -1,4: destruct (BinarySingleNaN.binary_normalize _ _ _ _ _ _ _ _); auto. -1,2: destruct (BinarySingleNaN.SF2B _ _); auto. + destruct x, y; try destruct s; try destruct s0; try reflexivity; + unfold BMINUS, BPLUS, BINOP, BOPP, UNOP, feq, Binary.Bplus, Binary.Bminus, + Binary.BSN2B, BinarySingleNaN.SF2B, Binary.build_nan; simpl. + 1,4: destruct (BinarySingleNaN.binary_normalize _ _ _ _ _ _ _ _); auto. + 1,2: destruct (BinarySingleNaN.SF2B _ _); auto. Qed. -End WithType. +End WithType. -Global Hint Resolve +Global Hint Resolve g_pos le_g_Sn d_le_g @@ -542,12 +722,18 @@ Global Hint Resolve g1n_lt_g1Sn : commonDB. +(** ** Automation + + [field_simplify_Rabs] reduces a goal of the form [Rabs e <= _] + by simplifying the expression [e] and splitting denominator non-zero + side conditions. *) + Ltac field_simplify_Rabs := -match goal with -|- Rabs ?a <= _ => -field_simplify a; -(repeat split; -try match goal with |-?z <> 0 => -field_simplify z (*; Interval.Tactic.interval *) -end) -end. + match goal with + | |- Rabs ?a <= _ => + field_simplify a; + repeat split; + try match goal with + | |- ?z <> 0 => field_simplify z + end + end. \ No newline at end of file diff --git a/accuracy_proofs/dot_acc.v b/accuracy_proofs/dot_acc.v index a68064e..2d2de00 100644 --- a/accuracy_proofs/dot_acc.v +++ b/accuracy_proofs/dot_acc.v @@ -1,128 +1,219 @@ -(** This file contains three main theorems for the accuracy of the non-fma - dot product : dotprod_mixed_error, dotprod_forward_error, - and sparse_dotprod_forward_error. *) +(** * Dot Product Accuracy Proofs (Non-FMA) -From LAProof.accuracy_proofs Require Import preamble common - dotprod_model sum_model - float_acc_lems dot_acc_lemmas. + This file establishes floating-point accuracy bounds for the non-fused + multiply-add (non-FMA) dot product computation over lists of floating-point + numbers. It provides both mixed-error and forward-error analyses, as well + as a refined bound for sparse vectors. + + ** Error Factors + + Throughout, the accumulated relative error factor is + %$g(n) = (1 + \mathbf{u})^n - 1$%#\(g(n) = (1 + \mathbf{u})^n - 1\)# and + the mixed absolute error factor is + %$g_1(n_1, n_2) = n_1 \cdot \eta \cdot (1 + g(n_2))$%#\(g_1(n_1, n_2) = n_1 \cdot \eta \cdot (1 + g(n_2))\)#, + where %$\mathbf{u}$%#\(\mathbf{u}\)# is the unit roundoff and + %$\eta$%#\(\eta\)# is the underflow threshold for the given floating-point type. + Both are defined in [common]. + + ** Main Results + + - [dotprod_mixed_error]: Shows that the computed dot product can be + expressed as an exact dot product of slightly perturbed inputs plus a + small absolute error term. Each input component is perturbed by a + relative factor bounded by %$g(n)$%#\(g(n)\)#, and the absolute + residual is bounded by %$g_1(n, n)$%#\(g_1(n,n)\)#, where + %$n$%#\(n\)# is the vector length. + + - [dotprod_forward_error]: Bounds the absolute forward error + %$|\mathtt{fl}(v_1 \cdot v_2) - v_1 \cdot v_2|$%#\(|\mathtt{fl}(v_1 \cdot v_2) - v_1 \cdot v_2|\)# + by %$g(n)\,(|v_1| \cdot |v_2|) + g_1(n,\, n-1)$%#\(g(n)\,(|v_1| \cdot |v_2|) + g_1(n,\,n-1)\)#, + where %$|v|$%#\(|v|\)# denotes componentwise absolute value. + + - [sparse_dotprod_forward_error]: Refines [dotprod_forward_error] for + sparse inputs by replacing the full vector length %$n$%#\(n\)# with + the number of nonzero entries %$n_{\mathrm{nz}}$%#\(n_{\mathrm{nz}}\)#, + giving a tighter bound when the input vectors are sparse. + + ** Dependencies + + This file relies on the following modules from [LAProof.accuracy_proofs]: + - [preamble]: basic setup and notation, + - [common]: shared definitions including << nnzR >> and << neg_zero >>, + - [dotprod_model]: the floating-point model [dotprodF] and its relational + characterization [dotprodF_rel_fold_right], + - [sum_model]: summation model infrastructure, + - [float_acc_lems]: generic floating-point accuracy lemmas, + - [dot_acc_lemmas]: the core relational accuracy lemmas + [dotprod_mixed_error_rel], [dotprod_forward_error_rel], and + [sparse_dotprod_forward_error_rel] that drive the proofs here. + + ** Structure + + The file is organised into two [Section]s: + - [MixedError]: proves [dotprod_mixed_error]. + - [ForwardError]: proves [dotprod_forward_error] and + [sparse_dotprod_forward_error]. +*) + +From LAProof.accuracy_proofs Require Import + preamble + common + dotprod_model + sum_model + float_acc_lems + dot_acc_lemmas. Require Import Reals. + Open Scope R. -Section MixedError. -Context {NAN: FPCore.Nans} {t : type}. +(* ------------------------------------------------------------------ *) +Section MixedError. -Notation g := (@g t). -Notation g1 := (@g1 t). -Notation D := (@default_rel t). -Notation E := (@default_abs t). +Context {NAN : FPCore.Nans} {t : type}. + +Notation g := (@g t). +Notation g1 := (@g1 t). +Notation D := (@default_rel t). +Notation E := (@default_abs t). Notation neg_zero := (@common.neg_zero t). -Variables (v1 v2: list (ftype t)). -Hypothesis Hlen: size v1 = size v2. -Hypothesis Hfin: Binary.is_finite (dotprodF v1 v2) = true. +Variables (v1 v2 : list (ftype t)). -Lemma dotprod_mixed_error: +Hypothesis Hlen : size v1 = size v2. +Hypothesis Hfin : Binary.is_finite (dotprodF v1 v2) = true. + +(** [dotprod_mixed_error] expresses the computed dot product as an exact inner + product of component-wise perturbed inputs plus a small absolute offset. + The relative perturbation on each input component is bounded by + %$g(n)$%#\(g(n)\)# and the absolute residual by %$g_1(n,n)$%#\(g_1(n,n)\)#. *) + +Lemma dotprod_mixed_error : exists (u : list R) (eta : R), - size u = size v2 /\ - FT2R (dotprodF v1 v2) = dotprodR u (map FT2R v2) + eta /\ - (forall n, (n < size v2)%nat -> exists delta, - nth 0 u n = FT2R (nth neg_zero v1 n) * (1 + delta) /\ Rabs delta <= g (size v2)) /\ - Rabs eta <= g1 (size v2) (size v2). + size u = size v2 + /\ FT2R (dotprodF v1 v2) = dotprodR u (map FT2R v2) + eta + /\ (forall n, (n < size v2)%nat -> + exists delta, + nth 0 u n = FT2R (nth neg_zero v1 n) * (1 + delta) + /\ Rabs delta <= g (size v2)) + /\ Rabs eta <= g1 (size v2) (size v2). Proof. -intros. -assert (size (zip v1 v2) = size v1) by - (rewrite size_zip; lia). -assert (Hlenr : size (rev v1) = size (rev v2)) by (rewrite !size_rev; auto). -rewrite <- size_rev in Hlen. -pose proof dotprodF_rel_fold_right v1 v2 as H1. -move :Hlen; rewrite size_rev => Hlen'. -rewrite rev_zip in H1; auto. -pose proof (dotprod_mixed_error_rel (rev v1) (rev v2) Hlenr (dotprodF v1 v2) H1 Hfin) as - (u & eta & H2 & H3 & H4 & H5). -exists (rev u), eta; repeat split; auto. -- -move :H2; rewrite !size_rev //. -- -pose proof dotprodR_rel u (map FT2R (rev v2)). -assert (dotprodR (rev u) (map FT2R v2) = FT2R (dotprodF v1 v2) - eta). -eapply R_dot_prod_rel_eq; eauto. -rewrite -dotprodR_rev; [ | rewrite size_map; rewrite size_rev in H2; auto]. -rewrite -map_rev; auto. -nra. -- -rewrite !size_rev in H4, H2, H5. -intros. -assert ((size u - S n < size v2)%nat) by lia. -specialize (H4 (size u - S n)%nat H6). -rewrite nth_rev in H4; [ | rewrite Hlen' // ]. -rewrite nth_rev; [ | lia]. -destruct H4 as (delta & Hn & HD). -exists delta; split. -rewrite Hn; repeat f_equal. -rewrite Hlen'. -rewrite H2. -rewrite <- Nat.sub_succ_l. -simpl. lia. -lia. -apply HD. -- -rewrite !size_rev in H5; auto. + assert (Hzip : size (zip v1 v2) = size v1) by + (rewrite size_zip; lia). + assert (Hlenr : size (rev v1) = size (rev v2)) by + (rewrite !size_rev; auto). + rewrite <- size_rev in Hlen. + pose proof dotprodF_rel_fold_right v1 v2 as Hfold. + move: Hlen; rewrite size_rev => Hlen'. + rewrite rev_zip in Hfold; auto. + pose proof + (dotprod_mixed_error_rel + (rev v1) (rev v2) Hlenr (dotprodF v1 v2) Hfold Hfin) + as (u & eta & Hsize_u & Hval_eq & Helem_bound & Heta_bound). + exists (rev u), eta. + repeat split. + - (* size *) + move: Hsize_u; rewrite !size_rev //. + - (* value equation *) + pose proof dotprodR_rel u (map FT2R (rev v2)) as Hdot_rel. + assert (Heq : dotprodR (rev u) (map FT2R v2) = FT2R (dotprodF v1 v2) - eta). + { eapply R_dot_prod_rel_eq; eauto. + rewrite -dotprodR_rev; + [ | rewrite size_map; rewrite size_rev in Hsize_u; auto]. + rewrite -map_rev; auto. } + nra. + - (* per-element bound *) + rewrite !size_rev in Helem_bound, Hsize_u, Heta_bound. + intros n Hn. + assert (Hlt : (size u - S n < size v2)%nat) by lia. + specialize (Helem_bound (size u - S n)%nat Hlt). + rewrite nth_rev in Helem_bound; [ | rewrite Hlen' //]. + rewrite nth_rev; [ | lia]. + destruct Helem_bound as (delta & Hval & Hdelta). + exists delta; split. + + rewrite Hval; repeat f_equal. + rewrite Hlen' Hsize_u. + rewrite <- Nat.sub_succ_l; [simpl; lia | lia]. + + exact Hdelta. + - (* eta bound *) + rewrite !size_rev in Heta_bound; auto. Qed. End MixedError. -Section ForwardError. -Context {NAN: FPCore.Nans} {t : type}. +(* ------------------------------------------------------------------ *) +Section ForwardError. + +Context {NAN : FPCore.Nans} {t : type}. Variables v1 v2 : list (ftype t). + Notation v1R := (map FT2R v1). Notation v2R := (map FT2R v2). Notation v1R' := (map Rabs v1R). Notation v2R' := (map Rabs v2R). Notation n := (size v2). -Notation g := (@g t). +Notation g := (@g t). Notation g1 := (@g1 t). -Hypothesis Hlen: size v1 = size v2. -Hypothesis Hfin: Binary.is_finite (dotprodF v1 v2) = true. +Hypothesis Hlen : size v1 = size v2. +Hypothesis Hfin : Binary.is_finite (dotprodF v1 v2) = true. + +(** [dotprod_forward_error] bounds the absolute forward error of the computed + dot product by %$g(n)\,(|v_1| \cdot |v_2|) + g_1(n,\,n-1)$%#\(g(n)\,(|v_1| \cdot |v_2|) + g_1(n,\,n-1)\)#, + where %$|v|$%#\(|v|\)# denotes the componentwise absolute-value vector. *) -Lemma dotprod_forward_error: - Rabs (FT2R (dotprodF v1 v2) - dotprodR v1R v2R ) - <= g n * dotprodR v1R' v2R' + g1 n (n - 1). +Lemma dotprod_forward_error : + Rabs (FT2R (dotprodF v1 v2) - dotprodR v1R v2R) + <= g n * dotprodR v1R' v2R' + g1 n (n - 1). Proof. -intros. -pose proof R_dot_prod_rel_fold_right' t v1 v2 Hlen as HB. -pose proof R_dot_prod_rel_fold_right_Rabs' t v1 v2 Hlen as HC. - simpl in HB, HC. rewrite <- map_rev in HC, HB. rewrite <- map_rev in HC. -pose proof dotprod_forward_error_rel (rev (zip v1 v2)) - (dotprodF v1 v2) (dotprodF_rel_fold_right _ _ ) Hfin - (dotprodR v1R v2R) (dotprodR v1R' v2R') HB HC. -rewrite size_rev size_zip Hlen minnn in H. -auto. + pose proof R_dot_prod_rel_fold_right' t v1 v2 Hlen as HB. + pose proof R_dot_prod_rel_fold_right_Rabs' t v1 v2 Hlen as HC. + simpl in HB, HC. + rewrite <- map_rev in HC, HB. + rewrite <- map_rev in HC. + pose proof + (dotprod_forward_error_rel + (rev (zip v1 v2)) + (dotprodF v1 v2) + (dotprodF_rel_fold_right _ _) + Hfin + (dotprodR v1R v2R) + (dotprodR v1R' v2R') + HB HC) as H. + rewrite size_rev size_zip Hlen minnn in H. + exact H. Qed. Notation nnzR := (common.nnzR v1R). -Lemma sparse_dotprod_forward_error: - Rabs (FT2R (dotprodF v1 v2) - dotprodR v1R v2R ) <= - g nnzR * dotprodR v1R' v2R' + g1 nnzR (nnzR - 1). -Proof. -intros. -pose proof dotprodF_rel_fold_right v1 v2 as HA. -pose proof R_dot_prod_rel_fold_right' t v1 v2 Hlen as HB. -pose proof R_dot_prod_rel_fold_right_Rabs' t v1 v2 Hlen as HC. - simpl in HB, HC. rewrite <- map_rev in HC, HB. +(** [sparse_dotprod_forward_error] refines [dotprod_forward_error] for sparse + inputs: the vector length %$n$%#\(n\)# is replaced by the number of + nonzero entries %$n_{\mathrm{nz}}$%#\(n_{\mathrm{nz}}\)#, yielding + %$g(n_{\mathrm{nz}})\,(|v_1| \cdot |v_2|) + g_1(n_{\mathrm{nz}},\,n_{\mathrm{nz}}-1)$%#\(g(n_{\mathrm{nz}})\,(|v_1| \cdot |v_2|) + g_1(n_{\mathrm{nz}},\,n_{\mathrm{nz}}-1)\)#. *) + +Lemma sparse_dotprod_forward_error : + Rabs (FT2R (dotprodF v1 v2) - dotprodR v1R v2R) + <= g nnzR * dotprodR v1R' v2R' + g1 nnzR (nnzR - 1). +Proof. + pose proof dotprodF_rel_fold_right v1 v2 as HA. + pose proof R_dot_prod_rel_fold_right' t v1 v2 Hlen as HB. + pose proof R_dot_prod_rel_fold_right_Rabs' t v1 v2 Hlen as HC. + simpl in HB, HC. + rewrite <- map_rev in HC, HB. rewrite <- map_rev in HC. -pose proof sparse_dotprod_forward_error_rel (rev v1) (rev v2). - rewrite !size_rev -rev_zip in H; auto. -specialize (H Hlen (dotprodF v1 v2) HA Hfin (dotprodR v1R v2R) - (dotprodR v1R' v2R') HB HC). -rewrite map_rev in H. -unfold common.nnzR, nnzR in H. -rewrite !count_rev in H. -auto. + pose proof sparse_dotprod_forward_error_rel (rev v1) (rev v2) as H. + rewrite !size_rev -rev_zip in H; auto. + specialize (H Hlen + (dotprodF v1 v2) HA Hfin + (dotprodR v1R v2R) + (dotprodR v1R' v2R') + HB HC). + rewrite map_rev in H. + unfold common.nnzR, nnzR in H. + rewrite !count_rev in H. + exact H. Qed. -End ForwardError. +End ForwardError. \ No newline at end of file diff --git a/accuracy_proofs/dot_acc_lemmas.v b/accuracy_proofs/dot_acc_lemmas.v index 431ca41..f31e80f 100644 --- a/accuracy_proofs/dot_acc_lemmas.v +++ b/accuracy_proofs/dot_acc_lemmas.v @@ -1,902 +1,1095 @@ -(* This file contatins lemmas for the accuracy of the fma and non-fma dot products. - These lemmas are used to prove the main accuracy theorems in dot_acc.v and fma_dot_acc.v. - The theorems use the inductive definitions R_dot_prod_rel and dot_prod_rel, - which are a bit easier (for me) to work with at a low level then dotprodF and dotprodR. *) +(** * Forward and Mixed Error Bounds for Floating-Point Dot Products -From LAProof.accuracy_proofs Require Import preamble common - dotprod_model - float_acc_lems. + This file establishes accuracy lemmas for both FMA-based and non-FMA-based + floating-point dot product computations. These results are foundational + building blocks used in [dot_acc.v] and [fma_dot_acc.v] to prove the main + accuracy theorems. + + The proofs are structured around two inductive relations: + - [dot_prod_rel] : relates a list of floating-point pairs to their + computed (non-FMA) floating-point dot product. + - [fma_dot_prod_rel] : relates a list of floating-point pairs to their + computed FMA-based floating-point dot product. + - [R_dot_prod_rel] : the analogous real-arithmetic dot product relation, + used to state the ideal (exact) result. + + The error bounds are expressed using the standard relative and absolute + error model parameters: [default_rel] and [default_abs]. + + ** Summary of Main Results + + *** Forward Error Bounds + + - [dotprod_forward_error_rel]: Bounds the absolute error of a non-FMA + floating-point dot product relative to the exact real dot product. + + - [fma_dotprod_forward_error_rel]: The analogous forward error bound for + FMA-based dot products. + + *** Sparse Forward Error Bounds + + - [sparse_dotprod_forward_error_rel]: A tighter forward error bound for + non-FMA dot products that exploits zero entries in one of the input + vectors, replacing the vector length with the number of nonzeros. + + - [sparse_fma_dotprod_forward_error]: The analogous sparse bound for + FMA-based dot products. + + *** Mixed Error Bounds + + - [dotprod_mixed_error_rel]: Shows that the computed non-FMA dot product + can be expressed as an exact dot product of slightly perturbed inputs + plus a small absolute error term. + + - [fma_dotprod_mixed_error_rel]: The analogous mixed error representation + for FMA-based dot products, with a slightly tighter absolute error term. +*) + +From LAProof.accuracy_proofs Require Import + preamble + common + dotprod_model + float_acc_lems. + +(** * Section 1: Forward Error Bound — Non-FMA Dot Product *) Section ForwardErrorRel1. -(* forward error bound for non-fma dot product using inductive rels *) -Context {NAN: FPCore.Nans} {t : type}. +(** + Forward error bound for the non-FMA dot product, using the inductive + relation [dot_prod_rel]. +*) + +Context {NAN : FPCore.Nans} {t : type}. -Notation g := (@g t). +Notation g := (@g t). Notation g1 := (@g1 t). -Notation D := (@default_rel t). -Notation E := (@default_abs t). - -Lemma dotprod_forward_error_rel: - forall (vF: seq (ftype t * ftype t)) - (fp : ftype t) - (Hfp : dot_prod_rel vF fp) - (Hfin: Binary.is_finite fp = true) - (rp rp_abs : R) - (Hrp : R_dot_prod_rel (map FR2 vF) rp) - (Hra : R_dot_prod_rel (map Rabsp (map FR2 vF)) rp_abs), - Rabs (FT2R fp - rp) <= g (size vF) * rp_abs + g1 (size vF) (size vF - 1). +Notation D := (@default_rel t). +Notation E := (@default_abs t). + + +(** [dotprod_forward_error_rel] establishes the standard forward error bound + for the non-FMA floating-point dot product. + + Given: + - vF : a list of floating-point input pairs, + - fp : the computed (finite) floating-point dot product satisfying + [dot_prod_rel vF fp], + - rp : the exact real dot product satisfying [R_dot_prod_rel (map FR2 vF) rp], + - rp_abs : the dot product of absolute values satisfying + [R_dot_prod_rel (map Rabsp (map FR2 vF)) rp_abs], + + the absolute error satisfies: + << |FT2R fp - rp| <= g(|vF|) * rp_abs + g1(|vF|, |vF| - 1) >> +*) + +Lemma dotprod_forward_error_rel : + forall (vF : seq (ftype t * ftype t)) + (fp : ftype t) + (Hfp : dot_prod_rel vF fp) + (Hfin : Binary.is_finite fp = true) + (rp rp_abs : R) + (Hrp : R_dot_prod_rel (map FR2 vF) rp) + (Hra : R_dot_prod_rel (map Rabsp (map FR2 vF)) rp_abs), + Rabs (FT2R fp - rp) <= g (size vF) * rp_abs + g1 (size vF) (size vF - 1). Proof. induction vF. -{ -intros; -inversion Hrp; -inversion Hfp; -inversion Hra; -subst. -rewrite /g /g1 /= !Rminus_diag Rabs_R0 !Rmult_0_l. lra. -} +{ (* base case: empty list *) + intros; + inversion Hrp; inversion Hfp; inversion Hra; subst. + rewrite /g /g1 /= !Rminus_diag Rabs_R0 !Rmult_0_l; lra. } intros. rename vF into l. -assert (Hl: l = [::] \/ l <> [::]) +assert (Hl : l = [::] \/ l <> [::]) by (destruct l; auto; right; congruence). destruct Hl. -- (* case empty l *) -subst; simpl. -rewrite (R_dot_prod_rel_single rp (FR2 a)); auto. -inversion Hfp. inversion H2. subst. -destruct (BPLUS_correct _ _ Hfin) as [[A Hz] ?]. -rewrite Bplus_0R; auto. -destruct (BMULT_accurate' _ _ A) as (d' & e' & Hed' & Hd' & He' & B). -unfold g1, g; simpl. -inversion Hra. inversion H4; subst. -rewrite {}B Rmult_1_r !Rplus_0_r. -field_simplify. -field_simplify_Rabs. destruct a; simpl. -eapply Rle_trans. apply Rabs_triang. -rewrite Rabs_mult. -eapply Rle_trans. -apply Rplus_le_compat. apply Rmult_le_compat; try apply Rabs_pos. -apply Rle_refl. apply Hd'. apply He'. -rewrite Rmult_comm. -apply Rplus_le_compat; try nra. -rewrite Rmult_assoc. -rewrite - Rabs_mult; nra. -- (* non-empty l *) -intros; inversion Hfp; -inversion Hrp; inversion Hra; subst. -(destruct (BPLUS_finite_e _ _ Hfin) as (A & B)). -(* IHl *) -specialize (IHvF s H3 B s0 s1 H7 H11). -destruct (BPLUS_accurate' (BMULT (fst a) (snd a)) s Hfin) as (d' & Hd'& Hplus); -rewrite Hplus; clear Hplus. -destruct (BMULT_accurate' (fst a) (snd a) A) as (d & e & Hed & Hd& He& Hmul); -rewrite Hmul; clear Hmul. -(* algebra *) -apply size_not_empty_nat in H. -destruct a; cbv [ FR2 Rabsp fst snd]. -simpl. -set (n:= size l) in *. -set (F:= FT2R f * FT2R f0). -field_simplify_Rabs. -replace (F * d * d' + F * d + F * d' + e * d' + e + FT2R s * d' + FT2R s - s0) with -((F * d * d' + F * d + F * d' + FT2R s * d') + (FT2R s - s0) + (1 + d') * e) by nra. -eapply Rle_trans; - [ apply Rabs_triang | ]. -eapply Rle_trans; [ apply Rplus_le_compat; [eapply Rle_trans; [ apply Rabs_triang | ] |] | ]. -apply Rplus_le_compat_l; apply IHvF . -rewrite Rabs_mult; apply Rmult_le_compat_l; [apply Rabs_pos | apply He]. -rewrite Rplus_assoc. -eapply Rle_trans; - [ apply Rplus_le_compat_r ; eapply Rle_trans; [ apply Rabs_triang | ] | ]. -apply Rplus_le_compat_l; rewrite Rabs_mult; rewrite Rmult_comm; - apply Rmult_le_compat; [ apply Rabs_pos| apply Rabs_pos| apply Hd' | ]. -{ apply Rabs_le_minus in IHvF. - assert (Hs: Rabs (FT2R s) <= g (size l) * s1 + g1 (size l) (size l - 1) + s1). -{ eapply Rle_trans; [apply IHvF | ]. apply Rplus_le_compat_l. - rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 l) s1); auto. - apply (dot_prod_sum_rel_R_Rabs (map FR2 l)); auto. } -apply Hs. } -field_simplify. -fold D E n. -rewrite !Rplus_assoc. -replace (Rabs (F * d * d' + (F * d + F * d')) + -(D * g n * s1 + - (D * s1 + - (D * g1 n (n - 1) + - ( s1 * g n + (g1 n (n - 1) + Rabs (1 + d') * E)))))) with -(Rabs (F * d * d' + (F * d + F * d')) + ((1+ D) * g n * s1 + D * s1) + - (D * g1 n (n - 1) + (g1 n (n - 1) + Rabs (1 + d') * E))) by nra. -replace (n.+1-1)%nat with n by lia. -replace (s1 * g (S n) + (g (S n) * Rabs (FT2R f) * Rabs (FT2R f0) + g1 (S n) n)) -with (g (S n) * Rabs (FT2R f * FT2R f0) + s1 * g (S n) + g1 (S n) n) by -(rewrite Rmult_assoc -Rabs_mult; nra). -apply Rplus_le_compat. -apply Rplus_le_compat. -eapply Rle_trans; - [ apply Rabs_triang | ]. -eapply Rle_trans; - [ apply Rplus_le_compat; [rewrite !Rabs_mult| eapply Rle_trans; [apply Rabs_triang| ]] | ]. -apply Rmult_le_compat; [rewrite -!Rabs_mult; try apply Rabs_pos | apply Rabs_pos| | apply Hd']. -apply Rmult_le_compat_l; [rewrite -!Rabs_mult; try apply Rabs_pos | apply Hd ]. -apply Rplus_le_compat; rewrite Rabs_mult. -apply Rmult_le_compat_l; [apply Rabs_pos | apply Hd ]. -apply Rmult_le_compat_l; [apply Rabs_pos | apply Hd' ]. -rewrite -!Rabs_mult. -fold D F. replace (Rabs F * D * D + (Rabs F * D + Rabs F * D)) with - ( ((1 + D)*(1+D) - 1) * Rabs F ) by nra. -apply Rmult_le_compat_r; try apply Rabs_pos; unfold D, g. -apply Rplus_le_compat; try nra. -rewrite <- tech_pow_Rmult. -apply Rmult_le_compat_l. -eapply Rle_trans with 1; try nra; apply default_rel_plus_1_ge_1. -eapply Rle_trans with ((1 + D)^1); try nra. -fold D; nra. -apply Rle_pow; auto with commonDB. -apply Req_le; rewrite one_plus_d_mul_g. -rewrite Rmult_comm. -repeat f_equal; try lia. -rewrite <- Rplus_assoc. -eapply Rle_trans; [apply Rplus_le_compat_l; - apply Rmult_le_compat_r; [ unfold E; apply default_abs_ge_0| eapply Rle_trans] | ]. -apply Rabs_triang. rewrite Rabs_R1. -apply Rplus_le_compat_l; apply Hd'. -rewrite !Rmult_plus_distr_r. rewrite Rmult_1_l. -rewrite <- !Rplus_assoc. -replace (D * g1 n (n - 1) + g1 n (n - 1)) with (g1 n (n-1) * (1+D)) by nra. -rewrite one_plus_d_mul_g1; [ | lia]. -rewrite Rplus_assoc. -replace (E + D * E) with ((1+D) * E) by nra. -eapply Rle_trans; [apply plus_d_e_g1_le; lia | apply Req_le; f_equal;lia]. +- (* case: singleton list *) + subst; simpl. + rewrite (R_dot_prod_rel_single rp (FR2 a)); auto. + inversion Hfp. inversion H2; subst. + destruct (BPLUS_correct _ _ Hfin) as [[A Hz] ?]. + rewrite Bplus_0R; auto. + destruct (BMULT_accurate' _ _ A) as (d' & e' & Hed' & Hd' & He' & B). + unfold g1, g; simpl. + inversion Hra. inversion H4; subst. + rewrite {}B Rmult_1_r !Rplus_0_r. + field_simplify. field_simplify_Rabs. + destruct a; simpl. + eapply Rle_trans; [apply Rabs_triang |]. + rewrite Rabs_mult. + eapply Rle_trans. + { apply Rplus_le_compat. + - apply Rmult_le_compat; try apply Rabs_pos. + + apply Rle_refl. + + apply Hd'. + - apply He'. } + rewrite Rmult_comm. + apply Rplus_le_compat; try nra. + rewrite Rmult_assoc; rewrite -Rabs_mult; nra. +- (* case: non-empty tail *) + intros; inversion Hfp; inversion Hrp; inversion Hra; subst. + destruct (BPLUS_finite_e _ _ Hfin) as (A & B). + (* apply induction hypothesis to the tail *) + specialize (IHvF s H3 B s0 s1 H7 H11). + destruct (BPLUS_accurate' (BMULT (fst a) (snd a)) s Hfin) as (d' & Hd' & Hplus). + rewrite Hplus; clear Hplus. + destruct (BMULT_accurate' (fst a) (snd a) A) as (d & e & Hed & Hd & He & Hmul). + rewrite Hmul; clear Hmul. + apply size_not_empty_nat in H. + destruct a; cbv [FR2 Rabsp fst snd]; simpl. + set (n := size l) in *. + set (F := FT2R f * FT2R f0). + field_simplify_Rabs. + replace (F * d * d' + F * d + F * d' + e * d' + e + FT2R s * d' + FT2R s - s0) + with ((F * d * d' + F * d + F * d' + FT2R s * d') + (FT2R s - s0) + (1 + d') * e) + by nra. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans; + [apply Rplus_le_compat; + [eapply Rle_trans; [apply Rabs_triang |] |] |]. + { apply Rplus_le_compat_l; apply IHvF. } + { rewrite Rabs_mult; apply Rmult_le_compat_l; [apply Rabs_pos | apply He]. } + rewrite Rplus_assoc. + eapply Rle_trans; + [apply Rplus_le_compat_r; + eapply Rle_trans; [apply Rabs_triang |] |]. + apply Rplus_le_compat_l. + rewrite Rabs_mult Rmult_comm. + apply Rmult_le_compat; [apply Rabs_pos | apply Rabs_pos | apply Hd' |]. + { apply Rabs_le_minus in IHvF. + assert (Hs : Rabs (FT2R s) <= g (size l) * s1 + g1 (size l) (size l - 1) + s1). + { eapply Rle_trans; [apply IHvF |]. + apply Rplus_le_compat_l. + rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 l) s1); auto. + apply (dot_prod_sum_rel_R_Rabs (map FR2 l)); auto. } + apply Hs. } + field_simplify. + fold D E n. + rewrite !Rplus_assoc. + replace (Rabs (F * d * d' + (F * d + F * d')) + + (D * g n * s1 + + (D * s1 + + (D * g1 n (n - 1) + + (s1 * g n + (g1 n (n - 1) + Rabs (1 + d') * E)))))) + with (Rabs (F * d * d' + (F * d + F * d')) + + ((1 + D) * g n * s1 + D * s1) + + (D * g1 n (n - 1) + (g1 n (n - 1) + Rabs (1 + d') * E))) + by nra. + replace (n.+1 - 1)%nat with n by lia. + replace (s1 * g (S n) + (g (S n) * Rabs (FT2R f) * Rabs (FT2R f0) + g1 (S n) n)) + with (g (S n) * Rabs (FT2R f * FT2R f0) + s1 * g (S n) + g1 (S n) n) + by (rewrite Rmult_assoc -Rabs_mult; nra). + apply Rplus_le_compat; [apply Rplus_le_compat |]. + + (* bound on |F * d * d' + (F * d + F * d')| *) + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans; + [apply Rplus_le_compat; + [rewrite !Rabs_mult + | eapply Rle_trans; [apply Rabs_triang |]] |]. + { apply Rmult_le_compat; + [rewrite -!Rabs_mult; apply Rabs_pos | apply Rabs_pos | | apply Hd']. + apply Rmult_le_compat_l; [rewrite -!Rabs_mult; apply Rabs_pos | apply Hd]. } + { apply Rplus_le_compat; rewrite Rabs_mult. + - apply Rmult_le_compat_l; [apply Rabs_pos | apply Hd]. + - apply Rmult_le_compat_l; [apply Rabs_pos | apply Hd']. } + rewrite -!Rabs_mult. + fold D F. + replace (Rabs F * D * D + (Rabs F * D + Rabs F * D)) + with (((1 + D) * (1 + D) - 1) * Rabs F) by nra. + apply Rmult_le_compat_r; [apply Rabs_pos |]. + unfold D, g. + apply Rplus_le_compat; try nra. + rewrite <- tech_pow_Rmult. + apply Rmult_le_compat_l; + [eapply Rle_trans with 1; try nra; apply default_rel_plus_1_ge_1 |]. + eapply Rle_trans with ((1 + D) ^ 1); try nra. + fold D; nra. + apply Rle_pow; auto with commonDB. + + (* bound on the g-weighted sum term *) + apply Req_le. + rewrite one_plus_d_mul_g. + rewrite Rmult_comm. + repeat f_equal; try lia. + + (* bound on the g1 absolute error term *) + rewrite <- Rplus_assoc. + eapply Rle_trans; + [apply Rplus_le_compat_l; + apply Rmult_le_compat_r; + [unfold E; apply default_abs_ge_0 + | eapply Rle_trans; [apply Rabs_triang |]] |]. + { rewrite Rabs_R1. apply Rplus_le_compat_l; apply Hd'. } + rewrite !Rmult_plus_distr_r Rmult_1_l. + rewrite <- !Rplus_assoc. + replace (D * g1 n (n - 1) + g1 n (n - 1)) + with (g1 n (n - 1) * (1 + D)) by nra. + rewrite one_plus_d_mul_g1; [| lia]. + rewrite Rplus_assoc. + replace (E + D * E) with ((1 + D) * E) by nra. + eapply Rle_trans; [apply plus_d_e_g1_le; lia |]. + apply Req_le; f_equal; lia. Qed. -End ForwardErrorRel1. +End ForwardErrorRel1. + +(** * Section 2: Forward Error Bound — FMA Dot Product *) Section ForwardErrorRel2. -(* forward error bound for fma dot product using inductive rels *) -Context {NAN: FPCore.Nans} {t : type}. +(** + Forward error bound for the FMA-based dot product, using the inductive + relation [fma_dot_prod_rel]. The bound takes the same form as the non-FMA + case but uses FMA accuracy lemmas internally. +*) + +Context {NAN : FPCore.Nans} {t : type}. -Variable (vF : list (ftype t * ftype t)). -Notation vR := (map FR2 vF). -Notation vR' := (map Rabsp (map FR2 vF)). +Variable (vF : list (ftype t * ftype t)). +Notation vR := (map FR2 vF). +Notation vR' := (map Rabsp (map FR2 vF)). -Variable (fp : ftype t). -Hypothesis Hfp : fma_dot_prod_rel vF fp. -Hypothesis Hfin: Binary.is_finite fp = true. +Variable (fp : ftype t). +Hypothesis Hfp : fma_dot_prod_rel vF fp. +Hypothesis Hfin : Binary.is_finite fp = true. -Variable (rp rp_abs : R). +Variable (rp rp_abs : R). Hypothesis Hrp : R_dot_prod_rel vR rp. -Hypothesis Hra : R_dot_prod_rel vR' rp_abs. +Hypothesis Hra : R_dot_prod_rel vR' rp_abs. -Notation g := (@g t). +Notation g := (@g t). Notation g1 := (@g1 t). -Notation D := (@default_rel t). -Notation E := (@default_abs t). +Notation D := (@default_rel t). +Notation E := (@default_abs t). -Lemma fma_dotprod_forward_error_rel: - Rabs (FT2R fp - rp) <= g (size vF) * rp_abs + g1 (size vF) (size vF - 1). +(** [fma_dotprod_forward_error_rel] establishes the standard forward error bound + for the FMA-based floating-point dot product. + + The absolute error satisfies: + + << |FT2R fp - rp| <= g(|vF|) * rp_abs + g1(|vF|, |vF| - 1) >> + + where << rp >> is the exact real dot product and << rp_abs >> is the dot product + of absolute values. +*) + +Lemma fma_dotprod_forward_error_rel : + Rabs (FT2R fp - rp) <= g (size vF) * rp_abs + g1 (size vF) (size vF - 1). Proof. revert Hfp Hrp Hra Hfin. revert fp rp rp_abs. induction vF. -{ -intros; -inversion Hrp; -inversion Hfp; -inversion Hra; -subst. -unfold g, g1; simpl. -rewrite !Rminus_diag Rabs_R0; -field_simplify; try apply default_rel_sep_0; - try apply Stdlib.Rdiv_pos_compat; try nra; -apply default_rel_gt_0. -} +{ (* base case: empty list *) + intros; + inversion Hrp; inversion Hfp; inversion Hra; subst. + unfold g, g1; simpl. + rewrite !Rminus_diag Rabs_R0. + field_simplify; try apply default_rel_sep_0; + try apply Stdlib.Rdiv_pos_compat; try nra; + apply default_rel_gt_0. } intros. -assert (Hl: l = [] \/ l <> []). -destruct l; auto. -right. -eapply hd_error_some_nil; simpl; auto. +assert (Hl : l = [] \/ l <> []). +{ destruct l; auto. + right; eapply hd_error_some_nil; simpl; auto. } destruct Hl. -(* list (a0 :: a :: l) *) -(* case empty l *) -{ -subst; simpl. -rewrite (R_dot_prod_rel_single rp (FR2 a)); [ | auto]. -inversion Hfp. inversion H2. subst. -pose proof fma_accurate' (fst a) (snd a) (Zconst t 0) Hfin as Hacc. -destruct Hacc as (e & d & Hz & He & Hd & A). rewrite A; clear A. -inversion Hra; inversion H3; subst. -unfold g1, g; simpl. -rewrite Rmult_1_r. rewrite !Rplus_0_r. -replace (1 + @default_rel t - 1) with (@default_rel t) by nra. -field_simplify_Rabs. destruct a; simpl. -rewrite Rminus_diag Rplus_0_r Rmult_1_r Rmult_1_l. -eapply Rle_trans. apply Rabs_triang. -apply Rplus_le_compat; try nra. -rewrite (Rmult_comm D). -rewrite !Rabs_mult. - apply Rmult_le_compat; try apply Rabs_pos; try apply Rle_refl; - try apply Rabs_pos; auto. -rewrite <- Rabs_mult. -apply Rabs_pos. -} -(* non-empty l *) -intros; inversion Hfp; -inversion Hrp; inversion Hra; subst. -(destruct (BFMA_finite_e _ _ _ Hfin) as (A & B & C)). -(* IHl *) -specialize (IHl s s0 s1 H3 H7 H11 C). -pose proof (fma_accurate' (fst a) (snd a) s Hfin) as Hplus. -destruct Hplus as (d' & e'& Hz & Hd'& He'& Hplus); rewrite Hplus; - clear Hplus. -(* algebra *) -destruct a; cbv [ FR2 Rabsp fst snd]. -simpl. -set (n:= size l). -field_simplify_Rabs. -replace (FT2R f * FT2R f0 * d' + FT2R s * d' + FT2R s + e' - s0) with - (d' * (FT2R f * FT2R f0) + d' * FT2R s + (FT2R s - s0) + e') by nra. -eapply Rle_trans; - [ apply Rabs_triang | eapply Rle_trans; [ apply Rplus_le_compat_r; apply Rabs_triang - | ] ]. -eapply Rle_trans; - [ apply Rplus_le_compat_r | ]. -apply Rplus_le_compat_r. -apply Rabs_triang. -eapply Rle_trans; - [apply Rplus_le_compat_r; apply Rplus_le_compat_l | ]. -apply IHl. -eapply Rle_trans; - [apply Rplus_le_compat; [apply Rplus_le_compat_r| apply He' ] | ]. -apply Rplus_le_compat. -rewrite Rabs_mult; -apply Rmult_le_compat_r; try apply Rabs_pos; -apply Hd'. -rewrite Rabs_mult; -apply Rmult_le_compat; try apply Rabs_pos. -apply Hd'. -apply Rabs_le_minus in IHl. -assert (Hs: Rabs (FT2R s) <= - g (size l) * s1 + g1 (size l) (size l - 1) + s1). -{ eapply Rle_trans. apply IHl. -apply Rplus_le_compat_l. -rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 l) s1); auto. -apply (dot_prod_sum_rel_R_Rabs (map FR2 l)); auto. } -apply Hs. -fold n. -set (F:=Rabs (FT2R f * FT2R f0)). -rewrite !Rmult_plus_distr_l. -replace (D * F + (D * (g n * s1) + D * g1 n (n - 1) + D * s1) + -(g n * s1 + g1 n (n - 1)) + E) with -(D * F + ((1 + D) * g n * s1 + D * s1) + g1 n (n - 1) * (1 + D) + E) by nra. -rewrite one_plus_d_mul_g. rewrite one_plus_d_mul_g1. -rewrite Rplus_assoc. -apply Rplus_le_compat. -apply Rplus_le_compat. -rewrite <- Rabs_mult. fold F. -apply Rmult_le_compat_r. -unfold F; apply Rabs_pos. -apply d_le_g_1; lia. -apply Rmult_le_compat_r. -rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 l) s1); auto. apply Rabs_pos. -apply Req_le; f_equal; auto; lia. -replace (n.+1-1)%nat with n by lia. -apply plus_e_g1_le. -unfold n; destruct l; try congruence; simpl; lia. +- (* case: singleton list *) + subst; simpl. + rewrite (R_dot_prod_rel_single rp (FR2 a)); [| auto]. + inversion Hfp. inversion H2; subst. + pose proof fma_accurate' (fst a) (snd a) (Zconst t 0) Hfin as Hacc. + destruct Hacc as (e & d & Hz & He & Hd & A). + rewrite A; clear A. + inversion Hra; inversion H3; subst. + unfold g1, g; simpl. + rewrite Rmult_1_r !Rplus_0_r. + replace (1 + @default_rel t - 1) with (@default_rel t) by nra. + field_simplify_Rabs. + destruct a; simpl. + rewrite Rminus_diag Rplus_0_r Rmult_1_r Rmult_1_l. + eapply Rle_trans; [apply Rabs_triang |]. + apply Rplus_le_compat; try nra. + rewrite Rmult_comm !Rabs_mult. + apply Rmult_le_compat; try apply Rabs_pos; try apply Rle_refl; auto. + rewrite <- Rabs_mult; apply Rabs_pos. +- (* case: non-empty tail *) + intros; inversion Hfp; inversion Hrp; inversion Hra; subst. + destruct (BFMA_finite_e _ _ _ Hfin) as (A & B & C). + (* apply induction hypothesis to the tail *) + specialize (IHl s s0 s1 H3 H7 H11 C). + pose proof (fma_accurate' (fst a) (snd a) s Hfin) as Hplus. + destruct Hplus as (d' & e' & Hz & Hd' & He' & Hplus). + rewrite Hplus; clear Hplus. + destruct a; cbv [FR2 Rabsp fst snd]; simpl. + set (n := size l). + field_simplify_Rabs. + replace (FT2R f * FT2R f0 * d' + FT2R s * d' + FT2R s + e' - s0) + with (d' * (FT2R f * FT2R f0) + d' * FT2R s + (FT2R s - s0) + e') + by nra. + eapply Rle_trans; + [apply Rabs_triang + | eapply Rle_trans; + [apply Rplus_le_compat_r; apply Rabs_triang |]]. + eapply Rle_trans; + [apply Rplus_le_compat_r |]. + { apply Rplus_le_compat_r; apply Rabs_triang. } + eapply Rle_trans; + [apply Rplus_le_compat_r; apply Rplus_le_compat_l |]. + { apply IHl. } + eapply Rle_trans; + [apply Rplus_le_compat; [apply Rplus_le_compat_r | apply He'] |]. + { apply Rplus_le_compat. + - rewrite Rabs_mult. + apply Rmult_le_compat_r; [apply Rabs_pos | apply Hd']. + - rewrite Rabs_mult. + apply Rmult_le_compat; try apply Rabs_pos. + + apply Hd'. + + apply Rabs_le_minus in IHl. + assert (Hs : Rabs (FT2R s) <= + g (size l) * s1 + g1 (size l) (size l - 1) + s1). + { eapply Rle_trans; [apply IHl |]. + apply Rplus_le_compat_l. + rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 l) s1); auto. + apply (dot_prod_sum_rel_R_Rabs (map FR2 l)); auto. } + apply Hs. } + fold n. + set (F := Rabs (FT2R f * FT2R f0)). + rewrite !Rmult_plus_distr_l. + replace (D * F + (D * (g n * s1) + D * g1 n (n - 1) + D * s1) + + (g n * s1 + g1 n (n - 1)) + E) + with (D * F + ((1 + D) * g n * s1 + D * s1) + + g1 n (n - 1) * (1 + D) + E) + by nra. + rewrite one_plus_d_mul_g one_plus_d_mul_g1. + rewrite Rplus_assoc. + apply Rplus_le_compat; [apply Rplus_le_compat |]. + + rewrite <- Rabs_mult; fold F. + apply Rmult_le_compat_r; [unfold F; apply Rabs_pos |]. + apply d_le_g_1; lia. + + apply Rmult_le_compat_r. + { rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 l) s1); auto. + apply Rabs_pos. } + apply Req_le; f_equal; auto; lia. + + replace (n.+1 - 1)%nat with n by lia. + apply plus_e_g1_le. + + unfold n; destruct l; try congruence; simpl; lia. Qed. End ForwardErrorRel2. -Section MixedErrorRel1. -(* mixed error bound for non-fma dot product using inductive rels *) -Context {NAN: FPCore.Nans} {t : type}. +(** * Section 3: Mixed Error Bound — Non-FMA Dot Product *) + +Section MixedErrorRel1. +(** + Mixed (componentwise relative + global absolute) error bound for the + non-FMA dot product. -Notation g := (@g t). + This section establishes that the computed dot product can be + expressed as an exact dot product of perturbed inputs, up to a small + absolute error. +*) + +Context {NAN : FPCore.Nans} {t : type}. + +Notation g := (@g t). Notation g1 := (@g1 t). -Notation D := (@default_rel t). -Notation E := (@default_abs t). +Notation D := (@default_rel t). +Notation E := (@default_abs t). Variables (v1 v2 : list (ftype t)). Hypothesis Hlen : size v1 = size v2. -Notation vF := (zip v1 v2). +Notation vF := (zip v1 v2). -Variable (fp : ftype t). -Hypothesis Hfp : dot_prod_rel vF fp. -Hypothesis Hfin: Binary.is_finite fp = true. +Variable (fp : ftype t). +Hypothesis Hfp : dot_prod_rel vF fp. +Hypothesis Hfin : Binary.is_finite fp = true. Notation neg_zero := (@common.neg_zero t). -(* mixed error bound *) -Lemma dotprod_mixed_error_rel: +(** [dotprod_mixed_error_rel] establishes the mixed error representation for + the non-FMA floating-point dot product. +*) + +Lemma dotprod_mixed_error_rel : exists (u : list R) (eta : R), size u = size v2 /\ R_dot_prod_rel (zip u (map FT2R v2)) (FT2R fp - eta) /\ - (forall n, (n < size v2)%nat -> exists delta, - nth 0 u n = FT2R (nth neg_zero v1 n) * (1 + delta) /\ Rabs delta <= g (size v2)) /\ + (forall n, (n < size v2)%nat -> + exists delta, + nth 0 u n = FT2R (nth neg_zero v1 n) * (1 + delta) /\ + Rabs delta <= g (size v2)) /\ Rabs eta <= g1 (size v2) (size v2). Proof. revert Hfp Hfin Hlen. revert fp v1. induction v2. -{ simpl; intros. replace v1 with (@nil (ftype t)) in * by (symmetry; apply length_zero_iff_nil; auto). - exists [], 0; repeat split; - [inversion Hfp; subst; rewrite Rminus_0_r; simpl; auto; - apply R_dot_prod_rel_nil | | rewrite Rabs_R0; unfold g1, g; simpl; nra ]. - intros; exists 0; split; - [ assert (n = 0)%nat by lia; subst; simpl; nra | rewrite Rabs_R0; unfold g; nra]. -} +{ (* base case: empty v2 *) + simpl; intros. + replace v1 with (@nil (ftype t)) in * + by (symmetry; apply length_zero_iff_nil; auto). + exists [], 0; repeat split. + - inversion Hfp; subst. + rewrite Rminus_0_r; simpl; auto. + apply R_dot_prod_rel_nil. + - intros; exists 0; split. + + assert (n = 0)%nat by lia; subst; simpl; nra. + + rewrite Rabs_R0; unfold g; nra. + - rewrite Rabs_R0; unfold g1, g; simpl; nra. } intros. - destruct v1; intros. - { simpl in Hlen. pose proof Nat.neq_0_succ (size l); try contradiction. } - assert (Hv1: l = [] \/ l <> []). - destruct l; auto. right. - eapply hd_error_some_nil; simpl; auto. - assert (Hlen1: size l0 = size l) by (simpl in Hlen; auto). - destruct Hv1. - assert (l0 = []). { subst l; destruct l0; auto; discriminate. } - subst; clear Hlen1. -{ (* case singleton lists *) -clear IHl. inversion Hfp; subst. -inversion H2; subst; clear H2. - simpl in Hfp, Hfin; unfold fst, snd. - destruct (BPLUS_correct _ _ Hfin) as [[Hfa _] Hplus]. -destruct (BMULT_correct _ _ Hfa) as [[Ha Hf] _]. -destruct (BMULT_accurate' f a Hfa) as (d & e & Hed & Hd & He & Hacc). -exists [FT2R f * (1 +d)], e; repeat split. -{ simpl. rewrite Hplus. simpl. -rewrite Rplus_0_r. -rewrite round_FT2R Hacc. -replace (FT2R f * FT2R a * (1 + d) + e - e) with - (FT2R f * (1 + d) * FT2R a + 0) by (simpl; nra). -apply R_dot_prod_rel_cons; apply R_dot_prod_rel_nil. } -{ intros; exists d; split; auto. simpl in H. - destruct n. { simpl; auto. } - lia. -eapply Rle_trans; [apply Hd| apply d_le_g_1; simpl; auto]. -} -eapply Rle_trans; [apply He|]. apply e_le_g1; simpl in *; auto. -} -(* case cons lists*) -(* apply IH *) -move :(size_not_empty_nat l H) => Hlen3. -inversion Hfp; subst. -unfold fst, snd in Hfin, Hfp; unfold fst, snd. -destruct (BPLUS_finite_e _ _ Hfin) as (A & B). -destruct (BMULT_finite_e _ _ A) as (C & _). -(* IHl *) -specialize (IHl s l0 H3 B Hlen1). -(* construct u *) -destruct (BPLUS_accurate' (BMULT f a) s Hfin) as (d' & Hd'& Hplus); -rewrite Hplus; clear Hplus. -destruct (BMULT_accurate' f a A) as (d & e & Hed & Hd& He& Hmul); -rewrite Hmul; clear Hmul. -destruct IHl as (u & eta & Hlenu & Hurel & Hun & Heta). -exists (FT2R f * (1+d) * (1 + d') :: map (Rmult (1+d')) u), - (e * (1 + d') + eta * (1 + d')). -repeat split. -{ simpl. rewrite size_map; auto. } -{ pose proof dot_prod_zip_map_Rmult (1+d') u (map FT2R l) (FT2R s - eta). -rewrite size_map in H0. specialize (H0 Hlenu Hurel); simpl. -replace - ((FT2R f * FT2R a * (1 + d) + e + FT2R s) * (1 + d') - - (e * (1 + d') + eta * (1 + d'))) -with - (FT2R f * (1 + d) * (1 + d') * FT2R a + (FT2R s - eta) * (1 + d')) by nra. -apply R_dot_prod_rel_cons; rewrite Rmult_comm; auto. } -{ intros. destruct n. simpl. -{ simpl. exists ((1 + d) * (1 + d') -1); split. - { field_simplify; nra. } - { field_simplify_Rabs. eapply Rle_trans; [apply Rabs_triang|]. - eapply Rle_trans; [apply Rplus_le_compat; [apply Rabs_triang | apply Hd' ] |]. - eapply Rle_trans; [apply Rplus_le_compat_r; apply Rplus_le_compat; [|apply Hd] | ]. - rewrite Rabs_mult. apply Rmult_le_compat; - [apply Rabs_pos | apply Rabs_pos | apply Hd | apply Hd']. - eapply Rle_trans with ((1 + D) * (1 + D) - 1); try nra. - unfold g. apply Rplus_le_compat; try nra. - rewrite <- tech_pow_Rmult; apply Rmult_le_compat; try nra; try - (eapply Rle_trans with 1; try nra; apply (default_rel_plus_1_ge_1)). - eapply Rle_trans with ((1 + D) ^ 1); try nra. - apply Rle_pow; try - (eapply Rle_trans with 1; try nra; apply (default_rel_plus_1_ge_1)). - rewrite <- Hlen1; auto. lia. } -} -simpl in H0; assert (Hn: (n < size l)%nat) by lia. -specialize (Hun n Hn); - destruct Hun as (delta & Hun & Hdelta). simpl; -replace 0 with (Rmult (1+d') 0) by nra. rewrite (nth_map R0); [ | lia]. -rewrite Hun. -exists ( (1+d') * (1+delta) -1). -split; [nra | ]. -field_simplify_Rabs. -eapply Rle_trans; [apply Rabs_triang | ]. -eapply Rle_trans; [apply Rplus_le_compat; [ apply Rabs_triang | apply Hdelta] | ]. -eapply Rle_trans; [apply Rplus_le_compat_r; [rewrite Rabs_mult] | ]. -apply Rplus_le_compat; [apply Rmult_le_compat; try apply Rabs_pos | ]. -apply Hd'. -apply Hdelta. -apply Hd'. -replace (D * g (size l) + D + g (size l)) with -((1 + D) * g (size l) *1 + D *1) by nra. -rewrite one_plus_d_mul_g. -rewrite Rmult_1_r. -apply Req_le; f_equal; lia. -} -simpl. -eapply Rle_trans; [apply Rabs_triang| ]. -eapply Rle_trans; [apply Rplus_le_compat; [rewrite Rabs_mult| rewrite Rabs_mult] | ]. -eapply Rmult_le_compat; try apply Rabs_pos. -apply He. -eapply Rle_trans; [apply Rabs_triang | rewrite Rabs_R1; apply Rplus_le_compat_l; apply Hd']. -eapply Rmult_le_compat; try apply Rabs_pos. -apply Heta. -eapply Rle_trans; [apply Rabs_triang | rewrite Rabs_R1; apply Rplus_le_compat_l; apply Hd']. -rewrite Rplus_comm. rewrite one_plus_d_mul_g1'. -assert (Hp: (1 <= S (size l))%nat) by lia. -pose proof @plus_d_e_g1_le' t (size l) (S (size l)) ltac:(lia) Hp as HYP; clear Hp. -eapply Rle_trans; [| apply HYP]; apply Req_le; nra. +destruct v1; intros. +{ simpl in Hlen. + pose proof Nat.neq_0_succ (size l); contradiction. } +assert (Hv1 : l = [] \/ l <> []). +{ destruct l; auto. + right; eapply hd_error_some_nil; simpl; auto. } +assert (Hlen1 : size l0 = size l) by (simpl in Hlen; auto). +destruct Hv1. +assert (l0 = []) by (subst l; destruct l0; auto; discriminate). +subst; clear Hlen1. +- (* case: singleton lists *) + clear IHl. + inversion Hfp; subst. + inversion H2; subst; clear H2. + simpl in Hfp, Hfin; unfold fst, snd. + destruct (BPLUS_correct _ _ Hfin) as [[Hfa _] Hplus]. + destruct (BMULT_correct _ _ Hfa) as [[Ha Hf] _]. + destruct (BMULT_accurate' f a Hfa) as (d & e & Hed & Hd & He & Hacc). + exists [FT2R f * (1 + d)], e; repeat split. + + simpl. + rewrite Hplus; simpl. + rewrite Rplus_0_r round_FT2R Hacc. + replace (FT2R f * FT2R a * (1 + d) + e - e) + with (FT2R f * (1 + d) * FT2R a + 0) by (simpl; nra). + apply R_dot_prod_rel_cons; apply R_dot_prod_rel_nil. + + intros; exists d; split; auto. + simpl in H. + destruct n; [simpl; auto | lia]. + eapply Rle_trans; [apply Hd | apply d_le_g_1; simpl; auto]. + + eapply Rle_trans; [apply He |]. + apply e_le_g1; simpl in *; auto. +- (* case: cons lists — apply induction hypothesis *) + move : (size_not_empty_nat l H) => Hlen3. + inversion Hfp; subst. + unfold fst, snd in Hfin, Hfp; unfold fst, snd. + destruct (BPLUS_finite_e _ _ Hfin) as (A & B). + destruct (BMULT_finite_e _ _ A) as (C & _). + specialize (IHl s l0 H3 B Hlen1). + destruct (BPLUS_accurate' (BMULT f a) s Hfin) as (d' & Hd' & Hplus). + rewrite Hplus; clear Hplus. + destruct (BMULT_accurate' f a A) as (d & e & Hed & Hd & He & Hmul). + rewrite Hmul; clear Hmul. + destruct IHl as (u & eta & Hlenu & Hurel & Hun & Heta). + exists (FT2R f * (1 + d) * (1 + d') :: map (Rmult (1 + d')) u), + (e * (1 + d') + eta * (1 + d')). + repeat split. + + simpl; rewrite size_map; auto. + + (* show the exact dot product relation holds *) + pose proof dot_prod_zip_map_Rmult (1 + d') u (map FT2R l) (FT2R s - eta). + rewrite size_map in H0. + specialize (H0 Hlenu Hurel); simpl. + replace ((FT2R f * FT2R a * (1 + d) + e + FT2R s) * (1 + d') - + (e * (1 + d') + eta * (1 + d'))) + with (FT2R f * (1 + d) * (1 + d') * FT2R a + (FT2R s - eta) * (1 + d')) + by nra. + apply R_dot_prod_rel_cons; rewrite Rmult_comm; auto. + + (* componentwise relative error bounds *) + intros. + destruct n. + { simpl. + exists ((1 + d) * (1 + d') - 1); split. + - field_simplify; nra. + - field_simplify_Rabs. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans; + [apply Rplus_le_compat; + [apply Rabs_triang | apply Hd'] |]. + eapply Rle_trans; + [apply Rplus_le_compat_r; + apply Rplus_le_compat; [| apply Hd] |]. + { rewrite Rabs_mult. + apply Rmult_le_compat; try apply Rabs_pos; [apply Hd | apply Hd']. } + eapply Rle_trans with ((1 + D) * (1 + D) - 1); try nra. + unfold g. + apply Rplus_le_compat; try nra. + rewrite <- tech_pow_Rmult. + apply Rmult_le_compat; try nra; + try (eapply Rle_trans with 1; try nra; apply default_rel_plus_1_ge_1). + eapply Rle_trans with ((1 + D) ^ 1); try nra. + apply Rle_pow; + try (eapply Rle_trans with 1; try nra; apply default_rel_plus_1_ge_1). + rewrite <- Hlen1; auto; lia. } + simpl in H0. + assert (Hn : (n < size l)%nat) by lia. + specialize (Hun n Hn). + destruct Hun as (delta & Hun & Hdelta). + simpl. + replace 0 with (Rmult (1 + d') 0) by nra. + rewrite (nth_map R0); [| lia]. + rewrite Hun. + exists ((1 + d') * (1 + delta) - 1). + split; [nra |]. + field_simplify_Rabs. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans; + [apply Rplus_le_compat; [apply Rabs_triang | apply Hdelta] |]. + eapply Rle_trans; + [apply Rplus_le_compat_r; rewrite Rabs_mult |]. + { apply Rplus_le_compat; + [apply Rmult_le_compat; try apply Rabs_pos; [apply Hd' | apply Hdelta] |]. + apply Hd'. } + replace (D * g (size l) + D + g (size l)) + with ((1 + D) * g (size l) * 1 + D * 1) by nra. + rewrite one_plus_d_mul_g Rmult_1_r. + apply Req_le; f_equal; lia. + + (* global absolute error bound *) + simpl. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans; + [apply Rplus_le_compat; + [rewrite Rabs_mult | rewrite Rabs_mult] |]. + { eapply Rmult_le_compat; try apply Rabs_pos. + - apply He. + - eapply Rle_trans; [apply Rabs_triang |]. + rewrite Rabs_R1. + apply Rplus_le_compat_l; apply Hd'. } + { eapply Rmult_le_compat; try apply Rabs_pos. + - apply Heta. + - eapply Rle_trans; [apply Rabs_triang |]. + rewrite Rabs_R1. + apply Rplus_le_compat_l; apply Hd'. } + rewrite Rplus_comm one_plus_d_mul_g1'. + assert (Hp : (1 <= S (size l))%nat) by lia. + pose proof @plus_d_e_g1_le' t (size l) (S (size l)) ltac:(lia) Hp as HYP. + eapply Rle_trans; [| apply HYP]. + apply Req_le; nra. Qed. End MixedErrorRel1. +(** * Section 4: Mixed Error Bound — FMA Dot Product *) + Section MixedErrorRel2. +(** + Mixed (componentwise relative + global absolute) error bound for the + FMA-based dot product. + + The structure mirrors [MixedErrorRel1] but uses [fma_dot_prod_rel] and + FMA accuracy lemmas, yielding the tighter absolute error bound + due one fewer rounding step. +*) -Context {NAN: FPCore.Nans} {t : type}. +Context {NAN : FPCore.Nans} {t : type}. -Notation g := (@g t). +Notation g := (@g t). Notation g1 := (@g1 t). -Notation D := (@default_rel t). -Notation E := (@default_abs t). +Notation D := (@default_rel t). +Notation E := (@default_abs t). Variables (v1 v2 : list (ftype t)). Hypothesis Hlen : size v1 = size v2. -Notation vF := (zip v1 v2). +Notation vF := (zip v1 v2). -Variable (fp : ftype t). -Hypothesis Hfp : fma_dot_prod_rel vF fp. -Hypothesis Hfin: Binary.is_finite fp = true. +Variable (fp : ftype t). +Hypothesis Hfp : fma_dot_prod_rel vF fp. +Hypothesis Hfin : Binary.is_finite fp = true. Notation neg_zero := (@common.neg_zero t). -(* mixed error bounds *) -Lemma fma_dotprod_mixed_error_rel: +(** [fma_dotprod_mixed_error_rel] establishes the mixed error representation for + the FMA-based floating-point dot product. + + The bound on the absolute error is one step tighter than the non-FMA case + because each FMA step introduces only one rounding error. +*) + +Lemma fma_dotprod_mixed_error_rel : exists (u : list R) (eta : R), size u = size v1 /\ R_dot_prod_rel (zip u (map FT2R v2)) (FT2R fp - eta) /\ - (forall n, (n < size v2)%nat -> exists delta, - nth 0 u n = FT2R (nth neg_zero v1 n) * (1 + delta) /\ Rabs delta <= g (size v2)) /\ + (forall n, (n < size v2)%nat -> + exists delta, + nth 0 u n = FT2R (nth neg_zero v1 n) * (1 + delta) /\ + Rabs delta <= g (size v2)) /\ Rabs eta <= g1 (size v2) (size v2 - 1). Proof. revert Hfp Hfin Hlen. revert fp v1. induction v2. -{ simpl; intros. replace v1 with (@nil (ftype t)) in * by (symmetry; apply length_zero_iff_nil; auto). - exists [], 0; repeat split; - [inversion Hfp; subst; rewrite Rminus_0_r; simpl; auto; - apply R_dot_prod_rel_nil | | rewrite Rabs_R0; unfold g1, g; simpl; nra ]. - intros; exists 0; split; - [ assert (n = 0)%nat by lia; subst; simpl; try nra | - rewrite Rabs_R0; unfold g; try nra]. -} +{ (* base case: empty v2 *) + simpl; intros. + replace v1 with (@nil (ftype t)) in * + by (symmetry; apply length_zero_iff_nil; auto). + exists [], 0; repeat split. + - inversion Hfp; subst. + rewrite Rminus_0_r; simpl; auto. + apply R_dot_prod_rel_nil. + - intros; exists 0; split. + + assert (n = 0)%nat by lia; subst; simpl; try nra. + + rewrite Rabs_R0; unfold g; try nra. + - rewrite Rabs_R0; unfold g1, g; simpl; nra. } intros. - destruct v1; intros. - { simpl in Hlen. pose proof Nat.neq_0_succ (size l); try contradiction. } - assert (Hv1: l = [] \/ l <> []). - destruct l; auto. right. - eapply hd_error_some_nil; simpl; auto. - assert (Hlen1: size l0 = size l) by (simpl in Hlen; auto). - destruct Hv1. - assert (l0 = []). { destruct l0,l; auto; discriminate. } - subst; clear Hlen1. -{ -inversion Hfp; subst. -inversion H2; subst; clear H2. -simpl in Hfp, Hfin. -pose proof fma_accurate' f a (Zconst t 0) Hfin as Hacc. -destruct Hacc as (d & e & Hde & Hd & He& Hacc). -exists [FT2R f * (1 +d)], e; repeat split. -{ simpl. rewrite Hacc FT2R_Zconst_0 Rplus_0_r. - replace ((FT2R f * FT2R a) * (1 + d) + e - e) with - (FT2R f * (1 + d) * FT2R a + 0) by (simpl; nra). -apply R_dot_prod_rel_cons; apply R_dot_prod_rel_nil. } -{ intros; exists d; split; auto. simpl in H. - destruct n. { simpl; auto. } - lia. -eapply Rle_trans; [apply Hd| apply d_le_g_1; simpl; auto]. -} -eapply Rle_trans; [apply He|]. unfold g1, g; simpl; nra. -} - (* apply IH *) -pose proof (size_not_empty_nat l H) as Hlen3. -inversion Hfp; subst. -(destruct (BFMA_finite_e _ _ _ Hfin) as (A' & B' & C')). -specialize (IHl s l0). -destruct IHl as (u & eta & Hlenu & A & B & C ); auto. -(* construct u0 *) -simpl in Hfin. -pose proof fma_accurate' f a s Hfin as Hacc; -destruct Hacc as (d & e & Hz & Hd & He & Hacc). -unfold fst, snd; rewrite Hacc. -exists (FT2R f * (1+d) :: map (Rmult (1+d)) u), (e + eta * (1 + d)). -repeat split. -{ simpl. rewrite size_map; auto. } -{ pose proof dot_prod_zip_map_Rmult (1+d) u (map FT2R l) (FT2R s - eta). -rewrite size_map in H0. -rewrite Hlen1 in Hlenu. -specialize (H0 Hlenu A); simpl. -replace ((FT2R f * FT2R a + FT2R s) * (1 + d) + e - (e + eta * (1 + d))) with -(FT2R f * (1 + d) * FT2R a + (FT2R s - eta)*(1+d)) by nra. -apply R_dot_prod_rel_cons. rewrite Rmult_comm; auto. } -{ intros. destruct n. simpl. -{ simpl. exists d; split; auto. eapply Rle_trans; [apply Hd| ]. apply d_le_g_1. lia. } -assert (n []). +{ destruct l; auto. + right; eapply hd_error_some_nil; simpl; auto. } +assert (Hlen1 : size l0 = size l) by (simpl in Hlen; auto). +destruct Hv1. +assert (l0 = []) by (destruct l0, l; auto; discriminate). +subst; clear Hlen1. +- (* case: singleton lists *) + inversion Hfp; subst. + inversion H2; subst; clear H2. + simpl in Hfp, Hfin. + pose proof fma_accurate' f a (Zconst t 0) Hfin as Hacc. + destruct Hacc as (d & e & Hde & Hd & He & Hacc). + exists [FT2R f * (1 + d)], e; repeat split. + + simpl. + rewrite Hacc FT2R_Zconst_0 Rplus_0_r. + replace ((FT2R f * FT2R a) * (1 + d) + e - e) + with (FT2R f * (1 + d) * FT2R a + 0) by (simpl; nra). + apply R_dot_prod_rel_cons; apply R_dot_prod_rel_nil. + + intros; exists d; split; auto. + simpl in H. + destruct n; [simpl; auto | lia]. + eapply Rle_trans; [apply Hd | apply d_le_g_1; simpl; auto]. + + eapply Rle_trans; [apply He |]. + unfold g1, g; simpl; nra. +- (* case: cons lists — apply induction hypothesis *) + pose proof (size_not_empty_nat l H) as Hlen3. + inversion Hfp; subst. + destruct (BFMA_finite_e _ _ _ Hfin) as (A' & B' & C'). + specialize (IHl s l0). + destruct IHl as (u & eta & Hlenu & A & B & C); auto. + simpl in Hfin. + pose proof fma_accurate' f a s Hfin as Hacc. + destruct Hacc as (d & e & Hz & Hd & He & Hacc). + unfold fst, snd; rewrite Hacc. + exists (FT2R f * (1 + d) :: map (Rmult (1 + d)) u), + (e + eta * (1 + d)). + repeat split. + + simpl; rewrite size_map; auto. + + (* show the exact dot product relation holds *) + pose proof dot_prod_zip_map_Rmult (1 + d) u (map FT2R l) (FT2R s - eta). + rewrite size_map in H0. + rewrite Hlen1 in Hlenu. + specialize (H0 Hlenu A); simpl. + replace ((FT2R f * FT2R a + FT2R s) * (1 + d) + e - (e + eta * (1 + d))) + with (FT2R f * (1 + d) * FT2R a + (FT2R s - eta) * (1 + d)) + by nra. + apply R_dot_prod_rel_cons; rewrite Rmult_comm; auto. + + (* componentwise relative error bounds *) + intros. + destruct n. + { simpl. + exists d; split; auto. + eapply Rle_trans; [apply Hd |]. + apply d_le_g_1; lia. } + assert (n < size l)%nat by (simpl in H0; lia); clear H0. + specialize (B n H1). + destruct B as (delta & B & HB); simpl. + replace 0 with (Rmult (1 + d) 0) by nra. + rewrite (nth_map R0); [| lia]. + rewrite B. + exists ((1 + d) * (1 + delta) - 1). + split; [nra |]. + field_simplify_Rabs. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans; + [apply Rplus_le_compat; [apply Rabs_triang | apply HB] |]. + eapply Rle_trans; + [apply Rplus_le_compat_r; rewrite Rabs_mult |]. + { apply Rplus_le_compat; + [apply Rmult_le_compat; try apply Rabs_pos; [apply Hd | apply HB] |]. + apply Hd. } + replace (D * g (size l) + D + g (size l)) + with ((1 + D) * g (size l) * 1 + D * 1) by nra. + rewrite one_plus_d_mul_g Rmult_1_r. + apply Req_le; f_equal; lia. + + (* global absolute error bound *) + simpl. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans; + [apply Rplus_le_compat; + [apply He | rewrite Rabs_mult] |]. + { eapply Rmult_le_compat; try apply Rabs_pos. + - apply C. + - eapply Rle_trans; [apply Rabs_triang |]. + rewrite Rabs_R1. + eapply Rle_trans; [apply Rplus_le_compat_l; apply Hd |]. + apply Rle_refl. } + rewrite one_plus_d_mul_g1. + 2: { destruct l; [contradiction | simpl; lia]. } + unfold g1; field_simplify. + rewrite Rplus_assoc. + apply Rplus_le_compat. + { apply Rmult_le_compat; try apply g_pos. + - apply Rmult_le_pos; [apply default_abs_ge_0 | apply pos_INR]. + - apply Rmult_le_compat; try apply default_abs_ge_0; try apply pos_INR. + + apply Req_le; auto. + + apply le_INR; lia. + - apply Req_le; f_equal; auto; lia. } + set (n := size l). + replace (INR (S n)) with (INR n + 1)%R. + { apply Req_le. + unfold GRing.one, GRing.add; simpl; nra. } + apply transitivity with (INR (n + 1)). + { rewrite plus_INR; simpl; auto. } + f_equal; simpl; lia. Qed. End MixedErrorRel2. +(** * Section 5: Sparse Forward Error Bound — Non-FMA Dot Product *) -Section SparseErrorRel1. -(* sparse forward error bound for non-fma dot product using inductive rels *) -Context {NAN: FPCore.Nans} {t : type} . +Section SparseErrorRel1. +(** + Sparse forward error bound for the non-FMA dot product. + + When a vector in the product contains many zeros, the error bound can be + sharpened by replacing the vector length with the number of nonzero entries. + This follows from the observation that multiplying by zero + contributes no rounding error, and thus such terms can be excluded from + the error accumulation. +*) + +Context {NAN : FPCore.Nans} {t : type}. Variables (v1 v2 : list (ftype t)). Hypothesis (Hlen : size v1 = size v2). -Variable (fp : ftype t). -Hypothesis Hfp : dot_prod_rel (zip v1 v2) fp. -Hypothesis Hfin: Binary.is_finite fp = true. +Variable (fp : ftype t). +Hypothesis Hfp : dot_prod_rel (zip v1 v2) fp. +Hypothesis Hfin : Binary.is_finite fp = true. -Notation v1R := (map FT2R v1). +Notation v1R := (map FT2R v1). +Notation nnzR := (nnzR v1R). -Variable (rp rp_abs : R). +Variable (rp rp_abs : R). Hypothesis Hrp : R_dot_prod_rel (map FR2 (zip v1 v2)) rp. -Hypothesis Hra : R_dot_prod_rel (map Rabsp (map FR2 (zip v1 v2))) rp_abs. +Hypothesis Hra : R_dot_prod_rel (map Rabsp (map FR2 (zip v1 v2))) rp_abs. -Notation g := (@common.g t). +Notation g := (@common.g t). Notation g1 := (@common.g1 t). -Notation nnzR := (nnzR v1R). -Lemma sparse_dotprod_forward_error_rel: - Rabs (FT2R fp - rp) <= g nnzR * rp_abs + g1 nnzR (nnzR - 1). +(** [sparse_dotprod_forward_error_rel] establishes the sparse forward error + bound for the non-FMA dot product.This is tighter than the dense bound. +*) + +Lemma sparse_dotprod_forward_error_rel : + Rabs (FT2R fp - rp) <= g nnzR * rp_abs + g1 nnzR (nnzR - 1). Proof. revert Hlen Hfp Hfin Hrp Hra. revert rp rp_abs fp v2. induction v1; intros. -{ simpl in Hlen; symmetry in Hlen; apply length_zero_iff_nil in Hlen; subst. -inversion Hfp; inversion Hrp; subst; simpl; field_simplify_Rabs. - rewrite Rabs_Ropp Rabs_R0. +{ (* base case: empty v1 *) + simpl in Hlen; symmetry in Hlen; apply length_zero_iff_nil in Hlen; subst. + inversion Hfp; inversion Hrp; subst; simpl; field_simplify_Rabs. + rewrite Rabs_Ropp Rabs_R0. apply Rplus_le_le_0_compat; auto with commonDB. - apply Rmult_le_pos; auto with commonDB. - rewrite <- (R_dot_prod_rel_Rabs_eq [] rp_abs); auto; + apply Rmult_le_pos; auto with commonDB. + rewrite <- (R_dot_prod_rel_Rabs_eq [] rp_abs); auto. apply Rabs_pos. } destruct v2; try discriminate. assert (Hlen1 : size l = size l0) by (simpl; auto). -set (n2:= (common.nnzR (map FT2R l))%nat) in *. -inversion Hrp. inversion Hfp. inversion Hra; subst. +set (n2 := (common.nnzR (map FT2R l))%nat) in *. +inversion Hrp. inversion Hfp. inversion Hra; subst. simpl in Hfin. destruct (BPLUS_correct _ _ Hfin) as [[Haf Hs0] Hplus]. specialize (IHl s s1 s0 l0 Hlen1 H6 Hs0 H2 H10). -simpl fst. simpl snd. -(* reason by cases on the head of the list *) -destruct (Req_EM_T (FT2R a) 0%R). -(* start head of list is zero *) -{ simpl map. rewrite e. - move :(nnzR_cons [seq FT2R i | i <- l]) => /eqP H8. rewrite {}H8. -replace (FT2R (BPLUS (BMULT a f) s0)) with (FT2R s0). -change GRing.zero with R0. -field_simplify_Rabs. -eapply Rle_trans; [apply IHl|]. -apply Req_le; f_equal; try nra. unfold n2, common.nnzR. -rewrite Rabs_R0 Rmult_0_l Rplus_0_l. -simpl count. auto. -rewrite {}Hplus. -pose proof Bmult_0R a f Haf e as A. - destruct A as [A|A] ; auto; rewrite A. -by rewrite Rplus_0_l round_FT2R. -by rewrite FT2R_pos_zero Rplus_0_l round_FT2R. - } (* end head of list is zero *) -(* start head of list is non-zero *) -unfold common.nnzR. -simpl count. -replace (0 != FT2R a) with true by (symmetry; apply /eqP; auto). -simpl. -set (l1:= (map FT2R l)) in *. -change (count _ _) with n2. -(* start case on nnz = case on nnz in tail *) -assert (H7: (n2 = 0)%nat \/ (1<=n2)%nat) by lia; destruct H7 as [H7|H7]. -(* tail all zeros *) -{ rewrite H7. - assert (H0: eq_op n2 0%N) by lia. - pose proof R_dot_prod_rel_nnzR l l0 Hlen1 s H2 H0; subst. -pose proof dot_prod_rel_nnzR l l0 Hlen1 s0 H6 Hs0 H0. -pose proof R_dot_prod_rel_nnzR_abs l l0 Hlen1 s1 H10 H0; subst. -rewrite Bplus_0R; auto. -destruct (BMULT_accurate' a f Haf) - as (d' & e' & Hed' & Hd' & He' & Hacc). -rewrite Hacc; clear Hacc. -unfold g1, g. -simpl. -field_simplify; field_simplify_Rabs. -eapply Rle_trans; [apply Rabs_triang | ]. -apply Rplus_le_compat. -rewrite Rabs_mult. -rewrite Rmult_comm. -rewrite Rabs_mult. rewrite Rmult_assoc. -apply Rmult_le_compat_r; auto with commonDB. -rewrite <- Rabs_mult; apply Rabs_pos. -eapply Rle_trans; [apply He'| ]; auto with commonDB; nra. -} -(* tail not all zeros *) -destruct (BPLUS_accurate' (BMULT a f) s0 Hfin) - as (d' & Hd' & Hacc). -rewrite Hacc; clear Hacc. -destruct (BMULT_accurate' a f Haf) - as (d & e & Hed & Hd & He & Hacc). -rewrite Hacc; clear Hacc. -set (F:= FT2R a * FT2R f ). -field_simplify_Rabs. -replace (F * d * d' + F * d + F * d' + e * d' + e + FT2R s0 * d' + FT2R s0 - s) with -((F * d * d' + F * d + F * d' + FT2R s0 * d') + (FT2R s0 - s) + (1 + d') * e) by nra. -eapply Rle_trans; - [ apply Rabs_triang | ]. -eapply Rle_trans; - [ apply Rplus_le_compat; [eapply Rle_trans; - [ apply Rabs_triang | ] |] | ]. -apply Rplus_le_compat_l; apply IHl . -rewrite Rabs_mult; apply Rmult_le_compat_l; [apply Rabs_pos | apply He]. -rewrite Rplus_assoc. -eapply Rle_trans; - [ apply Rplus_le_compat_r ; eapply Rle_trans; [ apply Rabs_triang | ] | ]. -apply Rplus_le_compat_l; rewrite Rabs_mult; rewrite Rmult_comm; - apply Rmult_le_compat; [ apply Rabs_pos| apply Rabs_pos| apply Hd' | ]. -{ apply Rabs_le_minus in IHl. - assert (Hs: Rabs (FT2R s0) <= - g n2 * s1 + g1 n2 (n2 - 1) + s1). - { eapply Rle_trans. apply IHl. - apply Rplus_le_compat. - apply Rplus_le_compat. - apply Rmult_le_compat; auto with commonDB; try apply Rle_refl. - rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 (zip l l0)) s1); auto; - apply Rabs_pos. - apply Rle_refl. - rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 (zip l l0)) s1); auto; - apply (dot_prod_sum_rel_R_Rabs (map FR2 (zip l l0))); auto. } -apply Hs. } -field_simplify. -unfold g1, g in IHl. -field_simplify in IHl. -set (D:= default_rel). -set (E:= default_abs). -rewrite !Rplus_assoc. -match goal with |-context[?A<= ?B] => -replace A with (Rabs (F * d * d' + (F * d + F * d')) + ((1+ D) * g n2 * s1 + D * s1) + - (D * g1 n2 (n2 - 1) + (g1 n2 (n2 -1) + Rabs (1 + d') * E))) by nra; -replace B with -(g (S n2) * Rabs F + s1 * g (S n2) + g1 (S n2) (S n2 - 1) ) - by (rewrite Rmult_assoc -Rabs_mult -/F; replace (1 + n2)%nat with n2.+1 by lia; nra) -end. -apply Rplus_le_compat. -apply Rplus_le_compat. -unfold g. -eapply Rle_trans; - [ apply Rabs_triang | ]. -eapply Rle_trans; - [ apply Rplus_le_compat; [rewrite !Rabs_mult| eapply Rle_trans; [apply Rabs_triang| ]] | ]. -apply Rmult_le_compat; - [rewrite -!Rabs_mult; apply Rabs_pos | apply Rabs_pos| | apply Hd']. -apply Rmult_le_compat_l; [rewrite -!Rabs_mult; apply Rabs_pos | apply Hd ]. -apply Rplus_le_compat; rewrite Rabs_mult. -apply Rmult_le_compat_l; [apply Rabs_pos | apply Hd ]. -apply Rmult_le_compat_l; [apply Rabs_pos | apply Hd' ]. -rewrite -!Rabs_mult. -fold D F. replace (Rabs F * D * D + (Rabs F * D + Rabs F * D)) with - ( ((1 + D)*(1+D) - 1) * Rabs F) by nra. -apply Rmult_le_compat_r; try apply Rabs_pos; unfold D, g. -apply Rplus_le_compat; try nra. -rewrite <- tech_pow_Rmult. -apply Rmult_le_compat_l; auto with commonDB. -eapply Rle_trans with ((1 + D)^1); try nra. -fold D; nra. -apply Rle_pow; auto with commonDB. -lia. -apply Req_le. unfold D,E. rewrite one_plus_d_mul_g. -rewrite Rmult_comm. -repeat f_equal; try lia. -rewrite <- !Rplus_assoc. -replace (D * g1 n2 (n2 - 1) + g1 n2 (n2 - 1)) with (g1 n2 (n2-1) * (1+D)) by nra. -unfold D. -rewrite one_plus_d_mul_g1; auto. -eapply Rle_trans; [apply Rplus_le_compat_l |]. -apply Rmult_le_compat_r; unfold E; auto with commonDB. -assert (Rabs (1 + d') <= 1 + D). -eapply Rle_trans; [apply Rabs_triang| rewrite Rabs_R1]. -apply Rplus_le_compat_l; apply Hd'. -apply H. -eapply Rle_trans; [apply plus_d_e_g1_le; auto| apply Req_le; f_equal;lia]. +simpl fst; simpl snd. +(* case split on whether the head entry is zero *) +destruct (Req_EM_T (FT2R a) 0%R). +- (* head is zero: contributes no rounding error *) + simpl map; rewrite e. + move : (nnzR_cons [seq FT2R i | i <- l]) => /eqP H8; rewrite {}H8. + replace (FT2R (BPLUS (BMULT a f) s0)) with (FT2R s0). + + change GRing.zero with R0. + field_simplify_Rabs. + eapply Rle_trans; [apply IHl |]. + apply Req_le; f_equal; try nra. + unfold n2, common.nnzR. + rewrite Rabs_R0 Rmult_0_l Rplus_0_l. + simpl count; auto. + + rewrite {}Hplus. + pose proof Bmult_0R a f Haf e as A. + destruct A as [A | A]; auto; rewrite A. + * by rewrite Rplus_0_l round_FT2R. + * by rewrite FT2R_pos_zero Rplus_0_l round_FT2R. +- (* head is nonzero: contributes to the error budget *) + unfold common.nnzR. + simpl count. + replace (0 != FT2R a) with true by (symmetry; apply /eqP; auto). + simpl. + set (l1 := (map FT2R l)) in *. + change (count _ _) with n2. + (* case split on whether the tail has any nonzeros *) + assert (H7 : (n2 = 0)%nat \/ (1 <= n2)%nat) by lia. + destruct H7 as [H7 | H7]. + + (* tail is all zeros: only one nonzero product *) + rewrite H7. + assert (H0 : eq_op n2 0%N) by lia. + pose proof R_dot_prod_rel_nnzR l l0 Hlen1 s H2 H0; subst. + pose proof dot_prod_rel_nnzR l l0 Hlen1 s0 H6 Hs0 H0. + pose proof R_dot_prod_rel_nnzR_abs l l0 Hlen1 s1 H10 H0; subst. + rewrite Bplus_0R; auto. + destruct (BMULT_accurate' a f Haf) as (d' & e' & Hed' & Hd' & He' & Hacc). + rewrite Hacc; clear Hacc. + unfold g1, g; simpl. + field_simplify; field_simplify_Rabs. + eapply Rle_trans; [apply Rabs_triang |]. + apply Rplus_le_compat. + { rewrite Rabs_mult Rmult_comm Rabs_mult Rmult_assoc. + apply Rmult_le_compat_r; auto with commonDB. + rewrite <- Rabs_mult; apply Rabs_pos. } + eapply Rle_trans; [apply He' |]; auto with commonDB; nra. + + (* tail has nonzeros: full inductive step *) + destruct (BPLUS_accurate' (BMULT a f) s0 Hfin) as (d' & Hd' & Hacc). + rewrite Hacc; clear Hacc. + destruct (BMULT_accurate' a f Haf) as (d & e & Hed & Hd & He & Hacc). + rewrite Hacc; clear Hacc. + set (F := FT2R a * FT2R f). + field_simplify_Rabs. + replace (F * d * d' + F * d + F * d' + e * d' + e + FT2R s0 * d' + FT2R s0 - s) + with ((F * d * d' + F * d + F * d' + FT2R s0 * d') + + (FT2R s0 - s) + (1 + d') * e) + by nra. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans; + [apply Rplus_le_compat; + [eapply Rle_trans; [apply Rabs_triang |] |] |]. + { apply Rplus_le_compat_l; apply IHl. } + { rewrite Rabs_mult; apply Rmult_le_compat_l; [apply Rabs_pos | apply He]. } + rewrite Rplus_assoc. + eapply Rle_trans; + [apply Rplus_le_compat_r; + eapply Rle_trans; [apply Rabs_triang |] |]. + apply Rplus_le_compat_l. + rewrite Rabs_mult Rmult_comm. + apply Rmult_le_compat; [apply Rabs_pos | apply Rabs_pos | apply Hd' |]. + { apply Rabs_le_minus in IHl. + assert (Hs : Rabs (FT2R s0) <= + g n2 * s1 + g1 n2 (n2 - 1) + s1). + { eapply Rle_trans; [apply IHl |]. + apply Rplus_le_compat. + - apply Rplus_le_compat. + + apply Rmult_le_compat; auto with commonDB; try apply Rle_refl. + rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 (zip l l0)) s1); auto. + apply Rabs_pos. + + apply Rle_refl. + - rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 (zip l l0)) s1); auto. + apply (dot_prod_sum_rel_R_Rabs (map FR2 (zip l l0))); auto. } + apply Hs. } + field_simplify. + unfold g1, g in IHl; field_simplify in IHl. + set (D := default_rel). + set (E := default_abs). + rewrite !Rplus_assoc. + match goal with |- context [?A <= ?B] => + replace A with + (Rabs (F * d * d' + (F * d + F * d')) + + ((1 + D) * g n2 * s1 + D * s1) + + (D * g1 n2 (n2 - 1) + (g1 n2 (n2 - 1) + Rabs (1 + d') * E))) + by nra; + replace B with + (g (S n2) * Rabs F + s1 * g (S n2) + g1 (S n2) (S n2 - 1)) + by (rewrite Rmult_assoc -Rabs_mult -/F; + replace (1 + n2)%nat with n2.+1 by lia; nra) + end. + apply Rplus_le_compat; [apply Rplus_le_compat |]. + * (* bound on |F * d * d' + (F * d + F * d')| *) + unfold g. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans; + [apply Rplus_le_compat; + [rewrite !Rabs_mult + | eapply Rle_trans; [apply Rabs_triang |]] |]. + { apply Rmult_le_compat; + [rewrite -!Rabs_mult; apply Rabs_pos | apply Rabs_pos | | apply Hd']. + apply Rmult_le_compat_l; [rewrite -!Rabs_mult; apply Rabs_pos | apply Hd]. } + { apply Rplus_le_compat; rewrite Rabs_mult. + - apply Rmult_le_compat_l; [apply Rabs_pos | apply Hd]. + - apply Rmult_le_compat_l; [apply Rabs_pos | apply Hd']. } + rewrite -!Rabs_mult. + fold D F. + replace (Rabs F * D * D + (Rabs F * D + Rabs F * D)) + with (((1 + D) * (1 + D) - 1) * Rabs F) by nra. + apply Rmult_le_compat_r; [apply Rabs_pos |]. + unfold D, g. + apply Rplus_le_compat; try nra. + rewrite <- tech_pow_Rmult. + apply Rmult_le_compat_l; auto with commonDB. + eapply Rle_trans with ((1 + D) ^ 1); try nra. + fold D; nra. + apply Rle_pow; auto with commonDB; lia. + * (* bound on the g-weighted sum term *) + apply Req_le. + unfold D, E. + rewrite one_plus_d_mul_g Rmult_comm. + repeat f_equal; try lia. + * (* bound on the g1 absolute error term *) + rewrite <- !Rplus_assoc. + replace (D * g1 n2 (n2 - 1) + g1 n2 (n2 - 1)) + with (g1 n2 (n2 - 1) * (1 + D)) by nra. + unfold D. + rewrite one_plus_d_mul_g1; auto. + eapply Rle_trans; [apply Rplus_le_compat_l |]. + { apply Rmult_le_compat_r; unfold E; auto with commonDB. + assert (Rabs (1 + d') <= 1 + D). + { eapply Rle_trans; [apply Rabs_triang |]. + rewrite Rabs_R1; apply Rplus_le_compat_l; apply Hd'. } + apply H. } + eapply Rle_trans; [apply plus_d_e_g1_le; auto |]. + apply Req_le; f_equal; lia. Qed. End SparseErrorRel1. -Section SparseErrorRel2. -(* sparse forward error bound for fma dot product using inductive rels *) -Context {NAN: FPCore.Nans} {t : type}. +(** * Section 6: Sparse Forward Error Bound — FMA Dot Product *) + +Section SparseErrorRel2. +(** + Sparse forward error bound for the FMA-based dot product. + + As in [SparseErrorRel1], zero entries in do not contribute to + rounding error accumulation. +*) + +Context {NAN : FPCore.Nans} {t : type}. Variables (v1 v2 : list (ftype t)). Hypothesis (Hlen : size v1 = size v2). -Variable (fp : ftype t). -Hypothesis Hfp : fma_dot_prod_rel (zip v1 v2) fp. -Hypothesis Hfin: Binary.is_finite fp = true. +Variable (fp : ftype t). +Hypothesis Hfp : fma_dot_prod_rel (zip v1 v2) fp. +Hypothesis Hfin : Binary.is_finite fp = true. -Notation v1R := (map FT2R v1). -Notation vR := (map FR2 (zip v1 v2)). -Notation vR' := (map Rabsp (map FR2 (zip v1 v2))). +Notation v1R := (map FT2R v1). +Notation vR := (map FR2 (zip v1 v2)). +Notation vR' := (map Rabsp (map FR2 (zip v1 v2))). +Notation nnzR := (nnzR v1R). -Variable (rp rp_abs : R). +Variable (rp rp_abs : R). Hypothesis Hrp : R_dot_prod_rel vR rp. -Hypothesis Hra : R_dot_prod_rel vR' rp_abs. +Hypothesis Hra : R_dot_prod_rel vR' rp_abs. -Notation g := (@common.g t). +Notation g := (@common.g t). Notation g1 := (@common.g1 t). -Notation D := (@default_rel t). -Notation E := (@default_abs t). -Notation nnzR := (nnzR v1R). +Notation D := (@default_rel t). +Notation E := (@default_abs t). + +(** [sparse_fma_dotprod_forward_error] establishes the sparse forward error + bound for the FMA dot product. +*) -Lemma sparse_fma_dotprod_forward_error: - Rabs (FT2R fp - rp) <= g nnzR * rp_abs + g1 nnzR (nnzR - 1). +Lemma sparse_fma_dotprod_forward_error : + Rabs (FT2R fp - rp) <= g nnzR * rp_abs + g1 nnzR (nnzR - 1). Proof. revert Hlen Hfp Hfin Hrp Hra. revert rp rp_abs fp v2. induction v1; intros. -{ simpl in Hlen; symmetry in Hlen; apply length_zero_iff_nil in Hlen; subst. -inversion Hfp; inversion Hrp; subst; simpl; field_simplify_Rabs. - rewrite Rabs_Ropp Rabs_R0. +{ (* base case: empty v1 *) + simpl in Hlen; symmetry in Hlen; apply length_zero_iff_nil in Hlen; subst. + inversion Hfp; inversion Hrp; subst; simpl; field_simplify_Rabs. + rewrite Rabs_Ropp Rabs_R0. apply Rplus_le_le_0_compat; auto with commonDB. - apply Rmult_le_pos; auto with commonDB. - rewrite <- (R_dot_prod_rel_Rabs_eq [] rp_abs); auto; + apply Rmult_le_pos; auto with commonDB. + rewrite <- (R_dot_prod_rel_Rabs_eq [] rp_abs); auto. apply Rabs_pos. } destruct v2; try discriminate. assert (Hlen1 : size l = size l0) by (simpl; auto). -set (n2:= (common.nnzR (map FT2R l))%nat) in *. +set (n2 := (common.nnzR (map FT2R l))%nat) in *. inversion Hrp. inversion Hfp. inversion Hra; subst. -simpl fst in *. simpl snd in *. -destruct (BFMA_correct _ _ _ Hfin) as [[Ha [Hf Hs0]] Hfma]. +simpl fst in *; simpl snd in *. +destruct (BFMA_correct _ _ _ Hfin) as [[Ha [Hf Hs0]] Hfma]. specialize (IHl s s1 s0 l0 Hlen1 H6 Hs0 H2 H10). -(* reason by cases on the head of the list *) -destruct (Req_EM_T (FT2R a) 0%R). -(* start head of list is zero *) -{ simpl map. rewrite e. - move :(nnzR_cons [seq FT2R i | i <- l]) => /eqP H. - rewrite {}H. -replace (FT2R (BFMA a f s0)) with (FT2R s0). -change GRing.zero with R0. -field_simplify_Rabs. -eapply Rle_trans; [apply IHl|]. -apply Req_le; f_equal; try nra. unfold n2, common.nnzR. -rewrite Rabs_R0 Rmult_0_l Rplus_0_l //. -pose proof Bfma_mult_0R a f s0 Hfin as A; destruct A; auto; rewrite A. - } (* end head of list is zero *) -(* start head of list is non-zero *) -unfold common.nnzR. -move :n => /eqP. rewrite eq_sym => n. -rewrite /= n /=. -set (l1:= (map FT2R l)) in *. -change (count _ l1) with n2. -(* start case on nnz = case on nnz in tail *) -assert (A: (n2 = O)%nat \/ (1<=n2)%nat) by lia; destruct A. -(* tail all zeros *) -{ rewrite H. - assert (n2 == 0)%N by lia. -pose proof R_dot_prod_rel_nnzR l l0 Hlen1 s H2 H0; subst. -pose proof fma_dot_prod_rel_nnzR l l0 Hlen1 s0 H6 Hs0 H0. -pose proof R_dot_prod_rel_nnzR_abs l l0 Hlen1 s1 H10 H0; subst. -destruct (fma_accurate' a f s0 Hfin) as (e & d & ed & He & Hd & Hacc). -rewrite Hacc; clear Hacc. -rewrite H1. -unfold g1, g. -simpl; field_simplify; field_simplify_Rabs. -eapply Rle_trans; [apply Rabs_triang | ]. -apply Rplus_le_compat. -rewrite Rabs_mult. -rewrite Rmult_comm. -rewrite Rabs_mult. rewrite Rmult_assoc. -apply Rmult_le_compat_r; auto with commonDB. -rewrite <- Rabs_mult; apply Rabs_pos. -eapply Rle_trans; [apply Hd| ]; auto with commonDB; nra. -} -pose proof (fma_accurate' a f s0 Hfin) as Hplus. -destruct Hplus as (d' & e'& Hz & Hd'& He'& Hplus); rewrite Hplus; - clear Hplus. -(* algebra *) -field_simplify_Rabs. -replace (FT2R a * FT2R f * d' + FT2R s0 * d' + FT2R s0 + e' - s) with - (d' * (FT2R a * FT2R f) + d' * FT2R s0 + (FT2R s0 - s) + e') by nra. -eapply Rle_trans; - [ apply Rabs_triang | eapply Rle_trans; [ apply Rplus_le_compat_r; apply Rabs_triang - | ] ]. -eapply Rle_trans; - [ apply Rplus_le_compat_r | ]. -apply Rplus_le_compat_r. -apply Rabs_triang. -eapply Rle_trans; - [apply Rplus_le_compat_r; apply Rplus_le_compat_l | ]. -apply IHl. -eapply Rle_trans; - [apply Rplus_le_compat; [apply Rplus_le_compat_r| apply He' ] | ]. -apply Rplus_le_compat. -rewrite Rabs_mult; -apply Rmult_le_compat_r; try apply Rabs_pos; -apply Hd'. -rewrite Rabs_mult; -apply Rmult_le_compat; try apply Rabs_pos. -apply Hd'. -apply Rabs_le_minus in IHl. -assert (Hs: Rabs (FT2R s0) <= - g n2 * s1 + g1 n2 (n2 - 1) + s1). -{ eapply Rle_trans. apply IHl. -apply Rplus_le_compat_l. -rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 (zip l l0)) s1); auto. -apply (dot_prod_sum_rel_R_Rabs (map FR2 (zip l l0))); auto. } -apply Hs. -set (F:=Rabs (FT2R a * FT2R f)). -rewrite !Rmult_plus_distr_l. -replace (D * F + (D * (g n2 * s1) + D * g1 n2 (n2 - 1) + D * s1) + -(g n2 * s1 + g1 n2 (n2 - 1)) + E) with -(D * F + ((1 + D) * g n2 * s1 + D * s1) + g1 n2 (n2 - 1) * (1 + D) + E) by nra. -rewrite one_plus_d_mul_g. rewrite one_plus_d_mul_g1; auto. -rewrite Rplus_assoc. -apply Rplus_le_compat. -apply Rplus_le_compat. -rewrite <- Rabs_mult. fold F. -apply Rmult_le_compat_r. -unfold F; apply Rabs_pos. -apply d_le_g_1; lia. -apply Rmult_le_compat_r. -rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 (zip l l0)) s1); auto. apply Rabs_pos. -apply Req_le; f_equal; auto; lia. -eapply Rle_trans. -apply plus_e_g1_le; auto. -replace (1 + n2 - 1)%nat with n2 by lia. -change (1 + n2)%nat with n2.+1. -lra. +(* case split on whether the head entry is zero *) +destruct (Req_EM_T (FT2R a) 0%R). +- (* head is zero: contributes no rounding error *) + simpl map; rewrite e. + move : (nnzR_cons [seq FT2R i | i <- l]) => /eqP H. + rewrite {}H. + replace (FT2R (BFMA a f s0)) with (FT2R s0). + + change GRing.zero with R0. + field_simplify_Rabs. + eapply Rle_trans; [apply IHl |]. + apply Req_le; f_equal; try nra. + unfold n2, common.nnzR. + rewrite Rabs_R0 Rmult_0_l Rplus_0_l //. + + pose proof Bfma_mult_0R a f s0 Hfin as A. + destruct A; auto; rewrite A. +- (* head is nonzero: contributes to the error budget *) + unfold common.nnzR. + move : n => /eqP; rewrite eq_sym => n. + rewrite /= n /=. + set (l1 := (map FT2R l)) in *. + change (count _ l1) with n2. + (* case split on whether the tail has any nonzeros *) + assert (A : (n2 = O)%nat \/ (1 <= n2)%nat) by lia. + destruct A. + + (* tail is all zeros: only one nonzero product *) + rewrite H. + assert (n2 == 0)%N by lia. + pose proof R_dot_prod_rel_nnzR l l0 Hlen1 s H2 H0; subst. + pose proof fma_dot_prod_rel_nnzR l l0 Hlen1 s0 H6 Hs0 H0. + pose proof R_dot_prod_rel_nnzR_abs l l0 Hlen1 s1 H10 H0; subst. + destruct (fma_accurate' a f s0 Hfin) as (e & d & ed & He & Hd & Hacc). + rewrite Hacc H1; clear Hacc. + unfold g1, g; simpl. + field_simplify; field_simplify_Rabs. + eapply Rle_trans; [apply Rabs_triang |]. + apply Rplus_le_compat. + { rewrite Rabs_mult Rmult_comm Rabs_mult Rmult_assoc. + apply Rmult_le_compat_r; auto with commonDB. + rewrite <- Rabs_mult; apply Rabs_pos. } + eapply Rle_trans; [apply Hd |]; auto with commonDB; nra. + + (* tail has nonzeros: full inductive step *) + pose proof (fma_accurate' a f s0 Hfin) as Hplus. + destruct Hplus as (d' & e' & Hz & Hd' & He' & Hplus). + rewrite Hplus; clear Hplus. + field_simplify_Rabs. + replace (FT2R a * FT2R f * d' + FT2R s0 * d' + FT2R s0 + e' - s) + with (d' * (FT2R a * FT2R f) + d' * FT2R s0 + (FT2R s0 - s) + e') + by nra. + eapply Rle_trans; + [apply Rabs_triang + | eapply Rle_trans; + [apply Rplus_le_compat_r; apply Rabs_triang |]]. + eapply Rle_trans; [apply Rplus_le_compat_r |]. + { apply Rplus_le_compat_r; apply Rabs_triang. } + eapply Rle_trans; + [apply Rplus_le_compat_r; apply Rplus_le_compat_l |]. + { apply IHl. } + eapply Rle_trans; + [apply Rplus_le_compat; [apply Rplus_le_compat_r | apply He'] |]. + { apply Rplus_le_compat. + - rewrite Rabs_mult. + apply Rmult_le_compat_r; [apply Rabs_pos | apply Hd']. + - rewrite Rabs_mult. + apply Rmult_le_compat; try apply Rabs_pos. + + apply Hd'. + + apply Rabs_le_minus in IHl. + assert (Hs : Rabs (FT2R s0) <= + g n2 * s1 + g1 n2 (n2 - 1) + s1). + { eapply Rle_trans; [apply IHl |]. + apply Rplus_le_compat_l. + rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 (zip l l0)) s1); auto. + apply (dot_prod_sum_rel_R_Rabs (map FR2 (zip l l0))); auto. } + apply Hs. } + set (F := Rabs (FT2R a * FT2R f)). + rewrite !Rmult_plus_distr_l. + replace (D * F + + (D * (g n2 * s1) + D * g1 n2 (n2 - 1) + D * s1) + + (g n2 * s1 + g1 n2 (n2 - 1)) + E) + with (D * F + ((1 + D) * g n2 * s1 + D * s1) + + g1 n2 (n2 - 1) * (1 + D) + E) + by nra. + rewrite one_plus_d_mul_g one_plus_d_mul_g1; auto. + rewrite Rplus_assoc. + apply Rplus_le_compat; [apply Rplus_le_compat |]. + * rewrite <- Rabs_mult; fold F. + apply Rmult_le_compat_r; [unfold F; apply Rabs_pos |]. + apply d_le_g_1; lia. + * apply Rmult_le_compat_r. + { rewrite <- (R_dot_prod_rel_Rabs_eq (map FR2 (zip l l0)) s1); auto. + apply Rabs_pos. } + apply Req_le; f_equal; auto; lia. + * eapply Rle_trans; [apply plus_e_g1_le; auto |]. + replace (1 + n2 - 1)%nat with n2 by lia. + change (1 + n2)%nat with n2.+1. + lra. Qed. -End SparseErrorRel2. +End SparseErrorRel2. \ No newline at end of file diff --git a/accuracy_proofs/dotprod_model.v b/accuracy_proofs/dotprod_model.v index 9b6d3d5..8aa85fe 100644 --- a/accuracy_proofs/dotprod_model.v +++ b/accuracy_proofs/dotprod_model.v @@ -1,575 +1,853 @@ +(** * Dot Product Definitions and Properties + + This file develops the theory of dot products in three settings: + floating-point arithmetic (with and without FMA), and real-number + arithmetic. The three treatments are connected by relational + characterizations that make it straightforward to transfer + accuracy bounds from the real model to the floating-point + implementations. + + ** Overview of contents + + - [dotprod]: a generic, parameterised dot-product computed via + [foldl] over a zipped pair of lists. Specialised to + - [dotprodF] - IEEE floating-point (BMULT / BPLUS), + - [fma_dotprod] - IEEE floating-point with fused multiply-add + (BFMA), and + - [dotprodR] - real arithmetic (Rmult / Rplus). + + - Inductive relations that characterize the value produced by each + variant: + - [dot_prod_rel] for [dotprodF], + - [fma_dot_prod_rel] for [fma_dotprod], + - [dotprod_any] / [dotprod_any'] for arbitrary reassociation / + reordering of a floating-point dot product, + - [R_dot_prod_rel] for [dotprodR]. + + - Key lemmas connecting implementations to their relations: + - [dotprodF_rel_fold_right] - [dot_prod_rel] holds for + [dotprodF]. + - [fma_dot_prod_rel_fold_right] - [fma_dot_prod_rel] holds for + [fma_dotprod]. + - [dotprodR_rel] - [R_dot_prod_rel] holds for [dotprodR]. + - [dotprod_rel_dotprod_any] - every result of [dot_prod_rel] + can be witnessed by [dotprod_any]. + + - Bound lemmas for the real dot product: + - [dotprodR_rel_bound'] - |⟨u,v⟩| ≤ n·a when each component + has absolute value at most √a. + - [dotprodR_rel_bound''] - same bound but for the dot product + of absolute values ⟨|u|,|v|⟩. + - [dot_prod_sum_rel_R_Rabs] - |⟨u,v⟩| ≤ ⟨|u|,|v|⟩. + + - Non-zero-detection lemmas ([Section NonZeroDP]): when every + entry of a vector has real value 0 the floating-point (and real) + dot products are 0. +*) + From LAProof.accuracy_proofs Require Import preamble common float_acc_lems. Require Import FunctionalExtensionality. Require Import Permutation. -Definition dotprod {A} (mult plus: A -> A -> A) (zero : A) (v1 v2: list A):A := - foldl (Basics.flip plus) zero (map (uncurry mult) (zip v1 v2)). +(** ** Generic dot product + + [dotprod mult plus zero v1 v2] computes the inner product of two + lists using the supplied binary operations and additive + identity. The result is obtained by zipping the two lists, + multipltying pointwise, and summing with [foldl]. *) + +Definition dotprod {A} (mult plus : A -> A -> A) (zero : A) + (v1 v2 : list A) : A := + foldl (Basics.flip plus) zero (map (uncurry mult) (zip v1 v2)). +(** If the left argument is [nil] the result is zero, regardless of + the right argument. *) + Lemma dotprod_nil_l : - forall A (l : list A) - (mult plus : A -> A -> A) (zero : A), dotprod mult plus zero nil l = zero. + forall A (l : list A) (mult plus : A -> A -> A) (zero : A), + dotprod mult plus zero nil l = zero. Proof. destruct l; auto. Qed. +(** If the right argument is [nil] the result is zero, regardless of + the left argument. *) + Lemma dotprod_nil_r : - forall A (l : list A) - (mult plus : A -> A -> A) (zero : A), dotprod mult plus zero l nil = zero. + forall A (l : list A) (mult plus : A -> A -> A) (zero : A), + dotprod mult plus zero l nil = zero. Proof. destruct l; auto. Qed. +(** When the right list is a singleton, the dot product reduces + to a single multiplication. *) + Lemma dotprod_single : - forall A (l : list A) - (mult plus : A -> A -> A) (zero a2: A) - (Hpz : forall y, plus y zero = y) - (Hmz : forall y, mult zero y = zero), -let a1 := nth zero l 0 in -dotprod mult plus zero l [a2] = mult a1 a2. + forall A (l : list A) + (mult plus : A -> A -> A) (zero a2 : A) + (Hpz : forall y, plus y zero = y) + (Hmz : forall y, mult zero y = zero), + let a1 := nth zero l 0 in + dotprod mult plus zero l [a2] = mult a1 a2. Proof. -intros; simpl; destruct l. -- rewrite dotprod_nil_l. subst a1. simpl; auto. -- unfold dotprod. rewrite /= {2}/Basics.flip Hpz. destruct l; auto. + intros; simpl; destruct l. + - rewrite dotprod_nil_l. subst a1. simpl; auto. + - unfold dotprod. rewrite /= {2}/Basics.flip Hpz. destruct l; auto. Qed. +(* ------------------------------------------------------------------ *) +(** ** Floating-point dot product *) +(* ------------------------------------------------------------------ *) + Section DotProdFloat. Context {NAN : FPCore.Nans} {t : type}. -Definition dotprodF : list (ftype t) -> list (ftype t) -> ftype t := - dotprod BMULT BPLUS pos_zero. - -Inductive dot_prod_rel : - list (ftype t * ftype t) -> ftype t -> Prop := -| dot_prod_rel_nil : dot_prod_rel nil pos_zero -| dot_prod_rel_cons : forall l (xy : ftype t * ftype t) s, - dot_prod_rel l s -> - dot_prod_rel (xy::l) (BPLUS (BMULT (fst xy) (snd xy)) s). - -Inductive dotprod_any' : forall (h: nat) (v: list (ftype t * ftype t)) (s: ftype t), Prop := -| Dotprod_Any_1: forall x, dotprod_any' O [x] (BMULT (fst x) (snd x)) -| Dotprod_Any_split: forall n1 n2 al bl a b, - dotprod_any' n1 al a -> dotprod_any' n2 bl b -> dotprod_any' (S (Nat.max n1 n2)) (al++bl) (BPLUS a b) -| Dotprod_Any_perm: forall n al bl s, Permutation al bl -> dotprod_any' n al s -> dotprod_any' n bl s. - -Inductive dotprod_any : forall (h: nat) (v: list (ftype t * ftype t)) (s: ftype t), Prop := -| Dotprod_Any_None: dotprod_any O nil pos_zero -| Dotprod_Any_Some: forall n v s, dotprod_any' n v s -> dotprod_any n v s. - -Lemma dotprod_rel_dotprod_any: forall (z: ftype t) (v: list (ftype t * ftype t)) s (Hz: iszero z), - Forall (fun xy => Binary.is_finite (BMULT (fst xy) (snd xy))) v -> - dot_prod_rel v s -> - exists s', feq s s' /\ dotprod_any (Nat.pred (size v)) v s'. +(** [dotprodF v1 v2] is the standard left-to-right floating-point dot + product using IEEE multiplication [BMULT] and addition [BPLUS], + starting from the positive zero. *) + +Definition dotprodF (v1 v2 : list (ftype t)) : ftype t := + dotprod BMULT BPLUS pos_zero v1 v2. + +(** [dot_prod_rel l s] is the inductive relation that characterizes the + value s obtained by computing the dot product of the list of + pairs l in left-to-right order using IEEE arithmetic. *) + +Inductive dot_prod_rel : + list (ftype t * ftype t) -> ftype t -> Prop := +| dot_prod_rel_nil : + dot_prod_rel nil pos_zero +| dot_prod_rel_cons : + forall l (xy : ftype t * ftype t) s, + dot_prod_rel l s -> + dot_prod_rel (xy :: l) + (BPLUS (BMULT (fst xy) (snd xy)) s). + +(** [dotprod_any' h v s] witnesses that s can be obtained as the + floating-point dot product of the pairs in v by _any_ + parenthesisation of depth at most h, including arbitrary + permutations of the summands (via [Dotprod_Any_perm]). This + relation supports accuracy analyses that do not depend on a + particular evaluation order. *) + +Inductive dotprod_any' : + forall (h : nat) (v : list (ftype t * ftype t)) + (s : ftype t), Prop := +| Dotprod_Any_1 : + forall x, + dotprod_any' O [x] (BMULT (fst x) (snd x)) +| Dotprod_Any_split : + forall n1 n2 al bl a b, + dotprod_any' n1 al a -> + dotprod_any' n2 bl b -> + dotprod_any' (S (Nat.max n1 n2)) (al ++ bl) (BPLUS a b) +| Dotprod_Any_perm : + forall n al bl s, + Permutation al bl -> + dotprod_any' n al s -> + dotprod_any' n bl s. + +(** [dotprod_any h v s] extends [dotprod_any'] to handle the empty + list, where the dot product is positive zero. *) + +Inductive dotprod_any : + forall (h : nat) (v : list (ftype t * ftype t)) + (s : ftype t), Prop := +| Dotprod_Any_None : + dotprod_any O nil pos_zero +| Dotprod_Any_Some : + forall n v s, + dotprod_any' n v s -> + dotprod_any n v s. + +(** Every value related to a list by [dot_prod_rel] is (up to [feq]) + also witnessed by [dotprod_any], provided all pairwise products + are finite. This is the key lemma that bridges the sequential + IEEE relation and the order-independent [dotprod_any] relation. *) + +Lemma dotprod_rel_dotprod_any : + forall (z : ftype t) (v : list (ftype t * ftype t)) s + (Hz : iszero z) + (Hfin : Forall (fun xy => + Binary.is_finite (BMULT (fst xy) (snd xy))) v), + dot_prod_rel v s -> + exists s', feq s s' /\ dotprod_any (Nat.pred (size v)) v s'. Proof. -destruct v as [ | [x y] v]; intros * Hz Hfin H. -- -destruct z; try discriminate; clear Hfin. -inversion H; clear H; subst; (eexists; split; [ | constructor]; reflexivity). -- -revert x y s z Hfin Hz H; induction v as [ | [x y] v]; simpl; intros. -+ -inversion H; clear H; subst. -inversion Hfin; clear Hfin; subst. rename H2 into Hfin. rename H1 into Hfin1. -inversion H3; clear H3; subst. -simpl in *. -eexists. -split; [ | constructor; constructor]. -simpl. -apply BPLUS_0_r. apply strict_feq_refl; auto. -+ -inversion Hfin; clear Hfin; subst. rename H3 into Hfin. rename H2 into Hfin1. -inversion H; clear H; subst. -specialize (IHv x y s0 z Hfin Hz H3). -simpl in *. -change (cons (x0,y0) (cons (x,y) v)) with ([(x0,y0)] ++ cons (x,y) v). -replace (S (size v)) with (S (Nat.max O (size v))) by lia. -destruct IHv as [s1 [? ?]]. -eexists. -inversion H0; clear H0; subst. -simpl in H1. -split. -2:{ constructor 2. -eapply Dotprod_Any_split; auto. -apply Dotprod_Any_1. -eassumption. -} -clear z Hz H3 H1. -rewrite H; auto. + destruct v as [ | [x y] v]; intros * Hz Hfin H. + - destruct z; try discriminate; clear Hfin. + inversion H; clear H; subst; + (eexists; split; [ | constructor]; reflexivity). + - revert x y s z Hfin Hz H; + induction v as [ | [x y] v]; simpl; intros. + + inversion H; clear H; subst. + inversion Hfin; clear Hfin; subst. + rename H2 into Hfin. rename H1 into Hfin1. + inversion H3; clear H3; subst. + simpl in *. + eexists. + split; [ | constructor; constructor]. + simpl. + apply BPLUS_0_r. apply strict_feq_refl; auto. + + inversion Hfin; clear Hfin; subst. + rename H3 into Hfin. rename H2 into Hfin1. + inversion H; clear H; subst. + specialize (IHv x y s0 z Hfin Hz H3). + simpl in *. + change (cons (x0, y0) (cons (x, y) v)) + with ([(x0, y0)] ++ cons (x, y) v). + replace (S (size v)) with (S (Nat.max O (size v))) by lia. + destruct IHv as [s1 [? ?]]. + eexists. + inversion H0; clear H0; subst. + simpl in H1. + split. + 2:{ constructor 2. + eapply Dotprod_Any_split; auto. + apply Dotprod_Any_1. + eassumption. } + clear z Hz H3 H1. + rewrite H; auto. Qed. +(** [dot_prod_rel] characterizes [dotprodF]: the relation holds for + the reversed zip of the two input lists. *) + Lemma dotprodF_rel_fold_right : -forall (v1 v2: list (ftype t)), + forall (v1 v2 : list (ftype t)), dot_prod_rel (rev (zip v1 v2)) (dotprodF v1 v2). Proof. -intros v1 v2. unfold dotprodF, dotprod. -rewrite -(revK (map _ (zip v1 v2))) foldl_rev -map_rev. -induction (rev _) as [ | [x y] l]; constructor; auto. + intros v1 v2. unfold dotprodF, dotprod. + rewrite -(revK (map _ (zip v1 v2))) foldl_rev -map_rev. + induction (rev _) as [ | [x y] l]; constructor; auto. Qed. End DotProdFloat. +(* ------------------------------------------------------------------ *) +(** ** Floating-point dot product with FMA *) +(* ------------------------------------------------------------------ *) + Section DotProdFMA. Context {NAN : FPCore.Nans} {t : type}. -(* FMA dot-product *) -Definition fma_dotprod (v1 v2: list (ftype t)) : ftype t := - foldl (fun s x12 => BFMA (fst x12) (snd x12) s) pos_zero (zip v1 v2). - -Inductive fma_dot_prod_rel : - list (ftype t * ftype t) -> ftype t -> Prop := -| fma_dot_prod_rel_nil : fma_dot_prod_rel nil (Zconst t 0) -| fma_dot_prod_rel_cons : forall l (xy : ftype t * ftype t) s, - fma_dot_prod_rel l s -> - fma_dot_prod_rel (xy::l) (BFMA (fst xy) (snd xy) s). - -(* NOTE: There is no fma_dotprod_any, because that doesn't really fit the FMA model. - That is, dotprod_any keeps its intermediate results in the same precision as the input values, - and that's not very FMA-like. *) - -Lemma fma_dot_prod_rel_fold_right : -forall (v1 v2: list (ftype t)), +(** [fma_dotprod v1 v2] computes the dot product o + using IEEE fused multiply-add. *) + +Definition fma_dotprod (v1 v2 : list (ftype t)) : ftype t := + foldl (fun s x12 => BFMA (fst x12) (snd x12) s) + pos_zero (zip v1 v2). + +(** [fma_dot_prod_rel l s] is the FMA analogue of [dot_prod_rel]: + s is the value obtained by accumulating [BFMA] from the left + over the pair list l. *) + +Inductive fma_dot_prod_rel : + list (ftype t * ftype t) -> ftype t -> Prop := +| fma_dot_prod_rel_nil : + fma_dot_prod_rel nil (Zconst t 0) +| fma_dot_prod_rel_cons : + forall l (xy : ftype t * ftype t) s, + fma_dot_prod_rel l s -> + fma_dot_prod_rel (xy :: l) + (BFMA (fst xy) (snd xy) s). + +(** Note: there is no [fma_dotprod_any] analogue of [dotprod_any], + because the FMA model is inherently sequential - it accumulates + products one at a time into a single running sum and does not + naturally support arbitrary reassociation. *) + +(** [fma_dot_prod_rel] characterizes [fma_dotprod]: the relation holds + for the reversed zip of the two input lists. *) + +Lemma fma_dot_prod_rel_fold_right : + forall (v1 v2 : list (ftype t)), fma_dot_prod_rel (rev (zip v1 v2)) (fma_dotprod v1 v2). Proof. -intros v1 v2. - unfold fma_dotprod. -rewrite -{2}(revK (zip v1 v2)) foldl_rev. -induction (rev _). -{ simpl; auto. apply fma_dot_prod_rel_nil. } -simpl. apply fma_dot_prod_rel_cons. auto. + intros v1 v2. + unfold fma_dotprod. + rewrite -{2}(revK (zip v1 v2)) foldl_rev. + induction (rev _). + { simpl; auto. apply fma_dot_prod_rel_nil. } + simpl. apply fma_dot_prod_rel_cons. auto. Qed. End DotProdFMA. +(* ------------------------------------------------------------------ *) +(** ** Real-number dot product *) +(* ------------------------------------------------------------------ *) Section RealDotProd. -Definition dotprodR: forall l1 l2 : seq R, R:= - dotprod Rmult Rplus 0%R. - -Inductive R_dot_prod_rel : list (R * R) -> R -> Prop := -| R_dot_prod_rel_nil : R_dot_prod_rel nil 0%R -| R_dot_prod_rel_cons : forall l xy s, - R_dot_prod_rel l s -> - R_dot_prod_rel (xy::l) (fst xy * snd xy + s). +(** [dotprodR l1 l2] is the exact real dot product of l1 and l2, + defined as an instance of the generic [dotprod] over ℝ. *) + +Definition dotprodR (l1 l2 : seq R) : R := + dotprod Rmult Rplus 0%R l1 l2. + +(** [R_dot_prod_rel l s] is the real analogue of [dot_prod_rel]: + s is the value of the dot product of the pair list l in ℝ. *) + +Inductive R_dot_prod_rel : list (R * R) -> R -> Prop := +| R_dot_prod_rel_nil : + R_dot_prod_rel nil 0%R +| R_dot_prod_rel_cons : + forall l xy s, + R_dot_prod_rel l s -> + R_dot_prod_rel (xy :: l) (fst xy * snd xy + s). + +(** The value witnessed by [R_dot_prod_rel] is unique. *) Lemma R_dot_prod_rel_eq : - forall l a b - (Ha: R_dot_prod_rel l a) - (Hb: R_dot_prod_rel l b), a = b. + forall l a b, + R_dot_prod_rel l a -> + R_dot_prod_rel l b -> + a = b. Proof. -induction l. -{ intros; inversion Ha; inversion Hb; auto. } -intros; inversion Ha; inversion Hb; subst; f_equal. -apply IHl; auto. + induction l. + { intros a b Ha Hb. inversion Ha; inversion Hb; auto. } + intros a0 b0 Ha Hb; inversion Ha; inversion Hb; subst; f_equal. + apply IHl; auto. Qed. -Definition Rabsp p : R * R := (Rabs (fst p), Rabs (snd p)). +(** [Rabsp p] replaces each component of the pair p by its absolute + value. It is used to build the absolute-value dot product that bounds + |⟨u, v⟩|. *) + +Definition Rabsp (p : R * R) : R * R := + (Rabs (fst p), Rabs (snd p)). -Definition FR2 {t: type} (x12: ftype t * ftype t) := (FT2R (fst x12), FT2R (snd x12)). +(** [FR2 x12] converts a pair of floating-point values to a pair of + real numbers using [FT2R]. *) + +Definition FR2 {t : type} (x12 : ftype t * ftype t) : R * R := + (FT2R (fst x12), FT2R (snd x12)). -Lemma FT2R_FR2 t : +(** Convenience rewriting rule: (FT2R a, FT2R a0) = FR2 (a, a0). *) + +Lemma FT2R_FR2 t : forall a a0 : ftype t, (FT2R a, FT2R a0) = FR2 (a, a0). Proof. reflexivity. Qed. -Definition sum_fold: list R -> R := foldr Rplus 0%R. +(** [sum_fold l] sums the elements of l using [foldr]. *) + +Definition sum_fold (l : list R) : R := foldr Rplus 0%R l. + +(** [dotprodR nil u = 0] for any u. *) -Lemma dotprodR_nil_l u: -dotprodR nil u = 0. +Lemma dotprodR_nil_l u : dotprodR nil u = 0. Proof. intros; apply dotprod_nil_l. Qed. -Lemma dotprodR_nil_r u: -dotprodR u nil = 0. +(** [dotprodR u nil = 0] for any u. *) + +Lemma dotprodR_nil_r u : dotprodR u nil = 0. Proof. intros; apply dotprod_nil_r. Qed. -Lemma flip_Rplus: Basics.flip Rplus = Rplus. -Proof. -rewrite /Basics.flip; -do 2 (apply FunctionalExtensionality.functional_extensionality; intro); lra. +(** [Basics.flip Rplus] is propositionally equal to [Rplus] because + real addition is commutative. *) + +Lemma flip_Rplus : Basics.flip Rplus = Rplus. +Proof. + rewrite /Basics.flip; + do 2 (apply FunctionalExtensionality.functional_extensionality; intro); lra. Qed. -Lemma Rplus_rewrite : (fun x y => x + y)%Re = Rplus. +Lemma Rplus_rewrite : (fun x y => x + y)%Re = Rplus. Proof. reflexivity. Qed. -Lemma sum_rev l: sum_fold l = sum_fold (rev l). +(** The sum [sum_fold l] equals [sum_fold (rev l)] because real + addition is commutative and associative. *) + +Lemma sum_rev l : sum_fold l = sum_fold (rev l). Proof. -rewrite /sum_fold -foldl_rev foldl_foldr. -f_equal; do 2 (apply FunctionalExtensionality.functional_extensionality; intro); lra. -hnf; intros; lra. -hnf; intros; lra. + rewrite /sum_fold -foldl_rev foldl_foldr. + f_equal; + do 2 (apply FunctionalExtensionality.functional_extensionality; intro); lra. + hnf; intros; lra. + hnf; intros; lra. Qed. +(** [R_dot_prod_rel] characterizes [dotprodR]: for any v1 and v2, + [R_dot_prod_rel (zip v1 v2) (dotprodR v1 v2)]. *) + Lemma dotprodR_rel : -forall (v1 v2: list R) , + forall (v1 v2 : list R), R_dot_prod_rel (zip v1 v2) (dotprodR v1 v2). Proof. -intros; unfold dotprodR, dotprod. -induction (zip v1 v2). -{ simpl. apply R_dot_prod_rel_nil. } -evar (z: R). -replace (foldl _ _ _) with z. -apply R_dot_prod_rel_cons; apply IHl. -subst z. -clear. -rewrite !foldl_foldr; [ | compute; intros; lra..]. -destruct a as [x y]; simpl. -rewrite Rplus_comm //. + intros; unfold dotprodR, dotprod. + induction (zip v1 v2). + { simpl. apply R_dot_prod_rel_nil. } + evar (z : R). + replace (foldl _ _ _) with z. + apply R_dot_prod_rel_cons; apply IHl. + subst z. + clear. + rewrite !foldl_foldr; [ | compute; intros; lra..]. + destruct a as [x y]; simpl. + rewrite Rplus_comm //. Qed. -Lemma dotprodR_rel_inj: forall l s1 s2, - R_dot_prod_rel l s1 -> R_dot_prod_rel l s2 -> s1=s2. +(** The value of the real dot product relation is injective in s. *) + +Lemma dotprodR_rel_inj : + forall l s1 s2, + R_dot_prod_rel l s1 -> + R_dot_prod_rel l s2 -> + s1 = s2. Proof. -induction l; intros; inversion H; clear H; inversion H0; clear H1; subst; f_equal; auto. + induction l; intros; + inversion H; clear H; inversion H0; clear H1; subst; f_equal; auto. Qed. -Lemma dotprodR_rev : forall (v1 v2: list R) , - size v1 = size v2 -> - dotprodR v1 (rev v2) = dotprodR (rev v1) v2. +(** Reversing the second argument is equivalent to reversing the + first, when both lists have the same length. *) + +Lemma dotprodR_rev : + forall (v1 v2 : list R), + size v1 = size v2 -> + dotprodR v1 (rev v2) = dotprodR (rev v1) v2. Proof. -intros. -rewrite /dotprodR /dotprod -{1}(revK v1) -rev_zip ?size_rev //. -rewrite {2}flip_Rplus map_rev foldl_rev foldl_foldr //; compute; intros; lra. + intros. + rewrite /dotprodR /dotprod + -{1}(revK v1) -rev_zip ?size_rev //. + rewrite {2}flip_Rplus map_rev foldl_rev foldl_foldr //; + compute; intros; lra. Qed. -Lemma map_FR2_zip: forall {t} (v1 v2: seq (ftype t)), - map FR2 (zip v1 v2) = zip (map FT2R v1) (map FT2R v2). +(** Zipping then mapping [FR2] is the same as mapping [FT2R] + componentwise before zipping. *) + +Lemma map_FR2_zip : + forall {t} (v1 v2 : seq (ftype t)), + map FR2 (zip v1 v2) = zip (map FT2R v1) (map FT2R v2). Proof. -induction v1; destruct v2; simpl; f_equal; auto. + induction v1; destruct v2; simpl; f_equal; auto. Qed. -Lemma map_Rabsp_zip: forall (v1 v2: seq R), - map Rabsp (zip v1 v2) = zip (map Rabs v1) (map Rabs v2). +(** Zipping then mapping [Rabsp] is the same as mapping [Rabs] + componentwise before zipping. *) + +Lemma map_Rabsp_zip : + forall (v1 v2 : seq R), + map Rabsp (zip v1 v2) = zip (map Rabs v1) (map Rabs v2). Proof. -induction v1; destruct v2; simpl; f_equal; auto. + induction v1; destruct v2; simpl; f_equal; auto. Qed. +(** The real dot product of the real images of v1 and v2 satisfies + [R_dot_prod_rel] on the reversed [FR2]-mapped zip, and equals the + [sum_fold] of the pointwise products. *) + Lemma R_dot_prod_rel_fold_right t : -forall (v1 v2: list (ftype t)) , - size v1 = size v2 -> - let prods := map (uncurry Rmult) (map FR2 (zip v1 v2)) in + forall (v1 v2 : list (ftype t)), + size v1 = size v2 -> + let prods := + map (uncurry Rmult) (map FR2 (zip v1 v2)) in R_dot_prod_rel (rev (map FR2 (zip v1 v2))) (sum_fold prods). Proof. -intros. -subst prods. -rewrite map_FR2_zip. -move :(dotprodR_rel (rev (map FT2R v1)) (rev (map FT2R v2))). -rewrite dotprodR_rev ?size_rev ?size_map // revK /sum_fold /dotprodR /dotprod - foldl_foldr //. -2,3: compute; intros; lra. -rewrite -rev_zip ?size_map ?flip_Rplus //. + intros. + subst prods. + rewrite map_FR2_zip. + move :(dotprodR_rel (rev (map FT2R v1)) (rev (map FT2R v2))). + rewrite dotprodR_rev ?size_rev ?size_map // revK + /sum_fold /dotprodR /dotprod + foldl_foldr //. + 2,3: compute; intros; lra. + rewrite -rev_zip ?size_map ?flip_Rplus //. Qed. +(** Variant of [R_dot_prod_rel_fold_right] expressing the sum as + [dotprodR (map FT2R v1) (map FT2R v2)]. *) + Lemma R_dot_prod_rel_fold_right' t : -forall (v1 v2: list (ftype t)) , - size v1 = size v2 -> - let prods := map (uncurry Rmult) (map FR2 (zip v1 v2)) in - R_dot_prod_rel (rev (map FR2 (zip v1 v2))) (dotprodR (map FT2R v1) (map FT2R v2)). + forall (v1 v2 : list (ftype t)), + size v1 = size v2 -> + let prods := + map (uncurry Rmult) (map FR2 (zip v1 v2)) in + R_dot_prod_rel (rev (map FR2 (zip v1 v2))) + (dotprodR (map FT2R v1) (map FT2R v2)). Proof. -intros. -replace (dotprodR _ _) with (sum_fold prods). -apply R_dot_prod_rel_fold_right; auto. -rewrite sum_rev /sum_fold /dotprodR /dotprod -foldl_rev revK /prods map_FR2_zip //. + intros. + replace (dotprodR _ _) with (sum_fold prods). + apply R_dot_prod_rel_fold_right; auto. + rewrite sum_rev /sum_fold /dotprodR /dotprod + -foldl_rev revK /prods map_FR2_zip //. Qed. +(** [R_dot_prod_rel] for the absolute-value dot product: the reversed zip + of [Rabsp ∘ FR2] satisfies the relation with value [sum_fold prods]. *) + Lemma R_dot_prod_rel_fold_right_Rabs t : -forall (v1 v2: list (ftype t)) , - size v1 = size v2 -> - let prods := map (uncurry Rmult) (map Rabsp (map FR2 (zip v1 v2))) in - R_dot_prod_rel (rev (map Rabsp (map FR2 (zip v1 v2)))) (sum_fold prods). + forall (v1 v2 : list (ftype t)), + size v1 = size v2 -> + let prods := + map (uncurry Rmult) (map Rabsp (map FR2 (zip v1 v2))) in + R_dot_prod_rel (rev (map Rabsp (map FR2 (zip v1 v2)))) + (sum_fold prods). Proof. -intros. -subst prods. -rewrite map_FR2_zip map_Rabsp_zip. -move :(dotprodR_rel (rev (map Rabs (map FT2R v1))) (rev (map Rabs (map FT2R v2)))). -rewrite dotprodR_rev ?size_rev ?size_map // revK /sum_fold /dotprodR /dotprod - foldl_foldr //. -2,3: compute; intros; lra. -rewrite -rev_zip ?size_map ?flip_Rplus //. + intros. + subst prods. + rewrite map_FR2_zip map_Rabsp_zip. + move :(dotprodR_rel (rev (map Rabs (map FT2R v1))) + (rev (map Rabs (map FT2R v2)))). + rewrite dotprodR_rev ?size_rev ?size_map // revK + /sum_fold /dotprodR /dotprod + foldl_foldr //. + 2,3: compute; intros; lra. + rewrite -rev_zip ?size_map ?flip_Rplus //. Qed. +(** Variant of [R_dot_prod_rel_fold_right_Rabs] expressing the sum as + [dotprodR (map Rabs (map FT2R v1)) (map Rabs (map FT2R v2))]. *) + Lemma R_dot_prod_rel_fold_right_Rabs' t : -forall (v1 v2: list (ftype t)) , - size v1 = size v2 -> - let prods := map (uncurry Rmult) (map Rabsp (map FR2 (zip v1 v2))) in - R_dot_prod_rel (rev (map Rabsp (map FR2 (zip v1 v2)))) (dotprodR (map Rabs (map FT2R v1)) (map Rabs (map FT2R v2))). + forall (v1 v2 : list (ftype t)), + size v1 = size v2 -> + let prods := + map (uncurry Rmult) (map Rabsp (map FR2 (zip v1 v2))) in + R_dot_prod_rel (rev (map Rabsp (map FR2 (zip v1 v2)))) + (dotprodR (map Rabs (map FT2R v1)) + (map Rabs (map FT2R v2))). Proof. -intros. -replace (dotprodR _ _) with (sum_fold prods). -apply R_dot_prod_rel_fold_right_Rabs; auto. -rewrite sum_rev /sum_fold /dotprodR /dotprod -foldl_rev revK /prods map_FR2_zip map_Rabsp_zip //. + intros. + replace (dotprodR _ _) with (sum_fold prods). + apply R_dot_prod_rel_fold_right_Rabs; auto. + rewrite sum_rev /sum_fold /dotprodR /dotprod + -foldl_rev revK /prods map_FR2_zip map_Rabsp_zip //. Qed. -Lemma R_dot_prod_rel_single rs a: -R_dot_prod_rel [::a] rs -> rs = (fst a * snd a). +(** If the pair list is a singleton a, then the dot product + relation forces [rs = fst a * snd a]. *) + +Lemma R_dot_prod_rel_single rs a : + R_dot_prod_rel [:: a] rs -> rs = fst a * snd a. Proof. -intros. -inversion H. -inversion H3; subst. -apply Rplus_0_r. + intros. + inversion H. + inversion H3; subst. + apply Rplus_0_r. Qed. -Lemma R_dot_prod_rel_single' a: -R_dot_prod_rel [::a] (fst a * snd a). +(** The converse of [R_dot_prod_rel_single]: the singleton relation + holds with value [fst a * snd a]. *) + +Lemma R_dot_prod_rel_single' a : + R_dot_prod_rel [:: a] (fst a * snd a). Proof. -replace (fst a * snd a)%Re with (fst a * snd a + 0)%Re by apply Rplus_0_r. -apply R_dot_prod_rel_cons; apply R_dot_prod_rel_nil. + replace (fst a * snd a)%Re with (fst a * snd a + 0)%Re + by apply Rplus_0_r. + apply R_dot_prod_rel_cons; apply R_dot_prod_rel_nil. Qed. +(** When all pairs in [l] have been replaced by their absolute values + (via [Rabsp]), the resulting sum is non-negative, so [Rabs s = s]. *) + Lemma R_dot_prod_rel_Rabs_eq : -forall l s, -R_dot_prod_rel (map Rabsp l) s -> Rabs s = s. + forall l s, + R_dot_prod_rel (map Rabsp l) s -> Rabs s = s. Proof. -induction l; intros; inversion H; clear H; subst. -apply Rabs_R0. -unfold Rabsp. destruct a; simpl. -replace (Rabs(Rabs r * Rabs r0 + s0))%Re with - (Rabs r * Rabs r0 + s0)%Re; try nra. -symmetry. -rewrite Rabs_pos_eq; try nra. -apply Rplus_le_le_0_compat. -apply Rmult_le_pos; -apply Rabs_pos. -rewrite <- IHl; try apply Rabs_pos; auto. + induction l; intros; inversion H; clear H; subst. + apply Rabs_R0. + unfold Rabsp. destruct a; simpl. + replace (Rabs (Rabs r * Rabs r0 + s0))%Re + with (Rabs r * Rabs r0 + s0)%Re; + try nra. + symmetry. + rewrite Rabs_pos_eq; try nra. + apply Rplus_le_le_0_compat. + apply Rmult_le_pos; apply Rabs_pos. + rewrite <- IHl; try apply Rabs_pos; auto. Qed. +(** The absolute value of the exact dot product is bounded by the + absolute-value dot product: [|⟨u, v⟩| ≤ ⟨|u|, |v|⟩]. *) + Lemma dot_prod_sum_rel_R_Rabs : -forall l s1 s2, -R_dot_prod_rel l s1 -> R_dot_prod_rel (map Rabsp l) s2 -> Rabs s1 <= Rabs s2. + forall l s1 s2, + R_dot_prod_rel l s1 -> + R_dot_prod_rel (map Rabsp l) s2 -> + Rabs s1 <= Rabs s2. Proof. -induction l. -{ intros. -inversion H. -inversion H0. -nra. } -intros. -inversion H; subst; clear H. -inversion H0; subst; clear H0. -unfold Rabsp; destruct a; simpl. -eapply Rle_trans; [ -apply Rabs_triang |]. -replace (Rabs (Rabs r * Rabs r0 + s0))%Re with - (Rabs r * Rabs r0 + s0)%Re. -eapply Rplus_le_compat; try nra. -rewrite Rabs_mult; nra. -rewrite <- (R_dot_prod_rel_Rabs_eq l); auto. -symmetry. -rewrite Rabs_pos_eq; try nra. -apply Rplus_le_le_0_compat. -apply Rmult_le_pos; -apply Rabs_pos. -rewrite <- (R_dot_prod_rel_Rabs_eq l); auto. -apply Rabs_pos. + induction l. + { intros. + inversion H. inversion H0. nra. } + intros. + inversion H; subst; clear H. + inversion H0; subst; clear H0. + unfold Rabsp; destruct a; simpl. + eapply Rle_trans; [apply Rabs_triang |]. + replace (Rabs (Rabs r * Rabs r0 + s0))%Re + with (Rabs r * Rabs r0 + s0)%Re. + eapply Rplus_le_compat; try nra. + rewrite Rabs_mult; nra. + rewrite <- (R_dot_prod_rel_Rabs_eq l); auto. + symmetry. + rewrite Rabs_pos_eq; try nra. + apply Rplus_le_le_0_compat. + apply Rmult_le_pos; apply Rabs_pos. + rewrite <- (R_dot_prod_rel_Rabs_eq l); auto. + apply Rabs_pos. Qed. -Lemma dot_prod_zip_map_Rmult a u v r: -size u = size v -> -R_dot_prod_rel (zip u v) r -> -R_dot_prod_rel (zip (map (Rmult a) u) v) (a * r). +(** Scaling the left factor of every pair by a scalar << a >> scales the + dot product by a. Formally, if [R_dot_prod_rel (zip u v) r] + then [R_dot_prod_rel (zip (map (Rmult a) u) v) (a * r)]. *) + +Lemma dot_prod_zip_map_Rmult a u v r : + size u = size v -> + R_dot_prod_rel (zip u v) r -> + R_dot_prod_rel (zip (map (Rmult a) u) v) (a * r). Proof. -intros. -move :(dotprodR_rel u v) => H1. -move :(dotprodR_rel_inj _ _ _ H0 H1) => H2. -subst r. -clear H0 H1. -move :(dotprodR_rel (map (Rmult a) u) v) => H3. -replace (Rmult a (dotprodR u v)) with (dotprodR (map (Rmult a) u) v); auto. -clear - H. -unfold dotprodR, dotprod. -rewrite !foldl_foldr. -2,3,4,5: compute; intros; lra. -revert v H; induction u; destruct v; intros; inversion H; clear H; subst; simpl. -compute; lra. -rewrite IHu; auto. -rewrite {1 3}/Basics.flip. -lra. + intros. + move :(dotprodR_rel u v) => H1. + move :(dotprodR_rel_inj _ _ _ H0 H1) => H2. + subst r. clear H0 H1. + move :(dotprodR_rel (map (Rmult a) u) v) => H3. + replace (Rmult a (dotprodR u v)) with (dotprodR (map (Rmult a) u) v); auto. + clear - H. + unfold dotprodR, dotprod. + rewrite !foldl_foldr. + 2,3,4,5: compute; intros; lra. + revert v H; induction u; destruct v; intros; inversion H; clear H; subst; + simpl. + compute; lra. + rewrite IHu; auto. + rewrite {1 3}/Basics.flip. lra. Qed. +(** Given a floating-point dot-product computation witnessed by + [dot_prod_rel], there exists a real value rp related to the + real-image list by [R_dot_prod_rel]. *) + Lemma dotprod_rel_R_exists {NAN : FPCore.Nans} {t : type} : - forall (l : list (ftype t * ftype t)) (fp : ftype t) - (Hfp : dot_prod_rel l fp), - exists rp, R_dot_prod_rel (map FR2 l) rp. + forall (l : list (ftype t * ftype t)) (fp : ftype t), + dot_prod_rel l fp -> + exists rp, R_dot_prod_rel (map FR2 l) rp. Proof. -intros ?. induction l. -{ simpl; exists 0. apply R_dot_prod_rel_nil. } -intros. inversion Hfp; subst. -destruct (IHl s H2) as (rs & Hrs); clear IHl. -exists (FT2R (fst a) * FT2R (snd a) + rs); simpl. -apply R_dot_prod_rel_cons; auto. + intros ?. induction l. + { simpl; exists 0. apply R_dot_prod_rel_nil. } + intros ? H2. inversion H2; subst. + destruct (IHl s H3) as (rs & Hrs); clear IHl. + exists (FT2R (fst a) * FT2R (snd a) + rs); simpl. + apply R_dot_prod_rel_cons; auto. Qed. +(** FMA analogue of [dotprod_rel_R_exists]: given a computation + witnessed by [fma_dot_prod_rel], there exists a real value related + to the real-image list. *) + Lemma dotprod_rel_R_exists_fma {NAN : FPCore.Nans} {t : type} : - forall (l : list (ftype t * ftype t)) (fp : ftype t) - (Hfp : fma_dot_prod_rel l fp), - exists rp, R_dot_prod_rel (map FR2 l) rp. + forall (l : list (ftype t * ftype t)) (fp : ftype t), + fma_dot_prod_rel l fp -> + exists rp, R_dot_prod_rel (map FR2 l) rp. Proof. -intros ?. induction l. -{ simpl; exists 0. apply R_dot_prod_rel_nil. } -intros. inversion Hfp; subst. -destruct (IHl s H2) as (rs & Hrs); clear IHl. -exists (FT2R (fst a) * FT2R (snd a) + rs); simpl. -apply R_dot_prod_rel_cons; auto. + intros ?. induction l. + { simpl; exists 0. apply R_dot_prod_rel_nil. } + intros ? H2. inversion H2; subst. + destruct (IHl s H3) as (rs & Hrs); clear IHl. + exists (FT2R (fst a) * FT2R (snd a) + rs); simpl. + apply R_dot_prod_rel_cons; auto. Qed. +(** FMA analogue for the absolute-value dot product: given + [fma_dot_prod_rel l fp], there exists a real value related to + [map Rabsp (map FR2 l)]. *) + Lemma sum_rel_R_abs_exists_fma {NAN : FPCore.Nans} {t : type} : - forall (l : list (ftype t * ftype t)) (fp : ftype t) - (Hfp : fma_dot_prod_rel l fp), - exists rp, R_dot_prod_rel (map Rabsp (map FR2 l)) rp. + forall (l : list (ftype t * ftype t)) (fp : ftype t), + fma_dot_prod_rel l fp -> + exists rp, R_dot_prod_rel (map Rabsp (map FR2 l)) rp. Proof. -intros ?. induction l. -{ simpl; exists 0. apply R_dot_prod_rel_nil. } -intros. inversion Hfp; subst. -destruct (IHl s H2) as (rs & Hrs); clear IHl. -exists (Rabs (FT2R (fst a)) * Rabs (FT2R (snd a)) + rs); simpl. -apply R_dot_prod_rel_cons; auto. + intros ?. induction l. + { simpl; exists 0. apply R_dot_prod_rel_nil. } + intros ? H2. inversion H2; subst. + destruct (IHl s H3) as (rs & Hrs); clear IHl. + exists (Rabs (FT2R (fst a)) * Rabs (FT2R (snd a)) + rs); simpl. + apply R_dot_prod_rel_cons; auto. Qed. -Lemma dotprodR_rel_bound' : - forall (t : type) (l : list (ftype t * ftype t)) (rp a: R) - (Ha : 0 <= a) - (Hrp : R_dot_prod_rel (map FR2 l) rp) - (Hin : forall x, In x l -> Rabs (FT2R (fst x)) <= sqrt a /\ Rabs (FT2R (snd x)) <= sqrt a), - Rabs rp <= INR (length l) * a. +(** Component-wise bound on the real dot product: + if every component of the pairs in [l] has absolute value at most + [√a], then [|⟨u, v⟩| ≤ n · a], where [n = length l]. + This is a standard building block for rounding error bounds. *) + +Lemma dotprodR_rel_bound' : + forall (t : type) (l : list (ftype t * ftype t)) (rp a : R) + (Ha : 0 <= a) + (Hrp : R_dot_prod_rel (map FR2 l) rp) + (Hin : forall x, In x l -> + Rabs (FT2R (fst x)) <= sqrt a /\ + Rabs (FT2R (snd x)) <= sqrt a), + Rabs rp <= INR (length l) * a. Proof. -induction l; intros. -{ inversion Hrp; subst; simpl; rewrite Rabs_R0; nra. } - inversion Hrp; subst. - eapply Rle_trans; [apply Rabs_triang|]. - eapply Rle_trans; [apply Rplus_le_compat | ]. + induction l; intros. + { inversion Hrp; subst; simpl; rewrite Rabs_R0; nra. } + inversion Hrp; subst. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans; [apply Rplus_le_compat |]. rewrite Rabs_mult; apply Rmult_le_compat; try apply Rabs_pos. apply Hin; simpl; auto. apply Hin; simpl; auto. - apply IHl; auto; [ apply Ha| intros; apply Hin; simpl; auto]. + apply IHl; auto; + [ apply Ha | intros; apply Hin; simpl; auto]. rewrite sqrt_def; auto. apply Req_le; - replace (length (a::l)) with ( S(length l)) by auto. + replace (length (a :: l)) with (S (length l)) by auto. rewrite S_INR; nra. Qed. -Lemma dotprodR_rel_bound'' : - forall (t : type) (l : list (ftype t * ftype t)) (rs_abs a: R) - (Ha : 0 <= a) - (Hrp : R_dot_prod_rel (map Rabsp (map FR2 l)) rs_abs) - (Hin : forall x, In x l -> Rabs (FT2R (fst x)) <= sqrt a /\ Rabs (FT2R (snd x)) <= sqrt a), - rs_abs <= INR (length l) * a. +(** Variant of [dotprodR_rel_bound'] for the absolute-value + relation: if [R_dot_prod_rel (map Rabsp (map FR2 l)) rs_abs] and + every component has absolute value at most √a, then + rs_abs ≤ n · a. *) + +Lemma dotprodR_rel_bound'' : + forall (t : type) (l : list (ftype t * ftype t)) (rs_abs a : R) + (Ha : 0 <= a) + (Hrp : R_dot_prod_rel (map Rabsp (map FR2 l)) rs_abs) + (Hin : forall x, In x l -> + Rabs (FT2R (fst x)) <= sqrt a /\ + Rabs (FT2R (snd x)) <= sqrt a), + rs_abs <= INR (length l) * a. Proof. -induction l; intros; inversion Hrp; clear Hrp; subst. -compute; nra. - eapply Rle_trans; [ apply Rplus_le_compat | ]. - apply Rmult_le_compat; - [ destruct a; simpl; apply Rabs_pos | destruct a; simpl; apply Rabs_pos | | ]. + induction l; intros; inversion Hrp; clear Hrp; subst. + compute; nra. + eapply Rle_trans; [apply Rplus_le_compat |]. + apply Rmult_le_compat; + [ destruct a; simpl; apply Rabs_pos + | destruct a; simpl; apply Rabs_pos | | ]. apply Hin; simpl; auto. apply Hin; simpl; auto. - apply IHl; auto; [ apply Ha| intros; apply Hin; simpl; auto]. + apply IHl; auto; + [ apply Ha | intros; apply Hin; simpl; auto]. rewrite sqrt_def; auto. apply Req_le; - replace (length (a::l)) with ( S(length l)) by auto. + replace (length (a :: l)) with (S (length l)) by auto. rewrite S_INR; nra. Qed. - End RealDotProd. +(* ------------------------------------------------------------------ *) +(** ** Non-zero detection for dot products *) +(* ------------------------------------------------------------------ *) Section NonZeroDP. -Context {NAN: FPCore.Nans} {t : type}. +Context {NAN : FPCore.Nans} {t : type}. Variables (v1 v2 : list (ftype t)). Hypothesis (Hlen : size v1 = size v2). Notation v1R := (map FT2R v1). -Lemma Req_eq: forall x y, Req_bool x y = eq_op x y. +(** [Req_bool] and the boolean equality [eq_op] coincide on ℝ. *) + +Lemma Req_eq : forall x y, Req_bool x y = eq_op x y. Proof. -intros. -destruct (Req_bool_spec x y); symmetry; apply /eqP ; auto. + intros. + destruct (Req_bool_spec x y); symmetry; apply /eqP; auto. Qed. +(** If every component of [v1] has real value 0 (i.e. [nnzR v1R = 0]), + and the IEEE dot product [fp] is finite, then [FT2R fp = 0]. + This justifies early exit when the first operand is the zero + vector. *) + Lemma dot_prod_rel_nnzR : -forall -(fp : ftype t) -(Hfp : dot_prod_rel (zip v1 v2) fp) -(Hfin: Binary.is_finite fp = true), -nnzR v1R == 0%nat -> FT2R fp = 0. + forall (fp : ftype t) + (Hfp : dot_prod_rel (zip v1 v2) fp) + (Hfin : Binary.is_finite fp = true), + nnzR v1R == 0%nat -> FT2R fp = 0. Proof. -intros. -rewrite nnzR_lemma in H. -revert H Hfp Hlen Hfin. revert v2 fp. -induction v1; intros. destruct v2; try discriminate; inversion Hfp; auto. -inversion Hfp; subst. -rewrite /pos_zero /Zconst => //=. -destruct xy => //=. -simpl BPLUS in Hfin, Hfp. -destruct v2 as [ | v2a v2r]; [discriminate |]. -inversion H0; clear H0; subst. -move :H => /= /andP [H H0]. -move : (BPLUS_correct _ _ Hfin) => [[H2 H3] H4]. -rewrite {}H4. -have Hs: FT2R s = 0 by (apply (IHl v2r) => //; auto). -rewrite Hs Rplus_0_r. -have Ha: FT2R a = 0 by move: H => /eqP //. -rewrite (proj2 (BMULT_correct _ _ H2)). -rewrite Ha Rmult_0_l !Generic_fmt.round_0 //. + Print nnzR. + intros. + rewrite nnzR_lemma in H. + revert H Hfp Hlen Hfin. revert v2 fp. + induction v1; intros. + - destruct v2; try discriminate; inversion Hfp; auto. + - inversion Hfp; subst. + rewrite /pos_zero /Zconst => //=. + destruct xy => //=. + simpl BPLUS in Hfin, Hfp. + destruct v2 as [ | v2a v2r]; [discriminate |]. + inversion H0; clear H0; subst. + move :H => /= /andP [H H0]. + move : (BPLUS_correct _ _ Hfin) => [[H2 H3] H4]. + rewrite {}H4. + have Hs: FT2R s = 0 by (apply (IHl v2r) => //; auto). + rewrite Hs Rplus_0_r. + have Ha: FT2R a = 0 by move: H => /eqP //. + rewrite (proj2 (BMULT_correct _ _ H2)). + rewrite Ha Rmult_0_l !Generic_fmt.round_0 //. Qed. +(** FMA analogue of [dot_prod_rel_nnzR]: when [nnzR v1R = 0] and the + FMA dot product is finite, [FT2R fp = 0]. *) + Lemma fma_dot_prod_rel_nnzR : -forall -(fp : ftype t) -(Hfp : fma_dot_prod_rel (zip v1 v2) fp) -(Hfin: Binary.is_finite fp = true), -nnzR v1R == 0%nat -> FT2R fp = 0. + forall (fp : ftype t) + (Hfp : fma_dot_prod_rel (zip v1 v2) fp) + (Hfin : Binary.is_finite fp = true), + nnzR v1R == 0%nat -> FT2R fp = 0. Proof. -intros. -rewrite nnzR_lemma in H. -move : v2 fp H Hfp Hlen Hfin. -clear Hlen. -induction v1; destruct v0; intros; inversion Hlen; clear Hlen. -inversion Hfp; auto. -inversion Hfp; clear Hfp; subst. -rewrite /Zconst => //=. -move :H => /= /andP [H8 H9]. -move : (BFMA_correct _ _ _ Hfin) => /= [[H2 [H3 H6]] H7]. -rewrite H7. -rewrite (IHl _ _ H9 H4); auto. -move :H8 => /eqP => H8. -rewrite -H8. -rewrite Rplus_0_r Rmult_0_l !Generic_fmt.round_0 //. + intros. + rewrite nnzR_lemma in H. + move : v2 fp H Hfp Hlen Hfin. + clear Hlen. + induction v1; destruct v0; intros; inversion Hlen; clear Hlen. + - inversion Hfp; auto. + - inversion Hfp; clear Hfp; subst. + rewrite /Zconst => //=. + move :H => /= /andP [H8 H9]. + move : (BFMA_correct _ _ _ Hfin) => /= [[H2 [H3 H6]] H7]. + rewrite H7. + rewrite (IHl _ _ H9 H4); auto. + move :H8 => /eqP => H8. + rewrite -H8. + rewrite Rplus_0_r Rmult_0_l !Generic_fmt.round_0 //. Qed. +(** Real analogue: when [nnzR v1R = 0], the real dot product [rp] + obtained via [R_dot_prod_rel (map FR2 (zip v1 v2)) rp] is 0. *) + Lemma R_dot_prod_rel_nnzR : -forall -(rp : R) -(Hrp : R_dot_prod_rel (map FR2 (zip v1 v2)) rp), -nnzR v1R == 0%nat -> rp = 0. + forall (rp : R) + (Hrp : R_dot_prod_rel (map FR2 (zip v1 v2)) rp), + nnzR v1R == 0%nat -> rp = 0. Proof. -intros ? ? H. -rewrite nnzR_lemma in H. -revert v2 rp H Hrp Hlen. -induction v1; intros. -destruct v2; try discriminate; auto. -inversion Hrp; auto. -destruct v2; try discriminate; auto. -inversion Hrp; subst. -unfold FR2, fst, snd. -move :H => /= /andP [H H0]. -move :H => /eqP H. -simpl in Hlen. -rewrite -H Rmult_0_l. -rewrite (IHl _ _ H0 H3). lra. lia. + intros ? ? H. + rewrite nnzR_lemma in H. + revert v2 rp H Hrp Hlen. + induction v1; intros. + - destruct v2; try discriminate; auto. + inversion Hrp; auto. + - destruct v2; try discriminate; auto. + inversion Hrp; subst. + unfold FR2, fst, snd. + move :H => /= /andP [H H0]. + move :H => /eqP H. + simpl in Hlen. + rewrite -H Rmult_0_l. + rewrite (IHl _ _ H0 H3). lra. lia. Qed. +(** Absolute-value dot product analogue of [R_dot_prod_rel_nnzR]: + when [nnzR v1R = 0], the absolute-value sum [rp_abs] is 0. *) + Lemma R_dot_prod_rel_nnzR_abs : -forall -(rp_abs : R) -(Hra : R_dot_prod_rel (map Rabsp (map FR2 (zip v1 v2))) rp_abs), -nnzR v1R == 0%nat -> rp_abs = 0. + forall (rp_abs : R) + (Hra : R_dot_prod_rel (map Rabsp (map FR2 (zip v1 v2))) rp_abs), + nnzR v1R == 0%nat -> rp_abs = 0. Proof. -intros ? ? H. -rewrite nnzR_lemma in H. -revert H Hra Hlen. revert v2 rp_abs . -induction v1; intros. -simpl in *. inversion Hra. auto. -destruct v2; try discriminate; auto. -destruct v2; try discriminate. -inversion Hra; subst. -unfold FR2, Rabsp, fst, snd. -move :H => /= /andP [H H0]. -move :H => /eqP H. -simpl in Hlen. -rewrite -H Rabs_R0 Rmult_0_l (IHl _ _ H0 H3). -lra. lia. + intros ? ? H. + rewrite nnzR_lemma in H. + revert H Hra Hlen. revert v2 rp_abs. + induction v1; intros. + - simpl in *. inversion Hra. auto. + destruct v2; try discriminate; auto. + - destruct v2; try discriminate. + inversion Hra; subst. + unfold FR2, Rabsp, fst, snd. + move :H => /= /andP [H H0]. + move :H => /eqP H. + simpl in Hlen. + rewrite -H Rabs_R0 Rmult_0_l (IHl _ _ H0 H3). + lra. lia. Qed. - End NonZeroDP. \ No newline at end of file diff --git a/accuracy_proofs/float_acc_lems.v b/accuracy_proofs/float_acc_lems.v index 7a656a2..1c54cff 100644 --- a/accuracy_proofs/float_acc_lems.v +++ b/accuracy_proofs/float_acc_lems.v @@ -1,640 +1,879 @@ -(* This file contains lemmas regarding the accuracy of floating point - operations such as BPLUS, BFMA, and BMULT. *) +(** * Floating-Point Operation Accuracy + + This file establishes accuracy lemmas for the basic floating-point + operations [BPLUS], [BMINUS], [BMULT], and [BFMA] as defined in the + VCFloat library. Each operation is analyzed in the + _round-to-nearest-even_ (RNE) rounding mode. + + The central results take the form of the standard _rounding error model_: + given that a floating-point operation does not overflow, its result [fl(op)] + satisfies a bound of the form + + %\[ \mathtt{fl}(a \mathbin{\mathrm{op}} b) = (a \mathbin{\mathrm{op}} b)(1 + \delta) + \varepsilon \]% + #\[ \mathtt{fl}(a \mathbin{\mathrm{op}} b) = (a \mathbin{\mathrm{op}} b)(1 + \delta) + \varepsilon \]# + + where the relative error %$\delta$%#\(\delta\)# and absolute error %$\varepsilon$%#\(\varepsilon\)# satisfy + + %\[ |\delta| \leq \mathbf{u}, \qquad |\varepsilon| \leq \eta, \qquad \delta \cdot \varepsilon = 0 \]% + #\[ |\delta| \leq \mathbf{u}, \qquad |\varepsilon| \leq \eta, \qquad \delta \cdot \varepsilon = 0 \]# + + and where %$\mathbf{u}$%#\(\mathbf{u}\)# denotes the _unit roundoff_ and %$\eta$%#\(\eta\)# denotes + the _underflow threshold_ for the given floating-point type << t >>. + The mutual exclusion condition %$\delta \cdot \varepsilon = 0$%#\(\delta \cdot \varepsilon = 0\)# reflects + the standard decomposition: in the normal range only relative error is incurred, while in the + subnormal range only absolute error is incurred. + + For [BPLUS] and [BMINUS] the error model simplifies to + %$\mathtt{fl}(a \mathbin{\mathrm{op}} b) = (a \mathbin{\mathrm{op}} b)(1 + \delta)$%#\(\mathtt{fl}(a \mathbin{\mathrm{op}} b) = (a \mathbin{\mathrm{op}} b)(1 + \delta)\)#, + since addition and subtraction on floating-point numbers that are + already representable incur only relative error. + + ** Structure + + The file is organized as follows: + + - _Signed zero lemmas_: behavior of operations when one argument is zero. + - _Signed zero facts_: [FT2R] evaluations at signed zeros. + - _No-overflow predicates_: [fma_no_overflow], [Bmult_no_overflow], + [Bplus_no_overflow], [Bminus_no_overflow]. + - _Generic rounding lemma_: [generic_round_property], which extracts + the %$(\delta, \varepsilon)$%#\((\delta, \varepsilon)\)# decomposition from the Flocq library for + an arbitrary real number. + - _Per-operation accuracy and finiteness lemmas_ for [BFMA], [BMULT], + [BPLUS], and [BMINUS], each in two forms: + - a form assuming a no-overflow hypothesis on real-valued arguments; + - a primed form ([fma_accurate'], [BMULT_accurate'], etc.) assuming + only that the _floating-point result_ is finite, from which the + no-overflow condition is derived automatically. + - _Correctness lemmas_ ([BFMA_correct], [BMULT_correct], [BPLUS_correct]): + these combine the finiteness of inputs and the rounding identity into + a single conclusion. +*) From LAProof.accuracy_proofs Require Import preamble common. Section GenFloat. -Context {NAN: FPCore.Nans} {t : type} . - -Lemma Bmult_0R (a f: ftype t) : -Binary.is_finite (BMULT a f) -> -FT2R a = 0 -> -(BMULT a f) = neg_zero \/ (BMULT a f) = pos_zero. + +Context {NAN : FPCore.Nans} {t : type}. + +(** ** Signed Zero Lemmas + + Behavior of [BMULT], [BPLUS], and [BFMA] when one operand evaluates + to zero under [FT2R]. These are used in higher-level proofs to + eliminate zero-valued terms from error bounds. *) + +Lemma Bmult_0R (a f : ftype t) : + Binary.is_finite (BMULT a f) -> + FT2R a = 0 -> + (BMULT a f) = neg_zero \/ (BMULT a f) = pos_zero. Proof. rewrite /BMULT/BINOP //= /pos_zero/neg_zero/Binary.Bmult. - destruct f,a,s,s0 => //=; + destruct f, a, s, s0 => //=; move => FIN HA; try discriminate FIN; auto; try apply Float_prop.eq_0_F2R in HA; repeat (destruct m0; try discriminate HA). Qed. -Lemma Bplus_0R (a f: ftype t) : -Binary.is_finite (BPLUS a f) -> -FT2R f = 0 -> -FT2R (BPLUS a f) = FT2R a. +Lemma Bplus_0R (a f : ftype t) : + Binary.is_finite (BPLUS a f) -> + FT2R f = 0 -> + FT2R (BPLUS a f) = FT2R a. Proof. - rewrite /BMULT/BINOP //= + rewrite /BMULT/BINOP //= /pos_zero/neg_zero/Binary.Bmult. - destruct f,a,s,s0 => //=; + destruct f, a, s, s0 => //=; move => FIN HA; try discriminate FIN; auto; try apply Float_prop.eq_0_F2R in HA; repeat (destruct m0; try discriminate HA). Qed. -Lemma Bfma_mult_0R (a f s : ftype t): -Binary.is_finite (BFMA a f s) -> -FT2R a = 0 -> -FT2R (BFMA a f s) = FT2R s. -Proof. - rewrite /BMULT/BINOP //= . - destruct f; +Lemma Bfma_mult_0R (a f s : ftype t) : + Binary.is_finite (BFMA a f s) -> + FT2R a = 0 -> + FT2R (BFMA a f s) = FT2R s. +Proof. + rewrite /BMULT/BINOP //=. + destruct f; destruct a; - destruct s; - destruct s0; destruct s1; destruct s => //=; + destruct s; + destruct s0; destruct s1; destruct s => //=; move => FIN HA; try discriminate FIN; auto; try apply Float_prop.eq_0_F2R in HA; repeat (destruct m0; try discriminate HA). Qed. -Fact neg_zero_is_finite: Binary.is_finite (@neg_zero t). -Proof. reflexivity. Qed. +(** ** Signed Zero Values + + Signed zeros are finite, and their + real-valued interpretations and that of [Zconst t 0] under [FT2R] + are all zero. These are used directly in arithmetic simplifications. *) -Fact FT2R_neg_zero : FT2R (@neg_zero t) = 0. +Fact neg_zero_is_finite : + Binary.is_finite (@neg_zero t). Proof. reflexivity. Qed. -Fact FT2R_pos_zero : FT2R (@pos_zero t) = 0. +Fact FT2R_neg_zero : + FT2R (@neg_zero t) = 0. Proof. reflexivity. Qed. -Fact FT2R_Zconst_0 : FT2R (Zconst t 0) = 0. +Fact FT2R_pos_zero : + FT2R (@pos_zero t) = 0. Proof. reflexivity. Qed. -Definition fma_no_overflow (x y z: R) : Prop := - (Rabs (rounded t (x * y + z)) < Raux.bpow Zaux.radix2 (femax t))%R. +Fact FT2R_Zconst_0 : + FT2R (Zconst t 0) = 0. +Proof. reflexivity. Qed. + +Notation fmax := (@fmax t). + +(** ** No-Overflow Predicates + + Each predicate asserts that the RNE-rounded exact result of the + corresponding operation has absolute value strictly less than + [fmax], i.e., the result falls within the normal + floating-point range. These are the preconditions for the main + accuracy lemmas and are derived automatically in the primed variants. *) + +Definition fma_no_overflow (x y z : R) : Prop := + (Rabs (rounded t (x * y + z)) < fmax)%R. + +Definition Bmult_no_overflow (x y : R) : Prop := + (Rabs (rounded t (x * y)) < fmax)%R. -Definition Bmult_no_overflow (x y: R) : Prop := - (Rabs (rounded t (x * y)) < Raux.bpow Zaux.radix2 (femax t))%R. +Definition Bplus_no_overflow (x y : R) : Prop := + (Rabs (Generic_fmt.round Zaux.radix2 + (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (x + y)) < fmax)%R. + +Definition Bminus_no_overflow (x y : R) : Prop := + (Rabs (Generic_fmt.round Zaux.radix2 + (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (x - y)) < fmax)%R. Notation D := (@default_rel t). Notation E := (@default_abs t). -Lemma generic_round_property: - forall (x: R), -exists delta epsilon : R, - delta * epsilon = 0 /\ - (Rabs delta <= D)%R /\ - (Rabs epsilon <= E)%R /\ - Generic_fmt.round Zaux.radix2 - (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) - x = (x * (1+delta)+epsilon)%Re. +(** ** Generic Rounding Error Decomposition + + The lemma below is the shared foundation for all per-operation accuracy + results. It packages the Flocq [error_N_FLT] theorem into the + %$(\delta, \varepsilon)$%#\((\delta, \varepsilon)\)# form used throughout this library. *) + +(** [generic_round_property] is the fundamental %$(\delta,\varepsilon)$%#\((\delta,\varepsilon)\)# + decomposition of RNE rounding error in the Flocq FLT format. + + The condition %$\delta * \varepsilon = 0$%#\(\delta * \varepsilon = 0\)# + asserts that exactly one of the two error terms is zero, reflecting the fact + that the two sources of rounding error are mutually exclusive: + + - In the _normal range_, error is purely relative: %$\varepsilon = 0$%#\(\varepsilon = 0\)#. + + - In the _subnormal range_, error + is purely absolute: %$\delta = 0$%#\(\delta = 0\)# + + No rounding event can simultaneously be in both regimes, so the two + error terms never both appear at once. *) + +Lemma generic_round_property : + forall (x : R), + exists delta epsilon : R, + delta * epsilon = 0 /\ + (Rabs delta <= D)%R /\ + (Rabs epsilon <= E)%R /\ + Generic_fmt.round Zaux.radix2 + (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + x = (x * (1 + delta) + epsilon)%Re. Proof. -intros. -destruct (Relative.error_N_FLT Zaux.radix2 (SpecFloat.emin (fprec t) (femax t)) (fprec t) - (fprec_gt_0 t) (fun x0 : Z => negb (Z.even x0)) x) - as [delta [epsilon [? [? [? ?]]]]]. -exists delta, epsilon. -repeat split; auto. + intros. + destruct (Relative.error_N_FLT Zaux.radix2 + (SpecFloat.emin (fprec t) (femax t)) (fprec t) + (fprec_gt_0 t) (fun x0 : Z => negb (Z.even x0)) x) + as [delta [epsilon [? [? [? ?]]]]]. + exists delta, epsilon. + repeat split; auto. Qed. -Lemma fma_accurate : - forall (x: ftype t) (FINx: Binary.is_finite x) - (y: ftype t) (FINy: Binary.is_finite y) - (z: ftype t) (FINz: Binary.is_finite z) - (FIN: fma_no_overflow (FT2R x) (FT2R y) (FT2R z)), +(** ** Fused Multiply-Add *) + +(** [fma_accurate] establishes the standard rounding error model for the + fused multiply-add operation. *) + +Lemma fma_accurate : + forall (x : ftype t) (FINx : Binary.is_finite x) + (y : ftype t) (FINy : Binary.is_finite y) + (z : ftype t) (FINz : Binary.is_finite z) + (FIN : fma_no_overflow (FT2R x) (FT2R y) (FT2R z)), exists delta, exists epsilon, - delta * epsilon = 0 /\ - Rabs delta <= D /\ - Rabs epsilon <= E /\ - (FT2R (BFMA x y z) = (FT2R x * FT2R y + FT2R z) * (1+delta) + epsilon)%Re. + delta * epsilon = 0 /\ + Rabs delta <= D /\ + Rabs epsilon <= E /\ + (FT2R (BFMA x y z) = (FT2R x * FT2R y + FT2R z) * (1 + delta) + epsilon)%Re. Proof. -move => x FINx y FINy z FINz FIN. -pose proof (Binary.Bfma_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) (FPCore.fma_nan (fprec t) (femax t) (fprec_gt_one t)) - BinarySingleNaN.mode_NE x y z FINx FINy FINz). -fold (@FT2R t) in H. -fold (@BFMA NAN t) in H. -cbv zeta in H. -pose proof ( - Raux.Rlt_bool_spec - (Rabs - (Generic_fmt.round Zaux.radix2 - (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode - BinarySingleNaN.mode_NE) (FT2R x * FT2R y + FT2R z))) - (Raux.bpow Zaux.radix2 (femax t))). -destruct H0. -- -destruct H as [? _]. -rewrite H. -apply generic_round_property. -- -red in FIN. unfold rounded in FIN. -Lra.lra. + move => x FINx y FINy z FINz FIN. + pose proof (Binary.Bfma_correct (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.fma_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y z FINx FINy FINz). + fold (@FT2R t) in H. + fold (@BFMA NAN t) in H. + cbv zeta in H. + pose proof ( + Raux.Rlt_bool_spec + (Rabs + (Generic_fmt.round Zaux.radix2 + (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (FT2R x * FT2R y + FT2R z))) + fmax). + unfold fmax in *. + destruct H0. + - destruct H as [? _]. + rewrite H. + apply generic_round_property. + - red in FIN. unfold rounded in FIN. + unfold fmax in *. + Lra.lra. Qed. -Lemma is_finite_fma_no_overflow : - forall (x y z: ftype t) - (HFINb : Binary.is_finite (BFMA x y z)), +(** [is_finite_fma_no_overflow] shows that finiteness of the result of an FMA + implies the no-overflow condition on the exact value %$x \cdot y + z$%#\(x \cdot y + z\)#. + This allows the primed form [fma_accurate'] to work directly from + a finiteness hypothesis. *) + +Lemma is_finite_fma_no_overflow : + forall (x y z : ftype t) + (HFINb : Binary.is_finite (BFMA x y z)), fma_no_overflow (FT2R x) (FT2R y) (FT2R z). Proof. -intros. -red. set (ov:= bpow Zaux.radix2 (femax t)). -pose proof Rle_or_lt ov (Rabs (rounded t (FT2R x * FT2R y + FT2R z))) as Hor; - destruct Hor; auto. -apply Rlt_bool_false in H. -assert (HFIN: Binary.is_finite x /\ - Binary.is_finite y /\ - Binary.is_finite z) - by (destruct x,y,z; destruct s; destruct s0; destruct s1; - simpl in *; try discriminate; repeat split; auto). -destruct HFIN as (A & B & C). -unfold rounded, ov in H. -pose proof (Binary.Bfma_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) (FPCore.fma_nan (fprec t) (femax t) (fprec_gt_one t)) - BinarySingleNaN.mode_NE x y z A B C) as H1. -simpl in H1, H. -rewrite H in H1; clear H. -fold (BFMA x y z) in *. -destruct (BFMA x y z); discriminate. + intros. + red. set (ov := bpow Zaux.radix2 (femax t)). + pose proof Rle_or_lt ov (Rabs (rounded t (FT2R x * FT2R y + FT2R z))) as Hor; + destruct Hor; auto. + apply Rlt_bool_false in H. + assert (HFIN : Binary.is_finite x /\ + Binary.is_finite y /\ + Binary.is_finite z) + by (destruct x, y, z; destruct s; destruct s0; destruct s1; + simpl in *; try discriminate; repeat split; auto). + destruct HFIN as (A & B & C). + unfold rounded, ov in H. + pose proof (Binary.Bfma_correct (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.fma_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y z A B C) as H1. + simpl in H1, H. + rewrite H in H1; clear H. + fold (BFMA x y z) in *. + destruct (BFMA x y z); discriminate. Qed. +(** [BFMA_finite_e] extracts finiteness of each individual argument from + finiteness of the FMA result. This is a standard regularity + property: the IEEE 754 FMA result is non-finite if any input is + non-finite (with the exception of a zero-times-infinity addend, which + produces a NaN). *) + Lemma BFMA_finite_e : - forall (a f u : ftype t) - (Hfin : Binary.is_finite (BFMA a f u)), - Binary.is_finite a /\ - Binary.is_finite f /\ - Binary.is_finite u. + forall (a f u : ftype t) + (Hfin : Binary.is_finite (BFMA a f u)), + Binary.is_finite a /\ + Binary.is_finite f /\ + Binary.is_finite u. +Proof. + intros. + repeat split; + destruct a, f, u; destruct s; destruct s0; destruct s1; + try discriminate; auto. +Qed. + +(** [is_finite_fma_no_overflow']: If all three inputs to an FMA are finite + and no overflow occurs, then the FMA result is finite. *) + +Lemma is_finite_fma_no_overflow' : + forall (x y z : ftype t) + (Hfinx : Binary.is_finite x = true) + (Hfiny : Binary.is_finite y = true) + (Hfinz : Binary.is_finite z = true) + (Hov : fma_no_overflow (FT2R x) (FT2R y) (FT2R z)), + Binary.is_finite (BFMA x y z) = true. Proof. -intros. -repeat split; -destruct a,f,u; destruct s; destruct s0; destruct s1; try discriminate; auto. + intros x y z Hfinx Hfiny Hfinz Hov. + pose proof (Binary.Bfma_correct + (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.fma_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE + x y z Hfinx Hfiny Hfinz) as H. + cbv zeta in H. + rewrite Rlt_bool_true in H. + - destruct H as [_ [HFIN _]]; exact HFIN. + - move: Hov; by rewrite /fma_no_overflow /rounded. Qed. -Lemma fma_accurate' : - forall (x y z : ftype t) - (FIN: Binary.is_finite (BFMA x y z)), + +(** [fma_accurate'] is the _finiteness-hypothesis_ form of [fma_accurate]: + it requires only that the result is finite, deriving the + no-overflow condition and input finiteness internally. *) + +Lemma fma_accurate' : + forall (x y z : ftype t) + (FIN : Binary.is_finite (BFMA x y z)), exists delta, exists epsilon, - delta * epsilon = 0 /\ - Rabs delta <= @default_rel t /\ - Rabs epsilon <= @default_abs t /\ - (FT2R (BFMA x y z) = (FT2R x * FT2R y + FT2R z) * (1+delta) + epsilon)%Re. + delta * epsilon = 0 /\ + Rabs delta <= @default_rel t /\ + Rabs epsilon <= @default_abs t /\ + (FT2R (BFMA x y z) = (FT2R x * FT2R y + FT2R z) * (1 + delta) + epsilon)%Re. +Proof. + intros. + pose proof (BFMA_finite_e _ _ _ FIN) as H; + destruct H as (A & B & C). + apply fma_accurate => //. + apply is_finite_fma_no_overflow; auto. +Qed. + +(** [BFMA_correct] combines [BFMA_finite_e] and [is_finite_fma_no_overflow] + to give the full correctness statement: finiteness of the result + implies finiteness of all inputs and a rounding identity. *) + +Lemma BFMA_correct (a b s : ftype t) : + Binary.is_finite (BFMA a b s) -> + (Binary.is_finite a /\ Binary.is_finite b /\ Binary.is_finite s) /\ + FT2R (BFMA a b s) = + Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (FT2R a * FT2R b + FT2R s). Proof. -intros. -pose proof (BFMA_finite_e _ _ _ FIN) as H; - destruct H as (A & B & C). -apply fma_accurate => //. -apply is_finite_fma_no_overflow; auto. + intros * FIN. + pose proof (is_finite_fma_no_overflow a b s FIN) as H4; + apply Rlt_bool_true in H4; + unfold common.rounded in H4. + assert (H : Binary.is_finite a = true /\ + Binary.is_finite b = true /\ + Binary.is_finite s = true). + { unfold BFMA, BINOP in FIN. + destruct a, b, s; auto; destruct s0, s1, s; discriminate. } + split; auto. + destruct H as [? [? ?]]. + pose proof (Binary.Bfma_correct (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.fma_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE + a b s H H0 H1) as H3; cbv zeta in H3. + fold (@FT2R t) in H3. + rewrite {}H4 in H3. + fold (BFMA a b s) in H3. + apply H3. Qed. -Lemma BMULT_accurate : - forall (x y : ftype t) (FIN: Bmult_no_overflow (FT2R x) (FT2R y)), +(** ** Floating-Point Multiplication *) + +(** [BMULT_accurate] establishes the rounding error model for multiplication, + which computes %$\mathtt{fl}(x \cdot y)$%#\(\mathtt{fl}(x \cdot y)\)# under RNE rounding. *) + +Lemma BMULT_accurate : + forall (x y : ftype t) + (FIN : Bmult_no_overflow (FT2R x) (FT2R y)), exists delta, exists epsilon, - delta * epsilon = 0 /\ - Rabs delta <= @default_rel t /\ - Rabs epsilon <= @default_abs t /\ - (FT2R (BMULT x y) = (FT2R x * FT2R y) * (1+delta) + epsilon)%Re. + delta * epsilon = 0 /\ + Rabs delta <= @default_rel t /\ + Rabs epsilon <= @default_abs t /\ + (FT2R (BMULT x y) = (FT2R x * FT2R y) * (1 + delta) + epsilon)%Re. Proof. -intros. -pose proof (Binary.Bmult_correct (fprec t) (femax t) (fprec_gt_0 t) (fprec_lt_femax t) - (FPCore.mult_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE x y). -cbv zeta in H. -pose proof ( - Raux.Rlt_bool_spec - (Rabs - (Generic_fmt.round Zaux.radix2 - (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode - BinarySingleNaN.mode_NE) - (Binary.B2R _ _ x * Binary.B2R _ _ y))) - (Raux.bpow Zaux.radix2 (femax t))). -fold (@FT2R t) in H,H0. -destruct H0. -- destruct H as [? _]. - unfold BMULT, BINOP. - rewrite {}H. - apply generic_round_property. -- -red in FIN. unfold rounded in FIN. -lra. + intros. + pose proof (Binary.Bmult_correct (fprec t) (femax t) (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.mult_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y). + cbv zeta in H. + pose proof ( + Raux.Rlt_bool_spec + (Rabs + (Generic_fmt.round Zaux.radix2 + (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (Binary.B2R _ _ x * Binary.B2R _ _ y))) fmax). + fold (@FT2R t) in H, H0. + unfold fmax in *. + destruct H0. + - destruct H as [? _]. + unfold BMULT, BINOP. + rewrite {}H. + apply generic_round_property. + - red in FIN. unfold rounded in FIN. + unfold fmax in *. + lra. Qed. +(** [is_finite_BMULT_no_overflow] shows that finiteness of [BMULT x y] + implies that the exact product %$x \cdot y$%#\(x \cdot y\)# does not overflow. *) + Lemma is_finite_BMULT_no_overflow : - forall (x y : ftype t) - (HFINb : Binary.is_finite (BMULT x y) ), + forall (x y : ftype t) + (HFINb : Binary.is_finite (BMULT x y)), Bmult_no_overflow (FT2R x) (FT2R y). Proof. -intros. -pose proof Rle_or_lt (bpow Zaux.radix2 (femax t)) - (Rabs (rounded t (FT2R x * FT2R y))) as Hor; - destruct Hor; auto. -apply Rlt_bool_false in H; red. -unfold rounded in H. -pose proof (Binary.Bmult_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) (FPCore.mult_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE x y) as - H0. -rewrite {}H in H0. -unfold BMULT, BINOP in HFINb. -destruct ((Binary.Bmult _ _ _ _ _ _ x y)); -simpl; try discriminate. + intros. + pose proof Rle_or_lt (bpow Zaux.radix2 (femax t)) + (Rabs (rounded t (FT2R x * FT2R y))) as Hor; + destruct Hor; auto. + apply Rlt_bool_false in H; red. + unfold rounded in H. + pose proof (Binary.Bmult_correct (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.mult_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y) as H0. + rewrite {}H in H0. + unfold BMULT, BINOP in HFINb. + destruct (Binary.Bmult _ _ _ _ _ _ x y); + simpl; try discriminate. Qed. -Lemma BMULT_accurate' : - forall - (x y : ftype t) - (FIN: Binary.is_finite (BMULT x y)), +(** [BMULT_accurate'] is the _finiteness-hypothesis_ form of + [BMULT_accurate]. *) + +Lemma BMULT_accurate' : + forall (x y : ftype t) + (FIN : Binary.is_finite (BMULT x y)), exists delta, exists epsilon, - delta * epsilon = 0 /\ - Rabs delta <= @default_rel t /\ - Rabs epsilon <= @default_abs t /\ - (FT2R (BMULT x y) = (FT2R x * FT2R y) * (1+delta) + epsilon)%Re. + delta * epsilon = 0 /\ + Rabs delta <= @default_rel t /\ + Rabs epsilon <= @default_abs t /\ + (FT2R (BMULT x y) = (FT2R x * FT2R y) * (1 + delta) + epsilon)%Re. Proof. -intros. -pose proof BMULT_accurate x y (is_finite_BMULT_no_overflow x y FIN); auto. + intros. + pose proof BMULT_accurate x y (is_finite_BMULT_no_overflow x y FIN); auto. Qed. +(** [BMULT_finite_e] extracts finiteness of each factor from finiteness + of the product. *) + Lemma BMULT_finite_e : - forall (a b : ftype t) (Hfin : Binary.is_finite (BMULT a b)), - Binary.is_finite a /\ Binary.is_finite b. + forall (a b : ftype t) + (Hfin : Binary.is_finite (BMULT a b)), + Binary.is_finite a /\ Binary.is_finite b. Proof. -unfold BMULT, BINOP; intros. -destruct a,b; inversion Hfin; clear Hfin; subst; auto. + unfold BMULT, BINOP; intros. + destruct a, b; inversion Hfin; clear Hfin; subst; auto. Qed. -Lemma BPLUS_finite_e : - forall (a b : ftype t) (Hfin : Binary.is_finite (BPLUS a b)), - Binary.is_finite a /\ Binary.is_finite b. -Proof. -unfold BPLUS, BINOP; intros. -destruct a,b; inversion Hfin; clear Hfin; subst; simpl; auto. -destruct s,s0; discriminate; auto. -Qed. +(** [BMULT_correct] gives the full correctness statement for multiplication: + finiteness of the result implies finiteness of the operands and a + rounding identity. *) -Lemma BMINUS_finite_sub : - forall (a b : ftype t) (Hfin : Binary.is_finite (BMINUS a b)), - Binary.is_finite a /\ Binary.is_finite b. +Lemma BMULT_correct (a b : ftype t) : + Binary.is_finite (BMULT a b) -> + (Binary.is_finite a /\ Binary.is_finite b) /\ + FT2R (BMULT a b) = + Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (FT2R a * FT2R b). Proof. -unfold BMINUS, BINOP; intros. -destruct a,b; inversion Hfin; clear Hfin; subst; simpl; auto. -destruct s,s0; discriminate; auto. + intros * FIN. + pose proof (is_finite_BMULT_no_overflow a b FIN). + apply Rlt_bool_true in H. + assert (Binary.is_finite a = true /\ Binary.is_finite b = true) + by (unfold is_true in *; unfold BMULT, BINOP in FIN; + destruct a, b; + simpl in FIN; split; try discriminate; auto; + match goal with + | H : Binary.is_finite + (Binary.Bplus _ _ _ _ _ _ + (Binary.B754_infinity _ _ ?s) + (Binary.B754_infinity _ _ ?s0)) = _ |- + Binary.is_finite _ = _ => + destruct s; destruct s0; try discriminate; auto + end). + split; auto. + destruct H0. + pose proof (Binary.Bmult_correct (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.mult_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE a b). + rewrite {}H in H2. + apply H2. Qed. +(** ** Floating-Point Addition *) -Definition Bplus_no_overflow (x y: R) : Prop := - (Rabs ( Generic_fmt.round Zaux.radix2 - (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode - BinarySingleNaN.mode_NE) (x + y )) < Raux.bpow Zaux.radix2 (femax t))%R. - +(** [BPLUS_B2R_zero] is a special case: adding the zero constant + to a finite value returns a result with the same + real value. *) -Lemma BPLUS_B2R_zero (a : ftype t): +Lemma BPLUS_B2R_zero (a : ftype t) : Binary.is_finite a -> FT2R (BPLUS a (Zconst t 0)) = FT2R a. Proof. -unfold BPLUS, BINOP, Zconst; intros; -destruct a; -unfold neg_zero; simpl; try discriminate; auto. -destruct s; simpl; auto. + unfold BPLUS, BINOP, Zconst; intros; + destruct a; + unfold neg_zero; simpl; try discriminate; auto. + destruct s; simpl; auto. Qed. +(** [BPLUS_accurate] establishes the standard relative rounding error + model for floating-point addition [BPLUS x y]. *) + Lemma BPLUS_accurate : - forall (x : ftype t) (FINx: Binary.is_finite x) - (y : ftype t) (FINy: Binary.is_finite y) - (FIN: Bplus_no_overflow (FT2R x) (FT2R y)), - exists delta, - Rabs delta <= @default_rel t /\ - (FT2R (BPLUS x y ) = (FT2R x + FT2R y) * (1+delta))%Re. + forall (x : ftype t) (FINx : Binary.is_finite x) + (y : ftype t) (FINy : Binary.is_finite y) + (FIN : Bplus_no_overflow (FT2R x) (FT2R y)), + exists delta, + Rabs delta <= @default_rel t /\ + (FT2R (BPLUS x y) = (FT2R x + FT2R y) * (1 + delta))%Re. Proof. -intros. -pose proof (Binary.Bplus_correct (fprec t) (femax t) (fprec_gt_0 t) - (fprec_lt_femax t) (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) - BinarySingleNaN.mode_NE x y FINx FINy). -cbv zeta in H. -pose proof ( - Raux.Rlt_bool_spec - (Rabs - (Generic_fmt.round Zaux.radix2 - (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode - BinarySingleNaN.mode_NE) - (Binary.B2R _ _ x + Binary.B2R _ _ y))) - (Raux.bpow Zaux.radix2 (femax t))). -destruct H0. -- -destruct H as [? _]. -unfold BPLUS, BINOP. -fold (@FT2R t) in *. -rewrite {}H. -assert (A: Generic_fmt.generic_format Zaux.radix2 - (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) - (FT2R x) ) - by apply Binary.generic_format_B2R. -assert (B: Generic_fmt.generic_format Zaux.radix2 - (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) - (FT2R y) ) - by apply Binary.generic_format_B2R. -assert (H1 := Plus_error.FLT_plus_error_N_ex Zaux.radix2 (SpecFloat.emin (fprec t) (femax t)) - (fprec t) (fun x0 : Z => negb (Z.even x0)) (FT2R x) (FT2R y) A B). -unfold Relative.u_ro in H1. fold (@default_rel t) in H1. -destruct H1 as (d & Hd & Hd'). -assert (H1: Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) - (FT2R x + FT2R y) = Generic_fmt.round Zaux.radix2 - (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) - (Generic_fmt.Znearest (fun x0 : Z => negb (Z.even x0))) - (FT2R x + FT2R y)) by auto. -rewrite <- H1 in Hd'. clear H1. -rewrite {}Hd'. -exists d; split; auto. -eapply Rle_trans; [apply Hd |]. -apply Rdiv_le_left. -apply Fourier_util.Rlt_zero_pos_plus1. -apply default_rel_gt_0. -eapply Rle_trans with (@default_rel t * 1); try nra. -- -red in FIN. -fold (@FT2R t) in *. -lra. + intros. + pose proof (Binary.Bplus_correct (fprec t) (femax t) (fprec_gt_0 t) + (fprec_lt_femax t) + (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y FINx FINy). + cbv zeta in H. + pose proof ( + Raux.Rlt_bool_spec + (Rabs + (Generic_fmt.round Zaux.radix2 + (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (Binary.B2R _ _ x + Binary.B2R _ _ y))) + (Raux.bpow Zaux.radix2 (femax t))). + destruct H0. + - destruct H as [? _]. + unfold BPLUS, BINOP. + fold (@FT2R t) in *. + rewrite {}H. + assert (A : Generic_fmt.generic_format Zaux.radix2 + (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) + (FT2R x)) + by apply Binary.generic_format_B2R. + assert (B : Generic_fmt.generic_format Zaux.radix2 + (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) + (FT2R y)) + by apply Binary.generic_format_B2R. + assert (H1 := Plus_error.FLT_plus_error_N_ex Zaux.radix2 + (SpecFloat.emin (fprec t) (femax t)) + (fprec t) (fun x0 : Z => negb (Z.even x0)) + (FT2R x) (FT2R y) A B). + unfold Relative.u_ro in H1. fold (@default_rel t) in H1. + destruct H1 as (d & Hd & Hd'). + assert (H1 : Generic_fmt.round Zaux.radix2 + (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (FT2R x + FT2R y) = + Generic_fmt.round Zaux.radix2 + (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) + (Generic_fmt.Znearest (fun x0 : Z => negb (Z.even x0))) + (FT2R x + FT2R y)) by auto. + rewrite <- H1 in Hd'. clear H1. + rewrite {}Hd'. + exists d; split; auto. + eapply Rle_trans; [apply Hd |]. + apply Rdiv_le_left. + apply Fourier_util.Rlt_zero_pos_plus1. + apply default_rel_gt_0. + eapply Rle_trans with (@default_rel t * 1); try nra. + - red in FIN. + fold (@FT2R t) in *. + unfold fmax in *. + lra. Qed. +(** [is_finite_sum_no_overflow] shows that finiteness of the result + implies the no-overflow condition on the exact sum %$x + y$%#\(x + y\)#. *) + Lemma is_finite_sum_no_overflow : - forall (x y: ftype t) - (HFINb : Binary.is_finite (BPLUS x y)), + forall (x y : ftype t) + (HFINb : Binary.is_finite (BPLUS x y)), Bplus_no_overflow (FT2R x) (FT2R y). Proof. -intros. -pose proof Rle_or_lt (bpow Zaux.radix2 (femax t)) (Rabs (rounded t (FT2R x + FT2R y))) as Hor; - destruct Hor; auto. -apply Rlt_bool_false in H. -assert (HFIN: Binary.is_finite x = true /\ Binary.is_finite y = true). -{ unfold BPLUS, BINOP in HFINb. - destruct x,y; - simpl in *; split; try discriminate; auto; - destruct s; destruct s0; simpl in *; try discriminate; auto. } -unfold rounded in H. -destruct HFIN as (A & B). -pose proof (Binary.Bplus_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE - x y A B) as - H0. -rewrite H in H0; -destruct H0 as ( C & _). -unfold BPLUS, BINOP in HFINb. -destruct ((Binary.Bplus (fprec t) (femax t) (fprec_gt_0 t) (fprec_lt_femax t) - (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE x y)); -simpl; try discriminate. + intros. + pose proof Rle_or_lt (bpow Zaux.radix2 (femax t)) + (Rabs (rounded t (FT2R x + FT2R y))) as Hor; + destruct Hor; auto. + apply Rlt_bool_false in H. + assert (HFIN : Binary.is_finite x = true /\ Binary.is_finite y = true). + { unfold BPLUS, BINOP in HFINb. + destruct x, y; + simpl in *; split; try discriminate; auto; + destruct s; destruct s0; simpl in *; try discriminate; auto. } + unfold rounded in H. + destruct HFIN as (A & B). + pose proof (Binary.Bplus_correct (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y A B) as H0. + rewrite H in H0; + destruct H0 as (C & _). + unfold BPLUS, BINOP in HFINb. + destruct (Binary.Bplus (fprec t) (femax t) (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y); + simpl; try discriminate. Qed. +(** [no_overflow_sum_is_finite] is the converse direction: if both + summands are finite and the exact sum does not overflow, then + the result is finite. *) + Lemma no_overflow_sum_is_finite : - forall (x y: ftype t) - (H1 : Binary.is_finite x) - (H2 : Binary.is_finite y) - (Hov : Bplus_no_overflow (FT2R x) (FT2R y)), + forall (x y : ftype t) + (H1 : Binary.is_finite x) + (H2 : Binary.is_finite y) + (Hov : Bplus_no_overflow (FT2R x) (FT2R y)), Binary.is_finite (BPLUS x y). Proof. -unfold Bplus_no_overflow, BPLUS, BINOP; intros. -pose proof (Binary.Bplus_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE - x y H1 H2) as - H0. -remember (Rlt_bool _ _ ) as HB; destruct HB. -destruct H0 as (_ & HP &_); auto. -exfalso. -fold (@FT2R t) in *. -unfold Rlt_bool in HeqHB. -remember (Rcompare _ _) as HR; destruct HR; try discriminate. -symmetry in HeqHR. -apply Rcompare_Eq_inv in HeqHR. -nra. -symmetry in HeqHR. -apply Rcompare_Gt_inv in HeqHR. -nra. + unfold Bplus_no_overflow, BPLUS, BINOP; intros. + pose proof (Binary.Bplus_correct (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y H1 H2) as H0. + remember (Rlt_bool _ _) as HB; destruct HB. + - destruct H0 as (_ & HP & _); auto. + - exfalso. + fold (@FT2R t) in *. + unfold Rlt_bool in HeqHB. + remember (Rcompare _ _) as HR; destruct HR; try discriminate. + + symmetry in HeqHR. + apply Rcompare_Eq_inv in HeqHR. + unfold fmax in *. + nra. + + symmetry in HeqHR. + apply Rcompare_Gt_inv in HeqHR. + unfold fmax in *. + nra. Qed. +(** [BPLUS_accurate'] is the _finiteness-hypothesis_ form of + [BPLUS_accurate]. *) + Lemma BPLUS_accurate' : - forall (x y : ftype t) - (FIN: Binary.is_finite (BPLUS x y)), - exists delta, - Rabs delta <= @default_rel t /\ - (FT2R (BPLUS x y ) = (FT2R x + FT2R y) * (1+delta))%Re. + forall (x y : ftype t) + (FIN : Binary.is_finite (BPLUS x y)), + exists delta, + Rabs delta <= @default_rel t /\ + (FT2R (BPLUS x y) = (FT2R x + FT2R y) * (1 + delta))%Re. Proof. -unfold BPLUS, BINOP. -intros. -eapply BPLUS_accurate. -1,2: destruct x,y; simpl; try discriminate; auto; - destruct s; destruct s0; simpl; try discriminate; auto. -apply is_finite_sum_no_overflow; auto. + unfold BPLUS, BINOP. + intros. + eapply BPLUS_accurate. + 1, 2: destruct x, y; simpl; try discriminate; auto; + destruct s; destruct s0; simpl; try discriminate; auto. + apply is_finite_sum_no_overflow; auto. Qed. -Definition Bminus_no_overflow (x y: R) : Prop := - (Rabs ( Generic_fmt.round Zaux.radix2 - (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode - BinarySingleNaN.mode_NE) (x - y )) < Raux.bpow Zaux.radix2 (femax t))%R. - +(** [BPLUS_finite_e] extracts finiteness of each summand from finiteness + of the sum. *) -Lemma is_finite_minus_no_overflow : - forall (x y: ftype t) - (HFINb : Binary.is_finite (BMINUS x y)), - Bminus_no_overflow (FT2R x) (FT2R y). +Lemma BPLUS_finite_e : + forall (a b : ftype t) + (Hfin : Binary.is_finite (BPLUS a b)), + Binary.is_finite a /\ Binary.is_finite b. Proof. -intros. -pose proof Rle_or_lt (bpow Zaux.radix2 (femax t)) (Rabs (rounded t (FT2R x - FT2R y))) as Hor; - destruct Hor; auto. -apply Rlt_bool_false in H. -assert (HFIN: Binary.is_finite x = true /\ Binary.is_finite y = true). -{ unfold BMINUS, BINOP in HFINb. - destruct x,y; - simpl in *; split; try discriminate; auto; - destruct s; destruct s0; simpl in *; try discriminate; auto. } -destruct HFIN as (A & B). -unfold rounded in H. -pose proof (Binary.Bminus_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE - x y A B) as - H0. -rewrite H in H0; -destruct H0 as ( C & _). -unfold BMINUS, BINOP in HFINb. -destruct ((Binary.Bminus (fprec t) (femax t) (fprec_gt_0 t) (fprec_lt_femax t) - (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE - x y)); -simpl; try discriminate. + unfold BPLUS, BINOP; intros. + destruct a, b; inversion Hfin; clear Hfin; subst; simpl; auto. + destruct s, s0; discriminate; auto. Qed. -Lemma no_overflow_minus_is_finite : - forall (x y: ftype t) - (H1 : Binary.is_finite x) - (H2 : Binary.is_finite y) - (Hov : Bminus_no_overflow (FT2R x) (FT2R y)), - Binary.is_finite (BMINUS x y). +(** [BPLUS_correct] gives the full correctness statement for + floating-point addition: + finiteness of the result implies finiteness of each summand and + relates the floating-point result to the rounded real result. *) + +Lemma BPLUS_correct (a b : ftype t) : + Binary.is_finite (BPLUS a b) -> + (Binary.is_finite a /\ Binary.is_finite b) /\ + FT2R (BPLUS a b) = + Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (FT2R a + FT2R b). Proof. -unfold Bminus_no_overflow, BMINUS, BINOP; intros. -pose proof (Binary.Bminus_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE - x y H1 H2) as - H0. -remember (Rlt_bool _ _ ) as HB; destruct HB. -destruct H0 as (_ & HP &_); auto. -exfalso. -unfold Rlt_bool in HeqHB. -fold (@FT2R t) in *. -remember (Rcompare _ _) as HR; destruct HR; try discriminate. -symmetry in HeqHR. -apply Rcompare_Eq_inv in HeqHR. -nra. -symmetry in HeqHR. -apply Rcompare_Gt_inv in HeqHR. -nra. + intros * FIN. + pose proof (is_finite_sum_no_overflow a b FIN). + apply Rlt_bool_true in H. + assert (Binary.is_finite a /\ Binary.is_finite b) + by (unfold is_true in *; unfold BPLUS, BINOP in FIN; + destruct a, b; + simpl in FIN; split; try discriminate; auto; + match goal with + | H : Binary.is_finite + (Binary.Bplus _ _ _ _ _ _ + (Binary.B754_infinity _ _ ?s) + (Binary.B754_infinity _ _ ?s0)) = _ |- + Binary.is_finite _ = _ => + destruct s; destruct s0; try discriminate; auto + end). + split; auto. + destruct H0. + pose proof (Binary.Bplus_correct (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE a b H0 H1) as H3. + rewrite {}H in H3. + apply H3. Qed. +(** ** Floating-Point Subtraction *) + +(** [BMINUS_accurate] establishes the pure relative rounding error model + for floating-point subtraction. *) + Lemma BMINUS_accurate : - forall (x : ftype t) (FINx: Binary.is_finite x) - (y : ftype t) (FINy: Binary.is_finite y) - (FIN: Bminus_no_overflow (FT2R x) (FT2R y)), - exists delta, - Rabs delta <= @default_rel t /\ - (FT2R (BMINUS x y ) = (FT2R x - FT2R y) * (1+delta))%Re. + forall (x : ftype t) (FINx : Binary.is_finite x) + (y : ftype t) (FINy : Binary.is_finite y) + (FIN : Bminus_no_overflow (FT2R x) (FT2R y)), + exists delta, + Rabs delta <= @default_rel t /\ + (FT2R (BMINUS x y) = (FT2R x - FT2R y) * (1 + delta))%Re. Proof. -intros. -pose proof (Binary.Bminus_correct (fprec t) (femax t) (fprec_gt_0 t) - (fprec_lt_femax t) (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) - BinarySingleNaN.mode_NE x y FINx FINy). -cbv zeta in H. -pose proof ( - Raux.Rlt_bool_spec - (Rabs - (Generic_fmt.round Zaux.radix2 + intros. + pose proof (Binary.Bminus_correct (fprec t) (femax t) (fprec_gt_0 t) + (fprec_lt_femax t) + (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y FINx FINy). + cbv zeta in H. + pose proof ( + Raux.Rlt_bool_spec + (Rabs + (Generic_fmt.round Zaux.radix2 + (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (Binary.B2R _ _ x - Binary.B2R _ _ y))) + (Raux.bpow Zaux.radix2 (femax t))). + fold (@FT2R t) in *. + destruct H0. + - destruct H as [? _]. + unfold BMINUS, BINOP. + rewrite H. + assert (A : Generic_fmt.generic_format Zaux.radix2 + (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) + (FT2R x)) + by apply Binary.generic_format_B2R. + assert (B : Generic_fmt.generic_format Zaux.radix2 + (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) + (- FT2R y)). + { apply Generic_fmt.generic_format_opp. + apply Binary.generic_format_B2R. } + pose proof Plus_error.FLT_plus_error_N_ex Zaux.radix2 + (SpecFloat.emin (fprec t) (femax t)) + (fprec t) (fun x0 : Z => negb (Z.even x0)) + (FT2R x) (- FT2R y) A B. + unfold Relative.u_ro in H1. fold (@default_rel t) in H1. + destruct H1 as (d & Hd & Hd'). + assert (Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode - BinarySingleNaN.mode_NE) - (Binary.B2R _ _ x - Binary.B2R _ _ y))) - (Raux.bpow Zaux.radix2 (femax t))). -fold (@FT2R t) in *. -destruct H0. -- -destruct H as [? _]. -unfold BMINUS, BINOP. -rewrite H. -assert (A: Generic_fmt.generic_format Zaux.radix2 - (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) - (FT2R x) ). -apply Binary.generic_format_B2R. -assert (B: Generic_fmt.generic_format Zaux.radix2 - (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) - (-FT2R y) ). -apply Generic_fmt.generic_format_opp. -apply Binary.generic_format_B2R. -pose proof Plus_error.FLT_plus_error_N_ex Zaux.radix2 (SpecFloat.emin (fprec t) (femax t)) - (fprec t) (fun x0 : Z => negb (Z.even x0)) (FT2R x) (-FT2R y) A B. -unfold Relative.u_ro in H1. fold (@default_rel t) in H1. -destruct H1 as (d & Hd & Hd'). -assert ( Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) - (FT2R x - FT2R y) = Generic_fmt.round Zaux.radix2 - (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) - (Generic_fmt.Znearest (fun x0 : Z => negb (Z.even x0))) - (FT2R x - FT2R y)) by auto. -replace (_ +- _) with ( FT2R x - FT2R y) in Hd' by nra. -rewrite <- H1 in Hd'. clear H1. -rewrite {}Hd'. -exists d; split; auto. -eapply Rle_trans; [apply Hd |]. -apply Rdiv_le_left. -apply Fourier_util.Rlt_zero_pos_plus1. -apply default_rel_gt_0. -eapply Rle_trans with (@default_rel t * 1); try nra. -- -red in FIN. lra. + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (FT2R x - FT2R y) = + Generic_fmt.round Zaux.radix2 + (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) + (Generic_fmt.Znearest (fun x0 : Z => negb (Z.even x0))) + (FT2R x - FT2R y)) by auto. + replace (_ + - _) with (FT2R x - FT2R y) in Hd' by nra. + rewrite <- H1 in Hd'. clear H1. + rewrite {}Hd'. + exists d; split; auto. + eapply Rle_trans; [apply Hd |]. + apply Rdiv_le_left. + apply Fourier_util.Rlt_zero_pos_plus1. + apply default_rel_gt_0. + eapply Rle_trans with (@default_rel t * 1); try nra. + - red in FIN. + unfold fmax in *. + lra. Qed. -Lemma BMINUS_accurate' : - forall (x y : ftype t) - (FIN: Binary.is_finite (BMINUS x y)), - exists delta, - Rabs delta <= @default_rel t /\ - (FT2R (BMINUS x y ) = (FT2R x - FT2R y) * (1+delta))%Re. -Proof. -intros. -eapply BMINUS_accurate. -1,2: unfold BMINUS, BINOP in FIN; -destruct x,y; simpl; try discriminate; auto; - destruct s; destruct s0; simpl; try discriminate; auto. -apply is_finite_minus_no_overflow; auto. -Qed. +(** [is_finite_minus_no_overflow] shows that finiteness of the floating-point + result implies that the exact difference does not overflow. *) -Lemma BPLUS_correct (a b: ftype t): - Binary.is_finite (BPLUS a b)-> - (Binary.is_finite a /\ Binary.is_finite b) /\ - FT2R (BPLUS a b) = - Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) - (FT2R a + FT2R b). +Lemma is_finite_minus_no_overflow : + forall (x y : ftype t) + (HFINb : Binary.is_finite (BMINUS x y)), + Bminus_no_overflow (FT2R x) (FT2R y). Proof. -intros * FIN. -pose proof (is_finite_sum_no_overflow a b FIN). -apply Rlt_bool_true in H. -assert (Binary.is_finite a /\ Binary.is_finite b) - by (unfold is_true in *; unfold BPLUS, BINOP in FIN; - destruct a,b; - simpl in FIN; split; try discriminate; auto ; - match goal with | H: Binary.is_finite - (Binary.Bplus _ _ _ _ _ _ (Binary.B754_infinity _ _ ?s) - (Binary.B754_infinity _ _ ?s0)) = _ |- Binary.is_finite _ = _ => - destruct s; destruct s0; try discriminate; auto end). -split; auto. -destruct H0. - pose proof (Binary.Bplus_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE - a b H0 H1) as H3. -rewrite {}H in H3. -apply H3. + intros. + pose proof Rle_or_lt (bpow Zaux.radix2 (femax t)) + (Rabs (rounded t (FT2R x - FT2R y))) as Hor; + destruct Hor; auto. + apply Rlt_bool_false in H. + assert (HFIN : Binary.is_finite x = true /\ Binary.is_finite y = true). + { unfold BMINUS, BINOP in HFINb. + destruct x, y; + simpl in *; split; try discriminate; auto; + destruct s; destruct s0; simpl in *; try discriminate; auto. } + destruct HFIN as (A & B). + unfold rounded in H. + pose proof (Binary.Bminus_correct (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y A B) as H0. + rewrite H in H0; + destruct H0 as (C & _). + unfold BMINUS, BINOP in HFINb. + destruct (Binary.Bminus (fprec t) (femax t) (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y); + simpl; try discriminate. Qed. -Lemma BMULT_correct (a b: ftype t): - Binary.is_finite (BMULT a b) -> - (Binary.is_finite a /\ - Binary.is_finite b) /\ - FT2R (BMULT a b) = - Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) - (FT2R a * FT2R b). +(** [no_overflow_minus_is_finite] is the converse direction for + subtraction: if both operands are finite and the exact difference + does not overflow, then the result is finite. *) + +Lemma no_overflow_minus_is_finite : + forall (x y : ftype t) + (H1 : Binary.is_finite x) + (H2 : Binary.is_finite y) + (Hov : Bminus_no_overflow (FT2R x) (FT2R y)), + Binary.is_finite (BMINUS x y). Proof. -intros * FIN. -pose proof (is_finite_BMULT_no_overflow a b FIN). -apply Rlt_bool_true in H. -assert (Binary.is_finite a = true /\ Binary.is_finite b = true) - by (unfold is_true in *; unfold BMULT, BINOP in FIN; - destruct a,b; - simpl in FIN; split; try discriminate; auto ; - match goal with | H: Binary.is_finite - (Binary.Bplus _ _ _ _ _ _ (Binary.B754_infinity _ _ ?s) - (Binary.B754_infinity _ _ ?s0)) = _ |- Binary.is_finite _ = _ => - destruct s; destruct s0; try discriminate; auto end). -split; auto. -destruct H0. -pose proof (Binary.Bmult_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) (FPCore.mult_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE - a b). -rewrite {}H in H2. -apply H2. + unfold Bminus_no_overflow, BMINUS, BINOP; intros. + pose proof (Binary.Bminus_correct (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y H1 H2) as H0. + remember (Rlt_bool _ _) as HB; destruct HB. + - destruct H0 as (_ & HP & _); auto. + - exfalso. + unfold Rlt_bool in HeqHB. + fold (@FT2R t) in *. + remember (Rcompare _ _) as HR; destruct HR; try discriminate. + + symmetry in HeqHR. + apply Rcompare_Eq_inv in HeqHR. + unfold fmax in *. + nra. + + symmetry in HeqHR. + apply Rcompare_Gt_inv in HeqHR. + unfold fmax in *. + nra. Qed. +(** [BMINUS_accurate'] is the _finiteness-hypothesis_ form of + [BMINUS_accurate]. *) -Lemma BFMA_correct (a b s: ftype t) : - Binary.is_finite (BFMA a b s) -> - (Binary.is_finite a /\ Binary.is_finite b /\ Binary.is_finite s) /\ - FT2R (BFMA a b s) = - Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) - (FT2R a * FT2R b + FT2R s). +Lemma BMINUS_accurate' : + forall (x y : ftype t) + (FIN : Binary.is_finite (BMINUS x y)), + exists delta, + Rabs delta <= @default_rel t /\ + (FT2R (BMINUS x y) = (FT2R x - FT2R y) * (1 + delta))%Re. Proof. -intros * FIN. -pose proof (is_finite_fma_no_overflow a b s FIN) as H4; apply Rlt_bool_true in H4; - unfold common.rounded in H4. -assert (H : Binary.is_finite a = true /\ Binary.is_finite b = true /\ Binary.is_finite s = true). - unfold BFMA, BINOP in FIN. - destruct a,b,s; auto; destruct s0,s1,s; discriminate. -split; auto. -destruct H as [? [? ?]]. -pose proof (Binary.Bfma_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) (FPCore.fma_nan (fprec t) (femax t) (fprec_gt_one t)) BinarySingleNaN.mode_NE - a b s H H0 H1) as H3; cbv zeta in H3. -fold (@FT2R t) in H3. -rewrite {}H4 in H3. -fold (BFMA a b s) in H3. -apply H3. + intros. + eapply BMINUS_accurate. + 1, 2: unfold BMINUS, BINOP in FIN; + destruct x, y; simpl; try discriminate; auto; + destruct s; destruct s0; simpl; try discriminate; auto. + apply is_finite_minus_no_overflow; auto. Qed. -End GenFloat. +(** [BMINUS_finite_sub] extracts finiteness of each operand from + finiteness of the difference. *) +Lemma BMINUS_finite_sub : + forall (a b : ftype t) + (Hfin : Binary.is_finite (BMINUS a b)), + Binary.is_finite a /\ Binary.is_finite b. +Proof. + unfold BMINUS, BINOP; intros. + destruct a, b; inversion Hfin; clear Hfin; subst; simpl; auto. + destruct s, s0; discriminate; auto. +Qed. +End GenFloat. \ No newline at end of file diff --git a/accuracy_proofs/fma_dot_acc.v b/accuracy_proofs/fma_dot_acc.v index c911a94..8d0ac17 100644 --- a/accuracy_proofs/fma_dot_acc.v +++ b/accuracy_proofs/fma_dot_acc.v @@ -1,130 +1,191 @@ -(** This file contains three main theorems for the accuracy of the fma - dot product : fma_dotprod_mixed_error, fma_dotprod_forward_error, - and fma_sparse_dotprod_forward_error. *) +(** * FMA Dot Product Accuracy Theorems -From LAProof.accuracy_proofs Require Import preamble common + This file establishes three main accuracy theorems for the fused + multiply-add (FMA) dot product computation [fma_dotprod]. + + The significance of these results is that they provide rigorous + floating-point error bounds for dot products computed using FMA + instructions, which are critical in numerical linear algebra. + + ** Error Factors + + Throughout, the accumulated relative error factor is + %$g(n) = (1 + \mathbf{u})^n - 1$%#\(g(n) = (1 + \mathbf{u})^n - 1\)# and + the mixed absolute error factor is + %$g_1(n_1, n_2) = n_1 \cdot \eta \cdot (1 + g(n_2))$%#\(g_1(n_1, n_2) = n_1 \cdot \eta \cdot (1 + g(n_2))\)#, + where %$\mathbf{u}$%#\(\mathbf{u}\)# is the unit roundoff and + %$\eta$%#\(\eta\)# is the underflow threshold for the given type. + Both are defined in [common]. + + ** Main Results + + - [fma_dotprod_mixed_error]: Shows that the FMA-computed dot product + can be expressed as an exact dot product of slightly perturbed inputs + plus a small absolute error term. + + - [fma_dotprod_forward_error]: Bounds the absolute forward error + %$|\mathtt{fl}_{\mathrm{fma}}(v_1 \cdot v_2) - v_1 \cdot v_2|$%#\(|\mathtt{fl}_{\mathrm{fma}}(v_1 \cdot v_2) - v_1 \cdot v_2|\)#. + + - [fma_sparse_dotprod_forward_error]: Refines [fma_dotprod_forward_error] + for sparse inputs by replacing the full vector length %$n$%#\(n\)# with + the number of nonzero entries, + giving a tighter bound when the input vectors are sparse. + + ** Dependencies + + This file relies on: + - [preamble], [common]: basic setup and shared definitions + - [dotprod_model], [sum_model]: relational models of FMA dot product + and summation + - [float_acc_lems]: elementary floating-point accuracy lemmas + - [dot_acc_lemmas]: dot-product-specific accuracy lemmas +*) + +From LAProof.accuracy_proofs Require Import preamble common dotprod_model sum_model float_acc_lems dot_acc_lemmas. -Section MixedError. -Context {NAN: FPCore.Nans} {t : type}. +(** * Mixed Error *) +Section MixedError. +Context {NAN : FPCore.Nans} {t : type}. -Notation g := (@g t). -Notation g1 := (@g1 t). -Notation D := (@default_rel t). -Notation E := (@default_abs t). +Notation g := (@g t). +Notation g1 := (@g1 t). +Notation D := (@default_rel t). +Notation E := (@default_abs t). Notation neg_zero := (@common.neg_zero t). -Variables (v1 v2: list (ftype t)). -Hypothesis Hlen: size v1 = size v2. -Hypothesis Hfin: Binary.is_finite(fma_dotprod v1 v2) = true. - -Lemma fma_dotprod_mixed_error: +Variables (v1 v2 : list (ftype t)). +Hypothesis Hlen : size v1 = size v2. +Hypothesis Hfin : Binary.is_finite (fma_dotprod v1 v2) = true. + +(** [fma_dotprod_mixed_error] expresses the FMA-computed dot product as an + exact inner product of component-wise perturbed inputs plus a small + absolute offset. The relative perturbation on each input component is + bounded by %$g(n)$%#\(g(n)\)# and the absolute residual by + %$g_1(n,\,n-1)$%#\(g_1(n,n-1)\)#. *) + +Lemma fma_dotprod_mixed_error : exists (u : list R) (eta : R), size u = size v1 /\ FT2R (fma_dotprod v1 v2) = dotprodR u (map FT2R v2) + eta /\ - (forall n, (n < size v2)%nat -> exists delta, - nth 0 u n = FT2R (nth neg_zero v1 n) * (1 + delta) /\ Rabs delta <= g (size v2)) /\ - Rabs eta <= g1 (size v2) (size v2 -1). + (forall n, (n < size v2)%nat -> + exists delta, + nth 0 u n = FT2R (nth neg_zero v1 n) * (1 + delta) /\ + Rabs delta <= g (size v2)) /\ + Rabs eta <= g1 (size v2) (size v2 - 1). Proof. -intros. -assert (size (zip v1 v2) = size v1) by (rewrite size_zip; lia). -assert (Hlenr : size (rev v1) = size (rev v2)) by (rewrite !size_rev; auto). -rewrite <- size_rev in Hlen. -pose proof fma_dot_prod_rel_fold_right v1 v2 as H1. -rewrite rev_zip in H1. 2: revert Hlen; rewrite size_rev; auto. -revert Hlen. -rewrite size_rev; intro. -pose proof (fma_dotprod_mixed_error_rel (rev v1) (rev v2) Hlenr (fma_dotprod v1 v2) H1 Hfin) as - (u & eta & H2 & H3 & H4 & H5). -exists (rev u), eta; repeat split; auto. -- -rewrite !size_rev in H2|-*; auto. -- -pose proof dotprodR_rel u (map FT2R (rev v2)). -assert (dotprodR (rev u) (map FT2R v2) = FT2R (fma_dotprod v1 v2) - eta). -eapply R_dot_prod_rel_eq; eauto. -rewrite <- dotprodR_rev, <- map_rev. auto. -rewrite size_rev in H2; rewrite size_map; auto; lia. -nra. -- -rewrite !size_rev in H4. -intros. -rewrite size_rev in H2. -assert ((size u - S n < size v2)%nat) by lia. -specialize (H4 (size u - S n)%nat H6). -rewrite nth_rev in H4. 2: rewrite Hlen //. -rewrite nth_rev. 2: rewrite H2 Hlen //. -destruct H4 as (delta & Hn & HD). -exists delta; split. -rewrite Hn; repeat f_equal. -rewrite Hlen. -rewrite H2. -rewrite <- Nat.sub_succ_l. -simpl. lia. -rewrite Hlen. lia. -apply HD. -- rewrite !size_rev in H5; auto. + assert (Hzip : size (zip v1 v2) = size v1) by (rewrite size_zip; lia). + assert (Hlenr : size (rev v1) = size (rev v2)) by (rewrite !size_rev; auto). + rewrite <- size_rev in Hlen. + pose proof fma_dot_prod_rel_fold_right v1 v2 as Hrel. + rewrite rev_zip in Hrel. 2: revert Hlen; rewrite size_rev; auto. + revert Hlen; rewrite size_rev; intro Hlen. + pose proof (fma_dotprod_mixed_error_rel + (rev v1) (rev v2) Hlenr + (fma_dotprod v1 v2) Hrel Hfin) + as (u & eta & Hsize & Heq & Hbnd & Heta). + exists (rev u), eta. + repeat split. + - (* size rev u = size v1 *) + rewrite !size_rev in Hsize |-*; auto. + - (* FT2R (fma_dotprod v1 v2) = dotprodR (rev u) (map FT2R v2) + eta *) + pose proof dotprodR_rel u (map FT2R (rev v2)) as Hdot. + assert (Heqr : dotprodR (rev u) (map FT2R v2) = + FT2R (fma_dotprod v1 v2) - eta). + { eapply R_dot_prod_rel_eq; eauto. + rewrite <- dotprodR_rev, <- map_rev; auto. + rewrite size_rev in Hsize; rewrite size_map; auto; lia. } + nra. + - (* per-index bound on entries of rev u *) + rewrite !size_rev in Hbnd. + intros n Hn. + rewrite size_rev in Hsize. + assert (Hlt : (size u - S n < size v2)%nat) by lia. + specialize (Hbnd (size u - S n)%nat Hlt). + rewrite nth_rev in Hbnd. 2: rewrite Hlen //. + rewrite nth_rev. 2: rewrite Hsize Hlen //. + destruct Hbnd as (delta & Hnth & Hdelta). + exists delta; split. + + rewrite Hnth; repeat f_equal. + rewrite Hlen Hsize. rewrite <- Nat.sub_succ_l; simpl; lia. + + exact Hdelta. + - (* |eta| <= g1(|v2|, |v2| - 1) *) + rewrite !size_rev in Heta; auto. Qed. End MixedError. -Section ForwardError. -Context {NAN: FPCore.Nans} {t : type}. +(** * Forward Error *) +Section ForwardError. +Context {NAN : FPCore.Nans} {t : type}. Variables v1 v2 : list (ftype t). -Notation v1R := (map FT2R v1). -Notation v2R := (map FT2R v2). -Notation v1R' := (map Rabs v1R). -Notation v2R' := (map Rabs v2R). -Notation n := (size v2). - -Notation g := (@g t). -Notation g1 := (@g1 t). + +Notation v1R := (map FT2R v1). +Notation v2R := (map FT2R v2). +Notation v1R' := (map Rabs v1R). +Notation v2R' := (map Rabs v2R). +Notation n := (size v2). +Notation g := (@g t). +Notation g1 := (@g1 t). Notation neg_zero := (@common.neg_zero t). -Hypothesis Hlen: size v1 = size v2. -Hypothesis Hfin: Binary.is_finite(fma_dotprod v1 v2) = true. +Hypothesis Hlen : size v1 = size v2. +Hypothesis Hfin : Binary.is_finite (fma_dotprod v1 v2) = true. -Lemma fma_dotprod_forward_error: - Rabs (FT2R (fma_dotprod v1 v2) - dotprodR v1R v2R ) - <= g n * dotprodR v1R' v2R' + g1 n (n - 1). +(** [fma_dotprod_forward_error] bounds the absolute forward error of the + FMA-computed dot product. *) + +Lemma fma_dotprod_forward_error : + Rabs (FT2R (fma_dotprod v1 v2) - dotprodR v1R v2R) + <= g n * dotprodR v1R' v2R' + g1 n (n - 1). Proof. -intros. -pose proof R_dot_prod_rel_fold_right' t v1 v2 Hlen as HB. -pose proof R_dot_prod_rel_fold_right_Rabs' t v1 v2 Hlen as HC. - simpl in HB, HC. rewrite <- map_rev in HC, HB. + pose proof R_dot_prod_rel_fold_right' t v1 v2 Hlen as HB. + pose proof R_dot_prod_rel_fold_right_Rabs' t v1 v2 Hlen as HC. + simpl in HB, HC. + rewrite <- map_rev in HC, HB. rewrite <- map_rev in HC. -pose proof fma_dotprod_forward_error_rel (rev (zip v1 v2)) - (fma_dotprod v1 v2) (fma_dot_prod_rel_fold_right _ _ ) Hfin - (dotprodR v1R v2R) (dotprodR v1R' v2R') HB HC. -rewrite size_rev size_zip Hlen minnn in H. -auto. + pose proof fma_dotprod_forward_error_rel + (rev (zip v1 v2)) + (fma_dotprod v1 v2) + (fma_dot_prod_rel_fold_right _ _) + Hfin + (dotprodR v1R v2R) + (dotprodR v1R' v2R') + HB HC as H. + rewrite size_rev size_zip Hlen minnn in H. + exact H. Qed. Notation nnzR := (common.nnzR v1R). -Lemma fma_sparse_dotprod_forward_error: - Rabs (FT2R (fma_dotprod v1 v2) - dotprodR v1R v2R ) <= - g nnzR * dotprodR v1R' v2R' + g1 nnzR (nnzR - 1). -Proof. -intros. -pose proof fma_dot_prod_rel_fold_right v1 v2 as HA. -pose proof R_dot_prod_rel_fold_right' t v1 v2 Hlen as HB. -pose proof R_dot_prod_rel_fold_right_Rabs' t v1 v2 Hlen as HC. - simpl in HB, HC. rewrite <- map_rev in HC, HB. +(** [fma_sparse_dotprod_forward_error] refines [fma_dotprod_forward_error] + for sparse inputs by replacing the full vector length %$n$%#\(n\)# with + the number of nonzero entries. *) + +Lemma fma_sparse_dotprod_forward_error : + Rabs (FT2R (fma_dotprod v1 v2) - dotprodR v1R v2R) + <= g nnzR * dotprodR v1R' v2R' + g1 nnzR (nnzR - 1). +Proof. + pose proof fma_dot_prod_rel_fold_right v1 v2 as HA. + pose proof R_dot_prod_rel_fold_right' t v1 v2 Hlen as HB. + pose proof R_dot_prod_rel_fold_right_Rabs' t v1 v2 Hlen as HC. + simpl in HB, HC. + rewrite <- map_rev in HC, HB. rewrite <- map_rev in HC. -pose proof sparse_fma_dotprod_forward_error (rev v1) (rev v2). - rewrite !size_rev -rev_zip in H; auto. -specialize (H Hlen (fma_dotprod v1 v2) HA Hfin (dotprodR v1R v2R) - (dotprodR v1R' v2R') HB HC). -rewrite map_rev in H. -unfold common.nnzR, nnzR in H. -rewrite !count_rev in H. -auto. + pose proof sparse_fma_dotprod_forward_error (rev v1) (rev v2) as H. + rewrite !size_rev -rev_zip in H; auto. + specialize (H Hlen + (fma_dotprod v1 v2) HA Hfin + (dotprodR v1R v2R) + (dotprodR v1R' v2R') + HB HC). + rewrite map_rev in H. + unfold common.nnzR, nnzR in H. + rewrite !count_rev in H. + exact H. Qed. -End ForwardError. - - +End ForwardError. \ No newline at end of file diff --git a/accuracy_proofs/fma_is_finite.v b/accuracy_proofs/fma_is_finite.v index a51f664..38917c6 100644 --- a/accuracy_proofs/fma_is_finite.v +++ b/accuracy_proofs/fma_is_finite.v @@ -1,412 +1,381 @@ -From LAProof.accuracy_proofs Require Import preamble common dotprod_model sum_model - float_acc_lems dot_acc_lemmas sum_acc. -Import Zorder. +(** * Finite Sum from Bounded Inputs -Section NAN. -Context {NAN: FPCore.Nans} {t : type}. + This file establishes key lemmas for proving finiteness of floating-point + dot products computed via fused multiply-add (FMA) operations. The central + result, [finite_sum_from_bounded], shows that if all input vector elements + are bounded in absolute value, then the accumulated + floating-point dot product remains finite. -Definition fmax := bpow Zaux.radix2 (femax t). + The proof strategy relies on: + - A bound function [fun_bnd t n] that decreases monotonically in n, + ensuring that element-wise bounds suffice to prevent overflow at each + FMA step. + - Forward error analysis for FMA-based dot products, connecting floating-point + accumulation to real-valued dot products. -Lemma is_finite_fma_no_overflow' : - forall (x y z: ftype t) - (Hfinx:Binary.is_finite x = true) - (Hfiny:Binary.is_finite y = true) - (Hfinz:Binary.is_finite z = true) - (Hov : @fma_no_overflow t (FT2R x) (FT2R y) (FT2R z)), -Binary.is_finite (BFMA x y z) = true. -Proof. -intros. -pose proof (Binary.Bfma_correct (fprec t) (femax t) - (fprec_gt_0 t) (fprec_lt_femax t) - (FPCore.fma_nan (fprec t) (femax t) (fprec_gt_one t)) - BinarySingleNaN.mode_NE - x y z Hfinx Hfiny Hfinz). -cbv zeta in H. -rewrite Rlt_bool_true in H. -move :H => [] _ [] HFIN _. -auto. -move: Hov. by rewrite /fma_no_overflow /rounded. -Qed. + ** Key Definitions -Definition fun_bnd (t : type) (n : nat) := -let x := (fmax - @default_abs t) / (1 + @default_rel t) - @g1 t n (n-1) in -let y := 1 / (1 + INR n * (@g t (n - 1) + 1)) in x * y. - -Lemma rdiv_lt (a b: R) : - 0 < b -> 0 < a -> b < a -> / a < / b. -Proof. -intros. -replace (/b) with (1/b) by nra. -apply Rdiv_lt_right; auto. -rewrite Rmult_comm. -apply Rdiv_lt_left; auto. -nra. -Qed. + - [fun_bnd t n] : A per-element magnitude bound such that dot products of + vectors of length n whose entries satisfy this bound do not overflow. -Lemma rdiv_le (a b: R) : - 0 < b -> 0 < a -> b <= a -> / a <= / b. -Proof. -intros. -replace (/b) with (1/b) by nra. -apply Rcomplements.Rle_div_r; auto. -rewrite Rmult_comm. -apply Rdiv_le_left; auto. -nra. -Qed. + ** Key Lemmas -Lemma rdiv_mult_eq : -forall a b, b <> 0 -> a/b = a * (1/b) . -Proof. -(intros; field_simplify; nra). -Qed. + - [fun_bnd_le] : The bound [fun_bnd t n] is non-increasing in n. + - [fun_bound_pos] : The bound [fun_bnd t n] is non-negative. + - [finite_sum_from_bounded] : Element-wise bounds on inputs + imply finiteness of the entire FMA dot product accumulation. +*) -Lemma Rminus_le_minus a b c : - b <= c -> a - c <= a - b. -Proof. nra. Qed. +From LAProof.accuracy_proofs Require Import + preamble + real_lemmas + common + dotprod_model + sum_model + float_acc_lems + dot_acc_lemmas + sum_acc. -Lemma Rminus_lt_minus a b c : - b < c -> a - c < a - b. -Proof. nra. Qed. +Import Zorder. -Lemma defualt_abs_le_fmax : -@default_abs t <= fmax. -Proof. -replace (fmax) with (1 * fmax) by nra. -unfold default_abs, fmax; apply Rmult_le_compat; try nra. -apply bpow_ge_0. -apply bpow_le. -apply Z.le_sub_le_add_r. -apply Z.le_sub_le_add_r. -eapply Z.le_trans with (fprec t + fprec t + femax t)%Z; - [ | repeat apply Zplus_le_compat_r; apply Z.lt_le_incl; apply fprec_lt_femax]. -eapply Z.le_trans with (fprec t + fprec t + fprec t)%Z; -[ | repeat apply Zplus_le_compat_l;apply Z.lt_le_incl; apply fprec_lt_femax ]. -eapply Z.le_trans with (1 + fprec t + fprec t)%Z; -[ | repeat apply Zplus_le_compat_r;apply Z.lt_le_incl;apply fprec_gt_one]. -eapply Z.le_trans with (1 + 1 + fprec t)%Z; -[ | repeat apply Zplus_le_compat_r; repeat apply Zplus_le_compat_l; apply Z.lt_le_incl; -apply fprec_gt_one]. -eapply Z.le_trans with (1 + 1 + 1)%Z; -[ lia | repeat apply Zplus_le_compat_r; repeat apply Zplus_le_compat_l; apply Z.lt_le_incl; -apply fprec_gt_one]. -Qed. +Section NAN. -Lemma bpow_femax_lb : -4 <= bpow Zaux.radix2 (femax t). -Proof. -pose proof fprec_gt_one t. -pose proof fprec_lt_femax t. -assert (1 < femax t)%Z. -eapply Z.lt_trans with (fprec t); auto. -eapply Rle_trans with (bpow Zaux.radix2 2). -unfold bpow; simpl; nra. -apply bpow_le; lia. -Qed. +Context {NAN : FPCore.Nans} {t : type}. -Lemma bpow_fprec_lb : -2 <= bpow Zaux.radix2 (fprec t). -Proof. -pose proof fprec_gt_one t. -eapply Rle_trans with (bpow Zaux.radix2 1). -unfold bpow; simpl; nra. -apply bpow_le; lia. -Qed. +Notation fmax := (@fmax t). -Lemma default_abs_ub : -@default_abs t <= 1. -Proof. -pose proof (@abs_le_rel t). -eapply Rle_trans. apply H. -rewrite /default_rel bpow_plus bpow_opp. -replace (bpow _ 1) with 2. -refine (Rle_trans _ (1/bpow Zaux.radix2 (fprec t)) _ _ _); - [try nra | apply Rdiv_le_left ]. -apply bpow_gt_0. -refine (Rle_trans _ 2 _ _ _); try nra. -rewrite Rmult_1_l. apply bpow_fprec_lb. -simpl; nra. -Qed. +(** ** Bound Function *) -Lemma default_rel_ub : -@default_rel t <= 1. -Proof. -unfold default_rel. -pose proof bpow_gt_0 Zaux.radix2 (fprec t). -rewrite !bpow_plus. -rewrite <- !Rmult_assoc. -rewrite Rmult_comm. -rewrite <- !Rmult_assoc. -replace (bpow Zaux.radix2 1 * / 2) with 1; [|simpl;nra]. -rewrite !bpow_opp !Rcomplements.Rle_div_r. -field_simplify; try nra. -replace 1 with (bpow Zaux.radix2 0). -apply bpow_le. -pose proof fprec_gt_one t; lia. -simpl; auto. -apply Rlt_gt; -replace (/ bpow Zaux.radix2 (fprec t)) with (1 / bpow Zaux.radix2 (fprec t)) by nra; -apply Rdiv_lt_0_compat; try nra. -Qed. +(** [fun_bnd t n] is used to construct a bound that will not cause + overflow during FMA-based dot product of vectors of length n. *) + +Definition fun_bnd (t : type) (n : nat) := + let x := (fmax - @default_abs t) / (1 + @default_rel t) - @g1 t n (n - 1) in + let y := 1 / (1 + INR n * (@g t (n - 1) + 1)) in + x * y. + + +(** ** Positivity of [fun_bnd] Components *) +(** The numerator of [fun_bnd] is non-negative when the g1 bound holds. *) -Lemma fun_bnd_pos_1 : -forall n -(Hn : @g1 t (n + 1) n <= fmax ), -0 <= (fmax - @default_abs t) / (1 + @default_rel t) - @g1 t n (n-1). +Lemma fun_bnd_pos_1 : + forall (n : nat) + (Hn : @g1 t (n + 1) n <= fmax), + 0 <= (fmax - @default_abs t) / (1 + @default_rel t) - @g1 t n (n - 1). Proof. -intros; -apply Rle_0_minus. apply Generic_proof.Rdiv_le_mult_pos; -[apply default_rel_plus_1_gt_0 | apply Rminus_plus_le_minus]. -assert (Hn': (n=0)%nat \/ (1<=n)%nat) by lia; destruct Hn'; subst. -{ simpl. unfold g1, g. simpl; field_simplify. apply defualt_abs_le_fmax. } -assert (Hn': (n = 1)%nat \/ (1 < n)%nat) by lia; destruct Hn'; subst. -{ simpl. unfold g1, g. simpl; field_simplify. -eapply Rle_trans. -apply Rplus_le_compat. -apply Rmult_le_compat. -apply default_abs_ge_0. -apply default_rel_ge_0. -apply default_abs_ub. -apply default_rel_ub. -apply Rmult_le_compat_l; try nra. -apply default_abs_ub. -eapply Rle_trans; [| apply bpow_femax_lb]; nra. } -eapply Rle_trans. apply mult_d_e_g1_le'; try lia. -replace (S n) with (n + 1)%nat by lia. -replace (S (n - 1)) with n by lia; auto. + intros n Hn. + apply Rle_0_minus. + apply Generic_proof.Rdiv_le_mult_pos; + [ apply default_rel_plus_1_gt_0 + | apply Rminus_plus_le_minus ]. + assert (Hn0 : (n = 0)%nat \/ (1 <= n)%nat) by lia. + destruct Hn0 as [Hn0 | Hn0]; subst. + { (* n = 0 *) + simpl; unfold g1, g; simpl; field_simplify. + apply default_abs_le_fmax. } + assert (Hn1 : (n = 1)%nat \/ (1 < n)%nat) by lia. + destruct Hn1 as [Hn1 | Hn1]; subst. + { (* n = 1 *) + simpl; unfold g1, g; simpl; field_simplify. + eapply Rle_trans. + - apply Rplus_le_compat. + + apply Rmult_le_compat; + [ apply default_abs_ge_0 + | apply default_rel_ge_0 + | apply default_abs_ub + | apply default_rel_ub ]. + + apply Rmult_le_compat_l; try nra. + apply default_abs_ub. + - eapply Rle_trans; [| apply bpow_fmax_lb_4]; nra. } + (* n > 1 *) + eapply Rle_trans. + - apply mult_d_e_g1_le'; try lia. + - replace (S n) with (n + 1)%nat by lia. + replace (S (n - 1)) with n by lia. + auto. Qed. +(** ** Monotonicity of [fun_bnd] *) -Lemma fun_bnd_le (n : nat) : -forall (Hn : @g1 t (S n + 1) (S n) <= fmax), -fun_bnd t (S n) <= fun_bnd t n. +(** [fun_bnd t n] is non-increasing. *) + +Lemma fun_bnd_le (n : nat) : + forall (Hn : @g1 t (S n + 1) (S n) <= fmax), + fun_bnd t (S n) <= fun_bnd t n. Proof. -assert (Hn': (n=0)%nat \/ (1<=n)%nat) by lia; destruct Hn'; subst. -{ intros; simpl. unfold fun_bnd. apply Rmult_le_compat; try apply Rabs_pos. -apply fun_bnd_pos_1; auto. simpl. unfold g. simpl; field_simplify; nra. -apply Rminus_le_minus. simpl. unfold g1, g; field_simplify; simpl. -field_simplify. apply default_abs_ge_0. -simpl; unfold g; field_simplify; simpl; try nra. } -intros; unfold fun_bnd. -assert (0 < 1 + INR (S n) * (@g t (S n - 1) + 1)). -{ -apply Rplus_lt_le_0_compat; try nra. -apply Rmult_le_pos; try apply pos_INR. -apply Rplus_le_le_0_compat; try nra; apply g_pos. } -assert ( -INR n * @g t (n - 1) + INR n + 1 > 0). -{ -apply Rplus_lt_le_0_compat; try nra. -apply Rplus_le_lt_0_compat; [| apply lt_0_INR; lia]. -apply Rmult_le_pos; try apply pos_INR. -apply g_pos. } -apply Rmult_le_compat; try apply Rabs_pos. -apply fun_bnd_pos_1; auto. -apply Rdiv_le_0_compat_Raux; try nra. -apply Rminus_le_minus. -replace (S n -1)%nat with (S (n-1))%nat by lia. -apply g1n_le_g1Sn; auto. -apply Rcomplements.Rle_div_r. -apply Rlt_gt. -replace (/ (1 + INR (S n) * (@g t (S n - 1) + 1))) with - (1/(1 + INR (S n) * (@g t (S n - 1) + 1))) by nra. -apply Rdiv_lt_0_compat; try nra. -field_simplify; try nra. -apply Rcomplements.Rle_div_r; try nra. -rewrite Rmult_1_l. -apply Rplus_le_compat; try nra. -apply Rplus_le_compat. -apply Rmult_le_compat; [ apply pos_INR | apply g_pos | | ]. -apply le_INR; lia. -replace (S n - 1)%nat with (S (n-1))%nat by lia. -apply le_g_Sn. -apply le_INR; lia. + assert (Hn0 : (n = 0)%nat \/ (1 <= n)%nat) by lia. + destruct Hn0 as [Hn0 | Hn0]; subst. + { (* n = 0 *) + intros Hn; simpl; unfold fun_bnd. + apply Rmult_le_compat; try apply Rabs_pos. + - apply fun_bnd_pos_1; auto. + - simpl; unfold g; simpl; field_simplify; nra. + - apply Rminus_le_minus; simpl; unfold g1, g; + field_simplify; simpl; field_simplify; + apply default_abs_ge_0. + - simpl; unfold g; field_simplify; simpl; nra. } + (* n >= 1 *) + intros Hn; unfold fun_bnd. + assert (Hpos_Sn : 0 < 1 + INR (S n) * (@g t (S n - 1) + 1)). + { apply Rplus_lt_le_0_compat; try nra. + apply Rmult_le_pos; + [ apply pos_INR + | apply Rplus_le_le_0_compat; try nra; apply g_pos ]. } + assert (Hpos_n : INR n * @g t (n - 1) + INR n + 1 > 0). + { apply Rplus_lt_le_0_compat; try nra. + apply Rplus_le_lt_0_compat; + [ apply Rmult_le_pos; [apply pos_INR | apply g_pos] + | apply lt_0_INR; lia ]. } + apply Rmult_le_compat; try apply Rabs_pos. + - apply fun_bnd_pos_1; auto. + - apply Rdiv_le_0_compat_Raux; try nra. + - apply Rminus_le_minus. + replace (S n - 1)%nat with (S (n - 1))%nat by lia. + apply g1n_le_g1Sn; auto. + - apply Rcomplements.Rle_div_r. + apply Rlt_gt. + replace (/ (1 + INR (S n) * (@g t (S n - 1) + 1))) + with (1 / (1 + INR (S n) * (@g t (S n - 1) + 1))) by nra. + apply Rdiv_lt_0_compat; try nra. + field_simplify; try nra. + apply Rcomplements.Rle_div_r; try nra. + rewrite Rmult_1_l. + apply Rplus_le_compat; try nra. + apply Rplus_le_compat. + + apply Rmult_le_compat; + [ apply pos_INR + | apply g_pos + | apply le_INR; lia + | replace (S n - 1)%nat with (S (n - 1))%nat by lia; + apply le_g_Sn ]. + + apply le_INR; lia. Qed. +(** ** List Splitting Lemmas *) + +(** The two halves of l have equal length. *) Lemma length_split {A : Type} (l : list (A * A)) : -length (fst (List.split l)) = length (snd (List.split l)). -Proof. -induction l; [simpl; auto | ]. -destruct a; simpl; destruct (List.split l); simpl. -simpl in IHl; lia. + length (fst (List.split l)) = length (snd (List.split l)). +Proof. + induction l as [| a l IHl]; [simpl; auto |]. + destruct a; simpl. + destruct (List.split l); simpl in *; lia. Qed. -Lemma fun_bound_pos n : -forall (Hn : @g1 t (n + 1) n <= fmax), -0 <= fun_bnd t n. +(** [fun_bnd t n] is non-negative. *) + +Lemma fun_bound_pos (n : nat) : + forall (Hn : @g1 t (n + 1) n <= fmax), + 0 <= fun_bnd t n. Proof. -intros; -unfold fun_bnd; apply Rmult_le_pos. -apply fun_bnd_pos_1; auto. -apply Rdiv_le_0_compat_Raux; try nra. -apply Rplus_lt_le_0_compat; try nra. -apply Rmult_le_pos; try apply pos_INR. -apply Rplus_le_le_0_compat; try nra; apply g_pos. + intros Hn; unfold fun_bnd. + apply Rmult_le_pos. + - apply fun_bnd_pos_1; auto. + - apply Rdiv_le_0_compat_Raux; try nra. + apply Rplus_lt_le_0_compat; try nra. + apply Rmult_le_pos; + [ apply pos_INR + | apply Rplus_le_le_0_compat; try nra; apply g_pos ]. Qed. +(** Splitting and recombining a list of pairs is the identity. *) + Lemma combine_split {A : Type} (l : list (A * A)) : -combine (fst (List.split l)) (snd (List.split l)) = l. + combine (fst (List.split l)) (snd (List.split l)) = l. Proof. -induction l; [simpl; auto | ]. -destruct a; simpl; destruct (List.split l); simpl. -simpl in IHl; rewrite IHl; auto. + induction l as [| a l IHl]; [simpl; auto |]. + destruct a; simpl. + destruct (List.split l); simpl in *. + rewrite IHl; auto. Qed. +(** ** Main Result: Finiteness from Bounded Inputs *) -Lemma finite_sum_from_bounded : - forall (v1 v2: list (ftype t)) - (fp : ftype t) - (Hfp: fma_dot_prod_rel (List.combine v1 v2) fp) - (Hn : @g1 t (S (length (List.combine v1 v2)) + 1) (S (length (List.combine v1 v2))) <= fmax ), - (forall x, In x (List.combine v1 v2) -> - Binary.is_finite (fst x) = true /\ - Binary.is_finite (snd x) = true /\ - Rabs (FT2R (fst x)) < sqrt (fun_bnd t (length (List.combine v1 v2))) /\ - Rabs (FT2R (snd x)) < sqrt (fun_bnd t (length (List.combine v1 v2))))-> - Binary.is_finite fp = true. +(** If every pair [(x, y)] in the combined input list satisfies: + - both x and y are finite, + - |FT2R x| and |FT2R y| are strictly bounded, + + and the g1 overflow condition holds for the list length, then the + FMA dot product accumulation fp is finite. + + The proof proceeds by induction on the input list, applying the single-step + [is_finite_fma_no_overflow'] lemma and the forward error bound + [fma_dotprod_forward_error_rel] at each step. *) + +Lemma finite_sum_from_bounded : + forall (v1 v2 : list (ftype t)) + (fp : ftype t) + (Hfp : fma_dot_prod_rel (List.combine v1 v2) fp) + (Hn : @g1 t (S (length (List.combine v1 v2)) + 1) + (S (length (List.combine v1 v2))) <= fmax), + (forall x, In x (List.combine v1 v2) -> + Binary.is_finite (fst x) = true /\ + Binary.is_finite (snd x) = true /\ + Rabs (FT2R (fst x)) < sqrt (fun_bnd t (length (List.combine v1 v2))) /\ + Rabs (FT2R (snd x)) < sqrt (fun_bnd t (length (List.combine v1 v2)))) -> + Binary.is_finite fp = true. Proof. -intros ? ? . -induction (List.combine v1 v2). -{ intros; inversion Hfp; subst; auto. } -{ intros. inversion Hfp; subst. -assert (Hn' : @g1 t (S (length l) + 1) (S (length l)) <= fmax). -{ eapply Rle_trans; [ | apply Hn]; simpl. set (n:= (length l + 1)%nat). - replace (length l) with (n-1)%nat by lia. - replace ((n - 1).+1 + 1)%nat with (n.+1) by lia. - replace ((n - 1).+2 + 1)%nat with (n.+1.+1) by lia. - replace ((n-1).+1)%nat with (n.+1-1)%nat by lia. - apply g1n_le_g1Sn; lia. } -assert (Hin: forall x : (ftype t * ftype t), - In x l -> Binary.is_finite (fst x) = true /\ - Binary.is_finite (snd x) = true /\ - Rabs (FT2R (fst x)) < sqrt (fun_bnd t (length l)) /\ - Rabs (FT2R (snd x)) < sqrt (fun_bnd t (length l))). - { intros. repeat split; [apply H; simpl; auto | apply H; simpl; auto | | ]. - eapply Rlt_le_trans; [apply H; simpl; auto | apply sqrt_le_1_alt; apply fun_bnd_le; auto ]. - eapply Rlt_le_trans; [apply H; simpl; auto | apply sqrt_le_1_alt; apply fun_bnd_le; auto ]. } -assert (Hfina: Binary.is_finite (fst a) = true /\ - Binary.is_finite (snd a) = true) by - (split; apply H; simpl; auto); destruct Hfina as (Hfina1 & Hfina2). -specialize (IHl s H3 Hn' Hin). -apply is_finite_fma_no_overflow'; auto. -unfold fma_no_overflow, rounded. -destruct (@generic_round_property t (FT2R (fst a) * FT2R (snd a) + FT2R s)) as - (del & eps & Hed & Hd & He & Hrn ); -rewrite Hrn; clear Hrn. -destruct (dotprod_rel_R_exists_fma l s H3) as (rs & Hrs). -destruct (sum_rel_R_abs_exists_fma l s H3) as (rs_abs & Habs). -pose proof fma_dotprod_forward_error_rel l - s H3 IHl rs rs_abs Hrs Habs as Hacc. -apply Rabs_le_minus in Hacc. -set (n:=(length l)) in *. -assert (Hacc' : Rabs (FT2R s) <= (@g t n + 1) * rs_abs + @g1 t n (n - 1)). -{ eapply Rle_trans. -apply Hacc. replace (g n * rs_abs + g1 n (n - 1) + Rabs rs) -with ((@g t n * rs_abs + Rabs rs) + @g1 t n (n - 1)) by nra. -apply Rplus_le_compat_r. -field_simplify. -apply Rplus_le_compat_l. -rewrite <- (@R_dot_prod_rel_Rabs_eq (map FR2 l)); try nra; auto. -apply (@dot_prod_sum_rel_R_Rabs (map FR2 l)); auto. } clear Hacc. -pose proof dotprodR_rel_bound' as C. -pose proof dotprodR_rel_bound'' as D. -eapply Rle_lt_trans; [apply Rabs_triang |]. -rewrite Rabs_mult. -eapply Rle_lt_trans; [apply Rplus_le_compat |]. -apply Rmult_le_compat; try apply Rabs_pos. -eapply Rle_trans; [apply Rabs_triang |]. -apply Rplus_le_compat. -rewrite Rabs_mult. -apply Rmult_le_compat; try apply Rabs_pos. -apply Rlt_le; apply H; simpl; auto. -apply Rlt_le; apply H; simpl; auto. -eapply Rle_trans. -apply Hacc'. -apply Rplus_le_compat_r. -apply Rmult_le_compat_l. -apply Rplus_le_le_0_compat; try nra. apply g_pos. -apply D. -apply fun_bound_pos. -apply Hn'. -apply Habs. -intros; split; apply Rlt_le; apply H; simpl; auto. -assert (HD: Rabs (1 + del) <= (1 + @default_rel t )). -{ eapply Rle_trans; [apply Rabs_triang| rewrite Rabs_R1; apply Rplus_le_compat_l]. -eapply Rle_trans; [apply Hd |]; nra. } -apply HD. -apply He. -rewrite sqrt_def. -{ -(*algebra*) -unfold fun_bnd. -replace (length (a :: l)) with (S n) by (simpl; lia). -set (x:= (@g t ((S n) - 1) + 1)). -set (y:= (1 + INR (S n) * x)). -set (z:= @g1 t (S n) ((S n) - 1)). -set (u := ((fmax - @default_abs t) / (1 + @default_rel t) - z) * (1 / y)). -rewrite <- !Rplus_assoc. -replace (( u + (@g t n + 1) * (INR (length l) * u))) - with ( u * (1 + (@g t n + 1) * (INR (length l)))) - by nra. -apply Rcomplements.Rlt_minus_r. -apply Rcomplements.Rlt_div_r. -apply Rlt_gt; apply default_rel_plus_1_gt_0. -apply Rcomplements.Rlt_minus_r. -assert (0 < 1 + (@g t n + 1) * INR (length l)). -{ apply Rplus_lt_le_0_compat; try nra. -apply Rmult_le_pos; try apply pos_INR. -apply Rplus_le_le_0_compat; try nra; apply g_pos. } -apply Rcomplements.Rlt_div_r; auto. -assert (0 < 1 / (1 + INR (S (length l)) * (@g t (S (length l) - 1) + 1))). -{ apply Rcomplements.Rdiv_lt_0_compat; try nra. -apply Rplus_lt_le_0_compat; try nra. -apply Rmult_le_pos; try apply pos_INR. -apply Rplus_le_le_0_compat; try nra; apply g_pos. } -assert (0 < 1 + INR (S n) * (@g t (S n - 1) + 1)). -{ -apply Rplus_lt_le_0_compat; try nra. -apply Rmult_le_pos; try apply pos_INR. -apply Rplus_le_le_0_compat; try nra; apply g_pos. } -rewrite rdiv_mult_eq; try nra. -unfold u, z, y, x. -apply Rmult_lt_compat. -apply fun_bnd_pos_1; auto. -apply Rlt_le; auto. -unfold fmax. -apply Rminus_lt_minus. -replace n with (length l). -assert (Hl: l = [] \/ l <> []). -destruct l; auto. -right. -eapply hd_error_some_nil; simpl; auto. -destruct Hl. subst. -simpl. unfold g1, g; field_simplify; simpl. field_simplify; apply default_abs_gt_0. -apply size_not_empty_nat in H4. -change @length with @size in *. -replace (S (size l) - 1)%nat with (S (size l - 1))%nat by lia. -apply g1n_lt_g1Sn; auto. lia. -subst n; auto. -apply Rcomplements.Rlt_div_r. -apply Rlt_gt. -replace (/ (1 + INR (S n) * (@g t (S n - 1) + 1))) with - (1/(1 + INR (S n) * (@g t (S n - 1) + 1))) by nra. -apply Rdiv_lt_0_compat; try nra. -field_simplify; try nra. -apply Rcomplements.Rlt_div_r; try nra. -rewrite Rmult_1_l. -apply Rplus_lt_le_compat; try nra. -apply Rplus_le_lt_compat. -rewrite Rmult_comm. -apply Rmult_le_compat; [ apply pos_INR | apply g_pos | | ]. -apply le_INR; lia. -replace (S n - 1)%nat with (n)%nat by lia; try nra. -unfold n. -apply lt_INR; lia. -} -apply fun_bound_pos; auto. -} + intros v1 v2. + induction (List.combine v1 v2) as [| a l IHl]. + { (* base case: empty list *) + intros fp Hfp Hn Hbnd. + inversion Hfp; subst; auto. } + { (* inductive case: a :: l *) + intros fp Hfp Hn Hbnd. + inversion Hfp; subst. + (* Establish the g1 bound for the shorter list *) + assert (Hn' : @g1 t (S (length l) + 1) (S (length l)) <= fmax). + { eapply Rle_trans; [| apply Hn]; simpl. + set (n := (length l + 1)%nat). + replace (length l) with (n - 1)%nat by lia. + replace ((n - 1).+1 + 1)%nat with (n.+1) by lia. + replace ((n - 1).+2 + 1)%nat with (n.+1.+1) by lia. + replace ((n - 1).+1)%nat with (n.+1 - 1)%nat by lia. + apply g1n_le_g1Sn; lia. } + (* Propagate element-wise bounds to the tail *) + assert (Hin : forall x : ftype t * ftype t, + In x l -> + Binary.is_finite (fst x) = true /\ + Binary.is_finite (snd x) = true /\ + Rabs (FT2R (fst x)) < sqrt (fun_bnd t (length l)) /\ + Rabs (FT2R (snd x)) < sqrt (fun_bnd t (length l))). + { intros x Hx; repeat split; + [ apply Hbnd; simpl; auto + | apply Hbnd; simpl; auto + | eapply Rlt_le_trans; + [ apply Hbnd; simpl; auto + | apply sqrt_le_1_alt; apply fun_bnd_le; auto ] + | eapply Rlt_le_trans; + [ apply Hbnd; simpl; auto + | apply sqrt_le_1_alt; apply fun_bnd_le; auto ] ]. } + (* Finiteness of the head elements *) + assert (Hfina : Binary.is_finite (fst a) = true /\ + Binary.is_finite (snd a) = true). + { split; apply Hbnd; simpl; auto. } + destruct Hfina as [Hfina1 Hfina2]. + (* Apply the inductive hypothesis to obtain finiteness of the tail accumulator *) + specialize (IHl s H2 Hn' Hin). + (* Reduce to showing no overflow for the outermost FMA *) + apply is_finite_fma_no_overflow'; auto. + unfold fma_no_overflow, rounded. + destruct (@generic_round_property t (FT2R (fst a) * FT2R (snd a) + FT2R s)) + as (del & eps & Hed & Hd & He & Hrn). + rewrite Hrn; clear Hrn. + (* Obtain real-valued dot product witnesses *) + destruct (dotprod_rel_R_exists_fma l s H2) as (rs & Hrs). + destruct (sum_rel_R_abs_exists_fma l s H2) as (rs_abs & Habs). + (* Forward error bound for the tail accumulator *) + pose proof fma_dotprod_forward_error_rel l s H2 IHl rs rs_abs Hrs Habs as Hacc. + apply Rabs_le_minus in Hacc. + set (n := length l) in *. + (* Bound the absolute value of the partial sum *) + assert (Hacc' : Rabs (FT2R s) <= + (@g t n + 1) * rs_abs + @g1 t n (n - 1)). + { eapply Rle_trans; [apply Hacc |]. + replace (g n * rs_abs + g1 n (n - 1) + Rabs rs) + with ((@g t n * rs_abs + Rabs rs) + @g1 t n (n - 1)) by nra. + apply Rplus_le_compat_r. + field_simplify. + apply Rplus_le_compat_l. + rewrite <- (@R_dot_prod_rel_Rabs_eq (map FR2 l)); try nra; auto. + apply (@dot_prod_sum_rel_R_Rabs (map FR2 l)); auto. } + clear Hacc. + pose proof dotprodR_rel_bound' as C. + pose proof dotprodR_rel_bound'' as D. + (* Upper bound on the FMA output *) + eapply Rle_lt_trans; [apply Rabs_triang |]. + rewrite Rabs_mult. + eapply Rle_lt_trans; [apply Rplus_le_compat |]. + { apply Rmult_le_compat; try apply Rabs_pos. + - eapply Rle_trans; [apply Rabs_triang |]. + apply Rplus_le_compat. + + rewrite Rabs_mult. + apply Rmult_le_compat; try apply Rabs_pos; + apply Rlt_le; apply Hbnd; simpl; auto. + + eapply Rle_trans; [apply Hacc' |]. + apply Rplus_le_compat_r. + apply Rmult_le_compat_l; + [ apply Rplus_le_le_0_compat; try nra; apply g_pos | apply D ]. + * apply fun_bound_pos; apply Hn'. + * apply Habs. + * intros; split; apply Rlt_le; apply Hbnd; simpl; auto. + - assert (HD : Rabs (1 + del) <= 1 + @default_rel t). + { eapply Rle_trans; [apply Rabs_triang |]. + rewrite Rabs_R1; apply Rplus_le_compat_l. + eapply Rle_trans; [apply Hd |]; nra. } + apply HD. + } + apply He. + (* Final algebraic inequality using fun_bnd structure *) + rewrite sqrt_def. + { unfold fun_bnd. + replace (length (a :: l)) with (S n) by (simpl; lia). + set (x := (@g t (S n - 1) + 1)). + set (y := (1 + INR (S n) * x)). + set (z := @g1 t (S n) (S n - 1)). + set (u := ((fmax - @default_abs t) / (1 + @default_rel t) - z) * (1 / y)). + rewrite <- !Rplus_assoc. + replace (u + (@g t n + 1) * (INR (length l) * u)) + with (u * (1 + (@g t n + 1) * INR (length l))) by nra. + apply Rcomplements.Rlt_minus_r. + apply Rcomplements.Rlt_div_r; + [apply Rlt_gt; apply default_rel_plus_1_gt_0 |]. + apply Rcomplements.Rlt_minus_r. + assert (Hpos_n : 0 < 1 + (@g t n + 1) * INR (length l)). + { apply Rplus_lt_le_0_compat; try nra. + apply Rmult_le_pos; + [ apply Rplus_le_le_0_compat; try nra; apply g_pos + | apply pos_INR ]. } + apply Rcomplements.Rlt_div_r; auto. + assert (Hpos_y : 0 < 1 / (1 + INR (S (length l)) * + (@g t (S (length l) - 1) + 1))). + { apply Rcomplements.Rdiv_lt_0_compat; try nra. + apply Rplus_lt_le_0_compat; try nra. + apply Rmult_le_pos; + [ apply pos_INR + | apply Rplus_le_le_0_compat; try nra; apply g_pos ]. } + assert (Hpos_Sn : 0 < 1 + INR (S n) * (@g t (S n - 1) + 1)). + { apply Rplus_lt_le_0_compat; try nra. + apply Rmult_le_pos; + [ apply pos_INR + | apply Rplus_le_le_0_compat; try nra; apply g_pos ]. } + rewrite rdiv_mult_eq; try nra. + unfold u, z, y, x. + apply Rmult_lt_compat; + [apply fun_bnd_pos_1; auto | apply Rlt_le; auto | |]. + { unfold fmax. + apply Rminus_lt_minus. + replace n with (length l) by (subst n; auto). + assert (Hl : l = [] \/ l <> []). + { destruct l; [left; auto | right]. + eapply hd_error_some_nil; simpl; auto. } + destruct Hl as [Hl | Hl]. + - subst; simpl; unfold g1, g; field_simplify; simpl; + field_simplify; apply default_abs_gt_0. + - apply size_not_empty_nat in Hl. + change @length with @size in *. + replace (S (size l) - 1)%nat with (S (size l - 1))%nat by lia. + apply g1n_lt_g1Sn; auto; lia. } + { apply Rcomplements.Rlt_div_r. + - apply Rlt_gt. + replace (/ (1 + INR (S n) * (@g t (S n - 1) + 1))) + with (1 / (1 + INR (S n) * (@g t (S n - 1) + 1))) by nra. + apply Rdiv_lt_0_compat; try nra. + - field_simplify; try nra. + apply Rcomplements.Rlt_div_r; try nra. + rewrite Rmult_1_l. + apply Rplus_lt_le_compat; try nra. + apply Rplus_le_lt_compat. + + rewrite Rmult_comm. + apply Rmult_le_compat; + [ apply pos_INR + | apply g_pos + | apply le_INR; lia + | replace (S n - 1)%nat with n%nat by lia; nra ]. + + unfold n; apply lt_INR; lia. } } + apply fun_bound_pos; auto. } Qed. - End NAN. \ No newline at end of file diff --git a/accuracy_proofs/gemm_acc.v b/accuracy_proofs/gemm_acc.v index 64b4d44..bd764bb 100644 --- a/accuracy_proofs/gemm_acc.v +++ b/accuracy_proofs/gemm_acc.v @@ -1,195 +1,413 @@ -From LAProof.accuracy_proofs Require Import preamble common - dotprod_model sum_model dot_acc float_acc_lems mv_mathcomp gemv_acc vec_op_acc. +(** * Matrix Multiplication Forward and Mixed Error Analysis -Section MMERROR. -(* forward error matrix multiplication *) -Context {NAN: FPCore.Nans} {t : FPStdLib.type}. + This module establishes rigorous rounding error bounds for floating-point + matrix operations, including matrix multiplication, scalar-matrix + multiplication, matrix addition, and the general matrix multiply-accumulate + (GEMM) operation. -Notation g := (@common.g t). + ** Error Factors + + Throughout, the accumulated relative error factor is + %$g(n) = (1 + \mathbf{u})^n - 1$%#\(g(n) = (1 + \mathbf{u})^n - 1\)# and + the mixed absolute error factor is + %$g_1(n_1, n_2) = n_1 \cdot \eta \cdot (1 + g(n_2))$%#\(g_1(n_1, n_2) = n_1 \cdot \eta \cdot (1 + g(n_2))\)#, + where %$\mathbf{u}$%#\(\mathbf{u}\)# is the unit roundoff and + %$\eta$%#\(\eta\)# is the underflow threshold for the given type. + Both are defined in [common]. + + ** Error Bound Taxonomy + + The theorems in this file fall into three categories: + + _Pure forward error bounds_ characterize the absolute difference between + the computed result and the exact result in terms of the input data. + No reference is made to a nearby exact problem. + + _Mixed forward-backward error bounds_ express the computed result as the + exact result of a slightly perturbed input (backward component), where the + perturbation is bounded in terms of the original input data (forward + component). The perturbation appears as an additive error matrix + satisfying entry-wise bounds. + + _Pure backward error bounds_ express the computed result as the exact + result of a perturbed input problem, with no forward error term. + + ** Key Results + + - [MMC_error] _(mixed)_: Shows that the floating-point matrix product + %$\mathtt{fl}(AB)$%#\(\mathtt{fl}(AB)\)# equals the exact product of slightly perturbed + columns plus a small entry-wise absolute error. The column perturbation + is bounded column-wise by %$g(n)$%#\(g(n)\)# relative to the input, and the + absolute residual by %$g_1(n,n)$%#\(g_1(n,n)\)# per entry. + + - [scaleM_error] _(mixed)_: Shows that floating-point scalar-matrix + multiplication equals exact scaling of a slightly perturbed matrix plus + a small entry-wise absolute error. The relative perturbation is bounded + by %$\mathbf{u}$%#\(\mathbf{u}\)# and the absolute residual by %$\eta$%#\(\eta\)#. + + - [sMMC_error] _(mixed)_: Composes [MMC_error] and [scaleM_error] to give + a structured decomposition of %$\mathtt{fl}(x \cdot (AB))$%#\(\mathtt{fl}(x \cdot (AB))\)# + with backward perturbations from both the matrix product and the scaling + step, together with forward absolute errors from each. + + - [mat_sum_error] _(pure backward)_: Shows that floating-point matrix + addition equals the exact sum of two slightly perturbed matrices, with + each entry perturbed by a relative factor bounded by %$\mathbf{u}$%#\(\mathbf{u}\)#. + No forward error term appears. + + - [mat_axpby_error] _(mixed)_: Bounds %$\mathtt{fl}(xA + yB)$%#\(\mathtt{fl}(xA + yB)\)# by + combining mixed errors from each scaling step with a backward error from + the floating-point addition, yielding relative perturbations of the + inputs and small absolute forward errors. + + - [GEMM_error] _(mixed)_: Master theorem for + %$\mathtt{fl}(s_1(AB) + s_2 Y)$%#\(\mathtt{fl}(s_1(AB) + s_2 Y)\)#. Decomposes the + full GEMM result into backward perturbation components and forward + absolute errors from matrix multiplication, scalar scaling, and + matrix addition. +*) + +From LAProof.accuracy_proofs Require Import + preamble common dotprod_model sum_model dot_acc + float_acc_lems mv_mathcomp gemv_acc vec_op_acc. + +Section MMERROR. + +(** We work in an abstract floating-point context [NAN] (specifying NaN + behavior) and over an abstract floating-point type << t >>. *) + +Context {NAN : FPCore.Nans} {t : FPStdLib.type}. + +Notation g := (@common.g t). Notation g1 := (@common.g1 t). -Theorem MMC_error: - forall m n p (A: 'M[ftype t]_(m,n)) (B: 'M[ftype t]_(n,p)) - (Hfin: F.finitemx (F.mulmx A B)), - exists (E eta: 'M[R]_(m,p)), - map_mx FT2R (F.mulmx A B) = (map_mx FT2R A *m map_mx FT2R B + E + eta)%Ri - /\ (forall k: 'I_p, - exists E0: 'M[R]_(m,n), - col k E = E0 *m col k (map_mx FT2R B) /\ - (forall i j, Rabs (E0 i j) <= g n * Rabs (map_mx FT2R A i j))) - /\ forall i j, Rabs (eta i j) <= g1 n n. +(** ** Matrix Multiplication Error + + [MMC_error] establishes that the floating-point matrix product equals the + exact product plus a column-wise backward perturbation and a small + entry-wise absolute offset. The relative column perturbation is bounded + by %$g(n)$%#\(g(n)\)# and the absolute residual by %$g_1(n,n)$%#\(g_1(n,n)\)# per entry, + where %$n$%#\(n\)# is the inner dimension. *) + +Theorem MMC_error : + forall m n p + (A : 'M[ftype t]_(m, n)) + (B : 'M[ftype t]_(n, p)) + (Hfin : F.finitemx (F.mulmx A B)), + exists (E eta : 'M[R]_(m, p)), + map_mx FT2R (F.mulmx A B) + = (map_mx FT2R A *m map_mx FT2R B + E + eta)%Ri + /\ (forall k : 'I_p, + exists E0 : 'M[R]_(m, n), + col k E = E0 *m col k (map_mx FT2R B) + /\ (forall i j, + Rabs (E0 i j) <= g n * Rabs (map_mx FT2R A i j))) + /\ (forall i j, Rabs (eta i j) <= g1 n n). Proof. -move => m n p. -elim: p. -- move => A B Hfin. -exists (const_mx 0), (const_mx 0). -repeat split. -+ apply /matrixP. move => i j. destruct j; lia. -+ move => k; destruct k; lia. -+ move => i j; destruct j; lia. -- move => p IH A B Hfin. -change (p.+1) with (1+p)%nat in *. -rewrite -(hsubmxK B) F.mulmx_row map_row_mx. -destruct (IH A (rsubmx B)) as [E [eta [Heq [HE Heta]]]]. { - move => i j. move :(Hfin i (rshift 1 j)). rewrite /F.mulmx !mxE col_rsubmx //. -} -clear IH. rewrite {}Heq. -destruct (mat_vec_mul_mixed_error A (lsubmx B)) as [E1 [eta1 [Heq1 [HE1 Heta1]]]]. { - move => i j. move :(Hfin i (lshift p j)). rewrite /F.mulmx !mxE col_lsubmx //. -} -rewrite {}Heq1. -exists (row_mx (E1 *m map_mx FT2R (lsubmx B)) E), (row_mx eta1 eta). -repeat split. -+ -rewrite map_lsubmx map_rsubmx hsubmxK -add_row_mx mulmxDl -add_row_mx. -f_equal. -f_equal. -rewrite -mul_mx_row hsubmxK //. -+ -move => k. -case_splitP k. -* exists E1. split => //. - rewrite colKl map_row_mx colKl !col_id //. -* destruct (HE k) as (E0 & Heq2 & HE0). - exists E0; split => //. - rewrite colKr map_row_mx colKr //. -+ -move => i j. -case_splitP j. -rewrite row_mxEl //. -rewrite row_mxEr //. + move => m n p. + elim: p. + - (** Base case: p = 0 columns; all column index types are empty. *) + move => A B Hfin. + exists (const_mx 0), (const_mx 0). + repeat split. + + apply /matrixP => i j; destruct j; lia. + + move => k; destruct k; lia. + + move => i j; destruct j; lia. + - (** Inductive step: split B into its leftmost column-block and the + remaining columns, apply the induction hypothesis to the right block, + and apply [mat_vec_mul_mixed_error] to the left block. *) + move => p IH A B Hfin. + change (p.+1) with (1 + p)%nat in *. + rewrite -(hsubmxK B) F.mulmx_row map_row_mx. + + (** Apply the induction hypothesis to the right submatrix of B. *) + destruct (IH A (rsubmx B)) as [E [eta [Heq [HE Heta]]]]. { + move => i j. + move : (Hfin i (rshift 1 j)). + rewrite /F.mulmx !mxE col_rsubmx //. + } + clear IH. + rewrite {}Heq. + + (** Apply the matrix-vector mixed error lemma to the left submatrix. *) + destruct (mat_vec_mul_mixed_error A (lsubmx B)) + as [E1 [eta1 [Heq1 [HE1 Heta1]]]]. { + move => i j. + move : (Hfin i (lshift p j)). + rewrite /F.mulmx !mxE col_lsubmx //. + } + rewrite {}Heq1. + + exists (row_mx (E1 *m map_mx FT2R (lsubmx B)) E), + (row_mx eta1 eta). + repeat split. + + (** Reassemble the block-column equation. *) + rewrite map_lsubmx map_rsubmx hsubmxK + -add_row_mx mulmxDl -add_row_mx. + f_equal; f_equal. + rewrite -mul_mx_row hsubmxK //. + + (** Column-wise backward error bound for E. *) + move => k. + case_splitP k. + * exists E1; split => //. + rewrite colKl map_row_mx colKl !col_id //. + * destruct (HE k) as (E0 & Heq2 & HE0). + exists E0; split => //. + rewrite colKr map_row_mx colKr //. + + (** Entry-wise forward absolute error bound for eta. *) + move => i j. + case_splitP j. + * rewrite row_mxEl //. + * rewrite row_mxEr //. Qed. -Theorem scaleM_error: - forall m n (A: 'M[ftype t]_(m,n)) (x: ftype t) - (Hfin: F.finitemx (F.scalemx x A)), - exists (E eta: 'M[R]_(m,n)), - map_mx FT2R (F.scalemx x A) = - scalemx (FT2R x) (map_mx FT2R A + E) + eta - /\ (forall i j, Rabs (E i j) <= @default_rel t * Rabs (map_mx FT2R A i j)) - /\ (forall i j, Rabs (eta i j) <= @default_abs t). +(** ** Scalar-Matrix Multiplication Error + + [scaleM_error] establishes that floating-point scalar-matrix + multiplication %$\mathtt{fl}(x \cdot A)$%#\(\mathtt{fl}(x \cdot A)\)# equals exact scaling of + a slightly perturbed matrix plus a small entry-wise absolute offset. + The relative perturbation is bounded entry-wise by the unit roundoff + %$\mathbf{u}$%#\(\mathbf{u}\)# and the absolute residual by the underflow threshold + %$\eta$%#\(\eta\)#. *) + +Theorem scaleM_error : + forall m n + (A : 'M[ftype t]_(m, n)) + (x : ftype t) + (Hfin : F.finitemx (F.scalemx x A)), + exists (E eta : 'M[R]_(m, n)), + map_mx FT2R (F.scalemx x A) + = scalemx (FT2R x) (map_mx FT2R A + E) + eta + /\ (forall i j, + Rabs (E i j) <= @default_rel t * Rabs (map_mx FT2R A i j)) + /\ (forall i j, + Rabs (eta i j) <= @default_abs t). Proof. -intros. -apply Fscalemx_mixed_error in Hfin. -destruct Hfin as [e [eta [? [? ?]]]]. -exists e, eta; split; auto. split; intros; auto. -destruct (H0 i j) as [d [? ?]]. -rewrite H2 Rabs_mult Rmult_comm. -apply /RleP; apply Rmult_le_compat_r. -apply Rabs_pos. -apply /RleP; auto. + intros m n A x Hfin. + apply Fscalemx_mixed_error in Hfin. + destruct Hfin as [e [eta [Heq [He Heta]]]]. + exists e, eta. + split; auto. + split; intros i j; auto. + destruct (He i j) as [d [Hd Hbd]]. + rewrite Hd Rabs_mult Rmult_comm. + apply /RleP. + apply Rmult_le_compat_r. + - apply Rabs_pos. + - apply /RleP; auto. Qed. -Theorem sMMC_error: - forall m n p (A: 'M[ftype t]_(m,n)) (B: 'M[ftype t]_(n,p)) (x: ftype t) - (Hfin: F.finitemx (F.scalemx x (F.mulmx A B))), - exists E1 E eta1 eta: 'M[R]_(m,p), - map_mx FT2R (F.scalemx x (F.mulmx A B)) = - scalemx (FT2R x) (((map_mx FT2R A *m map_mx FT2R B + E1) + eta1) + E) + eta - /\ (forall k: 'I_p, exists E0, - col k E1 = E0 *m (col k (map_mx FT2R B)) /\ - (forall i j, Rabs (E0 i j) <= g n * Rabs (map_mx FT2R A i j))) - /\ (forall i j, Rabs (eta1 i j) <= g1 n n) - /\ (forall i j, Rabs (eta i j) <= @default_abs t) - /\ (forall i j, Rabs (E i j) <= @default_rel t * Rabs (((map_mx FT2R A *m map_mx FT2R B + E1) + eta1)%Ri i j)). +(** ** Scaled Matrix Product Error + + [sMMC_error] composes [MMC_error] and [scaleM_error] to give a + structured decomposition of + %$\mathtt{fl}(x \cdot (AB))$%#\(\mathtt{fl}(x \cdot (AB))\)#. The result carries + backward perturbations from the matrix product (bounded by %$g(n)$%#\(g(n)\)# + column-wise) and from the scaling step (bounded by %$\mathbf{u}$%#\(\mathbf{u}\)# + entry-wise), together with forward absolute errors from each, bounded by + %$g_1(n,n)$%#\(g_1(n,n)\)# and %$\eta$%#\(\eta\)# respectively. *) + +Theorem sMMC_error : + forall m n p + (A : 'M[ftype t]_(m, n)) + (B : 'M[ftype t]_(n, p)) + (x : ftype t) + (Hfin : F.finitemx (F.scalemx x (F.mulmx A B))), + exists E1 E eta1 eta : 'M[R]_(m, p), + map_mx FT2R (F.scalemx x (F.mulmx A B)) + = scalemx (FT2R x) + (((map_mx FT2R A *m map_mx FT2R B + E1) + eta1) + E) + eta + /\ (forall k : 'I_p, + exists E0, + col k E1 = E0 *m col k (map_mx FT2R B) + /\ (forall i j, + Rabs (E0 i j) <= g n * Rabs (map_mx FT2R A i j))) + /\ (forall i j, Rabs (eta1 i j) <= g1 n n) + /\ (forall i j, Rabs (eta i j) <= @default_abs t) + /\ (forall i j, + Rabs (E i j) <= + @default_rel t * + Rabs (((map_mx FT2R A *m map_mx FT2R B + E1) + eta1)%Ri i j)). Proof. -move => m n p A B x Hfin. -destruct (scaleM_error _ _ (F.mulmx A B) x Hfin) - as (E & eta & Heq & HE & Heta). -rewrite Heq. -destruct (MMC_error _ _ _ A B) - as (E1 & eta1 & Heq1 & HE1 & Heta1). - apply (F.finitemx_scalemx_e _ _ Hfin). -rewrite Heq1. -exists E1, E, eta1, eta; repeat split => //. -move => i j. -move :(HE i j). -rewrite Heq1 //. + move => m n p A B x Hfin. + + (** Decompose the outer scaling error for x * (A*B). *) + destruct (scaleM_error _ _ (F.mulmx A B) x Hfin) + as (E & eta & Heq & HE & Heta). + rewrite Heq. + + (** Decompose the matrix-multiplication error for A*B, + propagating finiteness from [F.scalemx x (F.mulmx A B)]. *) + destruct (MMC_error _ _ _ A B) + as (E1 & eta1 & Heq1 & HE1 & Heta1). { + apply (F.finitemx_scalemx_e _ _ Hfin). + } + rewrite Heq1. + + exists E1, E, eta1, eta. + repeat split => //. + move => i j. + move : (HE i j). + rewrite Heq1 //. Qed. -Theorem mat_sum_error: - forall m n (A B: 'M[ftype t]_(m,n)) - (Hfin: F.finitemx (F.addmx A B)), - exists EA EB: 'M[R]_(m,n), - map_mx FT2R (F.addmx A B) = - (map_mx FT2R A + EA) + (map_mx FT2R B + EB) - /\ (forall i j, exists d, EA i j = map_mx FT2R A i j * d /\ Rabs d <= @default_rel t) - /\ (forall i j, exists d, EB i j = map_mx FT2R B i j * d /\ Rabs d <= @default_rel t). +(** ** Matrix Addition Error + + [mat_sum_error] establishes that floating-point matrix addition + %$\mathtt{fl}(A + B)$%#\(\mathtt{fl}(A + B)\)# equals the exact sum of two slightly + perturbed matrices. Each entry of both perturbation matrices is bounded + in relative terms by the unit roundoff %$\mathbf{u}$%#\(\mathbf{u}\)#. This is a + pure backward result: no forward error term appears. *) + +Theorem mat_sum_error : + forall m n + (A B : 'M[ftype t]_(m, n)) + (Hfin : F.finitemx (F.addmx A B)), + exists EA EB : 'M[R]_(m, n), + map_mx FT2R (F.addmx A B) + = (map_mx FT2R A + EA) + (map_mx FT2R B + EB) + /\ (forall i j, exists d, + EA i j = map_mx FT2R A i j * d /\ Rabs d <= @default_rel t) + /\ (forall i j, exists d, + EB i j = map_mx FT2R B i j * d /\ Rabs d <= @default_rel t). Proof. -intros. -destruct (Faddmx_mixed_error A B Hfin) as [EA [EB [Heq [HA HB]]]]. -exists EA, EB; repeat split; auto. + intros m n A B Hfin. + destruct (Faddmx_mixed_error A B Hfin) as [EA [EB [Heq [HA HB]]]]. + exists EA, EB. + repeat split; auto. Qed. -Theorem mat_axpby_error: - forall [m n] (A B: 'M[ftype t]_(m,n)) (x y: ftype t) - (Hfin: F.finitemx (F.addmx (F.scalemx x A) (F.scalemx y B))), - exists EA EB ea eb eta1 eta2: 'M[R]_(m,n), - map_mx FT2R (F.addmx (F.scalemx x A) (F.scalemx y B)) = - scalemx (FT2R x) (map_mx FT2R A + EA) + eta1 + ea - + scalemx (FT2R y) (map_mx FT2R B + EB) + eta2 + eb - /\ (forall i j, Rabs (EA i j) <= @default_rel t * Rabs (map_mx FT2R A i j)) - /\ (forall i j, Rabs (EB i j) <= @default_rel t * Rabs (map_mx FT2R B i j)) - /\ (forall i j, exists d, ea i j = (scalemx (FT2R x) (map_mx FT2R A + EA) + eta1) i j * d - /\ Rabs d <= @default_rel t) - /\ (forall i j, exists d, eb i j = (scalemx (FT2R y) (map_mx FT2R B + EB) + eta2) i j * d - /\ Rabs d <= @default_rel t) - /\ (forall i j, Rabs (eta1 i j) <= @default_abs t) - /\ (forall i j, Rabs (eta2 i j) <= @default_abs t). +(** ** Scaled Matrix Sum Error + + [mat_axpby_error] bounds the floating-point operation + %$\mathtt{fl}(xA + yB)$%#\(\mathtt{fl}(xA + yB)\)# by combining the mixed errors from + each scaling step with a backward error from the floating-point addition. + The result decomposes into relative perturbations of %$A$%#\(A\)# and %$B$%#\(B\)# + bounded by %$\mathbf{u}$%#\(\mathbf{u}\)# and irreducible absolute forward errors + from each scaling, bounded by %$\eta$%#\(\eta\)#. *) + +Theorem mat_axpby_error : + forall [m n] + (A B : 'M[ftype t]_(m, n)) + (x y : ftype t) + (Hfin : F.finitemx + (F.addmx (F.scalemx x A) (F.scalemx y B))), + exists EA EB ea eb eta1 eta2 : 'M[R]_(m, n), + map_mx FT2R (F.addmx (F.scalemx x A) (F.scalemx y B)) + = scalemx (FT2R x) (map_mx FT2R A + EA) + eta1 + ea + + scalemx (FT2R y) (map_mx FT2R B + EB) + eta2 + eb + /\ (forall i j, + Rabs (EA i j) <= @default_rel t * Rabs (map_mx FT2R A i j)) + /\ (forall i j, + Rabs (EB i j) <= @default_rel t * Rabs (map_mx FT2R B i j)) + /\ (forall i j, exists d, + ea i j + = (scalemx (FT2R x) (map_mx FT2R A + EA) + eta1) i j * d + /\ Rabs d <= @default_rel t) + /\ (forall i j, exists d, + eb i j + = (scalemx (FT2R y) (map_mx FT2R B + EB) + eta2) i j * d + /\ Rabs d <= @default_rel t) + /\ (forall i j, Rabs (eta1 i j) <= @default_abs t) + /\ (forall i j, Rabs (eta2 i j) <= @default_abs t). Proof. -move => m n A B x y Hfin. -destruct (mat_sum_error _ _ (F.scalemx x A) (F.scalemx y B)) - as (ea & eb & HEQ & H1 & H2) => //. -destruct (scaleM_error _ _ A x) as - (EA & eta1 & Heqx & H6 & H7) => //. - apply (F.finitemx_addmx_e _ _ Hfin). -destruct (scaleM_error _ _ B y) as - (EB & eta2 & Heqy & H12 & H13) => //. -apply (F.finitemx_addmx_e _ _ Hfin). -rewrite {}HEQ. -rewrite {}Heqx in H1|-*. -rewrite {}Heqy in H2|-*. -exists EA, EB, ea, eb, eta1, eta2; + move => m n A B x y Hfin. + + (** Decompose the outer addition as a pure backward error. *) + destruct (mat_sum_error _ _ (F.scalemx x A) (F.scalemx y B)) + as (ea & eb & HEQ & H1 & H2) => //. + + (** Decompose the mixed error for the scaling x * A. *) + destruct (scaleM_error _ _ A x) + as (EA & eta1 & Heqx & H6 & H7). { + apply (F.finitemx_addmx_e _ _ Hfin). + } + + (** Decompose the mixed error for the scaling y * B. *) + destruct (scaleM_error _ _ B y) + as (EB & eta2 & Heqy & H12 & H13). { + apply (F.finitemx_addmx_e _ _ Hfin). + } + + rewrite {}HEQ {}Heqx in H1 |- *. + rewrite {}Heqy in H2 |- *. + exists EA, EB, ea, eb, eta1, eta2. repeat split => //. -rewrite !addrA //. + rewrite !addrA //. Qed. -Theorem GEMM_error: - forall [m n p] (A: 'M[ftype t]_(m,n)) (B: 'M[ftype t]_(n,p)) (Y: 'M[ftype t]_(m,p)) - (s1 s2: ftype t) - (Hfin: F.finitemx (F.addmx (F.scalemx s1 (F.mulmx A B)) (F.scalemx s2 Y))), - exists ab1 ab2 ab3 ab4 ab5 y1 y2 y3: 'M[R]_(m,p), - map_mx FT2R (F.addmx (F.scalemx s1 (F.mulmx A B)) (F.scalemx s2 Y)) = - (scalemx (FT2R s1) ((((map_mx FT2R A *m map_mx FT2R B)+ ab1) + ab2) + ab3) + ab4) + ab5 + - ((scalemx (FT2R s2) (map_mx FT2R Y + y1) + y2) + y3) - /\ (forall k: 'I_p, exists E0, - col k ab1 = E0 *m col k (map_mx FT2R B) /\ - (forall i j, Rabs (E0 i j) <= g n * Rabs (map_mx FT2R A i j))) - /\ (forall i j, Rabs (ab2 i j) <= g1 n n) - /\ (forall i j, Rabs (ab3 i j) <= @default_rel t * Rabs ((((map_mx FT2R A *m map_mx FT2R B) + ab1) + ab2)%Ri i j)) - /\ (forall i j, Rabs (y1 i j) <= @default_rel t * Rabs (map_mx FT2R Y i j)) - /\ (forall i j, exists d, ab5 i j = (scalemx (FT2R s1) ((((map_mx FT2R A *m map_mx FT2R B) + ab1) + ab2) + ab3 )+ ab4) i j * d - /\ Rabs d <= @default_rel t) - /\ (forall i j, exists d, y3 i j = (scalemx (FT2R s2) (map_mx FT2R Y + y1) + y2) i j * d - /\ Rabs d <= @default_rel t) - /\ (forall i j, Rabs (ab4 i j) <= @default_abs t) - /\ (forall i j, Rabs (y2 i j) <= @default_abs t). +(** ** General GEMM Error + + [GEMM_error] is the master error theorem for the floating-point GEMM + operation %$\mathtt{fl}(s_1(AB) + s_2 Y)$%#\(\mathtt{fl}(s_1(AB) + s_2 Y)\)#. It + decomposes the result into backward perturbation components and forward + absolute errors arising from matrix multiplication (bounded by %$g(n)$%#\(g(n)\)# + and %$g_1(n,n)$%#\(g_1(n,n)\)#), scalar scaling (bounded by %$\mathbf{u}$%#\(\mathbf{u}\)# and + %$\eta$%#\(\eta\)#), and matrix addition (backward, bounded by %$\mathbf{u}$%#\(\mathbf{u}\)#). + The proof composes [mat_axpby_error] and [MMC_error]. *) + +Theorem GEMM_error : + forall [m n p] + (A : 'M[ftype t]_(m, n)) + (B : 'M[ftype t]_(n, p)) + (Y : 'M[ftype t]_(m, p)) + (s1 s2 : ftype t) + (Hfin : F.finitemx + (F.addmx (F.scalemx s1 (F.mulmx A B)) (F.scalemx s2 Y))), + exists ab1 ab2 ab3 ab4 ab5 y1 y2 y3 : 'M[R]_(m, p), + map_mx FT2R + (F.addmx (F.scalemx s1 (F.mulmx A B)) (F.scalemx s2 Y)) + = (scalemx (FT2R s1) + ((((map_mx FT2R A *m map_mx FT2R B) + ab1) + ab2) + ab3) + + ab4) + ab5 + + ((scalemx (FT2R s2) (map_mx FT2R Y + y1) + y2) + y3) + /\ (forall k : 'I_p, + exists E0, + col k ab1 = E0 *m col k (map_mx FT2R B) + /\ (forall i j, + Rabs (E0 i j) <= g n * Rabs (map_mx FT2R A i j))) + /\ (forall i j, Rabs (ab2 i j) <= g1 n n) + /\ (forall i j, + Rabs (ab3 i j) <= + @default_rel t * + Rabs ((((map_mx FT2R A *m map_mx FT2R B) + ab1) + ab2)%Ri i j)) + /\ (forall i j, + Rabs (y1 i j) <= @default_rel t * Rabs (map_mx FT2R Y i j)) + /\ (forall i j, exists d, + ab5 i j + = (scalemx (FT2R s1) + ((((map_mx FT2R A *m map_mx FT2R B) + ab1) + ab2) + ab3) + + ab4) i j * d + /\ Rabs d <= @default_rel t) + /\ (forall i j, exists d, + y3 i j + = (scalemx (FT2R s2) (map_mx FT2R Y + y1) + y2) i j * d + /\ Rabs d <= @default_rel t) + /\ (forall i j, Rabs (ab4 i j) <= @default_abs t) + /\ (forall i j, Rabs (y2 i j) <= @default_abs t). Proof. -intros. -(* compose errors from axpby and MMC *) -destruct (mat_axpby_error (F.mulmx A B) Y s1 s2) - as (ab3 & y1 & ab5 & y3 & ab4 & y2 & Heq1 & Hab3 & Hy1 & Hab5 & Hy3 & Hab4 & Hy2) => //. -destruct (MMC_error _ _ _ A B) - as (ab1 & ab2 & Heq2 & Hab1 & Hab2) => //. - apply F.finitemx_addmx_e in Hfin; destruct Hfin as [Hfin _]. - apply (F.finitemx_scalemx_e _ _ Hfin). -rewrite {}Heq1. -rewrite {}Heq2 in Hab5,Hab3|-*. -exists ab1, ab2, ab3, ab4, ab5, y1, y2, y3; repeat split => //. -rewrite !addrA //. -Qed. + intros m n p A B Y s1 s2 Hfin. -End MMERROR. + (** Decompose the axpby structure for s1*(A*B) + s2*Y, obtaining + backward addition errors ab5, y3 and forward scaling errors + ab4, y2, together with backward scaling perturbations ab3, y1. *) + destruct (mat_axpby_error (F.mulmx A B) Y s1 s2) + as (ab3 & y1 & ab5 & y3 & ab4 & y2 + & Heq1 & Hab3 & Hy1 & Hab5 & Hy3 & Hab4 & Hy2) => //. + (** Decompose the matrix-multiplication error for A*B, propagating + finiteness from the s1*(A*B) factor of Hfin. *) + destruct (MMC_error _ _ _ A B) + as (ab1 & ab2 & Heq2 & Hab1 & Hab2). { + apply F.finitemx_addmx_e in Hfin. + destruct Hfin as [Hfin _]. + apply (F.finitemx_scalemx_e _ _ Hfin). + } + rewrite {}Heq1 {}Heq2 in Hab5, Hab3 |- *. + exists ab1, ab2, ab3, ab4, ab5, y1, y2, y3. + repeat split => //. + rewrite !addrA //. +Qed. +End MMERROR. \ No newline at end of file diff --git a/accuracy_proofs/gemv_acc.v b/accuracy_proofs/gemv_acc.v index 085fddc..b55a64d 100644 --- a/accuracy_proofs/gemv_acc.v +++ b/accuracy_proofs/gemv_acc.v @@ -1,190 +1,245 @@ -From LAProof.accuracy_proofs Require Import preamble common - dotprod_model sum_model - dot_acc float_acc_lems mv_mathcomp. +(** * Matrix-Vector Multiplication Error Bounds + + This file establishes mixed and forward error bounds for floating-point + matrix-vector multiplication, building on the dot product and summation + accuracy results. + + ** Error Factors + + Throughout, the accumulated relative error factor is + %$g(n) = (1 + \mathbf{u})^n - 1$%#\(g(n) = (1 + \mathbf{u})^n - 1\)# and + the mixed absolute error factor is + %$g_1(n_1, n_2) = n_1 \cdot \eta \cdot (1 + g(n_2))$%#\(g_1(n_1, n_2) = n_1 \cdot \eta \cdot (1 + g(n_2))\)#, + where %$\mathbf{u}$%#\(\mathbf{u}\)# is the unit roundoff and + %$\eta$%#\(\eta\)# is the underflow threshold for the given type. + Both are defined in [common]. + + ** Main Results + + - [vec_vec_mul_mixed_error]: Shows that the floating-point row-times-column + dot product can be expressed as an exact product of a componentwise-perturbed + row vector with the exact column vector, plus a small absolute residual. + + - [mat_vec_mul_mixed_error]: Shows that the floating-point matrix-vector + product can be expressed as an exact product of a componentwise-perturbed + matrix with the exact input vector, plus a small absolute residual. + Proved by applying [vec_vec_mul_mixed_error] row by row. + + - [mat_vec_mul_forward_error]: Bounds the absolute forward error of the + floating-point matrix-vector product in the vector max-norm by + %$g(n) \cdot \|A\| \cdot \|B\| + g_1(n, n)$%#\(g(n) \cdot \|A\| \cdot \|B\| + g_1(n,n)\)#, + where %$\|A\|$%#\(\|A\|\)# is the matrix infinity norm and + %$\|B\|$%#\(\|B\|\)# is the vector max-norm. + + ** Dependencies + + This file relies on: + - [preamble], [common]: basic setup and shared definitions + - [dotprod_model], [sum_model]: relational models of dot product and summation + - [dot_acc], [float_acc_lems]: accuracy lemmas + - [mv_mathcomp]: floating-point matrix/vector operations and norm definitions +*) + +From LAProof.accuracy_proofs Require Import preamble common + dotprod_model sum_model + dot_acc float_acc_lems mv_mathcomp. From mathcomp.algebra_tactics Require Import ring. -Section WithNAN. -(* mixed error bounds over lists *) -Context {NAN: FPCore.Nans} {t : type}. +Section WithNAN. + +Context {NAN : FPCore.Nans} {t : type}. -Notation g := (@common.g t). +Notation g := (@common.g t). Notation g1 := (@common.g1 t). -Lemma vec_vec_mul_mixed_error: - forall [n] (A: 'M[ftype t]_(1,n)) (B: 'M[ftype t]_(n,1)) - (Hfin: F.finitemx (F.mulmx A B)), +(** ** Row-Vector Times Column-Vector: Mixed Error Bound *) + +(** [vec_vec_mul_mixed_error] shows that the floating-point row-times-column + dot product can be expressed as an exact product of a componentwise-perturbed + row vector with the exact column vector, plus a small absolute residual. *) + +Lemma vec_vec_mul_mixed_error : + forall [n] (A : 'M[ftype t]_(1,n)) (B : 'M[ftype t]_(n,1)) + (Hfin : F.finitemx (F.mulmx A B)), exists (E : 'M[R]_(1,n)) (eta : 'M[R]_(1,1)), - map_mx FT2R (F.mulmx A B) = ((map_mx FT2R A + E) *m (map_mx FT2R B) + eta)%Ri + map_mx FT2R (F.mulmx A B) + = ((map_mx FT2R A + E) *m (map_mx FT2R B) + eta)%Ri /\ (forall i j, Rabs (E i j) <= g n * Rabs (map_mx FT2R A i j)) /\ (forall i j, Rabs (eta i j) <= g1 n n). Proof. -intros *. -rewrite F.mulmx_dotprodF. -move => Hfin. -specialize (Hfin ord0 ord0). rewrite mxE in Hfin. -assert (Hlen: size (seq_of_rV A) = size (seq_of_rV B^T)). -unfold seq_of_rV. rewrite !size_map size_ord_enum //. -destruct (dotprod_mixed_error (seq_of_rV A) (seq_of_rV (trmx B)) Hlen Hfin) - as [u [eta [ Hu [Heq [HD ?]]]]]. -exists ((\row_j nth R0 u j) - map_mx FT2R A)%Ri, (const_mx eta). -repeat split. -- -rewrite map_const_mx. -rewrite {}Heq. -rewrite (addrC (map_mx _ _)) subrK. -apply /matrixP. intros i j. rewrite !ord1. clear i j. -rewrite !mxE. -change (GRing.add ?A ?B) with (Rplus A B). -f_equal. -rewrite index_ord_enum. -rewrite (unlock (bigop_unlock)). -unfold reducebig, comp, applybig. -unfold dotprodR, dotprod. -rewrite foldl_foldr. -2, 3: compute; intros; lra. -unfold seq_of_rV. -rewrite -!map_comp. -rewrite /seq_of_rV size_map size_ord_enum in Hu. -move :(nth_ord_enum_lemma R0 u). rewrite Hu => Hu'. -rewrite {1}Hu'. clear Hu'. -rewrite zip_map -map_comp foldr_map. -simpl. -f_equal. - apply FunctionalExtensionality.functional_extensionality; intro i. - apply FunctionalExtensionality.functional_extensionality; intro x. -rewrite flip_Rplus. -rewrite !mxE. -reflexivity. -- -intros i j. rewrite ord1; clear i. -destruct (HD j). -unfold seq_of_rV. rewrite size_map size_ord_enum. pose proof (ltn_ord j). lia. -rewrite !mxE. -destruct H0. -rewrite {}H0. -unfold seq_of_rV in H1|-*. -rewrite size_map size_ord_enum in H1. -rewrite (nth_map j). -2: rewrite size_ord_enum; pose proof (ltn_ord j); lia. -rewrite nth_ord_enum'. -change (A 0 j)%Ri with (A ord0 j). -set Aj := FT2R (A ord0 j). -change ((Aj * (1 + x))%Re - Aj)%Ri with ((Aj * (1 + x)) - Aj)%Ri. -replace (Aj * (1 + x) - Aj)%Ri with (Aj*x)%Ri by ring. -rewrite mulrC. -rewrite Rabs_mult. -apply /RleP. -apply Rmult_le_compat_r; auto. -apply Rabs_pos. -- -move => i j; rewrite !ord1 mxE; clear i j. -move :H; rewrite /seq_of_rV size_map size_ord_enum //. -move => H; apply /RleP; auto. + intros *. + rewrite F.mulmx_dotprodF. + move => Hfin. + specialize (Hfin ord0 ord0). rewrite mxE in Hfin. + assert (Hlen : size (seq_of_rV A) = size (seq_of_rV B^T)). + { unfold seq_of_rV. rewrite !size_map size_ord_enum //. } + destruct (dotprod_mixed_error (seq_of_rV A) (seq_of_rV (trmx B)) Hlen Hfin) + as [u [eta [Hu [Heq [HD Heta]]]]]. + exists ((\row_j nth R0 u j) - map_mx FT2R A)%Ri, (const_mx eta). + repeat split. + - rewrite map_const_mx. + rewrite {}Heq. + rewrite (addrC (map_mx _ _)) subrK. + apply /matrixP. intros i j. rewrite !ord1. clear i j. + rewrite !mxE. + change (GRing.add ?A ?B) with (Rplus A B). + f_equal. + rewrite index_ord_enum. + rewrite (unlock (bigop_unlock)). + unfold reducebig, comp, applybig. + unfold dotprodR, dotprod. + rewrite foldl_foldr. + 2, 3: compute; intros; lra. + unfold seq_of_rV. + rewrite -!map_comp. + rewrite /seq_of_rV size_map size_ord_enum in Hu. + move :(nth_ord_enum_lemma R0 u). rewrite Hu => Hu'. + rewrite {1}Hu'. clear Hu'. + rewrite zip_map -map_comp foldr_map. + simpl. + f_equal. + apply FunctionalExtensionality.functional_extensionality; intro i. + apply FunctionalExtensionality.functional_extensionality; intro x. + rewrite flip_Rplus. + rewrite !mxE. + reflexivity. + - intros i j. rewrite ord1; clear i. + destruct (HD j) as [d [Hval Hbd]]. + { unfold seq_of_rV. rewrite size_map size_ord_enum. + pose proof (ltn_ord j). lia. } + rewrite !mxE. + rewrite {}Hval. + unfold seq_of_rV in Hbd |- *. + rewrite size_map size_ord_enum in Hbd. + rewrite (nth_map j). + 2: { rewrite size_ord_enum; pose proof (ltn_ord j); lia. } + rewrite nth_ord_enum'. + change (A 0 j)%Ri with (A ord0 j). + set Aj := FT2R (A ord0 j). + change ((Aj * (1 + d))%Re - Aj)%Ri with ((Aj * (1 + d)) - Aj)%Ri. + replace (Aj * (1 + d) - Aj)%Ri with (Aj * d)%Ri by ring. + rewrite mulrC. + rewrite Rabs_mult. + apply /RleP. + apply Rmult_le_compat_r; auto. + apply Rabs_pos. + - move => i j; rewrite !ord1 mxE; clear i j. + move :Heta; rewrite /seq_of_rV size_map size_ord_enum //. + move => Heta; apply /RleP; auto. Qed. -Lemma mat_vec_mul_mixed_error: - forall [m n] (A: 'M[ftype t]_(m,n)) (B: 'M[ftype t]_(n,1)) - (Hfin: F.finitemx (F.mulmx A B)), +(** ** General Matrix-Vector Product: Mixed Error Bound *) + +(** [mat_vec_mul_mixed_error] shows that the floating-point matrix-vector + product can be expressed as an exact product of a componentwise-perturbed + matrix with the exact input vector, plus a small absolute residual. + Proved by applying [vec_vec_mul_mixed_error] row by row. *) + +Lemma mat_vec_mul_mixed_error : + forall [m n] (A : 'M[ftype t]_(m,n)) (B : 'M[ftype t]_(n,1)) + (Hfin : F.finitemx (F.mulmx A B)), exists (E : 'M[R]_(m,n)) (eta : 'M[R]_(m,1)), - map_mx FT2R (F.mulmx A B) = ((map_mx FT2R A + E) *m (map_mx FT2R B) + eta)%Ri + map_mx FT2R (F.mulmx A B) + = ((map_mx FT2R A + E) *m (map_mx FT2R B) + eta)%Ri /\ (forall i j, Rabs (E i j) <= g n * Rabs (map_mx FT2R A i j)) /\ (forall i j, Rabs (eta i j) <= g1 n n). Proof. -intros. -revert m A Hfin. -induction m; intros. -- -exists (const_mx R0), (const_mx R0). -split; [apply matrixP | split]; intros i j; destruct i; lia. -- -change (m.+1) with (1+m)%nat in A,Hfin|-*. -destruct (IHm (dsubmx A)) as [E2 [eta2 [? [? ?]]]]. { - move => i j. specialize (Hfin (rshift 1 i) j). - unfold F.mulmx in Hfin|-*. - rewrite mxE in Hfin. rewrite mxE row_dsubmx //. -} -clear IHm. -destruct (vec_vec_mul_mixed_error (usubmx A) B) as [E1 [eta1 [? [? ?]]]]. { - move => i j. specialize (Hfin (lshift m i) j). - unfold F.mulmx in Hfin|-*. - rewrite mxE in Hfin. rewrite mxE row_usubmx //. -} -exists (col_mx E1 E2), (col_mx eta1 eta2). -split; [ | split]. -+ -replace (F.mulmx A B) with (col_mx (F.mulmx (usubmx A) B) (F.mulmx (dsubmx A) B)). 2:{ - clear. - unfold F.mulmx. apply /matrixP. move => i j. -destruct (splitP i) as [i'|i']; - [replace i with (@lshift 1 m i'); [ | apply ord_inj; simpl; auto] - |replace i with (@rshift 1 m i'); [ | apply ord_inj; simpl; lia]]. - rewrite col_mxEu !mxE row_usubmx //. - rewrite col_mxEd !mxE row_dsubmx //. -} -rewrite map_col_mx {}H {}H2 map_usubmx map_dsubmx. -set A' := map_mx FT2R A. -set B' := map_mx FT2R B. -rewrite !mulmxDl. -rewrite -!add_col_mx. -f_equal. -f_equal. -rewrite mul_usub_mx mul_dsub_mx vsubmxK //. -rewrite mul_col_mx //. -+ -move => i j. -rewrite -(vsubmxK A) map_col_mx. -destruct (splitP i) as [i'|i']; - [replace i with (@lshift 1 m i'); [ | apply ord_inj; simpl; auto] - |replace i with (@rshift 1 m i'); [ | apply ord_inj; simpl; lia]]. - move :(H3 i' j). rewrite !col_mxEu !mxE //. - move :(H0 i' j). rewrite !col_mxEd !mxE //. -+ -move => i j. -destruct (splitP i) as [i'|i']; - [replace i with (@lshift 1 m i'); [ | apply ord_inj; simpl; auto] - |replace i with (@rshift 1 m i'); [ | apply ord_inj; simpl; lia]]. - move :(H4 i' j). rewrite !col_mxEu //. - move :(H1 i' j). rewrite !col_mxEd //. -Qed. + intros. + revert m A Hfin. + induction m; intros. + - exists (const_mx R0), (const_mx R0). + split; [apply matrixP | split]; intros i j; destruct i; lia. + - change (m.+1) with (1 + m)%nat in A, Hfin |- *. + destruct (IHm (dsubmx A)) as [E2 [eta2 [HE2 [HB2 Heta2]]]]. + { move => i j. specialize (Hfin (rshift 1 i) j). + unfold F.mulmx in Hfin |- *. + rewrite mxE in Hfin. rewrite mxE row_dsubmx //. } + clear IHm. + destruct (vec_vec_mul_mixed_error (usubmx A) B) as [E1 [eta1 [HE1 [HB1 Heta1]]]]. + { move => i j. specialize (Hfin (lshift m i) j). + unfold F.mulmx in Hfin |- *. + rewrite mxE in Hfin. rewrite mxE row_usubmx //. } + exists (col_mx E1 E2), (col_mx eta1 eta2). + split; [ | split]. + + replace (F.mulmx A B) + with (col_mx (F.mulmx (usubmx A) B) (F.mulmx (dsubmx A) B)). + 2: { clear. + unfold F.mulmx. apply /matrixP. move => i j. + destruct (splitP i) as [i'|i']; + [ replace i with (@lshift 1 m i'); + [ | apply ord_inj; simpl; auto] + | replace i with (@rshift 1 m i'); + [ | apply ord_inj; simpl; lia]]. + rewrite col_mxEu !mxE row_usubmx //. + rewrite col_mxEd !mxE row_dsubmx //. } + rewrite map_col_mx {}HE1 {}HE2 map_usubmx map_dsubmx. + set A' := map_mx FT2R A. + set B' := map_mx FT2R B. + rewrite !mulmxDl. + rewrite -!add_col_mx. + f_equal. + f_equal. + rewrite mul_usub_mx mul_dsub_mx vsubmxK //. + rewrite mul_col_mx //. + + move => i j. + rewrite -(vsubmxK A) map_col_mx. + destruct (splitP i) as [i'|i']; + [ replace i with (@lshift 1 m i'); + [ | apply ord_inj; simpl; auto] + | replace i with (@rshift 1 m i'); + [ | apply ord_inj; simpl; lia]]. + move :(HB1 i' j). rewrite !col_mxEu !mxE //. + move :(HB2 i' j). rewrite !col_mxEd !mxE //. + + move => i j. + destruct (splitP i) as [i'|i']; + [ replace i with (@lshift 1 m i'); + [ | apply ord_inj; simpl; auto] + | replace i with (@rshift 1 m i'); + [ | apply ord_inj; simpl; lia]]. + move :(Heta1 i' j). rewrite !col_mxEu //. + move :(Heta2 i' j). rewrite !col_mxEd //. +Qed. + +(** ** Matrix-Vector Product: Forward Error Bound *) + +(** [mat_vec_mul_forward_error] bounds the absolute forward error of the + floating-point matrix-vector product in the vector max-norm by + %$g(n) \cdot \|A\|_M \cdot \|B\| + g_1(n, n)$%#\(g(n) \cdot \|A\|_M \cdot \|B\| + g_1(n,n)\)#, + where [normM] is the matrix infinity norm and [normv] is the vector max-norm. *) Theorem mat_vec_mul_forward_error : - forall [m n] (A: 'M[ftype t]_(m,n)) (B: 'M[ftype t]_(n,1)) - (Hfin: F.finitemx (F.mulmx A B)), - normv (map_mx FT2R (F.mulmx A B) - (map_mx FT2R A *m map_mx FT2R B)) - <= (g n * normM (map_mx FT2R A) * normv (map_mx FT2R B)) + g1 n n. + forall [m n] (A : 'M[ftype t]_(m,n)) (B : 'M[ftype t]_(n,1)) + (Hfin : F.finitemx (F.mulmx A B)), + normv (map_mx FT2R (F.mulmx A B) - (map_mx FT2R A *m map_mx FT2R B)) + <= (g n * normM (map_mx FT2R A) * normv (map_mx FT2R B)) + g1 n n. Proof. -intros. -destruct (mat_vec_mul_mixed_error _ _ Hfin) as (E & eta & HE & H1 & H2). -rewrite {}HE mulmxDl. -set Ar := map_mx FT2R A. -set Br := map_mx FT2R B. -have H0: (Ar *m Br + E *m Br + eta - Ar *m Br = E *m Br + eta)%Ri. -rewrite -!addrA addrC addrA -addrA addNr addr0 //. -rewrite {}H0. -eapply (le_trans (normv_triang _ _ _)). -apply lerD. -eapply (le_trans (subMultNorm _ _ _ _ )). -apply ler_pM => //. -apply normM_pos. -apply normv_pos. -rewrite /normM mulrC big_max_mul. -apply: le_bigmax2 => i0 _. -rewrite /sum_abs. -rewrite big_mul => [ | i b | ]; [ | ring | ]. -- -apply ler_sum => i _. -rewrite mulrC -/Ar //. -- -apply /RleP; auto with commonDB. -- -apply /RleP; auto with commonDB. -- -rewrite /normv. -apply bigmax_le => [|i _]. -apply /RleP; auto with commonDB. -auto. + intros. + destruct (mat_vec_mul_mixed_error _ _ Hfin) as (E & eta & HE & H1 & H2). + rewrite {}HE mulmxDl. + set Ar := map_mx FT2R A. + set Br := map_mx FT2R B. + have H0 : (Ar *m Br + E *m Br + eta - Ar *m Br = E *m Br + eta)%Ri. + { rewrite -!addrA addrC addrA -addrA addNr addr0 //. } + rewrite {}H0. + eapply (le_trans (normv_triang _ _ _)). + apply lerD. + eapply (le_trans (subMultNorm _ _ _ _)). + apply ler_pM => //. + apply normM_pos. + apply normv_pos. + rewrite /normM mulrC big_max_mul. + apply: le_bigmax2 => i0 _. + rewrite /sum_abs. + rewrite big_mul => [ | i b | ]; [ | ring | ]. + - apply ler_sum => i _. + rewrite mulrC -/Ar //. + - apply /RleP; auto with commonDB. + - apply /RleP; auto with commonDB. + - rewrite /normv. + apply bigmax_le => [| i _]. + apply /RleP; auto with commonDB. + auto. Qed. -End WithNAN. - +End WithNAN. \ No newline at end of file diff --git a/accuracy_proofs/libvalidsdp.v b/accuracy_proofs/libvalidsdp.v index 2d0a839..85c2535 100644 --- a/accuracy_proofs/libvalidsdp.v +++ b/accuracy_proofs/libvalidsdp.v @@ -74,12 +74,6 @@ Section WithNaN. Context {NAN: FPCore.Nans} {t : type}. -Definition default_rel : R := - / 2 * Raux.bpow Zaux.radix2 (- fprec t + 1). - -Definition default_abs : R := - / 2 * Raux.bpow Zaux.radix2 (3 - femax t - fprec t). - Lemma prec_lt_emax: @flocq_float.prec (fprecp t) 0. -rewrite /eta. +Lemma default_abs_nonzero: eps <> 0. apply Rmult_integral_contrapositive. split. lra. rewrite bpow_powerRZ. @@ -112,9 +105,6 @@ apply powerRZ_NOR. simpl. lra. Qed. -Definition iszero {t} (x: ftype t) : bool := - match x with Binary.B754_zero _ _ _ => true | _ => false end. - Fixpoint fsum_l2r_rec [n: nat] (c : F) : F^n -> F := match n with | 0%N => fun _ => c @@ -135,16 +125,6 @@ Definition ytilded [k : nat] (c : F) (a b : F^k) (bk : F) := Definition ytildes [k : nat] (c : F) (a : F^k):= BSQRT (stilde c a a). - -Lemma BPLUS_B2R_zero (a : ftype t): - Binary.is_finite a -> - FT2R (BPLUS a (Zconst t 0)) = FT2R a. -Proof. -unfold BPLUS, BINOP, Zconst; intros; -destruct a; simpl; try discriminate; auto. -destruct s; simpl; auto. -Qed. - Lemma format_FT2R: forall (x: ftype t), is_true (@flocq_float.format (fprecp t) (femax t) (FT2R x)). Proof. move => x. diff --git a/accuracy_proofs/mv_mathcomp.v b/accuracy_proofs/mv_mathcomp.v index 5ac9242..829b689 100644 --- a/accuracy_proofs/mv_mathcomp.v +++ b/accuracy_proofs/mv_mathcomp.v @@ -1,7 +1,43 @@ -(* This file contains theorems connecting MathComp operations on - matrices and vectors to operations on lists. *) - -From LAProof.accuracy_proofs Require Import preamble common +(** * Matrix and Vector Operations: MathComp–List Correspondence + + This file establishes the formal connection between MathComp matrix/vector + operations and their list-based counterparts, with applications to + floating-point arithmetic via the VCFloat framework. + + ** Main Results + + - [Fmulmx_matrix_vector_mult]: the floating-point matrix–vector product + computed via [F.FMA_mulmx] on MathComp matrices equals the list-based + [matrix_vector_mult], bridging the MathComp and list-based worlds. + + ** Structure + + - _Utility definitions_ for norms ([normv], [normM], [sum_abs]) and + sequence extraction ([seq_of_rV]). + - _Tactics_ [ordify] and [case_splitP] for working with ordinal indices + and block-structured matrices. + - _Enumeration lemmas_ connecting [ord_enum], [index_enum], and list + operations ([nth_ord_enum'], [nth_index_enum], [index_ord_enum], etc.). + - _Norm lemmas_ including positivity ([normv_pos], [normM_pos]), + the submultiplicative inequality ([subMultNorm]), and the triangle + inequality ([normv_triang]). + - _Module F_ defining floating-point matrix/vector operations + ([F.dotprod], [F.mulmx], [F.FMA_mulmx], etc.) and proving structural + lemmas such as row/column block decompositions. + - _Conversion lemmas_ between MathComp matrices and list-of-lists + representations ([listlist_of_mx], [mx_of_listlist], [list_of_cV], + [cV_of_list]), including roundtrip identities and size invariants. + + ** Dependencies + + This file relies on: + - [preamble], [common]: basic setup and shared definitions + - [dotprod_model], [sum_model]: relational models of dot product and summation + - [dot_acc]: dot product accuracy lemmas + - [float_acc_lems]: elementary floating-point accuracy lemmas +*) + +From LAProof.accuracy_proofs Require Import preamble common dotprod_model sum_model dot_acc float_acc_lems. From mathcomp.algebra_tactics Require Import ring. @@ -9,100 +45,186 @@ From mathcomp.algebra_tactics Require Import ring. Open Scope ring_scope. Open Scope order_scope. -Definition sum_abs {m n} (A: 'M[R]_(m,n)) i : R:= \sum_j (Rabs (A i j)). -Definition normv {m} (v: 'cV[R]_m) : R:= \big[maxr/0]_(i < m) Rabs (v i 0%Ri). -Definition normM {m n} (A: 'M[R]_(m,n)) : R:= \big[maxr/0]_i (sum_abs A i). -Definition seq_of_rV {T}[n] (x: 'rV[T]_n) := map (x ord0) (ord_enum n). +(** ** Norm and sequence definitions *) + +(** [sum_abs A i] is the L1 row norm of row << i >> of matrix << A >>. *) + +Definition sum_abs {m n} (A : 'M[R]_(m,n)) i : R := + \sum_j (Rabs (A i j)). + +(** [normv v] is the infinity norm of a column vector << v >>, i.e., + the maximum of the absolute values of its entries. *) + +Definition normv {m} (v : 'cV[R]_m) : R := + \big[maxr/0]_(i < m) Rabs (v i 0%Ri). + +(** [normM A] is the infinity matrix norm of << A >>, i.e., the maximum + over rows of the L1 row norms. *) + +Definition normM {m n} (A : 'M[R]_(m,n)) : R := + \big[maxr/0]_i (sum_abs A i). + +(** [seq_of_rV x] converts a MathComp row vector to a plain list. *) + +Definition seq_of_rV {T} [n] (x : 'rV[T]_n) := + map (x ord0) (ord_enum n). + +(** ** Tactics *) + +(** [ordify n i] replaces a variable << i >> of type [Z] or [nat] with a + corresponding ordinal << i : 'I_n >>, introducing the coercion hypothesis + << Hi : i = nat_of_ord i >> (or its [Z] analogue). *) -(** Given a variable [i] of type [Z] or [nat], replace it everywhere with a variable [i] of type ['I_n], - appropriately coerced. *) Ltac ordify n i := let Hi := fresh "H" i in - let Hj := fresh "H" i in - let j := fresh "i" in - match type of i with ?t => let t' := eval hnf in t in match t' with - | Z => assert (Hi: Datatypes.is_true (ssrnat.leq (S (Z.to_nat i)) n)) by lia; - set (j := @Ordinal n (Z.to_nat i) Hi); - assert (Hj : i = Z.of_nat (nat_of_ord j)) by (simpl; lia) - | nat => assert (Hi: Datatypes.is_true (ssrnat.leq (S i) n)) by lia; - set (j := @Ordinal n i Hi); - assert (Hj : i = nat_of_ord j) by (simpl; lia) - end end; - clearbody j; clear Hi; - subst i; - rename j into i. - + let Hj := fresh "H" i in + let j := fresh "i" in + match type of i with ?t => + let t' := eval hnf in t in + match t' with + | Z => + assert (Hi : Datatypes.is_true (ssrnat.leq (S (Z.to_nat i)) n)) by lia; + set (j := @Ordinal n (Z.to_nat i) Hi); + assert (Hj : i = Z.of_nat (nat_of_ord j)) by (simpl; lia) + | nat => + assert (Hi : Datatypes.is_true (ssrnat.leq (S i) n)) by lia; + set (j := @Ordinal n i Hi); + assert (Hj : i = nat_of_ord j) by (simpl; lia) + end + end; + clearbody j; clear Hi; + subst i; + rename j into i. + +(** [case_splitP j] destructs an ordinal << j : 'I_(a + b) >> into its left + (lshift]) and right (rshift) components, rewriting << j >> in the goal + accordingly. *) + Ltac case_splitP j := - tryif clearbody j then fail "case_splitP requires a variable, but got a local definition" j - else tryif is_var j then idtac else fail "case_splitP requires a variable, but got" j; - match type of j with 'I_(addn ?a ?b) => - let i := fresh "j" in let H := fresh in - destruct (splitP j) as [i H | i H]; - [replace j with (@lshift a b i); [ | apply ord_inj; simpl; lia] - |replace j with (@rshift a b i); [ | apply ord_inj; simpl; lia]]; - clear j H; rename i into j - end. - -(** Example of how to use case_splitP *) -Local Remark mul_mx_row' [R : pzSemiRingType] m n p1 p2 - (A: 'M[R]_(m,n)) (Bl: 'M[R]_(n,p1)) (Br: 'M[R]_(n,p2)): + tryif clearbody j then + fail "case_splitP requires a variable, but got a local definition" j + else + tryif is_var j then idtac + else fail "case_splitP requires a variable, but got" j; + match type of j with 'I_(addn ?a ?b) => + let i := fresh "j" in + let H := fresh in + destruct (splitP j) as [i H | i H]; + [ replace j with (@lshift a b i); + [ | apply ord_inj; simpl; lia ] + | replace j with (@rshift a b i); + [ | apply ord_inj; simpl; lia ] ]; + clear j H; rename i into j + end. + +(** *** Example uses of [case_splitP] *) + +(** Proof of << A *m row_mx Bl Br = row_mx (A *m Bl) (A *m Br) >> + using [case_splitP]. *) + +Local Remark mul_mx_row' [R : pzSemiRingType] m n p1 p2 + (A : 'M[R]_(m,n)) + (Bl : 'M[R]_(n,p1)) + (Br : 'M[R]_(n,p2)) : A *m row_mx Bl Br = row_mx (A *m Bl) (A *m Br). Proof. -apply /matrixP => i j. -case_splitP j. -rewrite row_mxEl !mxE . apply eq_bigr. move => k _; rewrite row_mxEl//. -rewrite row_mxEr !mxE . apply eq_bigr. move => k _; rewrite row_mxEr//. + apply /matrixP => i j. + case_splitP j. + - rewrite row_mxEl !mxE. + apply eq_bigr => k _. + rewrite row_mxEl //. + - rewrite row_mxEr !mxE. + apply eq_bigr => k _. + rewrite row_mxEr //. Qed. -(** Example of how the mathcomp experts do this another way, from mathcomp.algebra.matrix *) -Local Remark mul_mx_row'' [R : pzSemiRingType] m n p1 p2 (A : 'M[R]_(m, n)) (Bl : 'M_(n, p1)) (Br : 'M_(n, p2)) : +(** Alternative proof of [mul_mx_row'] following the MathComp style + from [mathcomp.algebra.matrix]. *) + +Local Remark mul_mx_row'' [R : pzSemiRingType] m n p1 p2 + (A : 'M[R]_(m, n)) + (Bl : 'M_(n, p1)) + (Br : 'M_(n, p2)) : A *m row_mx Bl Br = row_mx (A *m Bl) (A *m Br). Proof. -apply/matrixP=> i k; rewrite !mxE. -by case defk: (split k) => /[!mxE]; under eq_bigr do rewrite mxE defk. + apply/matrixP => i k; rewrite !mxE. + by case defk: (split k) => /[!mxE]; + under eq_bigr do rewrite mxE defk. Qed. -Lemma nth_List_nth: forall {A: Type} (d: A) (l: seq.seq A) (n: nat), +(** ** Enumeration and list lemmas *) + +(** [seq.nth] and [List.nth] agree for any list. *) + +Lemma nth_List_nth : forall {A : Type} (d : A) (l : seq.seq A) (n : nat), seq.nth d l n = List.nth n l d. Proof. - move => A d l. elim : l => [//= n | //= h t IH n]. + move => A d l. + elim : l => [n | h t IH n]. - by case : n. - - case: n. by []. move => n. by rewrite /= IH. + - by case: n => [// | n]; rewrite /= IH. Qed. -Lemma pred_lt: forall [n: nat], (0 < n -> n.-1 < n)%nat. +(** The predecessor of a positive natural number is strictly smaller. *) + +Lemma pred_lt : forall [n : nat], (0 < n -> n.-1 < n)%nat. Proof. - move => n Hn. by rewrite ltn_predL. + move => n Hn. + by rewrite ltn_predL. Qed. -Definition pred_ord [n: nat] (Hn: (0 < n)%nat) : 'I_n := Ordinal (pred_lt Hn). +(** The ordinal << n-1 : 'I_n >> for a positive << n >>. *) + +Definition pred_ord [n : nat] (Hn : (0 < n)%nat) : 'I_n := + Ordinal (pred_lt Hn). + +(** The finite enumeration of ['I_n] has size << n >>. *) -Lemma ordinal_enum_size: forall n: nat, +Lemma ordinal_enum_size : forall n : nat, size (Finite.enum (ordinal n)) = n. Proof. - move => n. have: size ([seq val i | i <- enum 'I_n]) = n. rewrite val_enum_ord. by apply: size_iota. - rewrite size_map. unfold enum. rewrite size_map //. + move => n. + have: size ([seq val i | i <- enum 'I_n]) = n. + rewrite val_enum_ord. by apply: size_iota. + rewrite size_map. + unfold enum. + rewrite size_map //. Qed. -Lemma size_ord_enum: forall n, size (ord_enum n) = n. +(** [ord_enum n] has size << n >>. *) + +Lemma size_ord_enum : forall n, size (ord_enum n) = n. Proof. - move => n. - have : size (ord_enum n) = size ([seq val i | i <- ord_enum n]) by rewrite size_map. - by rewrite val_ord_enum size_iota. + move => n. + have Hsize : size (ord_enum n) = size ([seq val i | i <- ord_enum n]) + by rewrite size_map. + by rewrite Hsize val_ord_enum size_iota. Qed. -Lemma nth_index_enum: forall {n: nat} (x: 'I_n) y, +(** The << i >>-th element of [index_enum (ordinal n)] is << i >> itself. *) + +Lemma nth_index_enum : forall {n : nat} (x : 'I_n) y, seq.nth y (index_enum (ordinal n)) x = x. Proof. move => n x y. - have nth_ord := (@nth_ord_enum n y x). unfold enum in nth_ord. move: nth_ord. - rewrite (@nth_map _ y) //. by rewrite ordinal_enum_size. -Qed. + have nth_ord := (@nth_ord_enum n y x). + unfold enum in nth_ord. + move: nth_ord. + rewrite (@nth_map _ y) //. + by rewrite ordinal_enum_size. +Qed. + +(** The << i >>-th element of [ord_enum n] is << i >> itself. *) -Lemma nth_ord_enum': forall n (i: 'I_n) x, seq.nth x (ord_enum n) i = i. +Lemma nth_ord_enum' : forall n (i : 'I_n) x, seq.nth x (ord_enum n) i = i. Proof. - move => n i x. have Hv := val_ord_enum n. have Hmap := @nth_map 'I_n x nat x val i (ord_enum n). - move : Hmap. rewrite Hv size_ord_enum nth_iota =>[//=|//]. rewrite add0n. move => H. + move => n i x. + have Hv := val_ord_enum n. + have Hmap := @nth_map 'I_n x nat x val i (ord_enum n). + move : Hmap. + rewrite Hv size_ord_enum nth_iota =>[//=|//]. + rewrite add0n. + move => H. (*some annoying stuff about equality of ordinals vs nats*) have : nat_of_ord ( seq.nth x (ord_enum n) i) == nat_of_ord i. rewrite {2}H. by []. by []. @@ -110,10 +232,14 @@ Proof. by move => /eqP Heq. Qed. +(** [index_enum (ordinal n)] equals [ord_enum n]. *) -Lemma index_ord_enum: forall (n: nat), (index_enum (ordinal n)) = ord_enum n. +Lemma index_ord_enum : forall n : nat, + index_enum (ordinal n) = ord_enum n. Proof. - move => n. have: (0 <= n)%nat by []. rewrite leq_eqVlt => /orP[/eqP Hn0 | Hnpos]. + move => n. + have: (0 <= n)%nat by []. + rewrite leq_eqVlt => /orP[/eqP Hn0 | Hnpos]. - subst. rewrite /ord_enum /= /index_enum /=. apply size0nil. apply ordinal_enum_size. - apply (eq_from_nth (x0:=pred_ord Hnpos)). + rewrite ordinal_enum_size size_ord_enum //. @@ -122,785 +248,973 @@ Proof. rewrite nth_index_enum nth_ord_enum' //. Qed. +(** [seq_of_rV x] has size << n >>. *) -Lemma size_seq_of_rV : forall {T} [n] x, size (@seq_of_rV T n x) = n. +Lemma size_seq_of_rV : forall {T} [n] x, + size (@seq_of_rV T n x) = n. Proof. -intros. -rewrite /seq_of_rV size_map size_ord_enum //. + intros. + rewrite /seq_of_rV size_map size_ord_enum //. Qed. -Lemma nth_seq_of_rV: forall {T}[n](d: T)(x: 'rV[T]_n) (i: 'I_n), nth d (seq_of_rV x) i = x ord0 i. +(** The << i >>-th element of [seq_of_rV x] is [x ord0 i]. *) + +Lemma nth_seq_of_rV : forall {T} [n] (d : T) (x : 'rV[T]_n) (i : 'I_n), + nth d (seq_of_rV x) i = x ord0 i. Proof. -intros. -pose proof (ltn_ord i). -rewrite /seq_of_rV (nth_map i d) ?nth_ord_enum' // size_ord_enum //. -Qed. - -(* generally useful lemmmas for max operator *) -Lemma maxrC : @commutative R R maxr. - Proof. rewrite /commutative => x y. - rewrite -!RmaxE. apply Rmax_comm. Qed. - -Lemma maxrA : @associative R maxr. - Proof. rewrite /associative => x y z. - rewrite -!RmaxE. apply Rmax_assoc. Qed. - -Lemma big_mul {n:nat} (F : ordinal n -> R) op a: -(forall i b, op (F i) b * a = op (F i * a) (b * a)) -> -R0 <= a -> \big[op/0]_(i0 < n) (F i0) * a = \big[op/0]_(i0 < n) (F i0 * a). -Proof. -destruct n. -- -intros. - rewrite (unlock (bigop_unlock)). - unfold reducebig, comp, applybig. simpl. rewrite index_ord_enum. simpl. - apply Rmult_0_l. -- -revert F a. elim: n => /= // [F a Hc Ha| n0 IH F a Hc Ha]. -rewrite !big_ord_recl !big_ord0/= //. -rewrite (Hc ord0 0) mul0r //. -rewrite big_ord_recl => /= //. -etransitivity. -2 : rewrite big_ord_recl => /= //. -rewrite Hc. -rewrite IH => //. -Qed. - -Lemma big_max_mul {n:nat} (F : ordinal n -> R) a: -R0 <= a -> \big[maxr/0]_(i0 < n) (F i0) * a = \big[maxr/0]_(i0 < n) (F i0 * a). -Proof. -move => Ha. -apply big_mul => //. -move => i b. -change (maxr (F i) b * a = maxr (F i * a) (b * a))%Ri. -rewrite maxr_pMl //. -Qed. - -(* Lemmas about norm defs *) - - -Lemma normv_pos {m} (v: 'cV[R]_m) : R0 <= normv v. -Proof. -rewrite /normr/normv. -elim/big_ind: _ => //[x y Hx Hy| i _]. -rewrite -RmaxE. eapply le_trans; [apply Hy|]. -apply /RleP; apply Rmax_r. -apply /RleP; apply Rabs_pos. -Qed. - -Lemma normM_pos [m n] (A: 'M[R]_(m,n)) : R0 <= normM A. -Proof. -rewrite /normr/normM . -elim/big_ind: _ => //[x y Hx Hy| i _]. -rewrite -RmaxE/Rmax. destruct Rle_dec => //. -rewrite /sum_abs. -elim/big_ind: _ => //[x y Hx Hy| j _]. -apply addr_ge0 => //. -apply /RleP; apply Rabs_pos. -Qed. - -Lemma Rabs_sum (n:nat) : forall (F : ordinal n -> R), -Rabs (\sum_j F j) <= \sum_j Rabs (F j). -Proof. -destruct n. -- intros. - rewrite (unlock (bigop_unlock)). - unfold reducebig, comp, applybig. simpl. - rewrite index_ord_enum. simpl. rewrite Rabs_R0. apply /RleP. reflexivity. -- -elim : n => [F | n IH F]. -rewrite !big_ord_recr!big_ord0/=. - eapply le_trans ; [apply Rleb_norm_add| rewrite Rabs_R0; apply lerD => /= //]. -eapply le_trans. -1, 2: rewrite big_ord_recr /=. apply Rleb_norm_add. -apply lerD => /= //. -Qed. - - -Lemma subMultNorm m n (A: 'M[R]_(m,n)) (u : 'cV_n) : - normv ( A *m u ) <= normM A * normv u. -Proof. -destruct m. -- -rewrite /normr /normM /normv. - rewrite (unlock (bigop_unlock)). - unfold reducebig, comp, applybig. simpl. - rewrite index_ord_enum. simpl. - set xx := foldr _ _ _. clearbody xx. - apply /RleP. change (0 <= 0*xx)%Re. rewrite Rmult_0_l. reflexivity. -- -remember (normv u) as umax. -rewrite /normr /normM /normv /sum_abs /= big_max_mul. -apply: le_bigmax2 => i0 _. -rewrite mxE => /=. -eapply le_trans. -apply Rabs_sum . -elim/big_rec2: _ => // [ |i1 y1 y2 _ Hy]. -apply mulr_ge0 => //. -rewrite Hequmax; apply normv_pos. -rewrite mulrDl. -apply lerD => //. -rewrite Rabs_mult. -apply ler_pM => //. -1,2: apply /RleP; apply Rabs_pos. -rewrite Hequmax/normv. -by apply /le_bigmax. -rewrite Hequmax. - apply normv_pos. -Qed. - -Lemma normv_triang m (u v: 'cV_m) : - normv ( u + v ) <= normv u + normv v. -Proof. -rewrite {1}/normv. -apply: bigmax_le => [ | i _]. -apply addr_ge0; apply normv_pos. -rewrite mxE => /=. -eapply le_trans. -apply Rleb_norm_add. apply lerD; -apply: le_bigmax => [ | i _]. -Qed. - - -Local Definition crazy (T: Type): 'I_0 -> T. -intro. destruct X. lia. -Defined. - -Lemma exists_mx: forall {T} [m n] (F: 'I_m -> 'I_n -> T -> Prop), - (forall i j, exists x, F i j x) -> - exists A: 'M[T]_(m,n), forall i j, F i j (A i j). -Proof. -intros. -induction m. -- -exists (\matrix_(i,j) crazy T i). intros. destruct i. lia. -- -change (m.+1) with (1+m)%nat. -destruct (IHm (fun i j => F (rshift 1 i) j)). -intros. apply H. -assert (exists A1: 'M[T]_(1,n), forall j, F ord0 j (A1 ord0 j)). { - clear IHm x H0. - induction n. exists (\matrix_(i,j) crazy T j). intros. destruct j; lia. - destruct (IHn (fun i j => F i (rshift 1 j))). intros. apply H. - destruct (H ord0 ord0). - exists (row_mx (@const_mx _ 1 1 x0) x). intros. - change (n.+1) with (1 + n)%nat in j |-*. - destruct (splitP j). - replace j with (@lshift 1 n j0). + pose proof (ltn_ord i). + rewrite /seq_of_rV (nth_map i d) ?nth_ord_enum' // size_ord_enum //. +Qed. + +(** ** Lemmas about the [maxr] operator *) + +(** [maxr] is commutative. *) + +Lemma maxrC : @commutative R R maxr. +Proof. + rewrite /commutative => x y. + rewrite -!RmaxE. + apply Rmax_comm. +Qed. + +(** [maxr] is associative. *) + +Lemma maxrA : @associative R maxr. +Proof. + rewrite /associative => x y z. + rewrite -!RmaxE. + apply Rmax_assoc. +Qed. + +(** Scalar multiplication distributes over a big << op >>-fold when << op >> + is "linear" in the sense expressed by the hypothesis << Hc >>. *) + +Lemma big_mul {n : nat} (F : ordinal n -> R) op a + (Hc: forall i b, op (F i) b * a = op (F i * a) (b * a)): + R0 <= a -> + \big[op/0]_(i0 < n) (F i0) * a = \big[op/0]_(i0 < n) (F i0 * a). +Proof. + destruct n. + - intros Ha. + rewrite (unlock (bigop_unlock)). + unfold reducebig, comp, applybig; simpl. + rewrite index_ord_enum; simpl. + apply Rmult_0_l. + - revert Hc. revert F a. + elim: n => [F a Hc Ha | n0 IH F a Hc Ha]. + + rewrite !big_ord_recl !big_ord0 /= //. + rewrite (Hc ord0 0) mul0r //. + + rewrite big_ord_recl => /=. + etransitivity. + 2: rewrite big_ord_recl => /= //. + rewrite Hc IH //. +Qed. + +(** Scalar multiplication distributes over a big [maxr]-fold for + nonnegative scalars. *) + +Lemma big_max_mul {n : nat} (F : ordinal n -> R) a : + R0 <= a -> + \big[maxr/0]_(i0 < n) (F i0) * a = \big[maxr/0]_(i0 < n) (F i0 * a). +Proof. + move => Ha. + apply big_mul => //. + move => i b. + change (maxr (F i) b * a = maxr (F i * a) (b * a))%Ri. + rewrite maxr_pMl //. +Qed. + +(** ** Norm lemmas *) + +(** [normv v] is nonnegative for any vector << v >>. *) + +Lemma normv_pos {m} (v : 'cV[R]_m) : R0 <= normv v. +Proof. + rewrite /normr /normv. + elim/big_ind: _ => // [x y Hx Hy | i _]. + - rewrite -RmaxE. + eapply le_trans; [apply Hy |]. + apply /RleP; apply Rmax_r. + - apply /RleP; apply Rabs_pos. +Qed. + +(** [normM A] is nonnegative for any matrix << A >>. *) + +Lemma normM_pos [m n] (A : 'M[R]_(m,n)) : R0 <= normM A. +Proof. + rewrite /normr /normM. + elim/big_ind: _ => // [x y Hx Hy | i _]. + - rewrite -RmaxE /Rmax. + destruct Rle_dec => //. + - rewrite /sum_abs. + elim/big_ind: _ => // [x y Hx Hy | j _]. + + apply addr_ge0 => //. + + apply /RleP; apply Rabs_pos. +Qed. + +(** Triangle inequality for absolute values under a finite sum. *) + +Lemma Rabs_sum (n : nat) : forall (F : ordinal n -> R), + Rabs (\sum_j F j) <= \sum_j Rabs (F j). +Proof. + destruct n. + - intros F. + rewrite (unlock (bigop_unlock)). + unfold reducebig, comp, applybig; simpl. + rewrite index_ord_enum; simpl. + rewrite Rabs_R0. + apply /RleP; reflexivity. + - elim : n => [F | n IH F]. + + rewrite !big_ord_recr !big_ord0 /=. + eapply le_trans; [apply Rleb_norm_add |]. + rewrite Rabs_R0. + apply lerD => /= //. + + eapply le_trans. + 1, 2: rewrite big_ord_recr /=. + apply Rleb_norm_add. + apply lerD => /= //. +Qed. + +(** Submultiplicativity: [‖A u‖_∞ ≤ ‖A‖_∞ · ‖u‖_∞]. *) + +Lemma subMultNorm m n (A : 'M[R]_(m,n)) (u : 'cV_n) : + normv (A *m u) <= normM A * normv u. +Proof. + destruct m. + - rewrite /normr /normM /normv. + rewrite (unlock (bigop_unlock)). + unfold reducebig, comp, applybig; simpl. + rewrite index_ord_enum; simpl. + set xx := foldr _ _ _; clearbody xx. + apply /RleP. + change (0 <= 0 * xx)%Re. + rewrite Rmult_0_l; reflexivity. + - remember (normv u) as umax. + rewrite /normr /normM /normv /sum_abs /= big_max_mul. + apply: le_bigmax2 => i0 _. + rewrite mxE => /=. + eapply le_trans; [apply Rabs_sum |]. + elim/big_rec2: _ => // [| i1 y1 y2 _ Hy]. + + apply mulr_ge0 => //. + rewrite Hequmax; apply normv_pos. + + rewrite mulrDl. + apply lerD => //. + rewrite Rabs_mult. + apply ler_pM => //. + 1, 2: apply /RleP; apply Rabs_pos. + rewrite Hequmax /normv. + by apply /le_bigmax. + + rewrite Hequmax. + apply normv_pos. +Qed. + +(** Triangle inequality for [normv]: [‖u + v‖_∞ ≤ ‖u‖_∞ + ‖v‖_∞]. *) + +Lemma normv_triang m (u v : 'cV_m) : + normv (u + v) <= normv u + normv v. +Proof. + rewrite {1}/normv. + apply: bigmax_le => [| i _]. + - apply addr_ge0; apply normv_pos. + - rewrite mxE => /=. + eapply le_trans; [apply Rleb_norm_add |]. + apply lerD; + apply: le_bigmax => [| i _]. +Qed. + +(** ** Auxiliary definitions *) + +(** An eliminator for the empty type ['I_0]. *) + +Local Definition crazy (T : Type) : 'I_0 -> T. +Proof. intro H. destruct H. lia. Defined. + +(** If witnesses exist for every entry, then a matrix with those entries exists. *) + +Lemma exists_mx : forall {T} [m n] (F : 'I_m -> 'I_n -> T -> Prop), + (forall i j, exists x, F i j x) -> + exists A : 'M[T]_(m,n), forall i j, F i j (A i j). +Proof. + intros. + induction m. + - + exists (\matrix_(i,j) crazy T i). intros. destruct i. lia. + - + change (m.+1) with (1+m)%nat. + destruct (IHm (fun i j => F (rshift 1 i) j)). + intros. apply H. + assert (exists A1: 'M[T]_(1,n), forall j, F ord0 j (A1 ord0 j)). { + clear IHm x H0. + induction n. exists (\matrix_(i,j) crazy T j). intros. destruct j; lia. + destruct (IHn (fun i j => F i (rshift 1 j))). intros. apply H. + destruct (H ord0 ord0). + exists (row_mx (@const_mx _ 1 1 x0) x). + intros. + change (n.+1) with (1 + n)%nat in j |-*. + destruct (splitP j). + replace j with (@lshift 1 n j0). + 2: apply ord_inj; simpl; auto. + rewrite row_mxEl. rewrite mxE. + replace (lshift n j0) with (@ord0 n); auto. + rewrite ord1; apply ord_inj; simpl; auto. + replace j with (@rshift 1 n k). + 2: apply ord_inj; simpl; lia. + rewrite row_mxEr. + apply H0. + } + destruct H1 as [A1 ?]. + change (m.+1) with (1 + m)%nat. + exists (col_mx A1 x). + intros. + destruct (splitP i) as [i0|i0]. + + + replace i with (@lshift 1 m i0). 2: apply ord_inj; simpl; auto. - rewrite row_mxEl. rewrite mxE. - replace (lshift n j0) with (@ord0 n); auto. - rewrite ord1; apply ord_inj; simpl; auto. - replace j with (@rshift 1 n k). - 2: apply ord_inj; simpl; lia. - rewrite row_mxEr. - apply H0. -} -destruct H1 as [A1 ?]. -change (m.+1) with (1 + m)%nat. -exists (col_mx A1 x). -intros. -destruct (splitP i) as [i0|i0]. -+ -replace i with (@lshift 1 m i0). - 2: apply ord_inj; simpl; auto. -rewrite col_mxEu. -replace (lshift m i0) with (@ord0 m). -2: rewrite ord1; apply ord_inj; simpl; auto. -rewrite ord1. -apply H1. -+ -replace i with (@rshift 1 m i0). -2: apply ord_inj; simpl; auto. -rewrite col_mxEd. -apply H0. -Qed. - -Lemma rev_ord_enum: forall n, rev (ord_enum n) = map (@rev_ord n) (ord_enum n). -Proof. -intros. -assert (map (@nat_of_ord n) (rev (ord_enum n)) = map (@nat_of_ord n) (map (@rev_ord n) (ord_enum n))). -2:{ -set a := rev (ord_enum n) in H|-*; clearbody a. -set b := map (@rev_ord _) _ in H|-*; clearbody b. -revert b H; induction a; destruct b; intros; try discriminate; simpl; auto. -inversion H; clear H; subst. -f_equal; auto. -apply ord_inj; auto. -} -rewrite -map_comp map_rev val_ord_enum. -transitivity (map (fun y => subn n (S y)) (map (@nat_of_ord n) (ord_enum n))). -2: rewrite -map_comp /comp //. -unfold ord_enum. -rewrite pmap_filter. -2: intro; simpl; unfold insub; destruct idP; simpl in *; auto. -transitivity (map (fun y => subn n (S y)) (iota 0 n)). -2:{ -set u := O. - f_equal. symmetry. apply /all_filterP. -replace (fun x: nat => isSome (insub x)) with (fun x => x leq (S x) (addn (S n) u)) with (fun x : nat => leq (S x) (addn n (S u))); auto. -apply FunctionalExtensionality.functional_extensionality; intro x; lia. -} -apply nth_ext with (d:=O) (d':=O); change @length with @size. -rewrite size_rev size_map //. -intros. -rewrite size_rev size_iota in H. -rewrite -!nth_List_nth. -rewrite nth_rev. -2: rewrite size_iota; lia. -rewrite size_iota. -rewrite nth_iota. -2: lia. -rewrite (nth_map O). -2: rewrite size_iota; lia. -rewrite nth_iota; try lia. -Qed. - -Lemma nth_ord_enum_lemma: - forall [T] (d: T) (u: seq T), - u = map (nth d u \o @nat_of_ord (size u)) (ord_enum (size u)). -Proof. -intros. -rewrite map_comp val_ord_enum map_nth_iota0 // take_size //. -Qed. - -Lemma sumR_sum: forall (x: seq R), sumR x = \sum_(i in 'I_(size x)) nth R0 x (nat_of_ord i). -Proof. -intros. -rewrite /sumR (unlock bigop_unlock) - /reducebig /comp /applybig /= index_ord_enum. - rewrite {1}(nth_ord_enum_lemma R0 x). - rewrite foldr_map //. -Qed. - -Module F. (* Floating-point math-comp matrix and vector operations *) - -Section WithNAN. -Context {NAN: FPCore.Nans} {t : type}. - -Definition sum [n: nat] (x: 'I_n -> ftype t) : ftype t := - \big[BPLUS / neg_zero]_i x (rev_ord i). - -Definition dotprod [n: nat] (x: 'rV[ftype t]_n) (y: 'cV[ftype t]_n) : ftype t := - \big[BPLUS / pos_zero]_i (BMULT (x ord0 (rev_ord i)) (y (rev_ord i) ord0)). - -Definition FMA_dotprod [n: nat] (x: 'rV[ftype t]_n) (y: 'cV[ftype t]_n) : ftype t := - fma_dotprod (seq_of_rV x) (seq_of_rV y^T). - -Definition mulmx [m n p] (A: 'M[ftype t]_(m,n)) (B: 'M[ftype t]_(n,p)) := - \matrix_(i,k) dotprod (row i A) (col k B). - -Definition FMA_mulmx [m n p] (A: 'M[ftype t]_(m,n)) (B: 'M[ftype t]_(n,p)) := - \matrix_(i,k) FMA_dotprod (row i A) (col k B). - -Definition scalemx [m n] (a: ftype t) (M: 'M[ftype t]_(m,n)) := + rewrite col_mxEu. + replace (lshift m i0) with (@ord0 m). + 2: rewrite ord1; apply ord_inj; simpl; auto. + rewrite ord1. + apply H1. + + + replace i with (@rshift 1 m i0). + 2: apply ord_inj; simpl; auto. + rewrite col_mxEd. + apply H0. +Qed. + +(** Reversing [ord_enum n] yields the list of [rev_ord] images. *) + +Lemma rev_ord_enum : forall n, + rev (ord_enum n) = map (@rev_ord n) (ord_enum n). +Proof. + intros n. + assert (Hnat : map (@nat_of_ord n) (rev (ord_enum n)) = + map (@nat_of_ord n) (map (@rev_ord n) (ord_enum n))). { + rewrite -map_comp map_rev val_ord_enum. + transitivity (map (fun y => subn n (S y)) (map (@nat_of_ord n) (ord_enum n))). + 2: { rewrite -map_comp /comp //. } + unfold ord_enum. + rewrite pmap_filter. + 2: { intro; simpl; unfold insub; destruct idP; simpl in *; auto. } + transitivity (map (fun y => subn n (S y)) (iota 0 n)). + 2: { + set u := O. + f_equal. symmetry. apply /all_filterP. + replace (fun x : nat => isSome (insub x)) + with (fun x => x < n + u)%N. + 2: { + subst u; apply FunctionalExtensionality.functional_extensionality. + intro x. rewrite addn0. + unfold insub; destruct idP; auto. + } + clearbody u. + revert u; induction n; simpl; intros; auto. + apply /andP; split; [lia |]. + specialize (IHn (S u)). + replace (fun x : nat => leq (S x) (addn (S n) u)) + with (fun x : nat => leq (S x) (addn n (S u))); auto. + apply FunctionalExtensionality.functional_extensionality. + intro x; lia. + } + apply nth_ext with (d := O) (d' := O); + change @length with @size. + - rewrite size_rev size_map //. + - intros i Hi. + rewrite size_rev size_iota in Hi. + rewrite -!nth_List_nth nth_rev. + 2: rewrite size_iota; lia. + rewrite size_iota nth_iota. + 2: lia. + rewrite (nth_map O). + 2: rewrite size_iota; lia. + rewrite nth_iota; try lia. + } + set a := rev (ord_enum n) in Hnat |-*; clearbody a. + set b := map (@rev_ord _) _ in Hnat |-*; clearbody b. + revert b Hnat. + induction a; destruct b; intros Hnat; try discriminate; simpl; auto. + inversion Hnat as [Heq]; clear Hnat; subst. + f_equal; auto. + apply ord_inj; auto. +Qed. + +(** Any list [u] is the image of [nth d u] composed with the ordinal + enumeration of [size u]. *) +Lemma nth_ord_enum_lemma : forall [T] (d : T) (u : seq T), + u = map (nth d u \o @nat_of_ord (size u)) (ord_enum (size u)). +Proof. + intros. + rewrite map_comp val_ord_enum map_nth_iota0 // take_size //. +Qed. + +(** [sumR x] equals the big sum over ordinals of [nth R0 x i]. *) + +Lemma sumR_sum : forall (x : seq R), + sumR x = \sum_(i in 'I_(size x)) nth R0 x (nat_of_ord i). +Proof. + intros x. + rewrite /sumR (unlock bigop_unlock) + /reducebig /comp /applybig /= index_ord_enum. + rewrite {1}(nth_ord_enum_lemma R0 x). + rewrite foldr_map //. +Qed. + +(** ** Module F: Floating-point MathComp matrix/vector operations *) + +Module F. + +(** This module defines floating-point analogues of the standard matrix + and vector operations using MathComp's matrix type. The operations + are parameterized over a floating-point type << t >> and a NaN payload + [NAN]. *) + +Section WithNAN. +Context {NAN : FPCore.Nans} {t : type}. + +(** [sum x] computes the floating-point sum of the values of [x] over + all ordinal indices, accumulating in reverse order. *) +Definition sum [n : nat] (x : 'I_n -> ftype t) : ftype t := + \big[BPLUS / neg_zero]_i x (rev_ord i). + +(** [dotprod x y] computes the floating-point dot product of row vector + [x] and column vector [y] using pairwise [BMULT] and [BPLUS]. *) +Definition dotprod [n : nat] (x : 'rV[ftype t]_n) (y : 'cV[ftype t]_n) + : ftype t := + \big[BPLUS / pos_zero]_i (BMULT (x ord0 (rev_ord i)) (y (rev_ord i) ord0)). + +(** [FMA_dotprod x y] computes the dot product of [x] and [y] using + fused multiply-add ([fma_dotprod]) on their list representations. *) +Definition FMA_dotprod [n : nat] (x : 'rV[ftype t]_n) (y : 'cV[ftype t]_n) + : ftype t := + fma_dotprod (seq_of_rV x) (seq_of_rV y^T). + +(** [mulmx A B] is the floating-point matrix product, with each entry + computed via [dotprod]. *) +Definition mulmx [m n p] (A : 'M[ftype t]_(m,n)) (B : 'M[ftype t]_(n,p)) := + \matrix_(i,k) dotprod (row i A) (col k B). + +(** [FMA_mulmx A B] is the floating-point matrix product using + [FMA_dotprod] for each entry. *) +Definition FMA_mulmx [m n p] (A : 'M[ftype t]_(m,n)) (B : 'M[ftype t]_(n,p)) := + \matrix_(i,k) FMA_dotprod (row i A) (col k B). + +(** [scalemx a M] scales every entry of << M >> by << a >> using [BMULT]. *) + +Definition scalemx [m n] (a : ftype t) (M : 'M[ftype t]_(m,n)) := map_mx (BMULT a) M. -Definition addmx [m n] (A B: 'M[ftype t]_(m,n)) : 'M[ftype t]_(m,n) := +(** [addmx A B] adds two matrices entry-wise using [BPLUS]. *) + +Definition addmx [m n] (A B : 'M[ftype t]_(m,n)) : 'M[ftype t]_(m,n) := \matrix_(i,j) BPLUS (A i j) (B i j). -Lemma mulmx_row: - forall m n p1 p2 (A: 'M[ftype t]_(m,n)) (Bl: 'M_(n,p1)) (Br: 'M_(n,p2)), +(** [mulmx] distributes over right block-row matrices. *) + +Lemma mulmx_row : + forall m n p1 p2 + (A : 'M[ftype t]_(m,n)) + (Bl : 'M_(n,p1)) + (Br : 'M_(n,p2)), mulmx A (row_mx Bl Br) = row_mx (mulmx A Bl) (mulmx A Br). Proof. -intros. -apply /matrixP => i j. -case_splitP j. - rewrite row_mxEl !mxE -col_lsubmx row_mxKl //. - rewrite row_mxEr !mxE -col_rsubmx row_mxKr //. + intros. + apply /matrixP => i j. + case_splitP j. + - rewrite row_mxEl !mxE -col_lsubmx row_mxKl //. + - rewrite row_mxEr !mxE -col_rsubmx row_mxKr //. Qed. -Lemma FMA_mulmx_row: - forall m n p1 p2 (A: 'M[ftype t]_(m,n)) (Bl: 'M_(n,p1)) (Br: 'M_(n,p2)), +(** [FMA_mulmx] distributes over right block-row matrices. *) + +Lemma FMA_mulmx_row : + forall m n p1 p2 + (A : 'M[ftype t]_(m,n)) + (Bl : 'M_(n,p1)) + (Br : 'M_(n,p2)), FMA_mulmx A (row_mx Bl Br) = row_mx (FMA_mulmx A Bl) (FMA_mulmx A Br). Proof. -intros. -apply /matrixP => i j. -case_splitP j. - rewrite row_mxEl !mxE -col_lsubmx row_mxKl //. - rewrite row_mxEr !mxE -col_rsubmx row_mxKr //. + intros. + apply /matrixP => i j. + case_splitP j. + - rewrite row_mxEl !mxE -col_lsubmx row_mxKl //. + - rewrite row_mxEr !mxE -col_rsubmx row_mxKr //. Qed. -Lemma mulmx_col: - forall m1 m2 n p (Au: 'M[ftype t]_(m1,n)) (Ad: 'M[ftype t]_(m2,n)) (B: 'M_(n,p)), +(** [mulmx] distributes over left block-column matrices. *) + +Lemma mulmx_col : + forall m1 m2 n p + (Au : 'M[ftype t]_(m1,n)) + (Ad : 'M[ftype t]_(m2,n)) + (B : 'M_(n,p)), mulmx (col_mx Au Ad) B = col_mx (mulmx Au B) (mulmx Ad B). Proof. -intros. -apply /matrixP => i j. -case_splitP i. - rewrite col_mxEu !mxE -row_usubmx col_mxKu //. - rewrite col_mxEd !mxE -row_dsubmx col_mxKd //. + intros. + apply /matrixP => i j. + case_splitP i. + - rewrite col_mxEu !mxE -row_usubmx col_mxKu //. + - rewrite col_mxEd !mxE -row_dsubmx col_mxKd //. Qed. -Lemma FMA_mulmx_col: - forall m1 m2 n p (Au: 'M[ftype t]_(m1,n)) (Ad: 'M[ftype t]_(m2,n)) (B: 'M_(n,p)), +(** [FMA_mulmx] distributes over left block-column matrices. *) + +Lemma FMA_mulmx_col : + forall m1 m2 n p + (Au : 'M[ftype t]_(m1,n)) + (Ad : 'M[ftype t]_(m2,n)) + (B : 'M_(n,p)), FMA_mulmx (col_mx Au Ad) B = col_mx (FMA_mulmx Au B) (FMA_mulmx Ad B). Proof. -intros. -apply /matrixP => i j. -case_splitP i. - rewrite col_mxEu !mxE -row_usubmx col_mxKu //. - rewrite col_mxEd !mxE -row_dsubmx col_mxKd //. + intros. + apply /matrixP => i j. + case_splitP i. + - rewrite col_mxEu !mxE -row_usubmx col_mxKu //. + - rewrite col_mxEd !mxE -row_dsubmx col_mxKd //. Qed. -Lemma sum_sumF: forall [n] (x: 'I_n -> ftype t), sum x = sumF (map x (ord_enum n)). +(** [sum x] equals the list-based [sumF] applied to the image of [x] + over [ord_enum n]. *) +Lemma sum_sumF : forall [n] (x : 'I_n -> ftype t), + sum x = sumF (map x (ord_enum n)). Proof. - intros. - rewrite /sum /sumF (unlock bigop_unlock) /reducebig /comp /applybig - -(revK (map x _)) foldl_rev -map_rev rev_ord_enum -map_comp foldr_map index_ord_enum //. + intros. + rewrite /sum /sumF (unlock bigop_unlock) /reducebig /comp /applybig + -(revK (map x _)) foldl_rev -map_rev rev_ord_enum + -map_comp foldr_map index_ord_enum //. Qed. -Lemma dotprod_dotprodF: - forall [n] (x: 'rV[ftype t]_n) (y: 'cV[ftype t]_n), +(** [dotprod x y] equals the list-based [dotprodF] applied to the + list representations of [x] and [y^T]. *) +Lemma dotprod_dotprodF : + forall [n] (x : 'rV[ftype t]_n) (y : 'cV[ftype t]_n), dotprod x y = dotprodF (seq_of_rV x) (seq_of_rV (trmx y)). Proof. -intros. - rewrite /dotprod /seq_of_rV /dotprodF /dotprod_model.dotprod !ord1. - rewrite (unlock bigop_unlock). - unfold reducebig, comp, applybig. - rewrite -(revK (map (uncurry _) _)). - rewrite foldl_rev. - simpl. - rewrite index_ord_enum. - rewrite zip_map -map_comp. - rewrite -map_rev rev_ord_enum -map_comp. - rewrite foldr_map. - f_equal. - simpl. - apply FunctionalExtensionality.functional_extensionality; intro i. - apply FunctionalExtensionality.functional_extensionality; intro z. - rewrite !mxE. reflexivity. + intros. + rewrite /dotprod /seq_of_rV /dotprodF /dotprod_model.dotprod !ord1. + rewrite (unlock bigop_unlock). + unfold reducebig, comp, applybig. + rewrite -(revK (map (uncurry _) _)). + rewrite foldl_rev; simpl. + rewrite index_ord_enum zip_map -map_comp. + rewrite -map_rev rev_ord_enum -map_comp foldr_map. + f_equal; simpl. + apply FunctionalExtensionality.functional_extensionality; intro i. + apply FunctionalExtensionality.functional_extensionality; intro z. + rewrite !mxE; reflexivity. Qed. -Lemma mulmx_dotprodF: - forall [n] (A: 'M[ftype t]_(1,n)) (B: 'M[ftype t]_(n,1)), - mulmx A B = const_mx (dotprodF (seq_of_rV A) (seq_of_rV (trmx B))). +(** For [1 × n] and [n × 1] matrices, [mulmx A B] equals the constant + matrix whose sole entry is [dotprodF (seq_of_rV A) (seq_of_rV B^T)]. *) +Lemma mulmx_dotprodF : + forall [n] (A : 'M[ftype t]_(1,n)) (B : 'M[ftype t]_(n,1)), + mulmx A B = const_mx (dotprodF (seq_of_rV A) (seq_of_rV (trmx B))). Proof. -intros. - unfold mulmx. apply /matrixP. move => i k. rewrite !mxE row_id col_id. - apply dotprod_dotprodF. + intros. + unfold mulmx. + apply /matrixP => i k. + rewrite !mxE row_id col_id. + apply dotprod_dotprodF. Qed. -Lemma FMA_mulmx_fma_dotprod: - forall [n] (A: 'M[ftype t]_(1,n)) (B: 'M[ftype t]_(n,1)), - FMA_mulmx A B = const_mx (fma_dotprod (seq_of_rV A) (seq_of_rV (trmx B))). +(** For [1 × n] and [n × 1] matrices, [FMA_mulmx A B] equals the + constant matrix whose sole entry is + [fma_dotprod (seq_of_rV A) (seq_of_rV B^T)]. *) +Lemma FMA_mulmx_fma_dotprod : + forall [n] (A : 'M[ftype t]_(1,n)) (B : 'M[ftype t]_(n,1)), + FMA_mulmx A B = const_mx (fma_dotprod (seq_of_rV A) (seq_of_rV (trmx B))). Proof. -intros. - unfold mulmx. apply /matrixP. move => i k. rewrite !mxE row_id col_id //. + intros. + unfold mulmx. + apply /matrixP => i k. + rewrite !mxE row_id col_id //. Qed. -Definition finitemx [m n] (A: 'M[ftype t]_(m,n)) : Prop := - (forall i j, Binary.is_finite (A i j)). +(** [finitemx A] asserts that every entry of [A] is a finite + floating-point number. *) +Definition finitemx [m n] (A : 'M[ftype t]_(m,n)) : Prop := + forall i j, Binary.is_finite (A i j). + +(** If [addmx A B] is finite entry-wise, then both [A] and [B] are. *) -Lemma finitemx_addmx_e: forall [m n] (A B: 'M[ftype t]_(m,n)), +Lemma finitemx_addmx_e : forall [m n] (A B : 'M[ftype t]_(m,n)), finitemx (addmx A B) -> finitemx A /\ finitemx B. Proof. -rewrite /addmx /finitemx => m n A B Hfin. -split => i j; specialize (Hfin i j); rewrite mxE in Hfin; apply BPLUS_finite_e in Hfin; apply Hfin. + rewrite /addmx /finitemx => m n A B Hfin. + split => i j; + specialize (Hfin i j); + rewrite mxE in Hfin; + apply BPLUS_finite_e in Hfin; + apply Hfin. Qed. -Lemma finitemx_scalemx_e: forall [m n] (c: ftype t) (A: 'M[ftype t]_(m,n)), +(** If [scalemx c A] is finite entry-wise, then [A] is. *) + +Lemma finitemx_scalemx_e : forall [m n] (c : ftype t) (A : 'M[ftype t]_(m,n)), finitemx (scalemx c A) -> finitemx A. Proof. -rewrite /scalemx /finitemx => m n c A Hfin i j. -specialize (Hfin i j). rewrite mxE in Hfin. apply BMULT_finite_e in Hfin; apply Hfin. + rewrite /scalemx /finitemx => m n c A Hfin i j. + specialize (Hfin i j). + rewrite mxE in Hfin. + apply BMULT_finite_e in Hfin. + apply Hfin. Qed. End WithNAN. End F. -Definition listlist_of_mx {T} [m n: nat] (A: 'M[T]_(m,n)) : list (list T) := - map (fun i: 'I_m => map (A i) (ord_enum n)) (ord_enum m). - -Definition list_of_cV {T} [n: nat] (V: 'cV[T]_n) : list T := - map (fun i => V i ord0) (ord_enum n). - -Definition mx_of_listlist {T} {d: T} (rows cols: nat) (mval: list (list T)) : 'M[T]_(rows, cols) := - \matrix_(i,j) seq.nth (d: T) (seq.nth nil mval i) j. - -Definition cV_of_list {T} {d: T} (n: nat) (vval: list T) : 'cV[T]_n := - \matrix_(i,j) seq.nth (d:T) vval i. - -Definition matrix_cols_nat {T} (m: list (list T)) (cols: nat) := - Forall (fun r => size r = cols) m. - -Lemma listlist_of_mx_of_listlist: - forall {t} {d} rows cols (mval: list (list (ftype t))), - rows = Datatypes.length mval -> - matrix_cols_nat mval cols -> - listlist_of_mx (@mx_of_listlist _ d rows cols mval) = mval. -Proof. -intros. -unfold listlist_of_mx, mx_of_listlist. -eapply (nth_ext _ _ nil nil). -rewrite length_map -H. apply size_ord_enum. -intros i Hi. -rewrite -!nth_List_nth. -rewrite length_map in Hi. -change @length with @size in Hi,H. -rewrite size_ord_enum in Hi. -rewrite (nth_ord_enum_lemma nil mval) -H. -f_equal. -f_equal. -apply FunctionalExtensionality.functional_extensionality; intro j. -rewrite map_comp /comp. -rewrite val_ord_enum. -rewrite map_nth_iota. 2: lia. -rewrite drop0. -replace (take rows mval) with mval. -2: rewrite H take_size //. -rewrite (nth_ord_enum_lemma d (nth nil mval j)). -replace (size _) with cols. -2:{ clear i Hi. - red in H0. rewrite Forall_nth in H0. specialize (H0 j nil). rewrite nth_List_nth. symmetry; apply H0. - change @length with @size; rewrite -H. pose proof (ltn_ord j). lia. -} -f_equal. -simpl. -clear i Hi. -rename j into i. -apply FunctionalExtensionality.functional_extensionality; intro j. -rewrite mxE /comp //. -Qed. - -Lemma mx_of_listlist_of_mx: - forall {T} {d:T} rows cols (A: 'M[T]_(rows,cols)), - @mx_of_listlist _ d rows cols (listlist_of_mx A) = A. -Proof. -intros. -apply matrixP => i j. -rewrite /mx_of_listlist mxE /listlist_of_mx. -rewrite (nth_map i). -2: rewrite size_ord_enum; apply ltn_ord. -rewrite (nth_map j). -2: rewrite size_ord_enum; apply ltn_ord. -rewrite !nth_ord_enum'. -auto. -Qed. - -Lemma list_of_cV_of_list: - forall {T} {d:T} n (vval: list T), - size vval = n -> - list_of_cV (@cV_of_list _ d n vval) = vval. -Proof. -intros. -unfold list_of_cV, cV_of_list. -apply (nth_ext _ _ d d). -rewrite length_map -H. apply size_ord_enum. -intros i Hi. -rewrite -!nth_List_nth. -rewrite length_map in Hi. -change @length with @size in Hi,H. -rewrite size_ord_enum in Hi. -rewrite (nth_ord_enum_lemma d vval) -H. -f_equal. -f_equal. -(* apply FunctionalExtensionality.functional_extensionality; intro j. *) -rewrite map_comp /comp. -rewrite val_ord_enum. -rewrite map_nth_iota. 2: lia. -rewrite drop0 take_size. -apply FunctionalExtensionality.functional_extensionality; intro j. -rewrite mxE //. -Qed. - -Lemma cV_of_list_of_cV: - forall {T} `{d:T} n (x: 'cV[T]_n), +(** ** Conversions between MathComp matrices and list-of-lists *) + +(** [listlist_of_mx A] converts a MathComp matrix to a list of rows, + each row being a list of entries. *) +Definition listlist_of_mx {T} [m n : nat] (A : 'M[T]_(m,n)) : list (list T) := + map (fun i : 'I_m => map (A i) (ord_enum n)) (ord_enum m). + +(** [list_of_cV V] converts a MathComp column vector to a plain list. *) + +Definition list_of_cV {T} [n : nat] (V : 'cV[T]_n) : list T := + map (fun i => V i ord0) (ord_enum n). + +(** [mx_of_listlist rows cols mval] builds a MathComp matrix from a + list-of-lists [mval], using [d] as the default element for out-of-bounds + accesses. *) +Definition mx_of_listlist {T} {d : T} (rows cols : nat) (mval : list (list T)) + : 'M[T]_(rows, cols) := + \matrix_(i,j) seq.nth (d : T) (seq.nth nil mval i) j. + +(** [cV_of_list n vval] builds a MathComp column vector from a list [vval], + using [d] as the default element for out-of-bounds accesses. *) +Definition cV_of_list {T} {d : T} (n : nat) (vval : list T) : 'cV[T]_n := + \matrix_(i,j) seq.nth (d : T) vval i. + +(** [matrix_cols_nat m cols] asserts that every row in [m] has length [cols]. *) + +Definition matrix_cols_nat {T} (m : list (list T)) (cols : nat) := + Forall (fun r => size r = cols) m. + +(** Round-trip: converting a list-of-lists to a MathComp matrix and back + recovers the original list-of-lists, provided the dimensions match. *) +Lemma listlist_of_mx_of_listlist : + forall {t} {d} rows cols (mval : list (list (ftype t))), + rows = Datatypes.length mval -> + matrix_cols_nat mval cols -> + listlist_of_mx (@mx_of_listlist _ d rows cols mval) = mval. +Proof. + intros t d rows cols mval Hrows Hcols. + unfold listlist_of_mx, mx_of_listlist. + eapply (nth_ext _ _ nil nil). + - rewrite length_map -Hrows. apply size_ord_enum. + - intros i Hi. + rewrite -!nth_List_nth. + rewrite length_map in Hi. + change @length with @size in Hi, Hrows. + rewrite size_ord_enum in Hi. + rewrite (nth_ord_enum_lemma nil mval) -Hrows. + f_equal; f_equal. + apply FunctionalExtensionality.functional_extensionality; intro j. + rewrite map_comp /comp val_ord_enum. + rewrite map_nth_iota; [| lia]. + rewrite drop0. + replace (take rows mval) with mval. + 2: rewrite Hrows take_size //. + rewrite (nth_ord_enum_lemma d (nth nil mval j)). + replace (size _) with cols. + 2: { + clear i Hi. + red in Hcols; rewrite Forall_nth in Hcols. + specialize (Hcols j nil). + rewrite nth_List_nth. + symmetry; apply Hcols. + change @length with @size; rewrite -Hrows. + pose proof (ltn_ord j); lia. + } + f_equal; simpl. + clear i Hi; rename j into i. + apply FunctionalExtensionality.functional_extensionality; intro j. + rewrite mxE /comp //. +Qed. + +(** Round-trip: converting a MathComp matrix to a list-of-lists and back + recovers the original matrix. *) +Lemma mx_of_listlist_of_mx : + forall {T} {d : T} rows cols (A : 'M[T]_(rows,cols)), + @mx_of_listlist _ d rows cols (listlist_of_mx A) = A. +Proof. + intros. + apply matrixP => i j. + rewrite /mx_of_listlist mxE /listlist_of_mx. + rewrite (nth_map i). + 2: rewrite size_ord_enum; apply ltn_ord. + rewrite (nth_map j). + 2: rewrite size_ord_enum; apply ltn_ord. + rewrite !nth_ord_enum'; auto. +Qed. + +(** Round-trip: converting a list to a column vector and back recovers + the original list, provided the sizes match. *) +Lemma list_of_cV_of_list : + forall {T} {d : T} n (vval : list T), + size vval = n -> + list_of_cV (@cV_of_list _ d n vval) = vval. +Proof. + intros T d n vval Hsize. + unfold list_of_cV, cV_of_list. + apply (nth_ext _ _ d d). + - rewrite length_map -Hsize. apply size_ord_enum. + - intros i Hi. + rewrite -!nth_List_nth. + rewrite length_map in Hi. + change @length with @size in Hi, Hsize. + rewrite size_ord_enum in Hi. + rewrite (nth_ord_enum_lemma d vval) -Hsize. + f_equal; f_equal. + rewrite map_comp /comp val_ord_enum. + rewrite map_nth_iota; [| lia]. + rewrite drop0 take_size. + apply FunctionalExtensionality.functional_extensionality; intro j. + rewrite mxE //. +Qed. + +(** Round-trip: converting a column vector to a list and back recovers + the original vector. *) +Lemma cV_of_list_of_cV : + forall {T} `{d : T} n (x : 'cV[T]_n), @cV_of_list _ d n (list_of_cV x) = x. Proof. -intros. -apply matrixP => i j. -rewrite /mx_of_listlist mxE /listlist_of_mx. -rewrite (nth_map i). -2: rewrite size_ord_enum; apply ltn_ord. -rewrite !ord1. -f_equal. -apply nth_ord_enum'. + intros. + apply matrixP => i j. + rewrite /mx_of_listlist mxE /listlist_of_mx. + rewrite (nth_map i). + 2: rewrite size_ord_enum; apply ltn_ord. + rewrite !ord1. + f_equal. + apply nth_ord_enum'. Qed. -Lemma matrix_rows_listlist_of_mx: forall {T} [rows cols] (A: 'M[T]_(rows,cols)), - size (listlist_of_mx A) = rows. +(** [listlist_of_mx A] has [rows] rows. *) + +Lemma matrix_rows_listlist_of_mx : forall {T} [rows cols] (A : 'M[T]_(rows,cols)), + size (listlist_of_mx A) = rows. Proof. -intros. -unfold listlist_of_mx. rewrite size_map. apply size_ord_enum. + intros. + unfold listlist_of_mx. + rewrite size_map. + apply size_ord_enum. Qed. -Lemma matrix_cols_listlist_of_mx: forall {T} [rows cols] (A: 'M[T]_(rows,cols)), +(** Every row of [listlist_of_mx A] has length [cols]. *) + +Lemma matrix_cols_listlist_of_mx : forall {T} [rows cols] (A : 'M[T]_(rows,cols)), matrix_cols_nat (listlist_of_mx A) cols. Proof. -intros. -unfold matrix_cols_nat, listlist_of_mx. -apply Forall_map, Forall_forall. -intros; simpl. -rewrite size_map. apply mv_mathcomp.size_ord_enum. + intros. + unfold matrix_cols_nat, listlist_of_mx. + apply Forall_map, Forall_forall. + intros; simpl. + rewrite size_map. + apply mv_mathcomp.size_ord_enum. Qed. -Lemma size_list_of_cV: forall {T} [n] (vval: 'cV[T]_n), +(** [list_of_cV vval] has length [n]. *) + +Lemma size_list_of_cV : forall {T} [n] (vval : 'cV[T]_n), size (list_of_cV vval) = n. Proof. -intros. -rewrite /list_of_cV size_map size_ord_enum //. + intros. + rewrite /list_of_cV size_map size_ord_enum //. Qed. +(** The [i]-th element of [list_of_cV vval] is [vval i ord0]. *) -Lemma nth_list_of_cV: - forall {T} {d:T} [n] (vval: 'cV[T]_n) (i: 'I_n), - nth d (list_of_cV vval) (nat_of_ord i) = vval i ord0. +Lemma nth_list_of_cV : + forall {T} {d : T} [n] (vval : 'cV[T]_n) (i : 'I_n), + nth d (list_of_cV vval) (nat_of_ord i) = vval i ord0. Proof. -intros. -rewrite /list_of_cV (nth_map i) ?nth_ord_enum' // size_ord_enum. -apply ltn_ord. + intros. + rewrite /list_of_cV (nth_map i) ?nth_ord_enum' // size_ord_enum. + apply ltn_ord. Qed. -Definition list_dotprod {NAN: FPCore.Nans} {t: type} (v1 v2: list (ftype t)) : ftype t := - foldl (fun s x12 => BFMA (fst x12) (snd x12) s) (Zconst t 0) (zip v1 v2) . +(** ** List-based floating-point operations *) + +(** [list_dotprod v1 v2] computes the dot product of [v1] and [v2] using + fused multiply-add, accumulating from left to right with initial + value [0]. *) +Definition list_dotprod {NAN : FPCore.Nans} {t : type} + (v1 v2 : list (ftype t)) : ftype t := + foldl (fun s x12 => BFMA (fst x12) (snd x12) s) (Zconst t 0) (zip v1 v2). + +(** [matrix_vector_mult m v] applies the matrix [m] (given as a list of rows) + to the vector [v] (a list), computing each entry via [list_dotprod]. *) +Definition matrix_vector_mult {NAN : FPCore.Nans} {t : type} + (m : list (list (ftype t))) (v : list (ftype t)) : list (ftype t) := + map (fun row => list_dotprod row v) m. + +(** [list_of_cV] commutes with [col_mx]: stacking two column vectors and + converting to a list is the same as concatenating their list + representations. *) +Lemma list_of_cV_col_mx : forall {T} n1 n2 (x : 'cV[T]_n1) (y : 'cV[T]_n2), + list_of_cV (col_mx x y) = list_of_cV x ++ list_of_cV y. +Proof. + intros T n1 n2 x y. + assert (Hn1: (n1 = O \/ 0 < n1)%N) by lia. + destruct Hn1 as [Hn1 | Hn1]. + - subst. + rewrite /list_of_cV /col_mx; simpl. + apply eq_in_map. + red; simpl; intros x0 _. + clear. + rewrite mxE. + change n2 with (addn O n2) in x0. + case_splitP x0. { destruct x0; lia. } + f_equal; apply ord_inj; simpl; reflexivity. + - assert (d : T). + { destruct n1; try lia; apply (x ord0 ord0). } + rewrite /list_of_cV /col_mx. + apply eq_from_nth with d. + + rewrite size_cat !size_map !size_ord_enum //. + + intros i. + rewrite size_map !size_ord_enum => Hi. + rewrite nth_cat size_map size_ord_enum. + rewrite (nth_map (Ordinal Hi)) ?size_ord_enum // mxE. + assert (Hnth : nth (Ordinal Hi) (ord_enum (n1 + n2)) i = Ordinal Hi). + { change i with (nat_of_ord (Ordinal Hi)). + rewrite nth_ord_enum' //. } + rewrite Hnth. + destruct (i < n1)%N eqn:Hlt. + * unfold split; simpl. + destruct (ltnP i n1); try lia. + rewrite (nth_map (Ordinal i0)). + 2: rewrite size_ord_enum //. + change i with (nat_of_ord (Ordinal i0)). + rewrite nth_ord_enum' //. + * unfold split; simpl. + destruct (ltnP i n1); try lia. + assert (Hlt2 : is_true (i - n1 < n2)%N) by lia. + rewrite (nth_map (Ordinal Hlt2)). + 2: rewrite size_ord_enum //. + change (i - n1)%nat with (nat_of_ord (Ordinal Hlt2)). + rewrite nth_ord_enum' //. + f_equal; apply ord_inj; simpl; auto. +Qed. -Definition matrix_vector_mult {NAN: FPCore.Nans}{t: type} (m: list (list (ftype t))) (v: list (ftype t)) : list (ftype t) := - map (fun row => list_dotprod row v) m. +(** Mapping a constant function yields a [repeat]. *) -Lemma list_of_cV_col_mx: forall {T} n1 n2 (x: 'cV[T]_n1) (y: 'cV[T]_n2), - list_of_cV (col_mx x y) = list_of_cV x ++ list_of_cV y. +Lemma map_const_len : forall {A B} (c : B) (al : list A), + map (fun _ => c) al = repeat c (length al). Proof. -intros. -assert (n1 = O \/ 0< n1)%N by lia. -destruct H. -subst. -- rewrite /list_of_cV /col_mx. simpl. - apply eq_in_map. - red; simpl; intros. - clear H. - rewrite mxE. - change n2 with (addn O n2) in x0. - case_splitP x0. destruct x0; lia. - f_equal. - apply ord_inj. simpl. reflexivity. -- - assert (d: T). destruct n1; try lia; apply (x ord0 ord0). - rewrite /list_of_cV /col_mx. - apply eq_from_nth with d. - rewrite size_cat !size_map !size_ord_enum //. - intros i. - rewrite size_map !size_ord_enum => Hi. - rewrite nth_cat. - rewrite size_map size_ord_enum. - rewrite (nth_map (Ordinal Hi)) ?size_ord_enum // mxE. - assert (nth (Ordinal Hi) (ord_enum (n1+n2)) i = Ordinal Hi). - change i with (nat_of_ord (Ordinal Hi)). - rewrite nth_ord_enum' //. - rewrite H0. - destruct (i c) al = repeat c (length al). -Proof. -induction al; simpl; intros; f_equal; auto. -Qed. - -Lemma listlist_of_mx_col_mx: forall {T} n1 n2 m (A: 'M[T]_(n1,m)) (B: 'M[T]_(n2,m)), + induction al; simpl; intros; f_equal; auto. +Qed. + +(** [listlist_of_mx] commutes with [col_mx]: stacking two matrices and + converting to a list-of-lists is the same as appending their + list-of-lists representations. *) +Lemma listlist_of_mx_col_mx : + forall {T} n1 n2 m (A : 'M[T]_(n1,m)) (B : 'M[T]_(n2,m)), listlist_of_mx (col_mx A B) = listlist_of_mx A ++ listlist_of_mx B. -intros. -assert (m = 0 \/ 0 < m)%N by lia. destruct H as [Hm | Hm]. { - subst m. rewrite /listlist_of_mx. change (ord_enum 0) with (@nil 'I_0). simpl. - rewrite !map_const_len. - change @length with @size. rewrite !size_ord_enum. - rewrite repeat_app //. -} -assert (n1 = O \/ 0< n1)%N by lia. -destruct H as [Hn1 | Hn1]. -subst. -- rewrite /list_of_cV /col_mx. simpl. - apply eq_in_map; intros i _. - apply eq_in_map; intros j _. - rewrite mxE. - change n2 with (addn O n2) in i. - simpl in *. - case_splitP i. destruct i; lia. - f_equal. - apply ord_inj. simpl. reflexivity. -- - assert (d: T). destruct n1,m; try lia; apply (A ord0 ord0). - rewrite /list_of_cV /col_mx. - apply eq_from_nth with nil. - rewrite size_cat !size_map !size_ord_enum //. - intros i. - rewrite size_map !size_ord_enum => Hi. - rewrite nth_cat. - rewrite size_map size_ord_enum. - rewrite (nth_map (Ordinal Hi)) ?size_ord_enum //. - apply eq_from_nth with d. { - rewrite size_map size_ord_enum. - destruct (leq (S i) n1) eqn:?H. - assert (HA := matrix_cols_listlist_of_mx A). - red in HA. rewrite Forall_nth in HA. specialize (HA i nil). - change @length with @size in HA. - rewrite matrix_rows_listlist_of_mx in HA. - specialize (HA ltac:(lia)). - rewrite -nth_List_nth in HA. auto. - assert (HB := matrix_cols_listlist_of_mx B). - red in HB. rewrite Forall_nth in HB. specialize (HB (i-n1)%nat nil). - change @length with @size in HB. - rewrite matrix_rows_listlist_of_mx in HB. - specialize (HB ltac:(lia)). - rewrite -nth_List_nth in HB. auto. - } - rewrite size_map size_ord_enum => j Hj. - rewrite (nth_map (Ordinal Hj)). - 2: rewrite size_ord_enum //. - change j with (nat_of_ord (Ordinal Hj)). - rewrite nth_ord_enum'. - assert (nth (Ordinal Hi) (ord_enum (n1+n2)) i = Ordinal Hi). - change i with (nat_of_ord (Ordinal Hi)). - rewrite nth_ord_enum' //. - rewrite H. - rewrite mxE. - destruct (i A=B. -Proof. -intros. -apply matrixP. intros i j. -assert (m=O \/ n = O \/ 0 - cols = size vval -> - matrix_cols_nat mval cols -> - matrix_vector_mult mval vval = list_of_cV (F.FMA_mulmx (@mx_of_listlist _ (Zconst t 0) rows cols mval) - (@cV_of_list _ (Zconst t 0) cols vval)). -Proof. -intros. -subst rows. -destruct (size vval) eqn:Hcols. -- -destruct cols; try discriminate. -destruct vval; try discriminate. -clear H0 Hcols. -assert (mval = List.repeat nil (size mval)). -induction H1; auto. simpl. f_equal; auto. destruct x; auto; discriminate. -rewrite H. -set n := size mval. -clearbody n. clear mval H H1. -change @size with @length. rewrite repeat_length. -induction n. reflexivity. -simpl. -rewrite {}IHn. -replace (mx_of_listlist (S n) 0 (cons nil (repeat nil n))) with - (col_mx (@mx_of_listlist _ (Zconst t 0) 1 0 nil) (@mx_of_listlist _ (Zconst t 0) n 0 (repeat nil n))). -2: apply /matrixP => i j; destruct j; lia. -change (S n) with (addn 1 n). -rewrite F.FMA_mulmx_col. -set (u := F.FMA_mulmx _ _). -clearbody u. -rewrite /list_dotprod /=. -rewrite list_of_cV_col_mx. -rewrite {2}/list_of_cV. -set one := ord_enum 1. compute in one. destruct idP; try lia. subst one. -simpl. -f_equal. -rewrite /F.mulmx /mx_of_listlist /cV_of_list mxE //. -- -assert (0 < cols)%N by lia. rewrite -H0 in Hcols. clear n H0. -induction H1; [reflexivity | ]. -simpl. -replace (mx_of_listlist (S (size l)) cols (cons x l)) - with (col_mx (@mx_of_listlist _ (Zconst t 0) 1 cols (cons x nil)) (@mx_of_listlist _ (Zconst t 0) (size l) cols l)). -+ -change (S ?A) with (addn 1 A). -rewrite F.FMA_mulmx_col. -rewrite list_of_cV_col_mx. -replace (list_of_cV _) with [:: list_dotprod x vval]. -simpl. f_equal. -apply IHForall. -rewrite /list_of_cV. -set one := ord_enum 1. compute in one. destruct idP; try lia. subst one. -simpl. f_equal. -rewrite mxE /F.FMA_mulmx /F.FMA_dotprod /fma_dotprod /list_dotprod. -f_equal. f_equal. -* apply (@eq_from_nth _ pos_zero). -rewrite size_seq_of_rV //. -move => j Hj. rewrite H0 in Hj. ordify cols j. -rewrite nth_seq_of_rV !mxE //. -* apply (@eq_from_nth _ pos_zero). -rewrite size_seq_of_rV //. -move => j Hj. rewrite Hcols in Hj. ordify cols j. -rewrite nth_seq_of_rV !mxE //. -+ -change (S (size l)) with (addn 1 (size l)). -apply listlist_of_mx_inj. -rewrite listlist_of_mx_of_listlist. -2: simpl; change @length with @size; lia. -2: constructor; auto. -rewrite listlist_of_mx_col_mx. -rewrite !listlist_of_mx_of_listlist; auto. -constructor; auto. +Proof. + intros T n1 n2 m A B. + assert (Hm : (m = 0 \/ 0 < m)%N) by lia. + destruct Hm as [Hm | Hm]. { + subst m. + rewrite /listlist_of_mx. + change (ord_enum 0) with (@nil 'I_0); simpl. + rewrite !map_const_len. + change @length with @size. + rewrite !size_ord_enum repeat_app //. + } + assert (Hn1 : (n1 = O \/ 0 < n1)%N) by lia. + destruct Hn1 as [Hn1 | Hn1]. + - subst. + rewrite /list_of_cV /col_mx; simpl. + apply eq_in_map; intros i _. + apply eq_in_map; intros j _. + rewrite mxE. + change n2 with (addn O n2) in i; simpl in *. + case_splitP i. { destruct i; lia. } + f_equal; apply ord_inj; simpl; reflexivity. + - assert (d : T). + { destruct n1, m; try lia; apply (A ord0 ord0). } + rewrite /list_of_cV /col_mx. + apply eq_from_nth with nil. + + rewrite size_cat !size_map !size_ord_enum //. + + intros i. + rewrite size_map !size_ord_enum => Hi. + rewrite nth_cat size_map size_ord_enum. + rewrite (nth_map (Ordinal Hi)) ?size_ord_enum //. + apply eq_from_nth with d. { + rewrite size_map size_ord_enum. + destruct (leq (S i) n1) eqn:Hlt. + - have HA := matrix_cols_listlist_of_mx A. + red in HA; rewrite Forall_nth in HA. + specialize (HA i nil). + change @length with @size in HA. + rewrite matrix_rows_listlist_of_mx in HA. + specialize (HA ltac:(lia)). + rewrite -nth_List_nth in HA; auto. + - have HB := matrix_cols_listlist_of_mx B. + red in HB; rewrite Forall_nth in HB. + specialize (HB (i - n1)%nat nil). + change @length with @size in HB. + rewrite matrix_rows_listlist_of_mx in HB. + specialize (HB ltac:(lia)). + rewrite -nth_List_nth in HB; auto. + } + rewrite size_map size_ord_enum => j Hj. + rewrite (nth_map (Ordinal Hj)). + 2: rewrite size_ord_enum //. + change j with (nat_of_ord (Ordinal Hj)). + rewrite nth_ord_enum'. + assert (Hnth : nth (Ordinal Hi) (ord_enum (n1 + n2)) i = Ordinal Hi). + { change i with (nat_of_ord (Ordinal Hi)). + rewrite nth_ord_enum' //. } + rewrite Hnth mxE. + destruct (i < n1)%N eqn:Hlt. + * unfold split; simpl. + destruct (ltnP i n1); try lia. + rewrite (nth_map (Ordinal i0)). + 2: rewrite size_ord_enum //. + change i with (nat_of_ord (Ordinal i0)). + rewrite nth_ord_enum' //. + rewrite (nth_map (Ordinal Hj)). + 2: rewrite size_ord_enum //. + change j with (nat_of_ord (Ordinal Hj)). + rewrite nth_ord_enum' //. + * unfold split; simpl. + destruct (ltnP i n1); try lia. + assert (Hlt2 : is_true (i - n1 < n2)%N) by lia. + rewrite (nth_map (Ordinal Hlt2)). + 2: rewrite size_ord_enum //. + change (i - n1)%nat with (nat_of_ord (Ordinal Hlt2)). + rewrite nth_ord_enum' //. + rewrite (nth_map (Ordinal Hj)). + 2: rewrite size_ord_enum //. + f_equal; apply ord_inj; simpl; auto. + change j with (nat_of_ord (Ordinal Hj)). + rewrite nth_ord_enum' //. Qed. +(** [listlist_of_mx] is injective. *) +Lemma listlist_of_mx_inj : forall {T} [m n] (A B : 'M[T]_(m,n)), + listlist_of_mx A = listlist_of_mx B -> A = B. +Proof. + intros T m n A B Heq. + apply matrixP; intros i j. + assert (Hdim : (m = O \/ n = O \/ 0 < m /\ 0 < n)%N) by lia. + destruct Hdim as [Hm | [Hn | [Hm Hn]]]. + - subst; destruct i; lia. + - subst; destruct j; lia. + - assert (d : T) by + (destruct m; destruct n; try lia; apply (A ord0 ord0)). + assert (Hnth : nth d (nth nil (listlist_of_mx A) i) j = + nth d (nth nil (listlist_of_mx B) i) j). + { rewrite Heq; auto. } + clear - Hnth. + rewrite /listlist_of_mx in Hnth. + pose proof (ltn_ord i) as Hi. + pose proof (ltn_ord j) as Hj. + rewrite !(nth_map i) in Hnth. 2, 3: rewrite size_ord_enum; auto. + rewrite !(nth_map j) in Hnth. 2, 3: rewrite size_ord_enum; auto. + rewrite !nth_ord_enum' in Hnth. + auto. +Qed. + +(** ** Main theorem: floating-point matrix–vector multiplication *) + +(** [Fmulmx_matrix_vector_mult] is the central result connecting the + list-based [matrix_vector_mult] (using [list_dotprod]) to the + MathComp-based [F.FMA_mulmx]. Given: + - [mval]: a list-of-lists of floating-point values with [rows] rows + and [cols] columns, + - [vval]: a list of [cols] floating-point values, + the list [matrix_vector_mult mval vval] equals the column vector + [F.FMA_mulmx (mx_of_listlist mval) (cV_of_list vval)] converted + back to a list. *) +Lemma Fmulmx_matrix_vector_mult : + forall {NAN : FPCore.Nans} {t} rows cols + (mval : list (list (ftype t))) + (vval : list (ftype t)), + rows = size mval -> + cols = size vval -> + matrix_cols_nat mval cols -> + matrix_vector_mult mval vval = + list_of_cV (F.FMA_mulmx + (@mx_of_listlist _ (Zconst t 0) rows cols mval) + (@cV_of_list _ (Zconst t 0) cols vval)). +Proof. + intros NAN t rows cols mval vval Hrows Hcols Hcol_sizes. + subst rows. + destruct (size vval) eqn:Hsz. + - destruct cols; try discriminate. + destruct vval; try discriminate. + clear Hcols Hsz. + assert (Hmval : mval = List.repeat nil (size mval)). { + induction Hcol_sizes; auto. + simpl; f_equal; auto. + destruct x; auto; discriminate. + } + rewrite Hmval. + set n := size mval; clearbody n. + clear mval Hmval Hcol_sizes. + change @size with @length. + rewrite repeat_length. + induction n; [reflexivity |]. + simpl. + rewrite {}IHn. + replace (mx_of_listlist (S n) 0 (cons nil (repeat nil n))) + with (col_mx (@mx_of_listlist _ (Zconst t 0) 1 0 nil) + (@mx_of_listlist _ (Zconst t 0) n 0 (repeat nil n))). + 2: apply /matrixP => i j; destruct j; lia. + change (S n) with (addn 1 n). + rewrite F.FMA_mulmx_col. + set u := F.FMA_mulmx _ _; clearbody u. + rewrite /list_dotprod /=. + rewrite list_of_cV_col_mx. + rewrite {2}/list_of_cV. + set one := ord_enum 1. + compute in one. + destruct idP; try lia. + subst one; simpl. + f_equal. + rewrite /F.mulmx /mx_of_listlist /cV_of_list mxE //. + - assert (Hcols_pos : (0 < cols)%N) by lia. + rewrite -Hcols in Hsz; clear Hcols_pos. + induction Hcol_sizes as [| x l Hx Hl IH]; [reflexivity |]. + simpl. + replace (mx_of_listlist (S (size l)) cols (cons x l)) + with (col_mx (@mx_of_listlist _ (Zconst t 0) 1 cols (cons x nil)) + (@mx_of_listlist _ (Zconst t 0) (size l) cols l)). + + change (S ?A) with (addn 1 A). + rewrite F.FMA_mulmx_col list_of_cV_col_mx. + replace (list_of_cV _) with [:: list_dotprod x vval]. + { simpl; f_equal; apply IH. } + rewrite /list_of_cV. + set one := ord_enum 1. + compute in one. + destruct idP; try lia. + subst one; simpl; f_equal. + rewrite mxE /F.FMA_mulmx /F.FMA_dotprod /fma_dotprod /list_dotprod. + f_equal; f_equal. + * apply (@eq_from_nth _ pos_zero). + { rewrite size_seq_of_rV //. } + move => j Hj. + rewrite Hx in Hj. + ordify cols j. + rewrite nth_seq_of_rV !mxE //. + * apply (@eq_from_nth _ pos_zero). + { rewrite size_seq_of_rV //. } + move => j Hj. + rewrite Hsz in Hj. + ordify cols j. + rewrite nth_seq_of_rV !mxE //. + + change (S (size l)) with (addn 1 (size l)). + apply listlist_of_mx_inj. + rewrite listlist_of_mx_of_listlist. + 2: simpl; change @length with @size; lia. + 2: constructor; auto. + rewrite listlist_of_mx_col_mx. + rewrite !listlist_of_mx_of_listlist; auto. + constructor; auto. +Qed. \ No newline at end of file diff --git a/accuracy_proofs/real_lemmas.v b/accuracy_proofs/real_lemmas.v new file mode 100644 index 0000000..1cdfa1f --- /dev/null +++ b/accuracy_proofs/real_lemmas.v @@ -0,0 +1,57 @@ +(** ** Real Arithmetic Auxiliary Lemmas *) + +Require Import Coq.Reals.Reals. +Require Import Coq.Reals.RIneq. +Require Import Psatz. + +Open Scope R_scope. + +(** Strict monotonicity of reciprocal: [0 < b < a -> /a < /b]. *) + +Lemma rdiv_lt (a b : R) : + 0 < b -> 0 < a -> b < a -> / a < / b. +Proof. + intros Ha Hb Hlt. + apply Rinv_lt_contravar. + - nra. + - nra. +Qed. + +(** Non-strict monotonicity of reciprocal: [0 < b <= a -> /a <= /b]. *) + +Lemma rdiv_le (a b : R) : + 0 < b -> 0 < a -> b <= a -> / a <= / b. +Proof. + intros Ha Hb Hle. + apply Rinv_le_contravar; nra. +Qed. + +(** Division equals multiplication by reciprocal. *) + +Lemma rdiv_mult_eq (a b : R) : + b <> 0 -> a / b = a * (1 / b). +Proof. nra. Qed. + +(** Subtraction is anti-monotone in the subtrahend (non-strict). *) + +Lemma Rminus_le_minus (a b c : R) : + b <= c -> a - c <= a - b. +Proof. nra. Qed. + +(** Subtraction is anti-monotone in the subtrahend (strict). *) + +Lemma Rminus_lt_minus (a b c : R) : + b < c -> a - c < a - b. +Proof. nra. Qed. + +(** Addition is compatible with mixed [<=]/[<] ordering. *) + +Lemma Rplus_le_lt_compat (a1 a2 b1 b2 : R) : + a1 <= a2 -> b1 < b2 -> a1 + b1 < a2 + b2. +Proof. nra. Qed. + +(** Multiplication is compatible with mixed [<]/[<=] ordering. *) + +Lemma Rmult_le_lt_compat (a1 a2 b1 b2 : R) : + 0 < a1 -> 0 < b1 -> a1 < a2 -> b1 <= b2 -> a1 * b1 < a2 * b2. +Proof. nra. Qed. \ No newline at end of file diff --git a/accuracy_proofs/sum_acc.v b/accuracy_proofs/sum_acc.v index c8f65da..98e9566 100644 --- a/accuracy_proofs/sum_acc.v +++ b/accuracy_proofs/sum_acc.v @@ -1,267 +1,353 @@ -(*This file contains two theorems: forward and backward error bounds for - the sum of two floating point lists; the functional model for - the summation is defined in sum_model.v.*) +(** * Floating-Point Summation Accuracy Theorems -From LAProof.accuracy_proofs Require Import preamble common - sum_model - float_acc_lems . -Require LAProof.accuracy_proofs.mv_mathcomp. + This file establishes backward and forward error bounds for floating-point + summation, in both list-based and ordinal-indexed forms. + + ** Error Factors + + Throughout, the accumulated relative error factor is + %$g(n) = (1 + \mathbf{u})^n - 1$%#\(g(n) = (1 + \mathbf{u})^n - 1\)#, + where %$\mathbf{u}$%#\(\mathbf{u}\)# is the unit roundoff for the given + floating-point type. It is defined in [common]. + + ** Main Results + + - [bSUM]: Shows that the computed floating-point sum can be expressed as + the exact sum of a slightly perturbed input list. Each input element is + perturbed by a relative factor bounded by + %$g(n-1)$%#\(g(n-1)\)#, where %$n$%#\(n\)# is the list length. + + - [Fsum_backward_error]: The ordinal-indexed analogue of [bSUM], expressing + the same backward error bound for functions indexed by finite ordinals. + + - [fSUM]: Bounds the absolute forward error of the computed sum by + %$g(n) \cdot \sum |x_i|$%#\(g(n) \cdot \sum |x_i|\)#, where + %$n$%#\(n\)# is the list length. + + - [Fsum_forward_error]: The ordinal-indexed analogue of [fSUM]. + + - [sum_forward_error_permute]: Shows that the forward error bound is stable + under permutation of the input list, so the bound holds regardless of the + order in which elements are summed. + + ** Dependencies + + This file relies on: + - [preamble], [common]: basic setup and shared definitions + - [sum_model]: the floating-point summation model [sumF] and [sumR] + - [float_acc_lems]: elementary floating-point accuracy lemmas + +*) + +From LAProof.accuracy_proofs Require Import + preamble + common + sum_model + float_acc_lems. +Require LAProof.accuracy_proofs.mv_mathcomp. Require Import Permutation. -Section WithNan . -Context {NAN: FPCore.Nans} {t: type}. -Notation g := (@g t). +Section WithNan. + +Context {NAN : FPCore.Nans} {t : type}. -Notation D := (@default_rel t). +Notation g := (@g t). +Notation D := (@default_rel t). +Notation neg_zero := (@common.neg_zero t). -Notation neg_zero := (@common.neg_zero t). +(** ** Backward Error: List Version *) + +(** [bSUM] expresses the computed floating-point sum as the exact real sum of + a slightly perturbed input list. Each element of the perturbed list differs + from the corresponding input by a relative factor bounded by + %$g(n-1)$%#\(g(n-1)\)#, where %$n$%#\(n\)# is the list length. *) Theorem bSUM : - forall (x: list (ftype t)) (Hfin: Binary.is_finite (sumF x)), - exists (x': list R), + forall (x : list (ftype t)) + (Hfin : Binary.is_finite (sumF x)), + exists (x' : list R), size x' = size x /\ FT2R (sumF x) = sumR x' /\ - (forall n, (n < size x')%nat -> exists delta, - nth 0 x' n = FT2R (nth neg_zero x n) * (1 + delta) /\ Rabs delta <= g (size x' - 1)). + (forall n, + (n < size x')%nat -> + exists delta, + nth 0 x' n = FT2R (nth neg_zero x n) * (1 + delta) /\ + Rabs delta <= g (size x' - 1)). Proof. -move => x. +move=> x. rewrite /sumF -(revK x) foldl_rev size_rev. -induction (rev x) as [ | a l] => Hfin; clear x. -- -exists []; repeat split; auto => //=. +induction (rev x) as [| a l] => Hfin; clear x. +- (* base case: empty list *) + exists []; repeat split; auto => //=. - (* case a::l *) -have Hl: l = [] \/ l <> []. { - destruct l; auto. right; congruence. -} -destruct Hl. -+ (* case empty l *) - subst; simpl in *; - destruct (BPLUS_finite_e _ _ Hfin). - exists [FT2R a]; split; [ simpl; auto | split ; - [rewrite Bplus_0R|] ] => //. - unfold sumR; simpl; nra. - intros. exists 0; simpl in H1; split. - rewrite Rplus_0_r Rmult_1_r. - have H3: ((n = 1)%nat \/ (n = 0)%nat) by lia. - destruct H3; subst; auto. - rewrite Rabs_R0 /g /=. lra. -+ (* case non-empty l *) -simpl in *. -destruct (BPLUS_finite_e _ _ Hfin) as (A & B). -(* IHl *) -pose proof (size_not_empty_nat l H) as Hlen1. -specialize (IHl B). -destruct IHl as (l' & Hlen' & Hsum & Hdel); auto. -rewrite {1}/Basics.flip in Hfin. -(* construct l'0 *) -pose proof (BPLUS_accurate' _ _ Hfin) as Hplus. -destruct Hplus as (d' & Hd'& Hplus). -rewrite /Basics.flip in Hsum,B,Hplus|-*. -change (fun x z => @BPLUS NAN t x z) with (@BPLUS _ t) in Hsum,B,Hplus |- *. -exists (map (Rmult (1+d')) l' ++ [:: FT2R a * (1+d')]); repeat split. -* rewrite size_cat size_map /= Hlen' addnC //. -* rewrite {}Hplus Hsum Rmult_plus_distr_r -sumR_app_cons cats0 sumR_mult //. -* move => n H1. - rewrite nth_cat. - rewrite size_cat size_map in H1|-*. simpl size in H1. - destruct (n < size l')%N eqn:Hn. - -- rewrite (nth_map R0); [ | lia]. - specialize (Hdel n Hn). - destruct Hdel as (d & Hd1 & Hd2). - exists ( (1+d') * (1+d) -1). - rewrite {}Hd1. split. - ++ fold (ftype t). - rewrite rev_cons nth_rcons size_rev. - destruct (n < size l)%N eqn:Hn'; [ | lia]. nra. - ++ field_simplify_Rabs. - eapply Rle_trans; [apply Rabs_triang | eapply Rle_trans; [apply Rplus_le_compat_r; apply Rabs_triang | ] ]. -rewrite Rabs_mult. -replace (Rabs d' * Rabs d + Rabs d' + Rabs d ) with - ((1 + Rabs d') * Rabs d + Rabs d' ) by nra. -eapply Rle_trans; [apply Rplus_le_compat | ]. -apply Rmult_le_compat; try apply Rabs_pos. -apply Fourier_util.Rle_zero_pos_plus1; try apply Rabs_pos. -apply Rplus_le_compat_l; apply Hd'. -apply Hd2. apply Hd'. -replace ((1 + D) * g (size l' - 1) + D) with -((1 + D) * g (size l' - 1) * 1 + D * 1) by nra. -rewrite one_plus_d_mul_g; apply Req_le; rewrite Rmult_1_r /=. f_equal; lia. - -- - fold (ftype t). - assert (n = size l') by lia. subst n. - rewrite nth_rev /= ; [ | lia]. - rewrite -Hlen'. do 2 replace (_ - _)%N with O by lia. simpl. - exists d'; split; auto. - eapply Rle_trans; [ apply Hd' | ]. - apply d_le_g_1. lia. + have Hl : l = [] \/ l <> []. { + destruct l; auto; right; congruence. + } + destruct Hl as [Hl | Hl]. + + (* case empty l *) + subst; simpl in *; + destruct (BPLUS_finite_e _ _ Hfin) as [Ha Hb]. + exists [FT2R a]; split; [simpl; auto | split; + [rewrite Bplus_0R|]] => //. + unfold sumR; simpl; nra. + intros n Hn; exists 0; simpl in Hn; split. + rewrite Rplus_0_r Rmult_1_r. + have H3 : (n = 1)%nat \/ (n = 0)%nat by lia. + destruct H3 as [Hn1 | Hn0]; subst; auto. + rewrite Rabs_R0 /g /=; lra. + + (* case non-empty l *) + simpl in *. + destruct (BPLUS_finite_e _ _ Hfin) as [Ha Hb]. + (* IHl *) + pose proof (size_not_empty_nat l Hl) as Hlen1. + specialize (IHl Hb). + destruct IHl as (l' & Hlen' & Hsum & Hdel); auto. + rewrite {1}/Basics.flip in Hfin. + (* construct l'0 *) + pose proof (BPLUS_accurate' _ _ Hfin) as Hplus. + destruct Hplus as (d' & Hd' & Hplus). + rewrite /Basics.flip in Hsum, Hb, Hplus |- *. + change (fun x z => @BPLUS NAN t x z) with (@BPLUS _ t) in Hsum, Hb, Hplus |- *. + exists (map (Rmult (1+d')) l' ++ [:: FT2R a * (1+d')]); repeat split. + * rewrite size_cat size_map /= Hlen' addnC //. + * rewrite {}Hplus Hsum Rmult_plus_distr_r -sumR_app_cons cats0 sumR_mult //. + * move=> n Hn. + rewrite nth_cat. + rewrite size_cat size_map in Hn |- *; simpl size in Hn. + destruct (n < size l')%N eqn:Hn_lt. + -- rewrite (nth_map R0); [| lia]. + specialize (Hdel n Hn_lt). + destruct Hdel as (d & Hd1 & Hd2). + exists ((1+d') * (1+d) - 1). + rewrite {}Hd1; split. + ++ fold (ftype t). + rewrite rev_cons nth_rcons size_rev. + destruct (n < size l)%N eqn:Hn'; [| lia]; nra. + ++ field_simplify_Rabs. + eapply Rle_trans; + [apply Rabs_triang | + eapply Rle_trans; + [apply Rplus_le_compat_r; apply Rabs_triang |]]. + rewrite Rabs_mult. + replace (Rabs d' * Rabs d + Rabs d' + Rabs d) + with ((1 + Rabs d') * Rabs d + Rabs d') by nra. + eapply Rle_trans; [apply Rplus_le_compat |]. + apply Rmult_le_compat; try apply Rabs_pos. + apply Fourier_util.Rle_zero_pos_plus1; try apply Rabs_pos. + apply Rplus_le_compat_l; apply Hd'. + apply Hd2. apply Hd'. + replace ((1 + D) * g (size l' - 1) + D) + with ((1 + D) * g (size l' - 1) * 1 + D * 1) by nra. + rewrite one_plus_d_mul_g; apply Req_le. + rewrite Rmult_1_r /=; f_equal; lia. + -- fold (ftype t). + assert (Hn_eq : n = size l') by lia; subst n. + rewrite nth_rev /=; [| lia]. + rewrite -Hlen'; do 2 replace (_ - _)%N with O by lia; simpl. + exists d'; split; auto. + eapply Rle_trans; [apply Hd' |]. + apply d_le_g_1; lia. Qed. +(** ** Backward Error: Indexed Version *) + +(** [Fsum_backward_error] lifts [bSUM] to functions indexed by finite + ordinals. The computed sum equals the exact sum of a perturbed family, + with each element perturbed by a relative factor bounded by + %$g(n-1)$%#\(g(n-1)\)#. *) + Theorem Fsum_backward_error : - forall [n] (x: 'I_n -> ftype t) (Hfin: Binary.is_finite (mv_mathcomp.F.sum x)), - exists (x': 'I_n -> R), + forall [n] (x : 'I_n -> ftype t) + (Hfin : Binary.is_finite (mv_mathcomp.F.sum x)), + exists (x' : 'I_n -> R), FT2R (mv_mathcomp.F.sum x) = \sum_i x' i /\ - (forall i: 'I_n, exists delta, - x' i = FT2R (x i) * (1 + delta) /\ Rabs delta <= g (n-1)). + (forall i : 'I_n, + exists delta, + x' i = FT2R (x i) * (1 + delta) /\ + Rabs delta <= g (n - 1)). Proof. -intros. -have :(Binary.is_finite (sumF (map x (ord_enum n)))). -rewrite -mv_mathcomp.F.sum_sumF //. -move => Hfin'. -destruct (bSUM _ Hfin') as [x' [H [H0 H1]]]. -rewrite size_map mv_mathcomp.size_ord_enum in H. subst n. -exists (nth R0 x'). -split. -rewrite mv_mathcomp.F.sum_sumF. rewrite H0 mv_mathcomp.sumR_sum //. -move => i. -destruct (H1 i) as [delta [H2 H3]]. -destruct i; simpl; lia. -exists delta. -rewrite {}H2. -change GRing.mul with Rmult. -change GRing.add with Rplus. -change (GRing.one _) with 1%Re. -split; auto. -clear H3. -f_equal. -f_equal. -clear. -destruct (size x'); clear x'. -simpl. -destruct i; lia. -rewrite (nth_map (@ord0 n) common.neg_zero). -rewrite mv_mathcomp.nth_ord_enum' //. -rewrite mv_mathcomp.size_ord_enum. -pose proof ltn_ord i. lia. +move=> n x Hfin. +have Hfin' : Binary.is_finite (sumF (map x (ord_enum n))). +{ rewrite -mv_mathcomp.F.sum_sumF //. } +destruct (bSUM _ Hfin') as [x' [Hsize [Hsum Hdel]]]. +rewrite size_map mv_mathcomp.size_ord_enum in Hsize; subst n. +exists (nth R0 x'); split. +- rewrite mv_mathcomp.F.sum_sumF Hsum mv_mathcomp.sumR_sum //. +- move=> i. + destruct (Hdel i) as [delta [H2 H3]]. + { destruct i; simpl; lia. } + exists delta. + rewrite {}H2. + change GRing.mul with Rmult. + change GRing.add with Rplus. + change (GRing.one _) with 1%Re. + split; auto. + clear H3; f_equal; f_equal; clear. + destruct (size x'); clear x'. + { simpl; destruct i; lia. } + rewrite (nth_map (@ord0 n) common.neg_zero). + rewrite mv_mathcomp.nth_ord_enum' //. + rewrite mv_mathcomp.size_ord_enum. + pose proof ltn_ord i; lia. Qed. +(** ** Forward Error: List Version *) + +(** [fSUM] bounds the absolute forward error of the computed floating-point + sum by %$g(n) \cdot \sum |x_i|$%#\(g(n) \cdot \sum |x_i|\)#, where + %$n$%#\(n\)# is the list length and %$|x_i|$%#\(|x_i|\)# denotes the + componentwise absolute values of the inputs. *) + Theorem fSUM : - forall (x: list (ftype t)) (Hfin: Binary.is_finite (sumF x)), - Rabs ((sumR (map FT2R x)) - FT2R (sumF x)) <= g (size x) * (sumR (map Rabs (map FT2R x))). + forall (x : list (ftype t)) + (Hfin : Binary.is_finite (sumF x)), + Rabs (sumR (map FT2R x) - FT2R (sumF x)) <= + g (size x) * sumR (map Rabs (map FT2R x)). Proof. -move => x. +move=> x. rewrite -(revK x). -induction (rev x); clear x => Hfin. -- unfold g; subst; simpl. rewrite Rminus_0_r Rabs_R0; nra. +induction (rev x) as [| a l]; clear x => Hfin. +- (* base case: empty list *) + unfold g; subst; simpl. + rewrite Rminus_0_r Rabs_R0; nra. - (* case a::l *) -assert (Hl: l = [] \/ l <> []). -destruct l; auto; right; congruence. -destruct Hl. -+ (* case empty l *) -subst. unfold g; simpl; subst. -destruct (BPLUS_finite_e _ _ Hfin) as (A & B). -rewrite Bplus_0R; auto. -field_simplify_Rabs; field_simplify; rewrite Rabs_R0. -apply Rmult_le_pos; auto with commonDB; apply Rabs_pos. -+ (* case non-empty l *) -rewrite /sumF foldl_rev /= in Hfin. -change (fun x z : ftype t => Basics.flip BPLUS z x) with (@BPLUS _ t) in Hfin. -destruct (BPLUS_finite_e _ _ Hfin) as (A & B). -(* IHl *) -rewrite -foldl_rev in B. -specialize (IHl B). -(* accuracy rewrites *) -destruct (BPLUS_accurate' _ _ Hfin) as (d' & Hd'& Hplus). -move :IHl. -rewrite /sumF. -rewrite !foldl_rev. -change (fun x z : ftype t => Basics.flip BPLUS z x) with (@BPLUS _ t). -rewrite !map_rev !sumR_rev !size_rev => IHl. -simpl. -rewrite {}Hplus. -(* algebra *) -field_simplify_Rabs. -set s0 := sumR (map FT2R l). - set (s := foldr _ _ l). -replace (- FT2R a * d' + s0 - FT2R s * d' - FT2R s) with - ((s0 - FT2R s) - d' * (FT2R s + FT2R a)) by nra. -eapply Rle_trans; - [ apply Rabs_triang | eapply Rle_trans; [ apply Rplus_le_compat_r - | rewrite !Rabs_Ropp] ]. -apply IHl. -eapply Rle_trans; - [apply Rplus_le_compat_l | ]. - rewrite Rabs_mult. apply Rmult_le_compat; try apply Rabs_pos. - apply Hd'. - eapply Rle_trans; [apply Rabs_triang | apply Rplus_le_compat_r]. - rewrite Rabs_minus_sym in IHl; apply Rabs_le_minus in IHl. apply IHl. -rewrite !Rmult_plus_distr_l; rewrite <- !Rplus_assoc. -set (s1 := sumR (map Rabs (map FT2R l))). -replace (g (size l ) * s1 + D * (g (size l ) * s1)) with - ((1+ D) * g (size l) * s1) by nra. -eapply Rle_trans; [apply Rplus_le_compat_r; - apply Rplus_le_compat_l; apply Rmult_le_compat_l; try apply Rabs_pos|]. -apply default_rel_ge_0. -apply sumR_le_sumRabs. -rewrite sumRabs_Rabs. -rewrite one_plus_d_mul_g. -rewrite Rplus_comm. -apply size_not_empty_nat in H. -apply Rplus_le_compat. -apply Rmult_le_compat; try apply Rabs_pos; - try apply default_rel_ge_0; try nra. -apply d_le_g_1; lia. -apply Req_le; f_equal. -f_equal. lia. + assert (Hl : l = [] \/ l <> []) by (destruct l; auto; right; congruence). + destruct Hl as [Hl | Hl]. + + (* case empty l *) + subst; unfold g; simpl. + destruct (BPLUS_finite_e _ _ Hfin) as [Ha Hb]. + rewrite Bplus_0R; auto. + field_simplify_Rabs; field_simplify; rewrite Rabs_R0. + apply Rmult_le_pos; auto with commonDB; apply Rabs_pos. + + (* case non-empty l *) + rewrite /sumF foldl_rev /= in Hfin. + change (fun x z : ftype t => Basics.flip BPLUS z x) with (@BPLUS _ t) in Hfin. + destruct (BPLUS_finite_e _ _ Hfin) as [Ha Hb]. + (* IHl *) + rewrite -foldl_rev in Hb. + specialize (IHl Hb). + (* accuracy rewrites *) + destruct (BPLUS_accurate' _ _ Hfin) as (d' & Hd' & Hplus). + move :IHl. + rewrite /sumF !foldl_rev. + change (fun x z : ftype t => Basics.flip BPLUS z x) with (@BPLUS _ t). + rewrite !map_rev !sumR_rev !size_rev => IHl. + simpl. + rewrite {}Hplus. + (* algebra *) + field_simplify_Rabs. + set s0 := sumR (map FT2R l). + set (s := foldr _ _ l). + replace (- FT2R a * d' + s0 - FT2R s * d' - FT2R s) + with ((s0 - FT2R s) - d' * (FT2R s + FT2R a)) by nra. + eapply Rle_trans; + [apply Rabs_triang | + eapply Rle_trans; [apply Rplus_le_compat_r | rewrite !Rabs_Ropp]]. + apply IHl. + eapply Rle_trans; [apply Rplus_le_compat_l |]. + rewrite Rabs_mult; apply Rmult_le_compat; try apply Rabs_pos. + apply Hd'. + eapply Rle_trans; [apply Rabs_triang | apply Rplus_le_compat_r]. + rewrite Rabs_minus_sym in IHl; apply Rabs_le_minus in IHl; apply IHl. + rewrite !Rmult_plus_distr_l -!Rplus_assoc. + set (s1 := sumR (map Rabs (map FT2R l))). + replace (g (size l) * s1 + D * (g (size l) * s1)) + with ((1 + D) * g (size l) * s1) by nra. + eapply Rle_trans; + [apply Rplus_le_compat_r; + apply Rplus_le_compat_l; + apply Rmult_le_compat_l; try apply Rabs_pos |]. + apply default_rel_ge_0. + apply sumR_le_sumRabs. + rewrite sumRabs_Rabs one_plus_d_mul_g Rplus_comm. + apply size_not_empty_nat in Hl. + apply Rplus_le_compat. + apply Rmult_le_compat; try apply Rabs_pos; + try apply default_rel_ge_0; try nra. + apply d_le_g_1; lia. + apply Req_le; f_equal; f_equal; lia. Qed. -Lemma Fsum_forward_error: - forall [n] (x: 'I_n -> ftype t) (Hfin: Binary.is_finite (mv_mathcomp.F.sum x)), - Rabs (\sum_i (FT2R (x i)) - FT2R (mv_mathcomp.F.sum x)) <= g n * (\sum_i (Rabs (FT2R (x i)))). +(** ** Forward Error: Indexed Version *) + +(** [Fsum_forward_error] lifts [fSUM] to functions indexed by finite + ordinals, giving the same %$g(n) \cdot \sum |x_i|$%#\(g(n) \cdot \sum |x_i|\)# + bound for the absolute forward error of the indexed sum. *) + +Lemma Fsum_forward_error : + forall [n] (x : 'I_n -> ftype t) + (Hfin : Binary.is_finite (mv_mathcomp.F.sum x)), + Rabs (\sum_i (FT2R (x i)) - FT2R (mv_mathcomp.F.sum x)) <= + g n * (\sum_i (Rabs (FT2R (x i)))). Proof. -intros. -have :(Binary.is_finite (sumF (map x (ord_enum n)))). -rewrite -mv_mathcomp.F.sum_sumF //. -move => Hfin'. -move :(fSUM _ Hfin') => H. +move=> n x Hfin. +have Hfin' : Binary.is_finite (sumF (map x (ord_enum n))). +{ rewrite -mv_mathcomp.F.sum_sumF //. } +move: (fSUM _ Hfin') => H. rewrite !mv_mathcomp.sumR_sum !size_map !mv_mathcomp.size_ord_enum -map_comp in H. rewrite mv_mathcomp.F.sum_sumF. -match goal with H: Rle (Rabs (Rminus ?A _)) (Rmult _ ?B) |- Rle (Rabs (Rminus ?A' _)) (Rmult _ ?B') => - replace A' with A; [replace B' with B | ]; auto; clear +match goal with +| H : Rle (Rabs (Rminus ?A _)) (Rmult _ ?B) + |- Rle (Rabs (Rminus ?A' _)) (Rmult _ ?B') => + replace A' with A; [replace B' with B |]; auto; clear end. -- -apply eq_bigr => i _. -destruct n. destruct i; lia. -rewrite -map_comp. -rewrite (nth_map (@ord0 n) R0). -rewrite mv_mathcomp.nth_ord_enum' //. -rewrite mv_mathcomp.size_ord_enum //. -- -apply eq_bigr => i _. -destruct n. destruct i; lia. -rewrite (nth_map (@ord0 n) R0). -rewrite mv_mathcomp.nth_ord_enum' //. -rewrite mv_mathcomp.size_ord_enum //. +- apply eq_bigr => i _. + destruct n; [destruct i; lia |]. + rewrite -map_comp. + rewrite (nth_map (@ord0 n) R0). + rewrite mv_mathcomp.nth_ord_enum' //. + rewrite mv_mathcomp.size_ord_enum //. +- apply eq_bigr => i _. + destruct n; [destruct i; lia |]. + rewrite (nth_map (@ord0 n) R0). + rewrite mv_mathcomp.nth_ord_enum' //. + rewrite mv_mathcomp.size_ord_enum //. Qed. +(** ** Forward Error Under Permutation *) + +(** [sum_forward_error_permute'] is an auxiliary lemma: when two lists are + permutations of each other, the forward error bound for the computed sum + of either list can be expressed using the length of the original list. + Used internally by [sum_forward_error_permute]. *) + Lemma sum_forward_error_permute' : - forall (x x0: list (ftype t)) - (Hfin: Binary.is_finite (sumF x)) - (Hfin0: Binary.is_finite (sumF x0)) - (Hper: Permutation x x0), - Rabs ((sumR (map FT2R x0)) - FT2R (sumF x0)) <= g (size x) * (sumR (map Rabs (map FT2R x0))). + forall (x x0 : list (ftype t)) + (Hfin : Binary.is_finite (sumF x)) + (Hfin0 : Binary.is_finite (sumF x0)) + (Hper : Permutation x x0), + Rabs ((sumR (map FT2R x0)) - FT2R (sumF x0)) <= + g (size x) * (sumR (map Rabs (map FT2R x0))). Proof. -intros. -eapply Rle_trans. -apply (fSUM x0 Hfin0). +move=> x x0 Hfin Hfin0 Hper. +eapply Rle_trans; [apply (fSUM x0 Hfin0) |]. apply Req_le; f_equal. -change @size with @length. +change @size with @length. rewrite (Permutation_length Hper); auto. Qed. +(** [sum_forward_error_permute] shows that the forward error bound for the + computed sum is invariant under permutation of the input. If two lists + are permutations of each other, the absolute forward error for either + computed sum is bounded by + %$g(n) \cdot \sum |x_i|$%#\(g(n) \cdot \sum |x_i|\)# + using the shared element set and length %$n$%#\(n\)#. *) + Theorem sum_forward_error_permute : - forall (x x0: list (ftype t)) - (Hfin: Binary.is_finite (sumF x)) - (Hfin0: Binary.is_finite (sumF x0)) - (Hper: Permutation x x0), - Rabs ((sumR (map FT2R x)) - FT2R (sumF x0)) <= g (size x) * (sumR (map Rabs (map FT2R x))). + forall (x x0 : list (ftype t)) + (Hfin : Binary.is_finite (sumF x)) + (Hfin0 : Binary.is_finite (sumF x0)) + (Hper : Permutation x x0), + Rabs ((sumR (map FT2R x)) - FT2R (sumF x0)) <= + g (size x) * (sumR (map Rabs (map FT2R x))). Proof. -intros. -rewrite (sumR_permute (map FT2R x) (map FT2R x0)); [|apply Permutation_map; auto]. -eapply Rle_trans. -apply sum_forward_error_permute'; eauto. +move=> x x0 Hfin Hfin0 Hper. +rewrite (sumR_permute (map FT2R x) (map FT2R x0)); + [| apply Permutation_map; auto]. +eapply Rle_trans; [apply sum_forward_error_permute'; eauto |]. apply Req_le; f_equal; symmetry. -f_equal. apply Permutation_length; auto. +f_equal; apply Permutation_length; auto. apply sumR_permute. repeat apply Permutation_map; auto. Qed. diff --git a/accuracy_proofs/sum_is_finite.v b/accuracy_proofs/sum_is_finite.v index 5e6e85f..ff9884d 100644 --- a/accuracy_proofs/sum_is_finite.v +++ b/accuracy_proofs/sum_is_finite.v @@ -1,208 +1,301 @@ -From LAProof.accuracy_proofs Require Import preamble - common op_defs dotprod_model sum_model float_acc_lems - fma_dot_acc sum_acc. +(** * Finite Sum Boundedness + + This file establishes conditions under which floating-point summation + remains finite (non-infinite, non-NaN). The central contributions are: + + - [is_finite_sum_no_overflow']: a helper lemma showing that the floating-point + sum is finite whenever both summands are finite and no overflow occurs. + + - [fun_bnd]: a bound on individual list elements sufficient to guarantee + that their floating-point sum does not overflow. This bound decreases + monotonically as the list length grows, reflecting the accumulation of + rounding error. + + - [fun_bnd_le]: monotonicity of [fun_bnd], showing that the per-element + bound for a list of length S n is no greater than that for length n. + + - [finite_sum_from_bounded]: the main theorem. Given a list of finite + floating-point values each bounded in absolute value, + the floating-point sum fs satisfying [sum_rel_Ft] is + finite. The proof proceeds by induction on the list, using the + [fun_bnd_le] monotonicity lemma to discharge the inductive hypothesis + and explicit rounding-error arithmetic to close the overflow bound. +*) + +From LAProof.accuracy_proofs Require Import + preamble + real_lemmas + common + dotprod_model + sum_model + float_acc_lems + fma_dot_acc + sum_acc. Section NAN. -Variable NAN: FPCore.Nans. -Definition fmax (t: type) := bpow Zaux.radix2 (femax t). +Variable NAN : FPCore.Nans. + +(** ** Overflow-free addition preserves finiteness + If x and y are finite floating-point numbers and their sum does + not overflow (in the sense of [Bplus_no_overflow]), then their sum is + is finite. *) + Lemma is_finite_sum_no_overflow' (t : type) : - forall (x y: ftype t) - (Hfinx: Binary.is_finite x = true) - (Hfiny: Binary.is_finite y = true) - (Hov : @Bplus_no_overflow t (FT2R x) (FT2R y)), - Binary.is_finite (BPLUS x y ) = true. + forall (x y : ftype t) + (Hfinx : Binary.is_finite x = true) + (Hfiny : Binary.is_finite y = true) + (Hov : @Bplus_no_overflow t (FT2R x) (FT2R y)), + Binary.is_finite (BPLUS x y) = true. Proof. -intros. -pose proof (Binary.Bplus_correct (fprec t) (femax t) (fprec_gt_0 t) (fprec_lt_femax t) - (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) - BinarySingleNaN.mode_NE x y Hfinx Hfiny ). -unfold Bplus_no_overflow, FT2R in Hov. -apply Rlt_bool_true in Hov. -rewrite Hov in H; simpl in H; destruct H as (_ & B & _); simpl; auto. + intros x y Hfinx Hfiny Hov. + pose proof (Binary.Bplus_correct + (fprec t) (femax t) + (fprec_gt_0 t) (fprec_lt_femax t) + (FPCore.plus_nan (fprec t) (femax t) (fprec_gt_one t)) + BinarySingleNaN.mode_NE x y Hfinx Hfiny) as Hcorrect. + unfold Bplus_no_overflow, FT2R in Hov. + apply Rlt_bool_true in Hov. + rewrite Hov in Hcorrect. + simpl in Hcorrect. + destruct Hcorrect as (_ & HB & _). + simpl; auto. Qed. +(** ** Per-element bound for finite summation + + [fun_bnd t n] is the maximum absolute value that each element of an + n-element list may have while still guaranteeing that the + floating-point sum of the list is finite. It is defined as + + <> + + where the function g encodes the standard (1 + eps)^k - 1 rounding-error + growth factor. *) + Definition fun_bnd (t : type) (n : nat) := -fmax t / (1 + @default_rel t) * 1 / (1 + INR n * (@g t (n - 1) + 1)) . - -Lemma rdiv_lt (a b: R) : - 0 < b -> 0 < a -> b < a -> / a < / b. -Proof. -intros. -replace (/b) with (1/b) by nra. -apply Rdiv_lt_right; auto. -rewrite Rmult_comm. -apply Rdiv_lt_left; auto. -nra. -Qed. + (@fmax t) / (1 + @default_rel t) * 1 / (1 + INR n * (@g t (n - 1) + 1)). -Lemma rdiv_le (a b: R) : - 0 < b -> 0 < a -> b <= a -> / a <= / b. -Proof. -intros. -replace (/b) with (1/b) by nra. -apply Rcomplements.Rle_div_r; auto. -rewrite Rmult_comm. -apply Rdiv_le_left; auto. -nra. -Qed. +(** ** Monotonicity of [fun_bnd] -Lemma rdiv_mult_eq : -forall a b, b <> 0 -> a/b = a * (1/b) . + The per-element bound is non-increasing in the list length: a longer + list requires each element to be smaller to keep the sum finite. *) + +Lemma fun_bnd_le (t : type) (n : nat) : + fun_bnd t (S n) <= fun_bnd t n. Proof. -(intros; field_simplify; nra). + unfold fun_bnd. + apply Rmult_le_compat_l. + - rewrite Rmult_1_r. + apply Rcomplements.Rdiv_le_0_compat. + unfold fmax; apply bpow_ge_0. + eapply Rlt_trans with 1; try nra. + apply default_rel_plus_1_gt_1. + - apply rdiv_le; try ( + apply Rplus_lt_le_0_compat; try nra; + apply Rmult_le_pos; [apply pos_INR| ]; + apply Rplus_le_le_0_compat; try nra; + apply g_pos ). + apply Rplus_le_compat_l. + apply Rmult_le_compat; [apply pos_INR | | |]. + apply Rplus_le_le_0_compat; try nra; apply g_pos. + apply le_INR; try lia. + unfold g; field_simplify. + apply Rle_pow. + apply default_rel_plus_1_ge_1. + simpl; lia. Qed. -Lemma fun_bnd_le (t : type) (n : nat) : -fun_bnd t (S n) <= fun_bnd t n. -Proof. -intros; unfold fun_bnd. apply Rmult_le_compat_l. -rewrite Rmult_1_r. -apply Rcomplements.Rdiv_le_0_compat. -unfold fmax; apply bpow_ge_0. -eapply Rlt_trans with 1; try nra. -apply default_rel_plus_1_gt_1. -apply rdiv_le; try ( -apply Rplus_lt_le_0_compat; try nra; -apply Rmult_le_pos; [apply pos_INR| ]; -apply Rplus_le_le_0_compat; try nra; -apply g_pos ). -apply Rplus_le_compat_l. -apply Rmult_le_compat; [apply pos_INR | | |]. -apply Rplus_le_le_0_compat; try nra; apply g_pos. -apply le_INR; try lia. -unfold g; field_simplify. -apply Rle_pow. -apply default_rel_plus_1_ge_1. -simpl; lia. -Qed. +(** ** Main theorem: element-wise bound implies finite sum + If every element of l is finite and sufficiently bounded + then the floating-point sum fs satisfying [sum_rel_Ft l fs] is finite. -Lemma finite_sum_from_bounded : - forall (t: type) (l: list (ftype t)) - (fs : ftype t) - (Hfs: sum_rel_Ft l fs), - (forall x, In x l -> - Binary.is_finite x = true /\ Rabs (FT2R x) < fun_bnd t (length l)) -> - Binary.is_finite fs = true. + The proof proceeds by induction on l, using [fun_bnd_le] to transfer + the per-element bound to the tail, then closing the overflow condition + via explicit rounding-error arithmetic and algebraic manipulation. *) + +Lemma finite_sum_from_bounded : + forall (t : type) (l : list (ftype t)) + (fs : ftype t) + (Hfs : sum_rel_Ft l fs), + (forall x, In x l -> + Binary.is_finite x = true /\ Rabs (FT2R x) < fun_bnd t (length l)) -> + Binary.is_finite fs = true. Proof. -intros ?. -induction l. -{ intros; inversion Hfs; subst; simpl; auto. } -{ intros. inversion Hfs; subst. -assert (Hin: forall x : ftype t, - In x l -> Binary.is_finite x = true /\ - Rabs (FT2R x) < fun_bnd t (length l)). - { intros. split; [apply H; simpl; auto | ]. - eapply Rlt_le_trans; [apply H; simpl; auto | ]. - apply fun_bnd_le. } -assert (Hfina : Binary.is_finite a = true) by - (apply H; simpl; auto). -unfold sum. -fold (@sum_rel_Ft NAN t) in H3. -specialize (IHl s H3 Hin). -apply is_finite_sum_no_overflow'; auto. -unfold Bplus_no_overflow. -assert (A: Generic_fmt.generic_format Zaux.radix2 - (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) - (FT2R a) ) by (apply Binary.generic_format_B2R). -assert (B: Generic_fmt.generic_format Zaux.radix2 - (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) - (FT2R s) ) by (apply Binary.generic_format_B2R). -destruct (Plus_error.FLT_plus_error_N_ex Zaux.radix2 (SpecFloat.emin (fprec t) (femax t)) - (fprec t) (fun x0 : Z => negb (Z.even x0)) (FT2R a) (FT2R s) A B) as (d & Hd & Hd'). -unfold Relative.u_ro in Hd. fold (@default_rel t) in Hd. -assert ( H1: Generic_fmt.round Zaux.radix2 (SpecFloat.fexp (fprec t) (femax t)) - (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) - (FT2R a + FT2R s) = Generic_fmt.round Zaux.radix2 - (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) - (Generic_fmt.Znearest (fun x0 : Z => negb (Z.even x0))) - (FT2R a + FT2R s)) by auto. -rewrite <- H1 in Hd'; clear H1; rewrite Hd'; clear Hd'. -destruct (sum_rel_R_exists l s H3) as (rs & Hrs). -destruct (sum_rel_R_abs_exists l s H3) as (rs_abs & Habs). -Search sum_rel_Ft sumF. -assert (H3': s = sumF l) by (apply sum_rel_Ft_fold; auto). -assert (IHl': Binary.is_finite (sumF l)) by (rewrite -H3'; auto). -pose proof @fSUM NAN t l IHl' as H1. rewrite <- H3' in H1. -pose proof sum_rel_bound' as C. -pose proof sum_rel_bound'' as D. -rewrite Rabs_minus_sym in H1. -apply Rabs_le_minus in H1. -eapply Rle_lt_trans. -rewrite Rabs_mult. -apply Rmult_le_compat; try apply Rabs_pos. -eapply Rle_trans; [apply Rabs_triang | apply Rplus_le_compat ]. -apply Rlt_le; apply H; simpl; auto. -assert (Rabs (FT2R s) <= (@g t (length l - 1) + 1) * rs_abs). -{ eapply Rle_trans; [apply H1| field_simplify; apply Rplus_le_compat_l]. - eapply Rle_trans; [ eapply sum_rel_R_Rabs; [apply Hrs | apply Habs] |] . - eapply Req_le; eapply sum_rel_R_Rabs_eq; apply Habs. } -eapply Rle_trans. -apply H0. -apply Rmult_le_compat_l. -apply Rplus_le_le_0_compat; try nra. apply g_pos. -apply D. apply Habs. -intros. apply Rlt_le. apply H; simpl; auto. -assert (HD: Rabs (1 + d) <= (1 + default_rel t )). -{ eapply Rle_trans; [apply Rabs_triang| rewrite Rabs_R1; apply Rplus_le_compat_l]. -eapply Rle_trans; [apply Hd |]. -apply Rdiv_le_left. -apply Fourier_util.Rlt_zero_pos_plus1. -apply default_rel_gt_0. -eapply Rle_trans with (default_rel t * 1); try nra. } -apply HD. -(*algebra*) -unfold fun_bnd; rewrite Rmult_1_r. -set (x:= (g t (length (a :: l) - 1) + 1)). -set (y:= (1 + INR (length (a :: l)) * x)). -set (z:= (fmax t / (1 + default_rel t) / y)). -replace ((z + (g t (length l - 1) + 1) * (INR (length l) * z))) - with (z * (1 + (g t (length l - 1) + 1) * (INR (length l)))) - by nra. -rewrite Rmult_comm. -rewrite <- Rmult_assoc. -assert (Hy : 0 < y). -{ unfold y. - apply Rplus_lt_le_0_compat; try nra. - apply Rmult_le_pos; [apply pos_INR|]. - unfold x. apply Rplus_le_le_0_compat; try nra. - apply g_pos. } -assert (Hy' : y <> 0). { apply Stdlib.Rlt_neq_sym; auto. } -assert (H0: (1 + default_rel t) * z = fmax t / y). -{ unfold z; field_simplify; auto. -split; auto. -apply tech_Rplus; try nra. -apply default_rel_gt_0. } -rewrite H0. -rewrite rdiv_mult_eq; auto. -replace (bpow Zaux.radix2 (femax t)) with - (bpow Zaux.radix2 (femax t) * 1) by nra. -rewrite Rmult_assoc. -apply Rmult_lt_compat_l. apply bpow_gt_0. -rewrite Rmult_comm. -rewrite <- rdiv_mult_eq; auto. -apply Rdiv_lt_left; auto. -rewrite Rmult_1_l. -unfold y. -apply Rplus_le_lt_compat; try nra. -rewrite Rmult_comm. -eapply Rle_lt_trans with (INR (length l) * x). -apply Rmult_le_compat_l; [apply pos_INR|]. -unfold x. -apply Rplus_le_compat_r. -unfold g. -apply Rplus_le_compat_r. -apply Rle_pow. -apply default_rel_plus_1_ge_1. -simpl; lia. -apply Rmult_lt_compat_r. -unfold x. -apply Rle_lt_0_plus_1; apply g_pos. -apply lt_INR; simpl; try lia. } -Qed. + intros t. + induction l as [| a l IHl]. + - (* Base case: empty list *) + intros fs Hfs _. + inversion Hfs; subst; simpl; auto. + - (* Inductive case: list is [a :: l] *) + intros fs Hfs Hbnd. + inversion Hfs; subst. + + (* Transfer the bound to the tail l using [fun_bnd_le]. *) + assert (Hin : forall x : ftype t, + In x l -> + Binary.is_finite x = true /\ + Rabs (FT2R x) < fun_bnd t (length l)). + { intros x Hx. + split. + - apply Hbnd; simpl; auto. + - eapply Rlt_le_trans. + + apply Hbnd; simpl; auto. + + apply fun_bnd_le. } + (* The head element a is finite since it is in the list *) + assert (Hfina: Binary.is_finite a = true) by + (exact (proj1 (Hbnd a (in_eq a l)))). + fold (@sum_rel_Ft NAN t) in H2. + (* Apply the inductive hypothesis to obtain finiteness of the partial sum s. *) + specialize (IHl s H2 Hin). + + apply is_finite_sum_no_overflow'; auto. + + (* Establish the no-overflow condition for a + s. *) + unfold Bplus_no_overflow. + + (* Generic format witnesses for FT2R a and FT2R s. *) + assert (A : Generic_fmt.generic_format Zaux.radix2 + (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) + (FT2R a)) + by apply Binary.generic_format_B2R. + assert (B : Generic_fmt.generic_format Zaux.radix2 + (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) + (FT2R s)) + by apply Binary.generic_format_B2R. + + (* Obtain rounding error factor << d >> satisfying + round(a + s) = (a + s) * (1 + d) with |d| <= default_rel. *) + destruct (Plus_error.FLT_plus_error_N_ex + Zaux.radix2 + (SpecFloat.emin (fprec t) (femax t)) + (fprec t) + (fun x0 : Z => negb (Z.even x0)) + (FT2R a) (FT2R s) A B) + as (d & Hd & Hd'). + unfold Relative.u_ro in Hd. + fold (@default_rel t) in Hd. + + (* Rewrite the rounding mode to match Hd'. *) + assert (H1 : Generic_fmt.round Zaux.radix2 + (SpecFloat.fexp (fprec t) (femax t)) + (BinarySingleNaN.round_mode BinarySingleNaN.mode_NE) + (FT2R a + FT2R s) + = Generic_fmt.round Zaux.radix2 + (FLT.FLT_exp (SpecFloat.emin (fprec t) (femax t)) (fprec t)) + (Generic_fmt.Znearest (fun x0 : Z => negb (Z.even x0))) + (FT2R a + FT2R s)) + by auto. + rewrite <- H1 in Hd'; clear H1. + rewrite Hd'; clear Hd'. + + (* Witnesses for the real-valued sum and its absolute-value sum. *) + destruct (sum_rel_R_exists l s H2) as (rs & Hrs). + destruct (sum_rel_R_abs_exists l s H2) as (rs_abs & Habs). + + (* Use [fSUM] to bound |FT2R s| in terms of the rounding-error growth + factor g applied to the absolute-value sum rs_abs. *) + assert (H3' : s = sumF (rev l)). + { apply (sum_rel_Ft_fold (rev l)). rewrite revK. exact H2. } + assert (IHl' : Binary.is_finite (sumF (rev l))) by (rewrite <- H3'; auto). + assert (Hrev1 : sumR (map FT2R (rev l)) = rs). + { rewrite map_rev sumR_rev. symmetry. apply sum_rel_R_fold. exact Hrs. } + assert (Hrev2 : sumR (map Rabs (map FT2R (rev l))) = rs_abs). + { rewrite map_rev map_rev sumR_rev. symmetry. apply sum_rel_R_fold. exact Habs. } + pose proof (@fSUM NAN t (rev l) IHl') as Hfsum. + rewrite <- H3' in Hfsum. + rewrite Hrev1 in Hfsum. + rewrite Hrev2 in Hfsum. + rewrite size_rev in Hfsum. + rewrite Rabs_minus_sym in Hfsum. + apply Rabs_le_minus in Hfsum. + + (* Bound |FT2R s| <= (g(n) + 1) * rs_abs. *) + assert (Hs_abs : Rabs (FT2R s) <= + (@g t (length l) + 1) * rs_abs). + { eapply Rle_trans; [apply Hfsum |]. + assert (Hrsle : Rabs rs <= rs_abs). + { eapply Rle_trans. + - eapply sum_rel_R_Rabs; [apply Hrs | apply Habs]. + - eapply Req_le; eapply sum_rel_R_Rabs_eq; apply Habs. } + change @size with @length. + ring_simplify. + apply Rplus_le_compat_l; exact Hrsle. } + + (* Bound |1 + d| <= 1 + default_rel. *) + assert (HD : Rabs (1 + d) <= 1 + @default_rel t). + { eapply Rle_trans; [apply Rabs_triang |]. + rewrite Rabs_R1. + apply Rplus_le_compat_l. + eapply Rle_trans; [apply Hd |]. + apply Rdiv_le_left. + - apply Fourier_util.Rlt_zero_pos_plus1, default_rel_gt_0. + - eapply Rle_trans with (@default_rel t * 1); try nra. } + + (* Combine bounds: |(a + s)*(1+d)| <= (1 + default_rel) * (|a| + |s|) + and then close using [fun_bnd] algebra. *) + eapply Rle_lt_trans. + { rewrite Rabs_mult. + apply Rmult_le_compat; try apply Rabs_pos. + - eapply Rle_trans; [apply Rabs_triang |]. + apply Rplus_le_compat. + + apply Rlt_le, Hbnd; simpl; auto. + + eapply Rle_trans; [apply Hs_abs |]. + apply Rmult_le_compat_l. + * apply Rplus_le_le_0_compat; try nra; apply g_pos. + * apply sum_rel_bound''. + -- apply Habs. + -- intros x Hx; apply Rlt_le, Hbnd; simpl; auto. + - apply HD. } + + (* Pure algebraic closure: the accumulated bound fits within [fmax t]. *) + unfold fun_bnd; rewrite Rmult_1_r. + assert (Heq_sub : (length (a :: l) - 1)%nat = length l) + by (simpl; rewrite subSS subn0; reflexivity). + rewrite Heq_sub. + set (x := @g t (length l) + 1). + set (y := 1 + INR (length (a :: l)) * x). + set (z := (@fmax t) / (1 + @default_rel t) / y). + change @size with @length. + + replace (z + x * (INR (length l) * z)) + with (z * (1 + x * INR (length l))) + by nra. + assert (Hy : 0 < y). + { unfold y. + apply Rplus_lt_le_0_compat; try nra. + apply Rmult_le_pos; [apply pos_INR |]. + unfold x; apply Rplus_le_le_0_compat; try nra; apply g_pos. } + assert (Hy' : y <> 0) by (apply Stdlib.Rlt_neq_sym; auto). + + (* Simplify (1 + default_rel) * z = fmax / y. *) + assert (H0 : (1 + @default_rel t) * z = (@fmax t) / y). + { unfold z; field_simplify; auto. + split; auto. + apply tech_Rplus; try nra. + apply default_rel_gt_0. } + + (* 1 + x * INR (length l) < y. *) + assert (Hineq : 1 + x * INR (length l) < y). + { unfold y. + apply Rplus_le_lt_compat; try nra. + rewrite Rmult_comm. + apply Rmult_lt_compat_r; [unfold x; apply Rle_lt_0_plus_1, g_pos |]. + apply lt_INR; simpl; lia. } + + (* Rewrite via H0: z*(1+M)*(1+D) = (fmax/y)*(1+M). *) + replace (z * (1 + x * INR (length l)) * (1 + @default_rel t)) + with ((@fmax t) / y * (1 + x * INR (length l))) + by (rewrite <- H0; ring). + unfold fmax. + apply Rlt_le_trans with (bpow Zaux.radix2 (femax t) / y * y). + { apply Rmult_lt_compat_l; [| exact Hineq]. + unfold Rdiv; apply Rmult_lt_0_compat; [apply bpow_gt_0 | apply Rinv_pos; exact Hy]. } + unfold Rdiv; rewrite Rmult_assoc; rewrite Rinv_l; [lra | exact Hy']. +Qed. End NAN. \ No newline at end of file diff --git a/accuracy_proofs/sum_model.v b/accuracy_proofs/sum_model.v index f28d811..07043c3 100644 --- a/accuracy_proofs/sum_model.v +++ b/accuracy_proofs/sum_model.v @@ -1,506 +1,729 @@ -(* This file contains floating point functional models for the summation of - lists, as well as theorems regarding their equivalence. *) +(** * Floating-Point Summation: Functional Models and Equivalences -From LAProof.accuracy_proofs Require Import preamble common. + This file defines and relates several functional models for floating-point + summation of lists. It provides both real-valued and floating-point + relational specifications of summation, as well as a non-deterministic + "any-order" floating-point summation predicate. The key contributions are: + - [sum_rel]: A general inductive relation specifying summation over a list + given a default element and a binary operation. Instantiated to both + real ([sum_rel_R]) and floating-point ([sum_rel_Ft]) arithmetic. + + - [sum_any'], [sum_any]: An inductive predicate capturing floating-point + summation in _any_ order and _any_ binary tree structure, modulo + permutation. This is useful for reasoning about implementations that + reorder or restructure summation for efficiency. + + - [sumR]: A fold-based functional definition of real summation, shown + equivalent to [sum_rel_R]. + + - [sumF]: A fold-based functional definition of floating-point summation, + shown equivalent to [sum_rel_Ft]. + + Key lemmas include: + + - [sum_rel_sum_any]: Every [sum_rel] summation can be realized as a + [sum_any] summation (up to floating-point equality [feq]). + + - [sum_rel_R_abs], [sum_rel_R_Rabs]: Bounds on the absolute value of a + real sum in terms of the sum of absolute values of its elements. + + - [sum_rel_bound], [sum_rel_bound'], [sum_rel_bound'']: Uniform bounds + on the magnitude of a real sum given elementwise bounds. + + - [sum_rel_R_permute], [sumR_permute]: Summation is invariant under + permutation of the input list. + + - [subtract_loop_sum_any]: The subtraction loop idiom (used in, e.g., + Cholesky decomposition implementations) can be related to [sum_any], + enabling accuracy theorems for [sum_rel] to transfer to subtraction loops. +*) + +From LAProof.accuracy_proofs Require Import preamble common float_acc_lems. Require Import Permutation. -Inductive sum_rel {A : Type} (default: A) (sum_op : A -> A -> A) : list A -> A -> Prop := -| sum_rel_nil : sum_rel default sum_op [] default -| sum_rel_cons : forall l a s, - sum_rel default sum_op l s -> - sum_rel default sum_op (a::l) (sum_op a s). +(** ** General Summation Relation -Definition sum_rel_R := @sum_rel R 0%R Rplus. + [sum_rel default sum_op l s] holds when << s >> is the result of folding + [sum_op] over the list << l >> from the right, starting from << default >>. + The empty list yields << default >>, and a cons << a :: l >> yields + [sum_op a s] where << s >> is the sum of << l >>. *) + +Inductive sum_rel {A : Type} (default : A) (sum_op : A -> A -> A) + : list A -> A -> Prop := + | sum_rel_nil : sum_rel default sum_op [] default + | sum_rel_cons : forall l a s, + sum_rel default sum_op l s -> + sum_rel default sum_op (a :: l) (sum_op a s). -Inductive sum_any' {NAN: FPCore.Nans} {t}: forall (h: nat) (v: list (ftype t)) (s: ftype t), Prop := -| Sum_Any_1: forall x, sum_any' O [x] x -| Sum_Any_split: forall n1 n2 al bl a b, - sum_any' n1 al a -> sum_any' n2 bl b -> sum_any' (S (Nat.max n1 n2)) (al++bl) (BPLUS a b) -| Sum_Any_perm: forall n al bl s, Permutation al bl -> sum_any' n al s -> sum_any' n bl s. +(** [sum_rel_R] is [sum_rel] instantiated to real-number addition. *) -Inductive sum_any {NAN: FPCore.Nans} {t:type}: forall (h: nat) (v: list (ftype t)) (s: ftype t), Prop := -| Sum_Any_None: sum_any O nil pos_zero -| Sum_Any_Some: forall n v s, sum_any' n v s -> sum_any n v s. +Definition sum_rel_R := @sum_rel R 0%R Rplus. -Lemma sum_rel_sum_any: forall {NAN: FPCore.Nans} {t} z (v: list (ftype t)) s (Hz: iszero z), - sum_rel z BPLUS v s -> - exists s', feq s s' /\ sum_any (Nat.pred (size v)) v s'. +(** ** Any-Order Floating-Point Summation + + [sum_any' h v s] captures the idea that a list << v >> of floating-point + values can be summed in _any_ binary tree structure of depth at most << h >>, + up to reordering. The constructors are: + + - [Sum_Any_1]: A singleton list trivially sums to its element. + - [Sum_Any_split]: Two sublists can be summed independently and their + results combined with [BPLUS]. + - [Sum_Any_perm]: The summation result is invariant under permutation + of the input list. *) + +Inductive sum_any' {NAN : FPCore.Nans} {t} : + forall (h : nat) (v : list (ftype t)) (s : ftype t), Prop := + | Sum_Any_1 : forall x, + sum_any' O [x] x + | Sum_Any_split : forall n1 n2 al bl a b, + sum_any' n1 al a -> + sum_any' n2 bl b -> + sum_any' (S (Nat.max n1 n2)) (al ++ bl) (BPLUS a b) + | Sum_Any_perm : forall n al bl s, + Permutation al bl -> + sum_any' n al s -> + sum_any' n bl s. + +(** [sum_any h v s] extends [sum_any'] to handle the empty list, which sums + to << pos_zero >>. *) + +Inductive sum_any {NAN : FPCore.Nans} {t : type} : + forall (h : nat) (v : list (ftype t)) (s : ftype t), Prop := + | Sum_Any_None : sum_any O nil pos_zero + | Sum_Any_Some : forall n v s, + sum_any' n v s -> + sum_any n v s. + +(** ** Equivalence Between [sum_rel] and [sum_any] + + Every [sum_rel] floating-point summation (starting from a zero value) + can be realized as a [sum_any] summation, up to floating-point equality. + The height of the [sum_any] tree is << Nat.pred (size v) >>. *) + +Lemma sum_rel_sum_any : + forall {NAN : FPCore.Nans} {t} z (v : list (ftype t)) s + (Hz : iszero z), + sum_rel z BPLUS v s -> + exists s', feq s s' /\ sum_any (Nat.pred (size v)) v s'. Proof. -destruct v; intros. -- -destruct z; try discriminate; clear Hz; -inversion H; clear H; subst; (eexists; split; [ | constructor]; reflexivity). -- -revert f s z Hz H; induction v; simpl; intros. -+ -inversion H; clear H; subst. -inversion H3; clear H3; subst. -destruct s0; try discriminate. -destruct s; -(eexists; split; [ | constructor; constructor]; -destruct f; try reflexivity; -destruct s; reflexivity). -+ -inversion H; clear H; subst. -specialize (IHv a s0 z Hz H3). -change (cons f (cons a v)) with ([f] ++ cons a v). -replace (S (size v)) with (S (Nat.max O (size v))) by lia. -destruct IHv as [s1 [? ?]]. -eexists. -inversion H0; clear H0; subst. -simpl in H1. -split. -2:{ constructor 2. -eapply Sum_Any_split; auto. -apply Sum_Any_1. -eassumption. -} -clear z Hz H3 H1. -rewrite H; auto. + destruct v; intros. + - (* empty list *) + destruct z; try discriminate; clear Hz; + inversion H; clear H; subst; + (eexists; split; [ | constructor]; reflexivity). + - (* non-empty list *) + revert f s z Hz H; induction v; simpl; intros. + + (* singleton list *) + inversion H; clear H; subst. + inversion H3; clear H3; subst. + destruct s0; try discriminate. + destruct s; + (eexists; split; [ | constructor; constructor]; + destruct f; try reflexivity; + destruct s; reflexivity). + + (* cons case *) + inversion H; clear H; subst. + specialize (IHv a s0 z Hz H3). + change (cons f (cons a v)) with ([f] ++ cons a v). + replace (S (size v)) with (S (Nat.max O (size v))) by lia. + destruct IHv as [s1 [Hfeq Hany]]. + eexists. + inversion Hany; clear Hany; subst. + simpl in H. + split. + 2: { constructor 2. + eapply Sum_Any_split; auto. + apply Sum_Any_1. + eassumption. } + clear z Hz H3 H. + rewrite Hfeq; auto. Qed. +(** ** Bounds via Absolute Value Summation *) + +(** A real sum is bounded above by the sum of absolute values of its + elements. *) + Lemma sum_rel_R_abs : -forall l s1 s2, -sum_rel_R l s1 -> sum_rel_R (map Rabs l) s2 -> s1 <= s2. + forall l s1 s2, + sum_rel_R l s1 -> + sum_rel_R (map Rabs l) s2 -> + s1 <= s2. Proof. -induction l. -- -intros. -inversion H. -inversion H0. -nra. -- -intros. -inversion H; subst; clear H. -inversion H0; subst; clear H0. -eapply Rplus_le_compat. -apply Rle_abs. -fold sum_rel_R in H4. -fold sum_rel_R in H3. -apply IHl; -auto. + induction l. + - intros s1 s2 H1 H2. + inversion H1; inversion H2; nra. + - intros s1 s2 H1 H2. + inversion H1; subst; clear H1. + inversion H2; subst; clear H2. + eapply Rplus_le_compat; [apply Rle_abs |]. + fold sum_rel_R in H4, H3. + apply IHl; auto. Qed. -Lemma sum_rel_R_Rabs_pos : -forall l s, -sum_rel_R (map Rabs l) s -> 0 <= s. +(** The sum of absolute values is non-negative. *) + +Lemma sum_rel_R_Rabs_pos : + forall l s, + sum_rel_R (map Rabs l) s -> + 0 <= s. Proof. -induction l. -- -intros. -inversion H; compute; nra. -- -intros. -inversion H; subst; clear H. -fold sum_rel_R in H3. -specialize (IHl s0 H3). -apply Rplus_le_le_0_compat; auto; - try apply Rabs_pos. + induction l. + - intros s H. + inversion H; compute; nra. + - intros s H. + inversion H; subst; clear H. + fold sum_rel_R in H3. + specialize (IHl s0 H3). + apply Rplus_le_le_0_compat; auto; + apply Rabs_pos. Qed. +(** The sum of absolute values equals its own absolute value (i.e., it is + non-negative, so [Rabs s = s]). *) + Lemma sum_rel_R_Rabs_eq : -forall l s, -sum_rel_R (map Rabs l) s -> Rabs s = s. + forall l s, + sum_rel_R (map Rabs l) s -> + Rabs s = s. Proof. -induction l. -- -intros. -inversion H. -apply Rabs_R0. -- -intros. -inversion H; subst; clear H. -replace (Rabs(Rabs a + s0)) with - (Rabs a + s0); try nra. -symmetry. -rewrite Rabs_pos_eq; try nra. -apply Rplus_le_le_0_compat. -apply Rabs_pos. -eapply Rle_trans with (Rabs s0). -apply Rabs_pos. -eapply Req_le. -apply IHl. -fold sum_rel_R in H3. -auto. + induction l. + - intros s H. + inversion H. + apply Rabs_R0. + - intros s H. + inversion H; subst; clear H. + replace (Rabs (Rabs a + s0)) with (Rabs a + s0); try nra. + symmetry. + rewrite Rabs_pos_eq; try nra. + apply Rplus_le_le_0_compat. + + apply Rabs_pos. + + eapply Rle_trans with (Rabs s0). + * apply Rabs_pos. + * eapply Req_le. + fold sum_rel_R in H3. + apply IHl; auto. Qed. - +(** The absolute value of any real sum is bounded by the sum of absolute + values: [Rabs s1 <= Rabs s2] when << s2 >> is the sum of [Rabs] applied + elementwise to << l >>. *) + Lemma sum_rel_R_Rabs : -forall l s1 s2, -sum_rel_R l s1 -> sum_rel_R (map Rabs l) s2 -> Rabs s1 <= Rabs s2. + forall l s1 s2, + sum_rel_R l s1 -> + sum_rel_R (map Rabs l) s2 -> + Rabs s1 <= Rabs s2. Proof. -induction l. -- -intros. -inversion H. -inversion H0. -nra. -- -intros. -inversion H; subst; clear H. -inversion H0; subst; clear H0. -fold sum_rel_R in H4. -fold sum_rel_R in H3. -eapply Rle_trans. -apply Rabs_triang. -replace (Rabs(Rabs a + s0)) with - (Rabs a + s0). -eapply Rplus_le_compat; try nra. -eapply Rle_trans with (Rabs s0). -fold sum_rel_R in H4. -fold sum_rel_R in H3. -apply IHl; auto. -apply Req_le. -eapply sum_rel_R_Rabs_eq; apply H3. -symmetry. -rewrite Rabs_pos_eq; try nra. -apply Rplus_le_le_0_compat. -apply Rabs_pos. -eapply Rle_trans with (Rabs s0). -apply Rabs_pos. -apply Req_le. -eapply sum_rel_R_Rabs_eq; apply H3. + induction l. + - intros s1 s2 H1 H2. + inversion H1; inversion H2; nra. + - intros s1 s2 H1 H2. + inversion H1; subst; clear H1. + inversion H2; subst; clear H2. + fold sum_rel_R in H4, H3. + eapply Rle_trans; [apply Rabs_triang |]. + replace (Rabs (Rabs a + s0)) with (Rabs a + s0). + + eapply Rplus_le_compat; try nra. + eapply Rle_trans with (Rabs s0). + * apply IHl; auto. + * apply Req_le. + eapply sum_rel_R_Rabs_eq; apply H3. + + symmetry. + rewrite Rabs_pos_eq; try nra. + apply Rplus_le_le_0_compat. + * apply Rabs_pos. + * eapply Rle_trans with (Rabs s0). + -- apply Rabs_pos. + -- apply Req_le. + eapply sum_rel_R_Rabs_eq; apply H3. Qed. +(** ** Singleton Summation Lemmas *) + +(** A [sum_rel_R] derivation for a singleton list determines the sum uniquely. *) + Lemma sum_rel_R_single : -forall (a : R) (fs : R), sum_rel_R [a] fs -> fs = a. + forall (a : R) (fs : R), + sum_rel_R [a] fs -> + fs = a. Proof. -intros. -inversion H; auto. -inversion H3. subst. -apply Rplus_0_r. + intros a fs H. + inversion H; subst; auto. + inversion H3; subst. + apply Rplus_0_r. Qed. +(** A real value is the [sum_rel_R] of its own singleton list. *) + Lemma sum_rel_R_single' : -forall (a : R) , sum_rel_R [a] a. + forall (a : R), + sum_rel_R [a] a. Proof. -intros. -unfold sum_rel_R. -replace a with (a + 0) at 2 by nra. -apply sum_rel_cons. apply sum_rel_nil. -Qed. + intros a. + unfold sum_rel_R. + replace a with (a + 0) at 2 by nra. + apply sum_rel_cons. + apply sum_rel_nil. +Qed. +(** ** Permutation Invariance *) + +(** Inserting an element into an arbitrary position in the middle of a + split list preserves the [sum_rel_R] sum (with the element added). *) + Lemma sum_rel_R_app_cons : -forall l' l'' a s, -sum_rel_R (l' ++ l'') s -> -sum_rel_R (l' ++ a :: l'') (a + s). + forall l' l'' a s, + sum_rel_R (l' ++ l'') s -> + sum_rel_R (l' ++ a :: l'') (a + s). Proof. -induction l'; simpl. -{ intros; apply sum_rel_cons; auto. } -intros. -inversion H; subst; clear H. -specialize (IHl' l'' a0 s0 H3). -replace (a0 + (a + s0)) with (a + (a0 + s0)) by nra. -apply sum_rel_cons; auto. + induction l'; simpl. + - intros l'' a s H. + apply sum_rel_cons; auto. + - intros l'' a0 s H. + inversion H; subst; clear H. + specialize (IHl' l'' a0 s0 H3). + replace (a0 + (a + s0)) with (a + (a0 + s0)) by nra. + apply sum_rel_cons; auto. Qed. -Lemma sum_rel_bound : - forall (l : list R) (rs a: R) - (Hrs : sum_rel_R l rs) - (Hin : forall x, In x l -> Rabs x <= a), - Rabs rs <= INR (size l) * a. -Proof. -induction l; intros. -{ inversion Hrs; subst; simpl; rewrite Rabs_R0; nra. } - inversion Hrs; subst. - eapply Rle_trans; [apply Rabs_triang|]. - eapply Rle_trans; [apply Rplus_le_compat; - [apply Hin; simpl; auto| apply IHl; - [ apply H2 | intros; apply Hin; simpl; auto ] ] | ]. - apply Req_le. replace (size (a :: l)) with (size l + 1)%nat by (simpl; lia). - rewrite plus_INR; simpl; nra. -Qed. - + +(** [sum_rel_R] is invariant under list permutation. *) + Lemma sum_rel_R_permute : - forall (l l0: list R) - (Hper: Permutation l l0) (rs: R) - (Hrs: sum_rel_R l rs), - sum_rel_R l0 rs. + forall (l l0 : list R) + (Hper : Permutation l l0) + (rs : R) + (Hrs : sum_rel_R l rs), + sum_rel_R l0 rs. Proof. -intros ?. -induction l. -{ intros; inversion Hrs; subst. -apply Permutation_nil in Hper; subst; simpl; auto. } -intros. -apply Permutation_sym in Hper. -pose proof Permutation_vs_cons_inv Hper as H. -destruct H as (l' & l'' & H); subst. -apply Permutation_sym in Hper. -pose proof (@Permutation_cons_app_inv R l l' l'' a Hper). -inversion Hrs; subst. fold sum_rel_R in H3. -specialize (IHl (l' ++ l'') H s H3). -clear Hrs. -apply sum_rel_R_app_cons; auto. + intros l. + induction l. + - intros l0 Hper rs Hrs. + inversion Hrs; subst. + apply Permutation_nil in Hper; subst; simpl; auto. + - intros l0 Hper rs Hrs. + apply Permutation_sym in Hper. + pose proof Permutation_vs_cons_inv Hper as Hinv. + destruct Hinv as (l' & l'' & Heq); subst. + apply Permutation_sym in Hper. + pose proof (@Permutation_cons_app_inv R l l' l'' a Hper) as Hperm'. + inversion Hrs; subst. + fold sum_rel_R in H2. + specialize (IHl (l' ++ l'') Hperm' s H2). + clear Hrs. + apply sum_rel_R_app_cons; auto. Qed. +(** [sum_rel_R] over mapped floating-point values is invariant under + permutation of the floating-point list. *) + Lemma sum_rel_R_permute_t : - forall (t: type) (l l0: list (ftype t)) - (Hper: Permutation l l0) (rs: R) - (Hrs: sum_rel_R (map FT2R l) rs), - sum_rel_R (map FT2R l0) rs. + forall (t : type) (l l0 : list (ftype t)) + (Hper : Permutation l l0) + (rs : R) + (Hrs : sum_rel_R (map FT2R l) rs), + sum_rel_R (map FT2R l0) rs. +Proof. + intros t l l0 Hper rs Hrs. + apply sum_rel_R_permute with (map FT2R l); auto. + apply Permutation_map; auto. +Qed. + +(** ** Uniform Bound on Sum Magnitude + + Given a pointwise bound << a >> on the absolute values of list elements, + the magnitude of the sum is bounded by [INR] << (size l) * a >>. *) + +Lemma sum_rel_bound : + forall (l : list R) (rs a : R) + (Hrs : sum_rel_R l rs) + (Hin : forall x, In x l -> Rabs x <= a), + Rabs rs <= INR (size l) * a. Proof. -intros; -apply sum_rel_R_permute with (map FT2R l); auto. -apply Permutation_map; auto. + induction l; intros rs a1 Hrs Hin. + - inversion Hrs; subst; simpl; rewrite Rabs_R0; nra. + - inversion Hrs; subst. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans. + + apply Rplus_le_compat. + * apply Hin; simpl; auto. + * apply IHl; [apply H2 | intros x Hx; apply Hin; simpl; auto]. + + apply Req_le. + replace (size (a :: l)) with (size l + 1)%nat by (simpl; lia). + rewrite plus_INR; simpl; nra. Qed. +(** ** [sumR]: Functional Real Summation + + [sumR] is a computable fold-right definition of real summation. *) + Definition sumR := foldr Rplus 0. -Lemma sumRabs_pos x : -0 <= sumR (map Rabs x). +(** The sum of absolute values via [sumR] is non-negative. *) + +Lemma sumRabs_pos : + forall x, + 0 <= sumR (map Rabs x). Proof. -induction x; simpl; try nra. -apply Rplus_le_le_0_compat; [apply Rabs_pos | nra]. + induction x; simpl; try nra. + apply Rplus_le_le_0_compat; [apply Rabs_pos | nra]. Qed. -Lemma sumRabs_Rabs x : -Rabs (sumR (map Rabs x)) = sumR (map Rabs x). -Proof. rewrite Rabs_pos_eq; auto. apply sumRabs_pos. Qed. +(** The absolute value of [sumR] of absolute values equals itself. *) -Lemma sumR_mult x a : -sumR x * a = sumR (map (Rmult a) x). -Proof. induction x; simpl; nra. Qed. +Lemma sumRabs_Rabs : + forall x, + Rabs (sumR (map Rabs x)) = sumR (map Rabs x). +Proof. + intros x. + rewrite Rabs_pos_eq; auto. + apply sumRabs_pos. +Qed. -Lemma sumR_le_sumRabs x : -Rabs (sumR x) <= Rabs (sumR (map Rabs x)). +(** Scalar multiplication distributes over [sumR]. *) + +Lemma sumR_mult : + forall x a, + sumR x * a = sumR (map (Rmult a) x). +Proof. + induction x; simpl; intros a1. + - nra. + - rewrite <- IHx; nra. +Qed. + +(** The absolute value of a real sum is bounded by the sum of absolute + values. *) + +Lemma sumR_le_sumRabs : + forall x, + Rabs (sumR x) <= Rabs (sumR (map Rabs x)). Proof. -induction x; simpl; [nra | ]. -rewrite sumRabs_Rabs in IHx. -eapply Rle_trans. -2: rewrite Rabs_pos_eq. -apply Rabs_triang. -apply Rplus_le_compat_l; auto. -apply Rplus_le_le_0_compat; -[apply Rabs_pos| apply sumRabs_pos]. + induction x; simpl; [nra |]. + rewrite sumRabs_Rabs in IHx. + eapply Rle_trans. + 2: rewrite Rabs_pos_eq. + - apply Rabs_triang. + - apply Rplus_le_compat_l; auto. + - apply Rplus_le_le_0_compat; + [apply Rabs_pos | apply sumRabs_pos]. Qed. -Lemma sumR_app_cons l' l'' a: -a + sumR (l' ++ l'') = sumR (l' ++ a :: l''). -Proof. induction l'; simpl; [nra | rewrite <- IHl'; nra]. Qed. +(** Inserting an element at an arbitrary position in a split list preserves + the [sumR] value (with that element added). *) + +Lemma sumR_app_cons : + forall l' l'' a, + a + sumR (l' ++ l'') = sumR (l' ++ a :: l''). +Proof. + induction l'; simpl. + - intros; nra. + - intros; rewrite <- IHl'; nra. +Qed. + +(** [sumR] is invariant under list permutation. *) Lemma sumR_permute : - forall x x0 (Hper: Permutation x x0) , - sumR x = sumR x0. + forall x x0 + (Hper : Permutation x x0), + sumR x = sumR x0. Proof. -intros ?. -induction x; intros. -{ apply Permutation_nil in Hper; subst; simpl; auto. } -apply Permutation_sym in Hper. -pose proof Permutation_vs_cons_inv Hper as H. -destruct H as (l' & l'' & H); subst. -apply Permutation_sym in Hper. -pose proof (@Permutation_cons_app_inv R x l' l'' a Hper). -specialize (IHx (l' ++ l'') H ). -simpl. rewrite IHx sumR_app_cons; auto. + intros x. + induction x; intros x0 Hper. + - apply Permutation_nil in Hper; subst; simpl; auto. + - apply Permutation_sym in Hper. + pose proof Permutation_vs_cons_inv Hper as Hinv. + destruct Hinv as (l' & l'' & Heq); subst. + apply Permutation_sym in Hper. + pose proof (@Permutation_cons_app_inv R x l' l'' a Hper) as Hperm'. + specialize (IHx (l' ++ l'') Hperm'). + simpl. + rewrite IHx sumR_app_cons; auto. Qed. -Lemma sumR_rev: forall l, sumR (rev l) = sumR l. +(** [sumR] is invariant under list reversal. *) + +Lemma sumR_rev : + forall l, + sumR (rev l) = sumR l. Proof. -move => l. -apply sumR_permute. -rewrite rev_list_rev. -apply Permutation_sym. -apply Permutation_rev. + move => l. + apply sumR_permute. + rewrite rev_list_rev. + apply Permutation_sym. + apply Permutation_rev. Qed. -Lemma sum_rel_bound' : - forall (t : type) (l : list (ftype t)) (rs a: R) - (Hrs : sum_rel_R (map FT2R l) rs) - (Hin : forall x, In x l -> Rabs (FT2R x) <= a), - Rabs rs <= INR (size l) * a. +(** ** Uniform Bounds on Floating-Point Sums via Real Arithmetic *) + +(** Bound on [sum_rel_R] over [FT2R]-mapped floating-point lists, given + a pointwise bound on [Rabs (FT2R x)]. *) + +Lemma sum_rel_bound' : + forall (t : type) (l : list (ftype t)) (rs a : R) + (Hrs : sum_rel_R (map FT2R l) rs) + (Hin : forall x, In x l -> Rabs (FT2R x) <= a), + Rabs rs <= INR (size l) * a. Proof. -induction l; intros. -{ inversion Hrs; subst; simpl; rewrite Rabs_R0; nra. } - inversion Hrs; subst. - eapply Rle_trans; [apply Rabs_triang|]. - eapply Rle_trans; [apply Rplus_le_compat; - [apply Hin; simpl; auto| apply IHl; - [ apply H2 | intros; apply Hin; simpl; auto ] ] | ]. - apply Req_le. replace (size (a :: l)) with (size l + 1)%nat by (simpl; lia). - rewrite plus_INR; simpl; nra. + induction l; intros rs a1 Hrs Hin. + - inversion Hrs; subst; simpl; rewrite Rabs_R0; nra. + - inversion Hrs; subst. + eapply Rle_trans; [apply Rabs_triang |]. + eapply Rle_trans. + + apply Rplus_le_compat. + * apply Hin; simpl; auto. + * apply IHl; [apply H2 | intros x Hx; apply Hin; simpl; auto]. + + apply Req_le. + replace (size (a :: l)) with (size l + 1)%nat by (simpl; lia). + rewrite plus_INR; simpl; nra. Qed. -Lemma sum_rel_bound'' : - forall (t : type) (l : list (ftype t)) (rs_abs a: R) - (Hrs : sum_rel_R (map Rabs (map FT2R l)) rs_abs) - (Hin : forall x, In x l -> Rabs (FT2R x) <= a), - rs_abs <= INR (size l) * a. +(** Bound on the sum of absolute values of [FT2R]-mapped floating-point + list elements, given a pointwise bound. *) +Lemma sum_rel_bound'' : + forall (t : type) (l : list (ftype t)) (rs_abs a : R) + (Hrs : sum_rel_R (map Rabs (map FT2R l)) rs_abs) + (Hin : forall x, In x l -> Rabs (FT2R x) <= a), + rs_abs <= INR (size l) * a. Proof. -induction l; intros. -{ inversion Hrs; subst; simpl. compute. nra. } - inversion Hrs; subst. - fold sum_rel_R in H2. - eapply Rle_trans; [apply Rplus_le_compat; - [apply Hin; simpl; auto| apply IHl; - [ apply H2 | intros; apply Hin; simpl; auto ] ] | ]. - apply Req_le. replace (size (a :: l)) with (size l + 1)%nat by (simpl; lia). - rewrite plus_INR; simpl; nra. -Qed. - -Lemma sum_rel_R_fold : forall l rs, - sum_rel_R l rs -> rs = sumR l. -Proof. -induction l. -intros; inversion H; simpl; auto. -intros; inversion H. -fold sum_rel_R in H3. -specialize (IHl s H3). -subst; simpl. -auto. + induction l; intros rs_abs ? Hrs Hin. + - inversion Hrs; subst; simpl; compute; nra. + - inversion Hrs; subst. + fold sum_rel_R in H2. + eapply Rle_trans. + + apply Rplus_le_compat. + * apply Hin; simpl; auto. + * apply IHl; [apply H2 | intros x Hx; apply Hin; simpl; auto]. + + apply Req_le. + replace (size (a :: l)) with (size l + 1)%nat by (simpl; lia). + rewrite plus_INR; simpl; nra. Qed. -Lemma sum_map_Rmult (l : list R) (s a: R): -sum_rel_R l s -> -sum_rel_R (map (Rmult a) l) (a * s). -Proof. -revert l s a. induction l. -{ intros. simpl. inversion H; subst; rewrite Rmult_0_r; auto. } -intros. inversion H. destruct l. -{ simpl. inversion H3; subst. rewrite Rplus_0_r. - apply sum_rel_R_single'. } -fold sum_rel_R in H3. specialize (IHl s0 a0 H3). -simpl. rewrite Rmult_plus_distr_l; apply sum_rel_cons. -fold sum_rel_R. simpl in IHl; auto. +(** [sum_rel_R] and [sumR] agree: a [sum_rel_R] derivation yields the same + value as [sumR]. *) + +Lemma sum_rel_R_fold : + forall l rs, + sum_rel_R l rs -> + rs = sumR l. +Proof. + induction l. + - intros rs H. + inversion H; simpl; auto. + - intros rs H. + inversion H; subst. + fold sum_rel_R in H3. + specialize (IHl s H3). + subst; simpl; auto. Qed. -Section WithSTD. -Context {NAN: FPCore.Nans} {t : type}. +(** Scalar multiplication distributes over [sum_rel_R]: if << l >> sums to << s >>, + then [map (Rmult a) l] sums to << a * s >>. *) + +Lemma sum_map_Rmult : + forall (l : list R) (s a : R), + sum_rel_R l s -> + sum_rel_R (map (Rmult a) l) (a * s). +Proof. + induction l; intros s ? H; simpl. + - inversion H; subst; rewrite Rmult_0_r; auto. + - inversion H; subst. + destruct l. + + simpl. inversion H3; subst. rewrite Rplus_0_r. + apply sum_rel_R_single'. + + fold sum_rel_R in H3. + specialize (IHl s0 a0 H3). + simpl. + rewrite Rmult_plus_distr_l. + apply sum_rel_cons. + fold sum_rel_R. + simpl in IHl; auto. +Qed. -Definition sum_rel_Ft := @sum_rel (ftype t) neg_zero (BPLUS ). +(** ** Floating-Point Summation Instances and Properties *) -Lemma sum_rel_Ft_single fs a: -Binary.is_finite fs = true -> -sum_rel_Ft [a] fs -> fs = a. +Section WithSTD. +Context {NAN : FPCore.Nans} {t : type}. + +(** [sum_rel_Ft] is [sum_rel] instantiated to floating-point addition with + default value << neg_zero >>. *) + +Definition sum_rel_Ft := @sum_rel (ftype t) neg_zero BPLUS. + +(** For a finite floating-point value << fs >>, a [sum_rel_Ft] derivation for a + singleton list determines the sum uniquely. *) + +Lemma sum_rel_Ft_single : + forall (fs a : ftype t), + Binary.is_finite fs = true -> + sum_rel_Ft [a] fs -> + fs = a. Proof. -move => FIN Hs. -move: FIN. -inversion Hs; subst. -inversion H2; subst. -rewrite /sum/BPLUS/BINOP - /neg_zero. -move => FIN. -destruct a; - try discriminate FIN => //; -destruct s => //. + move => fs a Hfin Hs. + inversion Hs; subst. + inversion H2; subst. + rewrite /sum /BPLUS /BINOP /neg_zero. + move: Hfin. + destruct a; + try discriminate Hfin => //; + destruct s => //. Qed. +(** For any floating-point list and [sum_rel_Ft] derivation, there exists a + corresponding real-valued sum under [sum_rel_R]. *) + Lemma sum_rel_R_exists : forall (l : list (ftype t)) (fs : ftype t) - (Hfs : sum_rel_Ft l fs), - exists rs, sum_rel_R (map FT2R l) rs. + (Hfs : sum_rel_Ft l fs), + exists rs, sum_rel_R (map FT2R l) rs. Proof. -intros ?. induction l. -{ simpl; exists 0. apply sum_rel_nil. } -intros. inversion Hfs; subst. -fold sum_rel_Ft in H2. -destruct (IHl s H2) as (rs & Hrs); clear IHl. -exists (FT2R a + rs); simpl. -apply sum_rel_cons; auto. + intros l. + induction l. + - simpl; intros fs Hfs. + exists 0. apply sum_rel_nil. + - intros fs Hfs. + inversion Hfs; subst. + fold sum_rel_Ft in H2. + destruct (IHl s H2) as (rs & Hrs); clear IHl. + exists (FT2R a + rs); simpl. + apply sum_rel_cons; auto. Qed. -Lemma sum_rel_R_abs_exists: +(** For any floating-point list and [sum_rel_Ft] derivation, there exists a + corresponding sum of absolute values under [sum_rel_R]. *) + +Lemma sum_rel_R_abs_exists : forall (l : list (ftype t)) (fs : ftype t) - (Hfs : sum_rel_Ft l fs), - exists rs, sum_rel_R (map Rabs (map FT2R l)) rs. + (Hfs : sum_rel_Ft l fs), + exists rs, sum_rel_R (map Rabs (map FT2R l)) rs. Proof. -intros ?. induction l. -{ simpl; exists 0. apply sum_rel_nil. } -intros. inversion Hfs; subst. -fold sum_rel_Ft in H2. -destruct (IHl s H2) as (rs & Hrs); clear IHl. -exists (Rabs (FT2R a) + rs); simpl. -apply sum_rel_cons; auto. + intros l. + induction l. + - simpl; intros fs Hfs. + exists 0. apply sum_rel_nil. + - intros fs Hfs. + inversion Hfs; subst. + fold sum_rel_Ft in H2. + destruct (IHl s H2) as (rs & Hrs); clear IHl. + exists (Rabs (FT2R a) + rs); simpl. + apply sum_rel_cons; auto. Qed. - + +(** If the result << fs >> of a [sum_rel_Ft] computation is finite, then every + element of the input list is also finite. *) + Lemma is_finite_in : - forall (l : list (ftype t)) fs, - sum_rel_Ft l fs -> - let e := @default_abs t in - let d := @default_rel t in - let ov := powerRZ 2 (femax t) in - Binary.is_finite fs = true -> - forall a, In a l -> Binary.is_finite a = true. + forall (l : list (ftype t)) (fs : ftype t), + sum_rel_Ft l fs -> + Binary.is_finite fs = true -> + forall a, In a l -> Binary.is_finite a = true. Proof. -induction l => //=. -move => fs H0 H1 s [Hs|Hs]; subst. -inversion H0; subst. -move : H1; rewrite /sum => H1. -clear - H1; destruct s,s0; try destruct s; try destruct s0; try discriminate H1; reflexivity. -inversion H0; clear H0; subst. -fold sum_rel_Ft in H4. -eapply IHl; try eassumption. -clear - H1; destruct a,s0; try destruct s; try destruct s0; try discriminate H1; reflexivity. + induction l => //=. + move => fs Hsum Hfin a1 [Heq | Hin]; subst. + - inversion Hsum; subst. + move: Hfin; rewrite /sum => Hfin. + destruct (BPLUS_finite_e a1 s); auto. + - inversion Hsum; clear Hsum; subst. + fold sum_rel_Ft in H2. + eapply IHl; try eassumption. + destruct (BPLUS_finite_e a s); auto. Qed. -Definition sumF := foldl(Basics.flip (@BPLUS _ t)) neg_zero. +(** [sumF] is a computable fold-left definition of floating-point summation, + accumulating with [BPLUS] from << neg_zero >>. *) + +Definition sumF := foldl (Basics.flip (@BPLUS _ t)) neg_zero. -Lemma sum_rel_Ft_fold : forall l fs, - sum_rel_Ft (rev l) fs -> fs = sumF l. +(** [sum_rel_Ft] over a reversed list agrees with [sumF] on the original list. *) + +Lemma sum_rel_Ft_fold : + forall l fs, + sum_rel_Ft (rev l) fs -> + fs = sumF l. Proof. -intros. -rewrite /sumF -(revK l) foldl_rev. -move :fs H. -induction (rev l). -intros; inversion H; simpl; auto. -intros; inversion H. -fold sum_rel_Ft in H3. -specialize (IHl0 s H3). -subst; simpl. -auto. + intros l fs H. + rewrite /sumF -(revK l) foldl_rev. + revert fs H. + induction (rev l). + - intros fs H; inversion H; simpl; auto. + - intros fs H. + inversion H; subst. + fold sum_rel_Ft in H3. + specialize (IHl0 s H3). + subst; simpl; auto. Qed. -(** subtract_loop is a variant on summation used in some implementations of Cholesky decomposition, - among other things. We should be able to prove an equivalence, of sorts, with sum_rel, - so that the accuracy theorem for sum_rel can apply here as well. *) -(* Definition subtract_loop: forall (c: ftype t) (al: list (ftype t)), ftype t := foldl BMINUS. *) - -Lemma subtract_loop_sumR: forall (c: ftype t) (al: list (ftype t)), - feq (foldl BMINUS c al) (sumF (c :: map BOPP al)). +(** ** Subtraction Loop + + The subtraction loop << foldl BMINUS c al >> (used in, e.g., Cholesky + decomposition) is floating-point equal to [sumF] << (c :: map BOPP al) >>, + i.e., summing << c >> followed by the negations of << al >>. This enables + accuracy theorems for [sum_rel] to transfer to subtraction-loop + implementations. *) + +Lemma subtract_loop_sumR : + forall (c : ftype t) (al : list (ftype t)), + feq (foldl BMINUS c al) (sumF (c :: map BOPP al)). Proof. -intros. -revert c; induction al; simpl; intros. -destruct c; try destruct s; reflexivity. -rewrite {}IHal /sumF -/(ftype t). -simpl. -set x := Basics.flip BPLUS neg_zero (BMINUS c a). -set y := Basics.flip BPLUS (Basics.flip BPLUS neg_zero c) (BOPP a). -assert (feq x y) by rewrite /x /y /Basics.flip !BPLUS_neg_zero BPLUS_comm MINUS_PLUS_BOPP //. -clearbody x; clearbody y. -revert x y H; induction al; simpl; intros; auto. -apply IHal; auto. -apply BPLUS_mor; auto. + intros. + revert c; induction al; simpl; intros. + - destruct c; try destruct s; reflexivity. + - rewrite {}IHal /sumF -/(ftype t). + simpl. + set x := Basics.flip BPLUS neg_zero (BMINUS c a). + set y := Basics.flip BPLUS (Basics.flip BPLUS neg_zero c) (BOPP a). + assert (Hfeq : feq x y) by + (rewrite /x /y /Basics.flip !BPLUS_neg_zero BPLUS_comm MINUS_PLUS_BOPP //; auto). + clearbody x; clearbody y. + revert x y Hfeq. + induction al; simpl; intros x y Hfeq; auto. + apply IHal; auto. + apply BPLUS_mor; auto. Qed. -Lemma sum_rel_Ft_exists: forall (l: list (ftype t)), exists s, sum_rel_Ft l s. +(** Every floating-point list has at least one [sum_rel_Ft] derivation. *) + +Lemma sum_rel_Ft_exists : + forall (l : list (ftype t)), + exists s, sum_rel_Ft l s. Proof. -unfold sum_rel_Ft. -induction l; simpl. -eexists; constructor. -destruct IHl as [s ?]. -eexists; constructor; eauto. + unfold sum_rel_Ft. + induction l; simpl. + - eexists; constructor. + - destruct IHl as [s Hs]. + eexists; constructor; eauto. Qed. -Lemma subtract_loop_sum_any: forall (c: ftype t) (al: list (ftype t)), - exists s, feq (foldl BMINUS c al) s /\ sum_any (size al) (rev (c::map BOPP al)) s. +(** The subtraction loop << foldl BMINUS c al >> can be related to [sum_any], + establishing that it falls within the framework of any-order summation. + The resulting [sum_any] tree has height << size al >> and input list + << rev (c :: map BOPP al) >>. *) + +Lemma subtract_loop_sum_any : + forall (c : ftype t) (al : list (ftype t)), + exists s, + feq (foldl BMINUS c al) s /\ + sum_any (size al) (rev (c :: map BOPP al)) s. Proof. -intros. -assert (exists s: ftype t, sum_rel neg_zero BPLUS (rev (c::map BOPP al)) s /\ feq (foldl BMINUS c al) s). -- -destruct (sum_rel_Ft_exists (rev (cons c (map BOPP al)))) as [s ?]. -exists s; split; auto. -apply sum_rel_Ft_fold in H. -subst s. -apply subtract_loop_sumR. -- -destruct H as [s [? ?]]. -apply sum_rel_sum_any in H; [ | reflexivity]. -destruct H as [s' [? ?]]. -exists s'. -split. rewrite <- H; auto. -simpl in H1. -rewrite size_rev /= size_map in H1. -auto. + intros c al. + assert (Hexists : exists s : ftype t, + sum_rel neg_zero BPLUS (rev (c :: map BOPP al)) s /\ + feq (foldl BMINUS c al) s). + - destruct (sum_rel_Ft_exists (rev (c :: map BOPP al))) as [s Hs]. + exists s; split; auto. + apply sum_rel_Ft_fold in Hs. + subst s. + apply subtract_loop_sumR. + - destruct Hexists as [s [Hrel Hfeq]]. + apply sum_rel_sum_any in Hrel; [ | reflexivity]. + destruct Hrel as [s' [Hfeq' Hany]]. + exists s'. + split. + + rewrite <- Hfeq'; auto. + + simpl in Hany. + rewrite size_rev /= size_map in Hany. + auto. Qed. End WithSTD. \ No newline at end of file diff --git a/accuracy_proofs/vec_op_acc.v b/accuracy_proofs/vec_op_acc.v index 3d1f75d..c0e8b4b 100644 --- a/accuracy_proofs/vec_op_acc.v +++ b/accuracy_proofs/vec_op_acc.v @@ -1,225 +1,350 @@ +(** * Mixed Error Bounds for Floating-Point Matrix-Vector Operations + + This file establishes mixed error bounds for scalar-matrix multiplication, + entry-wise matrix addition, and general matrix-vector multiply-accumulate + (GEMV), computed in floating-point arithmetic. Each result decomposes + the floating-point output into its exact real counterpart plus structured + entry-wise error terms. + + ** Main Results + + - [Fscalemx_mixed_error]: Shows that the floating-point scalar-matrix + product can be expressed as an exact scaled sum where each matrix entry + carries a relative error bounded by the unit roundoff and an absolute + error bounded by the underflow threshold. + + - [Faddmx_mixed_error]: Shows that the floating-point entry-wise matrix + sum can be expressed as an exact sum of two componentwise-perturbed + matrices, where each perturbation is relative and bounded by the unit + roundoff. + + - [Smat_sumF_mixed_error]: Shows that the floating-point scaled matrix + sum can be expressed as a sum of two perturbed scaled matrices, by + composing [Fscalemx_mixed_error] and [Faddmx_mixed_error]. + + - [Smat_vec_mul_mixed_error]: Shows that the floating-point scaled + matrix-vector product can be expressed as an exact scaled result with + a forward error on the inner matrix-vector multiply and a mixed error + from the outer scalar multiplication. + + - [gemv_error]: Shows that the floating-point GEMV operation + %$s_1 A x + s_2 y$%#\(s_1 A x + s_2 y\)# can be expressed as an exact + result plus structured error matrices, combining the bounds from all + preceding lemmas. + + ** Dependencies + + This file relies on: + - [preamble], [common]: basic setup and shared definitions + - [dotprod_model], [sum_model]: relational models of dot product and summation + - [dot_acc], [float_acc_lems]: accuracy lemmas + - [mv_mathcomp]: floating-point matrix/vector operations + - [gemv_acc]: forward error bound for floating-point matrix-vector multiplication +*) + From LAProof.accuracy_proofs Require Import preamble common - dotprod_model sum_model dot_acc float_acc_lems mv_mathcomp gemv_acc. + dotprod_model sum_model dot_acc float_acc_lems mv_mathcomp gemv_acc. Section WithNans. -Context {NAN: FPCore.Nans} {t : FPStdLib.type}. -Notation g := (@common.g t). +Context {NAN : FPCore.Nans} {t : FPStdLib.type}. + +Notation g := (@common.g t). Notation g1 := (@common.g1 t). -Lemma Fscalemx_mixed_error: - forall [m n] (a: ftype t) (v: 'M[ftype t]_(m,n)) - (Hfin: F.finitemx (F.scalemx a v)), - let vr:= map_mx FT2R v in - exists (e eta: 'M[R]_(m,n)), - map_mx FT2R (F.scalemx a v) = (scalemx (FT2R a) (vr + e) + eta)%Ri - /\ (forall i j, exists d, e i j = vr i j * d /\ Rabs d <= @default_rel t) - /\ (forall i j, Rabs (eta i j) <= @default_abs t). +(** ** Scalar-Matrix Multiplication: Mixed Error Bound *) + +(** [Fscalemx_mixed_error] shows that the floating-point scalar-matrix product + equals an exact scaled sum of perturbed matrix entries plus an absolute + residual, with each entry carrying a relative perturbation bounded by + the unit roundoff and an absolute error bounded by the underflow threshold. *) + +Lemma Fscalemx_mixed_error : + forall [m n] (a : ftype t) (v : 'M[ftype t]_(m, n)) + (Hfin : F.finitemx (F.scalemx a v)), + let vr := map_mx FT2R v in + exists (e eta : 'M[R]_(m, n)), + map_mx FT2R (F.scalemx a v) = (scalemx (FT2R a) (vr + e) + eta)%Ri + /\ (forall i j, exists d, + e i j = vr i j * d /\ Rabs d <= @default_rel t) + /\ (forall i j, Rabs (eta i j) <= @default_abs t). Proof. -intros. -unfold F.scalemx. -pose F (i: 'I_m) (j: 'I_n) (x: R*R) := - let '(e,eta) := x in - FT2R (@BMULT NAN _ a (v i j)) = FT2R a * (FT2R (v i j) + e) + eta /\ - (exists d:R, e = vr i j * d /\ Rabs d <= @default_rel t) /\ - Rabs eta <= @default_abs t. -assert (forall i j, exists e eta, F i j (e,eta)). { - intros i j. - subst F. simpl. - subst vr. - rewrite !mxE. - specialize (Hfin i j). rewrite mxE in Hfin. - set (x := fun_of_matrix v i j) in Hfin|-*. simpl in x. clearbody x. - destruct (BMULT_accurate a x) as (del & eps & HD & HE & HF & Heq). - by apply is_finite_BMULT_no_overflow. -rewrite {}Heq. -remember ((FT2R x) * del)%Re as d. - exists d, eps. - repeat split. - change (FT2R a * FT2R x * (1 + del) + eps = FT2R a * (FT2R x + d) + eps)%Re. - nra. - exists del; split; auto. - apply /RleP; auto. - apply /RleP; auto. -} -destruct (exists_mx F). -intros; destruct (H i j) as [e [eta H']]. exists (e,eta); auto. -exists (map_mx fst x), (map_mx snd x). -subst F. subst vr. -repeat split. -- -apply matrixP; intros i j; specialize (H0 i j); simpl in H0. -rewrite !mxE in H0|-*. -destruct (fun_of_matrix x i j); simpl. -destruct H0 as [? [? ?]]. -apply H0. -- -intros i j; specialize (H0 i j); simpl in H0. -rewrite !mxE in H0|-*. -destruct (fun_of_matrix x i j); simpl. -destruct H0 as [? [? ?]]. -auto. -- -intros i j; specialize (H0 i j); simpl in H0. -rewrite !mxE in H0|-*. -destruct (fun_of_matrix x i j); simpl. -destruct H0 as [? [? ?]]. -auto. + intros m n a v Hfin vr. + unfold F.scalemx. + (* Define a pointwise property F capturing the per-entry error decomposition. *) + pose F (i : 'I_m) (j : 'I_n) (x : R * R) := + let '(e, eta) := x in + FT2R (@BMULT NAN _ a (v i j)) = FT2R a * (FT2R (v i j) + e) + eta + /\ (exists d : R, e = vr i j * d /\ Rabs d <= @default_rel t) + /\ Rabs eta <= @default_abs t. + (* Establish the per-entry error decomposition using BMULT_accurate. *) + assert (Hentry : forall i j, exists e eta, F i j (e, eta)). { + intros i j. + subst F vr; simpl. + rewrite !mxE. + specialize (Hfin i j); rewrite mxE in Hfin. + set (x := fun_of_matrix v i j) in Hfin |- *; clearbody x. + destruct (BMULT_accurate a x) as (del & eps & HD & HE & HF & Heq). + { by apply is_finite_BMULT_no_overflow. } + rewrite {}Heq. + (* Fold the relative error as d := FT2R x * del. *) + remember (FT2R x * del)%Re as d. + exists d, eps. + repeat split. + - (* Algebraic rearrangement: a*(x*(1+del)) + eps = a*(x+d) + eps *) + change (FT2R a * FT2R x * (1 + del) + eps + = FT2R a * (FT2R x + d) + eps)%Re. + nra. + - exists del; split; [auto | apply /RleP; auto]. + - apply /RleP; auto. + } + (* Lift the pointwise existentials to matrix existentials. *) + destruct (exists_mx F) as [x H0]. + { intros i j. + destruct (Hentry i j) as [e [eta H']]. + exists (e, eta); auto. } + exists (map_mx fst x), (map_mx snd x). + subst F vr. + repeat split. + - (* Prove the matrix equality entry-wise. *) + apply matrixP; intros i j. + specialize (H0 i j); simpl in H0. + rewrite !mxE in H0 |- *. + destruct (fun_of_matrix x i j) as [e eta]; simpl. + exact (proj1 H0). + - (* Prove the relative error bound on e. *) + intros i j. + specialize (H0 i j); simpl in H0. + rewrite !mxE in H0 |- *. + destruct (fun_of_matrix x i j); simpl. + exact (proj1 (proj2 H0)). + - (* Prove the absolute error bound on eta. *) + intros i j. + specialize (H0 i j); simpl in H0. + rewrite !mxE in H0 |- *. + destruct (fun_of_matrix x i j); simpl. + exact (proj2 (proj2 H0)). Qed. +(** ** Entry-wise Matrix Addition: Mixed Error Bound *) + +(** [Faddmx_mixed_error] shows that the floating-point entry-wise matrix sum + can be expressed as an exact sum of two componentwise-perturbed matrices, + where each perturbation is a relative error bounded by the unit roundoff. *) + Lemma Faddmx_mixed_error : - forall [m n] (A B: 'M[ftype t]_(m,n)) - (Hfin: F.finitemx (F.addmx A B)), - let Ar:= map_mx FT2R A in - let Br:= map_mx FT2R B in - exists (e1 e2 : 'M[R]_(m,n)), - map_mx FT2R (F.addmx A B) = ((Ar + e1) + (Br + e2))%Ri - /\ (forall i j, exists d, e1 i j = Ar i j * d /\ Rabs d <= @default_rel t) - /\ (forall i j, exists d, e2 i j = Br i j * d /\ Rabs d <= @default_rel t). + forall [m n] (A B : 'M[ftype t]_(m, n)) + (Hfin : F.finitemx (F.addmx A B)), + let Ar := map_mx FT2R A in + let Br := map_mx FT2R B in + exists (e1 e2 : 'M[R]_(m, n)), + map_mx FT2R (F.addmx A B) = ((Ar + e1) + (Br + e2))%Ri + /\ (forall i j, exists d, e1 i j = Ar i j * d /\ Rabs d <= @default_rel t) + /\ (forall i j, exists d, e2 i j = Br i j * d /\ Rabs d <= @default_rel t). Proof. -intros. -pose F (i: 'I_m) (j: 'I_n) (e12: R*R) := - let '(e1,e2) := e12 in - FT2R (@BPLUS NAN t (A i j) (B i j)) = ((Ar i j + e1) + (Br i j + e2)) - /\ (exists d, e1 = Ar i j * d /\ Rabs d <= @default_rel t) - /\ (exists d, e2 = Br i j * d /\ Rabs d <= @default_rel t). - -assert (forall i j, exists e1 e2, F i j (e1,e2)). { -subst F. -intros i j. simpl. specialize (Hfin i j). subst Ar. rewrite !mxE. rewrite mxE in Hfin. -set (a := A i j) in Hfin|-*. clearbody a. -set (b := B i j) in Hfin|-*. clearbody b. -destruct (BPLUS_finite_e _ _ Hfin) as [Ha Hb]. -destruct (BPLUS_accurate' _ _ Hfin) as (del & HD & Heq). -rewrite {}Heq. -exists (FT2R a * del), (FT2R b * del). -repeat split. -change ((FT2R a + FT2R b) * (1 + del) = - FT2R a + FT2R a * del + (FT2R b + FT2R b * del))%Re. -lra. -exists del; split; auto. -apply /RleP; auto. -exists del; split; auto. -apply /RleP; auto. -} -destruct (exists_mx F). -intros; destruct (H i j) as [e1 [e2 H']]. exists (e1,e2); auto. -exists (map_mx fst x), (map_mx snd x). -subst F. -repeat split. -- -apply matrixP; intros i j; specialize (H0 i j); simpl in H0. -rewrite !mxE in H0|-*. -destruct (fun_of_matrix x i j); simpl. -destruct H0 as [? [? ?]]. -apply H0. -- -intros i j; specialize (H0 i j); simpl in H0. -rewrite !mxE in H0|-*. -destruct (fun_of_matrix x i j); simpl. -destruct H0 as [? [? ?]]. -auto. -- -intros i j; specialize (H0 i j); simpl in H0. -rewrite !mxE in H0|-*. -destruct (fun_of_matrix x i j); simpl. -destruct H0 as [? [? ?]]. -auto. + intros m n A B Hfin Ar Br. + (* Define a pointwise property capturing the per-entry error decomposition. *) + pose F (i : 'I_m) (j : 'I_n) (e12 : R * R) := + let '(e1, e2) := e12 in + FT2R (@BPLUS NAN t (A i j) (B i j)) = ((Ar i j + e1) + (Br i j + e2)) + /\ (exists d, e1 = Ar i j * d /\ Rabs d <= @default_rel t) + /\ (exists d, e2 = Br i j * d /\ Rabs d <= @default_rel t). + (* Establish the per-entry error decomposition using BPLUS_accurate'. *) + assert (Hentry : forall i j, exists e1 e2, F i j (e1, e2)). { + intros i j. + subst F Ar Br; simpl. + specialize (Hfin i j); rewrite !mxE in Hfin |- *. + set (a := A i j) in Hfin |- *; clearbody a. + set (b := B i j) in Hfin |- *; clearbody b. + destruct (BPLUS_finite_e _ _ Hfin) as [Ha Hb]. + destruct (BPLUS_accurate' _ _ Hfin) as (del & HD & Heq). + rewrite {}Heq. + exists (FT2R a * del), (FT2R b * del). + repeat split. + - (* Both summands are perturbed by the same factor del. *) + change ((FT2R a + FT2R b) * (1 + del) + = FT2R a + FT2R a * del + (FT2R b + FT2R b * del))%Re. + lra. + - exists del; split; [auto | apply /RleP; auto]. + - exists del; split; [auto | apply /RleP; auto]. + } + (* Lift pointwise existentials to matrix existentials. *) + destruct (exists_mx F) as [x H0]. + { intros i j. + destruct (Hentry i j) as [e1 [e2 H']]. + exists (e1, e2); auto. } + exists (map_mx fst x), (map_mx snd x). + subst F. + repeat split. + - (* Matrix equality entry-wise. *) + apply matrixP; intros i j. + specialize (H0 i j); simpl in H0. + rewrite !mxE in H0 |- *. + destruct (fun_of_matrix x i j) as [e1 e2]; simpl. + exact (proj1 H0). + - (* Relative error bound on e1. *) + intros i j. + specialize (H0 i j); simpl in H0. + rewrite !mxE in H0 |- *. + destruct (fun_of_matrix x i j); simpl. + exact (proj1 (proj2 H0)). + - (* Relative error bound on e2. *) + intros i j. + specialize (H0 i j); simpl in H0. + rewrite !mxE in H0 |- *. + destruct (fun_of_matrix x i j); simpl. + exact (proj2 (proj2 H0)). Qed. +(** ** Scaled Matrix Sum: Mixed Error Bound *) + +(** [Smat_sumF_mixed_error] shows that the floating-point scaled matrix sum + can be expressed as a sum of two perturbed scaled matrices, by composing + [Fscalemx_mixed_error] and [Faddmx_mixed_error]. Each scaling step + contributes a relative and an absolute error; the outer addition + contributes an additional relative error per summand. *) + Lemma Smat_sumF_mixed_error : - forall [m n] (u v: 'M[ftype t]_(m,n)) (a b : ftype t) + forall [m n] (u v : 'M[ftype t]_(m, n)) (a b : ftype t) (Hfin : F.finitemx (F.addmx (F.scalemx a u) (F.scalemx b v))), - let vr:= map_mx FT2R v in - let ur:= map_mx FT2R u in - exists (e1 e2 e3 e4 e5 e6: 'M[R]_(m,n)), - map_mx FT2R (F.addmx (F.scalemx a u) (F.scalemx b v)) = - ((scalemx (FT2R a) (ur + e1) + e2 + e3) + - (scalemx (FT2R b) (vr + e4) + e5 + e6))%Ri - /\ (forall i j, exists d, e1 i j = ur i j * d /\ Rabs d <= @default_rel t) - /\ (forall i j, exists d, e4 i j = vr i j * d /\ Rabs d <= @default_rel t) - /\ (forall i j, exists d, e3 i j = (scalemx (FT2R a) (ur + e1) + e2) i j * d /\ Rabs d <= @default_rel t)%Ri - /\ (forall i j, exists d, e6 i j = (scalemx (FT2R b) (vr + e4) + e5) i j * d /\ Rabs d <= @default_rel t)%Ri - /\ (forall i j, Rabs (e5 i j) <= @default_abs t) - /\ (forall i j, Rabs (e2 i j) <= @default_abs t). + let vr := map_mx FT2R v in + let ur := map_mx FT2R u in + exists (e1 e2 e3 e4 e5 e6 : 'M[R]_(m, n)), + map_mx FT2R (F.addmx (F.scalemx a u) (F.scalemx b v)) = + ((scalemx (FT2R a) (ur + e1) + e2 + e3) + + (scalemx (FT2R b) (vr + e4) + e5 + e6))%Ri + /\ (forall i j, exists d, + e1 i j = ur i j * d /\ Rabs d <= @default_rel t) + /\ (forall i j, exists d, + e4 i j = vr i j * d /\ Rabs d <= @default_rel t) + /\ (forall i j, exists d, + e3 i j = (scalemx (FT2R a) (ur + e1) + e2) i j * d + /\ Rabs d <= @default_rel t)%Ri + /\ (forall i j, exists d, + e6 i j = (scalemx (FT2R b) (vr + e4) + e5) i j * d + /\ Rabs d <= @default_rel t)%Ri + /\ (forall i j, Rabs (e5 i j) <= @default_abs t) + /\ (forall i j, Rabs (e2 i j) <= @default_abs t). Proof. -intros. -simpl. -destruct (F.finitemx_addmx_e _ _ Hfin) as [HfinA HfinB]. -destruct (Faddmx_mixed_error _ _ Hfin) as (Du & Dv & Heq & HD). -rewrite {}Heq. -destruct (Fscalemx_mixed_error a u HfinA) as (ae & aeta & Heqa & Hea & Haeta). -destruct (Fscalemx_mixed_error b v HfinB) as [be [beta [Heqb [Heb Hbeta]]]]. -move :HD; rewrite {}Heqa {}Heqb => HD. -destruct HD as [HDu HDv]. -exists ae, aeta ,Du, be, beta, Dv. -repeat split => //. + intros m n u v a b Hfin vr ur. + simpl. + (* Decompose finiteness of the sum into finiteness of each scaled term. *) + destruct (F.finitemx_addmx_e _ _ Hfin) as [HfinA HfinB]. + (* Apply the addition mixed error bound to the outer sum. *) + destruct (Faddmx_mixed_error _ _ Hfin) as (Du & Dv & Heq & HDu & HDv). + rewrite {}Heq. + (* Apply the scalar multiplication mixed error bound to each term. *) + destruct (Fscalemx_mixed_error a u HfinA) as (ae & aeta & Heqa & Hea & Haeta). + destruct (Fscalemx_mixed_error b v HfinB) as (be & beta & Heqb & Heb & Hbeta). + (* Substitute the scale decompositions into the addition error terms. *) + move : HDu HDv; rewrite {}Heqa {}Heqb => HDu HDv. + exists ae, aeta, Du, be, beta, Dv. + repeat split => //. Qed. -Lemma Smat_vec_mul_mixed_error: - forall [m n] (b: ftype t) (A: 'M[ftype t]_(m,n)) (B: 'M[ftype t]_(n,1)) - (Hfin: F.finitemx (F.scalemx b (F.mulmx A B))), - exists (E : 'M[R]_(m,n)) (e eta1 eta2 : 'M[R]_(m,1)), +(** ** Scaled Matrix-Vector Product: Mixed Error Bound *) + +(** [Smat_vec_mul_mixed_error] shows that the floating-point scaled + matrix-vector product can be expressed as an exact scaled result with + a forward error on the inner matrix-vector multiply (bounded in terms of + %$g(n)$%#\(g(n)\)#) and a mixed error from the outer scalar multiplication. *) + +Lemma Smat_vec_mul_mixed_error : + forall [m n] (b : ftype t) (A : 'M[ftype t]_(m, n)) (B : 'M[ftype t]_(n, 1)) + (Hfin : F.finitemx (F.scalemx b (F.mulmx A B))), + exists (E : 'M[R]_(m, n)) (e eta1 eta2 : 'M[R]_(m, 1)), map_mx FT2R (F.scalemx b (F.mulmx A B)) = - (scalemx (FT2R b) ((map_mx FT2R A + E) *m (map_mx FT2R B) + eta1 + e) + eta2 )%Ri + (scalemx (FT2R b) + ((map_mx FT2R A + E) *m map_mx FT2R B + eta1 + e) + eta2)%Ri /\ (forall i j, Rabs (E i j) <= g n * Rabs (map_mx FT2R A i j)) /\ (forall i j, Rabs (eta2 i j) <= @default_abs t) - /\ (forall i j, exists d, e i j = FT2R (F.mulmx A B i j) * d /\ Rabs d <= @default_rel t) - /\ (forall i j, Rabs (eta1 i j) <= g1 n n). + /\ (forall i j, exists d, + e i j = FT2R (F.mulmx A B i j) * d /\ Rabs d <= @default_rel t) + /\ (forall i j, Rabs (eta1 i j) <= g1 n n). Proof. -intros. -destruct (Fscalemx_mixed_error _ _ Hfin) as (e & eta & Heq & Hea & Hetaa). -rewrite {}Heq in Hea|-*. -destruct (mat_vec_mul_mixed_error A B) - as (E & eta1 & Heq1 & H1). -apply (F.finitemx_scalemx_e _ _ Hfin). -rewrite {}Heq1. -destruct H1 as [H0 H1]. -exists E, e, eta1, eta; repeat split => //. -simpl. -move => i j. destruct (Hea i j) as [d H2]. -exists d. -rewrite !mxE. rewrite mxE in H2. -unfold F.mulmx in H2. -rewrite mxE in H2. -auto. + intros m n b A B Hfin. + (* Apply the scalar multiply mixed error bound to b * (A*x). *) + destruct (Fscalemx_mixed_error _ _ Hfin) as (e & eta & Heq & Hea & Hetaa). + rewrite {}Heq in Hea |- *. + (* Apply the forward error bound for floating-point matrix-vector multiply. *) + destruct (mat_vec_mul_mixed_error A B) as (E & eta1 & Heq1 & He1 & Heta1). + { apply (F.finitemx_scalemx_e _ _ Hfin). } + rewrite {}Heq1. + exists E, e, eta1, eta. + repeat split => //. + (* Rewrite the relative error on e in terms of the mulmx entry. *) + intros i j. + destruct (Hea i j) as [d H2]. + exists d. + rewrite !mxE; rewrite mxE in H2. + unfold F.mulmx in H2. + rewrite mxE in H2. + exact H2. Qed. -Lemma gemv_error: - forall [m n] (A: 'M[ftype t]_(m,n)) (x: 'cV[ftype t]_n) (y: 'cV[ftype t]_m) (s1 s2: ftype t) - (Hfin: F.finitemx (F.addmx (F.scalemx s1 (F.mulmx A x)) (F.scalemx s2 y))), +(** ** General Matrix-Vector Multiply-Accumulate (GEMV): Mixed Error Bound *) + +(** [gemv_error] is the central result of this file. It shows that the + floating-point GEMV operation %$s_1 A x + s_2 y$%#\(s_1 A x + s_2 y\)# + can be expressed as an exact result plus structured error matrices, + combining the bounds from [Smat_sumF_mixed_error] and + [mat_vec_mul_mixed_error]. *) + +Lemma gemv_error : + forall [m n] (A : 'M[ftype t]_(m, n)) (x : 'cV[ftype t]_n) + (y : 'cV[ftype t]_m) (s1 s2 : ftype t) + (Hfin : F.finitemx + (F.addmx (F.scalemx s1 (F.mulmx A x)) (F.scalemx s2 y))), exists e1 e2 e3 e4 e5 e6 e7 e8, - map_mx FT2R (F.addmx (F.scalemx s1 (F.mulmx A x)) (F.scalemx s2 y)) = - ((scalemx (FT2R s1) ((((map_mx FT2R A + e1) *m (map_mx FT2R x)) + e2) + e3) + e4) + e5) + - ((scalemx (FT2R s2) (map_mx FT2R y + e6) + e7) + e8) - /\ (forall i j, Rabs (e1 i j) <= g n * Rabs (map_mx FT2R A i j)) - /\ (forall i j, Rabs (e2 i j) <= g1 n n) - /\ (forall i j, exists d, e3 i j = (((map_mx FT2R A + e1) *m map_mx FT2R x) + e2)%Ri i j * d /\ Rabs d <= @default_rel t) - /\ (forall i j, exists d, e6 i j = map_mx FT2R y i j * d /\ Rabs d <= @default_rel t) - /\ (forall i j, exists d, e5 i j = (scalemx (FT2R s1) ((((map_mx FT2R A + e1) *m map_mx FT2R x) + e2) + e3) + e4) i j * d - /\ Rabs d <= @default_rel t) - /\ (forall i j, exists d, e8 i j = (scalemx (FT2R s2) (map_mx FT2R y + e6) + e7) i j * d /\ Rabs d <= @default_rel t) - /\ (forall i j, Rabs (e7 i j) <= @default_abs t) - /\ (forall i j, Rabs (e4 i j) <= @default_abs t). + map_mx FT2R (F.addmx (F.scalemx s1 (F.mulmx A x)) (F.scalemx s2 y)) = + ((scalemx (FT2R s1) + ((((map_mx FT2R A + e1) *m map_mx FT2R x) + e2) + e3) + e4) + e5) + + ((scalemx (FT2R s2) (map_mx FT2R y + e6) + e7) + e8) + /\ (forall i j, Rabs (e1 i j) <= g n * Rabs (map_mx FT2R A i j)) + /\ (forall i j, Rabs (e2 i j) <= g1 n n) + /\ (forall i j, exists d, + e3 i j = + (((map_mx FT2R A + e1) *m map_mx FT2R x) + e2)%Ri i j * d + /\ Rabs d <= @default_rel t) + /\ (forall i j, exists d, + e6 i j = map_mx FT2R y i j * d /\ Rabs d <= @default_rel t) + /\ (forall i j, exists d, + e5 i j = + (scalemx (FT2R s1) + ((((map_mx FT2R A + e1) *m map_mx FT2R x) + e2) + e3) + e4) i j * d + /\ Rabs d <= @default_rel t) + /\ (forall i j, exists d, + e8 i j = + (scalemx (FT2R s2) (map_mx FT2R y + e6) + e7) i j * d + /\ Rabs d <= @default_rel t) + /\ (forall i j, Rabs (e7 i j) <= @default_abs t) + /\ (forall i j, Rabs (e4 i j) <= @default_abs t). Proof. -intros. -(* proof follows from previous bounds for axpby and mul *) -destruct (Smat_sumF_mixed_error (F.mulmx A x) y s1 s2) - as (e3 & e4 & e5 & e6 & e7 & e8 & Heq1 & H1) => //. -rewrite {}Heq1. -destruct (mat_vec_mul_mixed_error A x) - as (e1 & e2 & Heq2 & H2). -apply F.finitemx_addmx_e in Hfin; destruct Hfin as [Hfin _]. -apply (F.finitemx_scalemx_e _ _ Hfin). -rewrite {}Heq2 in H1|-*. -destruct H2 as (He1 & He2). -destruct H1 as (He3 & He6 & He5 & He4 & He7 & He8). -simpl in *. -exists e1, e2, e3, e4, e5, e6, e7, e8; repeat split => /= //. + intros m n A x y s1 s2 Hfin. + (* Decompose s1*(A*x) + s2*y into scaled-sum error components. + This yields error matrices e3..e8 covering the two scalings and + the outer vector addition. *) + destruct (Smat_sumF_mixed_error (F.mulmx A x) y s1 s2 Hfin) + as (e3 & e4 & e5 & e6 & e7 & e8 + & Heq1 & He3 & He6 & He5 & He8 & He7 & He4). + rewrite {}Heq1. + (* Extract finiteness of A*x from the finiteness of s1*(A*x) + s2*y, + which is needed to apply the matrix-vector multiply error bound. *) + have HfinAx : F.finitemx (F.mulmx A x). + { apply F.finitemx_addmx_e in Hfin. + destruct Hfin as [HfinL _]. + exact: (F.finitemx_scalemx_e _ _ HfinL). } + (* Apply the mixed error bound for floating-point matrix-vector multiply. + This yields error matrices e1 (componentwise relative on A) and + e2 (absolute, from accumulated dot product rounding). *) + destruct (mat_vec_mul_mixed_error A x HfinAx) + as (e1 & e2 & Heq2 & He1 & He2). + (* Rewrite the mat-vec result in He3, He5, and the goal simultaneously. + He5 references map_mx FT2R (F.mulmx A x) inside the scalemx term, + so it must be rewritten alongside He3 and the main goal. *) + rewrite {}Heq2 in He3 He5 |- *. + exists e1, e2, e3, e4, e5, e6, e7, e8. + repeat split => /= //. Qed. - -End WithNans. - - - - + +End WithNans. \ No newline at end of file diff --git a/accuracy_proofs/vecnorm_acc.v b/accuracy_proofs/vecnorm_acc.v index 75f6ac8..6719de9 100644 --- a/accuracy_proofs/vecnorm_acc.v +++ b/accuracy_proofs/vecnorm_acc.v @@ -1,63 +1,147 @@ -From LAProof.accuracy_proofs Require Import preamble common - dotprod_model sum_model - float_acc_lems dot_acc sum_acc. +(** * Floating-Point Vector Norm Accuracy Bounds -Section TwoNorm. -Context {NAN: FPCore.Nans} {t : type}. + This file establishes error bounds for floating-point computations of + the two-norm (Euclidean) and one-norm of a vector. -Definition two_normF (x: list (ftype t)) : R := sqrt (FT2R (dotprodF x x)). -Definition two_normR (x: list R) : R := sqrt (dotprodR x x). + ** Error Factors + + Throughout, the accumulated relative error factor is + %$g(n) = (1 + \mathbf{u})^n - 1$%#\(g(n) = (1 + \mathbf{u})^n - 1\)# and + the mixed absolute error factor is + %$g_1(n_1, n_2) = n_1 \cdot \eta \cdot (1 + g(n_2))$%#\(g_1(n_1, n_2) = n_1 \cdot \eta \cdot (1 + g(n_2))\)#, + where %$\mathbf{u}$%#\(\mathbf{u}\)# is the unit roundoff and + %$\eta$%#\(\eta\)# is the underflow threshold for the given type. + Both are defined in [common]. + + ** Main Results + + - [bfVNRM2]: Shows that the floating-point two-norm can be expressed as + the square root of a mixed-error dot product: one copy of the input + appears componentwise perturbed, and a small absolute residual accounts + for underflow. + + - [bfVNRM1]: Shows that the floating-point one-norm equals the exact + one-norm of a slightly perturbed input vector. + + ** Dependencies + + This file relies on: + - [preamble], [common]: basic setup and shared definitions + - [dotprod_model], [sum_model]: relational models of dot product and summation + - [float_acc_lems]: elementary floating-point accuracy lemmas + - [dot_acc], [sum_acc]: dot product and summation accuracy theorems +*) + +From LAProof.accuracy_proofs Require Import + preamble + common + dotprod_model + sum_model + float_acc_lems + dot_acc + sum_acc. + +(** * Two-Norm *) + +Section TwoNorm. + +Context {NAN : FPCore.Nans} {t : type}. + +(** The floating-point two-norm: the square root of the floating-point dot + product of a vector with itself, coerced to a real number. *) + +Definition two_normF (x : list (ftype t)) : R := + sqrt (FT2R (dotprodF x x)). + +(** The exact real-valued two-norm: the square root of the real dot product + of a vector with itself. *) + +Definition two_normR (x : list R) : R := + sqrt (dotprodR x x). Variable (x : list (ftype t)). -Notation xR := (map FT2R x). -Notation n:= (size x). -Hypothesis Hfin: Binary.is_finite (dotprodF x x) = true. -Notation g := (@g t). -Notation g1 := (@g1 t). +Notation xR := (map FT2R x). +Notation n := (size x). +Notation g := (@g t). +Notation g1 := (@g1 t). Notation neg_zero := (@common.neg_zero t). -(* two norm mixed error bound *) -Lemma bfVNRM2: +(** We assume the floating-point dot product [dotprodF x x] is finite, + which is necessary for [two_normF x] to be well-defined. *) +Hypothesis Hfin : Binary.is_finite (dotprodF x x) = true. + +(** [bfVNRM2] expresses the floating-point two-norm as the square root of + a mixed-error dot product. One copy of the input appears componentwise + perturbed by a relative factor bounded by %$g(n)$%#\(g(n)\)#, and an + absolute residual bounded by %$g_1(n, n)$%#\(g_1(n,n)\)# accounts for + underflow. *) + +Lemma bfVNRM2 : exists (x' : list R) (eta : R), - two_normF x = sqrt (dotprodR x' xR + eta) /\ - (forall m, (m < n)%nat -> exists delta, - nth 0 x' m = FT2R (nth neg_zero x m) * (1 + delta) /\ Rabs delta <= g n) /\ + two_normF x = sqrt (dotprodR x' xR + eta) /\ + (forall m, (m < n)%nat -> + exists delta, + nth 0 x' m = FT2R (nth neg_zero x m) * (1 + delta) /\ + Rabs delta <= g n) /\ Rabs eta <= g1 n n. Proof. -destruct (dotprod_mixed_error x x Logic.eq_refl Hfin) - as (x' & eta & Hlen & Hrel & H1 & H2). -exists x', eta; repeat split; auto. -unfold two_normF, two_normR. -rewrite Hrel. f_equal; nra. + destruct (dotprod_mixed_error x x Logic.eq_refl Hfin) + as (x' & eta & Hlen & Hrel & H1 & H2). + exists x', eta. + repeat split; auto. + unfold two_normF, two_normR. + rewrite Hrel. + f_equal; nra. Qed. -End TwoNorm. +End TwoNorm. + +(** * One-Norm *) Section OneNorm. -Context {NAN: FPCore.Nans} {t : type}. -Definition one_normF (x: list (ftype t)) : R := FT2R (sumF x). -Definition one_normR (x: list R) : R := fold_right Rplus 0 x. +Context {NAN : FPCore.Nans} {t : type}. + +(** The floating-point one-norm: the real value of the floating-point + left-to-right accumulation of the vector entries. *) + +Definition one_normF (x : list (ftype t)) : R := + FT2R (sumF x). + +(** The exact real-valued one-norm: the sum of a list of real numbers. *) + +Definition one_normR (x : list R) : R := + fold_right Rplus 0 x. Variables (x : list (ftype t)). -Hypothesis Hfin: Binary.is_finite (sumF x) = true. -Notation xR := (map FT2R x). -Notation n:= (size x). -Notation g := (@g t). +(** We assume the floating-point sum [sumF x] is finite, which is necessary + for [one_normF x] to be well-defined. *) + +Hypothesis Hfin : Binary.is_finite (sumF x) = true. + +Notation xR := (map FT2R x). +Notation n := (size x). +Notation g := (@g t). Notation neg_zero := (@common.neg_zero t). -(* one norm backward error bound *) -Lemma bfVNRM1: - exists (x': list R), +(** [bfVNRM1] shows that the floating-point one-norm equals the exact + one-norm of a perturbed input vector, where each component is perturbed + by a relative factor bounded by %$g(n-1)$%#\(g(n-1)\)#. *) + +Lemma bfVNRM1 : + exists (x' : list R), one_normF x = one_normR x' /\ - (forall m, (m < n)%nat -> exists delta, - nth 0 x' m = FT2R (nth neg_zero x m) * (1 + delta) /\ Rabs delta <= g (n - 1)). + (forall m, (m < n)%nat -> + exists delta, + nth 0 x' m = FT2R (nth neg_zero x m) * (1 + delta) /\ + Rabs delta <= g (n - 1)). Proof. -destruct (bSUM x Hfin) as (x' & Hlen & Hrel & Hn). + destruct (bSUM x Hfin) as (x' & Hlen & Hrel & Hn). rewrite Hlen in Hn. -exists x'; repeat split; auto. + exists x'. + repeat split; auto. Qed. End OneNorm. \ No newline at end of file diff --git a/header.html b/header.html new file mode 100644 index 0000000..49cebaa --- /dev/null +++ b/header.html @@ -0,0 +1,23 @@ + + + + + +LAProof + + + + + +
+ + + +
diff --git a/html/index.html b/html/index.html index 455f0ed..def032f 100644 --- a/html/index.html +++ b/html/index.html @@ -109,4 +109,4 @@

Index of specifications and proofs

- + \ No newline at end of file From 08a8374e45eebe0d8d97d073682d088dd34e512f Mon Sep 17 00:00:00 2001 From: Andrew Appel Date: Tue, 10 Mar 2026 10:06:52 -0400 Subject: [PATCH 2/2] Makefile updates; doc edits; Rocq 9.1 / MathComp 2.5 compat 1. Move the reference copy of index.html out of html/ directory, because Makefile.coq's "make clean" deletes entire html directory. 2. A few minor fixes for compatibility with MathComp 2.5 and/or Rocq 9.1. All but one of these should be backward compatible. The change at line 326 of C/verif_densemat_cholesky.v is probably not backward compatible; this may need some work. 3. Minor edits to the comments in one or two files. --- C/verif_densemat_cholesky.v | 2 +- C/verif_sparse_byrows.v | 2 +- How-to-document.md | 6 ++++-- Makefile.coq.local | 7 ++++++- accuracy_proofs/common.v | 20 +++++++++++++++++++- accuracy_proofs/gemv_acc.v | 2 +- accuracy_proofs/libvalidsdp.v | 6 +++--- html/index.html => index.html | 0 8 files changed, 35 insertions(+), 10 deletions(-) rename html/index.html => index.html (100%) diff --git a/C/verif_densemat_cholesky.v b/C/verif_densemat_cholesky.v index 764e1f1..d23cfd6 100644 --- a/C/verif_densemat_cholesky.v +++ b/C/verif_densemat_cholesky.v @@ -323,7 +323,7 @@ forward_for_simple_bound (Z.of_nat n) (EX i:Z, set (al := seq.foldl _ _ _). subst bi. change (fstep i) with al. unfold forward_subst_step. - change ssralg.GRing.zero with (@ord0 O). + (* change ssralg.GRing.zero with (@ord0 O). *) (* MathComp 2.4 *) change @seq.map with @map. rewrite take_sublist. set (uu := BDIV _ _). diff --git a/C/verif_sparse_byrows.v b/C/verif_sparse_byrows.v index ccb76da..bcc165a 100644 --- a/C/verif_sparse_byrows.v +++ b/C/verif_sparse_byrows.v @@ -175,7 +175,7 @@ forward_for_simple_bound (Zlength mval) unfold matrix_rows; subst i. clear H10 H11 H12 H13 H9 H8 PNp PNv PNm H6. list_solve. - - Intro result. Exists result. + Intro r; Exists r. unfold matrix_rows in *. list_simplify. entailer!. unfold matrix_vector_mult in H9 |- *. diff --git a/How-to-document.md b/How-to-document.md index 7e3ed19..dfcea16 100644 --- a/How-to-document.md +++ b/How-to-document.md @@ -2,11 +2,13 @@ ## Rocq documentation (in .v files) -1. Edit html/index.html as you see fit. Do not edit any files of the form html/LAProof.*, as these are automatically generated by coqdoc. +0. Because Makefile.coq's "make clean" removes the entire html directory, we cannot store the reference copy of index.html there. Instead, index.html is in the root directory, and copied into the html directory by "make html2". + +1. Edit index.html as you see fit. Do not edit any files of the form html/LAProof.*, as these are automatically generated by coqdoc. 2. Edit the comments in any of the .v files, using [Coqdoc markup](https://rocq-prover.org/doc/V8.20.0/refman/using/tools/coqdoc.html). -3. `make html` creates all the Coqdoc output in the html/ directory. Browse and review these local files (by doing `open file` in your browser) and edit the .v files until this looks the way you want it. +3. `make html2` creates all the Coqdoc output in the html/ directory. Browse and review these local files (by doing `open file` in your browser) and edit the .v files until this looks the way you want it. 4. `make publish` sends all those html files (including index.html if you have edited it) to the Github pages site at https://verinum.org/LAProof/. The way it does this is by committing them to the special gh-pages branch of this repo, which contains _only_ a docs directory with those HTML files. After you do `make publish`, it will take several minutes before the changes appear at verinum.org/LAProof. diff --git a/Makefile.coq.local b/Makefile.coq.local index 79ac1ea..e8b7717 100644 --- a/Makefile.coq.local +++ b/Makefile.coq.local @@ -4,7 +4,12 @@ COQDOCEXTRAFLAGS= -g --no-lib-name --with-header header.html --index genindex -- accuracy: accuracy_proofs/export.vo C: C/verif_alloc.vo C/verif_sparse.vo C/verif_sparse_byrows.vo C/VSU_densemat.vo C/verif_build_csr.vo -publish: html cdocs +# Ugh. We can't just store the reference copy of index.html in html/index.html, because Makefile.coq's +# standard "make clean" removes the html directory entirely. So we gotta do this hack: +html2: html index.html + cp index.html html/index.html + +publish: html2 cdocs cd gh-pages; git submodule update cd gh-pages; rm -rf docs; mkdir docs cp html/* gh-pages/docs diff --git a/accuracy_proofs/common.v b/accuracy_proofs/common.v index 45c30e7..1ff4a36 100644 --- a/accuracy_proofs/common.v +++ b/accuracy_proofs/common.v @@ -7,6 +7,12 @@ The main concepts defined here are: + - _Floating-point type_: We use VCFloat's concept of a floating-point [type], + that contains a precision (mantissa size), max-exponent, and some properties + of them; all based on Flocq's underlying constructions. Examples are + [Tdouble] (64-bit IEEE double precision) and [Tsingle] (32-bit), and any other + (legal) combination of [fprec] and [femax] that one might need. + - _Floating-point rounding_: The function [rounded] captures round-to-nearest-even (RNE) rounding in radix-2 floating point, parameterized by a floating-point type << t >> that fixes the precision @@ -49,7 +55,9 @@ floating-point operations. Key properties include [g_pos], [le_g_Sn] (monotonicity), and the recurrence [one_plus_d_mul_g], which expresses how one additional rounding - step advances the bound. + step advances the bound. See also: Kellison et al., "LAProof: A Library of + Formal Proofs of Accuracy and Correctness of Linear Algebra Programs", + 2023, equation (4), where it is called "h". - _Mixed absolute/relative error accumulation factor_ [g1 n1 n2]: @@ -60,6 +68,7 @@ each amplified by up to << (1 + default_rel)^n2 >> subsequent multiplications. Numerous lemmas establish how [g1] grows as its arguments increase, supporting inductive error analyses. + See also: Kellison et al., "LAProof:...", equation (5). _Hint database_: All positivity, monotonicity, and ordering lemmas for [default_rel], [default_abs], [g], and [g1] are registered in @@ -108,6 +117,15 @@ Proof. destruct l; try congruence; compute; lia. Qed. +(** ** Parameterization by Nans and type + + Any IEEE floating-point implementation must instantiate, + - precision (mantissa size) and exponent size + - propagation rules for Not-a-Number + The Rocq types for these are, respectively, [type] and [FPCore.Nans]. LAProof's + accuracy proofs are, in general, parameterized to work with any instantiation of these. + We express that by Rocq's [Section] and [Context] commands. +*) Section WithType. Context {NAN : FPCore.Nans} {t : type}. diff --git a/accuracy_proofs/gemv_acc.v b/accuracy_proofs/gemv_acc.v index b55a64d..caa3bb5 100644 --- a/accuracy_proofs/gemv_acc.v +++ b/accuracy_proofs/gemv_acc.v @@ -237,7 +237,7 @@ Proof. - apply /RleP; auto with commonDB. - apply /RleP; auto with commonDB. - rewrite /normv. - apply bigmax_le => [| i _]. + apply @bigmax_le => [ | i _]. apply /RleP; auto with commonDB. auto. Qed. diff --git a/accuracy_proofs/libvalidsdp.v b/accuracy_proofs/libvalidsdp.v index 85c2535..7e1caf6 100644 --- a/accuracy_proofs/libvalidsdp.v +++ b/accuracy_proofs/libvalidsdp.v @@ -614,9 +614,9 @@ pose proof @cholesky.lemma_2_1 fspec fspec_eta_nonzero k a' b' (mkFS c) (mkFS bk repeat change (float_spec.FS_val (mkFS ?x)) with (FT2R x) in H|-*. rewrite LVSDP_ytilded_eq in H; auto. replace (\sum_i (float_spec.FS_val _ * _)) with (\sum_i (FT2R (fun_of_fin a i) * (FT2R (b i)))) in H. -2: apply eq_big; auto; [ move => x // | move => i _; rewrite /a' /b' !ffunE //]. +2: apply eq_big; auto; try solve [move => x //]; move => i _; rewrite /a' /b' !ffunE //. replace (\sum_i Rabs (float_spec.FS_val _ * _)) with (\sum_i Rabs (FT2R (fun_of_fin a i) * (FT2R (b i)))) in H. -2: apply eq_big; auto; [ move => x // | move => i _; rewrite /a' /b' !ffunE //]. +2: apply eq_big; auto; try solve [move => x //]; move => i _; rewrite /a' /b' !ffunE //. rewrite default_abs_eq default_rel_eq. apply H. Qed. @@ -701,7 +701,7 @@ clear H. destruct H0. change i with (nat_of_ord (Ordinal H)). rewrite nth_take. rewrite nth_ord_enum'. -simpl. lia. +first [simpl; lia | rewrite ltEord; simpl; lia]. (* compatibility Rocq 9.0, mathcomp 2.4 | 9.1,2.5*) simpl; lia. Qed. diff --git a/html/index.html b/index.html similarity index 100% rename from html/index.html rename to index.html