Acero Task Scheduler

之前的文章 https://blog.mwish.me/2023/08/03/Arrow-Acero-Framework/ 提到了 AsyncTaskScheduler。Acero 中,这里有两种 TaskScheduler,因为历史原因存在:

  1. AsyncTaskScheduler: 异步任务调度器,触发出不同的任务,每个任务提交需要返回一个 Result>,Task 里面可以递归的添加 SubTask。AsyncTaskScheduler 可以提供对应的 finish 或者 abort 的 callback。这里也有 Throttle 这样的限流作业提交(主要用于 Scan 和 Sink 等)和 TaskGroupScheduler 这样一个按组提交、整个一起完成的逻辑(类似 Go WaitGroup 经典的用法)。
  2. TaskScheduler: 我猜你一眼就说这是非异步任务调度器了,但很遗憾这并不是,这个类型实际上是 AsyncTaskScheduler 的调用者(设计可以不是,但 acero 链路里面实际上是)。它实际做的事情是一个 Task Group 运行器。这里会有这样的逻辑
    1. 允许注册一个 TaskGroup
    2. 允许开始调度,或者要求 ScheduleMore / ExecuteMore
1
TaskScheduler` 这个概念有点抽象,我们换个角度,以 pipeline 的角度来想想这个问题,首先 Acero 如前一篇文章,是一个 Push-based Execution Model,在底端由 `AsyncTaskScheduler` 来去驱动,见 `ExecPlanImpl::StartProducing` 和每个 Node 的 `StartProducing()` 接口,驱动一个 `Throttle` 的 `AsyncTaskScheduler

那 TaskScheduler 呢?它的设计是为了一个东西:ExecNode 内并发。对于一些 Pipeline Breaker,比如 Agg 或者 HashJoin,这里会使用 TaskGroup 来做「节点内并发调度」。同时 Node 初始的时候可能会初始化一些 Partition 数,或者 Hash Join 的分区数,TaskGroup 根据这些分区数来决定 Group-Size,并 Group-By Group 调度。一个 Group 调度完可以触发后面的调度,比如:前台 AsynTaskScheduler 不走 TaskScheduler,调用 Hash Build ,Build 完通过 TaskScheduler 和分区数启动 Probe Group。

俯瞰 acero::QueryContext

QueryContext 的成员包含下面两个,前者在 QueryContext::Init 中初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
arrow::util::AsyncTaskScheduler* async_scheduler_ = NULLPTR;
std::unique_ptr<TaskScheduler> task_scheduler_ = TaskScheduler::Make();

Status QueryContext::Init(util::AsyncTaskScheduler* scheduler) {
async_scheduler_ = scheduler;
return Status::OK();
}

class ARROW_ACERO_EXPORT QueryContext {
public:
TaskScheduler* scheduler() { return task_scheduler_.get(); }
arrow::util::AsyncTaskScheduler* async_scheduler() { return async_scheduler_; }
};

下面也有 QueryContext 层 Schedule Task 的例子,这几个比较简单明确

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
void QueryContext::ScheduleTask(std::function<Status()> fn, std::string_view name) {
::arrow::internal::Executor* exec = executor();
// Adds a task which submits fn to the executor and tracks its progress. If we're
// already stopping then the task is ignored and fn is not executed.
async_scheduler_->AddSimpleTask(
[exec, fn = std::move(fn)]() mutable { return exec->Submit(std::move(fn)); }, name);
}

void QueryContext::ScheduleTask(std::function<Status(size_t)> fn, std::string_view name) {
std::function<Status()> indexed_fn = [this, fn]() {
size_t thread_index = GetThreadIndex();
return fn(thread_index);
};
ScheduleTask(std::move(indexed_fn), name);
}

void QueryContext::ScheduleIOTask(std::function<Status()> fn, std::string_view name) {
async_scheduler_->AddSimpleTask(
[this, fn]() { return io_context_.executor()->Submit(std::move(fn)); }, name);
}

