diff --git a/src/server/common.cc b/src/server/common.cc index 44c72c296..ac6e8af94 100644 --- a/src/server/common.cc +++ b/src/server/common.cc @@ -427,4 +427,35 @@ std::ostream& operator<<(std::ostream& os, const GlobalState& state) { return os << GlobalStateName(state); } +NonUniquePicksGenerator::NonUniquePicksGenerator(RandomPick max_range) : max_range_(max_range) { + CHECK_GT(max_range, RandomPick(0)); +} + +RandomPick NonUniquePicksGenerator::Generate() { + return absl::Uniform(bitgen_, 0u, max_range_); +} + +UniquePicksGenerator::UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range) + : remaining_picks_count_(picks_count), picked_indexes_(picks_count) { + CHECK_GE(max_range, picks_count); + current_random_limit_ = max_range - picks_count; +} + +RandomPick UniquePicksGenerator::Generate() { + DCHECK_GT(remaining_picks_count_, 0u); + + remaining_picks_count_--; + + const RandomPick max_index = current_random_limit_++; + const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u); + + const bool random_index_is_picked = picked_indexes_.emplace(random_index).second; + if (random_index_is_picked) { + return random_index; + } + + picked_indexes_.insert(max_index); + return max_index; +} + } // namespace dfly diff --git a/src/server/common.h b/src/server/common.h index 563ed7119..ab4988507 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include #include @@ -309,4 +310,47 @@ struct MemoryBytesFlag { bool AbslParseFlag(std::string_view in, dfly::MemoryBytesFlag* flag, std::string* err); std::string AbslUnparseFlag(const dfly::MemoryBytesFlag& flag); +using RandomPick = std::uint32_t; + +class PicksGenerator { + public: + virtual RandomPick Generate() = 0; + virtual ~PicksGenerator() = default; +}; + +class NonUniquePicksGenerator : public PicksGenerator { + public: + /* The generated value will be within the closed-open interval [0, max_range) */ + NonUniquePicksGenerator(RandomPick max_range); + + RandomPick Generate() override; + + private: + const RandomPick max_range_; + absl::BitGen bitgen_{}; +}; + +/* + * Generates unique index in O(1). + * + * picks_count specifies the number of random indexes to be generated. + * In other words, this is the number of times the Generate() function is called. + * + * The class uses Robert Floyd's sampling algorithm + * https://dl.acm.org/doi/pdf/10.1145/30401.315746 + * */ +class UniquePicksGenerator : public PicksGenerator { + public: + /* The generated value will be within the closed-open interval [0, max_range) */ + UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range); + + RandomPick Generate() override; + + private: + RandomPick current_random_limit_; + std::uint32_t remaining_picks_count_; + std::unordered_set picked_indexes_; + absl::BitGen bitgen_{}; +}; + } // namespace dfly diff --git a/src/server/set_family.cc b/src/server/set_family.cc index 5e046a00f..561cd32f7 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -284,22 +284,62 @@ void InterStrSet(const DbContext& db_context, const vector& vec, String } } -StringVec PopStrSet(const DbContext& db_context, unsigned count, const SetType& st) { - StringVec result; +StringVec RandMemberStrSet(const DbContext& db_context, const CompactObj& co, + PicksGenerator& generator, std::size_t picks_count) { + CHECK(IsDenseEncoding(co)); - if (true) { - StringSet* ss = (StringSet*)st.first; - ss->set_time(MemberTimeSeconds(db_context.time_now_ms)); - - // TODO: this loop is inefficient because Pop searches again and again an occupied bucket. - for (unsigned i = 0; i < count && !ss->Empty(); ++i) { - result.push_back(ss->Pop().value()); - } + std::unordered_map times_index_is_picked; + for (std::size_t i = 0; i < picks_count; i++) { + times_index_is_picked[generator.Generate()]++; } + StringVec result; + result.reserve(picks_count); + + StringSet* ss = static_cast(co.RObjPtr()); + ss->set_time(MemberTimeSeconds(db_context.time_now_ms)); + + std::uint32_t ss_entry_index = 0; + container_utils::IterateSet( + co, [&result, ×_index_is_picked, &ss_entry_index](container_utils::ContainerEntry ce) { + auto it = times_index_is_picked.find(ss_entry_index++); + if (it != times_index_is_picked.end()) { + std::uint32_t t = it->second; + while (t--) { + result.emplace_back(ce.ToString()); + } + } + return true; + }); + + /* Equal elements in the result are always successive. So, it is necessary to shuffle them */ + absl::BitGen gen; + std::shuffle(result.begin(), result.end(), gen); + return result; } +StringVec RandMemberSet(const DbContext& db_context, const CompactObj& co, + PicksGenerator& generator, std::size_t picks_count) { + if (co.Encoding() == kEncodingIntSet) { + intset* is = static_cast(co.RObjPtr()); + + StringVec result; + result.reserve(picks_count); + + for (std::size_t i = 0; i < picks_count; i++) { + const std::size_t picked_index = generator.Generate(); + + int64_t value = 0; + CHECK_GT(intsetGet(is, picked_index, &value), std::uint8_t(0)); + + result.push_back(absl::StrCat(value)); + } + return result; + } + return RandMemberStrSet(db_context, co, generator, picks_count); +} + vector ToVec(absl::flat_hash_set&& set) { vector result(set.size()); size_t i = 0; @@ -819,69 +859,92 @@ OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_f return result; } -// count - how many elements to pop. -OpResult OpPop(const OpArgs& op_args, string_view key, unsigned count) { - auto& db_slice = op_args.shard->db_slice(); - auto find_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET); +OpResult OpRandMember(const OpArgs& op_args, std::string_view key, int count) { + auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET); if (!find_res) return find_res.status(); - StringVec result; - if (count == 0) - return result; + const CompactObj& co = find_res.value()->second; - auto it = find_res->it; - size_t slen = it->second.Size(); + const std::uint32_t size = co.Size(); + const bool picks_are_unique = count >= 0; + const std::uint32_t picks_count = + picks_are_unique ? std::min(static_cast(count), size) : std::abs(count); + + auto generator = [picks_are_unique, picks_count, size]() -> std::unique_ptr { + if (picks_are_unique) { + return std::make_unique(picks_count, size); + } else { + return std::make_unique(size); + } + }(); + + return RandMemberSet(op_args.db_cntx, co, *generator, picks_count); +} + +// count - how many elements to pop. +OpResult OpPop(const OpArgs& op_args, string_view key, unsigned count) { + auto& db_cntx = op_args.db_cntx; + auto& db_slice = op_args.shard->db_slice(); + auto find_res = db_slice.FindMutable(db_cntx, key, OBJ_SET); + if (!find_res) { + return find_res.status(); + } + + CompactObj& co = find_res->it->second; + + const std::uint32_t size = co.Size(); + const std::uint32_t picks_count = std::min(count, size); /* CASE 1: * The number of requested elements is greater than or equal to * the number of elements inside the set: simply return the whole set. */ - if (count >= slen) { - PrimeValue& pv = it->second; - if (IsDenseEncoding(pv)) { - StringSet* ss = (StringSet*)pv.RObjPtr(); + if (count >= size) { + if (IsDenseEncoding(co)) { + StringSet* ss = (StringSet*)co.RObjPtr(); ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms)); } - container_utils::IterateSet(it->second, [&result](container_utils::ContainerEntry ce) { + StringVec result; + result.reserve(picks_count); + + container_utils::IterateSet(co, [&result](container_utils::ContainerEntry ce) { result.push_back(ce.ToString()); return true; }); // Delete the set as it is now empty find_res->post_updater.Run(); - CHECK(db_slice.Del(op_args.db_cntx.db_index, it)); + CHECK(db_slice.Del(op_args.db_cntx.db_index, find_res->it)); // Replicate as DEL. if (op_args.shard->journal()) { RecordJournal(op_args, "DEL"sv, ArgSlice{key}); } - } else { - SetType st{it->second.RObjPtr(), it->second.Encoding()}; - if (st.second == kEncodingIntSet) { - intset* is = (intset*)st.first; - int64_t val = 0; - - // copy last count values. - for (uint32_t i = slen - count; i < slen; ++i) { - intsetGet(is, i, &val); - result.push_back(absl::StrCat(val)); - } - - is = intsetTrimTail(is, count); // now remove last count items - it->second.SetRObjPtr(is); - } else { - result = PopStrSet(op_args.db_cntx, count, st); - } - - // Replicate as SREM with removed keys, because SPOP is not deterministic. - if (op_args.shard->journal()) { - vector mapped(result.size() + 1); - mapped[0] = key; - std::copy(result.begin(), result.end(), mapped.begin() + 1); - RecordJournal(op_args, "SREM"sv, mapped); - } + return result; } + + /* CASE 2: + * The number of requested elements is less than the number of elements inside the set. + * In this case, we need to select random members from the set and then remove them. */ + UniquePicksGenerator generator{picks_count, size}; + + // Select random members + StringVec result = RandMemberSet(db_cntx, co, generator, picks_count); + + // Remove selected members + std::vector members_to_remove{result.begin(), result.end()}; + bool is_empty = RemoveSet(db_cntx, members_to_remove, &co).second; + find_res->post_updater.Run(); + + CHECK(!is_empty); + + // Replicate as SREM with removed keys, because SPOP is not deterministic. + if (op_args.shard->journal()) { + members_to_remove.insert(members_to_remove.begin(), key); + RecordJournal(op_args, "SREM"sv, members_to_remove); + } + return result; } @@ -1204,41 +1267,13 @@ void SRandMember(CmdArgList args, ConnectionContext* cntx) { if (auto err = parser.Error(); err) return cntx->SendError(err->MakeReply()); - const unsigned ucount = std::abs(count); - const auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult { - StringVec result; - auto find_res = shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET); - if (!find_res) { - return find_res.status(); - } - - const PrimeValue& pv = find_res.value()->second; - if (IsDenseEncoding(pv)) { - StringSet* ss = (StringSet*)pv.RObjPtr(); - ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms)); - } - - container_utils::IterateSet(find_res.value()->second, - [&result, ucount](container_utils::ContainerEntry ce) { - if (result.size() < ucount) { - result.push_back(ce.ToString()); - return true; - } - return false; - }); - return result; + return OpRandMember(t->GetOpArgs(shard), key, count); }; OpResult result = cntx->transaction->ScheduleSingleHopT(cb); auto* rb = static_cast(cntx->reply_builder()); if (result) { - if (count < 0 && !result->empty()) { - for (auto i = result->size(); i < ucount; ++i) { - // we can return duplicate elements, so first is OK - result->push_back(result->front()); - } - } rb->SendStringArr(*result, RedisReplyBuilder::SET); } else if (result.status() == OpStatus::KEY_NOTFOUND) { if (is_count) { diff --git a/src/server/set_family_test.cc b/src/server/set_family_test.cc index 024d2f9be..6328891f5 100644 --- a/src/server/set_family_test.cc +++ b/src/server/set_family_test.cc @@ -21,6 +21,20 @@ class SetFamilyTest : public BaseFamilyTest { protected: }; +MATCHER_P(ConsistsOfMatcher, elements, "") { + auto vec = arg.GetVec(); + for (const auto& x : vec) { + if (elements.find(x.GetString()) == elements.end()) { + return false; + } + } + return true; +} + +auto ConsistsOf(std::initializer_list elements) { + return ConsistsOfMatcher(std::unordered_set{elements}); +} + TEST_F(SetFamilyTest, SAdd) { auto resp = Run({"sadd", "x", "1", "2", "3"}); EXPECT_THAT(resp, IntArg(3)); @@ -159,34 +173,107 @@ TEST_F(SetFamilyTest, SPop) { } TEST_F(SetFamilyTest, SRandMember) { - auto resp = Run({"sadd", "x", "1", "2", "3"}); - resp = Run({"SRandMember", "x"}); + // Test IntSet + Run({"sadd", "x", "1", "2", "3"}); + + // Test if count > 0 (IntSet) + auto resp = Run({"SRandMember", "x"}); ASSERT_THAT(resp, ArgType(RespExpr::STRING)); - EXPECT_THAT(resp, "1"); + EXPECT_THAT(resp, AnyOf("1", "2", "3")); + + resp = Run({"SRandMember", "x", "1"}); + ASSERT_THAT(resp, ArgType(RespExpr::STRING)); + EXPECT_THAT(resp, AnyOf("1", "2", "3")); resp = Run({"SRandMember", "x", "2"}); - ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); - EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2")); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec(), IsSubsetOf({"1", "2", "3"})); - resp = Run({"SRandMember", "x", "0"}); - ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); - EXPECT_EQ(resp.GetVec().size(), 0); - - resp = Run({"SRandMember", "k"}); - ASSERT_THAT(resp, ArgType(RespExpr::NIL)); - - resp = Run({"SRandMember", "k", "2"}); - ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); - EXPECT_EQ(resp.GetVec().size(), 0); - - resp = Run({"SRandMember", "x", "-5"}); - ASSERT_THAT(resp, ArrLen(5)); - EXPECT_THAT(resp.GetVec(), ElementsAre("1", "2", "3", "1", "1")); - - resp = Run({"SRandMember", "x", "5"}); + resp = Run({"SRandMember", "x", "3"}); ASSERT_THAT(resp, ArrLen(3)); EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2", "3")); + // Test if count is larger than the size of the IntSet + resp = Run({"SRandMember", "x", "25"}); + ASSERT_THAT(resp, ArrLen(3)); + EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2", "3")); + + // Test if count < 0 (IntSet) + resp = Run({"SRandMember", "x", "-1"}); + ASSERT_THAT(resp, ArgType(RespExpr::STRING)); + EXPECT_THAT(resp, AnyOf("1", "2", "3")); + + resp = Run({"SRandMember", "x", "-2"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp, ConsistsOf({"1", "2", "3"})); + + resp = Run({"SRandMember", "x", "-3"}); + ASSERT_THAT(resp, ArrLen(3)); + EXPECT_THAT(resp, ConsistsOf({"1", "2", "3"})); + + // Test if count < 0, but the absolute value is larger than the size of the IntSet + resp = Run({"SRandMember", "x", "-25"}); + ASSERT_THAT(resp, ArrLen(25)); + EXPECT_THAT(resp, ConsistsOf({"1", "2", "3"})); + + // Test StrSet + Run({"sadd", "y", "a", "b", "c"}); + + // Test if count > 0 (StrSet) + resp = Run({"SRandMember", "y"}); + ASSERT_THAT(resp, ArgType(RespExpr::STRING)); + EXPECT_THAT(resp, AnyOf("a", "b", "c")); + + resp = Run({"SRandMember", "y", "1"}); + ASSERT_THAT(resp, ArgType(RespExpr::STRING)); + EXPECT_THAT(resp, AnyOf("a", "b", "c")); + + resp = Run({"SRandMember", "y", "2"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"})); + + resp = Run({"SRandMember", "y", "3"}); + ASSERT_THAT(resp, ArrLen(3)); + EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c")); + + // Test if count is larger than the size of the StrSet + resp = Run({"SRandMember", "y", "25"}); + ASSERT_THAT(resp, ArrLen(3)); + EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c")); + + // Test if count < 0 (StrSet) + resp = Run({"SRandMember", "y", "-1"}); + ASSERT_THAT(resp, ArgType(RespExpr::STRING)); + EXPECT_THAT(resp, AnyOf("a", "b", "c")); + + resp = Run({"SRandMember", "y", "-2"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"})); + + resp = Run({"SRandMember", "y", "-3"}); + ASSERT_THAT(resp, ArrLen(3)); + EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"})); + + // Test if count < 0, but the absolute value is larger than the size of the StrSet + resp = Run({"SRandMember", "y", "-25"}); + ASSERT_THAT(resp, ArrLen(25)); + EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"})); + + // Test if count is 0 + ASSERT_THAT(Run({"SRandMember", "x", "0"}), ArrLen(0)); + + // Test if set is empty + EXPECT_THAT(Run({"SAdd", "empty::set", "1"}), IntArg(1)); + EXPECT_THAT(Run({"SRem", "empty::set", "1"}), IntArg(1)); + ASSERT_THAT(Run({"SRandMember", "empty::set", "0"}), ArrLen(0)); + ASSERT_THAT(Run({"SRandMember", "empty::set", "3"}), ArrLen(0)); + ASSERT_THAT(Run({"SRandMember", "empty::set", "-4"}), ArrLen(0)); + + // Test if key does not exist + ASSERT_THAT(Run({"SRandMember", "unknown::set"}), ArgType(RespExpr::NIL)); + ASSERT_THAT(Run({"SRandMember", "unknown::set", "0"}), ArrLen(0)); + + // Test wrong arguments resp = Run({"SRandMember", "x", "5", "3"}); EXPECT_THAT(resp, ErrArg("wrong number of arguments")); } diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 1ab3443aa..c174aaec1 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -222,70 +222,6 @@ OpResult FindZEntry(const ZParams& zparams, const OpArgs& return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)}; } -using RandomPick = std::uint32_t; - -class PicksGenerator { - public: - virtual RandomPick Generate() = 0; - virtual ~PicksGenerator() = default; -}; - -class NonUniquePicksGenerator : public PicksGenerator { - public: - NonUniquePicksGenerator(RandomPick max_range) : max_range_(max_range) { - CHECK_GT(max_range, RandomPick(0)); - } - - RandomPick Generate() override { - return absl::Uniform(bitgen_, 0u, max_range_); - } - - private: - const RandomPick max_range_; - absl::BitGen bitgen_{}; -}; - -/* - * Generates unique index in O(1). - * - * picks_count specifies the number of random indexes to be generated. - * In other words, this is the number of times the Generate() function is called. - * - * The class uses Robert Floyd's sampling algorithm - * https://dl.acm.org/doi/pdf/10.1145/30401.315746 - * */ -class UniquePicksGenerator : public PicksGenerator { - public: - UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range) - : remaining_picks_count_(picks_count), picked_indexes_(picks_count) { - CHECK_GE(max_range, picks_count); - current_random_limit_ = max_range - picks_count; - } - - RandomPick Generate() override { - DCHECK_GT(remaining_picks_count_, 0u); - - remaining_picks_count_--; - - const RandomPick max_index = current_random_limit_++; - const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u); - - const bool random_index_is_picked = picked_indexes_.emplace(random_index).second; - if (random_index_is_picked) { - return random_index; - } - - picked_indexes_.insert(max_index); - return max_index; - } - - private: - RandomPick current_random_limit_; - std::uint32_t remaining_picks_count_; - std::unordered_set picked_indexes_; - absl::BitGen bitgen_{}; -}; - bool ScoreToLongLat(const std::optional& val, double* xy) { if (!val.has_value()) return false;