Arrow Compute: CallFunction and Add

概念介绍

Arrow 包含着一些对应的 Runtime 类型的语义,可以执行一些计算相关的功能,它的上层可以是 Python 脚本直接调用 Arrow Compute,也可以是 Substrait Plan 调用 Acero。

简单来说,Arrow 这些计算可以包括 Compute 层和 Acero 层:

  1. Compute: (执行单元) Datum, Expression, Function 和 Function 层面的执行( Kernel, FunctionExecutor, KernelExecutor)
  2. Acero: Push-Based Execution Model, ExecBatch, 调度(可能会用到 dataset 层的 api 来访问数据)

Datum: 一个泛用类型,可以包装 {scalar, array, chunked array, record batch, table}. 操作的输入和输出通常都是某种 vector<Datum> -> Datum

Function: 由一个或者数个 Kernel 组成,每个 Kernel 有操作和对应的输入/输出类型,Function 按照 Function name 注册在 FunctionRegistry 中. 在类型匹配的时候,Function 支持「挑选最佳的匹配」,并且,Function 支持 implicit cast,将一些类型完成对应的转型.

好,上面的知识你已经知道了,下面举一个弱智的例子:

1
2
3
4
5
6
7
auto i64_3 = std::make_shared<arrow::Int64Scalar>(3);
arrow::Datum incremented_datum;

ARROW_ASSIGN_OR_RAISE(incremented_datum,
arrow::compute::CallFunction("add", {int64_array_a, i64_3}));
std::shared_ptr<::arrow::Array> incremented_array = std::move(incremented_datum).make_array();
std::cout << incremented_array->ToString() << std::endl;
  1. 这里输入是一个 ArrayScalar,类型都是 Int64,输出是一个 Int64Array
  2. 调用了 arrow::compute::CallFunction, 通过名称 "add" 来调用对应的函数
  3. 返回类型是一个 Datum,可以 get 出来是一个 Array。

实际上,把 Int64Scalar 换成 Int32Scalar 也是可以通过的,因为这里的系统支持内部的转型。甚至从 Dict 转成普通类型也是可以的. Arrow Compute 指定了一个类型兼容性的表格,来处理这项任务:https://arrow.apache.org/docs/cpp/compute.html#common-numeric-type

这里简单的推断一下这套规则,感觉 Arrow 用的是简单的类型映射写进去的,没有引入一些相对复杂一些的计算:

  • Numeric
    • Numeric 允许 Implicit Cast
      • Signed integer / Unsigned integer / Floating Point 的 Common Type 是 max(T1, T2)
      • (Unsigned, Signed) 的 Common Type 是 larger-int(max(T1, T2)). 特殊的,当其中一个类型是 64-bit 的时候,这里选择的输出类型是 int64 (而可能不是 decimal)。这代表,(int64, uint64) 中,u64 的大于 i64::max 的值会被 truncate
    • Arrow 有 Decimal128 和 Decimal256
  • Temporal: Date types (Date32, Date64), Time types (Time32, Time64), Timestamp, Duration, Interval.
  • “Binary-like”: Binary, LargeBinary, sometimes also FixedSizeBinary.
  • “String-like”: String, LargeString.
    • 这里不支持特殊的限定长度或者 char varchar 类型。而且 String / LargeString 实际上和物理 Layout 有关
  • “List-like”: List, LargeList, sometimes also FixedSizeList.
  • “Nested”: List-likes (including FixedSizeList), Struct, Union, and related types like Map.

(令我感觉很搞笑的是,这里竟然没有包括字典和 REE. {字典, REE} 允许 Implicit Cast 到)

需要注意的是,Group Aggregation 并不能用 CallFunction 来调用

Function 有几种类型,粗略的说,类型包含:

  1. Aggregators: 接收 (chunked) array / scalar,生成一个 Scalar Output Value
  2. Grouped Aggregations: 类似 GROUP BY (col1, col2…)。接收 Hash Agg (我简单扫了一下,似乎不支持 SortMerge + Agg?)。实现上算子会类似下面所示。在这个
