Arrow simplifies expression with guarantee

在之前的系列中介绍过 Arrow Expression: https://blog.mwish.me/2023/07/06/arrow-expression/

在 Arrow 读取 Parquet 的时候,一般可能 Reader 会有某个 Filter,上层推下这个 Filter 来执行过滤。Parquet 的 min-max stats 会被抽取成 Arrow 的 Expression,然后这个 Expression 会被视作 Gurantee。Gurantee 返回给上层,用来简化 Filter。如果 Filter 注定是 false(即可以被视为,读出来的数据一定会被过滤掉),这个 RowGroup 就可以不再需要被读取。

下面分布介绍这个流程。

Gurantee 的抽取

  • Dataset 的 schema 层: https://blog.mwish.me/2023/07/23/Arrow-Dataset/#%E4%B8%8A%E5%B1%82%E6%93%8D%E4%BD%9C
  • Fragment::partition_expression() 中,抽取分区列表达式
  • ParquetFileFragment::EvaluateStatisticsAsExpression 会在每个 RowGroup 的各个 ColumnChunk 被抽取出来,然后用 kleene_and 把这些表达式(每个 Column 的)抽取到一起,作为整个 RowGroup 的 Guarantee。 ParquetFileFragment::TestRowGroups 去推一个表达式下来,拿到所有 RowGroup 化简后的表达式,尝试对其 SimplifyWithGuarantee,然后消除 !Expression::IsSatisfiable() 的 RowGroup

这几个地方之前都有介绍,便不赘述了. Expression::IsSatisfiable() 大部分还是靠手写规则来实现的。这里有个比较有意思的是 null 被判断为不可满足的

SimplifyWithGuarantee

签名如下:

1
2
3
4
5
6
7
/// Simplify an expression by replacing subexpressions based on a guarantee:
/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is
/// used to remove redundant function calls from a filter expression or to replace a
/// reference to a constant-value field with a literal.
ARROW_EXPORT
Result<Expression> SimplifyWithGuarantee(Expression,
const Expression& guaranteed_true_predicate);

根据 guaranteed_true_predicate 的表达式,来优化 Expression。这里实现如下:

  1. 通过 GuaranteeConjunctionMembers ,抽取 guaranteekleene_and 的成员列表作为 vector
  2. 下列用各种 adhoc 手段进行尝试
    1. ExtractKnownFieldValues 尝试抽取 and 列表中所有 a = value 这样的规则,抽成 map 的形式( 见 KnownFieldValues ),然后用这个等式带入待优化的表达式,进行化简。化简调用了 ReplaceFieldsWithKnownValues
    2. 对于 vector 的成员,尝试抽取 Inequality, 用 Inequality::ExtractOne 来尝试抽取成员。然后尝试用 Inequality 来优化

我们可以分析一下中间的过程:

这里有表达式的中间工具,我们之前提到的 Canonical 和 Consant Folding 的形式:

这里具体是,原先的表达式(而不是 guarantee)被优化后,对其尝试 Canonical 成 Call(FieldRef, Literal) 和按照 Call 来折叠的形式 ,然后尝试常量折叠,来化简这个表达式。

KnowValues: guarantee equal

1
2
3
4
5
6
/// \brief Extract an equality from an expression.
///
/// Recognizes expressions of the form:
/// equal(a, 2)
/// is_null(a)
std::optional<std::pair<FieldRef, Datum>> ExtractOneFieldValue(const Expression& guarantee);

ExtractKnownFieldValues 是一版本还是挺 adhoc 的实现,特殊在于其实来一个 or,就能报废掉这玩意,因为 Expression 里面有 or 代表无法抽取出合适的 ExtractOneFieldValue,只能说这是特定场景(因为 Parquet 的 Guarantee 和分区之类地方产生的 Guarantee 无法生成这样的数据?)的实现吧。

注意到一个特别的特点,这里抽取了 is_null,而不只是 equal,此外,is_valid 会在下面的逻辑中被处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// Conjunction members which are represented in known_values are erased from
// conjunction_members
Status ExtractKnownFieldValues(std::vector<Expression>* conjunction_members,
KnownFieldValues* known_values) {
// filter out consumed conjunction members, leaving only unconsumed
*conjunction_members = arrow::internal::FilterVector(
std::move(*conjunction_members),
[known_values](const Expression& guarantee) -> bool {
if (auto known_value = ExtractOneFieldValue(guarantee)) {
known_values->map.insert(std::move(*known_value));
return false;
}
return true;
});

return Status::OK();
}

