Train a cross validation XGBoostClassifierModel with Spark ML API to leverage Spark parallelism and nice hyper-parameter tuning API while stopping early if eval set metric is not improving after n iterations.
After fitting, the test error can be found in the summary (despite the warning about train_test_ratio!).
scala> val m = classifier.fit(trainset)
19/09/24 03:20:27 WARN XGBoostSpark: train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly pass a training and multiple evaluation datasets by passing 'eval_sets' and 'eval_set_names'
Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.0.14.213, DMLC_TRACKER_PORT=9091, DMLC_NUM_WORKER=1}
19/09/24 03:20:27 WARN XGBoostSpark: train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly pass a training and multiple evaluation datasets by passing 'eval_sets' and 'eval_set_names'
19/09/24 03:20:27 WARN XGBoostSpark: Unable to read total number of alive cores from REST API.Health Check will be ignored.
java.net.ConnectException: Connection refused (Connection refused)
at java.net.PlainSocketImpl.socketConnect(Native Method)
at java.net.AbstractPlainSocketImpl.doConnect(AbstractPlainSocketImpl.java:350)
at java.net.AbstractPlainSocketImpl.connectToAddress(AbstractPlainSocketImpl.java:206)
at java.net.AbstractPlainSocketImpl.connect(AbstractPlainSocketImpl.java:188)
at java.net.SocksSocketImpl.connect(SocksSocketImpl.java:392)
at java.net.Socket.connect(Socket.java:589)
at java.net.Socket.connect(Socket.java:538)
at sun.net.NetworkClient.doConnect(NetworkClient.java:180)
at sun.net.www.http.HttpClient.openServer(HttpClient.java:463)
at sun.net.www.http.HttpClient.openServer(HttpClient.java:558)
at sun.net.www.http.HttpClient.<init>(HttpClient.java:242)
at sun.net.www.http.HttpClient.New(HttpClient.java:339)
at sun.net.www.http.HttpClient.New(HttpClient.java:357)
at sun.net.www.protocol.http.HttpURLConnection.getNewHttpClient(HttpURLConnection.java:1226)
at sun.net.www.protocol.http.HttpURLConnection.plainConnect0(HttpURLConnection.java:1162)
at sun.net.www.protocol.http.HttpURLConnection.plainConnect(HttpURLConnection.java:1056)
at sun.net.www.protocol.http.HttpURLConnection.connect(HttpURLConnection.java:990)
at sun.net.www.protocol.http.HttpURLConnection.getInputStream0(HttpURLConnection.java:1570)
at sun.net.www.protocol.http.HttpURLConnection.getInputStream(HttpURLConnection.java:1498)
at java.net.URL.openStream(URL.java:1057)
at org.codehaus.jackson.JsonFactory._optimizedStreamFromURL(JsonFactory.java:935)
at org.codehaus.jackson.JsonFactory.createJsonParser(JsonFactory.java:530)
at org.codehaus.jackson.map.ObjectMapper.readTree(ObjectMapper.java:1590)
at org.apache.spark.SparkParallelismTracker.org$apache$spark$SparkParallelismTracker$$numAliveCores(SparkParallelismTracker.scala:54)
at org.apache.spark.SparkParallelismTracker$$anonfun$execute$1.apply$mcZ$sp(SparkParallelismTracker.scala:103)
at org.apache.spark.SparkParallelismTracker$$anonfun$1.apply$mcV$sp(SparkParallelismTracker.scala:72)
at org.apache.spark.SparkParallelismTracker$$anonfun$1.apply(SparkParallelismTracker.scala:72)
at org.apache.spark.SparkParallelismTracker$$anonfun$1.apply(SparkParallelismTracker.scala:72)
at scala.concurrent.impl.Future$PromiseCompletingRunnable.liftedTree1$1(Future.scala:24)
at scala.concurrent.impl.Future$PromiseCompletingRunnable.run(Future.scala:24)
at scala.concurrent.impl.ExecutionContextImpl$AdaptedForkJoinTask.exec(ExecutionContextImpl.scala:121)
at scala.concurrent.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:260)
at scala.concurrent.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:1339)
at scala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979)
at scala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107)
scala> m.summary.validationObjectiveHistory
res58: Seq[(String, Array[Float])] = ArrayBuffer((testset,Array(0.861355, 0.874199, 0.877361, 0.879096, 0.87939, 0.878349, 0.879181, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...
scala> m.summary.trainObjectiveHistory
res59: Array[Float] = Array(0.873627, 0.899988, 0.911056, 0.919834, 0.927977, 0.936445, 0.939293, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...
After fitting and extracting the best model, the test error is missing in the summary, and the model did not stop early like in the simpler case.
scala> val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel].stages(0).asInstanceOf[XGBoostClassificationModel]
scala> bestModel.summary.validationObjectiveHistory
res63: Seq[(String, Array[Float])] = ArrayBuffer()
scala> bestModel.summary.trainObjectiveHistory
res64: Array[Float] = Array(0.860007, 0.867117, 0.868701, 0.87272, 0.872711, 0.873824, 0.875132, 0.876343, 0.877166, 0.877516, 0.877863, 0.878515, 0.878807, 0.879681, 0.880481, 0.881074, 0.881616, 0.881843, 0.88226, 0.882432, 0.882846, 0.883575, 0.883963, 0.884313, 0.884672, 0.885035, 0.885229, 0.885641, 0.886056, 0.886481, 0.886975, 0.887251, 0.887573, 0.887877, 0.888422, 0.88864, 0.888901, 0.889309, 0.889529, 0.890081, 0.890419, 0.890664, 0.890913, 0.891132, 0.891503, 0.89168, 0.891954, 0.892234, 0.892456, 0.892693, 0.893012, 0.893268, 0.893707, 0.894003, 0.894187, 0.89441, 0.89481, 0.895115, 0.895422, 0.895757, 0.896122, 0.896326, 0.896599, 0.896974, 0.897242, 0.897566, 0.898031, 0.898379, 0.898652, 0.899014, 0.899408, 0.899801, 0.900234, 0.900636, 0.901094, 0.901349, 0.901828, 0.902...
scala> bestModel.summary.trainObjectiveHistory.filter(_>0.0).length
res65: Int = 200
My 2 cents guess: eval_sets is not a native spark parameter and cannot be serialized nor copied when integrated in a
CrossValidator()
. As a conclusion, it seems to be impossible to use Spark ml API and early stopping / eval_sets simultaneously.
My 1 cent suggestion for work around: specify the path of the testset rather the dataframe in order to serialize all the parameters in a standard way.
If there is already a good way to circumvent this issue, please let me know.
thanks for your answer. Your solution allows me to retrieve my validation error and actually enforce early stopping while training a PipelineModel, which was the goal of this issue. To go a bit further, it would be nice still to:
be able to specify a particular validation test that is not sampled from the training set, like it is possible when training a XGBoostClassifier alone. In some cases like mine, I cannot guaranty a strict independence between the instance of my training set (because of the expensive sampling method). However, I do have another fold which is independent by construction. it would be nice to be able to use it.
Send warning or error when using broken APIs. The API is error prone when using MLLib Pipelines. For instance, the setEvalSets() API is open whereas it does not work in practice without sending any error or at least warning.