1
2
3
HashAggregate
grouper(hash)
hash_agg(agg)
  1. Scalar Functions: Scalar 并不是表示这个函数是 F(Scalar) -> Scalar,而是表示一种单射关系,从一组是参数生成另一组参数。举例为 “add”, “mul” 甚至 “abs” 这种。实际上这里种类是非常多的,不过我们这篇博客不介绍框架之外的东西。
    1. 比较有意思的是,random number generator 这种 0 个参数的东西也被归类了这里。
  2. Vector Functions: 输出是类似带状态的,比如 Sort / 累积和 / Filter 这种

上面的系统被组织成下面的层次:

  • ExecNode (in Acero)
  • Function( 不同类型的 Function )
  • Kernel (同一个 Function 的不同实现)

下面我们来阅读后面两项的代码组织。

从 CallFunction 和 “add” 看起

一些函数能够被 CallFunction 调用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
ARROW_EXPORT
Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args,
const FunctionOptions* options, ExecContext* ctx = NULLPTR);

ARROW_EXPORT
Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args,
ExecContext* ctx = NULLPTR);

ARROW_EXPORT
Result<Datum> CallFunction(const std::string& func_name, const ExecBatch& batch,
const FunctionOptions* options, ExecContext* ctx = NULLPTR);

ARROW_EXPORT
Result<Datum> CallFunction(const std::string& func_name, const ExecBatch& batch,
ExecContext* ctx = NULLPTR);

这里除了我们之前介绍的,还有几个比较好玩的参数:

  1. FunctionOptions: 是一个比较奇怪的东西
  2. ExecBatch: 基本上是 Acero 之类的传下来的,它们倾向以 Batch 模式来 Exec
  3. ExecContext: 执行的时候的上下文,包含了:
    1. MemoryPool
    2. CPU Info ( SIMD instr 之类的)
    3. CPU Executor
    4. ChunkSize ( Batch Size )
    5. Function Registry

那么,最为最开始使用的用户,我们可以自顶向下的来看看这个 CallFunction:

1
2
3
4
5
6
7
8
9
Result<Datum> CallFunction(const std::string& func_name, const std::vector<Datum>& args,
const FunctionOptions* options, ExecContext* ctx) {
if (ctx == nullptr) {
ctx = default_exec_context();
}
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<const Function> func,
ctx->func_registry()->GetFunction(func_name));
return func->Execute(args, options, ctx);
}

FunctionRegistry

FunctionRegistry 的实现关系类似 Naming 的 Scope:

1
2
3
4
5
6
7
8
9
class FunctionRegistry::FunctionRegistryImpl {
public:

private:
FunctionRegistryImpl* parent_;
std::mutex lock_;
std::unordered_map<std::string, std::shared_ptr<Function>> name_to_function_;
std::unordered_map<std::string, const FunctionOptionsType*> name_to_options_type_;
};
  1. 内部成员实际上会有 lock_ 的保护
  2. 有父级关系( 我看了下,arrow 内部好像没用到,感觉是给用户处理 UDF 之类的?)
  3. 允许加入成员和 alias

那么,这里层次是怎么玩的呢,答案是它会按照分类来创建:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() {
auto registry = FunctionRegistry::Make();

// Register core kernels
RegisterScalarCast(registry.get());
RegisterVectorHash(registry.get());
RegisterVectorSelection(registry.get());

RegisterScalarOptions(registry.get());
RegisterVectorOptions(registry.get());
RegisterAggregateOptions(registry.get());

#ifdef ARROW_COMPUTE
// Register additional kernels

// Scalar functions
RegisterScalarArithmetic(registry.get());
RegisterScalarBoolean(registry.get());
RegisterScalarComparison(registry.get());
RegisterScalarIfElse(registry.get());
RegisterScalarNested(registry.get());
RegisterScalarRandom(registry.get()); // Nullary
RegisterScalarRoundArithmetic(registry.get());
RegisterScalarSetLookup(registry.get());
RegisterScalarStringAscii(registry.get());
RegisterScalarStringUtf8(registry.get());
RegisterScalarTemporalBinary(registry.get());
RegisterScalarTemporalUnary(registry.get());
RegisterScalarValidity(registry.get());

// Vector functions
RegisterVectorArraySort(registry.get());
RegisterVectorCumulativeSum(registry.get());
RegisterVectorNested(registry.get());
RegisterVectorRank(registry.get());
RegisterVectorReplace(registry.get());
RegisterVectorSelectK(registry.get());
RegisterVectorSort(registry.get());
RegisterVectorRunEndEncode(registry.get());
RegisterVectorRunEndDecode(registry.get());

// Aggregate functions
RegisterHashAggregateBasic(registry.get());
RegisterScalarAggregateBasic(registry.get());
RegisterScalarAggregateMode(registry.get());
RegisterScalarAggregateQuantile(registry.get());
RegisterScalarAggregateTDigest(registry.get());
RegisterScalarAggregateVariance(registry.get());
#endif

return registry;
}

} // namespace internal