这里调用 ReplaceFieldsWithKnownValues 去覆盖用户表达式中的逻辑。这里可能需要在替换的时候插入 cast

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
Result<Expression> ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values,
Expression expr) {
if (!expr.IsBound()) {
return Status::Invalid(
"ReplaceFieldsWithKnownValues called on an unbound Expression");
}

return ModifyExpression(
std::move(expr),
[&known_values](Expression expr) -> Result<Expression> {
if (auto ref = expr.field_ref()) {
auto it = known_values.map.find(*ref);
if (it != known_values.map.end()) {
Datum lit = it->second;
if (lit.type()->Equals(*expr.type())) return literal(std::move(lit));
// type mismatch, try casting the known value to the correct type

if (expr.type()->id() == Type::DICTIONARY &&
lit.type()->id() != Type::DICTIONARY) {
// the known value must be dictionary encoded

const auto& dict_type = checked_cast<const DictionaryType&>(*expr.type());
if (!lit.type()->Equals(dict_type.value_type())) {
ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, dict_type.value_type()));
}

if (lit.is_scalar()) {
ARROW_ASSIGN_OR_RAISE(auto dictionary,
MakeArrayFromScalar(*lit.scalar(), 1));

lit = Datum{DictionaryScalar::Make(MakeScalar<int32_t>(0),
std::move(dictionary))};
}
}

ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type()->GetSharedPtr()));
return literal(std::move(lit));
}
}
return expr;
},
[](Expression expr, ...) { return expr; });
}

ModifyExpression 的实现

这里有个很有意思的地方是它调用 ModifyExpression 去优化每个子表达式,这函数我原封不动搬上来,因为非常好理解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
/// Modify an Expression with pre-order and post-order visitation.
/// `pre` will be invoked on each Expression. `pre` will visit Calls before their
/// arguments, `post_call` will visit Calls (and no other Expressions) after their
/// arguments. Visitors should return the Identical expression to indicate no change; this
/// will prevent unnecessary construction in the common case where a modification is not
/// possible/necessary/...
///
/// If an argument was modified, `post_call` visits a reconstructed Call with the modified
/// arguments but also receives a pointer to the unmodified Expression as a second
/// argument. If no arguments were modified the unmodified Expression* will be nullptr.
template <typename PreVisit, typename PostVisitCall>
Result<Expression> ModifyExpression(Expression expr, const PreVisit& pre,
const PostVisitCall& post_call) {
ARROW_ASSIGN_OR_RAISE(expr, Result<Expression>(pre(std::move(expr))));

auto call = expr.call();
if (!call) return expr;

bool at_least_one_modified = false;
std::vector<Expression> modified_arguments;

for (size_t i = 0; i < call->arguments.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(auto modified_argument,
ModifyExpression(call->arguments[i], pre, post_call));

if (Identical(modified_argument, call->arguments[i])) {
continue;
}

if (!at_least_one_modified) {
modified_arguments = call->arguments;
at_least_one_modified = true;
}

modified_arguments[i] = std::move(modified_argument);
}

if (at_least_one_modified) {
// reconstruct the call expression with the modified arguments
auto modified_call = *call;
modified_call.arguments = std::move(modified_arguments);
return post_call(Expression(std::move(modified_call)), &expr);
}

return post_call(std::move(expr), NULLPTR);
}

is_valid: Nullablity

这段逻辑非常弱智:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
for (const auto& guarantee : conjunction_members) {
if (!guarantee.call()) continue;

// ...

if (guarantee.call()->function_name == "is_valid") {
ARROW_ASSIGN_OR_RAISE(
auto simplified,
SimplifyIsValidGuarantee(std::move(expr), *CallNotNull(guarantee)));

if (Identical(simplified, expr)) continue;

expr = std::move(simplified);
RETURN_NOT_OK(CanonicalizeAndFoldConstants());
}
}

和替换:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/// \brief Simplify an expression given a guarantee, if the guarantee
/// is is_valid().
Result<Expression> SimplifyIsValidGuarantee(Expression expr,
const Expression::Call& guarantee) {
if (guarantee.function_name != "is_valid") return expr;

return ModifyExpression(
std::move(expr), [](Expression expr) { return expr; },
[&](Expression expr, ...) -> Result<Expression> {
auto call = expr.call();
if (!call) return expr;

if (call->arguments[0] != guarantee.arguments[0]) return expr;

if (call->function_name == "is_valid") return literal(true);

if (call->function_name == "true_unless_null") return literal(true);

if (call->function_name == "is_null") return literal(false);

return expr;
});
}

需要注意的是,这里其实也可以处理

Inequality: Bound optimization

