diff --git a/src/dict.c b/src/dict.c index 2cf9d4839..6844e6c8e 100644 --- a/src/dict.c +++ b/src/dict.c @@ -739,6 +739,30 @@ unsigned int dictGetSomeKeys(dict *d, dictEntry **des, unsigned int count) { return stored; } +/* This is like dictGetRandomKey() from the POV of the API, but will do more + * work to ensure a better distribution of the returned element. + * + * This function improves the distribution because the dictGetRandomKey() + * problem is that it selects a random bucket, then it selects a random + * element from the chain in the bucket. However elements being in different + * chain lengths will have different probabilities of being reported. With + * this function instead what we do is to consider a "linear" range of the table + * that may be constituted of N buckets with chains of different lengths + * appearing one after the other. Then we report a random element in the range. + * In this way we smooth away the problem of different chain lenghts. */ +#define GETFAIR_NUM_ENTRIES 20 +dictEntry *dictGetFairRandomKey(dict *d) { + dictEntry *entries[GETFAIR_NUM_ENTRIES]; + unsigned int count = dictGetSomeKeys(d,entries,GETFAIR_NUM_ENTRIES); + /* Note that dictGetSomeKeys() may return zero elements in an unlucky + * run() even if there are actually elements inside the hash table. So + * when we get zero, we call the true dictGetRandomKey() that will always + * yeld the element if the hash table has at least one. */ + if (count == 0) return dictGetRandomKey(d); + unsigned int idx = rand() % count; + return entries[idx]; +} + /* Function to reverse bits. Algorithm from: * http://graphics.stanford.edu/~seander/bithacks.html#ReverseParallel */ static unsigned long rev(unsigned long v) { diff --git a/src/dict.h b/src/dict.h index 62018cc44..dec60f637 100644 --- a/src/dict.h +++ b/src/dict.h @@ -166,6 +166,7 @@ dictIterator *dictGetSafeIterator(dict *d); dictEntry *dictNext(dictIterator *iter); void dictReleaseIterator(dictIterator *iter); dictEntry *dictGetRandomKey(dict *d); +dictEntry *dictGetFairRandomKey(dict *d); unsigned int dictGetSomeKeys(dict *d, dictEntry **des, unsigned int count); void dictGetStats(char *buf, size_t bufsize, dict *d); uint64_t dictGenHashFunction(const void *key, int len); diff --git a/src/t_set.c b/src/t_set.c index 61013dbcd..290a83e6d 100644 --- a/src/t_set.c +++ b/src/t_set.c @@ -207,7 +207,7 @@ sds setTypeNextObject(setTypeIterator *si) { * used field with values which are easy to trap if misused. */ int setTypeRandomElement(robj *setobj, sds *sdsele, int64_t *llele) { if (setobj->encoding == OBJ_ENCODING_HT) { - dictEntry *de = dictGetRandomKey(setobj->ptr); + dictEntry *de = dictGetFairRandomKey(setobj->ptr); *sdsele = dictGetKey(de); *llele = -123456789; /* Not needed. Defensive. */ } else if (setobj->encoding == OBJ_ENCODING_INTSET) {