FunctionRegistry* GetFunctionRegistry() {
static auto g_registry = internal::CreateBuiltInRegistry();
return g_registry.get();
}

我们举几个例子,比如插入 Cast 相关的:

  1. MetaFunction 本质上是个转发器(误)
  2. 注册了一个 FunctionOption,这里绑定的是类型和对应的成员。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
static auto kCastOptionsType = GetFunctionOptionsType<CastOptions>(
arrow::internal::DataMember("to_type", &CastOptions::to_type),
arrow::internal::DataMember("allow_int_overflow", &CastOptions::allow_int_overflow),
arrow::internal::DataMember("allow_time_truncate", &CastOptions::allow_time_truncate),
arrow::internal::DataMember("allow_time_overflow", &CastOptions::allow_time_overflow),
arrow::internal::DataMember("allow_decimal_truncate",
&CastOptions::allow_decimal_truncate),
arrow::internal::DataMember("allow_float_truncate",
&CastOptions::allow_float_truncate),
arrow::internal::DataMember("allow_invalid_utf8", &CastOptions::allow_invalid_utf8));
} // namespace

void RegisterScalarCast(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::make_shared<CastMetaFunction>()));
DCHECK_OK(registry->AddFunctionOptionsType(kCastOptionsType));
}

其实你会对这个 FunctionType 很困惑,我也很困惑,看上去他们挫了一套反射库:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
template <typename Options, typename... Properties>
const FunctionOptionsType* GetFunctionOptionsType(const Properties&... properties) {
static const class OptionsType : public GenericOptionsType {
public:
explicit OptionsType(const arrow::internal::PropertyTuple<Properties...> properties)
: properties_(properties) {}

const char* type_name() const override { return Options::kTypeName; }

std::string Stringify(const FunctionOptions& options) const override {
const auto& self = checked_cast<const Options&>(options);
return StringifyImpl<Options>(self, properties_).Finish();
}
bool Compare(const FunctionOptions& options,
const FunctionOptions& other) const override {
const auto& lhs = checked_cast<const Options&>(options);
const auto& rhs = checked_cast<const Options&>(other);
return CompareImpl<Options>(lhs, rhs, properties_).equal_;
}
Status ToStructScalar(const FunctionOptions& options,
std::vector<std::string>* field_names,
std::vector<std::shared_ptr<Scalar>>* values) const override {
const auto& self = checked_cast<const Options&>(options);
RETURN_NOT_OK(
ToStructScalarImpl<Options>(self, properties_, field_names, values).status_);
return Status::OK();
}
Result<std::unique_ptr<FunctionOptions>> FromStructScalar(
const StructScalar& scalar) const override {
auto options = std::make_unique<Options>();
RETURN_NOT_OK(
FromStructScalarImpl<Options>(options.get(), scalar, properties_).status_);
return std::move(options);
}
std::unique_ptr<FunctionOptions> Copy(const FunctionOptions& options) const override {
auto out = std::make_unique<Options>();
CopyImpl<Options>(out.get(), checked_cast<const Options&>(options), properties_);
return std::move(out);
}

private:
const arrow::internal::PropertyTuple<Properties...> properties_;
} instance(arrow::internal::MakeProperties(properties...));
return &instance;
}

Function

然后我们来看看 Function 部分,以 “Add” 为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
namespace {

Result<Datum> ExecuteInternal(const Function& func, std::vector<Datum> args,
int64_t passed_length, const FunctionOptions* options,
ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(auto inputs, internal::GetFunctionArgumentTypes(args));
ARROW_ASSIGN_OR_RAISE(auto func_exec, func.GetBestExecutor(inputs));
ARROW_RETURN_NOT_OK(func_exec->Init(options, ctx));
return func_exec->Execute(args, passed_length);
}

} // namespace

Result<Datum> Function::Execute(const std::vector<Datum>& args,
const FunctionOptions* options, ExecContext* ctx) const {
return ExecuteInternal(*this, args, /*passed_length=*/-1, options, ctx);
}

