diff --git a/src/server/list_family.cc b/src/server/list_family.cc index 2a02efd49..e415b70dc 100644 --- a/src/server/list_family.cc +++ b/src/server/list_family.cc @@ -399,13 +399,30 @@ void ListFamily::PushGeneric(ListDir dir, bool skip_notexists, CmdArgList args, } void ListFamily::PopGeneric(ListDir dir, const CmdArgList& args, ConnectionContext* cntx) { - std::string_view key = ArgS(args, 1); + string_view key = ArgS(args, 1); + int32_t count = 1; + bool return_arr = false; + + if (args.size() > 2) { + if (args.size() > 3) + return (*cntx)->SendError(kSyntaxErr); + + string_view count_s = ArgS(args, 2); + if (!absl::SimpleAtoi(count_s, &count)) { + return (*cntx)->SendError(kInvalidIntErr); + } + + if (count < 0) { + return (*cntx)->SendError(facade::kUintErr); + } + return_arr = true; + } auto cb = [&](Transaction* t, EngineShard* shard) { - return OpPop(OpArgs{shard, t->db_index()}, key, dir); + return OpPop(OpArgs{shard, t->db_index()}, key, dir, count); }; - OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); switch (result.status()) { case OpStatus::KEY_NOTFOUND: @@ -415,7 +432,19 @@ void ListFamily::PopGeneric(ListDir dir, const CmdArgList& args, ConnectionConte default:; } - return (*cntx)->SendBulkString(result.value()); + if (return_arr) { + if (result->empty()) { + (*cntx)->SendNullArray(); + } else { + (*cntx)->StartArray(result->size()); + for (const auto& k : *result) { + (*cntx)->SendBulkString(k); + } + } + } else { + DCHECK_EQ(1u, result->size()); + (*cntx)->SendBulkString(result->front()); + } } OpResult ListFamily::OpPush(const OpArgs& op_args, std::string_view key, ListDir dir, @@ -466,7 +495,8 @@ OpResult ListFamily::OpPush(const OpArgs& op_args, std::string_view ke return quicklistCount(ql); } -OpResult ListFamily::OpPop(const OpArgs& op_args, string_view key, ListDir dir) { +OpResult ListFamily::OpPop(const OpArgs& op_args, string_view key, ListDir dir, + uint32_t count) { auto& db_slice = op_args.shard->db_slice(); OpResult it_res = db_slice.Find(op_args.db_ind, key, OBJ_LIST); if (!it_res) @@ -476,7 +506,15 @@ OpResult ListFamily::OpPop(const OpArgs& op_args, string_view key, ListD quicklist* ql = GetQL(it->second); db_slice.PreUpdate(op_args.db_ind, it); - string res = ListPop(dir, ql); + StringVec res; + if (quicklistCount(ql) < count) { + count = quicklistCount(ql); + } + res.reserve(count); + + for (unsigned i = 0; i < count; ++i) { + res.push_back(ListPop(dir, ql)); + } db_slice.PostUpdate(op_args.db_ind, it); if (quicklistCount(ql) == 0) { @@ -661,10 +699,10 @@ using CI = CommandId; void ListFamily::Register(CommandRegistry* registry) { *registry << CI{"LPUSH", CO::WRITE | CO::FAST | CO::DENYOOM, -3, 1, 1, 1}.HFUNC(LPush) << CI{"LPUSHX", CO::WRITE | CO::FAST | CO::DENYOOM, -3, 1, 1, 1}.HFUNC(LPushX) - << CI{"LPOP", CO::WRITE | CO::FAST | CO::DENYOOM, 2, 1, 1, 1}.HFUNC(LPop) + << CI{"LPOP", CO::WRITE | CO::FAST | CO::DENYOOM, -2, 1, 1, 1}.HFUNC(LPop) << CI{"RPUSH", CO::WRITE | CO::FAST | CO::DENYOOM, -3, 1, 1, 1}.HFUNC(RPush) << CI{"RPUSHX", CO::WRITE | CO::FAST | CO::DENYOOM, -3, 1, 1, 1}.HFUNC(RPushX) - << CI{"RPOP", CO::WRITE | CO::FAST | CO::DENYOOM, 2, 1, 1, 1}.HFUNC(RPop) + << CI{"RPOP", CO::WRITE | CO::FAST | CO::DENYOOM, -2, 1, 1, 1}.HFUNC(RPop) << CI{"BLPOP", CO::WRITE | CO::NOSCRIPT | CO::BLOCKING, -3, 1, -2, 1}.HFUNC(BLPop) << CI{"LLEN", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(LLen) << CI{"LINDEX", CO::READONLY, 3, 1, 1, 1}.HFUNC(LIndex) diff --git a/src/server/list_family.h b/src/server/list_family.h index 7b6262838..c3e48efaa 100644 --- a/src/server/list_family.h +++ b/src/server/list_family.h @@ -40,7 +40,10 @@ class ListFamily { static OpResult OpPush(const OpArgs& op_args, std::string_view key, ListDir dir, bool skip_notexist, absl::Span vals); - static OpResult OpPop(const OpArgs& op_args, std::string_view key, ListDir dir); + + static OpResult OpPop(const OpArgs& op_args, std::string_view key, ListDir dir, + uint32_t count); + static OpResult OpLen(const OpArgs& op_args, std::string_view key); static OpResult OpIndex(const OpArgs& op_args, std::string_view key, long index);