Skip to content

Commit b9c2587

Browse files
installing required packages
1 parent 78f213b commit b9c2587

3 files changed

Lines changed: 545 additions & 23 deletions

File tree

scripts/docker/zeppelin-interpreter/env_python_3_with_R.yml

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,55 @@ channels:
44
- defaults
55
dependencies:
66
- python >=3.9,<3.10
7-
- pyspark=3.3.2
7+
- pyspark=3.5
88
- pycodestyle
9-
- scipy
9+
# --- Core data libraries ---
10+
- pandas
1011
- numpy
12+
- scipy
13+
- pyarrow
14+
# --- Spark Connect protocol ---
1115
- grpcio
1216
- protobuf
17+
# --- HTTP / networking ---
18+
- requests
19+
- urllib3
20+
# --- File format support ---
21+
- openpyxl
22+
- xlrd
23+
- pyyaml
24+
- tabulate
25+
# --- GCP access ---
26+
- google-cloud-storage
27+
- google-auth
28+
- gcsfs
29+
# --- Visualization ---
30+
- matplotlib
31+
- seaborn
32+
- plotly
33+
- plotnine
34+
- altair
35+
- vega_datasets
36+
- hvplot
37+
# --- SQL on pandas ---
1338
- pandasql
39+
# --- ML ---
40+
- scikit-learn
41+
- xgboost
42+
# --- IPython / kernel ---
1443
- ipython
1544
- ipykernel
1645
- jupyter_client
17-
- hvplot
18-
- plotnine
19-
- seaborn
46+
# --- Data connectors ---
2047
- intake
2148
- intake-parquet
2249
- intake-xarray
23-
- altair
24-
- vega_datasets
25-
- plotly
50+
# --- pip-only packages ---
2651
- pip
2752
- pip:
28-
# works for regular pip packages
2953
- bkzep==0.6.1
54+
- delta-spark==3.2.1
55+
# --- R support ---
3056
- r-base=3
3157
- r-data.table
3258
- r-evaluate

spark-connect/src/main/resources/python/zeppelin_isparkconnect.py

Lines changed: 255 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,163 @@ def _convert_java_rows(jdf):
7373
return [_convert_java_row(r, col_names) for r in jrows]
7474

7575

