diff --git a/src/facade/conn_context.h b/src/facade/conn_context.h index ddbce31d5..70f82e403 100644 --- a/src/facade/conn_context.h +++ b/src/facade/conn_context.h @@ -36,17 +36,10 @@ class ConnectionContext { return protocol_; } - SinkReplyBuilder* reply_builder() { + SinkReplyBuilder* reply_builder_old() { return rbuilder_.get(); } - // Allows receiving the output data from the commands called from scripts. - SinkReplyBuilder* Inject(SinkReplyBuilder* new_i) { - SinkReplyBuilder* res = rbuilder_.release(); - rbuilder_.reset(new_i); - return res; - } - virtual size_t UsedMemory() const; // connection state / properties. diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 85929aa4c..995e4b6e0 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -725,7 +725,7 @@ void Connection::HandleRequests() { // down and return with an error accordingly. if (http_res && socket_->IsOpen()) { cc_.reset(service_->CreateContext(socket_.get(), this)); - reply_builder_ = cc_->reply_builder(); + reply_builder_ = cc_->reply_builder_old(); if (*http_res) { VLOG(1) << "HTTP1.1 identified"; @@ -811,7 +811,7 @@ std::pair Connection::GetClientInfoBeforeAfterTid() co string_view phase_name = PHASE_NAMES[phase_]; if (cc_) { - DCHECK(cc_->reply_builder() && reply_builder_); + DCHECK(reply_builder_); string cc_info = service_->GetContextInfo(cc_.get()).Format(); if (reply_builder_->IsSendActive()) phase_name = "send"; diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index dcb627828..6cdd66f1a 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -21,6 +21,17 @@ namespace dfly { using namespace std; using namespace facade; +static void SendSubscriptionChangedResponse(string_view action, std::optional topic, + unsigned count, RedisReplyBuilder* rb) { + rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH); + rb->SendBulkString(action); + if (topic.has_value()) + rb->SendBulkString(topic.value()); + else + rb->SendNull(); + rb->SendLong(count); +} + StoredCmd::StoredCmd(const CommandId* cid, ArgSlice args, facade::ReplyMode mode) : cid_{cid}, buffer_{}, sizes_(args.size()), reply_mode_{mode} { size_t total_size = 0; @@ -98,8 +109,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own } } -ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx, - facade::CapturingReplyBuilder* crb) +ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx) : facade::ConnectionContext(nullptr, nullptr), transaction{tx} { if (owner) { acl_commands = owner->acl_commands; @@ -115,8 +125,6 @@ ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction conn_state.db_index = owner->conn_state.db_index; conn_state.squashing_info = {owner}; } - auto* prev_reply_builder = Inject(crb); - CHECK_EQ(prev_reply_builder, nullptr); } void ConnectionContext::ChangeMonitor(bool start) { @@ -137,61 +145,13 @@ void ConnectionContext::ChangeMonitor(bool start) { EnableMonitoring(start); } -vector ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add, bool to_reply, - ConnectionContext* conn) { - vector result(to_reply ? args.size() : 0, 0); - - auto& conn_state = conn->conn_state; - if (!to_add && !conn_state.subscribe_info) - return result; - - if (!conn_state.subscribe_info) { - DCHECK(to_add); - - conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo); - conn->subscriptions++; - } - - auto& sinfo = *conn->conn_state.subscribe_info.get(); - auto& local_store = pattern ? sinfo.patterns : sinfo.channels; - - int32_t tid = util::ProactorBase::me()->GetPoolIndex(); - DCHECK_GE(tid, 0); - - ChannelStoreUpdater csu{pattern, to_add, conn, uint32_t(tid)}; - - // Gather all the channels we need to subscribe to / remove. - size_t i = 0; - for (string_view channel : args) { - if (to_add && local_store.emplace(channel).second) - csu.Record(channel); - else if (!to_add && local_store.erase(channel) > 0) - csu.Record(channel); - - if (to_reply) - result[i++] = sinfo.SubscriptionCount(); - } - - csu.Apply(); - - // Important to reset conn_state.subscribe_info only after all references to it were - // removed. - if (!to_add && conn_state.subscribe_info->IsEmpty()) { - conn_state.subscribe_info.reset(); - DCHECK_GE(conn->subscriptions, 1u); - conn->subscriptions--; - } - - return result; -} - -void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args) { - vector result = ChangeSubscriptions(false, args, to_add, to_reply, this); +void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args, + facade::RedisReplyBuilder* rb) { + vector result = ChangeSubscriptions(args, false, to_add, to_reply); if (to_reply) { for (size_t i = 0; i < result.size(); ++i) { const char* action[2] = {"unsubscribe", "subscribe"}; - auto rb = static_cast(reply_builder()); rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH); rb->SendBulkString(action[to_add]); rb->SendBulkString(ArgS(args, i)); @@ -200,53 +160,41 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis } } -void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args) { - vector result = ChangeSubscriptions(true, args, to_add, to_reply, this); +void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args, + facade::RedisReplyBuilder* rb) { + vector result = ChangeSubscriptions(args, true, to_add, to_reply); if (to_reply) { const char* action[2] = {"punsubscribe", "psubscribe"}; if (result.size() == 0) { - return SendSubscriptionChangedResponse(action[to_add], std::nullopt, 0); + return SendSubscriptionChangedResponse(action[to_add], std::nullopt, 0, rb); } for (size_t i = 0; i < result.size(); ++i) { - SendSubscriptionChangedResponse(action[to_add], ArgS(args, i), result[i]); + SendSubscriptionChangedResponse(action[to_add], ArgS(args, i), result[i], rb); } } } -void ConnectionContext::UnsubscribeAll(bool to_reply) { +void ConnectionContext::UnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb) { if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->channels.empty())) { - return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0); + return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0, rb); } StringVec channels(conn_state.subscribe_info->channels.begin(), conn_state.subscribe_info->channels.end()); CmdArgVec arg_vec(channels.begin(), channels.end()); - ChangeSubscription(false, to_reply, CmdArgList{arg_vec}); + ChangeSubscription(false, to_reply, CmdArgList{arg_vec}, rb); } -void ConnectionContext::PUnsubscribeAll(bool to_reply) { +void ConnectionContext::PUnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb) { if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->patterns.empty())) { - return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0); + return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0, rb); } StringVec patterns(conn_state.subscribe_info->patterns.begin(), conn_state.subscribe_info->patterns.end()); CmdArgVec arg_vec(patterns.begin(), patterns.end()); - ChangePSubscription(false, to_reply, CmdArgList{arg_vec}); -} - -void ConnectionContext::SendSubscriptionChangedResponse(string_view action, - std::optional topic, - unsigned count) { - auto rb = static_cast(reply_builder()); - rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH); - rb->SendBulkString(action); - if (topic.has_value()) - rb->SendBulkString(topic.value()); - else - rb->SendNull(); - rb->SendLong(count); + ChangePSubscription(false, to_reply, CmdArgList{arg_vec}, rb); } size_t ConnectionState::ExecInfo::UsedMemory() const { @@ -269,6 +217,53 @@ size_t ConnectionContext::UsedMemory() const { return facade::ConnectionContext::UsedMemory() + dfly::HeapSize(conn_state); } +vector ConnectionContext::ChangeSubscriptions(CmdArgList channels, bool pattern, + bool to_add, bool to_reply) { + vector result(to_reply ? channels.size() : 0, 0); + + if (!to_add && !conn_state.subscribe_info) + return result; + + if (!conn_state.subscribe_info) { + DCHECK(to_add); + + conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo); + subscriptions++; + } + + auto& sinfo = *conn_state.subscribe_info.get(); + auto& local_store = pattern ? sinfo.patterns : sinfo.channels; + + int32_t tid = util::ProactorBase::me()->GetPoolIndex(); + DCHECK_GE(tid, 0); + + ChannelStoreUpdater csu{pattern, to_add, this, uint32_t(tid)}; + + // Gather all the channels we need to subscribe to / remove. + size_t i = 0; + for (string_view channel : channels) { + if (to_add && local_store.emplace(channel).second) + csu.Record(channel); + else if (!to_add && local_store.erase(channel) > 0) + csu.Record(channel); + + if (to_reply) + result[i++] = sinfo.SubscriptionCount(); + } + + csu.Apply(); + + // Important to reset conn_state.subscribe_info only after all references to it were + // removed. + if (!to_add && conn_state.subscribe_info->IsEmpty()) { + conn_state.subscribe_info.reset(); + DCHECK_GE(subscriptions, 1u); + subscriptions--; + } + + return result; +} + void ConnectionState::ExecInfo::Clear() { DCHECK(!preborrowed_interpreter); // Must have been released properly state = EXEC_INACTIVE; diff --git a/src/server/conn_context.h b/src/server/conn_context.h index c53bbaf21..dc8c50442 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -267,9 +267,7 @@ struct ConnectionState { class ConnectionContext : public facade::ConnectionContext { public: ConnectionContext(::io::Sink* stream, facade::Connection* owner, dfly::acl::UserCredentials cred); - - ConnectionContext(const ConnectionContext* owner, Transaction* tx, - facade::CapturingReplyBuilder* crb); + ConnectionContext(const ConnectionContext* owner, Transaction* tx); struct DebugInfo { uint32_t shards_count = 0; @@ -292,10 +290,13 @@ class ConnectionContext : public facade::ConnectionContext { return conn_state.db_index; } - void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args); - void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args); - void UnsubscribeAll(bool to_reply); - void PUnsubscribeAll(bool to_reply); + void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args, + facade::RedisReplyBuilder* rb); + + void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args, + facade::RedisReplyBuilder* rb); + void UnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb); + void PUnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb); void ChangeMonitor(bool start); // either start or stop monitor on a given connection size_t UsedMemory() const override; @@ -317,8 +318,8 @@ class ConnectionContext : public facade::ConnectionContext { monitor = enable; } - void SendSubscriptionChangedResponse(std::string_view action, - std::optional topic, unsigned count); + std::vector ChangeSubscriptions(CmdArgList channels, bool pattern, bool to_add, + bool to_reply); }; } // namespace dfly diff --git a/src/server/debugcmd.cc b/src/server/debugcmd.cc index a7b513295..3eb382011 100644 --- a/src/server/debugcmd.cc +++ b/src/server/debugcmd.cc @@ -148,7 +148,7 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool absl::InlinedVector args_view; facade::CapturingReplyBuilder crb; - ConnectionContext local_cntx{cntx, stub_tx.get(), &crb}; + ConnectionContext local_cntx{cntx, stub_tx.get()}; absl::InsecureBitGen gen; for (unsigned i = 0; i < batch.sz; ++i) { @@ -175,7 +175,6 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool sf->service().InvokeCmd(cid, args_span, &crb, &local_cntx); } - local_cntx.Inject(nullptr); local_tx->UnlockMulti(); } diff --git a/src/server/journal/executor.cc b/src/server/journal/executor.cc index 6e91e1d4c..5f11e2891 100644 --- a/src/server/journal/executor.cc +++ b/src/server/journal/executor.cc @@ -38,9 +38,7 @@ template journal::ParsedEntry::CmdData BuildFromParts(Ts... par } // namespace JournalExecutor::JournalExecutor(Service* service) - : service_{service}, - reply_builder_{facade::ReplyMode::NONE}, - conn_context_{nullptr, nullptr, &reply_builder_} { + : service_{service}, reply_builder_{facade::ReplyMode::NONE}, conn_context_{nullptr, nullptr} { conn_context_.is_replicating = true; conn_context_.journal_emulated = true; conn_context_.skip_acl_validation = true; @@ -48,7 +46,6 @@ JournalExecutor::JournalExecutor(Service* service) } JournalExecutor::~JournalExecutor() { - conn_context_.Inject(nullptr); } void JournalExecutor::Execute(DbIndex dbid, absl::Span cmds) { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index df619404c..dd091d12b 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1803,13 +1803,6 @@ void Service::Unwatch(CmdArgList args, Transaction* tx, SinkReplyBuilder* builde return builder->SendOk(); } -template void WithReplies(CapturingReplyBuilder* crb, ConnectionContext* cntx, F&& f) { - SinkReplyBuilder* old_rrb = nullptr; - old_rrb = cntx->Inject(crb); - f(); - cntx->Inject(old_rrb); -} - optional Service::FlushEvalAsyncCmds(ConnectionContext* cntx, bool force) { auto& info = cntx->conn_state.script_info; @@ -1825,9 +1818,7 @@ optional Service::FlushEvalAsyncCmds(ConnectionC tx->MultiSwitchCmd(eval_cid); CapturingReplyBuilder crb{ReplyMode::ONLY_ERR}; - WithReplies(&crb, cntx, [&] { - MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), &crb, cntx, this, true, true); - }); + MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), &crb, cntx, this, true, true); info->async_cmds_heap_mem = 0; info->async_cmds.clear(); @@ -1842,9 +1833,6 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca) DVLOG(2) << "CallFromScript " << ca.args[0]; InterpreterReplier replier(ca.translator); - facade::SinkReplyBuilder* orig = cntx->Inject(&replier); - absl::Cleanup clean = [orig, cntx] { cntx->Inject(orig); }; - optional findcmd_err; if (ca.async) { @@ -2364,7 +2352,8 @@ void Service::Subscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil if (cluster::IsClusterEnabled()) { return builder->SendError("SUBSCRIBE is not supported in cluster mode yet"); } - cntx->ChangeSubscription(true /*add*/, true /* reply*/, std::move(args)); + cntx->ChangeSubscription(true /*add*/, true /* reply*/, std::move(args), + static_cast(builder)); } void Service::Unsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, @@ -2373,9 +2362,9 @@ void Service::Unsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* bu return builder->SendError("UNSUBSCRIBE is not supported in cluster mode yet"); } if (args.size() == 0) { - cntx->UnsubscribeAll(true); + cntx->UnsubscribeAll(true, static_cast(builder)); } else { - cntx->ChangeSubscription(false, true, args); + cntx->ChangeSubscription(false, true, args, static_cast(builder)); } } @@ -2384,7 +2373,7 @@ void Service::PSubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* bui if (cluster::IsClusterEnabled()) { return builder->SendError("PSUBSCRIBE is not supported in cluster mode yet"); } - cntx->ChangePSubscription(true, true, args); + cntx->ChangePSubscription(true, true, args, static_cast(builder)); } void Service::PUnsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder, @@ -2393,9 +2382,9 @@ void Service::PUnsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* b return builder->SendError("PUNSUBSCRIBE is not supported in cluster mode yet"); } if (args.size() == 0) { - cntx->PUnsubscribeAll(true); + cntx->PUnsubscribeAll(true, static_cast(builder)); } else { - cntx->ChangePSubscription(false, true, args); + cntx->ChangePSubscription(false, true, args, static_cast(builder)); } } @@ -2653,12 +2642,12 @@ void Service::OnConnectionClose(facade::ConnectionContext* cntx) { if (conn_state.subscribe_info) { // Clean-ups related to PUBSUB if (!conn_state.subscribe_info->channels.empty()) { - server_cntx->UnsubscribeAll(false); + server_cntx->UnsubscribeAll(false, nullptr); } if (conn_state.subscribe_info) { DCHECK(!conn_state.subscribe_info->patterns.empty()); - server_cntx->PUnsubscribeAll(false); + server_cntx->PUnsubscribeAll(false, nullptr); } DCHECK(!conn_state.subscribe_info); diff --git a/src/server/multi_command_squasher.cc b/src/server/multi_command_squasher.cc index dfda107f3..e62f3092e 100644 --- a/src/server/multi_command_squasher.cc +++ b/src/server/multi_command_squasher.cc @@ -143,7 +143,7 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard auto* local_tx = sinfo.local_tx.get(); facade::CapturingReplyBuilder crb; - ConnectionContext local_cntx{cntx_, local_tx, &crb}; + ConnectionContext local_cntx{cntx_, local_tx}; if (cntx_->conn()) { local_cntx.skip_acl_validation = cntx_->conn()->IsPrivileged(); } @@ -178,10 +178,6 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard CheckConnStateClean(local_state); } - // ConnectionContext deletes the reply builder upon destruction, so - // remove our local pointer from it. - local_cntx.Inject(nullptr); - reverse(sinfo.replies.begin(), sinfo.replies.end()); return OpStatus::OK; } diff --git a/src/server/rdb_load.cc b/src/server/rdb_load.cc index c01944d13..fa99a64f8 100644 --- a/src/server/rdb_load.cc +++ b/src/server/rdb_load.cc @@ -2856,15 +2856,12 @@ void RdbLoader::LoadScriptFromAux(string&& body) { void RdbLoader::LoadSearchIndexDefFromAux(string&& def) { facade::CapturingReplyBuilder crb{}; - ConnectionContext cntx{nullptr, nullptr, &crb}; + ConnectionContext cntx{nullptr, nullptr}; cntx.is_replicating = true; cntx.journal_emulated = true; cntx.skip_acl_validation = true; cntx.ns = &namespaces.GetDefaultNamespace(); - // Avoid deleting local crb - absl::Cleanup cntx_clean = [&cntx] { cntx.Inject(nullptr); }; - uint32_t consumed = 0; facade::RespVec resp_vec; facade::RedisParser parser; diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index 955290583..d86d8404f 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -390,7 +390,7 @@ RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) { DCHECK(context->transaction == nullptr) << id; - service_->DispatchCommand(CmdArgList{args}, context->reply_builder(), context); + service_->DispatchCommand(CmdArgList{args}, context->reply_builder_old(), context); DCHECK(context->transaction == nullptr); @@ -433,7 +433,8 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, string_view key, string_view va DCHECK(context->transaction == nullptr); - service_->DispatchMC(cmd, value, static_cast(context->reply_builder()), context); + service_->DispatchMC(cmd, value, static_cast(context->reply_builder_old()), + context); DCHECK(context->transaction == nullptr); @@ -452,8 +453,8 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResp auto* context = conn->cmd_cntx(); - service_->DispatchMC(cmd, string_view{}, static_cast(context->reply_builder()), - context); + service_->DispatchMC(cmd, string_view{}, + static_cast(context->reply_builder_old()), context); return conn->SplitLines(); } @@ -479,8 +480,8 @@ auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_listcmd_cntx(); - service_->DispatchMC(cmd, string_view{}, static_cast(context->reply_builder()), - context); + service_->DispatchMC(cmd, string_view{}, + static_cast(context->reply_builder_old()), context); return conn->SplitLines(); }