int QueryContext::RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
std::function<Status(size_t)> on_finished) {
return task_scheduler_->RegisterTaskGroup(std::move(task), std::move(on_finished));
}

Status QueryContext::StartTaskGroup(int task_group_id, int64_t num_tasks) {
return task_scheduler_->StartTaskGroup(GetThreadIndex(), task_group_id, num_tasks);
}

看到这里顺便提一下,QueryContext 里面有个比较有意思的类型是 ThreadIndexer,这个类型把 io 线程和 cpu 线程的 thread_id 映射到 0 开始的 index 上,用来维护 acero 内部执行的状态。

我来贴一段有水字数嫌疑的代码:https://github.com/apache/arrow/blob/ec3d2839d102cc44d368c154cfa756eb946909ee/cpp/src/arrow/acero/exec_plan.cc#L128-L186

  1. 利用 arrow::util::AsyncTaskScheduler::Make, 就地创建一个 AsyncTaskScheduler,然后去用它来初始化 QueryContextAsyncTaskScheduler*
  2. 去初始化每个 ExecNode ,初始化的时候,Node 拿到 QueryContext,在 TaskScheduler 注册 TaskGroup,e.g.: https://github.com/apache/arrow/blob/ec3d2839d102cc44d368c154cfa756eb946909ee/cpp/src/arrow/acero/hash_join_node.cc#L931-L982
  3. 这里调用了 TaskScheduler::RegisterEnd() 来标识所有 ExecNode 的注册完成。然后有个 StartScheduling。你会发现这个 StartScheduling 甚至没有开始驱动叶子结点 StartProducing。实际上这里就是注册具体执行的调度着是 ctx->ScheduleTask,即 AsyncTaskScheduler::AddSimpleTask,然后注册了一下最大的并发任务数(这里设的比核数还多一些,目测就是个随便设的经验值了)
  4. 驱动 ExecNode::StartProducing. 对于有的 ExecNode,这里的逻辑会分成不同的种类:
    1. 消费来自别的节点的数据,啥都不干:https://github.com/apache/arrow/blob/ec3d2839d102cc44d368c154cfa756eb946909ee/cpp/src/arrow/acero/aggregate_internal.h#L196
    2. 像 Join 一类,可能需要 Prepare 一些资源的:https://github.com/apache/arrow/blob/ec3d2839d102cc44d368c154cfa756eb946909ee/cpp/src/arrow/acero/asof_join_node.cc#L1509
    3. ScanNode / SourceNode,需要准备生成数据。Scan 会用 plan_->query_context()->async_scheduler()->AddSimpleTask 来添加 Scan 的上下文: https://github.com/apache/arrow/blob/ec3d2839d102cc44d368c154cfa756eb946909ee/cpp/src/arrow/acero/source_node.cc#L165https://github.com/apache/arrow/blob/ec3d2839d102cc44d368c154cfa756eb946909ee/cpp/src/arrow/dataset/scan_node.cc#L439-L451
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
Future<> scheduler_finished = arrow::util::AsyncTaskScheduler::Make(
[this](arrow::util::AsyncTaskScheduler* async_scheduler) {
QueryContext* ctx = query_context();
RETURN_NOT_OK(ctx->Init(async_scheduler));

for (auto& n : nodes_) {
RETURN_NOT_OK(n->Init());
}

ctx->scheduler()->RegisterEnd();
int num_threads = 1;
bool sync_execution = true;
if (auto executor = query_context()->exec_context()->executor()) {
num_threads = executor->GetCapacity();
sync_execution = false;
}
RETURN_NOT_OK(ctx->scheduler()->StartScheduling(
0 /* thread_index */,
[ctx](std::function<Status(size_t)> fn) -> Status {
// TODO(weston) add names to synchronous scheduler so we can use something
// better than sync-scheduler-task here
ctx->ScheduleTask(std::move(fn), "sync-scheduler-task");
return Status::OK();
},
/*concurrent_tasks=*/2 * num_threads, sync_execution));

// producers precede consumers
sorted_nodes_ = TopoSort();

Status st = Status::OK();

using rev_it = std::reverse_iterator<NodeVector::iterator>;
for (rev_it it(sorted_nodes_.end()), end(sorted_nodes_.begin()); it != end;
++it) {
auto node = *it;

st = node->StartProducing();
if (!st.ok()) {
// Stop nodes that successfully started, in reverse order
bool expected = false;
if (stopped_.compare_exchange_strong(expected, true)) {
StopProducingImpl(it.base(), sorted_nodes_.end());
}
return st;
}
}
return st;
},
[this](const Status& st) {
// If an error occurs we call StopProducing. The scheduler will already have
// stopped scheduling new tasks at this point. However, any nodes that are
// dealing with external tasks will need to trigger those external tasks to end
// early.
StopProducing();
});
scheduler_finished.AddCallback([this](const Status& st) {
if (st.ok()) {
if (stopped_.load()) {
finished_.MarkFinished(Status::Cancelled("Plan was cancelled early."));
} else {
finished_.MarkFinished();
}
} else {
finished_.MarkFinished(st);
}
});

