diff --git a/banquet.cpp b/banquet.cpp index bfa7821..141cbb3 100644 --- a/banquet.cpp +++ b/banquet.cpp @@ -685,17 +685,28 @@ banquet_signature_t banquet_sign(const banquet_instance_t &instance, std::vector lagrange_polys_evaluated_at_Re_2m2(2 * instance.m2 + 1); + size_t size_m2 = precomputation_for_zero_to_m2[0].size(); + size_t size_2m2 = precomputation_for_zero_to_2m2[0].size(); + for (size_t repetition = 0; repetition < instance.num_rounds; repetition++) { c_shares[repetition].resize(instance.num_MPC_parties); a[repetition].resize(instance.m1); b[repetition].resize(instance.m1); + + std::vector precomputed_m2_eval_x_pow_n = + field::eval_precompute(R_es[repetition], size_m2); + std::vector precomputed_2m2_eval_x_pow_n = + field::eval_precompute(R_es[repetition], size_2m2); + for (size_t k = 0; k < instance.m2 + 1; k++) { lagrange_polys_evaluated_at_Re_m2[k] = - field::eval(precomputation_for_zero_to_m2[k], R_es[repetition]); + field::eval_fast(precomputation_for_zero_to_m2[k], + precomputed_m2_eval_x_pow_n, instance.lambda); } for (size_t k = 0; k < 2 * instance.m2 + 1; k++) { lagrange_polys_evaluated_at_Re_2m2[k] = - field::eval(precomputation_for_zero_to_2m2[k], R_es[repetition]); + field::eval_fast(precomputation_for_zero_to_2m2[k], + precomputed_2m2_eval_x_pow_n, instance.lambda); } for (size_t party = 0; party < instance.num_MPC_parties; party++) { @@ -1104,13 +1115,22 @@ bool banquet_verify(const banquet_instance_t &instance, const banquet_repetition_proof_t &proof = signature.proofs[repetition]; size_t missing_party = missing_parties[repetition]; + std::vector precomputed_m2_eval_x_pow_n = + field::eval_precompute(R_es[repetition], + precomputation_for_zero_to_m2[0].size()); + std::vector precomputed_2m2_eval_x_pow_n = + field::eval_precompute(R_es[repetition], + precomputation_for_zero_to_2m2[0].size()); + for (size_t k = 0; k < instance.m2 + 1; k++) { lagrange_polys_evaluated_at_Re_m2[k] = - field::eval(precomputation_for_zero_to_m2[k], R_es[repetition]); + field::eval_fast(precomputation_for_zero_to_m2[k], + precomputed_m2_eval_x_pow_n, instance.lambda); } for (size_t k = 0; k < 2 * instance.m2 + 1; k++) { lagrange_polys_evaluated_at_Re_2m2[k] = - field::eval(precomputation_for_zero_to_2m2[k], R_es[repetition]); + field::eval_fast(precomputation_for_zero_to_2m2[k], + precomputed_2m2_eval_x_pow_n, instance.lambda); } c_shares[repetition].resize(instance.num_MPC_parties); diff --git a/field.cpp b/field.cpp index ed860d9..6a2438f 100644 --- a/field.cpp +++ b/field.cpp @@ -32,6 +32,29 @@ inline __m128i clmul(uint64_t a, uint64_t b) { return _mm_clmulepi64_si128(_mm_set_epi64x(0, a), _mm_set_epi64x(0, b), 0); } +// actually this is slower than the above two :( +uint64_t reduce_GF2_16_clmul(const __m128i in) { + // modulus = x^16 + x^5 + x^3 + x + 1 + __m128i p = _mm_set_epi64x(0x0, 0x2B); + __m128i mask = _mm_set_epi64x(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFF0000); + __m128i t; + + __m128i hi = _mm_srli_si128(in, 2); // extracting the in_hi + __m128i low = + _mm_xor_si128(_mm_or_si128(in, mask), mask); // extracting the low + + t = _mm_clmulepi64_si128(hi, p, 0x00); // in_hi_low(0x00) * p + t = _mm_xor_si128(t, low); // 4 + 16 -> Length after xor + + hi = _mm_srli_si128(t, 2); // extracting the t_hi + low = _mm_xor_si128(_mm_or_si128(t, mask), mask); // extracting the low + + t = _mm_clmulepi64_si128(hi, p, 0x00); // t_hi_low(0x00) * p + t = _mm_xor_si128(t, low); // 16 -> Length after xor + + return _mm_extract_epi64(t, 0); +} + uint64_t reduce_GF2_16(__m128i in) { // modulus = x^16 + x^5 + x^3 + x + 1 constexpr uint64_t lower_mask = 0xFFFFULL; @@ -45,6 +68,28 @@ uint64_t reduce_GF2_16(__m128i in) { return lower_mask & R_lower; } +// actually this is slower than the above two :( +uint64_t reduce_GF2_32_clmul(const __m128i in) { + // modulus = x^32 + x^7 + x^3 + x^2 + 1 + __m128i p = _mm_set_epi64x(0x0, 0x8d); + __m128i mask = _mm_set_epi64x(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFF00000000); + __m128i t; + + __m128i hi = _mm_srli_si128(in, 4); // extracting the in_hi + __m128i low = + _mm_xor_si128(_mm_or_si128(in, mask), mask); // extracting the low + + t = _mm_clmulepi64_si128(hi, p, 0x00); // in_hi_low(0x00) * p + t = _mm_xor_si128(t, low); // 4 + 32 -> Length after xor + + hi = _mm_srli_si128(t, 4); // extracting the t_hi + low = _mm_xor_si128(_mm_or_si128(t, mask), mask); // extracting the low + + t = _mm_clmulepi64_si128(hi, p, 0x00); // t_hi_low(0x00) * p + t = _mm_xor_si128(t, low); // 16 -> Length after xor + + return _mm_extract_epi64(t, 0); +} // actually a bit slowerthan naive version below __attribute__((unused)) uint64_t reduce_GF2_32_barret(__m128i in) { // modulus = x^32 + x^7 + x^3 + x^2 + 1 @@ -69,6 +114,28 @@ uint64_t reduce_GF2_32(__m128i in) { return lower_mask & R_lower; } +uint64_t reduce_GF2_40_clmul(const __m128i in) { + // modulus = x^40 + x^5 + x^4 + x^3 + 1 + __m128i p = _mm_set_epi64x(0x0, 0x39); + __m128i mask = _mm_set_epi64x(0xFFFFFFFFFFFFFFFF, 0xFFFFFF0000000000); + __m128i t; + + __m128i hi = _mm_srli_si128(in, 5); // extracting the hi + __m128i low = + _mm_xor_si128(_mm_or_si128(in, mask), mask); // extracting the low + + t = _mm_clmulepi64_si128(hi, p, 0x00); // hi_low(0x00) * p + t = _mm_xor_si128(t, low); // 4 + 40 -> Length after xor + + hi = _mm_srli_si128(t, 5); // extracting the hi + low = _mm_xor_si128(_mm_or_si128(t, mask), mask); // extracting the low + + t = _mm_clmulepi64_si128(hi, p, 0x00); // hi_low(0x00) * p + t = _mm_xor_si128(t, low); // 40 -> Length after xor + + return _mm_extract_epi64(t, 0); +} + uint64_t reduce_GF2_40(__m128i in) { // modulus = x^40 + x^5 + x^4 + x^3 + 1 constexpr uint64_t upper_mask = 0xFFFFULL; @@ -84,6 +151,29 @@ uint64_t reduce_GF2_40(__m128i in) { return lower_mask & R_lower; } +// actually this is slower than the above two :( +uint64_t reduce_GF2_48_clmul(const __m128i in) { + // modulus = x^48 + x^5 + x^3 + x^2 + 1 + __m128i p = _mm_set_epi64x(0x0, 0x2d); + __m128i mask = _mm_set_epi64x(0xFFFFFFFFFFFFFFFF, 0xFFFF000000000000); + __m128i t; + + __m128i hi = _mm_srli_si128(in, 6); // extracting the hi + __m128i low = + _mm_xor_si128(_mm_or_si128(in, mask), mask); // extracting the low + + t = _mm_clmulepi64_si128(hi, p, 0x00); // hi_low(0x00) * p + t = _mm_xor_si128(t, low); // 4 + 48 -> Length after xor + + hi = _mm_srli_si128(t, 6); // extracting the hi + low = _mm_xor_si128(_mm_or_si128(t, mask), mask); // extracting the low + + t = _mm_clmulepi64_si128(hi, p, 0x00); // hi_low(0x00) * p + t = _mm_xor_si128(t, low); // 48 -> Length after xor + + return _mm_extract_epi64(t, 0); +} + uint64_t reduce_GF2_48(__m128i in) { // modulus = x^48 + x^5 + x^3 + x^2 + 1 constexpr uint64_t upper_mask = 0xFFFFFFFFULL; @@ -179,6 +269,8 @@ std::vector GF2E::to_bytes() const { return buffer; } +uint64_t GF2E::get_data() const { return this->data; } + void GF2E::from_bytes(uint8_t *in) { data = 0; memcpy((uint8_t *)(&data), in, byte_size); @@ -385,6 +477,41 @@ std::vector build_from_roots(const std::vector &roots) { poly[len] = GF2E(1); return poly; } +// normal eval precomputation +std::vector eval_precompute(const GF2E &point, size_t poly_size) { + std::vector out; + out.reserve(poly_size); + + GF2E temp = point; + out.push_back(temp); + for (size_t i = 1; i < poly_size; ++i) { + temp *= point; + out.push_back(temp); + } + return out; +} + +// normal optmized polynomial evaluation with precomputation optmization +GF2E eval_fast(const std::vector &poly, const std::vector &x_pow_n, + const size_t lambda) { + __m128i acc = _mm_set_epi64x(0, poly[0].get_data()); + for (size_t i = 1; i < poly.size(); ++i) { + acc = acc ^ clmul(poly[i].get_data(), x_pow_n[i - 1].get_data()); + } + + switch (lambda) { + case 2: + return GF2E(reduce_GF2_16_clmul(acc)); + case 4: + return GF2E(reduce_GF2_32_clmul(acc)); + case 5: + return GF2E(reduce_GF2_40_clmul(acc)); + case 6: + return GF2E(reduce_GF2_48_clmul(acc)); + default: + return GF2E(reduce_GF2_32_clmul(acc)); + } +} // horner eval GF2E eval(const std::vector &poly, const GF2E &point) { GF2E acc; diff --git a/field.h b/field.h index a489872..78c68d8 100644 --- a/field.h +++ b/field.h @@ -54,6 +54,8 @@ class GF2E { friend GF2E(::dot_product)(const std::vector &lhs, const std::vector &rhs); + + uint64_t get_data() const; }; const GF2E &lift_uint8_t(uint8_t value); @@ -66,6 +68,9 @@ std::vector interpolate_with_precomputation( const std::vector &y_values); std::vector build_from_roots(const std::vector &roots); +std::vector eval_precompute(const GF2E &point, size_t poly_size); +GF2E eval_fast(const std::vector &poly, const std::vector &x_pow_n, + const size_t lambda); GF2E eval(const std::vector &poly, const GF2E &point); } // namespace field diff --git a/tests/field_test.cpp b/tests/field_test.cpp index 1d86ba5..1dfae74 100644 --- a/tests/field_test.cpp +++ b/tests/field_test.cpp @@ -265,4 +265,23 @@ TEST_CASE("NTL interpolation == custom", "[field]") { REQUIRE(a_lag[i][j] == utils::ntl_to_custom(b_lag[i][j])); } } +} + +TEST_CASE("Fast poly eval == poly eval", "[field]") { + + field::GF2E::init_extension_field(banquet_instance_get(Banquet_L1_Param4)); + // field::GF2E a; + std::vector a(1); + a[0].set_coeff(31); + a[0].set_coeff(39); + + std::vector roots = field::get_first_n_field_elements(3); + std::vector poly = field::build_from_roots(roots); + field::GF2E eval = field::eval(poly, a[0]); + + std::vector precomp = field::eval_precompute(a[0], poly.size()); + field::GF2E eval1 = field::eval_fast( + poly, precomp, banquet_instance_get(Banquet_L1_Param4).lambda); + + REQUIRE(eval == eval1); } \ No newline at end of file