Result<Datum> Function::Execute(const ExecBatch& batch, const FunctionOptions* options,
ExecContext* ctx) const {
return ExecuteInternal(*this, batch.values, batch.length, options, ctx);
}

Function 内部有一层 Dispatch 的流程:

  1. 创建对应的 Executor,这里的 Executor 是个 KernelExecutor 类型
  2. 然后找到 Best 的 Kernel( GetBestExecutor )
  3. 包装成一个 FunctionExecutor 返回

我们先来讲 Kernel,然后来讲对应的 Executors

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
Result<std::shared_ptr<FunctionExecutor>> Function::GetBestExecutor(
std::vector<TypeHolder> inputs) const {
std::unique_ptr<detail::KernelExecutor> executor;
if (kind() == Function::SCALAR) {
executor = detail::KernelExecutor::MakeScalar();
} else if (kind() == Function::VECTOR) {
executor = detail::KernelExecutor::MakeVector();
} else if (kind() == Function::SCALAR_AGGREGATE) {
executor = detail::KernelExecutor::MakeScalarAggregate();
} else {
return Status::NotImplemented("Direct execution of HASH_AGGREGATE functions");
}

ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, DispatchBest(&inputs));

return std::make_shared<detail::FunctionExecutorImpl>(std::move(inputs), kernel,
std::move(executor), *this);
}

Kernel 执行器和逻辑的生成

首先,我们来介绍 Kernel 对象,Kernel 的类型稍稍有点复杂:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/// \brief Common initializer function for all kernel types.
using KernelInit = std::function<Result<std::unique_ptr<KernelState>>(
KernelContext*, const KernelInitArgs&)>;

/// \brief Base type for kernels. Contains the function signature and
/// optionally the state initialization function, along with some common
/// attributes
struct ARROW_EXPORT Kernel {
Kernel() = default;

Kernel(std::shared_ptr<KernelSignature> sig, KernelInit init)
: signature(std::move(sig)), init(std::move(init)) {}

Kernel(std::vector<InputType> in_types, OutputType out_type, KernelInit init)
: Kernel(KernelSignature::Make(std::move(in_types), std::move(out_type)),
std::move(init)) {}

/// \brief The "signature" of the kernel containing the InputType input
/// argument validators and OutputType output type resolver.
std::shared_ptr<KernelSignature> signature;

/// \brief Create a new KernelState for invocations of this kernel, e.g. to
/// set up any options or state relevant for execution.
KernelInit init;

/// \brief Create a vector of new KernelState for invocations of this kernel.
static Status InitAll(KernelContext*, const KernelInitArgs&,
std::vector<std::unique_ptr<KernelState>>*);

/// \brief Indicates whether execution can benefit from parallelization
/// (splitting large chunks into smaller chunks and using multiple
/// threads). Some kernels may not support parallel execution at
/// all. Synchronization and concurrency-related issues are currently the
/// responsibility of the Kernel's implementation.
bool parallelizable = true;

/// \brief Indicates the level of SIMD instruction support in the host CPU is
/// required to use the function. The intention is for functions to be able to
/// contain multiple kernels with the same signature but different levels of SIMD,
/// so that the most optimized kernel supported on a host's processor can be chosen.
SimdLevel::type simd_level = SimdLevel::NONE;

// Additional kernel-specific data
std::shared_ptr<KernelState> data;
};

举例子是 Add 这个 Binary 的 Scalar Function:

1
2
auto add = MakeArithmeticFunction<Add>("add", add_doc);
AddDecimalBinaryKernels<Add>("add", add.get());

这里调了对应的 KernelGenerator,这里写了一套模版代码,我觉得调用起来挺方便,读起来感觉真的有点点难,首先这里有 Add 和 Checked 版本的 Add,他们提供是否检查的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
struct Add {
template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
Status*) {
return left + right;
}

template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_unsigned_integer_value<T> Call(KernelContext*, Arg0 left,
Arg1 right, Status*) {
return left + right;
}

template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_signed_integer_value<T> Call(KernelContext*, Arg0 left,
Arg1 right, Status*) {
return arrow::internal::SafeSignedAdd(left, right);
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
return left + right;
}
};