上面就是 Node / QueryContext 视角的 Scan 全流程。

AsyncTaskScheduler

AsyncTaskScheduler 用于「调度」对应的 Task 的执行,但是它本身不包含任何所谓线程池有关的逻辑,它负责:

  1. 维护 StopTokenmaybe_error_,它内部的错误状态被称为 “Abort”,Abort 可能来自下列情景。Abort 的时候,出错的第一个线程会设置 maybe_error_,然后调用 abort_callback_
    1. stopToken 被通知任务已经停止
    2. 任何 Task 执行出错
  2. 维护 running_tasks_Task 可以自己投递 Running Tasks,同时 running_tasks_ 数为 0 表示真的结束了,可以通知下游任务了,这里会通知内部的 Future<> finished_

我们上面提到了,AsyncTaskSchedulerTask 直接执行就够了,具体的 io/cpu 发送者还是在 QueryContext:: 里面包装给了 io thread 或者 cpu thread。

上面的逻辑比较有意思,就是会要求 AsyncTaskScheduler 内部在查询运行期间一直有任务在被调度,我们根据代码来思考一个这样的流程,以下代码来自 ScanNode

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
Status StartProducing() override {
NoteStartProducing(ToStringExtra());
batches_throttle_ = util::ThrottledAsyncTaskScheduler::Make(
plan_->query_context()->async_scheduler(), options_.target_bytes_readahead + 1);
plan_->query_context()->async_scheduler()->AddSimpleTask(
[this] {
return GetFragments(options_.dataset.get(), options_.filter)
.Then([this](const AsyncGenerator<std::shared_ptr<Fragment>>& frag_gen) {
ScanFragments(frag_gen);
});
},
"ScanNode::ListDataset::GetFragments"sv);
return Status::OK();
}

void ScanFragments(const AsyncGenerator<std::shared_ptr<Fragment>>& frag_gen) {
std::shared_ptr<util::AsyncTaskScheduler> fragment_tasks =
util::MakeThrottledAsyncTaskGroup(
plan_->query_context()->async_scheduler(), options_.fragment_readahead + 1,
/*queue=*/nullptr,
[this]() { return output_->InputFinished(this, num_batches_.load()); });
fragment_tasks->AddAsyncGenerator<std::shared_ptr<Fragment>>(
frag_gen,
[this, fragment_tasks =
std::move(fragment_tasks)](const std::shared_ptr<Fragment>& fragment) {
fragment_tasks->AddTask(std::make_unique<ListFragmentTask>(this, fragment));
return Status::OK();
},
"ScanNode::ListDataset::Next");
}

