chore: pass reply_builder explicitly to pubsub module (#4021)

Also, deprecate `reply_builder()` access method.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2024-11-04 19:20:12 +02:00 committed by GitHub
parent fb7ea6c827
commit 8d61a91200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 107 additions and 139 deletions

View File

@ -36,17 +36,10 @@ class ConnectionContext {
return protocol_;
}
SinkReplyBuilder* reply_builder() {
SinkReplyBuilder* reply_builder_old() {
return rbuilder_.get();
}
// Allows receiving the output data from the commands called from scripts.
SinkReplyBuilder* Inject(SinkReplyBuilder* new_i) {
SinkReplyBuilder* res = rbuilder_.release();
rbuilder_.reset(new_i);
return res;
}
virtual size_t UsedMemory() const;
// connection state / properties.

View File

@ -725,7 +725,7 @@ void Connection::HandleRequests() {
// down and return with an error accordingly.
if (http_res && socket_->IsOpen()) {
cc_.reset(service_->CreateContext(socket_.get(), this));
reply_builder_ = cc_->reply_builder();
reply_builder_ = cc_->reply_builder_old();
if (*http_res) {
VLOG(1) << "HTTP1.1 identified";
@ -811,7 +811,7 @@ std::pair<std::string, std::string> Connection::GetClientInfoBeforeAfterTid() co
string_view phase_name = PHASE_NAMES[phase_];
if (cc_) {
DCHECK(cc_->reply_builder() && reply_builder_);
DCHECK(reply_builder_);
string cc_info = service_->GetContextInfo(cc_.get()).Format();
if (reply_builder_->IsSendActive())
phase_name = "send";

View File

@ -21,6 +21,17 @@ namespace dfly {
using namespace std;
using namespace facade;
static void SendSubscriptionChangedResponse(string_view action, std::optional<string_view> topic,
unsigned count, RedisReplyBuilder* rb) {
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action);
if (topic.has_value())
rb->SendBulkString(topic.value());
else
rb->SendNull();
rb->SendLong(count);
}
StoredCmd::StoredCmd(const CommandId* cid, ArgSlice args, facade::ReplyMode mode)
: cid_{cid}, buffer_{}, sizes_(args.size()), reply_mode_{mode} {
size_t total_size = 0;
@ -98,8 +109,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own
}
}
ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx,
facade::CapturingReplyBuilder* crb)
ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx)
: facade::ConnectionContext(nullptr, nullptr), transaction{tx} {
if (owner) {
acl_commands = owner->acl_commands;
@ -115,8 +125,6 @@ ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction
conn_state.db_index = owner->conn_state.db_index;
conn_state.squashing_info = {owner};
}
auto* prev_reply_builder = Inject(crb);
CHECK_EQ(prev_reply_builder, nullptr);
}
void ConnectionContext::ChangeMonitor(bool start) {
@ -137,61 +145,13 @@ void ConnectionContext::ChangeMonitor(bool start) {
EnableMonitoring(start);
}
vector<unsigned> ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add, bool to_reply,
ConnectionContext* conn) {
vector<unsigned> result(to_reply ? args.size() : 0, 0);
auto& conn_state = conn->conn_state;
if (!to_add && !conn_state.subscribe_info)
return result;
if (!conn_state.subscribe_info) {
DCHECK(to_add);
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
conn->subscriptions++;
}
auto& sinfo = *conn->conn_state.subscribe_info.get();
auto& local_store = pattern ? sinfo.patterns : sinfo.channels;
int32_t tid = util::ProactorBase::me()->GetPoolIndex();
DCHECK_GE(tid, 0);
ChannelStoreUpdater csu{pattern, to_add, conn, uint32_t(tid)};
// Gather all the channels we need to subscribe to / remove.
size_t i = 0;
for (string_view channel : args) {
if (to_add && local_store.emplace(channel).second)
csu.Record(channel);
else if (!to_add && local_store.erase(channel) > 0)
csu.Record(channel);
if (to_reply)
result[i++] = sinfo.SubscriptionCount();
}
csu.Apply();
// Important to reset conn_state.subscribe_info only after all references to it were
// removed.
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
DCHECK_GE(conn->subscriptions, 1u);
conn->subscriptions--;
}
return result;
}
void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result = ChangeSubscriptions(false, args, to_add, to_reply, this);
void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb) {
vector<unsigned> result = ChangeSubscriptions(args, false, to_add, to_reply);
if (to_reply) {
for (size_t i = 0; i < result.size(); ++i) {
const char* action[2] = {"unsubscribe", "subscribe"};
auto rb = static_cast<RedisReplyBuilder*>(reply_builder());
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action[to_add]);
rb->SendBulkString(ArgS(args, i));
@ -200,53 +160,41 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
}
}
void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result = ChangeSubscriptions(true, args, to_add, to_reply, this);
void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb) {
vector<unsigned> result = ChangeSubscriptions(args, true, to_add, to_reply);
if (to_reply) {
const char* action[2] = {"punsubscribe", "psubscribe"};
if (result.size() == 0) {
return SendSubscriptionChangedResponse(action[to_add], std::nullopt, 0);
return SendSubscriptionChangedResponse(action[to_add], std::nullopt, 0, rb);
}
for (size_t i = 0; i < result.size(); ++i) {
SendSubscriptionChangedResponse(action[to_add], ArgS(args, i), result[i]);
SendSubscriptionChangedResponse(action[to_add], ArgS(args, i), result[i], rb);
}
}
}
void ConnectionContext::UnsubscribeAll(bool to_reply) {
void ConnectionContext::UnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb) {
if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->channels.empty())) {
return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0);
return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0, rb);
}
StringVec channels(conn_state.subscribe_info->channels.begin(),
conn_state.subscribe_info->channels.end());
CmdArgVec arg_vec(channels.begin(), channels.end());
ChangeSubscription(false, to_reply, CmdArgList{arg_vec});
ChangeSubscription(false, to_reply, CmdArgList{arg_vec}, rb);
}
void ConnectionContext::PUnsubscribeAll(bool to_reply) {
void ConnectionContext::PUnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb) {
if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->patterns.empty())) {
return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0);
return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0, rb);
}
StringVec patterns(conn_state.subscribe_info->patterns.begin(),
conn_state.subscribe_info->patterns.end());
CmdArgVec arg_vec(patterns.begin(), patterns.end());
ChangePSubscription(false, to_reply, CmdArgList{arg_vec});
}
void ConnectionContext::SendSubscriptionChangedResponse(string_view action,
std::optional<string_view> topic,
unsigned count) {
auto rb = static_cast<RedisReplyBuilder*>(reply_builder());
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action);
if (topic.has_value())
rb->SendBulkString(topic.value());
else
rb->SendNull();
rb->SendLong(count);
ChangePSubscription(false, to_reply, CmdArgList{arg_vec}, rb);
}
size_t ConnectionState::ExecInfo::UsedMemory() const {
@ -269,6 +217,53 @@ size_t ConnectionContext::UsedMemory() const {
return facade::ConnectionContext::UsedMemory() + dfly::HeapSize(conn_state);
}
vector<unsigned> ConnectionContext::ChangeSubscriptions(CmdArgList channels, bool pattern,
bool to_add, bool to_reply) {
vector<unsigned> result(to_reply ? channels.size() : 0, 0);
if (!to_add && !conn_state.subscribe_info)
return result;
if (!conn_state.subscribe_info) {
DCHECK(to_add);
conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
subscriptions++;
}
auto& sinfo = *conn_state.subscribe_info.get();
auto& local_store = pattern ? sinfo.patterns : sinfo.channels;
int32_t tid = util::ProactorBase::me()->GetPoolIndex();
DCHECK_GE(tid, 0);
ChannelStoreUpdater csu{pattern, to_add, this, uint32_t(tid)};
// Gather all the channels we need to subscribe to / remove.
size_t i = 0;
for (string_view channel : channels) {
if (to_add && local_store.emplace(channel).second)
csu.Record(channel);
else if (!to_add && local_store.erase(channel) > 0)
csu.Record(channel);
if (to_reply)
result[i++] = sinfo.SubscriptionCount();
}
csu.Apply();
// Important to reset conn_state.subscribe_info only after all references to it were
// removed.
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
DCHECK_GE(subscriptions, 1u);
subscriptions--;
}
return result;
}
void ConnectionState::ExecInfo::Clear() {
DCHECK(!preborrowed_interpreter); // Must have been released properly
state = EXEC_INACTIVE;

View File

@ -267,9 +267,7 @@ struct ConnectionState {
class ConnectionContext : public facade::ConnectionContext {
public:
ConnectionContext(::io::Sink* stream, facade::Connection* owner, dfly::acl::UserCredentials cred);
ConnectionContext(const ConnectionContext* owner, Transaction* tx,
facade::CapturingReplyBuilder* crb);
ConnectionContext(const ConnectionContext* owner, Transaction* tx);
struct DebugInfo {
uint32_t shards_count = 0;
@ -292,10 +290,13 @@ class ConnectionContext : public facade::ConnectionContext {
return conn_state.db_index;
}
void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args);
void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args);
void UnsubscribeAll(bool to_reply);
void PUnsubscribeAll(bool to_reply);
void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb);
void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb);
void UnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb);
void PUnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb);
void ChangeMonitor(bool start); // either start or stop monitor on a given connection
size_t UsedMemory() const override;
@ -317,8 +318,8 @@ class ConnectionContext : public facade::ConnectionContext {
monitor = enable;
}
void SendSubscriptionChangedResponse(std::string_view action,
std::optional<std::string_view> topic, unsigned count);
std::vector<unsigned> ChangeSubscriptions(CmdArgList channels, bool pattern, bool to_add,
bool to_reply);
};
} // namespace dfly