struct AddChecked {
template <typename T, typename Arg0, typename Arg1>
static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
T result = 0;
if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) {
*st = Status::Invalid("overflow");
}
return result;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
Status*) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
return left + right;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
return left + right;
}
};

有一个地方需要注意一下,虽然 Arrow 可能会在调用的时候判断,如果输入(有一个)是 NULL,那么输出 NULL。但是 Add 的时候,针对 Null 上 Undefined 的地方做加法开销也不大。我们这里看 MakeArithmeticFunction<Add>("add", add_doc):

1
2
3
4
5
6
// 这里还塞了一个 FunctionDoc 进去
const FunctionDoc add_doc{"Add the arguments element-wise",
("Results will wrap around on integer overflow.\n"
"Use function \"add_checked\" if you want overflow\n"
"to return an error."),
{"x", "y"}};

然后和:

1
2
3
4
5
6
7
8
9
10
11
template <typename Op, typename FunctionImpl = ArithmeticFunction>
std::shared_ptr<ScalarFunction> MakeArithmeticFunction(std::string name,
FunctionDoc doc) {
auto func = std::make_shared<FunctionImpl>(name, Arity::Binary(), std::move(doc));
for (const auto& ty : NumericTypes()) {
auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Op>(ty);
DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
}
AddNullExec(func.get());
return func;
}

这里 NumericTypes 包含什么呢:

1
2
3
4
5
6
7
8
9
10
11
12
13
// Numeric types
Extend(g_int_types, &g_numeric_types);
Extend(g_floating_types, &g_numeric_types);

// Non-parametric, non-nested types. This also DOES NOT include
//
// * Decimal
// * Fixed Size Binary
// * Time32
// * Time64
// * Timestamp
g_primitive_types = {null(), boolean(), date32(), date64()};
Extend(g_numeric_types, &g_primitive_types);

然后调用 ArithmeticExecFromOp<ScalarBinaryEqualTypes, Op>(ty) 产生 ArrayKernelExec, ArithmeticExecFromOp 会使用 ScalarBinaryEqualTypes 这个 Generator 来把动态的调用分发到静态模版上:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
using ArrayKernelExec = Status (*)(KernelContext*, const ExecSpan&, ExecResult*);

template <template <typename...> class KernelGenerator, typename Op,
typename KernelType = ArrayKernelExec, typename... Args>
KernelType ArithmeticExecFromOp(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::INT8:
return KernelGenerator<Int8Type, Int8Type, Op, Args...>::Exec;
case Type::UINT8:
return KernelGenerator<UInt8Type, UInt8Type, Op, Args...>::Exec;
case Type::INT16:
return KernelGenerator<Int16Type, Int16Type, Op, Args...>::Exec;
case Type::UINT16:
return KernelGenerator<UInt16Type, UInt16Type, Op, Args...>::Exec;
case Type::INT32:
return KernelGenerator<Int32Type, Int32Type, Op, Args...>::Exec;
case Type::UINT32:
return KernelGenerator<UInt32Type, UInt32Type, Op, Args...>::Exec;
case Type::DURATION:
case Type::INT64:
case Type::TIMESTAMP:
return KernelGenerator<Int64Type, Int64Type, Op, Args...>::Exec;
case Type::UINT64:
return KernelGenerator<UInt64Type, UInt64Type, Op, Args...>::Exec;
case Type::FLOAT:
return KernelGenerator<FloatType, FloatType, Op, Args...>::Exec;
case Type::DOUBLE:
return KernelGenerator<DoubleType, DoubleType, Op, Args...>::Exec;
default:
DCHECK(false);
return FailFunctor<KernelType>::Exec;
}

下面我们来看 Generator:

1
2
3
4
// A kernel exec generator for binary kernels where both input types are the
// same
template <typename OutType, typename ArgType, typename Op>
using ScalarBinaryEqualTypes = ScalarBinary<OutType, ArgType, ArgType, Op>;