struct ListFragmentTask : util::AsyncTaskScheduler::Task {
ListFragmentTask(ScanNode* node, std::shared_ptr<Fragment> fragment)
: node(node), fragment(std::move(fragment)) {
name_ = "ScanNode::ListFragment::" + this->fragment->ToString();
}

Result<Future<>> operator()() override {
return fragment
->InspectFragment(node->options_.format_options,
node->plan_->query_context()->exec_context())
.Then([this](const std::shared_ptr<InspectedFragment>& inspected_fragment) {
return BeginScan(inspected_fragment);
});
}

/// ...
};

简述一下上面的逻辑:

  1. 启动的时候,驱动 ThrottledAsyncTaskScheduler 来做一定的 Fragment Rate Limit
  2. 初始化 GetFragments 作为 Generator,这个接口类似 List,根据 DatasetPredicate 来返回出合适的 Fragments
  3. Fragment 的下游是 ScanFragments,它接受 AsyncGenerator>,然后调用 ListFragmentTask 来从 shared_ptr 产生 Scan

最后,根据一层层调用,Scan 出 Batch 后,这里会有对应的回调,来在系统中,交给 CPU Thread 调度下游节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
Status HandleBatch(const std::shared_ptr<RecordBatch>& batch) {
ARROW_ASSIGN_OR_RAISE(
compute::ExecBatch evolved_batch,
scan_->fragment_evolution->EvolveBatch(
batch, node_->options_.columns, *scan_->scan_request.fragment_selection));
compute::ExecBatch with_known_values = AddKnownValues(std::move(evolved_batch));
node_->plan_->query_context()->ScheduleTask(
[node = node_, output_batch = std::move(with_known_values)] {
return node->output_->InputReceived(node, output_batch);
},
"ScanNode::ProcessMorsel");
return Status::OK();
}

AddAsyncGenerator

这个代码实在是有意思,所以单独抽一节讲讲:

  1. 上游是 generator,然后 generator 产生的东西被异步提交给 visitor
  2. 维护了 State 存放整体的状态
  3. 维护了 SumbitTask 作为最初的 Task,然后在它的 operator() 里面,尝试在循环中执行尽可能多的 Task,防止过多的递归 Callback
  4. 如果有阻塞,构建 SubmitTaskCallback, SubmitTaskCallback 在 call 的时候也会尝试再投敌一个 SubmitTask 任务,递归的执行

笔者感觉 3/4 的思路上还是挺有意思的,值得借鉴一下。不过感觉这代码还是复杂了吧…n

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
template <typename T>
bool AsyncTaskScheduler::AddAsyncGenerator(std::function<Future<T>()> generator,
std::function<Status(const T&)> visitor,
std::string_view name) {
struct State {
State(std::function<Future<T>()> generator, std::function<Status(const T&)> visitor,
std::unique_ptr<AsyncTaskGroup> task_group, std::string_view name)
: generator(std::move(generator)),
visitor(std::move(visitor)),
task_group(std::move(task_group)),
name(name) {}
std::function<Future<T>()> generator;
std::function<Status(const T&)> visitor;
std::unique_ptr<AsyncTaskGroup> task_group;
std::string_view name;
};
struct SubmitTask : public Task {
explicit SubmitTask(std::unique_ptr<State> state_holder)
: state_holder(std::move(state_holder)) {}

struct SubmitTaskCallback {
SubmitTaskCallback(std::unique_ptr<State> state_holder, Future<> task_completion)
: state_holder(std::move(state_holder)),
task_completion(std::move(task_completion)) {}
void operator()(const Result<T>& maybe_item) {
if (!maybe_item.ok()) {
task_completion.MarkFinished(maybe_item.status());
return;
}
const auto& item = *maybe_item;
if (IsIterationEnd(item)) {
task_completion.MarkFinished();
return;
}
Status visit_st = state_holder->visitor(item);
if (!visit_st.ok()) {
task_completion.MarkFinished(std::move(visit_st));
return;
}
state_holder->task_group->AddTask(
std::make_unique<SubmitTask>(std::move(state_holder)));
task_completion.MarkFinished();
}
std::unique_ptr<State> state_holder;
Future<> task_completion;
};

Result<Future<>> operator()() {
Future<> task = Future<>::Make();
// Consume as many items as we can (those that are already finished)
// synchronously to avoid recursion / stack overflow.
while (true) {
Future<T> next = state_holder->generator();
if (next.TryAddCallback(
[&] { return SubmitTaskCallback(std::move(state_holder), task); })) {
return task;
}
ARROW_ASSIGN_OR_RAISE(T item, next.result());
if (IsIterationEnd(item)) {
task.MarkFinished();
return task;
}
ARROW_RETURN_NOT_OK(state_holder->visitor(item));
}
}

std::string_view name() const { return state_holder->name; }

std::unique_ptr<State> state_holder;
};
std::unique_ptr<AsyncTaskGroup> task_group =
AsyncTaskGroup::Make(this, [] { return Status::OK(); });
AsyncTaskGroup* task_group_view = task_group.get();
std::unique_ptr<State> state_holder = std::make_unique<State>(
std::move(generator), std::move(visitor), std::move(task_group), name);
task_group_view->AddTask(std::make_unique<SubmitTask>(std::move(state_holder)));
return true;
}

