@@ -21,13 +21,14 @@ package org.apache.spark.sql.sedona_sql.expressions
2121import org .apache .sedona .core .spatialRddTool .AdvancedStatCollector
2222import org .apache .spark .sql .catalyst .encoders .{ExpressionEncoder , RowEncoder }
2323import org .apache .sedona .common .Functions
24+ import org .apache .spark .sql .{Encoder , Encoders }
2425import org .apache .spark .sql .catalyst .encoders .ExpressionEncoder
2526import org .apache .spark .sql .expressions .Aggregator
2627import org .apache .spark .sql .sedona_sql .utils .SparkCompatUtil
2728import org .apache .spark .sql .types .{DoubleType , LongType , StructField , StructType }
2829import org .apache .spark .sql .{Encoder , Row }
2930import org .apache .spark .sql .sedona_sql .EncodersShim
30- import org .locationtech .jts .geom .{Coordinate , Geometry , GeometryFactory }
31+ import org .locationtech .jts .geom .{Coordinate , Envelope , Geometry , GeometryFactory }
3132import org .locationtech .jts .operation .overlayng .OverlayNGRobust
3233
3334import scala .collection .JavaConverters ._
@@ -38,18 +39,7 @@ import scala.collection.mutable.ListBuffer
3839 */
3940
4041trait TraitSTAggregateExec {
41- val initialGeometry : Geometry = {
42- // dummy value for initial value(polygon but )
43- // any other value is ok.
44- val coordinates : Array [Coordinate ] = new Array [Coordinate ](5 )
45- coordinates(0 ) = new Coordinate (- 999999999 , - 999999999 )
46- coordinates(1 ) = new Coordinate (- 999999999 , - 999999999 )
47- coordinates(2 ) = new Coordinate (- 999999999 , - 999999999 )
48- coordinates(3 ) = new Coordinate (- 999999999 , - 999999999 )
49- coordinates(4 ) = coordinates(0 )
50- val geometryFactory = new GeometryFactory ()
51- geometryFactory.createPolygon(coordinates)
52- }
42+ val initialGeometry : Geometry = null
5343 val serde = ExpressionEncoder [Geometry ]()
5444
5545 def zero : Geometry = initialGeometry
@@ -68,7 +58,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
6858 val bufferSerde = ExpressionEncoder [ListBuffer [Geometry ]]()
6959
7060 override def reduce (buffer : ListBuffer [Geometry ], input : Geometry ): ListBuffer [Geometry ] = {
71- buffer += input
61+ if (input != null ) {
62+ buffer += input
63+ }
7264 if (buffer.size >= bufferSize) {
7365 // Perform the union when buffer size is reached
7466 val unionGeometry = OverlayNGRobust .union(buffer.asJava)
@@ -92,6 +84,9 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
9284 }
9385
9486 override def finish (reduction : ListBuffer [Geometry ]): Geometry = {
87+ if (reduction.isEmpty) {
88+ return null
89+ }
9590 OverlayNGRobust .union(reduction.asJava)
9691 }
9792
@@ -103,81 +98,76 @@ private[apache] class ST_Union_Aggr(bufferSize: Int = 1000)
10398}
10499
105100/**
106- * Return the envelope boundary of the entire column
101+ * A helper class to store envelope boundary during aggregation. We use this custom case class
102+ * instead of JTS Envelope to work with the Spark Encoder.
107103 */
108- private [apache] class ST_Envelope_Aggr
109- extends Aggregator [Geometry , Geometry , Geometry ]
110- with TraitSTAggregateExec {
104+ case class EnvelopeBuffer (minX : Double , maxX : Double , minY : Double , maxY : Double ) {
105+ def isNull : Boolean = minX > maxX
111106
112- def reduce (buffer : Geometry , input : Geometry ): Geometry = {
113- val accumulateEnvelope = buffer.getEnvelopeInternal
114- val newEnvelope = input.getEnvelopeInternal
115- val coordinates : Array [Coordinate ] = new Array [Coordinate ](5 )
116- var minX = 0.0
117- var minY = 0.0
118- var maxX = 0.0
119- var maxY = 0.0
120- if (accumulateEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
121- // Found the accumulateEnvelope is the initial value
122- minX = newEnvelope.getMinX
123- minY = newEnvelope.getMinY
124- maxX = newEnvelope.getMaxX
125- maxY = newEnvelope.getMaxY
126- } else if (newEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
127- minX = accumulateEnvelope.getMinX
128- minY = accumulateEnvelope.getMinY
129- maxX = accumulateEnvelope.getMaxX
130- maxY = accumulateEnvelope.getMaxY
107+ def toEnvelope : Envelope = {
108+ if (isNull) {
109+ new Envelope ()
131110 } else {
132- minX = Math .min(accumulateEnvelope.getMinX, newEnvelope.getMinX)
133- minY = Math .min(accumulateEnvelope.getMinY, newEnvelope.getMinY)
134- maxX = Math .max(accumulateEnvelope.getMaxX, newEnvelope.getMaxX)
135- maxY = Math .max(accumulateEnvelope.getMaxY, newEnvelope.getMaxY)
111+ new Envelope (minX, maxX, minY, maxY)
136112 }
137- coordinates(0 ) = new Coordinate (minX, minY)
138- coordinates(1 ) = new Coordinate (minX, maxY)
139- coordinates(2 ) = new Coordinate (maxX, maxY)
140- coordinates(3 ) = new Coordinate (maxX, minY)
141- coordinates(4 ) = coordinates(0 )
142- val geometryFactory = new GeometryFactory ()
143- geometryFactory.createPolygon(coordinates)
144-
145113 }
146114
147- def merge (buffer1 : Geometry , buffer2 : Geometry ): Geometry = {
148- val leftEnvelope = buffer1.getEnvelopeInternal
149- val rightEnvelope = buffer2.getEnvelopeInternal
150- val coordinates : Array [Coordinate ] = new Array [Coordinate ](5 )
151- var minX = 0.0
152- var minY = 0.0
153- var maxX = 0.0
154- var maxY = 0.0
155- if (leftEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
156- minX = rightEnvelope.getMinX
157- minY = rightEnvelope.getMinY
158- maxX = rightEnvelope.getMaxX
159- maxY = rightEnvelope.getMaxY
160- } else if (rightEnvelope.equals(initialGeometry.getEnvelopeInternal)) {
161- minX = leftEnvelope.getMinX
162- minY = leftEnvelope.getMinY
163- maxX = leftEnvelope.getMaxX
164- maxY = leftEnvelope.getMaxY
115+ def merge (other : EnvelopeBuffer ): EnvelopeBuffer = {
116+ if (this .isNull) {
117+ other
118+ } else if (other.isNull) {
119+ this
165120 } else {
166- minX = Math .min(leftEnvelope.getMinX, rightEnvelope.getMinX)
167- minY = Math .min(leftEnvelope.getMinY, rightEnvelope.getMinY)
168- maxX = Math .max(leftEnvelope.getMaxX, rightEnvelope.getMaxX)
169- maxY = Math .max(leftEnvelope.getMaxY, rightEnvelope.getMaxY)
121+ EnvelopeBuffer (
122+ math.min(this .minX, other.minX),
123+ math.max(this .maxX, other.maxX),
124+ math.min(this .minY, other.minY),
125+ math.max(this .maxY, other.maxY))
170126 }
127+ }
128+ }
171129
172- coordinates(0 ) = new Coordinate (minX, minY)
173- coordinates(1 ) = new Coordinate (minX, maxY)
174- coordinates(2 ) = new Coordinate (maxX, maxY)
175- coordinates(3 ) = new Coordinate (maxX, minY)
176- coordinates(4 ) = coordinates(0 )
177- val geometryFactory = new GeometryFactory ()
178- geometryFactory.createPolygon(coordinates)
130+ /**
131+ * Return the envelope boundary of the entire column
132+ */
133+ private [apache] class ST_Envelope_Aggr
134+ extends Aggregator [Geometry , Option [EnvelopeBuffer ], Geometry ] {
135+
136+ val serde = ExpressionEncoder [Geometry ]()
137+
138+ def reduce (buffer : Option [EnvelopeBuffer ], input : Geometry ): Option [EnvelopeBuffer ] = {
139+ if (input == null ) return buffer
140+ val env = input.getEnvelopeInternal
141+ val envBuffer = EnvelopeBuffer (env.getMinX, env.getMaxX, env.getMinY, env.getMaxY)
142+ buffer match {
143+ case Some (b) => Some (b.merge(envBuffer))
144+ case None => Some (envBuffer)
145+ }
146+ }
147+
148+ def merge (
149+ buffer1 : Option [EnvelopeBuffer ],
150+ buffer2 : Option [EnvelopeBuffer ]): Option [EnvelopeBuffer ] = {
151+ (buffer1, buffer2) match {
152+ case (Some (b1), Some (b2)) => Some (b1.merge(b2))
153+ case (Some (_), None ) => buffer1
154+ case (None , Some (_)) => buffer2
155+ case (None , None ) => None
156+ }
157+ }
158+
159+ def finish (reduction : Option [EnvelopeBuffer ]): Geometry = {
160+ reduction match {
161+ case Some (b) => new GeometryFactory ().toGeometry(b.toEnvelope)
162+ case None => null
163+ }
179164 }
180165
166+ def bufferEncoder : Encoder [Option [EnvelopeBuffer ]] = Encoders .product[Option [EnvelopeBuffer ]]
167+
168+ def outputEncoder : ExpressionEncoder [Geometry ] = serde
169+
170+ def zero : Option [EnvelopeBuffer ] = None
181171}
182172
183173/**
@@ -187,16 +177,26 @@ private[apache] class ST_Intersection_Aggr
187177 extends Aggregator [Geometry , Geometry , Geometry ]
188178 with TraitSTAggregateExec {
189179 def reduce (buffer : Geometry , input : Geometry ): Geometry = {
190- if (buffer.isEmpty) input
191- else if (buffer.equalsExact(initialGeometry)) input
192- else buffer.intersection(input)
180+ if (input == null ) {
181+ return buffer
182+ }
183+ if (buffer == null ) {
184+ return input
185+ }
186+ buffer.intersection(input)
193187 }
194188
195189 def merge (buffer1 : Geometry , buffer2 : Geometry ): Geometry = {
196- if (buffer1.equalsExact(initialGeometry)) buffer2
197- else if (buffer2.equalsExact(initialGeometry)) buffer1
198- else buffer1.intersection(buffer2)
190+ if (buffer1 == null ) {
191+ return buffer2
192+ }
193+ if (buffer2 == null ) {
194+ return buffer1
195+ }
196+ buffer1.intersection(buffer2)
199197 }
198+
199+ override def finish (out : Geometry ): Geometry = out
200200}
201201
202202/**
@@ -225,7 +225,7 @@ private[apache] class ST_Collect_Agg
225225
226226 override def finish (reduction : ListBuffer [Geometry ]): Geometry = {
227227 if (reduction.isEmpty) {
228- new GeometryFactory ().createGeometryCollection()
228+ null
229229 } else {
230230 Functions .createMultiGeometry(reduction.toArray)
231231 }
@@ -303,25 +303,3 @@ class ST_Analyze_Aggr extends Aggregator[Geometry, AdvancedStatCollector, Row] {
303303 StructField (" mean_envelope_height" , DoubleType , nullable = false ),
304304 StructField (" mean_envelope_area" , DoubleType , nullable = false ))))
305305}
306-
307- override def merge (
308- buffer1 : ListBuffer [Geometry ],
309- buffer2 : ListBuffer [Geometry ]): ListBuffer [Geometry ] = {
310- buffer1 ++= buffer2
311- buffer1
312- }
313-
314- override def finish (reduction : ListBuffer [Geometry ]): Geometry = {
315- if (reduction.isEmpty) {
316- new GeometryFactory ().createGeometryCollection()
317- } else {
318- Functions .createMultiGeometry(reduction.toArray)
319- }
320- }
321-
322- def bufferEncoder : ExpressionEncoder [ListBuffer [Geometry ]] = bufferSerde
323-
324- def outputEncoder : ExpressionEncoder [Geometry ] = serde
325-
326- override def zero : ListBuffer [Geometry ] = ListBuffer .empty
327- }
0 commit comments