76+
# ---------------------------------------------------------------------------
77+
# Py4j / type-conversion helpers for createDataFrame and __getattr__
78+
# ---------------------------------------------------------------------------
79+
80+
def _is_java_object(obj):
81+
"""Check if obj is a Py4j proxy."""
82+
return hasattr(obj, '_get_object_id')
83+
84+
85+
def _is_java_dataset(obj):
86+
"""Check if a Py4j proxy represents a Spark Dataset."""
87+
if not _is_java_object(obj):
88+
return False
89+
try:
90+
return 'Dataset' in obj.getClass().getName()
91+
except Exception:
92+
return False
93+
94+
95+
_PYSPARK_TO_JAVA_TYPES = {
96+
'StringType': 'StringType',
97+
'IntegerType': 'IntegerType',
98+
'LongType': 'LongType',
99+
'DoubleType': 'DoubleType',
100+
'FloatType': 'FloatType',
101+
'BooleanType': 'BooleanType',
102+
'ShortType': 'ShortType',
103+
'ByteType': 'ByteType',
104+
'DateType': 'DateType',
105+
'TimestampType': 'TimestampType',
106+
'BinaryType': 'BinaryType',
107+
'NullType': 'NullType',
108+
}
109+
110+
111+
def _pyspark_type_to_java(dt):
112+
"""Convert a PySpark DataType instance to a Java DataType via Py4j gateway."""
113+
DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
114+
type_name = type(dt).__name__
115+
if type_name in _PYSPARK_TO_JAVA_TYPES:
116+
return getattr(DataTypes, _PYSPARK_TO_JAVA_TYPES[type_name])
117+
if type_name == 'DecimalType':
118+
return DataTypes.createDecimalType(dt.precision, dt.scale)
119+
if type_name == 'ArrayType':
120+
return DataTypes.createArrayType(
121+
_pyspark_type_to_java(dt.elementType), dt.containsNull)
122+
if type_name == 'MapType':
123+
return DataTypes.createMapType(
124+
_pyspark_type_to_java(dt.keyType),
125+
_pyspark_type_to_java(dt.valueType), dt.valueContainsNull)
126+
if type_name == 'StructType':
127+
return _pyspark_schema_to_java(dt)
128+
return DataTypes.StringType
129+
130+
131+
def _pyspark_schema_to_java(pyspark_schema):
132+
"""Convert a PySpark StructType to a Java StructType."""
133+
DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
134+
java_fields = gateway.jvm.java.util.ArrayList()
135+
for field in pyspark_schema.fields:
136+
jtype = _pyspark_type_to_java(field.dataType)
137+
java_fields.add(DataTypes.createStructField(
138+
field.name, jtype, getattr(field, 'nullable', True)))
139+
return DataTypes.createStructType(java_fields)
140+
141+
142+
def _infer_java_type(value):
143+
"""Infer a Java DataType from a Python value."""
144+
DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
145+
if value is None:
146+
return DataTypes.StringType
147+
if isinstance(value, bool):
148+
return DataTypes.BooleanType
149+
if isinstance(value, int):
150+
return DataTypes.LongType if abs(value) > 2147483647 else DataTypes.IntegerType
151+
if isinstance(value, float):
152+
return DataTypes.DoubleType
153+
return DataTypes.StringType
154+
155+
156+
def _resolve_schema(schema, data):
157+
"""Resolve any schema representation to a Java StructType."""
158+
if schema is None:
159+
return _infer_schema(data)
160+
if _is_java_object(schema):
161+
return schema
162+
if hasattr(schema, 'fields') and not _is_java_object(schema):
163+
return _pyspark_schema_to_java(schema)
164+
if isinstance(schema, str):
165+
try:
166+
return gateway.jvm.org.apache.spark.sql.types.StructType.fromDDL(schema)
167+
except Exception:
168+
raise ValueError("Cannot parse DDL schema: %s" % schema)
169+
if isinstance(schema, (list, tuple)) and schema and isinstance(schema[0], str):
170+
return _schema_from_names(schema, data)
171+
raise ValueError("Unsupported schema type: %s" % type(schema).__name__)
172+
173+
174+
def _infer_schema(data):
175+
"""Infer a Java StructType from the first element of the data."""
176+
if not data:
177+
raise ValueError("Cannot infer schema from empty data without a schema")
178+
DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
179+
first = data[0]
180+
if isinstance(first, Row):
181+
names, values = list(first._fields), list(first)
182+
elif isinstance(first, dict):
183+
names, values = list(first.keys()), list(first.values())
184+
elif isinstance(first, (list, tuple)):
185+
names = ["_%d" % (i + 1) for i in range(len(first))]
186+
values = list(first)
187+
else:
188+
names, values = ["value"], [first]
189+
java_fields = gateway.jvm.java.util.ArrayList()
190+
for i, name in enumerate(names):
191+
java_fields.add(DataTypes.createStructField(
192+
name, _infer_java_type(values[i] if i < len(values) else None), True))
193+
return DataTypes.createStructType(java_fields)
194+
195+
196+
def _schema_from_names(col_names, data):
197+
"""Create a Java StructType from column name list, inferring types from data."""
198+
DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
199+
first = data[0] if data else None
200+
java_fields = gateway.jvm.java.util.ArrayList()
201+
for i, name in enumerate(col_names):
202+
jtype = DataTypes.StringType
203+
if first is not None:
204+
val = None
205+
if isinstance(first, (list, tuple)) and i < len(first):
206+
val = first[i]
207+
elif isinstance(first, dict):
208+
val = first.get(name)
209+
elif isinstance(first, Row) and i < len(first):
210+
val = first[i]
211+
if val is not None:
212+
jtype = _infer_java_type(val)
213+
java_fields.add(DataTypes.createStructField(name, jtype, True))
214+
return DataTypes.createStructType(java_fields)
215+
216+
217+
def _to_java_rows(data, col_names):
218+
"""Convert Python data (list of Row/dict/tuple/list) to Java ArrayList<Row>."""
219+
RowFactory = gateway.jvm.org.apache.spark.sql.RowFactory
220+
java_rows = gateway.jvm.java.util.ArrayList()
221+
for item in data:
222+
if isinstance(item, Row):
223+
java_rows.add(RowFactory.create(*list(item)))
224+
elif isinstance(item, dict):
225+
java_rows.add(RowFactory.create(*[item.get(c) for c in col_names]))
226+
elif isinstance(item, (list, tuple)):
227+
java_rows.add(RowFactory.create(*list(item)))
228+
else:
229+
java_rows.add(RowFactory.create(item))
230+
return java_rows
231+
232+
76233
class SparkConnectDataFrame(object):
77234
"""Wrapper around a Java Dataset<Row> with production-safe data retrieval."""
78235

@@ -253,6 +410,72 @@ def summary(self, *statistics):
253410
def isEmpty(self):
254411
return self._jdf.isEmpty()
255412

