@@ -597,6 +597,9 @@ def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]:
597597)
598598
599599
600+ _MODEL_CHUNK_SIZE = 4096 * 1024
601+
602+
600603class _SparkXGBEstimator (Estimator , _SparkXGBParams , MLReadable , MLWritable ):
601604 _input_kwargs : Dict [str , Any ]
602605
@@ -1091,25 +1094,27 @@ def _train_booster(
10911094 context .barrier ()
10921095
10931096 if context .partitionId () == 0 :
1094- yield pd .DataFrame (
1095- data = {
1096- "config" : [booster .save_config ()],
1097- "booster" : [booster .save_raw ("json" ).decode ("utf-8" )],
1098- }
1099- )
1097+ config = booster .save_config ()
1098+ yield pd .DataFrame ({"data" : [config ]})
1099+ booster_json = booster .save_raw ("json" ).decode ("utf-8" )
1100+
1101+ for offset in range (0 , len (booster_json ), _MODEL_CHUNK_SIZE ):
1102+ booster_chunk = booster_json [offset : offset + _MODEL_CHUNK_SIZE ]
1103+ yield pd .DataFrame ({"data" : [booster_chunk ]})
11001104
11011105 def _run_job () -> Tuple [str , str ]:
11021106 rdd = (
11031107 dataset .mapInPandas (
11041108 _train_booster , # type: ignore
1105- schema = "config string, booster string" ,
1109+ schema = "data string" ,
11061110 )
11071111 .rdd .barrier ()
11081112 .mapPartitions (lambda x : x )
11091113 )
11101114 rdd_with_resource = self ._try_stage_level_scheduling (rdd )
1111- ret = rdd_with_resource .collect ()[0 ]
1112- return ret [0 ], ret [1 ]
1115+ ret = rdd_with_resource .collect ()
1116+ data = [v [0 ] for v in ret ]
1117+ return data [0 ], "" .join (data [1 :])
11131118
11141119 get_logger (_LOG_TAG ).info (
11151120 "Running xgboost-%s on %s workers with"
@@ -1690,7 +1695,12 @@ def saveImpl(self, path: str) -> None:
16901695 _SparkXGBSharedReadWrite .saveMetadata (self .instance , path , self .sc , self .logger )
16911696 model_save_path = os .path .join (path , "model" )
16921697 booster = xgb_model .get_booster ().save_raw ("json" ).decode ("utf-8" )
1693- _get_spark_session ().sparkContext .parallelize ([booster ], 1 ).saveAsTextFile (
1698+ booster_chunks = []
1699+
1700+ for offset in range (0 , len (booster ), _MODEL_CHUNK_SIZE ):
1701+ booster_chunks .append (booster [offset : offset + _MODEL_CHUNK_SIZE ])
1702+
1703+ _get_spark_session ().sparkContext .parallelize (booster_chunks , 1 ).saveAsTextFile (
16941704 model_save_path
16951705 )
16961706
@@ -1721,8 +1731,8 @@ def load(self, path: str) -> "_SparkXGBModel":
17211731 )
17221732 model_load_path = os .path .join (path , "model" )
17231733
1724- ser_xgb_model = (
1725- _get_spark_session ().sparkContext .textFile (model_load_path ).collect ()[ 0 ]
1734+ ser_xgb_model = "" . join (
1735+ _get_spark_session ().sparkContext .textFile (model_load_path ).collect ()
17261736 )
17271737
17281738 def create_xgb_model () -> "XGBModel" :
0 commit comments