View File

@ -148,7 +148,7 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool
absl::InlinedVector<string_view, 5> args_view;
facade::CapturingReplyBuilder crb;
ConnectionContext local_cntx{cntx, stub_tx.get(), &crb};
ConnectionContext local_cntx{cntx, stub_tx.get()};
absl::InsecureBitGen gen;
for (unsigned i = 0; i < batch.sz; ++i) {
@ -175,7 +175,6 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool
sf->service().InvokeCmd(cid, args_span, &crb, &local_cntx);
}
local_cntx.Inject(nullptr);
local_tx->UnlockMulti();
}

View File

@ -38,9 +38,7 @@ template <typename... Ts> journal::ParsedEntry::CmdData BuildFromParts(Ts... par
} // namespace
JournalExecutor::JournalExecutor(Service* service)
: service_{service},
reply_builder_{facade::ReplyMode::NONE},
conn_context_{nullptr, nullptr, &reply_builder_} {
: service_{service}, reply_builder_{facade::ReplyMode::NONE}, conn_context_{nullptr, nullptr} {
conn_context_.is_replicating = true;
conn_context_.journal_emulated = true;
conn_context_.skip_acl_validation = true;
@ -48,7 +46,6 @@ JournalExecutor::JournalExecutor(Service* service)
}
JournalExecutor::~JournalExecutor() {
conn_context_.Inject(nullptr);
}
void JournalExecutor::Execute(DbIndex dbid, absl::Span<journal::ParsedEntry::CmdData> cmds) {

View File

@ -1803,13 +1803,6 @@ void Service::Unwatch(CmdArgList args, Transaction* tx, SinkReplyBuilder* builde
return builder->SendOk();
}
template <typename F> void WithReplies(CapturingReplyBuilder* crb, ConnectionContext* cntx, F&& f) {
SinkReplyBuilder* old_rrb = nullptr;
old_rrb = cntx->Inject(crb);
f();
cntx->Inject(old_rrb);
}
optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionContext* cntx,
bool force) {
auto& info = cntx->conn_state.script_info;
@ -1825,9 +1818,7 @@ optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionC
tx->MultiSwitchCmd(eval_cid);
CapturingReplyBuilder crb{ReplyMode::ONLY_ERR};
WithReplies(&crb, cntx, [&] {
MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), &crb, cntx, this, true, true);
});
MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), &crb, cntx, this, true, true);
info->async_cmds_heap_mem = 0;
info->async_cmds.clear();
@ -1842,9 +1833,6 @@ void Service::CallFromScript(ConnectionContext* cntx, Interpreter::CallArgs& ca)
DVLOG(2) << "CallFromScript " << ca.args[0];
InterpreterReplier replier(ca.translator);
facade::SinkReplyBuilder* orig = cntx->Inject(&replier);
absl::Cleanup clean = [orig, cntx] { cntx->Inject(orig); };
optional<ErrorReply> findcmd_err;
if (ca.async) {
@ -2364,7 +2352,8 @@ void Service::Subscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil
if (cluster::IsClusterEnabled()) {
return builder->SendError("SUBSCRIBE is not supported in cluster mode yet");
}
cntx->ChangeSubscription(true /*add*/, true /* reply*/, std::move(args));
cntx->ChangeSubscription(true /*add*/, true /* reply*/, std::move(args),
static_cast<RedisReplyBuilder*>(builder));
}
void Service::Unsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder,
@ -2373,9 +2362,9 @@ void Service::Unsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* bu
return builder->SendError("UNSUBSCRIBE is not supported in cluster mode yet");
}
if (args.size() == 0) {
cntx->UnsubscribeAll(true);
cntx->UnsubscribeAll(true, static_cast<RedisReplyBuilder*>(builder));
} else {
cntx->ChangeSubscription(false, true, args);
cntx->ChangeSubscription(false, true, args, static_cast<RedisReplyBuilder*>(builder));
}
}
@ -2384,7 +2373,7 @@ void Service::PSubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* bui
if (cluster::IsClusterEnabled()) {
return builder->SendError("PSUBSCRIBE is not supported in cluster mode yet");
}
cntx->ChangePSubscription(true, true, args);
cntx->ChangePSubscription(true, true, args, static_cast<RedisReplyBuilder*>(builder));
}
void Service::PUnsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder,
@ -2393,9 +2382,9 @@ void Service::PUnsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* b
return builder->SendError("PUNSUBSCRIBE is not supported in cluster mode yet");
}
if (args.size() == 0) {
cntx->PUnsubscribeAll(true);
cntx->PUnsubscribeAll(true, static_cast<RedisReplyBuilder*>(builder));
} else {
cntx->ChangePSubscription(false, true, args);
cntx->ChangePSubscription(false, true, args, static_cast<RedisReplyBuilder*>(builder));
}
}
@ -2653,12 +2642,12 @@ void Service::OnConnectionClose(facade::ConnectionContext* cntx) {
if (conn_state.subscribe_info) { // Clean-ups related to PUBSUB
if (!conn_state.subscribe_info->channels.empty()) {
server_cntx->UnsubscribeAll(false);
server_cntx->UnsubscribeAll(false, nullptr);
}
if (conn_state.subscribe_info) {
DCHECK(!conn_state.subscribe_info->patterns.empty());
server_cntx->PUnsubscribeAll(false);
server_cntx->PUnsubscribeAll(false, nullptr);
}
DCHECK(!conn_state.subscribe_info);

View File

@ -143,7 +143,7 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard
auto* local_tx = sinfo.local_tx.get();
facade::CapturingReplyBuilder crb;
ConnectionContext local_cntx{cntx_, local_tx, &crb};
ConnectionContext local_cntx{cntx_, local_tx};
if (cntx_->conn()) {
local_cntx.skip_acl_validation = cntx_->conn()->IsPrivileged();
}
@ -178,10 +178,6 @@ OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard
CheckConnStateClean(local_state);
}
// ConnectionContext deletes the reply builder upon destruction, so
// remove our local pointer from it.
local_cntx.Inject(nullptr);
reverse(sinfo.replies.begin(), sinfo.replies.end());
return OpStatus::OK;
}

