This commit is contained in:
Ivan Akimov 2023-07-18 14:32:39 -05:00
parent f2d11a023e
commit cf377536d0
7 changed files with 389 additions and 271 deletions

1
.clippy.toml Normal file
View File

@ -0,0 +1 @@
too-many-arguments-threshold = 10

3
.gitattributes vendored
View File

@ -1 +1,2 @@
**/blocklist.json binary **/blocklist.json binary
Cargo.lock binary

View File

@ -1,5 +1,7 @@
# [Sqids Rust](https://sqids.org/rust) # [Sqids Rust](https://sqids.org/rust)
[![Github Actions](https://img.shields.io/github/actions/workflow/status/sqids/sqids-rust/tests.yml)](https://github.com/sqids/sqids-rust/actions)
Sqids (pronounced "squids") is a small library that lets you generate YouTube-looking IDs from numbers. It's good for link shortening, fast & URL-safe ID generation and decoding back into numbers for quicker database lookups. Sqids (pronounced "squids") is a small library that lets you generate YouTube-looking IDs from numbers. It's good for link shortening, fast & URL-safe ID generation and decoding back into numbers for quicker database lookups.
## Getting started ## Getting started

2
release.toml Normal file
View File

@ -0,0 +1,2 @@
consolidate-commits = false
consolidate-pushes = true

13
rustfmt.toml Normal file
View File

@ -0,0 +1,13 @@
max_width = 100
comment_width = 100
hard_tabs = true
edition = "2021"
reorder_imports = true
imports_granularity = "Crate"
use_small_heuristics = "Max"
wrap_comments = true
binop_separator = "Back"
trailing_comma = "Vertical"
trailing_semicolon = true
use_field_init_shorthand = true
format_macro_bodies = true

View File

@ -1,345 +1,326 @@
use derive_more::Display; use derive_more::Display;
use std::collections::HashSet; use std::{collections::HashSet, result};
use std::result;
#[derive(Display, Debug)] #[derive(Display, Debug)]
pub enum Error { pub enum Error {
#[display(fmt = "Alphabet length must be at least 5")] #[display(fmt = "Alphabet length must be at least 5")]
AlphabetLength, AlphabetLength,
#[display(fmt = "Alphabet must contain unique characters")] #[display(fmt = "Alphabet must contain unique characters")]
AlphabetUniqueCharacters, AlphabetUniqueCharacters,
#[display(fmt = "Minimum length has to be between {min} and {max}")] #[display(fmt = "Minimum length has to be between {min} and {max}")]
MinLength { min: usize, max: usize }, MinLength { min: usize, max: usize },
#[display(fmt = "Encoding supports numbers between {min} and {max}")] #[display(fmt = "Encoding supports numbers between {min} and {max}")]
EncodingRange { min: u64, max: u64 }, EncodingRange { min: u64, max: u64 },
#[display(fmt = "Ran out of range checking against the blocklist")] #[display(fmt = "Ran out of range checking against the blocklist")]
BlocklistOutOfRange, BlocklistOutOfRange,
} }
pub type Result<T> = result::Result<T, Error>; pub type Result<T> = result::Result<T, Error>;
pub fn default_blocklist() -> HashSet<String> {
serde_json::from_str(include_str!("blocklist.json")).unwrap()
}
#[derive(Debug)] #[derive(Debug)]
pub struct Options { pub struct Options {
alphabet: String, alphabet: String,
min_length: usize, min_length: usize,
blocklist: HashSet<String>, blocklist: HashSet<String>,
} }
impl Options { impl Options {
pub fn new( pub fn new(
alphabet: Option<String>, alphabet: Option<String>,
min_length: Option<usize>, min_length: Option<usize>,
blocklist: Option<HashSet<String>>, blocklist: Option<HashSet<String>>,
) -> Self { ) -> Self {
let mut options = Options::default(); let mut options = Options::default();
if let Some(alphabet) = alphabet { if let Some(alphabet) = alphabet {
options.alphabet = alphabet; options.alphabet = alphabet;
} }
if let Some(min_length) = min_length { if let Some(min_length) = min_length {
options.min_length = min_length; options.min_length = min_length;
} }
if let Some(blocklist) = blocklist { if let Some(blocklist) = blocklist {
options.blocklist = blocklist; options.blocklist = blocklist;
} }
options options
} }
} }
impl Default for Options { impl Default for Options {
fn default() -> Self { fn default() -> Self {
Options { Options {
alphabet: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".to_string(), alphabet: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".to_string(),
min_length: 0, min_length: 0,
blocklist: serde_json::from_str(include_str!("blocklist.json")).unwrap(), blocklist: default_blocklist(),
} }
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct Sqids { pub struct Sqids {
alphabet: Vec<char>, alphabet: Vec<char>,
min_length: usize, min_length: usize,
blocklist: HashSet<String>, blocklist: HashSet<String>,
} }
impl Sqids { impl Sqids {
pub fn new(options: Option<Options>) -> Result<Self> { pub fn new(options: Option<Options>) -> Result<Self> {
let options = options.unwrap_or_default(); let options = options.unwrap_or_default();
let alphabet: Vec<char> = options.alphabet.chars().collect(); let alphabet: Vec<char> = options.alphabet.chars().collect();
if alphabet.len() < 5 { if alphabet.len() < 5 {
return Err(Error::AlphabetLength); return Err(Error::AlphabetLength);
} }
let unique_chars: HashSet<char> = alphabet.iter().cloned().collect(); let unique_chars: HashSet<char> = alphabet.iter().cloned().collect();
if unique_chars.len() != alphabet.len() { if unique_chars.len() != alphabet.len() {
return Err(Error::AlphabetUniqueCharacters); return Err(Error::AlphabetUniqueCharacters);
} }
let filtered_blocklist: HashSet<String> = options let filtered_blocklist: HashSet<String> = options
.blocklist .blocklist
.iter() .iter()
.filter_map(|word| { .filter_map(|word| {
let word = word.to_lowercase(); let word = word.to_lowercase();
if word.len() >= 3 && word.chars().all(|c| alphabet.contains(&c)) { if word.len() >= 3 && word.chars().all(|c| alphabet.contains(&c)) {
Some(word) Some(word)
} else { } else {
None None
} }
}) })
.collect(); .collect();
let mut sqids = Sqids { let mut sqids =
alphabet, Sqids { alphabet, min_length: options.min_length, blocklist: filtered_blocklist };
min_length: options.min_length,
blocklist: filtered_blocklist,
};
if options.min_length < sqids.min_value() as usize if options.min_length < sqids.min_value() as usize ||
|| options.min_length > options.alphabet.len() options.min_length > options.alphabet.len()
{ {
return Err(Error::MinLength { return Err(Error::MinLength {
min: sqids.min_value() as usize, min: sqids.min_value() as usize,
max: options.alphabet.len(), max: options.alphabet.len(),
}); });
} }
sqids.alphabet = sqids.shuffle(&sqids.alphabet); sqids.alphabet = sqids.shuffle(&sqids.alphabet);
Ok(sqids) Ok(sqids)
} }
pub fn encode(&self, numbers: &[u64]) -> Result<String> { pub fn encode(&self, numbers: &[u64]) -> Result<String> {
if numbers.is_empty() { if numbers.is_empty() {
return Ok(String::new()); return Ok(String::new());
} }
let in_range_numbers: Vec<u64> = numbers let in_range_numbers: Vec<u64> = numbers
.iter() .iter()
.copied() .copied()
.filter(|&n| n >= self.min_value() && n <= self.max_value()) .filter(|&n| n >= self.min_value() && n <= self.max_value())
.collect(); .collect();
if in_range_numbers.len() != numbers.len() { if in_range_numbers.len() != numbers.len() {
return Err(Error::EncodingRange { return Err(Error::EncodingRange { min: self.min_value(), max: self.max_value() });
min: self.min_value(), }
max: self.max_value(),
});
}
self.encode_numbers(&in_range_numbers, false) self.encode_numbers(&in_range_numbers, false)
} }
pub fn decode(&self, id: &str) -> Vec<u64> { pub fn decode(&self, id: &str) -> Vec<u64> {
let mut ret = Vec::new(); let mut ret = Vec::new();
if id.is_empty() { if id.is_empty() {
return ret; return ret;
} }
let alphabet_chars: HashSet<char> = self.alphabet.iter().cloned().collect(); let alphabet_chars: HashSet<char> = self.alphabet.iter().cloned().collect();
if !id.chars().all(|c| alphabet_chars.contains(&c)) { if !id.chars().all(|c| alphabet_chars.contains(&c)) {
return ret; return ret;
} }
let prefix = id.chars().next().unwrap(); let prefix = id.chars().next().unwrap();
let offset = self.alphabet.iter().position(|&c| c == prefix).unwrap(); let offset = self.alphabet.iter().position(|&c| c == prefix).unwrap();
let mut alphabet: Vec<char> = self let mut alphabet: Vec<char> =
.alphabet self.alphabet.iter().cycle().skip(offset).take(self.alphabet.len()).copied().collect();
.iter()
.cycle()
.skip(offset)
.take(self.alphabet.len())
.copied()
.collect();
let partition = alphabet[1]; let partition = alphabet[1];
alphabet.remove(1); alphabet.remove(1);
alphabet.remove(0); alphabet.remove(0);
let mut id = id[1..].to_string(); let mut id = id[1..].to_string();
let partition_index = id.find(partition); let partition_index = id.find(partition);
if let Some(idx) = partition_index { if let Some(idx) = partition_index {
if idx > 0 && idx < id.len() - 1 { if idx > 0 && idx < id.len() - 1 {
id = id.split_off(idx + 1); id = id.split_off(idx + 1);
alphabet = self.shuffle(&alphabet); alphabet = self.shuffle(&alphabet);
} }
} }
while !id.is_empty() { while !id.is_empty() {
let separator = alphabet[alphabet.len() - 1]; let separator = alphabet[alphabet.len() - 1];
let chunks: Vec<&str> = id.split(separator).collect(); let chunks: Vec<&str> = id.split(separator).collect();
if !chunks.is_empty() { if !chunks.is_empty() {
let alphabet_without_separator: Vec<char> = let alphabet_without_separator: Vec<char> =
alphabet.iter().copied().take(alphabet.len() - 1).collect(); alphabet.iter().copied().take(alphabet.len() - 1).collect();
let num = self.to_number(chunks[0], &alphabet_without_separator); let num = self.to_number(chunks[0], &alphabet_without_separator);
ret.push(num); ret.push(num);
if chunks.len() > 1 { if chunks.len() > 1 {
alphabet = self.shuffle(&alphabet); alphabet = self.shuffle(&alphabet);
} }
} }
id = chunks[1..].join(&separator.to_string()); id = chunks[1..].join(&separator.to_string());
} }
ret ret
} }
pub fn min_value(&self) -> u64 { pub fn min_value(&self) -> u64 {
0 0
} }
pub fn max_value(&self) -> u64 { pub fn max_value(&self) -> u64 {
u64::MAX u64::MAX
} }
fn encode_numbers(&self, numbers: &[u64], partitioned: bool) -> Result<String> { fn encode_numbers(&self, numbers: &[u64], partitioned: bool) -> Result<String> {
let offset = numbers let offset = numbers.iter().enumerate().fold(numbers.len(), |a, (i, &v)| {
.iter() self.alphabet[v as usize % self.alphabet.len()] as usize + i + a
.enumerate() }) % self.alphabet.len();
.fold(numbers.len(), |a, (i, &v)| {
self.alphabet[v as usize % self.alphabet.len()] as usize + i + a
})
% self.alphabet.len();
let mut alphabet: Vec<char> = self let mut alphabet: Vec<char> =
.alphabet self.alphabet.iter().cycle().skip(offset).take(self.alphabet.len()).copied().collect();
.iter()
.cycle()
.skip(offset)
.take(self.alphabet.len())
.copied()
.collect();
let prefix = alphabet[0]; let prefix = alphabet[0];
let partition = alphabet[1]; let partition = alphabet[1];
alphabet.remove(1); alphabet.remove(1);
alphabet.remove(0); alphabet.remove(0);
let mut ret: Vec<String> = vec![prefix.to_string()]; let mut ret: Vec<String> = vec![prefix.to_string()];
for (i, &num) in numbers.iter().enumerate() { for (i, &num) in numbers.iter().enumerate() {
let alphabet_without_separator: Vec<char> = let alphabet_without_separator: Vec<char> =
alphabet.iter().copied().take(alphabet.len() - 1).collect(); alphabet.iter().copied().take(alphabet.len() - 1).collect();
ret.push(self.to_id(num, &alphabet_without_separator)); ret.push(self.to_id(num, &alphabet_without_separator));
if i < numbers.len() - 1 { if i < numbers.len() - 1 {
let separator = alphabet[alphabet.len() - 1]; let separator = alphabet[alphabet.len() - 1];
if partitioned && i == 0 { if partitioned && i == 0 {
ret.push(partition.to_string()); ret.push(partition.to_string());
} else { } else {
ret.push(separator.to_string()); ret.push(separator.to_string());
} }
alphabet = self.shuffle(&alphabet); alphabet = self.shuffle(&alphabet);
} }
} }
let mut id = ret.join(""); let mut id = ret.join("");
if self.min_length > id.len() { if self.min_length > id.len() {
if !partitioned { if !partitioned {
let mut new_numbers = vec![0]; let mut new_numbers = vec![0];
new_numbers.extend_from_slice(numbers); new_numbers.extend_from_slice(numbers);
id = self.encode_numbers(&new_numbers, true)?; id = self.encode_numbers(&new_numbers, true)?;
} }
if self.min_length > id.len() { if self.min_length > id.len() {
let mut new_id = id.clone(); let mut new_id = id.clone();
let alphabet_slice = &alphabet[..(self.min_length - id.len())]; let alphabet_slice = &alphabet[..(self.min_length - id.len())];
new_id.push_str(&alphabet_slice.iter().collect::<String>()); new_id.push_str(&alphabet_slice.iter().collect::<String>());
new_id.push_str(&id[1..]); new_id.push_str(&id[1..]);
id = new_id; id = new_id;
} }
} }
if self.is_blocked_id(&id) { if self.is_blocked_id(&id) {
let mut new_numbers; let mut new_numbers;
if partitioned { if partitioned {
if numbers[0] + 1 > self.max_value() { if numbers[0] + 1 > self.max_value() {
return Err(Error::BlocklistOutOfRange); return Err(Error::BlocklistOutOfRange);
} else { } else {
new_numbers = numbers.to_vec(); new_numbers = numbers.to_vec();
new_numbers[0] += 1; new_numbers[0] += 1;
} }
} else { } else {
new_numbers = vec![0]; new_numbers = vec![0];
new_numbers.extend_from_slice(numbers); new_numbers.extend_from_slice(numbers);
} }
id = self.encode_numbers(&new_numbers, true)?; id = self.encode_numbers(&new_numbers, true)?;
} }
Ok(id) Ok(id)
} }
fn to_id(&self, num: u64, alphabet: &[char]) -> String { fn to_id(&self, num: u64, alphabet: &[char]) -> String {
let mut id = Vec::new(); let mut id = Vec::new();
let mut result = num; let mut result = num;
loop { loop {
let idx = (result % alphabet.len() as u64) as usize; let idx = (result % alphabet.len() as u64) as usize;
id.insert(0, alphabet[idx]); id.insert(0, alphabet[idx]);
result /= alphabet.len() as u64; result /= alphabet.len() as u64;
if result == 0 { if result == 0 {
break; break;
} }
} }
id.into_iter().collect() id.into_iter().collect()
} }
fn to_number(&self, id: &str, alphabet: &[char]) -> u64 { fn to_number(&self, id: &str, alphabet: &[char]) -> u64 {
let mut result = 0; let mut result = 0;
for c in id.chars() { for c in id.chars() {
let idx = alphabet.iter().position(|&x| x == c).unwrap(); let idx = alphabet.iter().position(|&x| x == c).unwrap();
result = result * alphabet.len() as u64 + idx as u64; result = result * alphabet.len() as u64 + idx as u64;
} }
result result
} }
fn shuffle(&self, alphabet: &[char]) -> Vec<char> { fn shuffle(&self, alphabet: &[char]) -> Vec<char> {
let mut chars: Vec<char> = alphabet.to_vec(); let mut chars: Vec<char> = alphabet.to_vec();
for i in 0..(chars.len() - 1) { for i in 0..(chars.len() - 1) {
let j = chars.len() - 1 - i; let j = chars.len() - 1 - i;
let r = (i as u32 * j as u32 + chars[i] as u32 + chars[j] as u32) % chars.len() as u32; let r = (i as u32 * j as u32 + chars[i] as u32 + chars[j] as u32) % chars.len() as u32;
chars.swap(i, r as usize); chars.swap(i, r as usize);
} }
chars chars
} }
fn is_blocked_id(&self, id: &str) -> bool { fn is_blocked_id(&self, id: &str) -> bool {
let id = id.to_lowercase(); let id = id.to_lowercase();
for word in &self.blocklist { for word in &self.blocklist {
if word.len() <= id.len() { if word.len() <= id.len() {
if id.len() <= 3 || word.len() <= 3 { if id.len() <= 3 || word.len() <= 3 {
if id == *word { if id == *word {
return true; return true;
} }
} else if word.chars().any(|c| c.is_ascii_digit()) { } else if word.chars().any(|c| c.is_ascii_digit()) {
if id.starts_with(word) || id.ends_with(word) { if id.starts_with(word) || id.ends_with(word) {
return true; return true;
} }
} else if id.contains(word) { } else if id.contains(word) {
return true; return true;
} }
} }
} }
false false
} }
} }

View File

@ -2,10 +2,128 @@ use sqids::*;
#[test] #[test]
fn simple() { fn simple() {
let id = "8QRLaD".to_string(); let sqids = Sqids::new(None).unwrap();
let numbers = vec![1, 2, 3];
let sqids = Sqids::new(None).unwrap(); let numbers = vec![1, 2, 3];
assert_eq!(sqids.encode(&numbers).unwrap(), id); let id = "8QRLaD";
assert_eq!(sqids.decode(&id), numbers);
assert_eq!(sqids.encode(&numbers).unwrap(), id);
assert_eq!(sqids.decode(id), numbers);
}
#[test]
fn different_inputs() {
let sqids = Sqids::new(None).unwrap();
let numbers = vec![0, 0, 0, 1, 2, 3, 100, 1_000, 100_000, 1_000_000, sqids.max_value()];
assert_eq!(sqids.decode(&sqids.encode(&numbers).unwrap()), numbers);
}
#[test]
fn incremental_numbers() {
let sqids = Sqids::new(None).unwrap();
let ids = vec![
("bV", vec![0]),
("U9", vec![1]),
("g8", vec![2]),
("Ez", vec![3]),
("V8", vec![4]),
("ul", vec![5]),
("O3", vec![6]),
("AF", vec![7]),
("ph", vec![8]),
("n8", vec![9]),
];
for (id, numbers) in ids {
assert_eq!(sqids.encode(&numbers).unwrap(), id);
assert_eq!(sqids.decode(id), numbers);
}
}
#[test]
fn incremental_numbers_same_index_0() {
let sqids = Sqids::new(None).unwrap();
let ids = vec![
("SrIu", vec![0, 0]),
("nZqE", vec![0, 1]),
("tJyf", vec![0, 2]),
("e86S", vec![0, 3]),
("rtC7", vec![0, 4]),
("sQ8R", vec![0, 5]),
("uz2n", vec![0, 6]),
("7Td9", vec![0, 7]),
("3nWE", vec![0, 8]),
("mIxM", vec![0, 9]),
];
for (id, numbers) in ids {
assert_eq!(sqids.encode(&numbers).unwrap(), id);
assert_eq!(sqids.decode(id), numbers);
}
}
#[test]
fn incremental_numbers_same_index_1() {
let sqids = Sqids::new(None).unwrap();
let ids = vec![
("SrIu", vec![0, 0]),
("nbqh", vec![1, 0]),
("t4yj", vec![2, 0]),
("eQ6L", vec![3, 0]),
("r4Cc", vec![4, 0]),
("sL82", vec![5, 0]),
("uo2f", vec![6, 0]),
("7Zdq", vec![7, 0]),
("36Wf", vec![8, 0]),
("m4xT", vec![9, 0]),
];
for (id, numbers) in ids {
assert_eq!(sqids.encode(&numbers).unwrap(), id);
assert_eq!(sqids.decode(id), numbers);
}
}
#[test]
fn multi_input() {
let sqids = Sqids::new(None).unwrap();
let numbers: Vec<u64> = (0..100).collect();
let output = sqids.decode(&sqids.encode(&numbers).unwrap());
assert_eq!(numbers, output);
}
#[test]
fn encoding_no_numbers() {
let sqids = Sqids::new(None).unwrap();
assert_eq!(sqids.encode(&[]).unwrap(), "");
}
#[test]
fn decoding_empty_string() {
let sqids = Sqids::new(None).unwrap();
let numbers: Vec<u64> = vec![];
assert_eq!(sqids.decode(""), numbers);
}
#[test]
fn decoding_invalid_character() {
let sqids = Sqids::new(None).unwrap();
let numbers: Vec<u64> = vec![];
assert_eq!(sqids.decode("*"), numbers);
}
#[test]
#[should_panic]
fn encode_out_of_range_numbers() {
let sqids = Sqids::new(None).unwrap();
assert!(sqids.encode(&[sqids.min_value() - 1]).is_err());
assert!(sqids.encode(&[sqids.max_value() + 1]).is_err());
} }