413+
def repartition(self, numPartitions, *cols):
414+
if cols:
415+
return SparkConnectDataFrame(self._jdf.repartition(numPartitions, *cols))
416+
return SparkConnectDataFrame(self._jdf.repartition(numPartitions))
417+
418+
def coalesce(self, numPartitions):
419+
return SparkConnectDataFrame(self._jdf.coalesce(numPartitions))
420+
421+
def toDF(self, *cols):
422+
return SparkConnectDataFrame(self._jdf.toDF(*cols))
423+
424+
def unionByName(self, other, allowMissingColumns=False):
425+
other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
426+
return SparkConnectDataFrame(
427+
self._jdf.unionByName(other_jdf, allowMissingColumns))
428+
429+
def crossJoin(self, other):
430+
other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
431+
return SparkConnectDataFrame(self._jdf.crossJoin(other_jdf))
432+
433+
def subtract(self, other):
434+
other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
435+
return SparkConnectDataFrame(self._jdf.subtract(other_jdf))
436+
437+
def intersect(self, other):
438+
other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
439+
return SparkConnectDataFrame(self._jdf.intersect(other_jdf))
440+
441+
def exceptAll(self, other):
442+
other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
443+
return SparkConnectDataFrame(self._jdf.exceptAll(other_jdf))
444+
445+
def sample(self, withReplacement=None, fraction=None, seed=None):
446+
if withReplacement is None and fraction is None:
447+
raise ValueError("fraction must be specified")
448+
if isinstance(withReplacement, float) and fraction is None:
449+
fraction = withReplacement
450+
withReplacement = False
451+
if withReplacement is None:
452+
withReplacement = False
453+
if seed is not None:
454+
return SparkConnectDataFrame(
455+
self._jdf.sample(withReplacement, fraction, seed))
456+
return SparkConnectDataFrame(
457+
self._jdf.sample(withReplacement, fraction))
458+
459+
def dropna(self, how="any", thresh=None, subset=None):
460+
na = self._jdf.na()
461+
if thresh is not None:
462+
if subset:
463+
return SparkConnectDataFrame(na.drop(thresh, subset))
464+
return SparkConnectDataFrame(na.drop(thresh))
465+
if subset:
466+
return SparkConnectDataFrame(na.drop(how, subset))
467+
return SparkConnectDataFrame(na.drop(how))
468+
469+
def fillna(self, value, subset=None):
470+
na = self._jdf.na()
471+
if subset:
472+
return SparkConnectDataFrame(na.fill(value, subset))
473+
return SparkConnectDataFrame(na.fill(value))
474+
475+
@property
476+
def write(self):
477+
return self._jdf.write()
478+
256479
def __repr__(self):
257480
try:
258481
return "SparkConnectDataFrame[%s]" % ", ".join(
@@ -261,7 +484,15 @@ def __repr__(self):
261484
return "SparkConnectDataFrame[schema unavailable]"
262485

263486
def __getattr__(self, name):
264-
return getattr(self._jdf, name)
487+
attr = getattr(self._jdf, name)
488+
if not callable(attr):
489+
return attr
490+
def _method_wrapper(*args, **kwargs):
491+
result = attr(*args, **kwargs)
492+
if _is_java_dataset(result):
493+
return SparkConnectDataFrame(result)
494+
return result
495+
return _method_wrapper
265496

266497
def __iter__(self):
267498
"""Safe iteration with default limit to prevent OOM."""
@@ -287,17 +518,34 @@ def read(self):
287518
return self._jsession.read()
288519

289520
def createDataFrame(self, data, schema=None):
521+
"""Create a SparkConnectDataFrame from Python data.
522+
523+
Supports:
524+
- data: list of Row, list of tuples, list of dicts, pandas DataFrame
525+
- schema: PySpark StructType, list of column names, DDL string,
526+
Java StructType (Py4j proxy), or None (infer from data)
527+
"""
290528
try:
291529
import pandas as pd
292530
if isinstance(data, pd.DataFrame):
293-
warnings.warn(
294-
"createDataFrame from pandas goes through Py4j serialization. "
295-
"For large DataFrames, consider writing to a temp table instead.")
531+
if schema is None:
532+
schema = list(data.columns)
533+
data = data.values.tolist()
296534
except ImportError:
297535
pass
298-
if schema:
299-
return SparkConnectDataFrame(self._jsession.createDataFrame(data, schema))
300-
return SparkConnectDataFrame(self._jsession.createDataFrame(data))
536+
537+
if _is_java_object(data):
538+
if schema is None:
539+
return SparkConnectDataFrame(self._jsession.createDataFrame(data))
540+
java_schema = _resolve_schema(schema, None)
541+
return SparkConnectDataFrame(
542+
self._jsession.createDataFrame(data, java_schema))
543+
544+
java_schema = _resolve_schema(schema, data)
545+
col_names = [f.name() for f in java_schema.fields()]
546+
java_rows = _to_java_rows(data, col_names)
547+
return SparkConnectDataFrame(
548+
self._jsession.createDataFrame(java_rows, java_schema))
301549

302550
def range(self, start, end=None, step=1, numPartitions=None):
303551
if end is None:

0 commit comments

Comments
 (0)