refactor: remove extra code from CmdArgParser (#3619)

* refactor: remove extra code from CmdArgParser
This commit is contained in:
Borys 2024-09-03 10:04:05 +03:00 committed by GitHub
parent 8fca7dd9f8
commit d40e9088ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 66 additions and 125 deletions

View File

@ -11,24 +11,6 @@
namespace facade {
CmdArgParser::CheckProxy::operator bool() const {
if (idx_ >= parser_->args_.size())
return false;
std::string_view arg = parser_->SafeSV(idx_);
if (!absl::EqualsIgnoreCase(arg, tag_))
return false;
if (idx_ + expect_tail_ >= parser_->args_.size()) {
parser_->Report(SHORT_OPT_TAIL, idx_);
return false;
}
parser_->cur_i_++;
return true;
}
void CmdArgParser::ExpectTag(std::string_view tag) {
if (cur_i_ >= args_.size()) {
Report(OUT_OF_BOUNDS, cur_i_);

View File

@ -19,29 +19,6 @@ namespace facade {
struct CmdArgParser {
enum ErrorType { OUT_OF_BOUNDS, SHORT_OPT_TAIL, INVALID_INT, INVALID_CASES, INVALID_NEXT };
struct CheckProxy {
explicit operator bool() const;
// Expect the tag to be followed by a number of arguments.
// Reports an error if the tag is matched but the condition is not met.
CheckProxy& ExpectTail(size_t tail) {
expect_tail_ = tail;
return *this;
}
private:
friend struct CmdArgParser;
CheckProxy(CmdArgParser* parser, std::string_view tag, size_t idx)
: parser_{parser}, tag_{tag}, idx_{idx} {
}
CmdArgParser* parser_;
std::string_view tag_;
size_t idx_;
size_t expect_tail_ = 0;
};
struct ErrorInfo {
ErrorType type;
size_t index;
@ -65,6 +42,7 @@ struct CmdArgParser {
template <class T = std::string_view, class... Ts> auto Next() {
if (cur_i_ + sizeof...(Ts) >= args_.size()) {
Report(OUT_OF_BOUNDS, cur_i_);
return std::conditional_t<sizeof...(Ts) == 0, T, std::tuple<T, Ts...>>();
}
if constexpr (sizeof...(Ts) == 0) {
@ -88,8 +66,11 @@ struct CmdArgParser {
// Consume next value
template <class... Cases> auto Switch(Cases&&... cases) {
if (cur_i_ >= args_.size())
if (cur_i_ >= args_.size()) {
Report(OUT_OF_BOUNDS, cur_i_);
return typename decltype(SwitchImpl(std::string_view(),
std::forward<Cases>(cases)...))::value_type{};
}
auto idx = cur_i_++;
auto res = SwitchImpl(SafeSV(idx), std::forward<Cases>(cases)...);
@ -101,13 +82,26 @@ struct CmdArgParser {
}
// Check if the next value if equal to a specific tag. If equal, its consumed.
CheckProxy Check(std::string_view tag) {
return CheckProxy(this, tag, cur_i_);
bool Check(std::string_view tag) {
if (cur_i_ >= args_.size())
return false;
std::string_view arg = SafeSV(cur_i_);
if (!absl::EqualsIgnoreCase(arg, tag))
return false;
cur_i_++;
return true;
}
// Skip specified number of arguments
CmdArgParser& Skip(size_t n) {
cur_i_ += n;
if (cur_i_ + n > args_.size()) {
Report(OUT_OF_BOUNDS, cur_i_);
} else {
cur_i_ += n;
}
return *this;
}

View File

@ -78,7 +78,7 @@ TEST_F(CmdArgParserTest, Check) {
EXPECT_TRUE(parser.Check("TAG"));
EXPECT_FALSE(parser.Check("NOT_TAG_2"));
EXPECT_TRUE(parser.Check("TAG_2").ExpectTail(1));
EXPECT_TRUE(parser.Check("TAG_2"));
}
TEST_F(CmdArgParserTest, NextStatement) {
@ -97,15 +97,15 @@ TEST_F(CmdArgParserTest, NextStatement) {
TEST_F(CmdArgParserTest, CheckTailFail) {
auto parser = Make({"TAG", "11", "22", "TAG", "33"});
EXPECT_TRUE(parser.Check("TAG").ExpectTail(2));
EXPECT_TRUE(parser.Check("TAG"));
parser.Skip(2);
EXPECT_FALSE(parser.Check("TAG").ExpectTail(2));
EXPECT_TRUE(parser.Check("TAG"));
parser.Next<int, int>();
auto err = parser.Error();
EXPECT_TRUE(err);
EXPECT_EQ(err->type, CmdArgParser::SHORT_OPT_TAIL);
EXPECT_EQ(err->index, 3);
EXPECT_EQ(err->index, 4);
}
TEST_F(CmdArgParserTest, Cases) {
@ -125,7 +125,7 @@ TEST_F(CmdArgParserTest, IgnoreCase) {
EXPECT_EQ(absl::implicit_cast<string_view>(parser.Next()), "hello"sv);
EXPECT_TRUE(parser.Check("MARKER"sv).ExpectTail(1));
EXPECT_TRUE(parser.Check("MARKER"sv));
parser.Skip(1);
EXPECT_EQ(absl::implicit_cast<string_view>(parser.Next()), "world"sv);

View File

@ -1843,15 +1843,15 @@ void JsonFamily::Get(CmdArgList args, ConnectionContext* cntx) {
vector<pair<string_view, WrappedJsonPath>> paths;
while (parser.HasNext()) {
if (parser.Check("SPACE").ExpectTail(1)) {
if (parser.Check("SPACE")) {
space = parser.Next();
continue;
}
if (parser.Check("NEWLINE").ExpectTail(1)) {
if (parser.Check("NEWLINE")) {
new_line = parser.Next();
continue;
}
if (parser.Check("INDENT").ExpectTail(1)) {
if (parser.Check("INDENT")) {
indent = parser.Next();
continue;
}

View File

@ -902,18 +902,18 @@ void ListFamily::LPos(CmdArgList args, ConnectionContext* cntx) {
bool skip_count = true;
while (parser.HasNext()) {
if (parser.Check("RANK").ExpectTail(1)) {
if (parser.Check("RANK")) {
rank = parser.Next<int>();
continue;
}
if (parser.Check("COUNT").ExpectTail(1)) {
if (parser.Check("COUNT")) {
count = parser.Next<uint32_t>();
skip_count = false;
continue;
}
if (parser.Check("MAXLEN").ExpectTail(1)) {
if (parser.Check("MAXLEN")) {
max_len = parser.Next<uint32_t>();
continue;
}

View File

@ -48,48 +48,29 @@ search::SchemaField::VectorParams ParseVectorParams(CmdArgParser* parser) {
search::SchemaField::VectorParams params{};
params.use_hnsw = parser->Switch("HNSW", true, "FLAT", false);
size_t num_args = parser->Next<size_t>();
const size_t num_args = parser->Next<size_t>();
for (size_t i = 0; i * 2 < num_args; i++) {
if (parser->Check("DIM").ExpectTail(1)) {
if (parser->Check("DIM")) {
params.dim = parser->Next<size_t>();
continue;
}
if (parser->Check("DISTANCE_METRIC").ExpectTail(1)) {
} else if (parser->Check("DISTANCE_METRIC")) {
params.sim = parser->Switch("L2", search::VectorSimilarity::L2, "COSINE",
search::VectorSimilarity::COSINE);
continue;
}
if (parser->Check("INITIAL_CAP").ExpectTail(1)) {
} else if (parser->Check("INITIAL_CAP")) {
params.capacity = parser->Next<size_t>();
continue;
}
if (parser->Check("M").ExpectTail(1)) {
} else if (parser->Check("M")) {
params.hnsw_m = parser->Next<size_t>();
continue;
}
if (parser->Check("EF_CONSTRUCTION").ExpectTail(1)) {
} else if (parser->Check("EF_CONSTRUCTION")) {
params.hnsw_ef_construction = parser->Next<size_t>();
continue;
}
if (parser->Check("EF_RUNTIME").ExpectTail(1)) {
} else if (parser->Check("EF_RUNTIME")) {
parser->Next<size_t>();
LOG(WARNING) << "EF_RUNTIME not supported";
continue;
}
if (parser->Check("EPSILON").ExpectTail(1)) {
} else if (parser->Check("EPSILON")) {
parser->Next<double>();
LOG(WARNING) << "EPSILON not supported";
continue;
} else {
parser->Skip(2);
}
parser->Skip(2);
}
return params;
@ -98,7 +79,7 @@ search::SchemaField::VectorParams ParseVectorParams(CmdArgParser* parser) {
search::SchemaField::TagParams ParseTagParams(CmdArgParser* parser) {
search::SchemaField::TagParams params{};
while (parser->HasNext()) {
if (parser->Check("SEPARATOR").ExpectTail(1)) {
if (parser->Check("SEPARATOR")) {
string_view separator = parser->Next();
params.separator = separator.front();
continue;
@ -111,7 +92,6 @@ search::SchemaField::TagParams ParseTagParams(CmdArgParser* parser) {
break;
}
return params;
}
@ -136,7 +116,7 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
}
// AS [alias]
if (parser.Check("AS").ExpectTail(1))
if (parser.Check("AS"))
field_alias = parser.Next();
// Determine type
@ -223,43 +203,28 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, ConnectionC
while (parser.HasNext()) {
// [LIMIT offset total]
if (parser.Check("LIMIT").ExpectTail(2)) {
if (parser.Check("LIMIT")) {
params.limit_offset = parser.Next<size_t>();
params.limit_total = parser.Next<size_t>();
continue;
}
// RETURN {num} [{ident} AS {name}...]
if (parser.Check("RETURN").ExpectTail(1)) {
} else if (parser.Check("RETURN")) {
// RETURN {num} [{ident} AS {name}...]
size_t num_fields = parser.Next<size_t>();
params.return_fields = SearchParams::FieldReturnList{};
while (params.return_fields->size() < num_fields) {
string_view ident = parser.Next();
string_view alias = parser.Check("AS").ExpectTail(1) ? parser.Next() : ident;
string_view alias = parser.Check("AS") ? parser.Next() : ident;
params.return_fields->emplace_back(ident, alias);
}
continue;
}
// NOCONTENT
if (parser.Check("NOCONTENT")) {
} else if (parser.Check("NOCONTENT")) { // NOCONTENT
params.return_fields = SearchParams::FieldReturnList{};
continue;
}
// [PARAMS num(ignored) name(ignored) knn_vector]
if (parser.Check("PARAMS").ExpectTail(1)) {
} else if (parser.Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector]
params.query_params = ParseQueryParams(&parser);
continue;
}
if (parser.Check("SORTBY").ExpectTail(1)) {
} else if (parser.Check("SORTBY")) {
params.sort_option = search::SortOption{string{parser.Next()}, bool(parser.Check("DESC"))};
continue;
} else {
// Unsupported parameters are ignored for now
parser.Skip(1);
}
// Unsupported parameters are ignored for now
parser.Skip(1);
}
if (auto err = parser.Error(); err) {
@ -285,7 +250,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
while (parser.HasNext()) {
// LOAD count field [field ...]
if (parser.Check("LOAD").ExpectTail(1)) {
if (parser.Check("LOAD")) {
params.load_fields.resize(parser.Next<size_t>());
for (string_view& field : params.load_fields)
field = parser.Next();
@ -293,13 +258,13 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
}
// GROUPBY nargs property [property ...]
if (parser.Check("GROUPBY").ExpectTail(1)) {
if (parser.Check("GROUPBY")) {
vector<string_view> fields(parser.Next<size_t>());
for (string_view& field : fields)
field = parser.Next();
vector<aggregate::Reducer> reducers;
while (parser.Check("REDUCE").ExpectTail(2)) {
while (parser.Check("REDUCE")) {
parser.ToUpper(); // uppercase for func_name
auto [func_name, nargs] = parser.Next<string_view, size_t>();
auto func = aggregate::FindReducerFunc(func_name);
@ -325,7 +290,7 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
}
// SORTBY nargs
if (parser.Check("SORTBY").ExpectTail(1)) {
if (parser.Check("SORTBY")) {
parser.ExpectTag("1");
string_view field = parser.Next();
bool desc = bool(parser.Check("DESC"));
@ -335,14 +300,14 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
}
// LIMIT
if (parser.Check("LIMIT").ExpectTail(2)) {
if (parser.Check("LIMIT")) {
auto [offset, num] = parser.Next<size_t, size_t>();
params.steps.push_back(aggregate::MakeLimitStep(offset, num));
continue;
}
// PARAMS
if (parser.Check("PARAMS").ExpectTail(1)) {
if (parser.Check("PARAMS")) {
params.params = ParseQueryParams(&parser);
continue;
}
@ -469,13 +434,13 @@ void SearchFamily::FtCreate(CmdArgList args, ConnectionContext* cntx) {
while (parser.HasNext()) {
// ON HASH | JSON
if (parser.Check("ON").ExpectTail(1)) {
if (parser.Check("ON")) {
index.type = parser.Switch("HASH"sv, DocIndex::HASH, "JSON"sv, DocIndex::JSON);
continue;
}
// PREFIX count prefix [prefix ...]
if (parser.Check("PREFIX").ExpectTail(2)) {
if (parser.Check("PREFIX")) {
if (size_t num = parser.Next<size_t>(); num != 1)
return cntx->SendError("Multiple prefixes are not supported");
index.prefix = string(parser.Next());

View File

@ -645,7 +645,7 @@ optional<ReplicaOfArgs> ReplicaOfArgs::FromCmdArgs(CmdArgList args, ConnectionCo
ReplicaOfArgs replicaof_args;
CmdArgParser parser(args);
if (parser.Check("NO").ExpectTail(1)) {
if (parser.Check("NO")) {
parser.ExpectTag("ONE");
replicaof_args.port = 0;
} else {

View File

@ -1887,7 +1887,7 @@ void SetId(string_view key, string_view gname, CmdArgList args, ConnectionContex
string_view id = parser.Next();
while (parser.HasNext()) {
if (parser.Check("ENTRIESREAD").ExpectTail(1)) {
if (parser.Check("ENTRIESREAD")) {
// TODO: to support ENTRIESREAD.
return cntx->SendError(kSyntaxErr);
} else {

View File

@ -799,7 +799,7 @@ void StringFamily::Set(CmdArgList args, ConnectionContext* cntx) {
}
tie(sparams.expire_after_ms, ignore) = expiry.Calculate(now_ms, true);
} else if (parser.Check("_MCFLAGS").ExpectTail(1)) {
} else if (parser.Check("_MCFLAGS")) {
sparams.memcache_flags = parser.Next<uint32_t>();
} else {
uint16_t flag = parser.Switch( //