Inequality 实际上是这里比较有意思的代码,这里抽取的是 >= < 这种表达式和 is_in 这样的表达式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// An inequality comparison which a target Expression is known to satisfy. If nullable,
// the target may evaluate to null in addition to values satisfying the comparison.
struct Inequality {
// The inequality type
Comparison::type cmp;
// The LHS of the inequality
const FieldRef& target;
// The RHS of the inequality
const Datum& bound;
// Whether target can be null
bool nullable;
// Extract an Inequality if possible, derived from "less",
// "greater", "less_equal", and "greater_equal" expressions,
// possibly disjuncted with an "is_null" Expression.
// cmp(a, 2)
// cmp(a, 2) or is_null(a)
static std::optional<Inequality> ExtractOne(const Expression& guarantee)
};

struct Comparison {
enum type {
NA = 0,
EQUAL = 1,
LESS = 2,
GREATER = 4,
NOT_EQUAL = LESS | GREATER,
LESS_EQUAL = LESS | EQUAL,
GREATER_EQUAL = GREATER | EQUAL,
};
};

这里形式上会抽取出 Inequality 的形式。然后得以用这个 Inequality 去优化 Expression。这部分的抽取代码相对简单一点:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
static std::optional<Inequality> ExtractOneFromComparison(const Expression& guarantee) {
auto call = guarantee.call();
if (!call) return std::nullopt;

if (auto cmp = Comparison::Get(call->function_name)) {
// not_equal comparisons are not very usable as guarantees
if (*cmp == Comparison::NOT_EQUAL) return std::nullopt;

auto target = call->arguments[0].field_ref();
if (!target) return std::nullopt;

auto bound = call->arguments[1].literal();
if (!bound) return std::nullopt;
if (!bound->is_scalar()) return std::nullopt;

return Inequality{*cmp, /*target=*/*target, *bound, /*nullable=*/false};
}

return std::nullopt;
}

化简的代码,就是需要用这个 Inequality 去化简原本的表达式。

这里用 Inequality::Simplifypost_call 里面优化每个子表达式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
for (const auto& guarantee : conjunction_members) {
if (!guarantee.call()) continue;

if (auto inequality = Inequality::ExtractOne(guarantee)) {
ARROW_ASSIGN_OR_RAISE(auto simplified,
ModifyExpression(
std::move(expr), [](Expression expr) { return expr; },
[&](Expression expr, ...) -> Result<Expression> {
return inequality->Simplify(std::move(expr));
}));

if (Identical(simplified, expr)) continue;

expr = std::move(simplified);
RETURN_NOT_OK(CanonicalizeAndFoldConstants());
}
/// .. 优化 guarantee 是 is_valid 的代码
}

Inequality::Simplify

这里会先尝试针对 Inequalitynullable,去尝试优化 is_nullis_valid 这样的 call。

1
2
3
4
5
6
7
8
if (call->function_name == "is_valid" || call->function_name == "is_null") {
if (guarantee.nullable) return expr;
const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]);
if (!lhs.field_ref()) return expr;
if (*lhs.field_ref() != guarantee.target) return expr;

return call->function_name == "is_valid" ? literal(true) : literal(false);
}

然后会有一个新实现的 is_in 优化,这个优化非常有意思,它会尽力改写表达式,比如 a > 10 && is_valid(a)a in [1 , 2, 11] 会被优化成 a in [11], 这个 pr ( https://github.com/apache/arrow/commit/44b72d5c2518b7dc70b67b588432fb06ea3896c7 ) 有这段代码的完整实现

最后这里会根据 Comparison::StripOrderPreservingCast 抽取出来的 Lit 来进行一些实现,这里具体来说还是比较朴素的。这里会再次根据 is_null 来尝试判断,如果包含 nullable 的话,这里其实会处理成比较奇怪的形式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/// The given expression simplifies to `value` if the inequality
/// target is not nullable. Otherwise, it simplifies to either a
/// call to true_unless_null or !true_unless_null.
Result<Expression> simplified_to(const Expression& bound_target, bool value) const {
if (!nullable) return literal(value);

ExecContext exec_context;

// Data may be null, so comparison will yield `value` - or null IFF the data was null
//
// true_unless_null is cheap; it purely reuses the validity bitmap for the values
// buffer. Inversion is less cheap but we expect that term never to be evaluated
// since invert(true_unless_null(x)) is not satisfiable.
Expression::Call call;
call.function_name = "true_unless_null";
call.arguments = {bound_target};
ARROW_ASSIGN_OR_RAISE(
auto true_unless_null,
BindNonRecursive(std::move(call),
/*insert_implicit_casts=*/false, &exec_context));
if (value) return true_unless_null;

Expression::Call invert;
invert.function_name = "invert";
invert.arguments = {std::move(true_unless_null)};
return BindNonRecursive(std::move(invert),
/*insert_implicit_casts=*/false, &exec_context);
}