Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions be/src/exprs/lambda_function/lambda_function_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LambdaFunctionFactory;

void register_function_array_map(LambdaFunctionFactory& factory);
void register_function_array_filter(LambdaFunctionFactory& factory);
void register_function_array_sort(LambdaFunctionFactory& factory);

class LambdaFunctionFactory {
using Creator = std::function<LambdaFunctionPtr()>;
Expand Down Expand Up @@ -62,6 +63,7 @@ class LambdaFunctionFactory {
std::call_once(oc, []() {
register_function_array_map(instance);
register_function_array_filter(instance);
register_function_array_sort(instance);
});
return instance;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DictGet;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DictGetMany;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
Expand Down Expand Up @@ -584,6 +585,61 @@ public Expr visitArrayMap(ArrayMap arrayMap, PlanTranslatorContext context) {
return functionCallExpr;
}

@Override
public Expr visitArraySort(ArraySort arraySort, PlanTranslatorContext context) {
if (!(arraySort.child(0) instanceof Lambda)) {
return visitScalarFunction(arraySort, context);
}
Lambda lambda = (Lambda) arraySort.child(0);
List<Expr> arguments = new ArrayList<>(arraySort.children().size());
arguments.add(null);

// Construct the first column
ArrayItemReference arrayItemReference = lambda.getLambdaArgument(0);
String argName = arrayItemReference.getName();
Expr expr = arrayItemReference.getArrayExpression().accept(this, context);
arguments.add(expr);
ColumnRefExpr column = new ColumnRefExpr();
column.setName(argName);
column.setColumnId(0);
column.setNullable(true);
column.setType(((ArrayType) expr.getType()).getItemType());
context.addExprIdColumnRefPair(arrayItemReference.getExprId(), column);

// the second column here will not be used; it's just a placeholder.
arrayItemReference = lambda.getLambdaArgument(1);
ColumnRefExpr column2 = new ColumnRefExpr(column);
column2.setColumnId(1);
context.addExprIdColumnRefPair(arrayItemReference.getExprId(), column2);

List<Type> argTypes = arraySort.getArguments().stream()
.map(Expression::getDataType)
.map(DataType::toCatalogDataType)
.collect(Collectors.toList());
// two slots are same, we only need one
lambda.getLambdaArguments().stream().skip(1)
.map(ArrayItemReference::getArrayExpression)
.map(Expression::getDataType)
.map(DataType::toCatalogDataType)
.forEach(argTypes::add);
NullableMode nullableMode = arraySort.nullable()
? NullableMode.ALWAYS_NULLABLE
: NullableMode.ALWAYS_NOT_NULLABLE;
Type itemType = ((ArrayType) arguments.get(1).getType()).getItemType();
org.apache.doris.catalog.Function catalogFunction = new Function(
new FunctionName(arraySort.getName()), argTypes,
ArrayType.create(itemType, true),
true, true, nullableMode);

// create catalog FunctionCallExpr without analyze again
Expr lambdaBody = visitLambda(lambda, context);
arguments.set(0, lambdaBody);
LambdaFunctionCallExpr functionCallExpr = new LambdaFunctionCallExpr(catalogFunction,
new FunctionParams(false, arguments));
functionCallExpr.setNullableFromNereids(arraySort.nullable());
return functionCallExpr;
}

@Override
public Expr visitDictGet(DictGet dictGet, PlanTranslatorContext context) {
List<Expr> arguments = dictGet.getArguments().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ private UnboundFunction processHighOrderFunction(UnboundFunction unboundFunction
// bindLambdaFunction
Lambda lambda = (Lambda) unboundFunction.children().get(0);
Expression lambdaFunction = lambda.getLambdaFunction();
List<ArrayItemReference> arrayItemReferences = lambda.makeArguments(subChildren);
List<ArrayItemReference> arrayItemReferences = lambda.makeArguments(unboundFunction.getName(), subChildren);

List<Slot> boundedSlots = arrayItemReferences.stream()
.map(ArrayItemReference::toSlot)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMatchAll;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMatchAny;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSplit;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySortBy;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySplit;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
Expand Down Expand Up @@ -279,6 +280,17 @@ public Void visitArrayMap(ArrayMap arrayMap, CollectorContext context) {
return visit(arrayMap, context);
}

@Override
public Void visitArraySort(ArraySort arraySort, CollectorContext context) {
// ARRAY_SORT(lambda, <arr>)

Expression argument = arraySort.getArgument(0);
if ((argument instanceof Lambda)) {
return collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arraySort, context);
}

@Override
public Void visitArrayCount(ArrayCount arrayCount, CollectorContext context) {
// ARRAY_COUNT(<lambda>, <arr>[, ... ])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.LambdaType;
import org.apache.doris.nereids.types.coercion.AnyDataType;

import com.google.common.base.Preconditions;
Expand All @@ -40,7 +42,8 @@ public class ArraySort extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX))
FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX)),
FunctionSignature.ret(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX)).args(LambdaType.INSTANCE)
);

/**
Expand Down Expand Up @@ -77,7 +80,35 @@ public void checkLegalityBeforeTypeCoercion() {
@Override
public ArraySort withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new ArraySort(getFunctionParams(children));
return new ArraySort(children.get(0));
}

@Override
public DataType getDataType() {
if (children.get(0) instanceof Lambda) {
Lambda lambda = (Lambda) children.get(0);
ArrayItemReference argRef = lambda.getLambdaArguments().get(0);
Expression arrayExpr = argRef.getArrayExpression();
ArrayType arrayType = (ArrayType) arrayExpr.getDataType();
return ArrayType.of(arrayType.getItemType(), true);
} else if (children.get(0).getDataType() instanceof ArrayType) {
Expression arrayExpr = children.get(0);
ArrayType arrayType = (ArrayType) arrayExpr.getDataType();
return ArrayType.of(arrayType.getItemType(), true);
} else {
throw new AnalysisException("The first arg of array_sort must be lambda or array");
}
}

@Override
public boolean nullable() {
if (children.get(0) instanceof Lambda) {
return ((Lambda) children.get(0)).getLambdaArguments().stream()
.map(ArrayItemReference::getArrayExpression)
.anyMatch(Expression::nullable);
} else {
return child(0).nullable();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,23 @@ public Lambda(List<String> argumentNames, List<Expression> children) {

/**
* make slot according array expression
* @param functionName function name
* @param arrays array expression
* @return item slots of array expression
*/
public ImmutableList<ArrayItemReference> makeArguments(List<Expression> arrays) {
public ImmutableList<ArrayItemReference> makeArguments(String functionName, List<Expression> arrays) {
Builder<ArrayItemReference> builder = new ImmutableList.Builder<>();
if (arrays.size() != argumentNames.size()) {
// In the lambda expression of array_sort, x and y point to the same slot.
if (functionName.equalsIgnoreCase("array_sort") && arrays.size() == 1 && argumentNames.size() == 2) {
Expression array = arrays.get(0);
if (!(array.getDataType() instanceof ArrayType)) {
throw new AnalysisException(String.format("lambda argument must be array but is %s", array));
}
builder.add(new ArrayItemReference(argumentNames.get(0), array));
builder.add(new ArrayItemReference(argumentNames.get(1), array));
return builder.build();
}
throw new AnalysisException(String.format("lambda %s arguments' size is not equal parameters' size",
toSql()));
}
Expand Down
Loading
Loading