里面是:

  1. 把输入输出的 Array 用泛型抽了迭代器出来
  2. 根据 Op::Call 来调用
  3. 根据一个泛型 Writer 来抽出输出,写输出的数据
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
// A kernel exec generator for binary functions that addresses both array and
// scalar inputs and dispatches input iteration and output writing to other
// templates
//
// This template executes the operator even on the data behind null values,
// therefore it is generally only suitable for operators that are safe to apply
// even on the null slot values.
//
// The "Op" functor should have the form
//
// struct Op {
// template <typename OutValue, typename Arg0Value, typename Arg1Value>
// static OutValue Call(KernelContext* ctx, Arg0Value arg0, Arg1Value arg1, Status* st)
// {
// // implementation
// // NOTE: "status" should only populated with errors,
// // leave it unmodified to indicate Status::OK()
// }
// };
template <typename OutType, typename Arg0Type, typename Arg1Type, typename Op>
struct ScalarBinary {
using OutValue = typename GetOutputType<OutType>::T;
using Arg0Value = typename GetViewType<Arg0Type>::T;
using Arg1Value = typename GetViewType<Arg1Type>::T;

static Status ArrayArray(KernelContext* ctx, const ArraySpan& arg0,
const ArraySpan& arg1, ExecResult* out);

static Status ArrayScalar(KernelContext* ctx, const ArraySpan& arg0, const Scalar& arg1,
ExecResult* out);

static Status ScalarArray(KernelContext* ctx, const Scalar& arg0, const ArraySpan& arg1,
ExecResult* out);

static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
if (batch[0].is_array()) {
if (batch[1].is_array()) {
return ArrayArray(ctx, batch[0].array, batch[1].array, out);
} else {
return ArrayScalar(ctx, batch[0].array, *batch[1].scalar, out);
}
} else {
if (batch[1].is_array()) {
return ScalarArray(ctx, *batch[0].scalar, batch[1].array, out);
} else {
DCHECK(false);
return Status::Invalid("Should be unreachable");
}
}
}
};

最后调用 func->AddKernel(/*输入类型*/{ty, ty}, /*输出类型*/ty, ``/*执行callback*/``exec) 来添加执行的 Kernel。

这里我们再回过头看看那个问题,Kernel 是什么:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
struct Kernel;

/// \brief Arguments to pass to an KernelInit function. A struct is used to help
/// avoid API breakage should the arguments passed need to be expanded.
struct KernelInitArgs {
/// \brief A pointer to the kernel being initialized. The init function may
/// depend on the kernel's KernelSignature or other data contained there.
const Kernel* kernel;

/// \brief The types of the input arguments that the kernel is
/// about to be executed against.
const std::vector<TypeHolder>& inputs;

/// \brief Opaque options specific to this kernel. May be nullptr for functions
/// that do not require options.
const FunctionOptions* options;
};

/// \brief Common initializer function for all kernel types.
using KernelInit = std::function<Result<std::unique_ptr<KernelState>>(
KernelContext*, const KernelInitArgs&)>;

/// \brief Base type for kernels. Contains the function signature and
/// optionally the state initialization function, along with some common
/// attributes
struct ARROW_EXPORT Kernel {
Kernel() = default;

Kernel(std::shared_ptr<KernelSignature> sig, KernelInit init)
: signature(std::move(sig)), init(std::move(init)) {}

Kernel(std::vector<InputType> in_types, OutputType out_type, KernelInit init)
: Kernel(KernelSignature::Make(std::move(in_types), std::move(out_type)),
std::move(init)) {}

/// \brief The "signature" of the kernel containing the InputType input
/// argument validators and OutputType output type resolver.
std::shared_ptr<KernelSignature> signature;

/// \brief Create a new KernelState for invocations of this kernel, e.g. to
/// set up any options or state relevant for execution.
KernelInit init;

/// \brief Create a vector of new KernelState for invocations of this kernel.
static Status InitAll(KernelContext*, const KernelInitArgs&,
std::vector<std::unique_ptr<KernelState>>*);

bool parallelizable = true;
SimdLevel::type simd_level = SimdLevel::NONE;

// Additional kernel-specific data
std::shared_ptr<KernelState> data;
};

Kernel 本身是一组奇怪的 api,包含 KernelState, 外部根据KernelSignatureKernelInit来调用它,我们来看 ScalarKernel:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
using ArrayKernelExec = Status (*)(KernelContext*, const ExecSpan&, ExecResult*);

