mirror of
https://github.com/dragonflydb/dragonfly
synced 2024-11-21 23:19:53 +00:00
fix(set): fix random in SRANDMEMBER and SPOP commands (#3022)
Signed-off-by: Stepan Bagritsevich <sbagritsevich@quantumbrains.com>
This commit is contained in:
parent
4cd142d42c
commit
d3a585113f
@ -427,4 +427,35 @@ std::ostream& operator<<(std::ostream& os, const GlobalState& state) {
|
||||
return os << GlobalStateName(state);
|
||||
}
|
||||
|
||||
NonUniquePicksGenerator::NonUniquePicksGenerator(RandomPick max_range) : max_range_(max_range) {
|
||||
CHECK_GT(max_range, RandomPick(0));
|
||||
}
|
||||
|
||||
RandomPick NonUniquePicksGenerator::Generate() {
|
||||
return absl::Uniform(bitgen_, 0u, max_range_);
|
||||
}
|
||||
|
||||
UniquePicksGenerator::UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range)
|
||||
: remaining_picks_count_(picks_count), picked_indexes_(picks_count) {
|
||||
CHECK_GE(max_range, picks_count);
|
||||
current_random_limit_ = max_range - picks_count;
|
||||
}
|
||||
|
||||
RandomPick UniquePicksGenerator::Generate() {
|
||||
DCHECK_GT(remaining_picks_count_, 0u);
|
||||
|
||||
remaining_picks_count_--;
|
||||
|
||||
const RandomPick max_index = current_random_limit_++;
|
||||
const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u);
|
||||
|
||||
const bool random_index_is_picked = picked_indexes_.emplace(random_index).second;
|
||||
if (random_index_is_picked) {
|
||||
return random_index;
|
||||
}
|
||||
|
||||
picked_indexes_.insert(max_index);
|
||||
return max_index;
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <absl/random/random.h>
|
||||
#include <absl/strings/ascii.h>
|
||||
#include <absl/strings/str_cat.h>
|
||||
#include <absl/types/span.h>
|
||||
@ -309,4 +310,47 @@ struct MemoryBytesFlag {
|
||||
bool AbslParseFlag(std::string_view in, dfly::MemoryBytesFlag* flag, std::string* err);
|
||||
std::string AbslUnparseFlag(const dfly::MemoryBytesFlag& flag);
|
||||
|
||||
using RandomPick = std::uint32_t;
|
||||
|
||||
class PicksGenerator {
|
||||
public:
|
||||
virtual RandomPick Generate() = 0;
|
||||
virtual ~PicksGenerator() = default;
|
||||
};
|
||||
|
||||
class NonUniquePicksGenerator : public PicksGenerator {
|
||||
public:
|
||||
/* The generated value will be within the closed-open interval [0, max_range) */
|
||||
NonUniquePicksGenerator(RandomPick max_range);
|
||||
|
||||
RandomPick Generate() override;
|
||||
|
||||
private:
|
||||
const RandomPick max_range_;
|
||||
absl::BitGen bitgen_{};
|
||||
};
|
||||
|
||||
/*
|
||||
* Generates unique index in O(1).
|
||||
*
|
||||
* picks_count specifies the number of random indexes to be generated.
|
||||
* In other words, this is the number of times the Generate() function is called.
|
||||
*
|
||||
* The class uses Robert Floyd's sampling algorithm
|
||||
* https://dl.acm.org/doi/pdf/10.1145/30401.315746
|
||||
* */
|
||||
class UniquePicksGenerator : public PicksGenerator {
|
||||
public:
|
||||
/* The generated value will be within the closed-open interval [0, max_range) */
|
||||
UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range);
|
||||
|
||||
RandomPick Generate() override;
|
||||
|
||||
private:
|
||||
RandomPick current_random_limit_;
|
||||
std::uint32_t remaining_picks_count_;
|
||||
std::unordered_set<RandomPick> picked_indexes_;
|
||||
absl::BitGen bitgen_{};
|
||||
};
|
||||
|
||||
} // namespace dfly
|
||||
|
@ -284,22 +284,62 @@ void InterStrSet(const DbContext& db_context, const vector<SetType>& vec, String
|
||||
}
|
||||
}
|
||||
|
||||
StringVec PopStrSet(const DbContext& db_context, unsigned count, const SetType& st) {
|
||||
StringVec result;
|
||||
StringVec RandMemberStrSet(const DbContext& db_context, const CompactObj& co,
|
||||
PicksGenerator& generator, std::size_t picks_count) {
|
||||
CHECK(IsDenseEncoding(co));
|
||||
|
||||
if (true) {
|
||||
StringSet* ss = (StringSet*)st.first;
|
||||
ss->set_time(MemberTimeSeconds(db_context.time_now_ms));
|
||||
|
||||
// TODO: this loop is inefficient because Pop searches again and again an occupied bucket.
|
||||
for (unsigned i = 0; i < count && !ss->Empty(); ++i) {
|
||||
result.push_back(ss->Pop().value());
|
||||
}
|
||||
std::unordered_map<RandomPick, std::uint32_t> times_index_is_picked;
|
||||
for (std::size_t i = 0; i < picks_count; i++) {
|
||||
times_index_is_picked[generator.Generate()]++;
|
||||
}
|
||||
|
||||
StringVec result;
|
||||
result.reserve(picks_count);
|
||||
|
||||
StringSet* ss = static_cast<StringSet*>(co.RObjPtr());
|
||||
ss->set_time(MemberTimeSeconds(db_context.time_now_ms));
|
||||
|
||||
std::uint32_t ss_entry_index = 0;
|
||||
container_utils::IterateSet(
|
||||
co, [&result, ×_index_is_picked, &ss_entry_index](container_utils::ContainerEntry ce) {
|
||||
auto it = times_index_is_picked.find(ss_entry_index++);
|
||||
if (it != times_index_is_picked.end()) {
|
||||
std::uint32_t t = it->second;
|
||||
while (t--) {
|
||||
result.emplace_back(ce.ToString());
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
/* Equal elements in the result are always successive. So, it is necessary to shuffle them */
|
||||
absl::BitGen gen;
|
||||
std::shuffle(result.begin(), result.end(), gen);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
StringVec RandMemberSet(const DbContext& db_context, const CompactObj& co,
|
||||
PicksGenerator& generator, std::size_t picks_count) {
|
||||
if (co.Encoding() == kEncodingIntSet) {
|
||||
intset* is = static_cast<intset*>(co.RObjPtr());
|
||||
|
||||
StringVec result;
|
||||
result.reserve(picks_count);
|
||||
|
||||
for (std::size_t i = 0; i < picks_count; i++) {
|
||||
const std::size_t picked_index = generator.Generate();
|
||||
|
||||
int64_t value = 0;
|
||||
CHECK_GT(intsetGet(is, picked_index, &value), std::uint8_t(0));
|
||||
|
||||
result.push_back(absl::StrCat(value));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
return RandMemberStrSet(db_context, co, generator, picks_count);
|
||||
}
|
||||
|
||||
vector<string> ToVec(absl::flat_hash_set<string>&& set) {
|
||||
vector<string> result(set.size());
|
||||
size_t i = 0;
|
||||
@ -819,69 +859,92 @@ OpResult<StringVec> OpInter(const Transaction* t, EngineShard* es, bool remove_f
|
||||
return result;
|
||||
}
|
||||
|
||||
// count - how many elements to pop.
|
||||
OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count) {
|
||||
auto& db_slice = op_args.shard->db_slice();
|
||||
auto find_res = db_slice.FindMutable(op_args.db_cntx, key, OBJ_SET);
|
||||
OpResult<StringVec> OpRandMember(const OpArgs& op_args, std::string_view key, int count) {
|
||||
auto find_res = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_SET);
|
||||
if (!find_res)
|
||||
return find_res.status();
|
||||
|
||||
StringVec result;
|
||||
if (count == 0)
|
||||
return result;
|
||||
const CompactObj& co = find_res.value()->second;
|
||||
|
||||
auto it = find_res->it;
|
||||
size_t slen = it->second.Size();
|
||||
const std::uint32_t size = co.Size();
|
||||
const bool picks_are_unique = count >= 0;
|
||||
const std::uint32_t picks_count =
|
||||
picks_are_unique ? std::min(static_cast<std::uint32_t>(count), size) : std::abs(count);
|
||||
|
||||
auto generator = [picks_are_unique, picks_count, size]() -> std::unique_ptr<PicksGenerator> {
|
||||
if (picks_are_unique) {
|
||||
return std::make_unique<UniquePicksGenerator>(picks_count, size);
|
||||
} else {
|
||||
return std::make_unique<NonUniquePicksGenerator>(size);
|
||||
}
|
||||
}();
|
||||
|
||||
return RandMemberSet(op_args.db_cntx, co, *generator, picks_count);
|
||||
}
|
||||
|
||||
// count - how many elements to pop.
|
||||
OpResult<StringVec> OpPop(const OpArgs& op_args, string_view key, unsigned count) {
|
||||
auto& db_cntx = op_args.db_cntx;
|
||||
auto& db_slice = op_args.shard->db_slice();
|
||||
auto find_res = db_slice.FindMutable(db_cntx, key, OBJ_SET);
|
||||
if (!find_res) {
|
||||
return find_res.status();
|
||||
}
|
||||
|
||||
CompactObj& co = find_res->it->second;
|
||||
|
||||
const std::uint32_t size = co.Size();
|
||||
const std::uint32_t picks_count = std::min(count, size);
|
||||
|
||||
/* CASE 1:
|
||||
* The number of requested elements is greater than or equal to
|
||||
* the number of elements inside the set: simply return the whole set. */
|
||||
if (count >= slen) {
|
||||
PrimeValue& pv = it->second;
|
||||
if (IsDenseEncoding(pv)) {
|
||||
StringSet* ss = (StringSet*)pv.RObjPtr();
|
||||
if (count >= size) {
|
||||
if (IsDenseEncoding(co)) {
|
||||
StringSet* ss = (StringSet*)co.RObjPtr();
|
||||
ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms));
|
||||
}
|
||||
|
||||
container_utils::IterateSet(it->second, [&result](container_utils::ContainerEntry ce) {
|
||||
StringVec result;
|
||||
result.reserve(picks_count);
|
||||
|
||||
container_utils::IterateSet(co, [&result](container_utils::ContainerEntry ce) {
|
||||
result.push_back(ce.ToString());
|
||||
return true;
|
||||
});
|
||||
|
||||
// Delete the set as it is now empty
|
||||
find_res->post_updater.Run();
|
||||
CHECK(db_slice.Del(op_args.db_cntx.db_index, it));
|
||||
CHECK(db_slice.Del(op_args.db_cntx.db_index, find_res->it));
|
||||
|
||||
// Replicate as DEL.
|
||||
if (op_args.shard->journal()) {
|
||||
RecordJournal(op_args, "DEL"sv, ArgSlice{key});
|
||||
}
|
||||
} else {
|
||||
SetType st{it->second.RObjPtr(), it->second.Encoding()};
|
||||
if (st.second == kEncodingIntSet) {
|
||||
intset* is = (intset*)st.first;
|
||||
int64_t val = 0;
|
||||
|
||||
// copy last count values.
|
||||
for (uint32_t i = slen - count; i < slen; ++i) {
|
||||
intsetGet(is, i, &val);
|
||||
result.push_back(absl::StrCat(val));
|
||||
}
|
||||
|
||||
is = intsetTrimTail(is, count); // now remove last count items
|
||||
it->second.SetRObjPtr(is);
|
||||
} else {
|
||||
result = PopStrSet(op_args.db_cntx, count, st);
|
||||
}
|
||||
|
||||
// Replicate as SREM with removed keys, because SPOP is not deterministic.
|
||||
if (op_args.shard->journal()) {
|
||||
vector<string_view> mapped(result.size() + 1);
|
||||
mapped[0] = key;
|
||||
std::copy(result.begin(), result.end(), mapped.begin() + 1);
|
||||
RecordJournal(op_args, "SREM"sv, mapped);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* CASE 2:
|
||||
* The number of requested elements is less than the number of elements inside the set.
|
||||
* In this case, we need to select random members from the set and then remove them. */
|
||||
UniquePicksGenerator generator{picks_count, size};
|
||||
|
||||
// Select random members
|
||||
StringVec result = RandMemberSet(db_cntx, co, generator, picks_count);
|
||||
|
||||
// Remove selected members
|
||||
std::vector<std::string_view> members_to_remove{result.begin(), result.end()};
|
||||
bool is_empty = RemoveSet(db_cntx, members_to_remove, &co).second;
|
||||
find_res->post_updater.Run();
|
||||
|
||||
CHECK(!is_empty);
|
||||
|
||||
// Replicate as SREM with removed keys, because SPOP is not deterministic.
|
||||
if (op_args.shard->journal()) {
|
||||
members_to_remove.insert(members_to_remove.begin(), key);
|
||||
RecordJournal(op_args, "SREM"sv, members_to_remove);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -1204,41 +1267,13 @@ void SRandMember(CmdArgList args, ConnectionContext* cntx) {
|
||||
if (auto err = parser.Error(); err)
|
||||
return cntx->SendError(err->MakeReply());
|
||||
|
||||
const unsigned ucount = std::abs(count);
|
||||
|
||||
const auto cb = [&](Transaction* t, EngineShard* shard) -> OpResult<StringVec> {
|
||||
StringVec result;
|
||||
auto find_res = shard->db_slice().FindReadOnly(t->GetDbContext(), key, OBJ_SET);
|
||||
if (!find_res) {
|
||||
return find_res.status();
|
||||
}
|
||||
|
||||
const PrimeValue& pv = find_res.value()->second;
|
||||
if (IsDenseEncoding(pv)) {
|
||||
StringSet* ss = (StringSet*)pv.RObjPtr();
|
||||
ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms));
|
||||
}
|
||||
|
||||
container_utils::IterateSet(find_res.value()->second,
|
||||
[&result, ucount](container_utils::ContainerEntry ce) {
|
||||
if (result.size() < ucount) {
|
||||
result.push_back(ce.ToString());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return result;
|
||||
return OpRandMember(t->GetOpArgs(shard), key, count);
|
||||
};
|
||||
|
||||
OpResult<StringVec> result = cntx->transaction->ScheduleSingleHopT(cb);
|
||||
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
|
||||
if (result) {
|
||||
if (count < 0 && !result->empty()) {
|
||||
for (auto i = result->size(); i < ucount; ++i) {
|
||||
// we can return duplicate elements, so first is OK
|
||||
result->push_back(result->front());
|
||||
}
|
||||
}
|
||||
rb->SendStringArr(*result, RedisReplyBuilder::SET);
|
||||
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
|
||||
if (is_count) {
|
||||
|
@ -21,6 +21,20 @@ class SetFamilyTest : public BaseFamilyTest {
|
||||
protected:
|
||||
};
|
||||
|
||||
MATCHER_P(ConsistsOfMatcher, elements, "") {
|
||||
auto vec = arg.GetVec();
|
||||
for (const auto& x : vec) {
|
||||
if (elements.find(x.GetString()) == elements.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
auto ConsistsOf(std::initializer_list<std::string> elements) {
|
||||
return ConsistsOfMatcher(std::unordered_set<std::string>{elements});
|
||||
}
|
||||
|
||||
TEST_F(SetFamilyTest, SAdd) {
|
||||
auto resp = Run({"sadd", "x", "1", "2", "3"});
|
||||
EXPECT_THAT(resp, IntArg(3));
|
||||
@ -159,34 +173,107 @@ TEST_F(SetFamilyTest, SPop) {
|
||||
}
|
||||
|
||||
TEST_F(SetFamilyTest, SRandMember) {
|
||||
auto resp = Run({"sadd", "x", "1", "2", "3"});
|
||||
resp = Run({"SRandMember", "x"});
|
||||
// Test IntSet
|
||||
Run({"sadd", "x", "1", "2", "3"});
|
||||
|
||||
// Test if count > 0 (IntSet)
|
||||
auto resp = Run({"SRandMember", "x"});
|
||||
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
|
||||
EXPECT_THAT(resp, "1");
|
||||
EXPECT_THAT(resp, AnyOf("1", "2", "3"));
|
||||
|
||||
resp = Run({"SRandMember", "x", "1"});
|
||||
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
|
||||
EXPECT_THAT(resp, AnyOf("1", "2", "3"));
|
||||
|
||||
resp = Run({"SRandMember", "x", "2"});
|
||||
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
|
||||
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2"));
|
||||
ASSERT_THAT(resp, ArrLen(2));
|
||||
EXPECT_THAT(resp.GetVec(), IsSubsetOf({"1", "2", "3"}));
|
||||
|
||||
resp = Run({"SRandMember", "x", "0"});
|
||||
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
|
||||
EXPECT_EQ(resp.GetVec().size(), 0);
|
||||
|
||||
resp = Run({"SRandMember", "k"});
|
||||
ASSERT_THAT(resp, ArgType(RespExpr::NIL));
|
||||
|
||||
resp = Run({"SRandMember", "k", "2"});
|
||||
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
|
||||
EXPECT_EQ(resp.GetVec().size(), 0);
|
||||
|
||||
resp = Run({"SRandMember", "x", "-5"});
|
||||
ASSERT_THAT(resp, ArrLen(5));
|
||||
EXPECT_THAT(resp.GetVec(), ElementsAre("1", "2", "3", "1", "1"));
|
||||
|
||||
resp = Run({"SRandMember", "x", "5"});
|
||||
resp = Run({"SRandMember", "x", "3"});
|
||||
ASSERT_THAT(resp, ArrLen(3));
|
||||
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2", "3"));
|
||||
|
||||
// Test if count is larger than the size of the IntSet
|
||||
resp = Run({"SRandMember", "x", "25"});
|
||||
ASSERT_THAT(resp, ArrLen(3));
|
||||
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("1", "2", "3"));
|
||||
|
||||
// Test if count < 0 (IntSet)
|
||||
resp = Run({"SRandMember", "x", "-1"});
|
||||
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
|
||||
EXPECT_THAT(resp, AnyOf("1", "2", "3"));
|
||||
|
||||
resp = Run({"SRandMember", "x", "-2"});
|
||||
ASSERT_THAT(resp, ArrLen(2));
|
||||
EXPECT_THAT(resp, ConsistsOf({"1", "2", "3"}));
|
||||
|
||||
resp = Run({"SRandMember", "x", "-3"});
|
||||
ASSERT_THAT(resp, ArrLen(3));
|
||||
EXPECT_THAT(resp, ConsistsOf({"1", "2", "3"}));
|
||||
|
||||
// Test if count < 0, but the absolute value is larger than the size of the IntSet
|
||||
resp = Run({"SRandMember", "x", "-25"});
|
||||
ASSERT_THAT(resp, ArrLen(25));
|
||||
EXPECT_THAT(resp, ConsistsOf({"1", "2", "3"}));
|
||||
|
||||
// Test StrSet
|
||||
Run({"sadd", "y", "a", "b", "c"});
|
||||
|
||||
// Test if count > 0 (StrSet)
|
||||
resp = Run({"SRandMember", "y"});
|
||||
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
|
||||
EXPECT_THAT(resp, AnyOf("a", "b", "c"));
|
||||
|
||||
resp = Run({"SRandMember", "y", "1"});
|
||||
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
|
||||
EXPECT_THAT(resp, AnyOf("a", "b", "c"));
|
||||
|
||||
resp = Run({"SRandMember", "y", "2"});
|
||||
ASSERT_THAT(resp, ArrLen(2));
|
||||
EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"}));
|
||||
|
||||
resp = Run({"SRandMember", "y", "3"});
|
||||
ASSERT_THAT(resp, ArrLen(3));
|
||||
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));
|
||||
|
||||
// Test if count is larger than the size of the StrSet
|
||||
resp = Run({"SRandMember", "y", "25"});
|
||||
ASSERT_THAT(resp, ArrLen(3));
|
||||
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));
|
||||
|
||||
// Test if count < 0 (StrSet)
|
||||
resp = Run({"SRandMember", "y", "-1"});
|
||||
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
|
||||
EXPECT_THAT(resp, AnyOf("a", "b", "c"));
|
||||
|
||||
resp = Run({"SRandMember", "y", "-2"});
|
||||
ASSERT_THAT(resp, ArrLen(2));
|
||||
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));
|
||||
|
||||
resp = Run({"SRandMember", "y", "-3"});
|
||||
ASSERT_THAT(resp, ArrLen(3));
|
||||
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));
|
||||
|
||||
// Test if count < 0, but the absolute value is larger than the size of the StrSet
|
||||
resp = Run({"SRandMember", "y", "-25"});
|
||||
ASSERT_THAT(resp, ArrLen(25));
|
||||
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));
|
||||
|
||||
// Test if count is 0
|
||||
ASSERT_THAT(Run({"SRandMember", "x", "0"}), ArrLen(0));
|
||||
|
||||
// Test if set is empty
|
||||
EXPECT_THAT(Run({"SAdd", "empty::set", "1"}), IntArg(1));
|
||||
EXPECT_THAT(Run({"SRem", "empty::set", "1"}), IntArg(1));
|
||||
ASSERT_THAT(Run({"SRandMember", "empty::set", "0"}), ArrLen(0));
|
||||
ASSERT_THAT(Run({"SRandMember", "empty::set", "3"}), ArrLen(0));
|
||||
ASSERT_THAT(Run({"SRandMember", "empty::set", "-4"}), ArrLen(0));
|
||||
|
||||
// Test if key does not exist
|
||||
ASSERT_THAT(Run({"SRandMember", "unknown::set"}), ArgType(RespExpr::NIL));
|
||||
ASSERT_THAT(Run({"SRandMember", "unknown::set", "0"}), ArrLen(0));
|
||||
|
||||
// Test wrong arguments
|
||||
resp = Run({"SRandMember", "x", "5", "3"});
|
||||
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
|
||||
}
|
||||
|
@ -222,70 +222,6 @@ OpResult<DbSlice::ItAndUpdater> FindZEntry(const ZParams& zparams, const OpArgs&
|
||||
return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)};
|
||||
}
|
||||
|
||||
using RandomPick = std::uint32_t;
|
||||
|
||||
class PicksGenerator {
|
||||
public:
|
||||
virtual RandomPick Generate() = 0;
|
||||
virtual ~PicksGenerator() = default;
|
||||
};
|
||||
|
||||
class NonUniquePicksGenerator : public PicksGenerator {
|
||||
public:
|
||||
NonUniquePicksGenerator(RandomPick max_range) : max_range_(max_range) {
|
||||
CHECK_GT(max_range, RandomPick(0));
|
||||
}
|
||||
|
||||
RandomPick Generate() override {
|
||||
return absl::Uniform(bitgen_, 0u, max_range_);
|
||||
}
|
||||
|
||||
private:
|
||||
const RandomPick max_range_;
|
||||
absl::BitGen bitgen_{};
|
||||
};
|
||||
|
||||
/*
|
||||
* Generates unique index in O(1).
|
||||
*
|
||||
* picks_count specifies the number of random indexes to be generated.
|
||||
* In other words, this is the number of times the Generate() function is called.
|
||||
*
|
||||
* The class uses Robert Floyd's sampling algorithm
|
||||
* https://dl.acm.org/doi/pdf/10.1145/30401.315746
|
||||
* */
|
||||
class UniquePicksGenerator : public PicksGenerator {
|
||||
public:
|
||||
UniquePicksGenerator(std::uint32_t picks_count, RandomPick max_range)
|
||||
: remaining_picks_count_(picks_count), picked_indexes_(picks_count) {
|
||||
CHECK_GE(max_range, picks_count);
|
||||
current_random_limit_ = max_range - picks_count;
|
||||
}
|
||||
|
||||
RandomPick Generate() override {
|
||||
DCHECK_GT(remaining_picks_count_, 0u);
|
||||
|
||||
remaining_picks_count_--;
|
||||
|
||||
const RandomPick max_index = current_random_limit_++;
|
||||
const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u);
|
||||
|
||||
const bool random_index_is_picked = picked_indexes_.emplace(random_index).second;
|
||||
if (random_index_is_picked) {
|
||||
return random_index;
|
||||
}
|
||||
|
||||
picked_indexes_.insert(max_index);
|
||||
return max_index;
|
||||
}
|
||||
|
||||
private:
|
||||
RandomPick current_random_limit_;
|
||||
std::uint32_t remaining_picks_count_;
|
||||
std::unordered_set<RandomPick> picked_indexes_;
|
||||
absl::BitGen bitgen_{};
|
||||
};
|
||||
|
||||
bool ScoreToLongLat(const std::optional<double>& val, double* xy) {
|
||||
if (!val.has_value())
|
||||
return false;
|
||||
|
Loading…
Reference in New Issue
Block a user