ThrottledAsyncTaskScheduler

ThrottledAsyncTaskScheduler 依赖一个已有的 AsyncTaskScheduler,通常是某个 ExecNode 内部用来给某种操作限流的。ThrottledAsyncTaskScheduler 的内容包含:

  1. Queue: 缓存积压消费的任务的 Queue,默认 FIFO,也可以定义别的 Queue
  2. Throttle: 对应的限流器,也允许 PauseResume. 每个 Task 会有一个默认为 1 的 cost,使用 ThrottledAsyncTaskScheduler 的用户也可以自己定义

Future 的 TryAddCallback 语义

我们可以想象几个 Future 有关的语义:

  1. folly 没有 shared future ,只有 SharedPromise,它的 Future 也是单个 Futurethen.. 之类的方式生成下来的
  2. Arrow Acero 使用的 Future 类似 SharedPromise + SharedFuture,这里有一个很好玩的两个接口:
    1. AddCallback 无论如何都会添加 Callback,如果完成了会就地执行
    2. TryAdd 则不会,相当于是对递归场景的可能优化,避免在 Loop 里面可以解决的问题变成 Callback 一层又一层。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
/// \brief Consumer API: Register a callback to run when this future completes
///
/// The callback should receive the result of the future (const Result<T>&)
/// For a void or statusy future this should be (const Status&)
///
/// There is no guarantee to the order in which callbacks will run. In
/// particular, callbacks added while the future is being marked complete
/// may be executed immediately, ahead of, or even the same time as, other
/// callbacks that have been previously added.
///
/// WARNING: callbacks may hold arbitrary references, including cyclic references.
/// Since callbacks will only be destroyed after they are invoked, this can lead to
/// memory leaks if a Future is never marked finished (abandoned):
///
/// {
/// auto fut = Future<>::Make();
/// fut.AddCallback([fut]() {});
/// }
///
/// In this example `fut` falls out of scope but is not destroyed because it holds a
/// cyclic reference to itself through the callback.
template <typename OnComplete, typename Callback = WrapOnComplete<OnComplete>>
void AddCallback(OnComplete on_complete,
CallbackOptions opts = CallbackOptions::Defaults()) const {
// We know impl_ will not be dangling when invoking callbacks because at least one
// thread will be waiting for MarkFinished to return. Thus it's safe to keep a
// weak reference to impl_ here
impl_->AddCallback(Callback{std::move(on_complete)}, opts);
}

