diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 9fe874d..9b27c8a 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -1,5 +1,20 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +# gRPC +# must be included separately, since we need to load transitive deps of grpc. +http_archive( + name = "com_github_grpc_grpc", + strip_prefix = "grpc-1.55.0", + urls = [ + "https://github.com/grpc/grpc/archive/refs/tags/v1.55.0.zip", + ], +) +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +grpc_deps() +load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") +grpc_extra_deps() + # rules_proto defines abstract rules for building Protocol Buffers. # https://github.com/bazelbuild/rules_proto http_archive( @@ -47,8 +62,6 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_depe go_rules_dependencies() -go_register_toolchains(version = "1.19.3") - # Install gtest. # https://github.com/google/googletest http_archive( diff --git a/dpf/distributed_point_function.h b/dpf/distributed_point_function.h index f7e2c3e..3595a56 100644 --- a/dpf/distributed_point_function.h +++ b/dpf/distributed_point_function.h @@ -40,8 +40,6 @@ #include "absl/types/span.h" #include "dpf/aes_128_fixed_key_hash.h" #include "dpf/distributed_point_function.pb.h" -#include "dpf/internal/evaluate_prg_hwy.h" -#include "dpf/internal/maybe_deref_span.h" #include "dpf/internal/proto_validator.h" #include "dpf/internal/value_type_helpers.h" #include "hwy/aligned_allocator.h" @@ -232,17 +230,9 @@ class DistributedPointFunction { // Returns FAILED_PRECONDITION if `RegisterValueType` has not been called // for all types in the `DpfParameters` passed at construction. - // Legacy interface for absl::uint128, which doesn't require explicitly - // converting to absl::Span. - absl::StatusOr> GenerateKeysIncremental( - absl::uint128 alpha, const std::vector& beta) { - return GenerateKeysIncremental(alpha, absl::MakeConstSpan(beta)); - } - - // Templated version when all value types are equal. - template + // Overload for simple integers. absl::StatusOr> GenerateKeysIncremental( - absl::uint128 alpha, absl::Span beta) { + absl::uint128 alpha, absl::Span beta) { std::vector values(beta.size()); for (int i = 0; i < static_cast(beta.size()); ++i) { absl::StatusOr value = ToValue(beta[i]); @@ -377,35 +367,6 @@ class DistributedPointFunction { &ctx); } - // Evaluates a span of DPF keys. The i-th key is evaluated at - // evaluation_points[i]. After each hierarchy level, calls `op` on the output - // at that hierarchy level. `op` must be callable with the following - // signature: - // - // op(int hierarchy_level, absl::Span values) - // - // It should return a value that is implicitly convertible to `bool`. - // - // This method is intended for use cases similar to - // - // absl::StatusOr> EvaluateAt( - // int hierarchy_level, absl::Span evaluation_points, - // EvaluationContext& ctx) - // - // but without the overhead of EvaluationContext. Instead, all operations on - // intermediate values, and obtaining the final result, should be done via - // `op`. - // - // Return absl::OkStatus() after successfully evaluating `op` on the last - // hierarchy level, or as soon as `op` returns `false`. Returns - // INVALID_ARGUMENT in case any `key` is malformed, or if any of the - // `evaluation_points` are out of range. - template - absl::Status EvaluateAndApply( - dpf_internal::MaybeDerefSpan, - absl::Span evaluation_points, Fn op, - int evaluation_points_rightshift = 0) const; - // Returns the DpfParameters of this DPF. inline absl::Span parameters() const { return parameters_; @@ -573,13 +534,6 @@ class DistributedPointFunction { absl::flat_hash_map& value_correction_functions); - // For the given `key` and `hierarchy_level`, returns the value correction - // words as an array of integers, where the size of the array matches the - // number of batched elements per block. - template - absl::StatusOr()>> - GetValueCorrectionAsArray(const DpfKey& key, int hierarchy_level) const; - // Joint implementation of the two variants of `EvaluateAt`. If `ctx != // NULL`, `key` must point to `ctx->key()`, and `*ctx` will be updated with // the partial evaluations at this `hierarchy_level`. @@ -636,6 +590,8 @@ class DistributedPointFunction { // correct values for it anyway. absl::flat_hash_map value_correction_functions_; + + friend class KeyGenerationProtocol; }; //========================// @@ -728,7 +684,7 @@ absl::StatusOr> DistributedPointFunction::EvaluateUntil( int previous_log_domain_size = 0; int previous_hierarchy_level = ctx.previous_hierarchy_level(); if (!prefixes.empty()) { - DCHECK_GE(ctx.previous_hierarchy_level(), 0); + DCHECK(ctx.previous_hierarchy_level() >= 0); previous_log_domain_size = parameters_[previous_hierarchy_level].log_domain_size(); for (absl::uint128 prefix : prefixes) { @@ -864,7 +820,7 @@ absl::StatusOr> DistributedPointFunction::EvaluateUntil( // Compute the number of outputs we will have. For each prefix, we will have a // full expansion from the previous heirarchy level to the current heirarchy // level. - DCHECK_LT(log_domain_size - previous_log_domain_size, 63); + DCHECK(log_domain_size - previous_log_domain_size < 63); int64_t outputs_per_prefix = int64_t{1} << (log_domain_size - previous_log_domain_size); @@ -890,26 +846,6 @@ absl::StatusOr> DistributedPointFunction::EvaluateUntil( } } -template -absl::StatusOr()>> -DistributedPointFunction::GetValueCorrectionAsArray(const DpfKey& key, - int hierarchy_level) const { - // Get output correction word from `key`. - const ::google::protobuf::RepeatedPtrField* value_correction = nullptr; - if (hierarchy_level < static_cast(parameters_.size()) - 1) { - value_correction = - &(key.correction_words(hierarchy_to_tree_[hierarchy_level]) - .value_correction()); - } else { - // Last level value correction is stored in an extra proto field, since we - // have one less correction word than tree levels. - value_correction = &(key.last_level_value_correction()); - } - - // Split output correction into elements of type T, and return it. - return dpf_internal::ValuesToArray(*value_correction); -} - template absl::StatusOr> DistributedPointFunction::EvaluateAtImpl( const DpfKey& key, int hierarchy_level, @@ -954,11 +890,31 @@ absl::StatusOr> DistributedPointFunction::EvaluateAtImpl( return std::vector{}; // Nothing to do. } + // Get output correction word from `key`. + constexpr int elements_per_block = dpf_internal::ElementsPerBlock(); + const ::google::protobuf::RepeatedPtrField* value_correction = nullptr; + if (hierarchy_level < static_cast(parameters_.size()) - 1) { + value_correction = + &(key.correction_words(hierarchy_to_tree_[hierarchy_level]) + .value_correction()); + } else { + // Last level value correction is stored in an extra proto field, since we + // have one less correction word than tree levels. + value_correction = &(key.last_level_value_correction()); + } + + // Split output correction into elements of type T, and save it in + // correction_ints. + absl::StatusOr> correction_ints = + dpf_internal::ValuesToArray(*value_correction); + if (!correction_ints.ok()) { + return correction_ints.status(); + } + // Split up evaluation_points into tree indices and block indices, if we're // operating on a packed type. Otherwise set `tree_indices` to // `evaluation_points`. hwy::AlignedFreeUniquePtr maybe_recomputed_tree_indices; - constexpr int elements_per_block = dpf_internal::ElementsPerBlock(); absl::Span tree_indices; if (elements_per_block > 1) { maybe_recomputed_tree_indices = @@ -1026,22 +982,16 @@ absl::StatusOr> DistributedPointFunction::EvaluateAtImpl( } DCHECK(static_cast(seeds.size()) == num_evaluation_points); - // Hash `seeds`. + // Hash DPF evaluations. absl::StatusOr> hashed_expansion = HashExpandedSeeds(hierarchy_level, seeds); if (!hashed_expansion.ok()) { return hashed_expansion.status(); } - // Get value correction words. - absl::StatusOr> correction_ints = - GetValueCorrectionAsArray(key, hierarchy_level); - if (!correction_ints.ok()) { - return correction_ints.status(); - } - // Perform value correction. - std::vector result(num_evaluation_points); + std::vector result; + result.reserve(num_evaluation_points); const int blocks_needed = blocks_needed_[hierarchy_level]; for (int64_t i = 0; i < num_evaluation_points; ++i) { std::array current_elements = @@ -1053,7 +1003,7 @@ absl::StatusOr> DistributedPointFunction::EvaluateAtImpl( if (elements_per_block > 1) { block_index = DomainToBlockIndex(evaluation_points[i], hierarchy_level); } - result[i] = current_elements[block_index]; + result.push_back(current_elements[block_index]); if (selected_partial_evaluations->control_bits[i]) { result[i] += (*correction_ints)[block_index]; } @@ -1069,134 +1019,6 @@ absl::StatusOr> DistributedPointFunction::EvaluateAtImpl( return result; } -template -absl::Status DistributedPointFunction::EvaluateAndApply( - dpf_internal::MaybeDerefSpan keys, - absl::Span evaluation_points, Fn op, - int evaluation_points_rightshift) const { - if (evaluation_points.size() != keys.size()) { - return absl::InvalidArgumentError( - "`keys.size()` != `evaluation_points.size()`"); - } - for (int i = 0; i < keys.size(); ++i) { - absl::Status status = proto_validator_->ValidateDpfKey(keys[i]); - if (!status.ok()) return status; - } - - const int64_t num_keys = keys.size(); - const int num_hierarchy_levels = parameters_.size(); - DpfExpansion eval; - eval.control_bits.resize(num_keys); - eval.seeds = hwy::AllocateAligned(num_keys); - if (eval.seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - absl::Span seeds(eval.seeds.get(), num_keys); - absl::Span control_bits(eval.control_bits); - hwy::AlignedFreeUniquePtr correction_seeds; - BitVector correction_control_bits_left, correction_control_bits_right; - std::vector values(num_keys); - - // Initialize seeds and control bits. - for (int64_t i = 0; i < num_keys; ++i) { - seeds[i] = absl::MakeUint128(keys[i].seed().high(), keys[i].seed().low()); - control_bits[i] = keys[i].party(); - } - - int start_level = 0; - int stop_level = hierarchy_to_tree_[0]; - for (int hierarchy_level = 0; hierarchy_level < num_hierarchy_levels; - ++hierarchy_level) { - if (hierarchy_level > 0) { - start_level = stop_level; - stop_level = hierarchy_to_tree_[hierarchy_level]; - } - - // Compute index shifts for the current level. - const int domain_index_rightshift = - evaluation_points_rightshift + parameters_.back().log_domain_size() - - parameters_[hierarchy_level].log_domain_size(); - const int tree_index_rightshift = evaluation_points_rightshift + - parameters_.back().log_domain_size() - - hierarchy_to_tree_[hierarchy_level]; - - int num_tree_levels = stop_level - start_level; - if (num_tree_levels > 0) { - correction_seeds = - hwy::AllocateAligned(num_tree_levels * num_keys); - if (correction_seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - correction_control_bits_left.resize(num_tree_levels * num_keys); - correction_control_bits_right.resize(num_tree_levels * num_keys); - for (int i = 0; i < num_tree_levels; ++i) { - for (int64_t j = 0; j < num_keys; ++j) { - const int64_t index = i * num_keys + j; - const CorrectionWord& cw = keys[j].correction_words(start_level + i); - correction_seeds[index] = - absl::MakeUint128(cw.seed().high(), cw.seed().low()); - correction_control_bits_left[index] = cw.control_left(); - correction_control_bits_right[index] = cw.control_right(); - } - } - - // Evaluate the current hierarchy level for all keys. - absl::Status status = dpf_internal::EvaluateSeeds( - seeds.size(), num_tree_levels, num_tree_levels * num_keys, - seeds.data(), control_bits.data(), evaluation_points.data(), - tree_index_rightshift, correction_seeds.get(), - correction_control_bits_left.data(), - correction_control_bits_right.data(), prg_left_, prg_right_, - seeds.data(), control_bits.data()); - if (!status.ok()) { - return status; - } - } - - // Hash `seeds`. - absl::StatusOr> - hashed_expansion = HashExpandedSeeds(hierarchy_level, seeds); - if (!hashed_expansion.ok()) { - return hashed_expansion.status(); - } - - // Compute value correction for the current level. - constexpr int elements_per_block = dpf_internal::ElementsPerBlock(); - const int blocks_needed = blocks_needed_[hierarchy_level]; - for (int64_t i = 0; i < num_keys; ++i) { - std::array current_elements = - dpf_internal::ConvertBytesToArrayOf(absl::string_view( - reinterpret_cast(hashed_expansion->get() + - i * blocks_needed), - blocks_needed * sizeof(absl::uint128))); - absl::StatusOr> correction_ints = - GetValueCorrectionAsArray(keys[i], hierarchy_level); - if (!correction_ints.ok()) { - return correction_ints.status(); - } - int block_index = 0; - if (elements_per_block > 1 && domain_index_rightshift < 128) { - block_index = DomainToBlockIndex( - evaluation_points[i] >> domain_index_rightshift, hierarchy_level); - } - values[i] = current_elements[block_index]; - if (control_bits[i]) { - values[i] += (*correction_ints)[block_index]; - } - if (keys[i].party() == 1) { - values[i] = -values[i]; - } - } - - // Call the callback with the values at the current level, and return if the - // result is `false`. - if (!op(values)) { - break; - } - } - return absl::OkStatus(); -} - } // namespace distributed_point_functions #endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_ diff --git a/dpf/key_generation_protocol/BUILD b/dpf/key_generation_protocol/BUILD new file mode 100644 index 0000000..ffc398e --- /dev/null +++ b/dpf/key_generation_protocol/BUILD @@ -0,0 +1,150 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_library") +load("@rules_proto//proto:defs.bzl", "proto_library") + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "key_generation_protocol", + srcs = ["key_generation_protocol.cc"], + hdrs = ["key_generation_protocol.h"], + deps = [ + ":key_generation_protocol_cc_proto", + "//dcf/fss_gates/prng:basic_rng", + "//dpf:distributed_point_function", + "//dpf:distributed_point_function_cc_proto", + "//dpf:status_macros", + "//dpf/internal:evaluate_prg_hwy", + "//dpf/internal:get_hwy_mode", + "//dpf/internal:proto_validator", + "//dpf/internal:value_type_helpers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "key_generation_protocol_test", + srcs = ["key_generation_protocol_test.cc"], + deps = [ + ":key_generation_protocol", + "//dcf/fss_gates/prng:basic_rng", + "//dpf:distributed_point_function_cc_proto", + "//dpf/internal:evaluate_prg_hwy", + "//dpf/internal:get_hwy_mode", + "//dpf/internal:proto_validator", + "//dpf/internal:status_matchers", + "//dpf/internal:value_type_helpers", + "@com_github_google_googletest//:gtest_main", + ], +) + +cc_proto_library( + name = "key_generation_protocol_cc_proto", + deps = [":key_generation_protocol_proto"], +) + +proto_library( + name = "key_generation_protocol_proto", + srcs = ["key_generation_protocol.proto"], + deps = ["//dpf:distributed_point_function_proto"], +) + +proto_library( + name = "key_generation_protocol_rpc_proto", + srcs = [":key_generation_protocol_rpc.proto"], + deps = [ + ":key_generation_protocol_proto", + ], +) +cc_proto_library( + name = "key_generation_protocol_rpc_cc_proto", + deps = [ + ":key_generation_protocol_rpc_proto", + ] +) +cc_grpc_library( + name = "key_generation_protocol_rpc_grpc_proto", + srcs = [":key_generation_protocol_rpc_proto"], + grpc_only = True, + deps = [":key_generation_protocol_rpc_cc_proto"], +) + + +cc_library( + name = "key_generation_protocol_rpc_impl", + srcs = ["key_generation_protocol_rpc_impl.cc"], + hdrs = [ + "key_generation_protocol_rpc_impl.h", + ], + deps = [ + ":key_generation_protocol", + ":key_generation_protocol_cc_proto", + ":key_generation_protocol_rpc_grpc_proto", + "@com_github_grpc_grpc//:grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_binary( + name = "key_generation_protocol_server", + srcs = ["key_generation_protocol_server.cc"], + deps = [ + ":key_generation_protocol", + ":key_generation_protocol_rpc_impl", + ":key_generation_protocol_rpc_grpc_proto", + "@com_google_absl//absl/base", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_github_grpc_grpc//:grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_binary( + name = "key_generation_protocol_client", + srcs = ["key_generation_protocol_client.cc"], + deps = [ + ":key_generation_protocol", + ":key_generation_protocol_cc_proto", + ":key_generation_protocol_rpc_grpc_proto", + "@com_google_absl//absl/base", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_github_grpc_grpc//:grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + + diff --git a/dpf/key_generation_protocol/key_generation_protocol.cc b/dpf/key_generation_protocol/key_generation_protocol.cc new file mode 100644 index 0000000..535f113 --- /dev/null +++ b/dpf/key_generation_protocol/key_generation_protocol.cc @@ -0,0 +1,832 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dpf/key_generation_protocol/key_generation_protocol.h" + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "dpf/distributed_point_function.h" +#include "dpf/status_macros.h" +#include "dcf/fss_gates/prng/basic_rng.h" +#include "dpf/internal/evaluate_prg_hwy.h" +#include "dpf/internal/get_hwy_mode.h" +#include "dpf/internal/proto_validator.h" +#include "dpf/internal/value_type_helpers.h" +#include "dpf/status_macros.h" + + +namespace distributed_point_functions { + +KeyGenerationProtocol::KeyGenerationProtocol( + std::unique_ptr dpf) + : dpf_(std::move(dpf)) {} + +absl::StatusOr> +KeyGenerationProtocol::Create(absl::Span parameters) { +// if (party != 0 && party != 1) { +// return absl::InvalidArgumentError("`party` must be 0 or 1"); +// } + DPF_ASSIGN_OR_RETURN(auto dpf, + DistributedPointFunction::CreateIncremental(parameters)); + + +// uint64_t levels = parameters.back().log_domain_size(); +// uint64_t levels = 63; + + return absl::WrapUnique(new KeyGenerationProtocol(std::move(dpf))); +} + + absl::StatusOr> + KeyGenerationProtocol::PerformKeyGenerationPrecomputation(){ + + KeyGenerationPreprocessing preproc_party0, preproc_party1; + + + int n = dpf_->parameters_.size(); + + for(int i = 0; i < n; i++){ + IdpfLevelCorrelation ipdfcorr_party0, ipdfcorr_party1; + + // Generating correlations for first mux + std::pair mux_1; + + DPF_ASSIGN_OR_RETURN(mux_1, + KeyGenerationProtocol::genMuxCorrelation()); + + ipdfcorr_party0.mux_1 = mux_1.first; + ipdfcorr_party1.mux_1 = mux_1.second; + + + // Generating correlations for second mux + std::pair mux_2; + + DPF_ASSIGN_OR_RETURN(mux_2, + KeyGenerationProtocol::genMuxCorrelation()); + + ipdfcorr_party0.mux_2 = mux_2.first; + ipdfcorr_party1.mux_2 = mux_2.second; + + // Generating bit beaver triples + + const absl::string_view kSampleSeed = absl::string_view(); + DPF_ASSIGN_OR_RETURN( + auto rng, distributed_point_functions::BasicRng::Create(kSampleSeed)); + + absl::uint128 a_temp, b_temp, a_0_temp, b_0_temp, c_0_temp; + bool a, b, c, a_0, b_0, c_0, a_1, b_1, c_1; + DPF_ASSIGN_OR_RETURN(a_temp, rng->Rand128()); + DPF_ASSIGN_OR_RETURN(b_temp, rng->Rand128()); + DPF_ASSIGN_OR_RETURN(a_0_temp, rng->Rand128()); + DPF_ASSIGN_OR_RETURN(b_0_temp, rng->Rand128()); + DPF_ASSIGN_OR_RETURN(c_0_temp, rng->Rand128()); + a = (a_temp & 1) ? 1 : 0; + b = (b_temp & 1) ? 1 : 0; + c = a & b; + a_0 = (a_0_temp & 1) ? 1 : 0; + b_0 = (b_0_temp & 1) ? 1 : 0; + c_0 = (c_0_temp & 1) ? 1 : 0; + a_1 = a ^ a_0; + b_1 = b ^ b_0; + c_1 = c ^ c_0; + + ipdfcorr_party0.bit_triple = {a, a_0, b_0, c_0}; + ipdfcorr_party1.bit_triple = {b, a_1, b_1, c_1}; + + // Populating the idpf correlation for this level + + preproc_party0.level_corr.push_back(ipdfcorr_party0); + preproc_party1.level_corr.push_back(ipdfcorr_party1); + } + + return std::make_pair(std::move(preproc_party0), std::move(preproc_party1)); + +// return absl::UnimplementedError(""); + } + + absl::StatusOr KeyGenerationProtocol::Initialize( + int partyid, + const absl::uint128& alpha_shares, + const std::vector& beta_shares, + const KeyGenerationPreprocessing& keygen_preproc){ + + // We are assuming that number of parameters = number of levels + levels = dpf_->parameters_.size(); + + + // Check validity of beta. + if (beta_shares.size() != dpf_->parameters_.size()) { + return absl::InvalidArgumentError( + "`beta` has to have the same size as `parameters` passed at " + "construction"); + } + for (int i = 0; i < static_cast(dpf_->parameters_.size()); ++i) { + absl::Status status = dpf_->proto_validator_->ValidateValue(beta_shares[i], i); + if (!status.ok()) { + return status; + } + } + + // Sampling root seed + + std::vector seeds; + + const absl::string_view kSampleSeed = absl::string_view(); + DPF_ASSIGN_OR_RETURN( + auto rng, distributed_point_functions::BasicRng::Create(kSampleSeed)); + + absl::uint128 seed; + DPF_ASSIGN_OR_RETURN(seed, rng->Rand128()); + + seeds.push_back(seed); + + + // Sampling root control bit share + + std::vector shares_of_control_bits; + + shares_of_control_bits.push_back(partyid); + + + // Populating the root seed in DPF key + DpfKey key; + + // Setting root seed in DPF key +// RAND_bytes(reinterpret_cast(&seeds), sizeof(absl::uint128)); + key.mutable_seed()->set_high(absl::Uint128High64(seed)); + key.mutable_seed()->set_low(absl::Uint128Low64(seed)); + + + // Setting party id in DPF key + key.set_party(partyid); + + ProtocolState state; + + state.tree_level = 0; + state.seeds = seeds; + state.shares_of_control_bits = shares_of_control_bits; + state.key = key; + state.alpha_shares = alpha_shares; + state.beta_shares = beta_shares; + state.keygen_preproc = keygen_preproc; + + return state; + +// return absl::UnimplementedError(""); + } + + + absl::StatusOr + KeyGenerationProtocol::ComputeSeedCorrectionOtReceiverMessage( + int partyid, + ProtocolState& state) const{ + + // Prepare OT receiver message using state.alpha_shares + // and mux_1 correlation + + SeedCorrectionOtReceiverMessage msg_ot_recv; + + bool alpha_level_share = state.alpha_shares & (1 << (levels - state.tree_level - 1)) ? 1 : 0; + bool rot_masked_alpha_level_share = alpha_level_share ^ + state.keygen_preproc.level_corr[state.tree_level].mux_1.rot_receiver_choice_bit; + + msg_ot_recv.set_choice_bit_mask(rot_masked_alpha_level_share); + + + return msg_ot_recv; + } + + absl::StatusOr + KeyGenerationProtocol::ComputeSeedCorrectionOtSenderMessage(int partyid, + const SeedCorrectionOtReceiverMessage& seed_ot_receiver_message, + ProtocolState& state) const{ + + absl::uint128 seed_left_cumulative_xor = 0, seed_right_cumulative_xor = 0; + bool control_left_cumulative_xor = 0, control_right_cumulative_xor = 0; + + std::vector expanded_seeds_left; + expanded_seeds_left.resize(state.seeds.size()); + + std::vector expanded_seeds_right; + expanded_seeds_right.resize(state.seeds.size()); + + // Line 3: Expanding all the left children at next level. + DPF_RETURN_IF_ERROR( + dpf_->prg_left_.Evaluate(state.seeds, + absl::MakeSpan(expanded_seeds_left))); + + // Line 3: Expanding all the right children at next level. + DPF_RETURN_IF_ERROR( + dpf_->prg_right_.Evaluate(state.seeds, + absl::MakeSpan(expanded_seeds_right))); + + + // Line 4 : Cumulative XOR of all the left seeds, left control bits, + // right seeds, right control bits at the next level. + for(int i = 0; i < state.seeds.size(); i++){ + bool control_left = dpf_internal::ExtractAndClearLowestBit( + expanded_seeds_left[i]); + bool control_right = dpf_internal::ExtractAndClearLowestBit( + expanded_seeds_right[i]); + seed_left_cumulative_xor ^= expanded_seeds_left[i]; + seed_right_cumulative_xor ^= expanded_seeds_right[i]; + control_left_cumulative_xor ^= control_left; + control_right_cumulative_xor ^= control_right; + + // Populating new (uncorrected) seeds. + state.uncorrected_seeds.push_back(expanded_seeds_left[i]); + state.uncorrected_seeds.push_back(expanded_seeds_right[i]); + state.shares_of_uncorrected_control_bits.push_back(control_left); + state.shares_of_uncorrected_control_bits.push_back(control_right); + } + + // Storing cumulative seeds and control bits in the state. + state.seed_left_cumulative = seed_left_cumulative_xor; + state.seed_right_cumulative = seed_right_cumulative_xor; + state.control_left_cumulative = control_left_cumulative_xor; + state.control_right_cumulative = control_right_cumulative_xor; + + + // Sampling randomness for Mux 1 mask + const absl::string_view kSampleSeed = absl::string_view(); + DPF_ASSIGN_OR_RETURN( + auto rng, distributed_point_functions::BasicRng::Create(kSampleSeed)); + + DPF_ASSIGN_OR_RETURN(absl::uint128 r, rng->Rand128()); + + state.mux_1_randomness = r; + // Preparing OT sender messages using r, seed_left_cumulative_xor, + // seed_right_cumulative_xor and mux_1 correlation + + SeedCorrectionOtSenderMessage mux_sender_msg; + + absl::uint128 mux_sender_first_msg_without_rot_mask, + mux_sender_second_msg_without_rot_mask; + + // We want to perform MUX2((s^R, s^L), \alpha_l) meaning that the output + // should be s^R be \alpha_L = 0 and s^L otherwise. + + // Preparing the unmasked OT sender messages. + bool alpha_level_share = + state.alpha_shares & (1 << (levels - state.tree_level)) ? 1 : 0; + + if(alpha_level_share == false){ + mux_sender_first_msg_without_rot_mask = r ^ seed_right_cumulative_xor; + mux_sender_second_msg_without_rot_mask = r ^ seed_left_cumulative_xor; + } + else{ + mux_sender_first_msg_without_rot_mask = r ^ seed_left_cumulative_xor; + mux_sender_second_msg_without_rot_mask = r ^ seed_right_cumulative_xor; + } + + // Swap the ROT masks if seed_ot_receiver_message = 1 + absl::uint128 rot_sender_first_msg, rot_sender_second_msg; + if(seed_ot_receiver_message.choice_bit_mask()){ + rot_sender_first_msg = + state.keygen_preproc.level_corr[state.tree_level].mux_1.rot_sender_second_string; + rot_sender_second_msg = + state.keygen_preproc.level_corr[state.tree_level].mux_1.rot_sender_first_string; + } + else{ + rot_sender_first_msg = + state.keygen_preproc.level_corr[state.tree_level].mux_1.rot_sender_first_string; + rot_sender_second_msg = + state.keygen_preproc.level_corr[state.tree_level].mux_1.rot_sender_second_string; + } + + // Preparing the masked OT sender msgs. + absl::uint128 mux_sender_first_msg_with_rot_mask, + mux_sender_second_msg_with_rot_mask; + + mux_sender_first_msg_with_rot_mask = mux_sender_first_msg_without_rot_mask ^ + rot_sender_first_msg; + + mux_sender_second_msg_with_rot_mask = mux_sender_second_msg_without_rot_mask ^ + rot_sender_second_msg; + + mux_sender_msg.mutable_masked_message_one()->set_high( + absl::Uint128High64(mux_sender_first_msg_with_rot_mask)); + + mux_sender_msg.mutable_masked_message_one()->set_low( + absl::Uint128Low64(mux_sender_first_msg_with_rot_mask)); + + mux_sender_msg.mutable_masked_message_two()->set_high( + absl::Uint128High64(mux_sender_second_msg_with_rot_mask)); + + mux_sender_msg.mutable_masked_message_two()->set_low( + absl::Uint128Low64(mux_sender_second_msg_with_rot_mask)); + + return mux_sender_msg; + } + + absl::StatusOr + KeyGenerationProtocol::ComputeSeedCorrectionOpening( + int partyid, + const SeedCorrectionOtSenderMessage& seed_ot_sender_message, + ProtocolState& state) const{ + + SeedCorrectionShare opening_msg; + + bool alpha_level_share = + state.alpha_shares & (1 << (levels - state.tree_level)) ? 1 : 0; + + // Compute mux output + + absl::uint128 mux_output; + + // Retrieve the correct OT msg + + absl::uint128 ot_output; + + absl::uint128 sender_string_one, sender_string_two; + + sender_string_one = absl::MakeUint128( + seed_ot_sender_message.masked_message_one().high(), + seed_ot_sender_message.masked_message_one().low()); + + sender_string_two = absl::MakeUint128( + seed_ot_sender_message.masked_message_two().high(), + seed_ot_sender_message.masked_message_two().low()); + + + if(alpha_level_share == false){ + ot_output = sender_string_one ^ state.keygen_preproc. + level_corr[state.tree_level].mux_1.rot_receiver_string; + } + else{ + ot_output = sender_string_two ^ state.keygen_preproc. + level_corr[state.tree_level].mux_1.rot_receiver_string; + } + + // Add the randomness of mux to compute the output : Step 6 + // in cryptflow2 mux protocol + + mux_output = ot_output ^ state.mux_1_randomness; + + opening_msg.mutable_seed()->set_high( + absl::Uint128High64(mux_output)); + + opening_msg.mutable_seed()->set_low( + absl::Uint128Low64(mux_output)); + + state.mux_1_output = mux_output; + + + // Computing shares of left and right contol bit correction. + bool control_left_correction, control_right_correction; + + // Step 5 + control_left_correction = state.control_left_cumulative ^ alpha_level_share ^ partyid; + + // Step 5 + control_right_correction = state.control_right_cumulative ^ alpha_level_share; + + state.control_left_correction = control_left_correction; + state.control_right_correction = control_right_correction; + + opening_msg.set_control_bit_left(control_left_correction); + opening_msg.set_control_bit_right(control_right_correction); + + + return opening_msg; + } + + absl::StatusOr KeyGenerationProtocol::ApplySeedCorrectionShare + (int partyid, + const SeedCorrectionShare& seed_correction_share, + ProtocolState& state) const{ + + using T = uint64_t; + + absl::uint128 reconstructed_seed_correction; + bool reconstructed_control_left_correction, + reconstructed_control_right_correction; + + absl::uint128 seed_correction_other_party_share = + absl::MakeUint128( + seed_correction_share.seed().high(), + seed_correction_share.seed().low()); + + + // Step 5 : Opening correction seed, left and right correction control bits + reconstructed_seed_correction = + state.mux_1_output ^ seed_correction_other_party_share; + + reconstructed_control_left_correction = + state.control_left_correction ^ seed_correction_share.control_bit_left(); + + reconstructed_control_right_correction = + state.control_right_correction ^ seed_correction_share.control_bit_right(); + + // Adding reconstructed_seed_correction, reconstructed_control_left_correction + // reconstructed_control_right_correction to the DPF key +// CorrectionWord* correction_word = state.keys.add_correction_words(); + + // Storing reconstructed correction seed, left and right correction control bits + // in the state. Will be used later for populating the DPF key + state.reconstructed_seed_correction = reconstructed_seed_correction; + + state.reconstructed_control_left_correction = reconstructed_control_left_correction; + + state.reconstructed_control_right_correction = reconstructed_control_right_correction; + + + // TODO: Implement steps 6 - 10 + + uint64_t n = state.seeds.size(); + + + for(uint64_t i = 0; i < n; i++) { + uint64_t left_index = i << 1; + uint64_t right_index = left_index + 1; + + bool control_bit_parent = state.shares_of_control_bits[i]; + + // Perform correction of left and right seeds and control bits - Step 6 and Step 7 + if (control_bit_parent) { + state.uncorrected_seeds[left_index] ^= reconstructed_seed_correction; + state.uncorrected_seeds[right_index] ^= reconstructed_seed_correction; + state.shares_of_uncorrected_control_bits[left_index] = + (state.shares_of_uncorrected_control_bits[left_index] ^ reconstructed_control_left_correction); + state.shares_of_uncorrected_control_bits[right_index] = ( + state.shares_of_uncorrected_control_bits[right_index] ^ reconstructed_control_right_correction); + } + + } + + + + // TODO : Remove the template hardcoding to uint64_t + + // Issue: this initialization will depend on the Value type + Value cumulative_word = ValueZero(); + + absl::uint128 cumulative_control_sum = partyid; + + + + std::vector seed_after_convert, value_seed_after_convert; + + seed_after_convert.resize(state.uncorrected_seeds.size()); + value_seed_after_convert.resize(state.uncorrected_seeds.size()); + + DPF_RETURN_IF_ERROR( + dpf_->prg_left_.Evaluate(state.uncorrected_seeds, + absl::MakeSpan(seed_after_convert))); + + DPF_RETURN_IF_ERROR( + dpf_->prg_value_.Evaluate(state.uncorrected_seeds, + absl::MakeSpan(value_seed_after_convert))); + + state.uncorrected_seeds = seed_after_convert; + + + for(int i = 0; i < value_seed_after_convert.size(); i++){ + + // Temporary hack for converting absl::uint128 into + // required integer type (e.g. uint64_t) + T out_value_temp = static_cast(value_seed_after_convert[i]); + + Value value = ToValue(out_value_temp); + + // Line 9 : Adding words + DPF_ASSIGN_OR_RETURN(cumulative_word, + ValueAdd(cumulative_word, value)); + + // Line 10: Adding control bits + if(partyid == 0){ + cumulative_control_sum += + (state.shares_of_uncorrected_control_bits[i]) ? 1 : 0; + } + else{ + // TODO : Check the -1 operation because we are operating over uint + cumulative_control_sum += + (state.shares_of_uncorrected_control_bits[i]) ? -1 : 0; + } + + } + + + state.cumulative_word = cumulative_word; + + // Line 10 : Setting tau_zero and tau_one to be the LSB and second LSB respectively. + state.tau_zero = cumulative_control_sum & (1) ? 1 : 0; + state.tau_one = cumulative_control_sum & (1 << 1) ? 1 : 0; + + bool masked_tau_msg = + state.tau_zero ^ state.keygen_preproc.level_corr[state.tree_level].bit_triple.mask; + + state.masked_tau_zero = masked_tau_msg; + + MaskedTau round4msg; + + round4msg.set_masked_tau_zero(masked_tau_msg); + + return round4msg; + } + + absl::StatusOr + KeyGenerationProtocol::ComputeValueCorrectionOtReceiverMessage( + int partyid, + const MaskedTau& masked_tau, + ProtocolState& state) const{ + + // Compute share of t* using masked tau msg + + bool masked_tau_zero_party_0, masked_tau_zero_party_1; + + if(partyid == 0){ + masked_tau_zero_party_0 = state.masked_tau_zero; + masked_tau_zero_party_1 = masked_tau.masked_tau_zero(); + } + else{ + masked_tau_zero_party_1 = state.masked_tau_zero; + masked_tau_zero_party_0 = masked_tau.masked_tau_zero(); + } + + + bool share_of_product; + + + if(partyid == 0){ + share_of_product = (masked_tau_zero_party_0 & state.keygen_preproc.level_corr[state.tree_level].bit_triple.b) + ^ (masked_tau_zero_party_1 & state.keygen_preproc.level_corr[state.tree_level].bit_triple.a) + ^ state.keygen_preproc.level_corr[state.tree_level].bit_triple.c; + } + else{ + share_of_product = (masked_tau_zero_party_0 & masked_tau_zero_party_1) + ^ (masked_tau_zero_party_0 & state.keygen_preproc.level_corr[state.tree_level].bit_triple.b) + ^ (masked_tau_zero_party_1 & state.keygen_preproc.level_corr[state.tree_level].bit_triple.a) + ^ state.keygen_preproc.level_corr[state.tree_level].bit_triple.c; + } + + + + bool share_of_t_star; + + if(partyid == 0){ + share_of_t_star = state.tau_one ^ share_of_product; + } + else{ + share_of_t_star = 1 ^ state.tau_one ^ share_of_product; + } + + state.share_of_t_star = share_of_t_star; + + // Compute MUX 2 round 1 msg + + ValueCorrectionOtReceiverMessage round5msg; + + bool masked_choice_bit = state.share_of_t_star ^ + state.keygen_preproc.level_corr[state.tree_level].mux_2.rot_receiver_choice_bit; + + round5msg.set_choice_bit_mask(masked_choice_bit); + + return round5msg; + } + + absl::StatusOr + KeyGenerationProtocol::ComputeValueCorrectionOtSenderMessage( + int partyid, + const ValueCorrectionOtReceiverMessage& value_ot_receiver_message, + ProtocolState& state) const{ + + using T = uint64_t; + + // Construct OT sender msg + + Value share_of_W0_CW, share_of_W1_CW; + + Value beta_share = state.beta_shares[state.tree_level]; + +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 0" << std::endl; + + if(partyid == 0){ + DPF_ASSIGN_OR_RETURN(share_of_W0_CW, + ValueSub( + beta_share, + state.cumulative_word)); + } + else{ +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 1" << std::endl; + DPF_ASSIGN_OR_RETURN(share_of_W0_CW, + ValueAdd( + beta_share, + state.cumulative_word)); + } + + if(partyid == 0){ + DPF_ASSIGN_OR_RETURN(share_of_W1_CW, + ValueSub( + state.cumulative_word, + beta_share)); + } + else{ +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 2" << std::endl; + DPF_ASSIGN_OR_RETURN(Value share_of_W1_CW_temp, + ValueAdd( + beta_share, + state.cumulative_word)); +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 3" << std::endl; + DPF_ASSIGN_OR_RETURN(share_of_W1_CW, + ValueNegate(share_of_W1_CW_temp)); + } + +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 4" << std::endl; + + // Swap ROT mask depending on the receiver msg + absl::uint128 rot_sender_mask_first, rot_sender_mask_second; + + if(value_ot_receiver_message.choice_bit_mask()){ + rot_sender_mask_first = state.keygen_preproc.level_corr[state.tree_level].mux_2.rot_sender_second_string; + rot_sender_mask_second = state.keygen_preproc.level_corr[state.tree_level].mux_2.rot_sender_first_string; + } + else{ + rot_sender_mask_first = state.keygen_preproc.level_corr[state.tree_level].mux_2.rot_sender_first_string; + rot_sender_mask_second = state.keygen_preproc.level_corr[state.tree_level].mux_2.rot_sender_second_string; + } + + // Convert ROT masks into Value type +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 5" << std::endl; + Value rot_sender_mask_first_value, rot_sender_mask_second_value; + + DPF_ASSIGN_OR_RETURN(rot_sender_mask_first_value, + ConvertRandToVal(rot_sender_mask_first)); + + DPF_ASSIGN_OR_RETURN(rot_sender_mask_second_value, + ConvertRandToVal(rot_sender_mask_second)); + +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 6" << std::endl; + + // Sample mux 2 randomness + const absl::string_view kSampleSeed = absl::string_view(); + DPF_ASSIGN_OR_RETURN( + auto rng, distributed_point_functions::BasicRng::Create(kSampleSeed)); + + DPF_ASSIGN_OR_RETURN(state.mux_2_randomness, rng->Rand128()); + + DPF_ASSIGN_OR_RETURN(Value random_value_mask, + ConvertRandToVal(state.mux_2_randomness)); + +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 7" << std::endl; + + + // Generate mux sender msg depending on own share of t* + + Value masked_ot_1, masked_ot_2, masked_ot_1_tmp, masked_ot_2_tmp; + + + if(state.share_of_t_star == false){ +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 7.1" << std::endl; + DPF_ASSIGN_OR_RETURN(masked_ot_1_tmp, + ValueAdd(share_of_W0_CW, + rot_sender_mask_first_value)); +//std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 7.2" << std::endl; + + + DPF_ASSIGN_OR_RETURN(masked_ot_2_tmp, + ValueAdd(share_of_W1_CW, + rot_sender_mask_second_value)); +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 7.25" << std::endl; + } + else{ +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 7.3" << std::endl; + + DPF_ASSIGN_OR_RETURN(masked_ot_1_tmp, + ValueAdd(share_of_W1_CW, + rot_sender_mask_first_value)); + +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 7.4" << std::endl; + DPF_ASSIGN_OR_RETURN(masked_ot_2_tmp, + ValueAdd(share_of_W0_CW, + rot_sender_mask_second_value)); + } + +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 8" << std::endl; + + DPF_ASSIGN_OR_RETURN(masked_ot_1, + ValueSub(masked_ot_1_tmp, + random_value_mask)); + + DPF_ASSIGN_OR_RETURN(masked_ot_2, + ValueSub(masked_ot_2_tmp, + random_value_mask)); + +// std::cout << "ComputeValueCorrectionOtSenderMessage Checkpoint 9" << std::endl; + + + ValueCorrectionOtSenderMessage round6msg; + + *(round6msg.mutable_masked_message_one()) = masked_ot_1; + + *(round6msg.mutable_masked_message_two()) = masked_ot_2; + + return round6msg; + } + + + absl::StatusOr + KeyGenerationProtocol::ComputeValueCorrectionOtShare( + int partyid, + const ValueCorrectionOtSenderMessage& value_ot_sender_message, + ProtocolState& state) const{ + + using T = uint64_t; + + ValueCorrectionShare value_corr_share; + + + // Decode mux message + + + // Retrieve OT message + Value ot_msg; + + if(state.share_of_t_star) + ot_msg = value_ot_sender_message.masked_message_two(); + else + ot_msg = value_ot_sender_message.masked_message_one(); + + // Convert ROT masks into Value type + Value rot_receiver_value; + + DPF_ASSIGN_OR_RETURN(rot_receiver_value, + ConvertRandToVal( + state.keygen_preproc.level_corr[state.tree_level].mux_2.rot_receiver_string)); + + DPF_ASSIGN_OR_RETURN(ot_msg, + ValueSub(ot_msg, + rot_receiver_value)); + + Value mux_output, mux_2_randomness_value; + + DPF_ASSIGN_OR_RETURN(mux_2_randomness_value, + ConvertRandToVal( + state.mux_2_randomness)); + + DPF_ASSIGN_OR_RETURN(mux_output, + ValueAdd(ot_msg, + mux_2_randomness_value)); + + + *(value_corr_share.mutable_value()) = mux_output; + + state.correction_value_share = mux_output; + + return value_corr_share; + } + + absl::StatusOr KeyGenerationProtocol::ApplyValueCorrectionShare( + int partyid, + const ValueCorrectionShare& value_correction_share, + ProtocolState& state) const{ + + using T = uint64_t; + + // Reconstruct the correction word + Value correction_value_other_party_share = value_correction_share.value(); + + Value correction_value; + + DPF_ASSIGN_OR_RETURN(correction_value, + ValueAdd(correction_value_other_party_share, + state.correction_value_share)); + + + state.seeds = state.uncorrected_seeds; + state.uncorrected_seeds.clear(); + state.uncorrected_seeds.reserve(2 * state.seeds.size()); + + state.shares_of_control_bits = state.shares_of_uncorrected_control_bits; + state.shares_of_uncorrected_control_bits.clear(); + + + CorrectionWord* correction_word = state.key.add_correction_words(); + + *(correction_word->add_value_correction()) = correction_value; + + correction_word->set_control_left(state.reconstructed_control_left_correction); + + correction_word->set_control_right(state.reconstructed_control_right_correction); + + correction_word->mutable_seed()->set_high(absl::Uint128High64(state.reconstructed_seed_correction)); + correction_word->mutable_seed()->set_low(absl::Uint128Low64(state.reconstructed_seed_correction)); + + + state.tree_level += 1; + + // Todo : Clear aux state variables + return 0; + } + +} // namespace distributed_point_functions diff --git a/dpf/key_generation_protocol/key_generation_protocol.h b/dpf/key_generation_protocol/key_generation_protocol.h new file mode 100644 index 0000000..8cead46 --- /dev/null +++ b/dpf/key_generation_protocol/key_generation_protocol.h @@ -0,0 +1,402 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_KEY_GENERATION_PROTOCOL_KEY_GENERATION_PROTOCOL_H_ +#define DISTRIBUTED_POINT_FUNCTIONS_DPF_KEY_GENERATION_PROTOCOL_KEY_GENERATION_PROTOCOL_H_ + +#include + +#include "absl/numeric/int128.h" +#include "absl/status/statusor.h" +#include "dpf/distributed_point_function.h" +#include "dpf/distributed_point_function.pb.h" +#include "dpf/key_generation_protocol/key_generation_protocol.pb.h" +#include "dcf/fss_gates/prng/basic_rng.h" +#include "dpf/internal/evaluate_prg_hwy.h" +#include "dpf/internal/get_hwy_mode.h" +#include "dpf/internal/proto_validator.h" +#include "dpf/internal/value_type_helpers.h" + +namespace distributed_point_functions { + +// A two-party protocol for generating a DPF key. +// For each level of the DPF evaluation tree, the following messages are +// exchanged between the parties. We refer to the corresponding lines in +// Algorithm 8 of https://eprint.iacr.org/2022/866.pdf. +// +// 1. Perform two parallel OTs to obtain shares of s_{CW} (Step 5) +// 2. Exchange shares of s_{CW}, t^L_{CW}, and t^R_{CW} (Step 5) +// 3. Perform two parallel OTs to obtain shares of W_{CW} (Step 11) +// 4. Exchange shares of W_{CW}. +// +// These steps correspond to the following functions in this class: +// +// 1a. ComputeSeedCorrectionOtReceiverMessage +// 1b. ComputeSeedCorrectionOtSenderMessage +// 2. ComputeSeedCorrectionShare +// 3a. ComputeValueCorrectionOtReceiverMessage +// 3b. ComputeValueCorrectionOtSenderMessage +// 4. ComputeValueCorrectionShare +// +// Each of these methods takes the other party's message from the previous +// round, as well as a ProtocolState message containing the party's local state. +// It updates the state and returns the computed message or a Status indicating +// any errors. +// +// NOTE: We may want to compute the value correction first, as done in +// DistributedPointFunction::GenerateIncremental. + +struct BitBeaverTriple { + bool mask; + bool a; + bool b; + bool c; +}; + +// Mux involves 2 parallel OTs (let's call OT_A and OT_B) +// Each party acts as OT sender in one OT +// and OT receiver in the other OT. + +struct MuxCorrelation{ + + // One OT + absl::uint128 rot_sender_first_string, rot_sender_second_string; + + // Other OT + bool rot_receiver_choice_bit; + absl::uint128 rot_receiver_string; +}; + +struct IdpfLevelCorrelation{ + MuxCorrelation mux_1, mux_2; + BitBeaverTriple bit_triple; +}; + + +struct KeyGenerationPreprocessing{ + // i^th element of this vector will contain the correlation needed to + // perform i^th level Doerner Shelat. + std::vector level_corr; +}; + +struct ProtocolState{ + + + // Round 2 state + + // Uncorrected seeds at the next level (left to right) + // - twice the length of seeds + std::vector uncorrected_seeds; + + // Uncorrected control bits at the next level (left to right) + // - twice the length of shares_of_control_bits + std::vector shares_of_uncorrected_control_bits; + + // Mux 1 randomness mask + absl::uint128 mux_1_randomness; + + + // Cumulative left seed, right seed, left control bit, + // and right control bit [Obtained in Step 4] + absl::uint128 seed_left_cumulative, seed_right_cumulative; + bool control_left_cumulative, control_right_cumulative; + + // Round 3 state + + absl::uint128 mux_1_output; + + bool control_left_correction, control_right_correction; + + + // Round 4 state + + absl::uint128 reconstructed_seed_correction; + + bool reconstructed_control_left_correction, + reconstructed_control_right_correction; + + bool masked_tau_zero; + + bool tau_zero, tau_one; + + Value cumulative_word; + + + // Round 5 state + + bool share_of_t_star; + + + // Round 6 state + + // Mux 2 randomness mask + absl::uint128 mux_2_randomness; + + // Round 7 state + + // share of correction value + + Value correction_value_share; + + + // global state variables + + // DPF key + DpfKey key; + + absl::uint128 alpha_shares; + std::vector beta_shares; + KeyGenerationPreprocessing keygen_preproc; + + uint64_t tree_level; + // Add more local state variables here. + + // Seeds at the current level (left to right) + std::vector seeds; + + // Control bits at the current level (left to right) + std::vector shares_of_control_bits; +}; + + +class KeyGenerationProtocol { + public: + + uint64_t levels; + + // Creates a new instance of the key generation protocol for a DPF with the + // given parameters. Party must be 0 or 1. + static absl::StatusOr> Create( + absl::Span parameters); + + // Performs precomputation stage of Key Generation protocol and returns a pair of + // KeyGenerationPreprocessing - one for each party. + absl::StatusOr> + PerformKeyGenerationPrecomputation(); + + // Create ProtocolState given shares of alphas and betas. + absl::StatusOr Initialize(int partyid, + const absl::uint128& alpha_shares, + const std::vector& beta_shares, + const KeyGenerationPreprocessing& keygen_preproc); + + // Receiver OT message for the MUX in Step 5. Just takes the state as input. + absl::StatusOr + ComputeSeedCorrectionOtReceiverMessage(int partyid, ProtocolState& state) const; + + // Computes the sender OT message given the receiver message and the state. + absl::StatusOr + ComputeSeedCorrectionOtSenderMessage(int partyid, + const SeedCorrectionOtReceiverMessage& seed_ot_receiver_message, + ProtocolState& state) const; + + // Computes the share of the seed correction word given the sender OT message + // and the state. + absl::StatusOr ComputeSeedCorrectionOpening(int partyid, + const SeedCorrectionOtSenderMessage& seed_ot_sender_message, + ProtocolState& state) const; + + // Updates the state with the other party's seed correction share + // and generate tau mult msg + absl::StatusOr ApplySeedCorrectionShare(int partyid, + const SeedCorrectionShare& seed_correction_share, + ProtocolState& state) const; + + + // Computes the OT receiver message for the MUX gate in Step 11 given the + // state. + absl::StatusOr + ComputeValueCorrectionOtReceiverMessage(int partyid, + const MaskedTau& masked_tau, + ProtocolState& state) const; + + // Computes the OT sender message in Step 11 given the receiver message and + // the state. + absl::StatusOr + ComputeValueCorrectionOtSenderMessage(int partyid, + const ValueCorrectionOtReceiverMessage& value_ot_receiver_message, + ProtocolState& state) const; + + // Computes the value correction share given the OT sender message and the + // state. + absl::StatusOr ComputeValueCorrectionOtShare(int partyid, + const ValueCorrectionOtSenderMessage& value_ot_sender_message, + ProtocolState& state) const; + + // Updates the state with the other party's value correction share. + absl::StatusOr ApplyValueCorrectionShare(int partyid, + const ValueCorrectionShare& value_correction_share, + ProtocolState& state) const; + + // Finalizes the protocol after all tree levels have been computed and returns + // the generated DpfKey. + absl::StatusOr Finalize(int partyid, ProtocolState& state) const; + + template + Value ValueZero() const{ + T zero = 0; + Value value = ToValue(zero); + return value; + } + + template + absl::StatusOr ValueAdd(const Value& value1, const Value& value2) const{ + DPF_ASSIGN_OR_RETURN(T v1, FromValue(value1)); + DPF_ASSIGN_OR_RETURN(T v2, FromValue(value2)); + T v3 = v1 + v2; + Value value3 = ToValue(v3); + return value3; + } + + template + absl::StatusOr ValueNegate(const Value& value) const{ + DPF_ASSIGN_OR_RETURN(T v, FromValue(value)); + T v_neg = -v; + Value value_neg = ToValue(v_neg); + return value_neg; + } + + template + absl::StatusOr ValueSub(const Value& value1, const Value& value2) const{ + DPF_ASSIGN_OR_RETURN(T v1, FromValue(value1)); + DPF_ASSIGN_OR_RETURN(T v2, FromValue(value2)); + T v3 = v1 - v2; + Value value3 = ToValue(v3); + return value3; + } + + // Expands seed s into a new seed and Value +// template +// absl::StatusOr> Convert(const absl::uint128 s) const{ +// +// std::vector in_seed, out_seed, out_value; +// in_seed.push_back(s); +// +// out_seed.resize(1); +// out_value.resize(1); +// +// DPF_RETURN_IF_ERROR( +// dpf_->prg_left_.Evaluate(in_seed, +// absl::MakeSpan(out_seed))); +// +// +// DPF_RETURN_IF_ERROR( +// dpf_->prg_value_.Evaluate(in_seed, +// absl::MakeSpan(out_value))); +// +// // Temporary hack for converting absl::uint128 into +// // required integer type (e.g. uint64_t) +// T out_value_temp = static_cast(out_value[0]); +// +// Value value = ToValue(out_value_temp); +// +// return std::make_pair(out_seed[0], value); +// } + + + // Helper method for converting randomness into Value type + template + absl::StatusOr ConvertRandToVal(const absl::uint128 s) const{ + +// std::vector in_seed, out_value; +// in_seed.push_back(s); +// +// out_value.resize(1); +// +// DPF_RETURN_IF_ERROR( +// dpf_->prg_value_.Evaluate(in_seed, +// absl::MakeSpan(out_value))); + + // Temporary hack for converting absl::uint128 into + // required integer type (e.g. uint64_t) + +// T out_value_temp = static_cast(out_value[0]); + + + + + T out_value_temp = static_cast(s); + + Value value = ToValue(out_value_temp); + + return value; + } + + + std::unique_ptr dpf_; + + + +private: + explicit KeyGenerationProtocol(std::unique_ptr dpf); + + + + // Number of leaves = 2 ^ levels. + + + absl::StatusOr> genMuxCorrelation(){ + + MuxCorrelation mux_corr_party0, mux_corr_party1; + + const absl::string_view kSampleSeed = absl::string_view(); + DPF_ASSIGN_OR_RETURN( + auto rng, distributed_point_functions::BasicRng::Create(kSampleSeed)); + + absl::uint128 r_0, r_1, b, r_b; + DPF_ASSIGN_OR_RETURN(r_0, rng->Rand128()); + DPF_ASSIGN_OR_RETURN(r_1, rng->Rand128()); + DPF_ASSIGN_OR_RETURN(b, rng->Rand128()); + b = b & 1; + if (b == 0) r_b = r_0; + else r_b = r_1; + + absl::uint128 s_0, s_1, c, s_c; + DPF_ASSIGN_OR_RETURN(s_0, rng->Rand128()); + DPF_ASSIGN_OR_RETURN(s_1, rng->Rand128()); + DPF_ASSIGN_OR_RETURN(c, rng->Rand128()); + c = c & 1; + if (c == 0) s_c = s_0; + else s_c = s_1; + + mux_corr_party0.rot_sender_first_string = r_0; + mux_corr_party0.rot_sender_second_string = r_1; + mux_corr_party0.rot_receiver_choice_bit = (c ? 1 : 0); + mux_corr_party0.rot_receiver_string = s_c; + + mux_corr_party1.rot_sender_first_string = s_0; + mux_corr_party1.rot_sender_second_string = s_1; + mux_corr_party1.rot_receiver_choice_bit = (b ? 1 : 0); + mux_corr_party1.rot_receiver_string = r_b; + + + return std::make_pair(mux_corr_party0, mux_corr_party1); + + } + + absl::uint128 BlockToUint128(Block x){ + absl::uint128 y = absl::MakeUint128(x.high(),x.low()); + return y; + } + + + + +}; + +} // namespace distributed_point_functions + +#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_KEY_GENERATION_PROTOCOL_KEY_GENERATION_PROTOCOL_H_ diff --git a/dpf/key_generation_protocol/key_generation_protocol.proto b/dpf/key_generation_protocol/key_generation_protocol.proto new file mode 100644 index 0000000..d78f016 --- /dev/null +++ b/dpf/key_generation_protocol/key_generation_protocol.proto @@ -0,0 +1,54 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package distributed_point_functions; + +import "dpf/distributed_point_function.proto"; + +// For faster allocations of sub-messages. +option cc_enable_arenas = true; + +message SeedCorrectionOtSenderMessage { + Block masked_message_one = 1; + Block masked_message_two = 2; +} + +message SeedCorrectionOtReceiverMessage { + bool choice_bit_mask = 1; +} + +message SeedCorrectionShare { + Block seed = 1; + bool control_bit_left = 2; + bool control_bit_right = 3; +} + +message MaskedTau{ + bool masked_tau_zero = 1; +} + +message ValueCorrectionOtSenderMessage { + Value masked_message_one = 1; + Value masked_message_two = 2; +} + +message ValueCorrectionOtReceiverMessage { + bool choice_bit_mask = 1; +} + +message ValueCorrectionShare { + Value value = 1; +} diff --git a/dpf/key_generation_protocol/key_generation_protocol_client.cc b/dpf/key_generation_protocol/key_generation_protocol_client.cc new file mode 100644 index 0000000..ae9584f --- /dev/null +++ b/dpf/key_generation_protocol/key_generation_protocol_client.cc @@ -0,0 +1,465 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include // NOLINT +#define GLOG_NO_ABBREVIATED_SEVERITIES +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/memory/memory.h" +#include "include/grpc/grpc_security_constants.h" +#include "include/grpcpp/grpcpp.h" +#include "dpf/key_generation_protocol/key_generation_protocol.h" +#include "dpf/key_generation_protocol/key_generation_protocol.pb.h" +#include "dpf/key_generation_protocol/key_generation_protocol_rpc.grpc.pb.h" +#include "dpf/key_generation_protocol/key_generation_protocol_rpc.pb.h" +#include "include/grpcpp/security/server_credentials.h" +#include "include/grpcpp/server_builder.h" +#include "include/grpcpp/server_context.h" +#include "include/grpcpp/support/status.h" +#include "absl/status/status.h" +#include "dpf/status_macros.h" +#include "absl/strings/string_view.h" + +ABSL_FLAG(std::string, port, "0.0.0.0:10501", + "Port on which to contact server"); +ABSL_FLAG(size_t, num_levels, 20, + "The number of levels for the DPF, also how many iterations to execute."); + +namespace distributed_point_functions { + +absl::Status ExecuteProtocol() { + // Setup + size_t levels = absl::GetFlag(FLAGS_num_levels); + + // Generate parameters for KeyGenProtocol. + std::vector parameters; + parameters.reserve(levels); + + for (size_t i = 0; i < levels; i++){ + parameters.push_back(DpfParameters()); + parameters[i].set_log_domain_size(i + 1); + parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(64); + } + + std::unique_ptr keygen; + + DPF_ASSIGN_OR_RETURN(keygen, + KeyGenerationProtocol::Create(parameters)); + + std::pair preproc; + + DPF_ASSIGN_OR_RETURN(preproc, + keygen->PerformKeyGenerationPrecomputation()); + + + absl::uint128 alpha = 23; + + // Generating shares of alpha for Party 0 and Party 1 + absl::uint128 alpha_share_party0, alpha_share_party1; + const absl::string_view kSampleSeed = absl::string_view("abcdefg"); + DPF_ASSIGN_OR_RETURN( + auto rng, distributed_point_functions::BasicRng::Create(kSampleSeed)); + + + DPF_ASSIGN_OR_RETURN(alpha_share_party0, rng->Rand128()); + alpha_share_party1 = alpha ^ alpha_share_party0; + + // Generating shares of beta for Party 0 and Party 1 + std::vector beta; + + + for(size_t i = 0; i < levels; i++){ + Value beta_i; + beta_i.mutable_integer()->set_value_uint64(42); + beta.push_back(beta_i); + } + + std::vector beta_shares_party0, beta_shares_party1; + + for (size_t i = 0; i < beta.size(); i++){ + DPF_ASSIGN_OR_RETURN(absl::uint128 beta_share_party0_seed, rng->Rand128()); + + Value value0 = ToValue(static_cast(beta_share_party0_seed)); + + DPF_ASSIGN_OR_RETURN(Value value1, + keygen->ValueSub(beta[i], value0)); + + beta_shares_party0.push_back(value0); + beta_shares_party1.push_back(value1); + } + + // Consider grpc::SslServerCredentials if not running locally. + std::cout << "Client: Creating server stub..." << std::endl; + grpc::ChannelArguments ch_args; + ch_args.SetMaxReceiveMessageSize(-1); // consider limiting max message size + std::unique_ptr stub = + KeyGenerationProtocolRpc::NewStub(::grpc::CreateCustomChannel( + absl::GetFlag(FLAGS_port), grpc::InsecureChannelCredentials(), ch_args)); + std::cout << "Client: Starting KeyGenerationProtocol " << std::endl; + double pzero_time = 0; + double pone_time_incl_comm = 0; + double end_to_end_time = 0; + auto start = std::chrono::high_resolution_clock::now(); + auto client_start = start; + auto client_end = start; + auto server_start = start; + auto server_end = start; + + ::grpc::Status grpc_status; + DPF_ASSIGN_OR_RETURN(ProtocolState protocol_state_party_0, + keygen->Initialize(1, + alpha_share_party0, + beta_shares_party0, + preproc.first)); + grpc::CompletionQueue cq; + + // Initiate server work. + std::cout << "Client: Starting protocol" << std::endl; + + uint64_t cq_index=1; + void* got_tag; + bool ok = false; + + for(size_t i = 0; i < levels; i++) { + std::cout << "Client: Starting iteration " << i + 1 << std::endl; + + // Run Round 1 + client_start = std::chrono::high_resolution_clock::now(); + ::grpc::ClientContext client_context0; + KeyGenerationProtocolClientMessage client_message_0; + *client_message_0.mutable_start_message() = StartMessage(); + KeyGenerationProtocolServerMessage server_message_1; + std::unique_ptr > rpc(stub->AsyncHandle(&client_context0, client_message_0, &cq)); + rpc->Finish(&server_message_1, &grpc_status, (void*)cq_index); + + KeyGenerationProtocolClientMessage client_message_1; + DPF_ASSIGN_OR_RETURN(*client_message_1.mutable_client_round_1_message(), + keygen->ComputeSeedCorrectionOtReceiverMessage( + 0, + protocol_state_party_0)); + + client_end = std::chrono::high_resolution_clock::now(); + pzero_time += (std::chrono::duration_cast( + client_end - client_start).count())/ 1e6; + + server_start = std::chrono::high_resolution_clock::now(); + GPR_ASSERT(cq.Next(&got_tag, &ok)); + GPR_ASSERT(got_tag == (void*) cq_index); + GPR_ASSERT(ok); + + if (!grpc_status.ok()) { + std::cerr << "Client: Failed on message round 1 of level " << i+1 << " with status " << + grpc_status.error_code() << " error_message: " << + grpc_status.error_message() << std::endl; + return absl::UnknownError(""); + } + server_end = std::chrono::high_resolution_clock::now(); + pone_time_incl_comm += + (std::chrono::duration_cast(server_end - + server_start).count())/ 1e6; + + // Run Round 2 + client_start = std::chrono::high_resolution_clock::now(); + ::grpc::ClientContext client_context1; + cq_index++; + ok=false; + KeyGenerationProtocolServerMessage server_message_2; + rpc = std::unique_ptr>(stub->AsyncHandle(&client_context1, client_message_1, &cq)); + rpc->Finish(&server_message_2, &grpc_status, (void*)cq_index); + + KeyGenerationProtocolClientMessage client_message_2; + + DPF_ASSIGN_OR_RETURN(*client_message_2.mutable_client_round_2_message(), + keygen->ComputeSeedCorrectionOtSenderMessage( + 0, + server_message_1.server_round_1_message(), + protocol_state_party_0)); + + client_end = std::chrono::high_resolution_clock::now(); + pzero_time += (std::chrono::duration_cast( + client_end - client_start).count())/ 1e6; + server_start = std::chrono::high_resolution_clock::now(); + + GPR_ASSERT(cq.Next(&got_tag, &ok)); + GPR_ASSERT(got_tag == (void*) cq_index); + GPR_ASSERT(ok); + + if (!grpc_status.ok()) { + std::cerr << "Client: Failed on message round 2 of level " << i+1 << " with status " << + grpc_status.error_code() << " error_message: " << + grpc_status.error_message() << std::endl; + return absl::UnknownError(""); + } + server_end = std::chrono::high_resolution_clock::now(); + pone_time_incl_comm += + (std::chrono::duration_cast(server_end - + server_start).count())/ 1e6; + + // Run Round 3 + client_start = std::chrono::high_resolution_clock::now(); + ::grpc::ClientContext client_context2; + cq_index++; + ok=false; + KeyGenerationProtocolServerMessage server_message_3; + rpc = std::unique_ptr>(stub->AsyncHandle(&client_context2, client_message_2, &cq)); + rpc->Finish(&server_message_3, &grpc_status, (void*)cq_index); + + KeyGenerationProtocolClientMessage client_message_3; + + DPF_ASSIGN_OR_RETURN(*client_message_3.mutable_client_round_3_message(), + keygen->ComputeSeedCorrectionOpening( + 0, + server_message_2.server_round_2_message(), + protocol_state_party_0)); + + client_end = std::chrono::high_resolution_clock::now(); + pzero_time += (std::chrono::duration_cast( + client_end - client_start).count())/ 1e6; + server_start = std::chrono::high_resolution_clock::now(); + + GPR_ASSERT(cq.Next(&got_tag, &ok)); + GPR_ASSERT(got_tag == (void*) cq_index); + GPR_ASSERT(ok); + + if (!grpc_status.ok()) { + std::cerr << "Client: Failed on message round 3 of level " << i+1 << " with status " << + grpc_status.error_code() << " error_message: " << + grpc_status.error_message() << std::endl; + return absl::UnknownError(""); + } + server_end = std::chrono::high_resolution_clock::now(); + pone_time_incl_comm += + (std::chrono::duration_cast(server_end - + server_start).count())/ 1e6; + + + // Run Round 4 + client_start = std::chrono::high_resolution_clock::now(); + ::grpc::ClientContext client_context3; + cq_index++; + ok=false; + KeyGenerationProtocolServerMessage server_message_4; + rpc = std::unique_ptr>(stub->AsyncHandle(&client_context3, client_message_3, &cq)); + rpc->Finish(&server_message_4, &grpc_status, (void*)cq_index); + + KeyGenerationProtocolClientMessage client_message_4; + + DPF_ASSIGN_OR_RETURN(*client_message_4.mutable_client_round_4_message(), + keygen->ApplySeedCorrectionShare( + 0, + server_message_3.server_round_3_message(), + protocol_state_party_0)); + + client_end = std::chrono::high_resolution_clock::now(); + pzero_time += (std::chrono::duration_cast( + client_end - client_start).count())/ 1e6; + server_start = std::chrono::high_resolution_clock::now(); + + GPR_ASSERT(cq.Next(&got_tag, &ok)); + GPR_ASSERT(got_tag == (void*) cq_index); + GPR_ASSERT(ok); + + if (!grpc_status.ok()) { + std::cerr << "Client: Failed on message round 4 of level " << i+1 << " with status " << + grpc_status.error_code() << " error_message: " << + grpc_status.error_message() << std::endl; + return absl::UnknownError(""); + } + server_end = std::chrono::high_resolution_clock::now(); + pone_time_incl_comm += + (std::chrono::duration_cast(server_end - + server_start).count())/ 1e6; + + // Run Round 5 + + client_start = std::chrono::high_resolution_clock::now(); + ::grpc::ClientContext client_context4; + cq_index++; + ok=false; + KeyGenerationProtocolServerMessage server_message_5; + rpc = std::unique_ptr>(stub->AsyncHandle(&client_context4, client_message_4, &cq)); + rpc->Finish(&server_message_5, &grpc_status, (void*)cq_index); + + KeyGenerationProtocolClientMessage client_message_5; + + DPF_ASSIGN_OR_RETURN(*client_message_5.mutable_client_round_5_message(), + keygen->ComputeValueCorrectionOtReceiverMessage( + 0, + server_message_4.server_round_4_message(), + protocol_state_party_0)); + + client_end = std::chrono::high_resolution_clock::now(); + pzero_time += (std::chrono::duration_cast( + client_end - client_start).count())/ 1e6; + server_start = std::chrono::high_resolution_clock::now(); + + GPR_ASSERT(cq.Next(&got_tag, &ok)); + GPR_ASSERT(got_tag == (void*) cq_index); + GPR_ASSERT(ok); + + if (!grpc_status.ok()) { + std::cerr << "Client: Failed on message round 5 of level " << i+1 << " with status " << + grpc_status.error_code() << " error_message: " << + grpc_status.error_message() << std::endl; + return absl::UnknownError(""); + } + server_end = std::chrono::high_resolution_clock::now(); + pone_time_incl_comm += + (std::chrono::duration_cast(server_end - + server_start).count())/ 1e6; + + + // Run Round 6 + + client_start = std::chrono::high_resolution_clock::now(); + ::grpc::ClientContext client_context5; + cq_index++; + ok=false; + KeyGenerationProtocolServerMessage server_message_6; + rpc = std::unique_ptr>(stub->AsyncHandle(&client_context5, client_message_5, &cq)); + rpc->Finish(&server_message_6, &grpc_status, (void*)cq_index); + + KeyGenerationProtocolClientMessage client_message_6; + + DPF_ASSIGN_OR_RETURN(*client_message_6.mutable_client_round_6_message(), + keygen->ComputeValueCorrectionOtSenderMessage( + 0, + server_message_5.server_round_5_message(), + protocol_state_party_0)); + + client_end = std::chrono::high_resolution_clock::now(); + pzero_time += (std::chrono::duration_cast( + client_end - client_start).count())/ 1e6; + server_start = std::chrono::high_resolution_clock::now(); + + GPR_ASSERT(cq.Next(&got_tag, &ok)); + GPR_ASSERT(got_tag == (void*) cq_index); + GPR_ASSERT(ok); + + if (!grpc_status.ok()) { + std::cerr << "Client: Failed on message round 6 of level " << i+1 << " with status " << + grpc_status.error_code() << " error_message: " << + grpc_status.error_message() << std::endl; + return absl::UnknownError(""); + } + server_end = std::chrono::high_resolution_clock::now(); + pone_time_incl_comm += + (std::chrono::duration_cast(server_end - + server_start).count())/ 1e6; + + // Run Round 7 + + client_start = std::chrono::high_resolution_clock::now(); + ::grpc::ClientContext client_context6; + cq_index++; + ok=false; + KeyGenerationProtocolServerMessage server_message_7; + rpc = std::unique_ptr>(stub->AsyncHandle(&client_context6, client_message_6, &cq)); + rpc->Finish(&server_message_7, &grpc_status, (void*)cq_index); + + KeyGenerationProtocolClientMessage client_message_7; + + + DPF_ASSIGN_OR_RETURN(*client_message_7.mutable_client_round_7_message(), + keygen->ComputeValueCorrectionOtShare( + 0, + server_message_6.server_round_6_message(), + protocol_state_party_0)); + + client_end = std::chrono::high_resolution_clock::now(); + pzero_time += (std::chrono::duration_cast( + client_end - client_start).count())/ 1e6; + server_start = std::chrono::high_resolution_clock::now(); + + GPR_ASSERT(cq.Next(&got_tag, &ok)); + GPR_ASSERT(got_tag == (void*) cq_index); + GPR_ASSERT(ok); + + if (!grpc_status.ok()) { + std::cerr << "Client: Failed on message round 7 of level " << i+1 << " with status " << + grpc_status.error_code() << " error_message: " << + grpc_status.error_message() << std::endl; + return absl::UnknownError(""); + } + server_end = std::chrono::high_resolution_clock::now(); + pone_time_incl_comm += + (std::chrono::duration_cast(server_end - + server_start).count())/ 1e6; + + // Compute result + client_start = std::chrono::high_resolution_clock::now(); + ::grpc::ClientContext client_context7; + cq_index++; + ok=false; + KeyGenerationProtocolServerMessage server_message_end; + rpc = std::unique_ptr>(stub->AsyncHandle(&client_context7, client_message_7, &cq)); + rpc->Finish(&server_message_end, &grpc_status, (void*)cq_index); + + DPF_ASSIGN_OR_RETURN(int x, + keygen->ApplyValueCorrectionShare( + 0, + server_message_7.server_round_7_message(), + protocol_state_party_0)); + + client_end = std::chrono::high_resolution_clock::now(); + pzero_time += (std::chrono::duration_cast( + client_end - client_start).count())/ 1e6; + server_start = std::chrono::high_resolution_clock::now(); + + GPR_ASSERT(cq.Next(&got_tag, &ok)); + GPR_ASSERT(got_tag == (void*) cq_index); + GPR_ASSERT(ok); + + if (!grpc_status.ok()) { + std::cerr << "Client: Failed on end message of level " << i+1 << " with status " << + grpc_status.error_code() << " error_message: " << + grpc_status.error_message() << std::endl; + return absl::UnknownError(""); + } + server_end = std::chrono::high_resolution_clock::now(); + pone_time_incl_comm += + (std::chrono::duration_cast(server_end - + server_start).count())/ 1e6; + + } + + auto end = std::chrono::high_resolution_clock::now(); + + // Add in preprocessing phase. For the online phase, since the initial round for client and server can be done at the same time + end_to_end_time = (std::chrono::duration_cast( + end-start).count()) + / 1e6; + // Print results + std::cout << "Completed run" << std::endl << "num_levels=" + << levels << std::endl + << "Client time total (s) =" << pzero_time <has_start_message()) { + DPF_ASSIGN_OR_RETURN(*response->mutable_server_round_1_message(), + keygen_->ComputeSeedCorrectionOtReceiverMessage( + 1, + protocol_state_party_1_)); + } else if(request->has_client_round_1_message()) { + DPF_ASSIGN_OR_RETURN(*response->mutable_server_round_2_message(), + keygen_->ComputeSeedCorrectionOtSenderMessage( + 1, + request->client_round_1_message(), + protocol_state_party_1_)); + } else if(request->has_client_round_2_message()) { + DPF_ASSIGN_OR_RETURN(*response->mutable_server_round_3_message(), + keygen_->ComputeSeedCorrectionOpening( + 1, + request->client_round_2_message(), + protocol_state_party_1_)); + } else if(request->has_client_round_3_message()) { + DPF_ASSIGN_OR_RETURN(*response->mutable_server_round_4_message(), + keygen_->ApplySeedCorrectionShare( + 1, + request->client_round_3_message(), + protocol_state_party_1_)); + } else if(request->has_client_round_4_message()) { + DPF_ASSIGN_OR_RETURN(*response->mutable_server_round_5_message(), + keygen_->ComputeValueCorrectionOtReceiverMessage( + 1, + request->client_round_4_message(), + protocol_state_party_1_)); + } else if(request->has_client_round_5_message()) { + DPF_ASSIGN_OR_RETURN(*response->mutable_server_round_6_message(), + keygen_->ComputeValueCorrectionOtSenderMessage( + 1, + request->client_round_5_message(), + protocol_state_party_1_)); + } else if(request->has_client_round_6_message()) { + DPF_ASSIGN_OR_RETURN(*response->mutable_server_round_7_message(), + keygen_->ComputeValueCorrectionOtShare( + 1, + request->client_round_6_message(), + protocol_state_party_1_)); + } else if(request->has_client_round_7_message()) { + + int x; + DPF_ASSIGN_OR_RETURN(x, + keygen_->ApplyValueCorrectionShare( + 1, + request->client_round_7_message(), + protocol_state_party_1_)); + *response->mutable_end_message() = EndMessage(); + + // Last message needs to update the current level + std::cout << "Server: completed iteration " << current_level_+1 + << std::endl; + current_level_++; + } else { + return absl::InvalidArgumentError(absl::StrCat("KeyGenerationProtocolServer server" + " received an unrecognized message, with case ", + request->client_message_oneof_case())); + } + + total_client_message_size_ += request->ByteSizeLong(); + total_server_message_size_ += response->ByteSizeLong(); + + if(current_level_ == num_levels_) { + std::cout << "Server completed." < +#include + +namespace distributed_point_functions { + +// Implements the Gradient Descent RPC-handling Server. +class KeyGenerationProtocolRpcImpl : public KeyGenerationProtocolRpc::Service { + public: + KeyGenerationProtocolRpcImpl( + std::unique_ptr keygen, size_t num_levels, + absl::uint128 alpha_share_party_1, + std::vector beta_shares_party_1, + KeyGenerationPreprocessing preproc_party_1 + ): keygen_(std::move(keygen)), current_level_(0), num_levels_(num_levels), + alpha_share_party_1_(std::move(alpha_share_party_1)), + beta_shares_party_1_(std::move(beta_shares_party_1)), + preproc_party_1_(std::move(preproc_party_1)) { + protocol_state_party_1_ = keygen_->Initialize(1, + alpha_share_party_1_, + beta_shares_party_1_, + preproc_party_1_).value(); + + } + + // Executes a round of the protocol. + ::grpc::Status Handle(::grpc::ServerContext* context, + const KeyGenerationProtocolClientMessage* request, + KeyGenerationProtocolServerMessage* response) override; + + size_t current_level() { + return current_level_; + } + + private: + // Internal version of Handle, that returns a non-grpc Status. + absl::Status HandleInternal(::grpc::ServerContext* context, + const KeyGenerationProtocolClientMessage* request, + KeyGenerationProtocolServerMessage* response); + + std::unique_ptr keygen_; + + volatile size_t current_level_ = 0; + const size_t num_levels_; + + absl::uint128 alpha_share_party_1_; + std::vector beta_shares_party_1_; + KeyGenerationPreprocessing preproc_party_1_; + ProtocolState protocol_state_party_1_; + + + size_t total_client_message_size_ = 0; + size_t total_server_message_size_ = 0; +}; + +} // namespace distributed_point_functions + +#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_KEY_GENERATION_PROTOCOL_KEY_GENERATION_PROTOCOL_RPC_IMPL_H_ diff --git a/dpf/key_generation_protocol/key_generation_protocol_server.cc b/dpf/key_generation_protocol/key_generation_protocol_server.cc new file mode 100644 index 0000000..e624cca --- /dev/null +++ b/dpf/key_generation_protocol/key_generation_protocol_server.cc @@ -0,0 +1,148 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include // NOLINT +#define GLOG_NO_ABBREVIATED_SEVERITIES +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/memory/memory.h" +#include "include/grpc/grpc_security_constants.h" +#include "include/grpcpp/grpcpp.h" +#include "dpf/key_generation_protocol/key_generation_protocol.h" +#include "dpf/key_generation_protocol/key_generation_protocol_rpc_impl.h" +#include "dpf/key_generation_protocol/key_generation_protocol.pb.h" +#include "dpf/key_generation_protocol/key_generation_protocol_rpc.grpc.pb.h" +#include "dpf/key_generation_protocol/key_generation_protocol_rpc.pb.h" +#include "include/grpcpp/security/server_credentials.h" +#include "include/grpcpp/server_builder.h" +#include "include/grpcpp/server_context.h" +#include "include/grpcpp/support/status.h" +#include "absl/status/status.h" +#include "dpf/status_macros.h" +#include "absl/strings/string_view.h" + +ABSL_FLAG(std::string, port, "0.0.0.0:10501", "Port on which to listen"); +ABSL_FLAG(size_t, num_levels, 20, + "The number of levels for the DPF, also how many iterations to execute."); + +namespace distributed_point_functions { + +absl::Status RunServer() { + std::cout << "Server: starting... " << std::endl; + + size_t levels = absl::GetFlag(FLAGS_num_levels); + + // Generate parameters for KeyGenProtocol. + std::vector parameters; + parameters.reserve(levels); + + for (size_t i = 0; i < levels; i++){ + parameters.push_back(DpfParameters()); + parameters[i].set_log_domain_size(i + 1); + parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(64); + } + + std::unique_ptr keygen; + + DPF_ASSIGN_OR_RETURN(keygen, + KeyGenerationProtocol::Create(parameters)); + + std::pair preproc; + + DPF_ASSIGN_OR_RETURN(preproc, + keygen->PerformKeyGenerationPrecomputation()); + + + absl::uint128 alpha = 23; + + // Generating shares of alpha for Party 0 and Party 1 + absl::uint128 alpha_share_party0, alpha_share_party1; + const absl::string_view kSampleSeed = absl::string_view("abcdefg"); + DPF_ASSIGN_OR_RETURN( + auto rng, distributed_point_functions::BasicRng::Create(kSampleSeed)); + + + DPF_ASSIGN_OR_RETURN(alpha_share_party0, rng->Rand128()); + alpha_share_party1 = alpha ^ alpha_share_party0; + + // Generating shares of beta for Party 0 and Party 1 + std::vector beta; + + + for(size_t i = 0; i < levels; i++){ + Value beta_i; + beta_i.mutable_integer()->set_value_uint64(42); + beta.push_back(beta_i); + } + + std::vector beta_shares_party0, beta_shares_party1; + + for (size_t i = 0; i < beta.size(); i++){ + DPF_ASSIGN_OR_RETURN(absl::uint128 beta_share_party0_seed, rng->Rand128()); + + Value value0 = ToValue(static_cast(beta_share_party0_seed)); + + DPF_ASSIGN_OR_RETURN(Value value1, + keygen->ValueSub(beta[i], value0)); + + beta_shares_party0.push_back(value0); + beta_shares_party1.push_back(value1); + } + + + // Initialize the service + std::unique_ptr service = std::make_unique( + std::move(keygen), levels, std::move(alpha_share_party1), + std::move(beta_shares_party1), std::move(preproc.second) + ); + ::grpc::ServerBuilder builder; + // Consider grpc::SslServerCredentials if not running locally. + builder.AddListeningPort(absl::GetFlag(FLAGS_port), + grpc::InsecureServerCredentials()); + builder.SetMaxReceiveMessageSize(INT_MAX); // consider limiting max message size + builder.RegisterService(service.get()); + std::unique_ptr<::grpc::Server> grpc_server(builder.BuildAndStart()); + // Run the server on a background thread. + + std::thread grpc_server_thread( + [](::grpc::Server* grpc_server_ptr) { + std::cout << "Server: listening on " << absl::GetFlag(FLAGS_port) + << std::endl; + grpc_server_ptr->Wait(); + }, + grpc_server.get()); + while (service->current_level() < levels) { + } + // Shut down server. + grpc_server->Shutdown(); + grpc_server_thread.join(); + std::cout << "Server completed protocol and shut down." << std::endl; + + return absl::OkStatus(); +} + +} // namespace distributed_point_functions + +int main(int argc, char** argv) { + absl::ParseCommandLine(argc, argv); + auto status = ::distributed_point_functions::RunServer(); + if(!status.ok()){ + std::cerr << "Server failed: " << status.message(); + return 1; + } + return 0; +} diff --git a/dpf/key_generation_protocol/key_generation_protocol_test.cc b/dpf/key_generation_protocol/key_generation_protocol_test.cc new file mode 100644 index 0000000..023cc02 --- /dev/null +++ b/dpf/key_generation_protocol/key_generation_protocol_test.cc @@ -0,0 +1,294 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dpf/key_generation_protocol/key_generation_protocol.h" + +#include "dpf/distributed_point_function.pb.h" +#include "dpf/internal/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace distributed_point_functions { +namespace { + +using dpf_internal::IsOkAndHolds; +using dpf_internal::StatusIs; +using ::testing::HasSubstr; +using ::testing::NotNull; + +class KeyGenerationProtocolTest : public testing::Test { + protected: + void SetUp() override { + + levels = 20; + // There will be 2^levels number of leaves in the DPF tree + + parameters_.resize(levels); + + for (int i = 0; i < levels; i++){ + parameters_[i].set_log_domain_size(i + 1); + parameters_[i].mutable_value_type()->mutable_integer()->set_bitsize(64); + } +// parameters_[0].set_log_domain_size(5); +// parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(64); +// parameters_[1].set_log_domain_size(10); +// parameters_[1].mutable_value_type()->mutable_integer()->set_bitsize(64); + } + std::vector parameters_; + int levels; + using T = uint64_t; + + absl::uint128 BlockToUint128(Block x){ + absl::uint128 y = absl::MakeUint128(x.high(),x.low()); + return y; + } + void displayDpfKey(DpfKey key, int levels){ + std::cout << "Party : " << key.party() << std::endl; + std::cout << "Root seed : " << BlockToUint128(key.seed()) << std::endl; + + for(int i = 0; i < levels; i++){ + std::cout << "Level : " << i << std::endl; + std::cout << "seed correction : " << BlockToUint128(key.correction_words(i).seed()) << std::endl; + std::cout << "left control correction : " << key.correction_words(i).control_left() << std::endl; + std::cout << "right control correction : " << key.correction_words(i).control_right() << std::endl; + T v = *(FromValue(key.correction_words(i).value_correction(0))); + std::cout << "Value correction : " << v << std::endl; + } + } +}; + + +// +//TEST_F(KeyGenerationProtocolTest, CreateSucceeds) { +// +// EXPECT_THAT(KeyGenerationProtocol::Create(parameters_), +// IsOkAndHolds(NotNull())); +//} + +//TEST_F(KeyGenerationProtocolTest, CreateFailsIfPartyIsNot0Or1) { +// constexpr int party = 2; +// +// EXPECT_THAT(KeyGenerationProtocol::Create(parameters_, party), +// StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("party"))); +//} + +TEST_F(KeyGenerationProtocolTest, EndToEndSucceeds) { + + std::unique_ptr keygen; + + DPF_ASSERT_OK_AND_ASSIGN(keygen, + KeyGenerationProtocol::Create(parameters_)); + + std::pair preproc; + + DPF_ASSERT_OK_AND_ASSIGN(preproc, + keygen->PerformKeyGenerationPrecomputation()); + + absl::uint128 alpha = 23; + + + // Generating shares of alpha for Party 0 and Party 1 + + absl::uint128 alpha_share_party0, alpha_share_party1; + + const absl::string_view kSampleSeed = absl::string_view(); + DPF_ASSERT_OK_AND_ASSIGN( + auto rng, distributed_point_functions::BasicRng::Create(kSampleSeed)); + + DPF_ASSERT_OK_AND_ASSIGN(alpha_share_party0, rng->Rand128()); + + alpha_share_party1 = alpha ^ alpha_share_party0; + + std::cout << "Party 0 alpha share : " << alpha_share_party0 << std::endl; + + std::cout << "Party 1 alpha share : " << alpha_share_party1 << std::endl; + + // Generating shares of beta for Party 0 and Party 1 + std::vector beta; + + for(int i = 0; i < levels; i++){ + Value beta_i; +// beta_i.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(42); + beta_i.mutable_integer()->set_value_uint64(42); + beta.push_back(beta_i); + } + + std::vector beta_shares_party0, beta_shares_party1; + + for (int i = 0; i < beta.size(); i++){ + DPF_ASSERT_OK_AND_ASSIGN(absl::uint128 beta_share_party0_seed, rng->Rand128()); + + Value value0 = ToValue(static_cast(beta_share_party0_seed)); + + DPF_ASSERT_OK_AND_ASSIGN(Value value1, + keygen->ValueSub(beta[i], value0)); + + beta_shares_party0.push_back(value0); + + beta_shares_party1.push_back(value1); + + // TODO: Look at ValueCorrection implementation in iDPF +// keygen->dpf_->ValueCorrectionFunction func; +// +// DPF_ASSERT_OK_AND_ASSIGN( +// ValueCorrectionFunction func, +// GetValueCorrectionFunction(parameters_[hierarchy_level])); + } + + // Running KeyGen Initialization + + ProtocolState state_party0, state_party1; + + DPF_ASSERT_OK_AND_ASSIGN(state_party0, + keygen->Initialize(0, + alpha_share_party0, + beta_shares_party0, + preproc.first)); + + DPF_ASSERT_OK_AND_ASSIGN(state_party1, + keygen->Initialize(1, + alpha_share_party1, + beta_shares_party1, + preproc.second)); + + // Running KeyGen 2PC offline phase for each level + + for(int i = 0; i < levels; i++) { + + SeedCorrectionOtReceiverMessage round1_party0, round1_party1; + + DPF_ASSERT_OK_AND_ASSIGN(round1_party0, + keygen->ComputeSeedCorrectionOtReceiverMessage( + 0, + state_party0)); + + DPF_ASSERT_OK_AND_ASSIGN(round1_party1, + keygen->ComputeSeedCorrectionOtReceiverMessage( + 1, + state_party1)); + + SeedCorrectionOtSenderMessage round2_party0, round2_party1; + + DPF_ASSERT_OK_AND_ASSIGN(round2_party0, + keygen->ComputeSeedCorrectionOtSenderMessage( + 0, + round1_party1, + state_party0)); + + DPF_ASSERT_OK_AND_ASSIGN(round2_party1, + keygen->ComputeSeedCorrectionOtSenderMessage( + 1, + round1_party0, + state_party1)); + + SeedCorrectionShare round3_party0, round3_party1; + + DPF_ASSERT_OK_AND_ASSIGN(round3_party0, + keygen->ComputeSeedCorrectionOpening( + 0, + round2_party1, + state_party0)); + + + DPF_ASSERT_OK_AND_ASSIGN(round3_party1, + keygen->ComputeSeedCorrectionOpening( + 1, + round2_party0, + state_party1)); + + MaskedTau round4_party0, round4_party1; + + DPF_ASSERT_OK_AND_ASSIGN(round4_party0, + keygen->ApplySeedCorrectionShare( + 0, + round3_party1, + state_party0)); + + DPF_ASSERT_OK_AND_ASSIGN(round4_party1, + keygen->ApplySeedCorrectionShare( + 1, + round3_party0, + state_party1)); + + ValueCorrectionOtReceiverMessage round5_party0, round5_party1; + + DPF_ASSERT_OK_AND_ASSIGN(round5_party0, + keygen->ComputeValueCorrectionOtReceiverMessage( + 0, + round4_party1, + state_party0)); + + DPF_ASSERT_OK_AND_ASSIGN(round5_party1, + keygen->ComputeValueCorrectionOtReceiverMessage( + 1, + round4_party0, + state_party1)); + + ValueCorrectionOtSenderMessage round6_party0, round6_party1; + + DPF_ASSERT_OK_AND_ASSIGN(round6_party0, + keygen->ComputeValueCorrectionOtSenderMessage( + 0, + round5_party1, + state_party0)); + + DPF_ASSERT_OK_AND_ASSIGN(round6_party1, + keygen->ComputeValueCorrectionOtSenderMessage( + 0, + round5_party0, + state_party1)); + + + ValueCorrectionShare round7_party0, round7_party1; + + DPF_ASSERT_OK_AND_ASSIGN(round7_party0, + keygen->ComputeValueCorrectionOtShare( + 0, + round6_party1, + state_party0)); + + DPF_ASSERT_OK_AND_ASSIGN(round7_party1, + keygen->ComputeValueCorrectionOtShare( + 1, + round6_party0, + state_party1)); + + int x; + + DPF_ASSERT_OK_AND_ASSIGN(x, + keygen->ApplyValueCorrectionShare( + 0, + round7_party1, + state_party0)); + + DPF_ASSERT_OK_AND_ASSIGN(x, + keygen->ApplyValueCorrectionShare( + 1, + round7_party0, + state_party1)); + } + + std::cout << "\n\n"; + + displayDpfKey(state_party0.key, levels); + + std::cout << "\n\n"; + + displayDpfKey(state_party1.key, levels); + + + } +} // namespace +} // namespace distributed_point_functions +