arrow expression

Expression,顾名思义,是 arrow 中计算的表达式。这里可以通过 Substrait 来构建 Plan 或者单机的表达式。

Glance

在 Arrow 中,Expression 可以分为下面几种可能的形式:

  • Call
    • Function + Args 的包装,分为 Bounded / Unbounded 的类型
  • Parameter ( A reference to a single (potentially nested) field of the input Datum.)
    • Arrow 或者 Input 中 Field 的包装,分为 Bounded / Unbounded
    • 可以通过 FieldRef 之类的来构建。用户大部分时候只走 field_ref,但底下实现还是个 Parameter
  • Literal: 正常的 Literal, 包含 Null。实际上由 Datum (我们在介绍 Compute 的时候讲过) 实现

上面这几套接口在 Expression 中表现为比较有趣的形式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/// An unbound expression which maps a single Datum to another Datum.
/// An expression is one of
/// - A literal Datum.
/// - A reference to a single (potentially nested) field of the input Datum.
/// - A call to a compute function, with arguments specified by other Expressions.
class ARROW_EXPORT Expression {
public:
bool is_valid() const { return impl_ != NULLPTR; }

/// Access a Call or return nullptr if this expression is not a call
const Call* call() const;
/// Access a Datum or return nullptr if this expression is not a literal
const Datum* literal() const;
/// Access a FieldRef or return nullptr if this expression is not a field_ref
const FieldRef* field_ref() const;

struct Parameter {
FieldRef ref;

// post-bind properties
TypeHolder type;
::arrow::internal::SmallVector<int, 2> indices;
};
const Parameter* parameter() const;

private:
using Impl = std::variant<Datum, Parameter, Call>;
std::shared_ptr<Impl> impl_;
};

这里的使用方式很有意思,类似 enum:

1
2
auto call = expr->call();
if (call == nullptr) return;

Datum 作为 literal 我们之前就已经介绍过了(囧)

它自己还有整个 Expression 的类型 (Type),然后此外,整个 Expression 还有对应的 Hash 和 Equal 的方法,用来组一些比较。我们之后会看到这些方法

1
2
3
4
5
6
std::string ToString() const;
bool Equals(const Expression& other) const;
size_t hash() const;
struct Hash {
size_t operator()(const Expression& expr) const { return expr.hash(); }
};

此外,这里还有一些特殊的属性,需要额外提供一下:

  1. IsBound:
    1. 这个后面讲,有点复杂
  2. IsScalar (需要注意的是,在这里,complex type 之类的也算是 scalar)
    1. 对于 Datum 来说,Datum 包含的是否是 Scalar (它还能包含 Array Table RecordBatch ChunkedArray 之类的)
    2. 对于 FieldRef,这里…啥都行!
    3. Call: all argument IsScalar, and function type is SCALAR.

FieldRef and Parameter

FieldRef 是个很奇怪的东西,表示对某个 Field 的引用,它本身也可以是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/// Unlike FieldPath (which exclusively uses indices of child fields), FieldRef may
/// reference a field by name. It is intended to replace parameters like `int field_index`
/// and `const std::string& field_name`; it can be implicitly constructed from either a
/// field index or a name.
///
/// Nested fields can be referenced as well. Given
/// schema({field("a", struct_({field("n", null())})), field("b", int32())})
///
/// the following all indicate the nested field named "n":
/// FieldRef ref1(0, 0);
/// FieldRef ref2("a", 0);
/// FieldRef ref3("a", "n");
/// FieldRef ref4(0, "n");
/// ARROW_ASSIGN_OR_RAISE(FieldRef ref5,
/// FieldRef::FromDotPath(".a[0]"));

他可以表示的路径如上所述, 这里可以用名字表示。那么实际上内部用 FieldPath 是最好的,但是这里 xjb 糊了一套,用户 Bind 的时候要传个 Schema 进来,然后用这个 Schema 来 Bind,Bind 完里面还是原来那个 "a.b.c",只是绑定了一些类型。

在 Bind 的时候,注意到这些 Post Bind,类型是在 Bind 之后绑定的(还是个 TypeHolder 呢,呵呵)

1
2
3
4
5
6
7
struct Parameter {
FieldRef ref;

// post-bind properties
TypeHolder type;
::arrow::internal::SmallVector<int, 2> indices;
};

Call

Call 表示一个 fn call.

如果是 Call, 这里会有对应的参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
struct Call {
std::string function_name;
std::vector<Expression> arguments;
std::shared_ptr<FunctionOptions> options;
// Cached hash value
size_t hash;

// post-Bind properties:
std::shared_ptr<Function> function;
const Kernel* kernel = NULLPTR;
std::shared_ptr<KernelState> kernel_state;
TypeHolder type;

void ComputeHash();
};

这里可以通过 function 来判断是否做 binding。这里的含义还是比较清晰的,Expression 的主要内容还是在这个地方应该是最重要的一个类型。+ - 之类的,本身有 add 之类的 Function ( 在之前的 compute 模块介绍过 )。而 Expression 层会组织成:

1
Call(function_name="Add", arguments={field_ref("a.b"), literal(1000)})

