fix(set): fix random in SRANDMEMBER and SPOP commands (#3022)

Signed-off-by: Stepan Bagritsevich <sbagritsevich@quantumbrains.com>
This commit is contained in:
Stepan Bagritsevich 2024-05-13 11:08:01 +04:00 committed by GitHub
parent 4cd142d42c
commit d3a585113f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 297 additions and 164 deletions

View File

@ -427,4 +427,35 @@ std::ostream& operator<<(std::ostream& os, const GlobalState& state) {
return os << GlobalStateName(state);
}
NonUniquePicksGenerator::NonUniquePicksGenerator(RandomPick max_range) : max_range_(max_range) {
CHECK_GT(max_range, RandomPick(0));
}
RandomPick NonUniquePicksGenerator::Generate() {
return absl::Uniform(bitgen_, 0u, max_range_);
}
UniquePicksGenerator::UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range)
: remaining_picks_count_(picks_count), picked_indexes_(picks_count) {
CHECK_GE(max_range, picks_count);
current_random_limit_ = max_range - picks_count;
}
RandomPick UniquePicksGenerator::Generate() {
DCHECK_GT(remaining_picks_count_, 0u);
remaining_picks_count_--;
const RandomPick max_index = current_random_limit_++;
const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u);
const bool random_index_is_picked = picked_indexes_.emplace(random_index).second;
if (random_index_is_picked) {
return random_index;
}
picked_indexes_.insert(max_index);
return max_index;
}
} // namespace dfly

View File

@ -4,6 +4,7 @@
#pragma once
#include <absl/random/random.h>
#include <absl/strings/ascii.h>
#include <absl/strings/str_cat.h>
#include <absl/types/span.h>
@ -309,4 +310,47 @@ struct MemoryBytesFlag {
bool AbslParseFlag(std::string_view in, dfly::MemoryBytesFlag* flag, std::string* err);
std::string AbslUnparseFlag(const dfly::MemoryBytesFlag& flag);
using RandomPick = std::uint32_t;
class PicksGenerator {
public:
virtual RandomPick Generate() = 0;
virtual ~PicksGenerator() = default;
};
class NonUniquePicksGenerator : public PicksGenerator {
public:
/* The generated value will be within the closed-open interval [0, max_range) */
NonUniquePicksGenerator(RandomPick max_range);
RandomPick Generate() override;
private:
const RandomPick max_range_;
absl::BitGen bitgen_{};
};
/*
* Generates unique index in O(1).
*
* picks_count specifies the number of random indexes to be generated.
* In other words, this is the number of times the Generate() function is called.
*
* The class uses Robert Floyd's sampling algorithm
* https://dl.acm.org/doi/pdf/10.1145/30401.315746
* */
class UniquePicksGenerator : public PicksGenerator {
public:
/* The generated value will be within the closed-open interval [0, max_range) */
UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range);
RandomPick Generate() override;
private:
RandomPick current_random_limit_;
std::uint32_t remaining_picks_count_;
std::unordered_set<RandomPick> picked_indexes_;
absl::BitGen bitgen_{};
};
} // namespace dfly

View File

