diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8aa5612d00c80..6c48da22aa8d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -640,7 +640,9 @@ class Analyzer( Seq(ResolveUpdateEventTimeWatermarkColumn) ++ extendedResolutionRules ++ Seq(NameStreamingSources) : _*), - Batch("Remove TempResolvedColumn", Once, RemoveTempResolvedColumn), + Batch("Remove analysis-only markers", Once, + RemoveTempResolvedColumn, + RemoveInputTypeMarkers), Batch("Post-Hoc Resolution", Once, Seq(ResolveCommandsWithIfExists) ++ postHocResolutionRules: _*), @@ -4446,6 +4448,40 @@ object EliminateUnions extends Rule[LogicalPlan] { } } +/** + * Removes the analysis-only input-type markers ([[ImplicitCastInput]] / [[TypeCheckInput]]) that a + * [[DelegateFunction]] inserts to drive or check implicit cast. Once type coercion has run they + * have served their purpose, so we strip them at the end of analysis, leaving a clean `definition` + * in the [[DelegateExpression]]. Like [[RemoveTempResolvedColumn]], this just unwraps a marker to + * its child; it is not load-bearing -- a `DelegateExpression` is correct with or without markers. + * + * Only a *resolved* marker is unwrapped (its child's type satisfied the implicit-cast / type-check + * contract). An unresolved marker -- a `TypeCheckInput` whose `implicitCast = false` check failed, + * or an `ImplicitCastInput` whose argument could not be cast -- is left in place so its + * `ExpectsInputTypes` failure stays visible to `CheckAnalysis`. Unwrapping it unconditionally would + * expose a resolved child of the wrong type and let analysis silently accept a mismatched argument. + */ +object RemoveInputTypeMarkers extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = + plan.resolveExpressionsWithPruning(_.containsPattern(INPUT_TYPE_MARKER)) { + case marker: ImplicitCastInput if marker.resolved => marker.child + case marker: TypeCheckInput if marker.resolved => marker.child + } + + /** + * Expression-level unwrap, for callers that have no rule batch to run this rule in -- notably the + * single-pass resolver, which builds `DelegateFunction`s (inserting markers) but does not execute + * the fixed-point analyzer's batches. Apply it once type coercion has cast the marker children. + * Like [[apply]], only resolved markers are unwrapped; a failed type check is left for the + * analyzer to report. + */ + def removeMarkers(expression: Expression): Expression = + expression.transformUpWithPruning(_.containsPattern(INPUT_TYPE_MARKER)) { + case marker: ImplicitCastInput if marker.resolved => marker.child + case marker: TypeCheckInput if marker.resolved => marker.child + } +} + /** * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level * expression in Project(project list) or Aggregate(aggregate expressions) or diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a8654491a2697..925b03c8150d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -661,7 +661,7 @@ object FunctionRegistry { expression[Substring]("substr", true), expression[Substring]("substring"), expression[Left]("left"), - expression[Right]("right"), + expressionBuilder("right", Right), expression[SubstringIndex]("substring_index"), expression[StringTranslate]("translate"), expression[StringTrim]("trim"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala index dd70963a79841..83d62b4326d91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala @@ -22,12 +22,14 @@ import scala.util.Random import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{ FunctionResolution, + RemoveInputTypeMarkers, UnresolvedFunction, UnresolvedSeed } import org.apache.spark.sql.catalyst.expressions.{ BinaryArithmetic, Collate, + DelegateExpression, Expression, ExpressionWithRandomSeed, InheritAnalysisRules, @@ -200,6 +202,21 @@ class FunctionResolver( case windowFunction: WindowFunction if (expressionResolutionContext.windowFunctionNestednessLevel != 1) => throwWindowFunctionWithoutOverClause(windowFunction) + case delegateExpression: DelegateExpression => + // `DelegateFunction.build` produces a freshly-built `definition` subtree -- like + // `InheritAnalysisRules`' `replacement` above -- so resolve its children recursively; + // this reaches the analysis-only input-type markers ([[ImplicitCastInput]] / + // [[TypeCheckInput]]) buried inside and lets coercion cast their children. The fixed-point + // analyzer then strips the markers in `RemoveInputTypeMarkers`, but single-pass has no such + // batch, so strip them here once coercion has run. + val resolvedDelegateExpression = + withResolvedChildren(delegateExpression, expressionResolver.resolve _) + RemoveInputTypeMarkers.removeMarkers( + coerceExpressionTypes( + expression = resolvedDelegateExpression, + expressionTreeTraversal = traversals.current + ) + ) case other => coerceExpressionTypes(expression = other, expressionTreeTraversal = traversals.current) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala index 3ba0121dbb64e..4045321055111 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala @@ -558,7 +558,8 @@ class ResolverGuard( _: TryValidateUTF8 | _: StringReplace | _: Overlay | _: StringTranslate | _: FindInSet | _: String2TrimExpression | _: StringTrimBoth | _: StringInstr | _: SubstringIndex | _: StringLocate | _: StringLPad | _: BinaryPad | _: StringRPad | _: FormatString | - _: InitCap | _: StringRepeat | _: StringSpace | _: Substring | _: Right | _: Left | + _: InitCap | _: StringRepeat | _: StringSpace | _: Substring | _: DelegateExpression | + _: Left | _: Length | _: BitLength | _: OctetLength | _: Levenshtein | _: SoundEx | _: Ascii | _: Chr | _: Base64 | _: UnBase64 | _: Decode | _: StringDecode | _: Encode | _: ToBinary | _: FormatNumber | _: Sentences | _: StringSplitSQL | _: SplitPart | _: Empty2Null | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DelegateExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DelegateExpression.scala new file mode 100644 index 0000000000000..166e282a81a6d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DelegateExpression.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.trees.TreePattern.{DELEGATE_EXPRESSION, INPUT_TYPE_MARKER, TreePattern} +import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType} + +/** + * A transparent, named delegate over a `definition` expression -- a LOGICAL-phase construct. + * + * `DelegateExpression` lets a high-level function (e.g. `right(a, b)`) stay readable in the + * analyzed and optimized logical plan, and lets optimizer rules introduce such nodes (e.g. + * `multi_get_json_object`), without hand-written `eval`/`doGenCode`. Every behavior delegates to + * `definition`, a real child fully visible to the analyzer and optimizer. + * + * `name`/`inputs` are purely informational (EXPLAIN/SQL): nothing enforces that `definition` + * matches what they claim, so the wrapper is never exposed to physical planning or external + * systems. + * `LowerDelegateExpression` strips it to `definition` in `QueryExecution.createSparkPlan` -- the + * single entry point to the planner, used by both the main query and AQE re-planning -- so the + * planner and every physical consumer (join-key extraction, data source pushdown, columnar rules, + * codegen) sees the real executed expression. (Data source V2 pushdown runs earlier, in the logical + * optimizer, so it unfolds the wrapper directly in `V2ExpressionBuilder`.) The wrapper survives the + * logical optimizer, so the optimized plan stays readable and optimizer rules can introduce these + * nodes; `eval`/`doGenCode` still delegate, as a safety net if a delegate ever reaches execution. + * + * Note: because the strip runs before planning, a `DelegateExpression` created by a *physical* rule + * (after `createSparkPlan`) is not stripped and may reach an external system un-lowered. That is + * acceptable -- like any other expression the system does not recognize, it simply falls back, and + * `eval`/`doGenCode` keep it correct within Spark. Analysis- and optimizer-inserted nodes (the + * common case) are always stripped, so physical-rule insertion is the only uncovered path. + * + * By the same token, a *logical* consumer that pattern-matches a specific shape does not see + * through the wrapper to its `definition`: e.g. `ExtractEquiJoinKeys`, `PartitionPruning`, + * `InjectRuntimeFilter` and `StreamingJoinHelper` all run before `createSparkPlan` and match a bare + * `EqualTo`/`And`/..., so a delegate whose `definition` is a join equality would not be recognized + * as an equi-join key (losing DPP / runtime filters / equi-key stats) until it is lowered. Today no + * delegate produces a boolean/predicate `definition` used as a join condition (`right` returns a + * string; `multi_get_json_object` returns a struct), so this is latent, not an active gap; a future + * predicate-shaped delegate meant to participate in such rewrites must either lower earlier or + * teach those consumers to look through the wrapper. + */ +case class DelegateExpression( + name: String, + inputs: Seq[Expression], + definition: Expression) + extends Expression with UnaryLike[Expression] { + + override def child: Expression = definition + override def dataType: DataType = definition.dataType + override def nullable: Boolean = definition.nullable + override def foldable: Boolean = definition.foldable + // Delegate `nullIntolerant` too (it is not derived from children, unlike `throwable`), so that + // null-intolerance optimizations -- `IsNotNull`-constraint inference in + // `QueryPlanConstraints.scanNullIntolerantAttribute` and `NullPropagation`'s `IsNull`/`IsNotNull` + // simplifications -- still fire while the wrapper is in the logical plan (e.g. for the + // `multi_get_json_object` delegate, whose `Invoke` definition is null-intolerant). + override def nullIntolerant: Boolean = definition.nullIntolerant + override lazy val deterministic: Boolean = definition.deterministic + override lazy val canonicalized: Expression = definition.canonicalized + + override def eval(input: InternalRow): Any = definition.eval(input) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + definition.genCode(ctx) + + final override val nodePatterns: Seq[TreePattern] = Seq(DELEGATE_EXPRESSION) + override protected def withNewChildInternal(newChild: Expression): DelegateExpression = + copy(definition = newChild) + + override def prettyName: String = name + override def sql: String = s"$name(${inputs.map(_.sql).mkString(", ")})" + override def toString: String = s"$name(${inputs.mkString(", ")})" +} + +/** + * Common behavior of the two analysis-only input-type markers. A marker is removed at the end of + * analysis by [[org.apache.spark.sql.catalyst.analysis.RemoveInputTypeMarkers]]; it survives only + * when its type check failed, in which case `CheckAnalysis` -- walking bottom-up -- reports the + * marker (the deepest failing node) before it reaches the enclosing [[DelegateExpression]]. + * + * A marker is an analysis-only implementation detail, so it must not surface in the user-facing + * error. It therefore reports as if the check ran on the high-level delegate call itself, matching + * what the pre-delegate function produced: + * - `checkInputDataTypes` reports the true argument position (`argIndex`) rather than the + * marker's only child (which `ExpectsInputTypes` would always label the "first" parameter), and + * - `sql` renders the delegate call (`callSql`) rather than the internal `implicitcastinput(...)` + * / `typecheckinput(...)` name, so `sqlExpr`/`toSQLExpr` stay attributed to e.g. `right(...)`. + * + * `funcName`/`argIndex`/`callSql` are supplied by [[DelegateFunction.build]]. They default to empty + * for direct construction (e.g. tests), where the marker never reaches error reporting. + */ +// A pure self-typed trait (not extending `UnaryExpression`/`ExpectsInputTypes` itself) so it can be +// mixed in LAST in the linearization -- that is what makes its `checkInputDataTypes` override win +// over the `ExpectsInputTypes` one. If it extended those directly, the concrete-type mix-in order +// needed to win the override would violate the superclass constraint. +trait InputTypeMarker extends Unevaluable { + self: UnaryExpression with ExpectsInputTypes => + def expectedType: AbstractDataType + def funcName: String + def argIndex: Int + def callSql: String + + override def inputTypes: Seq[AbstractDataType] = Seq(expectedType) + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable + override lazy val canonicalized: Expression = child.canonicalized + final override val nodePatterns: Seq[TreePattern] = Seq(INPUT_TYPE_MARKER) + + override def checkInputDataTypes(): TypeCheckResult = { + ExpectsInputTypes.checkInputDataTypes(children, inputTypes) match { + // The base check labels the marker's only child index 0 ("first"); relabel it with the real + // argument position within the delegate call. + case m: DataTypeMismatch if m.messageParameters.contains("paramIndex") => + m.copy(messageParameters = + m.messageParameters + ("paramIndex" -> ExpectsInputTypes.ordinalNumber(argIndex))) + case other => other + } + } + + override def sql: String = if (callSql.nonEmpty) callSql else s"$prettyName(${child.sql})" +} + +/** + * Analysis-only marker that requests an implicit cast of `child` to `expectedType`: it declares the + * expected type so the standard `TypeCoercion` rule casts the child, then is removed at the end of + * analysis. It never reaches execution, hence [[Unevaluable]]. Modeled on + * [[org.apache.spark.sql.catalyst.analysis.TempResolvedColumn]]. + */ +case class ImplicitCastInput( + child: Expression, + expectedType: AbstractDataType, + funcName: String = "", + argIndex: Int = 0, + callSql: String = "") + extends UnaryExpression with ImplicitCastInputTypes with InputTypeMarker { + override protected def withNewChildInternal(newChild: Expression): ImplicitCastInput = + copy(child = newChild) +} + +/** + * Analysis-only marker that requires `child` to already match `expectedType` (no cast is inserted), + * failing analysis otherwise. Removed at the end of analysis like [[ImplicitCastInput]]. + */ +case class TypeCheckInput( + child: Expression, + expectedType: AbstractDataType, + funcName: String = "", + argIndex: Int = 0, + callSql: String = "") + extends UnaryExpression with ExpectsInputTypes with InputTypeMarker { + override protected def withNewChildInternal(newChild: Expression): TypeCheckInput = + copy(child = newChild) +} + +/** + * The per-function object each built-in function defines (e.g. `object Right extends + * DelegateFunction`). It is just an [[ExpressionBuilder]] -- registered with the ordinary + * `expressionBuilder(...)`, with its `@ExpressionDescription` annotation read off the object as + * usual -- specialized for the delegate pattern: replace the `InheritAnalysisRules` ceremony with + * one `lower` method plus a couple of flags. `apply` is the direct-construction entry point. + * + * Input-type contract, covering all three cases (applied per argument): + * - `inputTypes` empty (or `AnyDataType` for a position): accept any type (no check, no cast). + * - `inputTypes` set, `implicitCast = true` (default): implicit-cast each arg to its type. + * - `inputTypes` set, `implicitCast = false` : type-check each arg, no cast. + */ +trait DelegateFunction extends ExpressionBuilder { + def name: String + def inputTypes: Seq[AbstractDataType] = Nil + def implicitCast: Boolean = true + + /** Lower the function into the expression it delegates to. */ + def lower(args: Seq[Expression]): Expression + + /** + * ExpressionBuilder contract: invoked by the registry during function resolution. ONLY this + * (analysis-time) path inserts the input-type markers, because the analyzer's `TypeCoercion` + * casts them and `RemoveInputTypeMarkers` strips them afterwards. + */ + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + // `inputTypes` carries one entry per argument position (`AnyDataType` for an accept-any-type + // position), so when it is set its length is the function's arity. Validate it here so a + // wrong-arity call fails with the structured WRONG_NUM_ARGS error rather than an + // IndexOutOfBounds from `lower` (too few args) or a silently-ignored extra argument (too many). + // An empty `inputTypes` marks a variadic function whose `lower` accepts any argument count. + if (inputTypes.nonEmpty && expressions.length != inputTypes.length) { + throw QueryCompilationErrors.wrongNumArgsError( + funcName, Seq(inputTypes.length), expressions.length) + } + // The pretty rendering of the high-level call (e.g. `right('abcd', array(1))`), so a marker + // that survives a failed type check reports the error against the delegate call -- exactly as + // the pre-delegate function did -- rather than leaking the internal marker name. Matches the + // `toPrettySQL`-based `sqlExpr` the fixed-point analyzer produced before this migration. + val callSql = s"$name(${expressions.map(toPrettySQL(_)).mkString(", ")})" + val args = expressions.zipWithIndex.map { case (e, i) => + val expected = if (i < inputTypes.length) inputTypes(i) else AnyDataType + expected match { + case AnyDataType => e + case t if implicitCast => ImplicitCastInput(e, t, name, i, callSql) + case t => TypeCheckInput(e, t, name, i, callSql) + } + } + DelegateExpression(name, expressions, lower(args)) + } + + /** + * Direct construction for use anywhere, including optimizer rules. Unlike [[build]] this inserts + * NO input-type markers -- there is no analyzer pass left to coerce or strip them -- so callers + * must pass arguments that are already resolved and of the expected types, exactly as when + * constructing any other expression (`Add`, `Substring`, ...) after analysis. The resolved + * precondition is asserted so misuse fails loudly here rather than later. + */ + final def apply(inputs: Expression*): DelegateExpression = { + require(inputs.forall(_.resolved), + s"$name: arguments to DelegateFunction.apply must be resolved; use it after analysis " + + "(e.g. in optimizer rules) with already-typed arguments, or register the function and " + + "let the analyzer build it") + DelegateExpression(name, inputs, lower(inputs)) + } + + def unapply(e: Expression): Option[Seq[Expression]] = e match { + case d: DelegateExpression if d.name == name => Some(d.inputs) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index b82ed39824f9d..e79df8d7c6148 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -161,70 +161,63 @@ object GetJsonObject { } /** - * Extracts multiple simple named paths from a JSON string in one parse. This is an internal - * expression used to share sibling [[GetJsonObject]] expressions; unsupported and - * prefix-conflicting JSON paths remain as independent GetJsonObject expressions. + * Builds the internal expression that extracts multiple simple named paths from a JSON string in + * one parse, used to share sibling [[GetJsonObject]] expressions; unsupported and + * prefix-conflicting paths remain as independent `GetJsonObject` expressions. + * + * It is inserted by `OptimizeCsvJsonExprs` (after analysis, so its inputs are resolved), and is the + * optimizer-constructed showcase for [[DelegateExpression]]: instead of hand-written + * eval/doGenCode, it builds a typed delegate directly -- the high-level call + * `multi_get_json_object(json, p1, .., pn)` stays visible via `inputs`, while the `definition` + * delegates evaluation to [[MultiGetJsonObjectEvaluator]] through an `Invoke`. No rewrite step: the + * delegate runs as-is. */ -case class MultiGetJsonObject( - json: Expression, - fallbackPaths: Seq[String]) - extends UnaryExpression - with ExpectsInputTypes { - - // OptimizeCsvJsonExprs caps shared path depth to keep evaluator recursion stack-safe. - require(fallbackPaths.nonEmpty) - - override def child: Expression = json - - override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeWithCollation(supportsTrimCollation = true)) - - override lazy val dataType: DataType = StructType(fallbackPaths.indices.map { index => - StructField(s"_$index", StringType, nullable = true) - }) - - override def nullable: Boolean = true - - // This internal unary expression always returns null when its JSON child is null. - override def nullIntolerant: Boolean = true - - override def prettyName: String = "multi_get_json_object" - - final override val nodePatterns: Seq[TreePattern] = Seq(GET_JSON_OBJECT) - - @transient - private lazy val namedPaths = fallbackPaths.map { path => - GetJsonObject.simpleNamedPath(UTF8String.fromString(path)).getOrElse { - throw new IllegalArgumentException(s"Unsupported shared JSON path: $path") +object MultiGetJsonObject { + val name: String = "multi_get_json_object" + + def apply(json: Expression, fallbackPaths: Seq[String]): DelegateExpression = { + // OptimizeCsvJsonExprs caps shared path depth to keep evaluator recursion stack-safe. + require(fallbackPaths.nonEmpty) + val resultType = StructType(fallbackPaths.indices.map { index => + StructField(s"_$index", StringType, nullable = true) + }) + val namedPaths = fallbackPaths.map { path => + GetJsonObject.simpleNamedPath(UTF8String.fromString(path)).getOrElse { + throw new IllegalArgumentException(s"Unsupported shared JSON path: $path") + } } + val evaluator = + MultiGetJsonObjectEvaluator(fallbackPaths.map(UTF8String.fromString), namedPaths) + // `propagateNull = true` reproduces the old null-intolerant behavior: null json -> null result. + val definition = Invoke( + Literal.create(evaluator, ObjectType(classOf[MultiGetJsonObjectEvaluator])), + "evaluate", + resultType, + Seq(json), + Seq(json.dataType), + returnNullable = true) + // `inputs` keeps the high-level call visible: the json plus one string literal per path. + val pathInputs = fallbackPaths.map(p => Literal(UTF8String.fromString(p), StringType)) + DelegateExpression(name, json +: pathInputs, definition) } - @transient - private lazy val evaluator = MultiGetJsonObjectEvaluator( - fallbackPaths.map(UTF8String.fromString), - namedPaths) - - override def eval(input: InternalRow): Any = { - evaluator.evaluate(json.eval(input).asInstanceOf[UTF8String]) + /** Recovers `(json, fallbackPaths)` from a delegate produced by `apply`. */ + def unapply(e: Expression): Option[(Expression, Seq[String])] = e match { + case d: DelegateExpression if d.name == name => + val paths = d.inputs.tail.map { + case Literal(p: UTF8String, _: StringType) => p.toString + case other => throw new IllegalStateException(s"Unexpected path input: $other") + } + Some((d.inputs.head, paths)) + case _ => None } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val refEvaluator = ctx.addReferenceObj("evaluator", evaluator) - val jsonEval = json.genCode(ctx) - val resultType = CodeGenerator.javaType(dataType) - ev.copy(code = code""" - |${jsonEval.code} - |boolean ${ev.isNull} = ${jsonEval.isNull}; - |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |if (!${ev.isNull}) { - | ${ev.value} = ($resultType) $refEvaluator.evaluate(${jsonEval.value}); - | ${ev.isNull} = ${ev.value} == null; - |} - |""".stripMargin) - } + def isInstance(e: Expression): Boolean = unapply(e).isDefined - override protected def withNewChildInternal(newChild: Expression): MultiGetJsonObject = - copy(json = newChild) + def pathsOf(e: Expression): Seq[String] = unapply(e) match { + case Some((_, paths)) => paths + case None => throw new IllegalArgumentException(s"Not a multi_get_json_object: $e") + } } // scalastyle:off line.size.limit line.contains.tab diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 66c4a39ce8233..a18336dcb19a5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2452,29 +2452,34 @@ case class Substring(str: Expression, pos: Expression, len: Expression) since = "2.3.0", group = "string_funcs") // scalastyle:on line.size.limit -case class Right(str: Expression, len: Expression) extends RuntimeReplaceable - with ImplicitCastInputTypes with BinaryLike[Expression] { - - override lazy val replacement: Expression = If( - IsNull(str), - Literal(null, str.dataType), - If( - LessThanOrEqual(len, Literal(0)), - Literal(UTF8String.EMPTY_UTF8, str.dataType), - new Substring(str, UnaryMinus(len, failOnError = false)) - ) - ) +object Right extends DelegateFunction { + override val name: String = "right" override def inputTypes: Seq[AbstractDataType] = - Seq( - StringTypeWithCollation(supportsTrimCollation = true), - IntegerType - ) - override def left: Expression = str - override def right: Expression = len - override protected def withNewChildrenInternal( - newLeft: Expression, newRight: Expression): Expression = { - copy(str = newLeft, len = newRight) + Seq(StringTypeWithCollation(supportsTrimCollation = true), IntegerType) + + // At build time `str` is the not-yet-coerced argument (wrapped in an `ImplicitCastInput` marker + // that delegates `dataType` to its child), so `str.dataType` is the *input* type, which is not + // necessarily a string yet -- e.g. `right(12345, 2)` has an `IntegerType` child the implicit cast + // will turn into a string. Use it for the null/empty branch literals only when it is already a + // string-family type, so a CHAR(N)/VARCHAR(N) result (under + // `spark.sql.preserveCharVarcharTypeInfo`) or a non-default collation is preserved through the + // `If` branch unification; otherwise fall back to plain `StringType`, the type the implicit cast + // produces. Typing a UTF8String literal with a non-string type would be invalid. + override def lower(args: Seq[Expression]): Expression = { + val str = args(0) + val len = args(1) + val litType = str.dataType match { + case _: StringType | _: CharType | _: VarcharType => str.dataType + case _ => StringType + } + If( + IsNull(str), + Literal(null, litType), + If( + LessThanOrEqual(len, Literal(0)), + Literal(UTF8String.EMPTY_UTF8, litType), + new Substring(str, UnaryMinus(len, failOnError = false)))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala index ff471cd6f00f8..a307814e25fcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala @@ -50,7 +50,7 @@ object NormalizePlan extends PredicateHelper { */ def normalizeExpressions(plan: LogicalPlan): LogicalPlan = { val withNormalizedRuntimeReplaceable = normalizeRuntimeReplaceable(plan) - withNormalizedRuntimeReplaceable.transformAllExpressions { + lazy val rule: PartialFunction[Expression, Expression] = { case subqueryExpression: SubqueryExpression => val normalizedPlan = normalizeExpressions(subqueryExpression.plan) subqueryExpression.withNewPlan(normalizedPlan) @@ -60,7 +60,16 @@ object NormalizePlan extends PredicateHelper { commonExpressionRef.copy(id = new CommonExpressionId(id = 0)) case expressionWithRandomSeed: ExpressionWithRandomSeed => expressionWithRandomSeed.withNewSeed(0) + case d: DelegateExpression => + // `inputs` are display-only metadata, not children, so `transformAllExpressions` never + // reaches them -- yet a `Rand` seed or `CommonExpressionId` there is just as + // run-dependent as in the `definition` child. Normalize them explicitly with the same + // rule (e.g. `right(rand(), 1)` under the Hybrid Analyzer, whose fixed-point and + // single-pass runs pick different seeds, would otherwise fail structural comparison). + // `definition` is reached by the surrounding traversal. + d.copy(inputs = d.inputs.map(_.transform(rule))) } + withNormalizedRuntimeReplaceable.transformAllExpressions(rule) } /** @@ -86,7 +95,15 @@ object NormalizePlan extends PredicateHelper { * we must normalize them to check if two different queries are identical. */ def normalizeExprIds(plan: LogicalPlan): LogicalPlan = { - plan.transformAllExpressions { + // Defined as a named rule (rather than inline) so it can also be applied to a + // `DelegateExpression`'s `inputs`, which are display-only metadata -- not children -- and so + // are never reached by `transformAllExpressions`. Normalizing them explicitly keeps the + // informational call deterministic across runs (e.g. `right(g#0, g#0)` in EXPLAIN), since expr + // ids come from a process-global counter; the `definition` child is reached by the normal + // traversal. + lazy val rule: PartialFunction[Expression, Expression] = { + case d: DelegateExpression => + d.copy(inputs = d.inputs.map(_.transform(rule))) case s: ScalarSubquery => s.copy(plan = normalizeExprIds(s.plan), exprId = ExprId(0)) case s: LateralSubquery => @@ -114,6 +131,7 @@ object NormalizePlan extends PredicateHelper { case a: FunctionTableSubqueryArgumentExpression => a.copy(plan = normalizeExprIds(a.plan), exprId = ExprId(0)) } + plan.transformAllExpressions(rule) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 173df28e2b248..ded9196c4544d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -49,6 +49,7 @@ object TreePattern extends Enumeration { val CREATE_NAMED_STRUCT: Value = Value val CURRENT_LIKE: Value = Value val DATETIME: Value = Value + val DELEGATE_EXPRESSION: Value = Value val DYNAMIC_PRUNING_EXPRESSION: Value = Value val DYNAMIC_PRUNING_SUBQUERY: Value = Value val EXISTS_SUBQUERY = Value @@ -64,6 +65,7 @@ object TreePattern extends Enumeration { val IN: Value = Value val IN_SUBQUERY: Value = Value val INSET: Value = Value + val INPUT_TYPE_MARKER: Value = Value val INVOKE: Value = Value val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index e40a6e6df88b5..d386762fd2c4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -94,6 +94,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L private def generateExpression( expr: Expression, isPredicate: Boolean = false): Option[V2Expression] = expr match { case literal: Literal => Some(translateLiteral(literal)) + // DelegateExpression is a Spark-internal wrapper; push down its real definition instead. + case d: DelegateExpression => generateExpression(d.definition, isPredicate) case _ if expr.contextIndependentFoldable && SQLConf.get.getConf(SQLConf.DATA_SOURCE_V2_EXPR_FOLDING) => // If the expression is context independent foldable, we can convert it to a literal. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 5dc3962821a03..83a8eae901016 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -127,6 +127,25 @@ package object util extends Logging { ), dataType = r.dataType ) + case d: DelegateExpression => + // `inputs` are display-only metadata, not children, so `transform` never descends into them. + // Render the high-level call with each input prettified (qualifiers stripped, string literals + // unquoted, ...) so generated column names match the pre-delegate function, e.g. the column + // name for `right(c7, 2)` stays `right(c7, 2)` rather than `right(spark_catalog....c7, 2)`. + // Like the `InheritAnalysisRules` branch above, pre-trim `TempResolvedColumn` from the + // inputs when requested: `usePrettyExpression` has no marker case, so `toPrettySQL` alone + // would leave `tempresolvedcolumn(...)` wrapping the child (e.g. an aggregate/HAVING alias + // would become `count(right(tempresolvedcolumn(v), 1))` instead of `count(right(v, 1))`). + val prettyInputs = if (shouldTrimTempResolvedColumn) { + d.inputs.map(trimTempResolvedColumn) + } else { + d.inputs + } + PrettyAttribute( + name = s"${d.name}(${prettyInputs + .map(i => toPrettySQL(i, shouldTrimTempResolvedColumn)).mkString(", ")})", + dataType = d.dataType + ) case c: Cast if !c.containsTag(Cast.USER_SPECIFIED_CAST) => PrettyAttribute(usePrettyExpression(c.child, shouldTrimTempResolvedColumn).sql, c.dataType) case p: PythonFuncExpression => PrettyPythonUDF(p.name, p.dataType, p.children) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DelegateExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DelegateExpressionSuite.scala new file mode 100644 index 0000000000000..e3ecff61578bb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DelegateExpressionSuite.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{RemoveInputTypeMarkers, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Validates [[DelegateExpression]] transparency (eval + codegen, via `checkEvaluation` which runs + * both paths) and that [[DelegateFunction]] supports all three input-type contracts. + */ +class DelegateExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + // ---- transparency: every behavior delegates to `definition` ---- + + test("delegates eval and codegen to its definition (foldable)") { + val expr = DelegateExpression("inc", Seq(Literal(10)), Add(Literal(10), Literal(1))) + checkEvaluation(expr, 11) + } + + test("delegates eval and codegen with a non-foldable input") { + val ref = BoundReference(0, IntegerType, nullable = true) + val expr = DelegateExpression("inc", Seq(ref), Add(ref, Literal(1))) + checkEvaluation(expr, 11, InternalRow(10)) + checkEvaluation(expr, null, InternalRow(null)) + } + + test("delegates type/nullability/foldability/determinism and canonicalizes to its definition") { + val ref = BoundReference(0, IntegerType, nullable = true) + val expr = DelegateExpression("inc", Seq(ref), Add(ref, Literal(1))) + assert(expr.dataType == IntegerType) + assert(expr.nullable) + assert(!expr.foldable) + assert(expr.deterministic) + assert(expr.canonicalized == Add(ref, Literal(1)).canonicalized) + assert(DelegateExpression("inc", Seq(Literal(10)), Add(Literal(10), Literal(1))).foldable) + } + + // ---- input-type contracts ---- + + private object CastFn extends DelegateFunction { + override val name = "castfn" + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def implicitCast: Boolean = true + override def lower(args: Seq[Expression]): Expression = args.head + } + + private object CheckFn extends DelegateFunction { + override val name = "checkfn" + override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) + override def implicitCast: Boolean = false + override def lower(args: Seq[Expression]): Expression = args.head + } + + private object AnyFn extends DelegateFunction { + override val name = "anyfn" + // no inputTypes -> accepts any type + override def lower(args: Seq[Expression]): Expression = args.head + } + + // The input-type contracts are the analysis-time `build` path, which inserts the markers. + private def buildDef(fn: DelegateFunction, args: Expression*): Expression = + fn.build(fn.name, args).asInstanceOf[DelegateExpression].definition + + test("contract 1 (implicit cast): args wrapped in ImplicitCastInput (cast happens via " + + "the standard coercion rule)") { + val shim = buildDef(CastFn, Literal(1)).asInstanceOf[ImplicitCastInput] + assert(shim.expectedType == StringType) + // The shim IS an ImplicitCastInputTypes node, so TypeCoercion will cast its child. + assert(shim.isInstanceOf[ImplicitCastInputTypes]) + } + + test("contract 2 (type check only): args wrapped in TypeCheckInput, mismatch rejected, " + + "no cast") { + val ok = buildDef(CheckFn, Literal(1)).asInstanceOf[TypeCheckInput] + assert(ok.checkInputDataTypes().isSuccess) + // A Long is NOT cast down to Int -- it is rejected. + val bad = buildDef(CheckFn, Literal(1L)).asInstanceOf[TypeCheckInput] + assert(bad.checkInputDataTypes().isFailure) + assert(!bad.isInstanceOf[ImplicitCastInputTypes]) + } + + test("contract 3 (any type): no inputTypes -> no shim, arg passed through unchanged") { + assert(buildDef(AnyFn, Literal(1L)) == Literal(1L)) + } + + test("nullIntolerant is delegated to the definition") { + // An Invoke with the default propagateNull = true is null-intolerant; a bare Literal is not. + val invoke = Invoke(Literal("x"), "toString", StringType) + assert(invoke.nullIntolerant) + assert(DelegateExpression("f", Seq(Literal("x")), invoke).nullIntolerant, + "the wrapper should report its null-intolerant definition's null-intolerance") + assert(!DelegateExpression("g", Seq(Literal(1)), Literal(1)).nullIntolerant) + } + + test("build validates argument count against the inputTypes arity") { + // CheckFn declares one typed input, so build rejects any other arity with WRONG_NUM_ARGS rather + // than indexing past the args (too few) or silently ignoring extras (too many). + Seq(Seq.empty[Expression], Seq(Literal(1), Literal(2))).foreach { args => + val e = intercept[AnalysisException](CheckFn.build(CheckFn.name, args)) + assert(e.getCondition == "WRONG_NUM_ARGS.WITHOUT_SUGGESTION") + } + // AnyFn has no inputTypes -> it is variadic and `lower` owns the arg handling, so no arity + // check. + assert(AnyFn.build(AnyFn.name, Seq(Literal(1), Literal(2))).isInstanceOf[DelegateExpression]) + } + + test("a surviving marker reports transparently: delegate-call sql and the real argument index") { + // A marker only survives when its type check failed, and `CheckAnalysis` (bottom-up) reports it + // before the enclosing delegate. It must therefore look like the high-level call, not an + // internal shim: its `sql` is the delegate call, and its type-check error carries the true + // argument position rather than the marker's only-child index 0 ("first"). + val marker = ImplicitCastInput( + Literal(1L), StringType, funcName = "castfn", argIndex = 1, callSql = "castfn(x, 1)") + assert(marker.sql == "castfn(x, 1)") + marker.checkInputDataTypes() match { + case m: DataTypeMismatch => + assert(m.messageParameters("paramIndex") == "second", + s"expected the real argument index, got ${m.messageParameters("paramIndex")}") + case other => fail(s"expected a DataTypeMismatch, got $other") + } + // With no supplied context (direct construction), `sql` falls back to the plain node rendering. + assert( + ImplicitCastInput(Literal(1L), StringType).sql == s"implicitcastinput(${Literal(1L).sql})") + } + + test("RemoveInputTypeMarkers keeps a failed type-check marker for CheckAnalysis to report") { + // A resolved marker has served its purpose and is unwrapped to its child ... + val okDelegate = CheckFn.build(CheckFn.name, Seq(Literal(1))) + assert(!RemoveInputTypeMarkers.removeMarkers(okDelegate).exists(_.isInstanceOf[TypeCheckInput]), + "a resolved TypeCheckInput should be unwrapped") + // ... but a type-mismatched (unresolved) marker is left in place, so its ExpectsInputTypes + // failure stays visible to CheckAnalysis instead of exposing a resolved child of a wrong type. + val badDelegate = CheckFn.build(CheckFn.name, Seq(Literal(1L))) + val cleaned = RemoveInputTypeMarkers.removeMarkers(badDelegate) + assert(cleaned.exists(_.isInstanceOf[TypeCheckInput]), + s"a failed TypeCheckInput must be preserved for CheckAnalysis, got $cleaned") + } + + private object MixedFn extends DelegateFunction { + override val name = "mixedfn" + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, AnyDataType) + override def lower(args: Seq[Expression]): Expression = CreateArray(args) + } + + test("input-type contract is per argument: AnyDataType position opts out of shimming") { + val args = buildDef(MixedFn, Literal(1), Literal(2)).asInstanceOf[CreateArray].children + assert(args(0).isInstanceOf[ImplicitCastInput]) // StringType -> shimmed + assert(args(1) == Literal(2)) // AnyDataType -> raw + } + + test("apply (direct construction) inserts no markers; args must already be typed") { + // Unlike `build`, `apply` is construct-anywhere and never produces input-type markers. + assert(CastFn(Literal("x")).definition == Literal("x")) + assert( + MixedFn(Literal("s"), Literal(2)).definition == CreateArray(Seq(Literal("s"), Literal(2)))) + } + + test("apply rejects unresolved arguments") { + intercept[IllegalArgumentException](CastFn(UnresolvedAttribute("x"))) + } + + // ---- definition is a real child (the safety property the whole design rests on) ---- + + test("transform reaches into the definition and withNewChildren replaces it") { + val ref = BoundReference(0, IntegerType, nullable = true) + val expr = DelegateExpression("inc", Seq(ref), Add(ref, Literal(1))) + // tree traversal descends into `definition` + val bumped = expr.transform { case Literal(1, IntegerType) => Literal(2) } + assert(bumped == DelegateExpression("inc", Seq(ref), Add(ref, Literal(2)))) + // withNewChildren swaps the single child, which is the definition + val replaced = expr.withNewChildren(Seq(Literal(99))).asInstanceOf[DelegateExpression] + assert(replaced.definition == Literal(99)) + assert(replaced.inputs == Seq(ref)) // inputs are metadata, untouched + } + + test("references come from the definition, not from inputs") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + // `b` appears only as display metadata; the real child references `a` + val expr = DelegateExpression("f", Seq(b), Add(a, Literal(1))) + assert(expr.references == AttributeSet(a)) + } + + test("sql and prettyName reflect the high-level call") { + val expr = DelegateExpression("myfunc", Seq(Literal(1), Literal("x")), Literal(0)) + assert(expr.prettyName == "myfunc") + assert(expr.sql == "myfunc(1, 'x')") + } + + test("DelegateFunction.unapply round-trips apply") { + assert(CastFn.unapply(CastFn(Literal("x"))).contains(Seq(Literal("x")))) + assert(AnyFn.unapply(Literal(1)).isEmpty) + } + + // ---- input-type markers are transient, unevaluable, and transparent in type ---- + + test("input-type markers are Unevaluable and delegate type/nullability to their child") { + val marker = ImplicitCastInput(BoundReference(0, IntegerType, nullable = true), StringType) + assert(marker.dataType == IntegerType) // delegates to child until coercion casts it + assert(marker.nullable) + intercept[Exception](marker.eval(InternalRow(1))) + } + + // ---- MultiGetJsonObject: the optimizer-constructed delegate ---- + + test("MultiGetJsonObject builds a typed delegate over an Invoke and round-trips paths") { + val e = MultiGetJsonObject(Literal("{}"), Seq("$.a", "$.b")) + assert(e.name == "multi_get_json_object") + assert(e.definition.isInstanceOf[Invoke]) + assert(e.dataType == StructType(Seq( + StructField("_0", StringType), StructField("_1", StringType)))) + assert(MultiGetJsonObject.unapply(e).map(_._2).contains(Seq("$.a", "$.b"))) + } + + test("MultiGetJsonObject evaluates by delegating through the Invoke") { + val e = MultiGetJsonObject(Literal("""{"a":1,"b":2}"""), Seq("$.a", "$.b")) + val row = e.eval(null).asInstanceOf[InternalRow] + assert(row.getUTF8String(0) == UTF8String.fromString("1")) + assert(row.getUTF8String(1) == UTF8String.fromString("2")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala index cb551b772ef31..00aa3fc0775ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala @@ -207,10 +207,10 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { optimized match { case Project(projectList, Project(innerProjectList, _: LocalRelation)) => val sharedAlias = innerProjectList.collectFirst { - case alias @ Alias(_: MultiGetJsonObject, "_shared_json_paths") => alias + case alias @ Alias(MultiGetJsonObject(_, _), "_shared_json_paths") => alias }.getOrElse(fail(s"Missing shared JSON paths in plan:\n$optimized")) - val shared = sharedAlias.child.asInstanceOf[MultiGetJsonObject] - assert(shared.fallbackPaths == Seq("$.b", "$['a']")) + val shared = sharedAlias.child + assert(MultiGetJsonObject.pathsOf(shared) == Seq("$.b", "$['a']")) val sharedAttr = sharedAlias.toAttribute val extractedFields = projectList.flatMap(_.collect { @@ -254,10 +254,10 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { optimized match { case Project(projectList, Project(innerProjectList, _: LocalRelation)) => val sharedAlias = innerProjectList.collectFirst { - case alias @ Alias(_: MultiGetJsonObject, "_shared_json_paths") => alias + case alias @ Alias(MultiGetJsonObject(_, _), "_shared_json_paths") => alias }.getOrElse(fail(s"Missing shared JSON paths in plan:\n$optimized")) - val shared = sharedAlias.child.asInstanceOf[MultiGetJsonObject] - assert(shared.fallbackPaths == Seq("$.a.b", "$['a']['c.d']", "$.e", "$.f.b")) + val shared = sharedAlias.child + assert(MultiGetJsonObject.pathsOf(shared) == Seq("$.a.b", "$['a']['c.d']", "$.e", "$.f.b")) val sharedAttr = sharedAlias.toAttribute val extractedFields = projectList.flatMap(_.collect { @@ -286,11 +286,11 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val shared = optimized.collect { case Project(projectList, _) => projectList.collectFirst { - case alias @ Alias(_: MultiGetJsonObject, "_shared_json_paths") => alias + case alias @ Alias(MultiGetJsonObject(_, _), "_shared_json_paths") => alias } }.flatten.headOption.getOrElse(fail(s"Missing shared JSON paths in plan:\n$optimized")) - .child.asInstanceOf[MultiGetJsonObject] - assert(shared.fallbackPaths == Seq("$.a", "$.c.d", "$.e")) + .child + assert(MultiGetJsonObject.pathsOf(shared) == Seq("$.a", "$.c.d", "$.e")) val remainingPaths = optimized.expressions.flatMap(_.collect { case GetJsonObject(_, Literal(path: UTF8String, StringType)) => path.toString @@ -308,11 +308,11 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val shared = optimized.collect { case Project(projectList, _) => projectList.collectFirst { - case alias @ Alias(_: MultiGetJsonObject, "_shared_json_paths") => alias + case alias @ Alias(MultiGetJsonObject(_, _), "_shared_json_paths") => alias } }.flatten.headOption.getOrElse(fail(s"Missing shared JSON paths in plan:\n$optimized")) - .child.asInstanceOf[MultiGetJsonObject] - assert(shared.fallbackPaths == Seq("$.a.b", "$.a.c", "$.d")) + .child + assert(MultiGetJsonObject.pathsOf(shared) == Seq("$.a.b", "$.a.c", "$.d")) assert(optimized.expressions.exists(_.exists { case GetJsonObject(_, Literal(path: UTF8String, StringType)) => path.toString == "$.a" case _ => false @@ -336,7 +336,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { def assertSharedPaths(optimized: LogicalPlan): Unit = { val sharedPaths = optimized.collect { case Project(projectList, _) => projectList.collect { - case Alias(shared: MultiGetJsonObject, "_shared_json_paths") => shared.fallbackPaths + case Alias(MultiGetJsonObject(_, paths), "_shared_json_paths") => paths } }.flatten assert(sharedPaths == expectedSharedPaths) @@ -356,13 +356,13 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val shared = optimized.collect { case Project(projectList, _) => projectList.collectFirst { - case alias @ Alias(_: MultiGetJsonObject, "_shared_json_paths") => alias + case alias @ Alias(MultiGetJsonObject(_, _), "_shared_json_paths") => alias } }.flatten.headOption.getOrElse(fail(s"Missing shared JSON paths in plan:\n$optimized")) - .child.asInstanceOf[MultiGetJsonObject] - assert(shared.fallbackPaths.length == pathCount) - assert(shared.fallbackPaths.head == "$.field_0") - assert(shared.fallbackPaths.last == s"$$.field_${pathCount - 1}") + .child + assert(MultiGetJsonObject.pathsOf(shared).length == pathCount) + assert(MultiGetJsonObject.pathsOf(shared).head == "$.field_0") + assert(MultiGetJsonObject.pathsOf(shared).last == s"$$.field_${pathCount - 1}") } test("SPARK-47670: shared get_json_object paths survive project collapsing") { @@ -373,7 +373,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val optimized = OptimizerWithCollapseProject.execute(query.analyze) assert(optimized.exists { plan => - plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + plan.expressions.exists(_.exists(MultiGetJsonObject.isInstance(_))) }) assert(optimized.collect { case _: Project => true }.length == 2) } @@ -386,7 +386,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { GetJsonObject($"json", Literal("$.a"))).as("a"), GetJsonObject($"json", Literal("$.b")).as("b")) assert(!Optimizer.execute(guardedQuery.analyze).exists { plan => - plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + plan.expressions.exists(_.exists(MultiGetJsonObject.isInstance(_))) }) val lowerProject = testRelation2.select( @@ -394,7 +394,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { GetJsonObject($"json", Literal("$.b")).as("b")) val prunedQuery = lowerProject.select(lowerProject.output.head) assert(!OptimizerWithColumnPruning.execute(prunedQuery.analyze).exists { plan => - plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + plan.expressions.exists(_.exists(MultiGetJsonObject.isInstance(_))) }) } @@ -419,7 +419,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { Pmod(Cast(GetJsonObject($"json", Literal("$.a")), IntegerType), Literal(0)).as("a"), GetJsonObject($"json", Literal("$.b")).as("b")) assert(!Optimizer.execute(query.analyze).exists { plan => - plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + plan.expressions.exists(_.exists(MultiGetJsonObject.isInstance(_))) }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala index 633b826a36949..63efa00a4d156 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala @@ -366,6 +366,25 @@ class NormalizePlanSuite extends SparkFunSuite with SQLConfHelper { assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan)) } + test("Normalize DelegateExpression display inputs (random seed and common-expr id)") { + // A DelegateExpression's `inputs` are display-only metadata, not children, so + // `transformAllExpressions` never reaches them. They still carry run-dependent state (a `Rand` + // seed, a `CommonExpressionId`) that must be normalized, or two structurally-identical + // delegates (e.g. the fixed-point vs single-pass renderings of `right(rand(), 1)`) compare + // unequal and trip a false HYBRID_ANALYZER_LOGICAL_PLAN_COMPARISON_MISMATCH. + def delegateWith(seed: Long, id: CommonExpressionId): DelegateExpression = + DelegateExpression( + "f", + inputs = Seq(rand(seed), CommonExpressionRef(id, IntegerType, nullable = false)), + definition = Literal(1)) + + val baseline = LocalRelation().select(delegateWith(11L, new CommonExpressionId)) + val test = LocalRelation().select(delegateWith(22L, new CommonExpressionId)) + + assert(baseline != test) + assert(NormalizePlan(baseline) == NormalizePlan(test)) + } + test("Normalize UnionLoopRef IDs") { val col1 = $"col1".int val col2 = col1.newInstance() diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_right.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_right.explain index f8413c9deb725..efbdb104b432d 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_right.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_right.explain @@ -1,2 +1,2 @@ -Project [if (isnull(g#0)) null else if ((cast(g#0 as int) <= 0)) else substring(g#0, -cast(g#0 as int), 2147483647) AS right(g, g)#0] +Project [right(g#0, g#0) AS right(g, g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LowerDelegateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LowerDelegateExpression.scala new file mode 100644 index 0000000000000..4436a78a2f8d6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LowerDelegateExpression.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.{DelegateExpression, Expression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.DELEGATE_EXPRESSION + +/** + * Strips every [[DelegateExpression]] down to its `definition`. Run on the optimized logical plan + * in [[QueryExecution.createSparkPlan]] -- the single entry point to the planner, used by both the + * main query and AQE re-planning -- so the planner and every physical consumer (join-key + * extraction, V1 / cached-batch pushdown, columnar rules, codegen) sees the real executed + * expression rather than the informational wrapper. Data source V2 pushdown runs earlier, in the + * logical optimizer, so it unfolds the wrapper separately in `V2ExpressionBuilder`. The wrapper + * remains in the optimized logical plan for EXPLAIN. + */ +object LowerDelegateExpression extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = + plan.transformAllExpressionsWithPruning(_.containsPattern(DELEGATE_EXPRESSION)) { + case d: DelegateExpression => lower(d) + } + + // `definition` can itself be a [[DelegateExpression]] -- a delegate whose lowered form composes + // another delegate function. `transformDown` does not re-apply the rule to the replacement it + // just produced, so unwrap the chain here to guarantee no wrapper survives. Delegates nested + // deeper (as children of `definition`) are handled by the surrounding tree traversal. + @scala.annotation.tailrec + private def lower(d: DelegateExpression): Expression = d.definition match { + case inner: DelegateExpression => lower(inner) + case other => other + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index e37588faf8fb2..ea96da8aef412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -817,7 +817,10 @@ object QueryExecution { plan: LogicalPlan): SparkPlan = { // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. - planner.plan(ReturnAnswer(plan)).next() + // Strip DelegateExpression to its definition right before planning, so the planner and every + // physical consumer (pushdown, columnar rules, codegen) sees the real executed expression. The + // wrapper is purely informational and stays in the optimized logical plan for EXPLAIN. + planner.plan(ReturnAnswer(LowerDelegateExpression(plan))).next() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 9a483076ff567..97c851c08fc12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -814,7 +814,18 @@ case class AdaptiveSparkPlanExec( try { logicalPlan.invalidateStatsCache() val optimized = optimizer.execute(logicalPlan) - val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() + // Strip DelegateExpression once here and reuse the SAME lowered tree for both planning and + // the returned logical plan. `createSparkPlan` is the single place that strips + // DelegateExpression before planning; feeding it the already-lowered tree makes its internal + // strip a no-op (the same instance is returned when no delegate remains), so the re-planned + // stage sees the real executed expression AND the physical plan's `logicalLink` targets point + // at `lowered`'s own nodes. Returning the unlowered `optimized` instead would leave those + // links pointing at lowered copies absent from the returned tree, so + // `replaceWithQueryStagesInLogicalPlan` could not match a completed stage back by reference + // equality and would lose that stage's statistics. + val lowered = LowerDelegateExpression(optimized) + val sparkPlan = QueryExecution.createSparkPlan( + context.session.sessionState.planner, lowered) val newPlan = applyPhysicalRules( applyQueryPostPlannerStrategyRules(sparkPlan), preprocessingRules ++ queryStagePreparationRules, @@ -833,7 +844,7 @@ case class AdaptiveSparkPlanExec( case _ => newPlan } - Some((finalPlan, optimized)) + Some((finalPlan, lowered)) } catch { case e: InvalidAQEPlanException[_] => logOnLevel(log"Re-optimize - ${MDC(ERROR, e.getMessage())}:\n" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 9c745461b1709..63150063622ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -930,21 +930,33 @@ private[sql] object DataSourceV2Strategy extends Logging { protected[sql] def rebuildExpressionFromFilter( predicate: Predicate, translatedFilterToExpr: mutable.HashMap[Predicate, Expression]): Expression = { - predicate match { - case and: V2And => - expressions.And( - rebuildExpressionFromFilter(and.left(), translatedFilterToExpr), - rebuildExpressionFromFilter(and.right(), translatedFilterToExpr)) - case or: V2Or => - expressions.Or( - rebuildExpressionFromFilter(or.left(), translatedFilterToExpr), - rebuildExpressionFromFilter(or.right(), translatedFilterToExpr)) - case not: V2Not => - expressions.Not(rebuildExpressionFromFilter(not.child(), translatedFilterToExpr)) - case _ => - translatedFilterToExpr.getOrElse(predicate, - throw SparkException.internalError( - "Failed to rebuild Expression for filter: " + predicate)) + // Prefer an exact map entry before structurally descending. A normal compound predicate is + // decomposed at translation time (only its leaves are mapped), so this lookup misses and we + // fall through to the structural cases below -- behavior unchanged. But a `DelegateExpression` + // whose `definition` is a compound (e.g. `And(a > 1, b < 2)`) is translated as a single leaf, + // so the whole `V2And`/`V2Or`/`V2Not` is mapped back to the delegate; matching the structural + // case first would then recurse into the synthetic children, which have no map entries, and + // throw. Checking the map first restores the original delegate directly. This is + // granularity-correct at every level of descent, so a delegate nested inside an ordinary + // compound is covered too. + translatedFilterToExpr.get(predicate) match { + case Some(expr) => expr + case None => + predicate match { + case and: V2And => + expressions.And( + rebuildExpressionFromFilter(and.left(), translatedFilterToExpr), + rebuildExpressionFromFilter(and.right(), translatedFilterToExpr)) + case or: V2Or => + expressions.Or( + rebuildExpressionFromFilter(or.left(), translatedFilterToExpr), + rebuildExpressionFromFilter(or.right(), translatedFilterToExpr)) + case not: V2Not => + expressions.Not(rebuildExpressionFromFilter(not.child(), translatedFilterToExpr)) + case _ => + throw SparkException.internalError( + "Failed to rebuild Expression for filter: " + predicate) + } } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/string-functions.sql.out index 15e6f3ada2668..932d5893888cb 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/string-functions.sql.out @@ -89,21 +89,21 @@ Project [left(abcd, -2) AS left(abcd, -2)#x, left(abcd, 0) AS left(abcd, 0)#x, l -- !query select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null) -- !query analysis -Project [right(abcd, 2) AS right(abcd, 2)#x, right(abcd, 5) AS right(abcd, 5)#x, right(abcd, cast(2 as int)) AS right(abcd, 2)#x, right(abcd, cast(null as int)) AS right(abcd, NULL)#x] +Project [right(abcd, 2) AS right(abcd, 2)#x, right(abcd, 5) AS right(abcd, 5)#x, right(abcd, 2) AS right(abcd, 2)#x, right(abcd, null) AS right(abcd, NULL)#x] +- OneRowRelation -- !query select right(null, -2) -- !query analysis -Project [right(cast(null as string), -2) AS right(NULL, -2)#x] +Project [right(null, -2) AS right(NULL, -2)#x] +- OneRowRelation -- !query select right("abcd", -2), right("abcd", 0), right("abcd", 'a') -- !query analysis -Project [right(abcd, -2) AS right(abcd, -2)#x, right(abcd, 0) AS right(abcd, 0)#x, right(abcd, cast(a as int)) AS right(abcd, a)#x] +Project [right(abcd, -2) AS right(abcd, -2)#x, right(abcd, 0) AS right(abcd, 0)#x, right(abcd, a) AS right(abcd, a)#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/text.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/text.sql.out index ef7b7a4180ba1..462c91244a036 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/text.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/text.sql.out @@ -134,7 +134,7 @@ Project [reverse(abcde) AS reverse(abcde)#x] select i, left('ahoj', i), right('ahoj', i) from range(-5, 6) t(i) order by i -- !query analysis Sort [i#xL ASC NULLS FIRST], true -+- Project [i#xL, left(ahoj, cast(i#xL as int)) AS left(ahoj, i)#x, right(ahoj, cast(i#xL as int)) AS right(ahoj, i)#x] ++- Project [i#xL, left(ahoj, cast(i#xL as int)) AS left(ahoj, i)#x, right(ahoj, i#xL) AS right(ahoj, i)#x] +- SubqueryAlias t +- Project [id#xL AS i#xL] +- Range (-5, 6, step=1) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out index 15e6f3ada2668..932d5893888cb 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/string-functions.sql.out @@ -89,21 +89,21 @@ Project [left(abcd, -2) AS left(abcd, -2)#x, left(abcd, 0) AS left(abcd, 0)#x, l -- !query select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null) -- !query analysis -Project [right(abcd, 2) AS right(abcd, 2)#x, right(abcd, 5) AS right(abcd, 5)#x, right(abcd, cast(2 as int)) AS right(abcd, 2)#x, right(abcd, cast(null as int)) AS right(abcd, NULL)#x] +Project [right(abcd, 2) AS right(abcd, 2)#x, right(abcd, 5) AS right(abcd, 5)#x, right(abcd, 2) AS right(abcd, 2)#x, right(abcd, null) AS right(abcd, NULL)#x] +- OneRowRelation -- !query select right(null, -2) -- !query analysis -Project [right(cast(null as string), -2) AS right(NULL, -2)#x] +Project [right(null, -2) AS right(NULL, -2)#x] +- OneRowRelation -- !query select right("abcd", -2), right("abcd", 0), right("abcd", 'a') -- !query analysis -Project [right(abcd, -2) AS right(abcd, -2)#x, right(abcd, 0) AS right(abcd, 0)#x, right(abcd, cast(a as int)) AS right(abcd, a)#x] +Project [right(abcd, -2) AS right(abcd, -2)#x, right(abcd, 0) AS right(abcd, 0)#x, right(abcd, a) AS right(abcd, a)#x] +- OneRowRelation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DelegateExpressionQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DelegateExpressionQuerySuite.scala new file mode 100644 index 0000000000000..483eeb71690e4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DelegateExpressionQuerySuite.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.analysis.resolver.ResolverRunner +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, DelegateExpression, ImplicitCastInput, Literal, MultiGetJsonObject, TypeCheckInput} +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.execution.{LowerDelegateExpression, WholeStageCodegenExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType + +/** + * End-to-end proof for the delegate-expression redesign: `right()` is built as a + * [[DelegateExpression]] -- a logical-phase wrapper that stays readable in the optimized plan and + * is lowered to its real definition (by `LowerDelegateExpression`) before physical execution, so + * the planner, pushdown, columnar rules and codegen see the actual executed expression. + */ +class DelegateExpressionQuerySuite extends QueryTest with SharedSparkSession { + + test("right() is a DelegateExpression in the optimized plan, lowered before execution") { + val df = spark.range(0, 3).selectExpr("right(concat('row', cast(id as string)), 1) as r") + checkAnswer(df, Seq(Row("0"), Row("1"), Row("2"))) + + // Readable in the optimized logical plan ... + assert(df.queryExecution.optimizedPlan.exists( + _.expressions.exists(_.exists(_.isInstanceOf[DelegateExpression]))), + s"expected a DelegateExpression in the optimized plan:\n${df.queryExecution.optimizedPlan}") + // ... but lowered away before physical execution, so engines see the real expression. + val executed = df.queryExecution.executedPlan + assert(!executed.exists(_.expressions.exists(_.exists(_.isInstanceOf[DelegateExpression]))), + s"DelegateExpression should be lowered before execution:\n$executed") + assert(executed.exists(_.isInstanceOf[WholeStageCodegenExec]), + s"expected whole-stage codegen in the executed plan:\n$executed") + } + + test("right() implicit-casts a non-string arg via the standard coercion rule (no extra step)") { + // The old plain-form `right` was ImplicitCastInputTypes; the delegate form preserves this by + // wrapping the arg in an ImplicitCastInput shim that the standard TypeCoercion rule handles. + checkAnswer(spark.sql("SELECT right(12345, 2)"), Row("45")) + } + + test("internal input shims are stripped at the end of analysis") { + val df = spark.range(0, 3).selectExpr("right(concat('row', cast(id as string)), 2) as r") + val analyzed = df.queryExecution.analyzed + // The high-level delegate remains in the plan ... + assert(analyzed.exists(_.expressions.exists(_.exists(_.isInstanceOf[DelegateExpression])))) + // ... but the internal coercion shims are gone (they were inserted, then stripped). + assert(!analyzed.exists(_.expressions.exists(_.exists(e => + e.isInstanceOf[ImplicitCastInput] || e.isInstanceOf[TypeCheckInput]))), + s"input shims should be stripped after analysis:\n$analyzed") + } + + test("right() produces identical results with whole-stage codegen on and off") { + Seq("true", "false").foreach { flag => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) { + checkAnswer( + spark.range(0, 3).selectExpr("right(concat('row', cast(id as string)), 2) as r"), + Seq(Row("w0"), Row("w1"), Row("w2"))) + } + } + } + + test("optimizer-inserted MultiGetJsonObject is a delegate in the optimized plan, lowered " + + "before execution") { + import testImplicits._ + withSQLConf(SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED.key -> "true") { + val df = Seq("""{"a":1,"b":2}""").toDF("j") + .selectExpr("get_json_object(j, '$.a') as a", "get_json_object(j, '$.b') as b") + checkAnswer(df, Row("1", "2")) + + // The two sibling get_json_object calls were shared into one delegate, readable in the + // optimized plan ... + assert(df.queryExecution.optimizedPlan.exists( + _.expressions.exists(_.exists(MultiGetJsonObject.isInstance))), + s"expected a multi_get_json_object delegate in the optimized plan") + // ... and lowered to its Invoke definition before execution. + val executed = df.queryExecution.executedPlan + assert(!executed.exists(_.expressions.exists(_.exists(MultiGetJsonObject.isInstance))), + s"delegate should be lowered before execution:\n$executed") + assert(executed.exists(_.isInstanceOf[WholeStageCodegenExec])) + } + } + + test("right() resolves cleanly under the single-pass resolver (input-type markers stripped)") { + // The single-pass resolver builds DelegateFunctions through the same registry path (inserting + // the input-type markers) but has no fixed-point batch to strip them; FunctionResolver must + // remove them after coercion, else the Unevaluable markers would reach execution. We assert at + // the analyzed-plan level (where the fix lives): single-pass does not yet support the + // DeserializeToObject operator a typed `collect`/`checkAnswer` introduces, so the right() + // execution results stay covered by the fixed-point tests above. + withSQLConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED.key -> "true") { + // 12345 (int) exercises the ImplicitCastInput path: it must be cast to string. + val analyzed = spark.sql("SELECT right(12345, 2) AS r").queryExecution.analyzed + // single-pass actually ran (there is no fallback when the conf is on) ... + assert(analyzed.getTagValue(ResolverRunner.SINGLE_PASS_ANALYSIS_MARKER).contains(true), + s"expected single-pass analysis to run:\n$analyzed") + assert(analyzed.exists(_.expressions.exists(_.exists(_.isInstanceOf[DelegateExpression]))), + s"expected the right() delegate in the analyzed plan:\n$analyzed") + // ... the DelegateFunction's input-type markers were stripped ... + assert(!analyzed.exists(_.expressions.exists(_.exists(e => + e.isInstanceOf[ImplicitCastInput] || e.isInstanceOf[TypeCheckInput]))), + s"input shims should be stripped under single-pass:\n$analyzed") + // ... and the implicit cast the marker drove still applies (the marker was removed, not the + // Cast). + assert(analyzed.exists(_.expressions.exists(_.exists(_.isInstanceOf[Cast]))), + s"expected the implicit Cast to survive marker removal:\n$analyzed") + } + } + + test("LowerDelegateExpression fully unwraps a directly-nested delegate-of-delegate") { + // A delegate whose `definition` is itself a delegate (e.g. one delegate function composing + // another). transformDown does not re-apply the rule to the replacement it produces, so the + // rule must unwrap the chain itself -- otherwise the inner wrapper would reach the planner. + val inner = DelegateExpression("inner", Seq(Literal(1)), Literal(1)) + val outer = DelegateExpression("outer", Seq(Literal(1)), inner) + val lowered = LowerDelegateExpression(Project(Seq(Alias(outer, "c")()), OneRowRelation())) + assert(!lowered.exists(_.expressions.exists(_.exists(_.isInstanceOf[DelegateExpression]))), + s"nested delegates should be fully lowered:\n$lowered") + } + + test("right() preserves the input column's collation in its output type") { + // `Right.lower` builds the null/empty `If` branches as plain StringType literals (it cannot + // read the not-yet-coerced arg's dataType); type coercion then re-unifies the branches to the + // column's collation, since string literals carry the weakest collation strength. + val df = spark.sql("SELECT right('Hello' COLLATE UTF8_LCASE, 3) AS r") + assert(df.schema("r").dataType === StringType("UTF8_LCASE"), + s"right() should preserve the UTF8_LCASE collation, got ${df.schema("r").dataType}") + checkAnswer(df, Row("llo")) + } + + test("right() preserves the input CHAR/VARCHAR type with preserveCharVarcharTypeInfo") { + // `Right.lower` types its null/empty `If` branch literals with `str.dataType` (the resolved + // input type the marker delegates), so the result keeps CHAR(N) instead of being widened to + // plain string when type coercion unifies the branches. + withSQLConf(SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO.key -> "true") { + checkAnswer(spark.sql("SELECT typeof(right(CAST('abc' AS CHAR(5)), 2)) AS t"), Row("char(5)")) + } + } + + test("right() rejects a wrong number of arguments with WRONG_NUM_ARGS") { + // `DelegateFunction.build` validates arity before lowering, so too few/too many arguments fail + // with the structured error rather than an IndexOutOfBounds or a silently ignored extra arg. + Seq("SELECT right('abcd')", "SELECT right('abcd', 1, 99)").foreach { q => + val e = intercept[AnalysisException](spark.sql(q)) + assert(e.getCondition == "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + s"unexpected error condition for `$q`: ${e.getCondition}") + } + } + + test("a delegate over a HAVING aggregate gets a clean generated name (TempResolvedColumn " + + "trimmed)") { + // SPARK-52385-style: in an aggregate/HAVING, `v` is wrapped in a `TempResolvedColumn` while it + // resolves against the grouping input. That wrapper rides in the delegate's display-only + // `inputs`, which the pretty-printer's `transform` never rewrites; without an explicit trim the + // generated column name would leak `right(tempresolvedcolumn(v), 1)` instead of `right(v, 1)`. + import testImplicits._ + withTempView("hav") { + Seq((1, 3), (1, 5)).toDF("k", "v").createOrReplaceTempView("hav") + val df = spark.sql("SELECT max(right(v, 1)) FROM hav HAVING max(right(v, 1)) IS NOT NULL") + val name = df.schema.fields.head.name + assert(!name.contains("tempresolvedcolumn"), + s"generated name should not leak the temp-resolution marker, got: $name") + assert(name == "max(right(v, 1))", s"unexpected generated name: $name") + } + } + + test("an un-castable argument is reported against the delegate call, not the internal marker") { + // A failed implicit cast leaves the `ImplicitCastInput` marker in the tree so `CheckAnalysis` + // (walking bottom-up) can reject it. The marker is an analysis-only detail, so it reports as if + // the check ran on the high-level `right(...)` call: the `sqlExpr` stays `right('abc', ...)` + // (not `implicitcastinput(...)`) and `paramIndex` is the real argument position (`second`), + // matching the pre-delegate `right` -- no user-facing change. + checkError( + exception = intercept[AnalysisException](spark.sql("SELECT right('abc', array(1))")), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"right(abc, array(1))\"", + "paramIndex" -> "second", + "requiredType" -> "\"INT\"", + "inputSql" -> "\"array(1)\"", + "inputType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext( + fragment = "right('abc', array(1))", start = 7, stop = 28))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 396b86144f4da..8a103f2f52236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -27,8 +27,10 @@ import com.fasterxml.jackson.core.StreamReadConstraints import org.apache.spark.{SparkException, SparkRuntimeException} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{JsonToStructs, Literal, MultiGetJsonObject} +import org.apache.spark.sql.catalyst.expressions.{Expression, JsonToStructs, Literal, MultiGetJsonObject} import org.apache.spark.sql.catalyst.expressions.Cast._ +import org.apache.spark.sql.catalyst.expressions.json.MultiGetJsonObjectEvaluator +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.util.TimestampNanosTestUtils import org.apache.spark.sql.catalyst.util.TimestampNanosTestUtils.foreachNanosPrecision import org.apache.spark.sql.execution.{InputAdapter, SparkPlan, WholeStageCodegenExec} @@ -147,7 +149,7 @@ class JsonFunctionsSuite extends SharedSparkSession { get_json_object($"json", "$.d")) if (jsonOptimization && sharedParsing) { assert(query.queryExecution.optimizedPlan.exists { plan => - plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + plan.expressions.exists(_.exists(MultiGetJsonObject.isInstance(_))) }) } rows = query.collect().toSeq @@ -231,7 +233,7 @@ class JsonFunctionsSuite extends SharedSparkSession { if (sharedParsingEnabled) { assert(query.queryExecution.optimizedPlan.exists { plan => - plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + plan.expressions.exists(_.exists(MultiGetJsonObject.isInstance(_))) }) } rows = query.collect().toSeq @@ -257,7 +259,7 @@ class JsonFunctionsSuite extends SharedSparkSession { get_json_object($"json", "$.b")) if (jsonOptimization) { assert(query.queryExecution.optimizedPlan.exists { plan => - plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + plan.expressions.exists(_.exists(MultiGetJsonObject.isInstance(_))) }) } rows = query.collect().toSeq @@ -310,10 +312,19 @@ class JsonFunctionsSuite extends SharedSparkSession { get_json_object($"json", "$.a.y")) checkAnswer(df, Row("1", "2")) + // The shared MultiGetJsonObject delegate is lowered before execution, so in the executed plan + // the shared extraction is its definition: an Invoke of MultiGetJsonObjectEvaluator. + def isSharedEval(e: Expression): Boolean = e.exists { + case i: Invoke if i.functionName == "evaluate" => + i.targetObject match { + case Literal(_: MultiGetJsonObjectEvaluator, _) => true + case _ => false + } + case _ => false + } def containsSharedExtraction(plan: SparkPlan): Boolean = plan match { case _: InputAdapter => false - case other - if other.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) => true + case other if other.expressions.exists(isSharedEval) => true case other => other.children.exists(containsSharedExtraction) } assert(df.queryExecution.executedPlan.exists { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala index 0301c1d0f5baa..7a7a33914d72d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.mutable + import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -325,6 +327,32 @@ class DataSourceV2StrategySuite extends SharedSparkSession { Some(new Predicate("=", Array(FieldReference("col"), LiteralValue(true, BooleanType))))) } + test("SPARK-57512: V2 pushdown sees through a DelegateExpression wrapper") { + val pred = GreaterThan($"cint".int, Literal(1)) + val delegate = DelegateExpression("wrap", Seq($"cint".int, Literal(1)), pred) + // The wrapper is unfolded to its definition, so it pushes down exactly like the bare predicate. + assert(DataSourceV2Strategy.translateFilterV2(delegate).isDefined) + assert(DataSourceV2Strategy.translateFilterV2(delegate) == + DataSourceV2Strategy.translateFilterV2(pred)) + } + + test("SPARK-57512: a compound-definition DelegateExpression round-trips through filter rebuild") { + val a = $"cint".int + val b = $"`c.int`".int + // The definition is a compound predicate, so the wrapper translates to a structural V2And. + val definition = And(GreaterThan(a, Literal(1)), LessThan(b, Literal(2))) + val delegate = DelegateExpression("wrap", Seq(a, b), definition) + val map = mutable.HashMap.empty[Predicate, Expression] + val translated = DataSourceV2Strategy.translateFilterV2WithMapping(delegate, Some(map)) + assert(translated.isDefined, "the compound delegate should translate via its definition") + // The whole V2And is mapped back to the delegate (it was translated as a single leaf). + // Rebuilding must restore the delegate via the exact map entry, not descend into the synthetic + // children that have no map entries -- descending would throw + // "Failed to rebuild Expression for filter". + val rebuilt = DataSourceV2Strategy.rebuildExpressionFromFilter(translated.get, map) + assert(rebuilt == delegate, s"expected the original delegate, got $rebuilt") + } + test("inability to convert unknown expressions and predicates") { val unknownExpr = new GeneralScalarExpression("UNKNOWN", Array()) assert(V2ExpressionUtils.toCatalyst(unknownExpr).isEmpty)