View File

@ -2856,15 +2856,12 @@ void RdbLoader::LoadScriptFromAux(string&& body) {
void RdbLoader::LoadSearchIndexDefFromAux(string&& def) {
facade::CapturingReplyBuilder crb{};
ConnectionContext cntx{nullptr, nullptr, &crb};
ConnectionContext cntx{nullptr, nullptr};
cntx.is_replicating = true;
cntx.journal_emulated = true;
cntx.skip_acl_validation = true;
cntx.ns = &namespaces.GetDefaultNamespace();
// Avoid deleting local crb
absl::Cleanup cntx_clean = [&cntx] { cntx.Inject(nullptr); };
uint32_t consumed = 0;
facade::RespVec resp_vec;
facade::RedisParser parser;

View File

@ -390,7 +390,7 @@ RespExpr BaseFamilyTest::Run(std::string_view id, ArgSlice slice) {
DCHECK(context->transaction == nullptr) << id;
service_->DispatchCommand(CmdArgList{args}, context->reply_builder(), context);
service_->DispatchCommand(CmdArgList{args}, context->reply_builder_old(), context);
DCHECK(context->transaction == nullptr);
@ -433,7 +433,8 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, string_view key, string_view va
DCHECK(context->transaction == nullptr);
service_->DispatchMC(cmd, value, static_cast<MCReplyBuilder*>(context->reply_builder()), context);
service_->DispatchMC(cmd, value, static_cast<MCReplyBuilder*>(context->reply_builder_old()),
context);
DCHECK(context->transaction == nullptr);
@ -452,8 +453,8 @@ auto BaseFamilyTest::RunMC(MP::CmdType cmd_type, std::string_view key) -> MCResp
auto* context = conn->cmd_cntx();
service_->DispatchMC(cmd, string_view{}, static_cast<MCReplyBuilder*>(context->reply_builder()),
context);
service_->DispatchMC(cmd, string_view{},
static_cast<MCReplyBuilder*>(context->reply_builder_old()), context);
return conn->SplitLines();
}
@ -479,8 +480,8 @@ auto BaseFamilyTest::GetMC(MP::CmdType cmd_type, std::initializer_list<std::stri
auto* context = conn->cmd_cntx();
service_->DispatchMC(cmd, string_view{}, static_cast<MCReplyBuilder*>(context->reply_builder()),
context);
service_->DispatchMC(cmd, string_view{},
static_cast<MCReplyBuilder*>(context->reply_builder_old()), context);
return conn->SplitLines();
}