Skip to content

Commit eb93b5c

Browse files
Kontinuationjiayuasu
authored andcommitted
[apacheGH-2565] Fix NULL handling for various aggregation functions in SedonaSpark (apache#2563)
1 parent e1525ff commit eb93b5c

2 files changed

Lines changed: 234 additions & 107 deletions

File tree

spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/AggregateFunctions.scala

Lines changed: 85 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ package org.apache.spark.sql.sedona_sql.expressions
2121
import org.apache.sedona.core.spatialRddTool.AdvancedStatCollector
2222
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
2323
import org.apache.sedona.common.Functions
24+
import org.apache.spark.sql.{Encoder, Encoders}
2425
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2526
import org.apache.spark.sql.expressions.Aggregator
2627
import org.apache.spark.sql.sedona_sql.utils.SparkCompatUtil
2728
import org.apache.spark.sql.types.{DoubleType, LongType, StructField, StructType}
2829
import org.apache.spark.sql.{Encoder, Row}
2930
import org.apache.spark.sql.sedona_sql.EncodersShim
30-
import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
31+
import org.locationtech.jts.geom.{Coordinate, Envelope, Geometry, GeometryFactory}
3132
import org.locationtech.jts.operation.overlayng.OverlayNGRobust
3233

3334
import scala.collection.JavaConverters._
@@ -38,18 +39,7 @@ import scala.collection.mutable.ListBuffer
3839
*/
3940

4041
trait TraitSTAggregateExec {
41-
val initialGeometry: Geometry = {
42-
// dummy value for initial value(polygon but )
43-
// any other value is ok.
44-
val coordinates: Array[Coordinate] = new Array[Coordinate](5)
45-
coordinates(0) = new Coordinate(-999999999, -999999999)
46-
coordinates(1) = new Coordinate(-999999999, -999999999)
47-
coordinates(2) = new Coordinate(-999999999, -999999999)
48-
coordinates(3) = new Coordinate(-999999999, -999999999)
49-
coordinates(4) = coordinates(0)
50-
val geometryFactory = new GeometryFactory()
51-
geometryFactory.createPolygon(coordinates)
52-
}
42+
val initialGeometry: Geometry = null
5343
val serde = ExpressionEncoder[Geometry]()
5444

5545
def zero: Geometry = initialGeometry
@@ -68,7 +58,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
6858
val bufferSerde = ExpressionEncoder[ListBuffer[Geometry]]()
6959

7060
override def reduce(buffer: ListBuffer[Geometry], input: Geometry): ListBuffer[Geometry] = {
71-
buffer += input
61+
if (input != null) {
62+
buffer += input
63+
}
7264
if (buffer.size >= bufferSize) {
7365
// Perform the union when buffer size is reached
7466
val unionGeometry = OverlayNGRobust.union(buffer.asJava)
@@ -92,6 +84,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
9284
}
9385

9486
override def finish(reduction: ListBuffer[Geometry]): Geometry = {
87+
if (reduction.isEmpty) {
88+
return null
89+
}
9590
OverlayNGRobust.union(reduction.asJava)
9691
}
9792

@@ -103,81 +98,76 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
10398
}
10499

105100
/**
106-
* Return the envelope boundary of the entire column
101+
* A helper class to store envelope boundary during aggregation. We use this custom case class
102+
* instead of JTS Envelope to work with the Spark Encoder.
107103
*/
108-
private[apache] class ST_Envelope_Aggr
109-
extends Aggregator[Geometry, Geometry, Geometry]
110-
with TraitSTAggregateExec {
104+
case class EnvelopeBuffer(minX: Double, maxX: Double, minY: Double, maxY: Double) {
105+
def isNull: Boolean = minX > maxX
111106

112-
def reduce(buffer: Geometry, input: Geometry): Geometry = {
113-
val accumulateEnvelope = buffer.getEnvelopeInternal
114-
val newEnvelope = input.getEnvelopeInternal
115-
val coordinates: Array[Coordinate] = new Array[Coordinate](5)
116-
var minX = 0.0
117-
var minY = 0.0
118-
var maxX = 0.0
119-
var maxY = 0.0
120-
if (accumulateEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
121-
// Found the accumulateEnvelope is the initial value
122-
minX = newEnvelope.getMinX
123-
minY = newEnvelope.getMinY
124-
maxX = newEnvelope.getMaxX
125-
maxY = newEnvelope.getMaxY
126-
} else if (newEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
127-
minX = accumulateEnvelope.getMinX
128-
minY = accumulateEnvelope.getMinY
129-
maxX = accumulateEnvelope.getMaxX
130-
maxY = accumulateEnvelope.getMaxY
107+
def toEnvelope: Envelope = {
108+
if (isNull) {
109+
new Envelope()
131110
} else {
132-
minX = Math.min(accumulateEnvelope.getMinX, newEnvelope.getMinX)
133-
minY = Math.min(accumulateEnvelope.getMinY, newEnvelope.getMinY)
134-
maxX = Math.max(accumulateEnvelope.getMaxX, newEnvelope.getMaxX)
135-
maxY = Math.max(accumulateEnvelope.getMaxY, newEnvelope.getMaxY)
111+
new Envelope(minX, maxX, minY, maxY)
136112
}
137-
coordinates(0) = new Coordinate(minX, minY)
138-
coordinates(1) = new Coordinate(minX, maxY)
139-
coordinates(2) = new Coordinate(maxX, maxY)
140-
coordinates(3) = new Coordinate(maxX, minY)
141-
coordinates(4) = coordinates(0)
142-
val geometryFactory = new GeometryFactory()
143-
geometryFactory.createPolygon(coordinates)
144-
145113
}
146114

147-
def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
148-
val leftEnvelope = buffer1.getEnvelopeInternal
149-
val rightEnvelope = buffer2.getEnvelopeInternal
150-
val coordinates: Array[Coordinate] = new Array[Coordinate](5)
151-
var minX = 0.0
152-
var minY = 0.0
153-
var maxX = 0.0
154-
var maxY = 0.0
155-
if (leftEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
156-
minX = rightEnvelope.getMinX
157-
minY = rightEnvelope.getMinY
158-
maxX = rightEnvelope.getMaxX
159-
maxY = rightEnvelope.getMaxY
160-
} else if (rightEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
161-
minX = leftEnvelope.getMinX
162-
minY = leftEnvelope.getMinY
163-
maxX = leftEnvelope.getMaxX
164-
maxY = leftEnvelope.getMaxY
115+
def merge(other: EnvelopeBuffer): EnvelopeBuffer = {
116+
if (this.isNull) {
117+
other
118+
} else if (other.isNull) {
119+
this
165120
} else {
166-
minX = Math.min(leftEnvelope.getMinX, rightEnvelope.getMinX)
167-
minY = Math.min(leftEnvelope.getMinY, rightEnvelope.getMinY)
168-
maxX = Math.max(leftEnvelope.getMaxX, rightEnvelope.getMaxX)
169-
maxY = Math.max(leftEnvelope.getMaxY, rightEnvelope.getMaxY)
121+
EnvelopeBuffer(
122+
math.min(this.minX, other.minX),
123+
math.max(this.maxX, other.maxX),
124+
math.min(this.minY, other.minY),
125+
math.max(this.maxY, other.maxY))
170126
}
127+
}
128+
}
171129

172-
coordinates(0) = new Coordinate(minX, minY)
173-
coordinates(1) = new Coordinate(minX, maxY)
174-
coordinates(2) = new Coordinate(maxX, maxY)
175-
coordinates(3) = new Coordinate(maxX, minY)
176-
coordinates(4) = coordinates(0)
177-
val geometryFactory = new GeometryFactory()
178-
geometryFactory.createPolygon(coordinates)
130+
/**
131+
* Return the envelope boundary of the entire column
132+
*/
133+
private[apache] class ST_Envelope_Aggr
134+
extends Aggregator[Geometry, Option[EnvelopeBuffer], Geometry] {
135+
136+
val serde = ExpressionEncoder[Geometry]()
137+
138+
def reduce(buffer: Option[EnvelopeBuffer], input: Geometry): Option[EnvelopeBuffer] = {
139+
if (input == null) return buffer
140+
val env = input.getEnvelopeInternal
141+
val envBuffer = EnvelopeBuffer(env.getMinX, env.getMaxX, env.getMinY, env.getMaxY)
142+
buffer match {
143+
case Some(b) => Some(b.merge(envBuffer))
144+
case None => Some(envBuffer)
145+
}
146+
}
147+
148+
def merge(
149+
buffer1: Option[EnvelopeBuffer],
150+
buffer2: Option[EnvelopeBuffer]): Option[EnvelopeBuffer] = {
151+
(buffer1, buffer2) match {
152+
case (Some(b1), Some(b2)) => Some(b1.merge(b2))
153+
case (Some(_), None) => buffer1
154+
case (None, Some(_)) => buffer2
155+
case (None, None) => None
156+
}
157+
}
158+
159+
def finish(reduction: Option[EnvelopeBuffer]): Geometry = {
160+
reduction match {
161+
case Some(b) => new GeometryFactory().toGeometry(b.toEnvelope)
162+
case None => null
163+
}
179164
}
180165

166+
def bufferEncoder: Encoder[Option[EnvelopeBuffer]] = Encoders.product[Option[EnvelopeBuffer]]
167+
168+
def outputEncoder: ExpressionEncoder[Geometry] = serde
169+
170+
def zero: Option[EnvelopeBuffer] = None
181171
}
182172

183173
/**
@@ -187,16 +177,26 @@ private[apache] class ST_Intersection_Aggr
187177
extends Aggregator[Geometry, Geometry, Geometry]
188178
with TraitSTAggregateExec {
189179
def reduce(buffer: Geometry, input: Geometry): Geometry = {
190-
if (buffer.isEmpty) input
191-
else if (buffer.equalsExact(initialGeometry)) input
192-
else buffer.intersection(input)
180+
if (input == null) {
181+
return buffer
182+
}
183+
if (buffer == null) {
184+
return input
185+
}
186+
buffer.intersection(input)
193187
}
194188

195189
def merge(buffer1: Geometry, buffer2: Geometry): Geometry = {
196-
if (buffer1.equalsExact(initialGeometry)) buffer2
197-
else if (buffer2.equalsExact(initialGeometry)) buffer1
198-
else buffer1.intersection(buffer2)
190+
if (buffer1 == null) {
191+
return buffer2
192+
}
193+
if (buffer2 == null) {
194+
return buffer1
195+
}
196+
buffer1.intersection(buffer2)
199197
}
198+
199+
override def finish(out: Geometry): Geometry = out
200200
}
201201

202202
/**
@@ -225,7 +225,7 @@ private[apache] class ST_Collect_Agg
225225

226226
override def finish(reduction: ListBuffer[Geometry]): Geometry = {
227227
if (reduction.isEmpty) {
228-
new GeometryFactory().createGeometryCollection()
228+
null
229229
} else {
230230
Functions.createMultiGeometry(reduction.toArray)
231231
}
@@ -303,25 +303,3 @@ class ST_Analyze_Aggr extends Aggregator[Geometry, AdvancedStatCollector, Row] {
303303
StructField("mean_envelope_height", DoubleType, nullable = false),
304304
StructField("mean_envelope_area", DoubleType, nullable = false))))
305305
}
306-
307-
override def merge(
308-
buffer1: ListBuffer[Geometry],
309-
buffer2: ListBuffer[Geometry]): ListBuffer[Geometry] = {
310-
buffer1 ++= buffer2
311-
buffer1
312-
}
313-
314-
override def finish(reduction: ListBuffer[Geometry]): Geometry = {
315-
if (reduction.isEmpty) {
316-
new GeometryFactory().createGeometryCollection()
317-
} else {
318-
Functions.createMultiGeometry(reduction.toArray)
319-
}
320-
}
321-
322-
def bufferEncoder: ExpressionEncoder[ListBuffer[Geometry]] = bufferSerde
323-
324-
def outputEncoder: ExpressionEncoder[Geometry] = serde
325-
326-
override def zero: ListBuffer[Geometry] = ListBuffer.empty
327-
}

0 commit comments

Comments
 (0)