/// \brief Overload of AddCallback that will return false instead of running
/// synchronously
///
/// This overload will guarantee the callback is never run synchronously. If the future
/// is already finished then it will simply return false. This can be useful to avoid
/// stack overflow in a situation where you have recursive Futures. For an example
/// see the Loop function
///
/// Takes in a callback factory function to allow moving callbacks (the factory function
/// will only be called if the callback can successfully be added)
///
/// Returns true if a callback was actually added and false if the callback failed
/// to add because the future was marked complete.
template <typename CallbackFactory,
typename OnComplete = detail::result_of_t<CallbackFactory()>,
typename Callback = WrapOnComplete<OnComplete>>
bool TryAddCallback(CallbackFactory callback_factory,
CallbackOptions opts = CallbackOptions::Defaults()) const {
return impl_->TryAddCallback([&]() { return Callback{callback_factory()}; }, opts);
}

知道这个之后,我们可以看看 Throttle 限流的逻辑:

  1. AddTask 简单的尝试 AddCallback 来回调,争取调度更多任务(ContinueTasks
  2. SubmitTask 需要防止递归 ContinueTasks,比较有意思,这里面逻辑就会 TryAddCallback
  3. 如果一旦 Status 不为 ok,不需要管理剩下的 throttle,因为,ThrottleImpl 析构的时候,会尝试清理掉 blocking 的 Future。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
bool AddTask(std::unique_ptr<Task> task) override {
std::unique_lock lk(mutex_);
// If the queue isn't empty then don't even try and acquire the throttle
// We can safely assume it is either blocked or in the middle of trying to
// alert a queued task.
if (!queue_->Empty()) {
queue_->Push(std::move(task));
return true;
}
int latched_cost = std::min(task->cost(), throttle_->Capacity());
std::optional<Future<>> maybe_backoff = throttle_->TryAcquire(latched_cost);
if (maybe_backoff) {
queue_->Push(std::move(task));
lk.unlock();
maybe_backoff->AddCallback(
[weak_self = std::weak_ptr<ThrottledAsyncTaskSchedulerImpl>(
shared_from_this())](const Status& st) {
if (st.ok()) {
if (auto self = weak_self.lock()) {
self->ContinueTasks();
}
}
});
return true;
} else {
lk.unlock();
return SubmitTask(std::move(task), latched_cost, /*in_continue=*/false);
}
}

bool SubmitTask(std::unique_ptr<Task> task, int latched_cost, bool in_continue) {
// Wrap the task with a wrapper that runs it and then checks to see if there are any
// queued tasks
std::string_view name = task->name();
return target_->AddSimpleTask(
[latched_cost, in_continue, inner_task = std::move(task),
self = shared_from_this()]() mutable -> Result<Future<>> {
ARROW_ASSIGN_OR_RAISE(Future<> inner_fut, (*inner_task)());
if (!inner_fut.TryAddCallback([&] {
return [latched_cost, self = std::move(self)](const Status& st) -> void {
if (st.ok()) {
self->throttle_->Release(latched_cost);
self->ContinueTasks();
}
};
})) {
// If the task is already finished then don't run ContinueTasks
// if we are already running it so we can avoid stack overflow
self->throttle_->Release(latched_cost);
if (!in_continue) {
self->ContinueTasks();
}
}
return inner_fut;
},
name);
}

void ContinueTasks() {
std::unique_lock lk(mutex_);
while (!queue_->Empty()) {
int next_cost = std::min(queue_->Peek().cost(), throttle_->Capacity());
std::optional<Future<>> maybe_backoff = throttle_->TryAcquire(next_cost);
if (maybe_backoff) {
lk.unlock();
if (!maybe_backoff->TryAddCallback([&] {
return [self = shared_from_this()](const Status& st) {
if (st.ok()) {
self->ContinueTasks();
}
};
})) {
if (!maybe_backoff->status().ok()) {
return;
}
lk.lock();
continue;
}
return;
} else {
std::unique_ptr<Task> next_task = queue_->Pop();
lk.unlock();
if (!SubmitTask(std::move(next_task), next_cost, /*in_continue=*/true)) {
return;
}
lk.lock();
}
}
}

AsyncTaskGroup

提供一个 Group 和一个 Group Callback。代码很简单,直接看一眼就懂:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class AsyncTaskGroupImpl : public AsyncTaskGroup {
public:
AsyncTaskGroupImpl(AsyncTaskScheduler* target, FnOnce<Status()> finish_cb)
: target_(target), state_(std::make_shared<State>(std::move(finish_cb))) {}

~AsyncTaskGroupImpl() {
if (--state_->task_count == 0) {
Status st = std::move(state_->finish_cb)();
if (!st.ok()) {
// We can't return an invalid status from the destructor so we schedule a dummy
// failing task
target_->AddSimpleTask([st = std::move(st)]() { return st; },
"failed_task_reporter"sv);
}
}
}

bool AddTask(std::unique_ptr<Task> task) override {
state_->task_count++;
struct WrapperTask : public Task {
WrapperTask(std::unique_ptr<Task> target, std::shared_ptr<State> state)
: target(std::move(target)), state(std::move(state)) {}
Result<Future<>> operator()() override {
ARROW_ASSIGN_OR_RAISE(Future<> inner_fut, (*target)());
return inner_fut.Then([state = std::move(state)]() {
if (--state->task_count == 0) {
return std::move(state->finish_cb)();
}
return Status::OK();
});
}
int cost() const override { return target->cost(); }
std::string_view name() const override { return target->name(); }
std::unique_ptr<Task> target;
std::shared_ptr<State> state;
};
return target_->AddTask(std::make_unique<WrapperTask>(std::move(task), state_));
}
private:
struct State {
explicit State(FnOnce<Status()> finish_cb)
: task_count(1), finish_cb(std::move(finish_cb)) {}
std::atomic<int> task_count;
FnOnce<Status()> finish_cb;
};
AsyncTaskScheduler* target_;
std::shared_ptr<State> state_;
};

TaskScheduler

TaskScheduler 是一个写了不少无锁代码的 Task Group 调度器。虽然不是全部部分都无锁,但这部分代码还是比较复杂的。感觉是为了性能特意搞得相对复杂了一些,不过我个人总觉得它并发数好像不一定有那么高。我们先贴一下 TaskScheduler 对应的几个 Callback 和接口:

  1. Group 的执行函数 TaskImpl 暴露了一个 (thread_id, task_id) 的参数,同时提供了一个 TaskGroupContinuationImpl 的函数,留了一个 thread_id 的参数
  2. ScheduleImpl 本质上是一种…你在这本篇文章搜一下就能找到调用的地方了,发给我们亲爱的 QueryContext::ScheduleTask
  3. 我帮你们看过代码了,ExecuteMore 根本没有外部调用者,呵呵
  4. 在外部注册完成的时候,会调用 StartScheduling ,但是这个时候什么都不会发生,当需要的时候,比如 Hash Build 完成了,来调用 HashProbe 管线,这里就会启动对应的 Tasks。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// Used for asynchronous execution of operations that can be broken into
// a fixed number of symmetric tasks that can be executed concurrently.
//
// Implements priorities between multiple such operations, called task groups.
//
// Allows to specify the maximum number of in-flight tasks at any moment.
//
// Also allows for executing next pending tasks immediately using a caller thread.
//
class ARROW_ACERO_EXPORT TaskScheduler {
public:
using TaskImpl = std::function<Status(size_t, int64_t)>;
using TaskGroupContinuationImpl = std::function<Status(size_t)>;
using ScheduleImpl = std::function<Status(TaskGroupContinuationImpl)>;
using AbortContinuationImpl = std::function<void()>;

virtual ~TaskScheduler() = default;
// Order in which task groups are registered represents priorities of their tasks
// (the first group has the highest priority).
//
// Returns task group identifier that is used to request operations on the task group.
virtual int RegisterTaskGroup(TaskImpl task_impl,
TaskGroupContinuationImpl cont_impl) = 0;
virtual void RegisterEnd() = 0;
// total_num_tasks may be zero, in which case task group continuation will be executed
// immediately
virtual Status StartTaskGroup(size_t thread_id, int group_id,
int64_t total_num_tasks) = 0;
// Execute given number of tasks immediately using caller thread
virtual Status ExecuteMore(size_t thread_id, int num_tasks_to_execute,
bool execute_all) = 0;
// Begin scheduling tasks using provided callback and
// the limit on the number of in-flight tasks at any moment.
//
// Scheduling will continue as long as there are waiting tasks.
//
// It will automatically resume whenever new task group gets started.
virtual Status StartScheduling(size_t thread_id, ScheduleImpl schedule_impl,
int num_concurrent_tasks, bool use_sync_execution) = 0;
// Abort scheduling and execution.
// Used in case of being notified about unrecoverable error for the entire query.
virtual void Abort(AbortContinuationImpl impl) = 0;

static std::unique_ptr<TaskScheduler> Make();
};

那么这里的内部逻辑是什么样的呢?之前被馒头总带着看了这个 patch: https://github.com/apache/arrow/pull/45268

我还是偷懒直接贴一下成员:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
  // Task group state transitions progress one way.
// Seeing an old version of the state by a thread is a valid situation.
//
enum class TaskGroupState : int {
NOT_READY,
READY,
ALL_TASKS_STARTED,
ALL_TASKS_FINISHED
};

struct TaskGroup {
TaskGroup(TaskImpl task_impl, TaskGroupContinuationImpl cont_impl)
: task_impl_(std::move(task_impl)),
cont_impl_(std::move(cont_impl)),
state_(TaskGroupState::NOT_READY),
num_tasks_present_(0) {
num_tasks_started_.value.store(0);
num_tasks_finished_.value.store(0);
}
TaskGroup(const TaskGroup& src)
: task_impl_(src.task_impl_),
cont_impl_(src.cont_impl_),
state_(TaskGroupState::NOT_READY),
num_tasks_present_(0) {
ARROW_DCHECK(src.state_ == TaskGroupState::NOT_READY);
num_tasks_started_.value.store(0);
num_tasks_finished_.value.store(0);
}
TaskImpl task_impl_;
TaskGroupContinuationImpl cont_impl_;

TaskGroupState state_;
int64_t num_tasks_present_;

AtomicWithPadding<int64_t> num_tasks_started_;
AtomicWithPadding<int64_t> num_tasks_finished_;
};

// 最多允许同时运行的任务数
int num_concurrent_tasks_;
// schedule 函数,调度的时候会发给 AsyncTaskScheduler
ScheduleImpl schedule_impl_;
// 外部主动调用 abort 的时候设置的 callback
AbortContinuationImpl abort_cont_impl_;

// 所有的 TaskGroup
std::vector<TaskGroup> task_groups_;
bool register_finished_;
std::mutex mutex_; // Mutex protecting task_groups_ (state_ and num_tasks_present_
// fields) and register_finished_ flag

AtomicWithPadding<bool> aborted_;
// 最大为 `num_concurrent_tasks_`,表示最大可用的剩余并发数
AtomicWithPadding<int> num_tasks_to_schedule_;
// If a task group adds tasks it's possible for a thread inside
// ScheduleMore to miss this fact. This serves as a flag to
// notify the scheduling thread that it might need to make
// another pass through the scheduler
//
// 在数次调用之中,是否有新的 StartTaskGroup 插入了, 用来防止丢失调度机会
AtomicWithPadding<bool> tasks_added_recently_;
};

基本上 TaskScheduler 就是不停靠 fetch 任务数 -> 找到任务组 -> 调度任务这样的流程来执行了。