@@ -73,6 +73,163 @@ def _convert_java_rows(jdf):
7373 return [_convert_java_row (r , col_names ) for r in jrows ]
7474
7575
76+ # ---------------------------------------------------------------------------
77+ # Py4j / type-conversion helpers for createDataFrame and __getattr__
78+ # ---------------------------------------------------------------------------
79+
80+ def _is_java_object (obj ):
81+ """Check if obj is a Py4j proxy."""
82+ return hasattr (obj , '_get_object_id' )
83+
84+
85+ def _is_java_dataset (obj ):
86+ """Check if a Py4j proxy represents a Spark Dataset."""
87+ if not _is_java_object (obj ):
88+ return False
89+ try :
90+ return 'Dataset' in obj .getClass ().getName ()
91+ except Exception :
92+ return False
93+
94+
95+ _PYSPARK_TO_JAVA_TYPES = {
96+ 'StringType' : 'StringType' ,
97+ 'IntegerType' : 'IntegerType' ,
98+ 'LongType' : 'LongType' ,
99+ 'DoubleType' : 'DoubleType' ,
100+ 'FloatType' : 'FloatType' ,
101+ 'BooleanType' : 'BooleanType' ,
102+ 'ShortType' : 'ShortType' ,
103+ 'ByteType' : 'ByteType' ,
104+ 'DateType' : 'DateType' ,
105+ 'TimestampType' : 'TimestampType' ,
106+ 'BinaryType' : 'BinaryType' ,
107+ 'NullType' : 'NullType' ,
108+ }
109+
110+
111+ def _pyspark_type_to_java (dt ):
112+ """Convert a PySpark DataType instance to a Java DataType via Py4j gateway."""
113+ DataTypes = gateway .jvm .org .apache .spark .sql .types .DataTypes
114+ type_name = type (dt ).__name__
115+ if type_name in _PYSPARK_TO_JAVA_TYPES :
116+ return getattr (DataTypes , _PYSPARK_TO_JAVA_TYPES [type_name ])
117+ if type_name == 'DecimalType' :
118+ return DataTypes .createDecimalType (dt .precision , dt .scale )
119+ if type_name == 'ArrayType' :
120+ return DataTypes .createArrayType (
121+ _pyspark_type_to_java (dt .elementType ), dt .containsNull )
122+ if type_name == 'MapType' :
123+ return DataTypes .createMapType (
124+ _pyspark_type_to_java (dt .keyType ),
125+ _pyspark_type_to_java (dt .valueType ), dt .valueContainsNull )
126+ if type_name == 'StructType' :
127+ return _pyspark_schema_to_java (dt )
128+ return DataTypes .StringType
129+
130+
131+ def _pyspark_schema_to_java (pyspark_schema ):
132+ """Convert a PySpark StructType to a Java StructType."""
133+ DataTypes = gateway .jvm .org .apache .spark .sql .types .DataTypes
134+ java_fields = gateway .jvm .java .util .ArrayList ()
135+ for field in pyspark_schema .fields :
136+ jtype = _pyspark_type_to_java (field .dataType )
137+ java_fields .add (DataTypes .createStructField (
138+ field .name , jtype , getattr (field , 'nullable' , True )))
139+ return DataTypes .createStructType (java_fields )
140+
141+
142+ def _infer_java_type (value ):
143+ """Infer a Java DataType from a Python value."""
144+ DataTypes = gateway .jvm .org .apache .spark .sql .types .DataTypes
145+ if value is None :
146+ return DataTypes .StringType
147+ if isinstance (value , bool ):
148+ return DataTypes .BooleanType
149+ if isinstance (value , int ):
150+ return DataTypes .LongType if abs (value ) > 2147483647 else DataTypes .IntegerType
151+ if isinstance (value , float ):
152+ return DataTypes .DoubleType
153+ return DataTypes .StringType
154+
155+
156+ def _resolve_schema (schema , data ):
157+ """Resolve any schema representation to a Java StructType."""
158+ if schema is None :
159+ return _infer_schema (data )
160+ if _is_java_object (schema ):
161+ return schema
162+ if hasattr (schema , 'fields' ) and not _is_java_object (schema ):
163+ return _pyspark_schema_to_java (schema )
164+ if isinstance (schema , str ):
165+ try :
166+ return gateway .jvm .org .apache .spark .sql .types .StructType .fromDDL (schema )
167+ except Exception :
168+ raise ValueError ("Cannot parse DDL schema: %s" % schema )
169+ if isinstance (schema , (list , tuple )) and schema and isinstance (schema [0 ], str ):
170+ return _schema_from_names (schema , data )
171+ raise ValueError ("Unsupported schema type: %s" % type (schema ).__name__ )
172+
173+
174+ def _infer_schema (data ):
175+ """Infer a Java StructType from the first element of the data."""
176+ if not data :
177+ raise ValueError ("Cannot infer schema from empty data without a schema" )
178+ DataTypes = gateway .jvm .org .apache .spark .sql .types .DataTypes
179+ first = data [0 ]
180+ if isinstance (first , Row ):
181+ names , values = list (first ._fields ), list (first )
182+ elif isinstance (first , dict ):
183+ names , values = list (first .keys ()), list (first .values ())
184+ elif isinstance (first , (list , tuple )):
185+ names = ["_%d" % (i + 1 ) for i in range (len (first ))]
186+ values = list (first )
187+ else :
188+ names , values = ["value" ], [first ]
189+ java_fields = gateway .jvm .java .util .ArrayList ()
190+ for i , name in enumerate (names ):
191+ java_fields .add (DataTypes .createStructField (
192+ name , _infer_java_type (values [i ] if i < len (values ) else None ), True ))
193+ return DataTypes .createStructType (java_fields )
194+
195+
196+ def _schema_from_names (col_names , data ):
197+ """Create a Java StructType from column name list, inferring types from data."""
198+ DataTypes = gateway .jvm .org .apache .spark .sql .types .DataTypes
199+ first = data [0 ] if data else None
200+ java_fields = gateway .jvm .java .util .ArrayList ()
201+ for i , name in enumerate (col_names ):
202+ jtype = DataTypes .StringType
203+ if first is not None :
204+ val = None
205+ if isinstance (first , (list , tuple )) and i < len (first ):
206+ val = first [i ]
207+ elif isinstance (first , dict ):
208+ val = first .get (name )
209+ elif isinstance (first , Row ) and i < len (first ):
210+ val = first [i ]
211+ if val is not None :
212+ jtype = _infer_java_type (val )
213+ java_fields .add (DataTypes .createStructField (name , jtype , True ))
214+ return DataTypes .createStructType (java_fields )
215+
216+
217+ def _to_java_rows (data , col_names ):
218+ """Convert Python data (list of Row/dict/tuple/list) to Java ArrayList<Row>."""
219+ RowFactory = gateway .jvm .org .apache .spark .sql .RowFactory
220+ java_rows = gateway .jvm .java .util .ArrayList ()
221+ for item in data :
222+ if isinstance (item , Row ):
223+ java_rows .add (RowFactory .create (* list (item )))
224+ elif isinstance (item , dict ):
225+ java_rows .add (RowFactory .create (* [item .get (c ) for c in col_names ]))
226+ elif isinstance (item , (list , tuple )):
227+ java_rows .add (RowFactory .create (* list (item )))
228+ else :
229+ java_rows .add (RowFactory .create (item ))
230+ return java_rows
231+
232+
76233class SparkConnectDataFrame (object ):
77234 """Wrapper around a Java Dataset<Row> with production-safe data retrieval."""
78235
@@ -253,6 +410,72 @@ def summary(self, *statistics):
253410 def isEmpty (self ):
254411 return self ._jdf .isEmpty ()
255412
413+ def repartition (self , numPartitions , * cols ):
414+ if cols :
415+ return SparkConnectDataFrame (self ._jdf .repartition (numPartitions , * cols ))
416+ return SparkConnectDataFrame (self ._jdf .repartition (numPartitions ))
417+
418+ def coalesce (self , numPartitions ):
419+ return SparkConnectDataFrame (self ._jdf .coalesce (numPartitions ))
420+
421+ def toDF (self , * cols ):
422+ return SparkConnectDataFrame (self ._jdf .toDF (* cols ))
423+
424+ def unionByName (self , other , allowMissingColumns = False ):
425+ other_jdf = other ._jdf if isinstance (other , SparkConnectDataFrame ) else other
426+ return SparkConnectDataFrame (
427+ self ._jdf .unionByName (other_jdf , allowMissingColumns ))
428+
429+ def crossJoin (self , other ):
430+ other_jdf = other ._jdf if isinstance (other , SparkConnectDataFrame ) else other
431+ return SparkConnectDataFrame (self ._jdf .crossJoin (other_jdf ))
432+
433+ def subtract (self , other ):
434+ other_jdf = other ._jdf if isinstance (other , SparkConnectDataFrame ) else other
435+ return SparkConnectDataFrame (self ._jdf .subtract (other_jdf ))
436+
437+ def intersect (self , other ):
438+ other_jdf = other ._jdf if isinstance (other , SparkConnectDataFrame ) else other
439+ return SparkConnectDataFrame (self ._jdf .intersect (other_jdf ))
440+
441+ def exceptAll (self , other ):
442+ other_jdf = other ._jdf if isinstance (other , SparkConnectDataFrame ) else other
443+ return SparkConnectDataFrame (self ._jdf .exceptAll (other_jdf ))
444+
445+ def sample (self , withReplacement = None , fraction = None , seed = None ):
446+ if withReplacement is None and fraction is None :
447+ raise ValueError ("fraction must be specified" )
448+ if isinstance (withReplacement , float ) and fraction is None :
449+ fraction = withReplacement
450+ withReplacement = False
451+ if withReplacement is None :
452+ withReplacement = False
453+ if seed is not None :
454+ return SparkConnectDataFrame (
455+ self ._jdf .sample (withReplacement , fraction , seed ))
456+ return SparkConnectDataFrame (
457+ self ._jdf .sample (withReplacement , fraction ))
458+
459+ def dropna (self , how = "any" , thresh = None , subset = None ):
460+ na = self ._jdf .na ()
461+ if thresh is not None :
462+ if subset :
463+ return SparkConnectDataFrame (na .drop (thresh , subset ))
464+ return SparkConnectDataFrame (na .drop (thresh ))
465+ if subset :
466+ return SparkConnectDataFrame (na .drop (how , subset ))
467+ return SparkConnectDataFrame (na .drop (how ))
468+
469+ def fillna (self , value , subset = None ):
470+ na = self ._jdf .na ()
471+ if subset :
472+ return SparkConnectDataFrame (na .fill (value , subset ))
473+ return SparkConnectDataFrame (na .fill (value ))
474+
475+ @property
476+ def write (self ):
477+ return self ._jdf .write ()
478+
256479 def __repr__ (self ):
257480 try :
258481 return "SparkConnectDataFrame[%s]" % ", " .join (
@@ -261,7 +484,15 @@ def __repr__(self):
261484 return "SparkConnectDataFrame[schema unavailable]"
262485
263486 def __getattr__ (self , name ):
264- return getattr (self ._jdf , name )
487+ attr = getattr (self ._jdf , name )
488+ if not callable (attr ):
489+ return attr
490+ def _method_wrapper (* args , ** kwargs ):
491+ result = attr (* args , ** kwargs )
492+ if _is_java_dataset (result ):
493+ return SparkConnectDataFrame (result )
494+ return result
495+ return _method_wrapper
265496
266497 def __iter__ (self ):
267498 """Safe iteration with default limit to prevent OOM."""
@@ -287,17 +518,34 @@ def read(self):
287518 return self ._jsession .read ()
288519
289520 def createDataFrame (self , data , schema = None ):
521+ """Create a SparkConnectDataFrame from Python data.
522+
523+ Supports:
524+ - data: list of Row, list of tuples, list of dicts, pandas DataFrame
525+ - schema: PySpark StructType, list of column names, DDL string,
526+ Java StructType (Py4j proxy), or None (infer from data)
527+ """
290528 try :
291529 import pandas as pd
292530 if isinstance (data , pd .DataFrame ):
293- warnings . warn (
294- "createDataFrame from pandas goes through Py4j serialization. "
295- "For large DataFrames, consider writing to a temp table instead." )
531+ if schema is None :
532+ schema = list ( data . columns )
533+ data = data . values . tolist ( )
296534 except ImportError :
297535 pass
298- if schema :
299- return SparkConnectDataFrame (self ._jsession .createDataFrame (data , schema ))
300- return SparkConnectDataFrame (self ._jsession .createDataFrame (data ))
536+
537+ if _is_java_object (data ):
538+ if schema is None :
539+ return SparkConnectDataFrame (self ._jsession .createDataFrame (data ))
540+ java_schema = _resolve_schema (schema , None )
541+ return SparkConnectDataFrame (
542+ self ._jsession .createDataFrame (data , java_schema ))
543+
544+ java_schema = _resolve_schema (schema , data )
545+ col_names = [f .name () for f in java_schema .fields ()]
546+ java_rows = _to_java_rows (data , col_names )
547+ return SparkConnectDataFrame (
548+ self ._jsession .createDataFrame (java_rows , java_schema ))
301549
302550 def range (self , start , end = None , step = 1 , numPartitions = None ):
303551 if end is None :
0 commit comments