/// \brief Kernel data structure for implementations of ScalarFunction. In
/// addition to the members found in Kernel, contains the null handling
/// and memory pre-allocation preferences.
struct ARROW_EXPORT ScalarKernel : public Kernel {
ScalarKernel() = default;

ScalarKernel(std::shared_ptr<KernelSignature> sig, ArrayKernelExec exec,
KernelInit init = NULLPTR)
: Kernel(std::move(sig), init), exec(exec) {}

ScalarKernel(std::vector<InputType> in_types, OutputType out_type, ArrayKernelExec exec,
KernelInit init = NULLPTR)
: Kernel(std::move(in_types), std::move(out_type), std::move(init)), exec(exec) {}

/// \brief Perform a single invocation of this kernel. Depending on the
/// implementation, it may only write into preallocated memory, while in some
/// cases it will allocate its own memory. Any required state is managed
/// through the KernelContext.
ArrayKernelExec exec;

/// \brief Writing execution results into larger contiguous allocations
/// requires that the kernel be able to write into sliced output ArrayData*,
/// including sliced output validity bitmaps. Some kernel implementations may
/// not be able to do this, so setting this to false disables this
/// functionality.
bool can_write_into_slices = true;

// For scalar functions preallocated data and intersecting arg validity
// bitmaps is a reasonable default
NullHandling::type null_handling = NullHandling::INTERSECTION;
MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE;
};

这里比较关键的是这个 Exec,我们来看看执行侧:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class ARROW_EXPORT KernelExecutor {
public:
virtual ~KernelExecutor() = default;

/// The Kernel's `init` method must be called and any KernelState set in the
/// KernelContext *before* KernelExecutor::Init is called. This is to facilitate
/// the case where init may be expensive and does not need to be called again for
/// each execution of the kernel, for example the same lookup table can be re-used
/// for all scanned batches in a dataset filter.
virtual Status Init(KernelContext*, KernelInitArgs) = 0;

// TODO(wesm): per ARROW-16819, adding ExecBatch variant so that a batch
// length can be passed in for scalar functions; will have to return and
// clean a bunch of things up
virtual Status Execute(const ExecBatch& batch, ExecListener* listener) = 0;

virtual Datum WrapResults(const std::vector<Datum>& args,
const std::vector<Datum>& outputs) = 0;

/// \brief Check the actual result type against the resolved output type
virtual Status CheckResultType(const Datum& out, const char* function_name) = 0;

static std::unique_ptr<KernelExecutor> MakeScalar();
static std::unique_ptr<KernelExecutor> MakeVector();
static std::unique_ptr<KernelExecutor> MakeScalarAggregate();
};

最终,ScalarExecutor::Execute 会执行那个 exec,来处理对应的逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
Status ExecuteNonSpans(ExecListener* listener) {
// ARROW-16756: Kernel is going to allocate some memory and so
// for the time being we pass in an empty or partially-filled
// shared_ptr<ArrayData> or shared_ptr<Scalar> to be populated
// by the kernel.
//
// We will eventually delete the Scalar output path per
// ARROW-16757.
ExecSpan input;
ExecResult output;
while (span_iterator_.Next(&input)) {
ARROW_ASSIGN_OR_RAISE(output.value, PrepareOutput(input.length));
DCHECK(output.is_array_data());

ArrayData* out_arr = output.array_data().get();
if (output_type_.type->id() == Type::NA) {
out_arr->null_count = out_arr->length;
} else if (kernel_->null_handling == NullHandling::INTERSECTION) {
RETURN_NOT_OK(PropagateNulls(kernel_ctx_, input, out_arr));
} else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) {
out_arr->null_count = 0;
}

RETURN_NOT_OK(kernel_->exec(kernel_ctx_, input, &output));

// Output type didn't change
DCHECK(output.is_array_data());

// Emit a result for each chunk
RETURN_NOT_OK(EmitResult(std::move(output.array_data()), listener));
}
return Status::OK();
}

Kernel 的类型匹配

还有一个比较重要的是 Kernel 函数的类型匹配,这里会根据 Matches 来挑选对应的类型来进行匹配。我们回到 ExecuteInternal:

1
2
3
4
5
6
7
8
Result<Datum> ExecuteInternal(const Function& func, std::vector<Datum> args,
int64_t passed_length, const FunctionOptions* options,
ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(auto inputs, internal::GetFunctionArgumentTypes(args));
ARROW_ASSIGN_OR_RAISE(auto func_exec, func.GetBestExecutor(inputs));
ARROW_RETURN_NOT_OK(func_exec->Init(options, ctx));
return func_exec->Execute(args, passed_length);
}