这样的形式,在 Bind 以后,这里会绑定到 Function 和对应的 kernel 执行器上。

这里还提供了 and, or 之类的东西给用户,非常有意思:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
ARROW_EXPORT Expression project(std::vector<Expression> values,
std::vector<std::string> names);

ARROW_EXPORT Expression equal(Expression lhs, Expression rhs);

ARROW_EXPORT Expression not_equal(Expression lhs, Expression rhs);

ARROW_EXPORT Expression less(Expression lhs, Expression rhs);

ARROW_EXPORT Expression less_equal(Expression lhs, Expression rhs);

ARROW_EXPORT Expression greater(Expression lhs, Expression rhs);

ARROW_EXPORT Expression greater_equal(Expression lhs, Expression rhs);

ARROW_EXPORT Expression is_null(Expression lhs, bool nan_is_null = false);

ARROW_EXPORT Expression is_valid(Expression lhs);

ARROW_EXPORT Expression and_(Expression lhs, Expression rhs);
ARROW_EXPORT Expression and_(const std::vector<Expression>&);
ARROW_EXPORT Expression or_(Expression lhs, Expression rhs);
ARROW_EXPORT Expression or_(const std::vector<Expression>&);
ARROW_EXPORT Expression not_(Expression operand);

比较有意思的是 project,会产生一个新的 struct 类型。

Bind

对 Expression 的 Bind 会比较复杂. 这里可能做的事情有:

  1. Bind 所有的子成员,然后进入 BindNonRecursive
  2. 拿到所有 arguments 的类型
  3. 根据 Name 和参数去找到对应的 Function , Kernel,这里先找完全匹配的 ( `DispatchExact
    1. 如果找到正好匹配的,就用这个
    2. 如果有 insert_implicit_casts,就会尝试修改类型
      1. 对于 Literal,找到 Literal 的最小类型(eg: 如果是 Datum(int32(8)), 可以改成 Datum(int8(8))
      2. 尝试去 DispatchBest,然后如果有不一样的,
        1. 如果是 field_ref,直接用目标类型
        2. 如果是 field_ref 或者 call,插入一个 Cast
  4. 初始化 KernelContextKernelState

优化

表达式的很大一部分逻辑在于对表达式进行处理。下面列举了一组相关的 API。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/// Weak canonicalization which establishes guarantees for subsequent passes. Even
/// equivalent Expressions may result in different canonicalized expressions.
/// TODO this could be a strong canonicalization
ARROW_EXPORT
Result<Expression> Canonicalize(Expression, ExecContext* = NULLPTR);

/// Simplify Expressions based on literal arguments (for example, add(null, x) will always
/// be null so replace the call with a null literal). Includes early evaluation of all
/// calls whose arguments are entirely literal.
ARROW_EXPORT
Result<Expression> FoldConstants(Expression);

/// Simplify Expressions by replacing with known values of the fields which it references.
ARROW_EXPORT
Result<Expression> ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values,
Expression);

/// Simplify an expression by replacing subexpressions based on a guarantee:
/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is
/// used to remove redundant function calls from a filter expression or to replace a
/// reference to a constant-value field with a literal.
ARROW_EXPORT
Result<Expression> SimplifyWithGuarantee(Expression,
const Expression& guaranteed_true_predicate);

/// Replace all named field refs (e.g. "x" or "x.y") with field paths (e.g. [0] or [1,3])
///
/// This isn't usually needed and does not offer any simplification by itself. However,
/// it can be useful to normalize an expression to paths to make it simpler to work with.
ARROW_EXPORT Result<Expression> RemoveNamedRefs(Expression expression);

tools

arrow/compute/util.h 等地方提供了一组靠谱的工具,最典型的是一个 tree visitor:

1
2
3
4
5
6
7
8
9
10
11
12
13
/// Modify an Expression with pre-order and post-order visitation.
/// `pre` will be invoked on each Expression. `pre` will visit Calls before their
/// arguments, `post_call` will visit Calls (and no other Expressions) after their
/// arguments. Visitors should return the Identical expression to indicate no change; this
/// will prevent unnecessary construction in the common case where a modification is not
/// possible/necessary/...
///
/// If an argument was modified, `post_call` visits a reconstructed Call with the modified
/// arguments but also receives a pointer to the unmodified Expression as a second
/// argument. If no arguments were modified the unmodified Expression* will be nullptr.
template <typename PreVisit, typename PostVisitCall>
Result<Expression> ModifyExpression(Expression expr, const PreVisit& pre,
const PostVisitCall& post_call);

这个能够被用来扫 + 更新整个 tree.

Canonicalize

尝试把表达式处理成长得差不多的情况。

e.g: 对同样的可交换操作的处理

1
((a + b) + 2) + 3

这里发现是同样的操作,就会尝试整理,然后做成便于 constant folding 的形式。这里也会有 field_ref, literal , null literal 位置的关系。

e.g.: 对比较的处理

1
2
a > 3
3 < a

这里会被处理成一样的形式。

Others

这里还有 Constant Folding,抽出 Key-Values 等形式。其实这个表达式相对来说还是太 trivial 了,做个入门还行,难一点还是别看这个了。