@ -284,22 +284,62 @@ void InterStrSet(const DbContext& db_context, const vector<SetType>& vec, String
}
}
StringVec PopStrSet(const DbContext& db_context, unsigned count, const SetType& st) {
StringVec result;
StringVec RandMemberStrSet(const DbContext& db_context, const CompactObj& co,
PicksGenerator& generator, std::size_t picks_count) {
CHECK(IsDenseEncoding(co));
if (true) {
StringSet* ss = (StringSet*)st.first;
ss->set_time(MemberTimeSeconds(db_context.time_now_ms));
// TODO: this loop is inefficient because Pop searches again and again an occupied bucket.
for (unsigned i = 0; i < count && !ss->Empty(); ++i) {
result.push_back(ss->Pop().value());
}
std::unordered_map<RandomPick, std::uint32_t> times_index_is_picked;
for (std::size_t i = 0; i < picks_count; i++) {
times_index_is_picked[generator.Generate()]++;
}
StringVec result;
result.reserve(picks_count);
StringSet* ss = static_cast<StringSet*>(co.RObjPtr());
ss->set_time(MemberTimeSeconds(db_context.time_now_ms));
std::uint32_t ss_entry_index = 0;
container_utils::IterateSet(
co, [&result, &times_index_is_picked, &ss_entry_index](container_utils::ContainerEntry ce) {
auto it = times_index_is_picked.find(ss_entry_index++);
if (it != times_index_is_picked.end()) {
std::uint32_t t = it->second;
while (t--) {
result.emplace_back(ce.ToString());
}
}
return true;
});
/* Equal elements in the result are always successive. So, it is necessary to shuffle them */
absl::BitGen gen;
std::shuffle(result.begin(), result.end(), gen);
return result;
}
StringVec RandMemberSet(const DbContext& db_context, const CompactObj& co,
PicksGenerator& generator, std::size_t picks_count) {
if (co.Encoding() == kEncodingIntSet) {
intset* is = static_cast<intset*>(co.RObjPtr());
StringVec result;
result.reserve(picks_count);
for (std::size_t i = 0; i < picks_count; i++) {
const std::size_t picked_index = generator.Generate();
int64_t value = 0;
CHECK_GT(intsetGet(is, picked_index, &value), std::uint8_t(0));
result.push_back(absl::StrCat(value));
}
return result;
}
return RandMemberStrSet(db_context, co, generator, picks_count);
}
vector<string> ToVec(absl::flat_hash_set<string>&& set) {
vector<string> result(set.size());
size_t i = 0;
@ -819,69 +859,92 @@ OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_f
return result;
}
// count - how many elements to pop.
OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count) {
auto& db_slice = op_args.shard->db_slice();
auto find_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET);
OpResult<StringVec> OpRandMember(const OpArgs& op_args, std::string_view key, int count) {
auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET);
if (!find_res)
return find_res.status();
StringVec result;
if (count == 0)
return result;
const CompactObj& co = find_res.value()->second;
auto it = find_res->it;
size_t slen = it->second.Size();
const std::uint32_t size = co.Size();
const bool picks_are_unique = count >= 0;
const std::uint32_t picks_count =
picks_are_unique ? std::min(static_cast<std::uint32_t>(count), size) : std::abs(count);
auto generator = [picks_are_unique, picks_count, size]() -> std::unique_ptr<PicksGenerator> {
if (picks_are_unique) {
return std::make_unique<UniquePicksGenerator>(picks_count, size);
} else {
return std::make_unique<NonUniquePicksGenerator>(size);
}
}();
return RandMemberSet(op_args.db_cntx, co, *generator, picks_count);
}
// count - how many elements to pop.
OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count) {
auto& db_cntx = op_args.db_cntx;
auto& db_slice = op_args.shard->db_slice();
auto find_res = db_slice.FindMutable(db_cntx, key, OBJ_SET);
if (!find_res) {
return find_res.status();
}
CompactObj& co = find_res->it->second;
const std::uint32_t size = co.Size();
const std::uint32_t picks_count = std::min(count, size);
/* CASE 1:
* The number of requested elements is greater than or equal to
* the number of elements inside the set: simply return the whole set. */
if (count >= slen) {
PrimeValue& pv = it->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
if (count >= size) {
if (IsDenseEncoding(co)) {
StringSet* ss = (StringSet*)co.RObjPtr();
ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms));
}
container_utils::IterateSet(it->second, [&result](container_utils::ContainerEntry ce) {
StringVec result;
result.reserve(picks_count);
container_utils::IterateSet(co, [&result](container_utils::ContainerEntry ce) {
result.push_back(ce.ToString());
return true;
});
// Delete the set as it is now empty
find_res->post_updater.Run();
CHECK(db_slice.Del(op_args.db_cntx.db_index, it));
CHECK(db_slice.Del(op_args.db_cntx.db_index, find_res->it));
// Replicate as DEL.
if (op_args.shard->journal()) {
RecordJournal(op_args, "DEL"sv, ArgSlice{key});
}
} else {
SetType st{it->second.RObjPtr(), it->second.Encoding()};
if (st.second == kEncodingIntSet) {
intset* is = (intset*)st.first;
int64_t val = 0;
// copy last count values.
for (uint32_t i = slen - count; i < slen; ++i) {
intsetGet(is, i, &val);
result.push_back(absl::StrCat(val));
}
is = intsetTrimTail(is, count); // now remove last count items
it->second.SetRObjPtr(is);
} else {
result = PopStrSet(op_args.db_cntx, count, st);
}
// Replicate as SREM with removed keys, because SPOP is not deterministic.
if (op_args.shard->journal()) {
vector<string_view> mapped(result.size() + 1);
mapped[0] = key;
std::copy(result.begin(), result.end(), mapped.begin() + 1);
RecordJournal(op_args, "SREM"sv, mapped);
}
return result;
}
/* CASE 2:
* The number of requested elements is less than the number of elements inside the set.
* In this case, we need to select random members from the set and then remove them. */
UniquePicksGenerator generator{picks_count, size};
// Select random members
StringVec result = RandMemberSet(db_cntx, co, generator, picks_count);
// Remove selected members
std::vector<std::string_view> members_to_remove{result.begin(), result.end()};
bool is_empty = RemoveSet(db_cntx, members_to_remove, &co).second;
find_res->post_updater.Run();
CHECK(!is_empty);
// Replicate as SREM with removed keys, because SPOP is not deterministic.
if (op_args.shard->journal()) {
members_to_remove.insert(members_to_remove.begin(), key);
RecordJournal(op_args, "SREM"sv, members_to_remove);
}
return result;
}
@ -1204,41 +1267,13 @@ void SRandMember(CmdArgList args, ConnectionContext* cntx) {
if (auto err = parser.Error(); err)
return cntx->SendError(err->MakeReply());
const unsigned ucount = std::abs(count);
const auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<StringVec> {
StringVec result;
auto find_res = shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET);
if (!find_res) {
return find_res.status();
}
const PrimeValue& pv = find_res.value()->second;
if (IsDenseEncoding(pv)) {
StringSet* ss = (StringSet*)pv.RObjPtr();
ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms));
}
container_utils::IterateSet(find_res.value()->second,
[&result, ucount](container_utils::ContainerEntry ce) {
if (result.size() < ucount) {
result.push_back(ce.ToString());
return true;
}
return false;
});
return result;
return OpRandMember(t->GetOpArgs(shard), key, count);
};
OpResult<StringVec> result = cntx->transaction->ScheduleSingleHopT(cb);
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
if (result) {
if (count < 0 && !result->empty()) {
for (auto i = result->size(); i < ucount; ++i) {
// we can return duplicate elements, so first is OK
result->push_back(result->front());
}
}
rb->SendStringArr(*result, RedisReplyBuilder::SET);
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
if (is_count) {

View File

@ -21,6 +21,20 @@ class SetFamilyTest : public BaseFamilyTest {
protected:
};
MATCHER_P(ConsistsOfMatcher, elements, "") {
auto vec = arg.GetVec();
for (const auto& x : vec) {
if (elements.find(x.GetString()) == elements.end()) {
return false;
}
}
return true;
}
auto ConsistsOf(std::initializer_list<std::string> elements) {
return ConsistsOfMatcher(std::unordered_set<std::string>{elements});
}
TEST_F(SetFamilyTest, SAdd) {
auto resp = Run({"sadd", "x", "1", "2", "3"});
EXPECT_THAT(resp, IntArg(3));
@ -159,34 +173,107 @@ TEST_F(SetFamilyTest, SPop) {
}
TEST_F(SetFamilyTest, SRandMember) {
auto resp = Run({"sadd", "x", "1", "2", "3"});
resp = Run({"SRandMember", "x"});
// Test IntSet
Run({"sadd", "x", "1", "2", "3"});
// Test if count > 0 (IntSet)
auto resp = Run({"SRandMember", "x"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, "1");
EXPECT_THAT(resp, AnyOf("1", "2", "3"));
resp = Run({"SRandMember", "x", "1"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("1", "2", "3"));
resp = Run({"SRandMember", "x", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2"));
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec(), IsSubsetOf({"1", "2", "3"}));
resp = Run({"SRandMember", "x", "0"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);
resp = Run({"SRandMember", "k"});
ASSERT_THAT(resp, ArgType(RespExpr::NIL));
resp = Run({"SRandMember", "k", "2"});
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
EXPECT_EQ(resp.GetVec().size(), 0);
resp = Run({"SRandMember", "x", "-5"});
ASSERT_THAT(resp, ArrLen(5));
EXPECT_THAT(resp.GetVec(), ElementsAre("1", "2", "3", "1", "1"));
resp = Run({"SRandMember", "x", "5"});
resp = Run({"SRandMember", "x", "3"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2", "3"));
// Test if count is larger than the size of the IntSet
resp = Run({"SRandMember", "x", "25"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2", "3"));
// Test if count < 0 (IntSet)
resp = Run({"SRandMember", "x", "-1"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("1", "2", "3"));
resp = Run({"SRandMember", "x", "-2"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp, ConsistsOf({"1", "2", "3"}));
resp = Run({"SRandMember", "x", "-3"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp, ConsistsOf({"1", "2", "3"}));
// Test if count < 0, but the absolute value is larger than the size of the IntSet
resp = Run({"SRandMember", "x", "-25"});
ASSERT_THAT(resp, ArrLen(25));
EXPECT_THAT(resp, ConsistsOf({"1", "2", "3"}));
// Test StrSet
Run({"sadd", "y", "a", "b", "c"});
// Test if count > 0 (StrSet)
resp = Run({"SRandMember", "y"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("a", "b", "c"));
resp = Run({"SRandMember", "y", "1"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("a", "b", "c"));
resp = Run({"SRandMember", "y", "2"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"}));
resp = Run({"SRandMember", "y", "3"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));
// Test if count is larger than the size of the StrSet
resp = Run({"SRandMember", "y", "25"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));
// Test if count < 0 (StrSet)
resp = Run({"SRandMember", "y", "-1"});
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
EXPECT_THAT(resp, AnyOf("a", "b", "c"));
resp = Run({"SRandMember", "y", "-2"});
ASSERT_THAT(resp, ArrLen(2));
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));
resp = Run({"SRandMember", "y", "-3"});
ASSERT_THAT(resp, ArrLen(3));
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));
// Test if count < 0, but the absolute value is larger than the size of the StrSet
resp = Run({"SRandMember", "y", "-25"});
ASSERT_THAT(resp, ArrLen(25));
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));
// Test if count is 0
ASSERT_THAT(Run({"SRandMember", "x", "0"}), ArrLen(0));
// Test if set is empty
EXPECT_THAT(Run({"SAdd", "empty::set", "1"}), IntArg(1));
EXPECT_THAT(Run({"SRem", "empty::set", "1"}), IntArg(1));
ASSERT_THAT(Run({"SRandMember", "empty::set", "0"}), ArrLen(0));
ASSERT_THAT(Run({"SRandMember", "empty::set", "3"}), ArrLen(0));
ASSERT_THAT(Run({"SRandMember", "empty::set", "-4"}), ArrLen(0));
// Test if key does not exist
ASSERT_THAT(Run({"SRandMember", "unknown::set"}), ArgType(RespExpr::NIL));
ASSERT_THAT(Run({"SRandMember", "unknown::set", "0"}), ArrLen(0));
// Test wrong arguments
resp = Run({"SRandMember", "x", "5", "3"});
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
}

View File

@ -222,70 +222,6 @@ OpResult<DbSlice::ItAndUpdater> FindZEntry(const ZParams& zparams, const OpArgs&
return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)};
}
using RandomPick = std::uint32_t;
class PicksGenerator {
public:
virtual RandomPick Generate() = 0;
virtual ~PicksGenerator() = default;
};
class NonUniquePicksGenerator : public PicksGenerator {
public:
NonUniquePicksGenerator(RandomPick max_range) : max_range_(max_range) {
CHECK_GT(max_range, RandomPick(0));
}
RandomPick Generate() override {
return absl::Uniform(bitgen_, 0u, max_range_);
}
private:
const RandomPick max_range_;
absl::BitGen bitgen_{};
};
/*
* Generates unique index in O(1).
*
* picks_count specifies the number of random indexes to be generated.
* In other words, this is the number of times the Generate() function is called.
*
* The class uses Robert Floyd's sampling algorithm
* https://dl.acm.org/doi/pdf/10.1145/30401.315746
* */
class UniquePicksGenerator : public PicksGenerator {
public:
UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range)
: remaining_picks_count_(picks_count), picked_indexes_(picks_count) {
CHECK_GE(max_range, picks_count);
current_random_limit_ = max_range - picks_count;
}
RandomPick Generate() override {
DCHECK_GT(remaining_picks_count_, 0u);
remaining_picks_count_--;
const RandomPick max_index = current_random_limit_++;
const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u);
const bool random_index_is_picked = picked_indexes_.emplace(random_index).second;
if (random_index_is_picked) {
return random_index;
}
picked_indexes_.insert(max_index);
return max_index;
}
private:
RandomPick current_random_limit_;
std::uint32_t remaining_picks_count_;
std::unordered_set<RandomPick> picked_indexes_;
absl::BitGen bitgen_{};
};
bool ScoreToLongLat(const std::optional<double>& val, double* xy) {
if (!val.has_value())
return false;