这里会:

  1. GetBestExecutor,根据 Function::DispatchBest 来找到对应的 Kernel,然后创建 FunctionExecutor. Dispatch 默认走的是 DispatchExact,就是找到完全对应的输入。有的时候根据需求不同,可能会插入一些 Cast
  2. func_exec->Init 阶段,去 Init 需要的内容
  3. 调用 func_exec->Execute

Dispatch 的签名如下:

1
2
3
4
5
6
7
8
/// \brief Return a best-match kernel that can execute the function given the argument
/// types, after implicit casts are applied.
///
/// \param[in,out] values Argument types. An element may be modified to
/// indicate that the returned kernel only approximately matches the input
/// value descriptors; callers are responsible for casting inputs to the type
/// required by the kernel.
virtual Result<const Kernel*> DispatchBest(std::vector<TypeHolder>* values) const;

我们再以 Add 为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
struct ArithmeticFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;

Result<const Kernel*> DispatchBest(std::vector<TypeHolder>* types) const override {
RETURN_NOT_OK(CheckArity(types->size()));

RETURN_NOT_OK(CheckDecimals(types));

using arrow::compute::detail::DispatchExactImpl;
if (auto kernel = DispatchExactImpl(this, *types)) return kernel;

EnsureDictionaryDecoded(types);

// Only promote types for binary functions
if (types->size() == 2) {
ReplaceNullWithOtherType(types);
TimeUnit::type finest_unit;
if (CommonTemporalResolution(types->data(), types->size(), &finest_unit)) {
ReplaceTemporalTypes(finest_unit, types);
} else {
if (TypeHolder type = CommonNumeric(*types)) {
ReplaceTypes(type, types);
}
}
}

if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
return arrow::compute::detail::NoMatchingKernel(this, *types);
}
};

这个地方就是创建所有需要的上下文,Kernel 的挑选来自于

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/// \brief An type-checking interface to permit customizable validation rules
/// for use with InputType and KernelSignature. This is for scenarios where the
/// acceptance is not an exact type instance, such as a TIMESTAMP type for a
/// specific TimeUnit, but permitting any time zone.
struct ARROW_EXPORT TypeMatcher {
virtual ~TypeMatcher() = default;

/// \brief Return true if this matcher accepts the data type.
virtual bool Matches(const DataType& type) const = 0;

/// \brief A human-interpretable string representation of what the type
/// matcher checks for, usable when printing KernelSignature or formatting
/// error messages.
virtual std::string ToString() const = 0;

/// \brief Return true if this TypeMatcher contains the same matching rule as
/// the other. Currently depends on RTTI.
virtual bool Equals(const TypeMatcher& other) const = 0;
};

这里也允许 Output 计算出对应的输出类型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
/// \brief Container to capture both exact and input-dependent output types.
class ARROW_EXPORT OutputType {
public:
/// \brief An enum indicating whether the value type is an invariant fixed
/// value or one that's computed by a kernel-defined resolver function.
enum ResolveKind { FIXED, COMPUTED };

/// Type resolution function. Given input types, return output type. This
/// function MAY may use the kernel state to decide the output type based on
/// the FunctionOptions.
///
/// This function SHOULD _not_ be used to check for arity, that is to be
/// performed one or more layers above.
using Resolver = Result<TypeHolder> (*)(KernelContext*, const std::vector<TypeHolder>&);

在 Init 的时候逻辑如下. 这里 Add 对应的上下文不太多。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
Status KernelInit(const FunctionOptions* options) {
RETURN_NOT_OK(CheckOptions(func, options));
if (options == NULLPTR) {
options = func.default_options();
}
if (kernel->init) {
ARROW_ASSIGN_OR_RAISE(state,
kernel->init(&kernel_ctx, {kernel, in_types, options}));
kernel_ctx.SetState(state.get());
}

RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, in_types, options}));
this->options = options;
inited = true;
return Status::OK();
}

Status Init(const FunctionOptions* options, ExecContext* exec_ctx) override {
if (exec_ctx == NULLPTR) {
exec_ctx = default_exec_context();
}
kernel_ctx = KernelContext{exec_ctx, kernel};
return KernelInit(options);
}

总结

这篇文章简单串了一下最简单的 Add 执行的流程,限于本人目前水平和篇幅,这部分不会很详细。等待本人更熟悉 Runtime 代码再扩充本文的内容。