添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement . We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train a cross validation XGBoostClassifierModel with Spark ML API to leverage Spark parallelism and nice hyper-parameter tuning API while stopping early if eval set metric is not improving after n iterations.

Issue

Early stopping and eval_sets capabilities do not work when using a CrossValidator , whereas it works fine when fitting a simple XGBoostClassifier .

Config:

  • Xgboost4j-0.90
  • Xgboost4j-spark-0.90
  • Spark-2.4.3
  • Details

    Training a XGBoostClassifier with eval_sets work fine:

    val xgbParam = Map(
      "seed" -> 42L,
      "objective" -> "binary:logistic",
      "eta" -> 1.0,
      "missing" -> -1.0,
      "num_round" -> 200,
      "max_depth" -> 12,
      "gamma" -> 0,
      "alpha" -> 0.9,
      "lambda" -> 1.0,
      "subsample" -> 1.0,
      "colsample_bytree" -> 0.8,
      "eval_sets" -> Map("testset" -> testset),
      "nthread" -> nthread,
      "num_workers" -> 1
    val classifier: XGBoostClassifier = new XGBoostClassifier(xgbParam).
         |   setProbabilityCol("probabilities").
         |   setFeaturesCol("features"). // hard coded because it is a field of LabeledPairOfItemWithFeatures (dataset)
         |   setLabelCol("label").
         |   setNumEarlyStoppingRounds(2).
         |   setMaximizeEvaluationMetrics(true).
         |   setEvalMetric("auc")

    After fitting, the test error can be found in the summary (despite the warning about train_test_ratio!).

    scala>  val m = classifier.fit(trainset)
    19/09/24 03:20:27 WARN XGBoostSpark: train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly pass a training and multiple evaluation datasets by passing 'eval_sets' and 'eval_set_names'
    Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.0.14.213, DMLC_TRACKER_PORT=9091, DMLC_NUM_WORKER=1}
    19/09/24 03:20:27 WARN XGBoostSpark: train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly pass a training and multiple evaluation datasets by passing 'eval_sets' and 'eval_set_names'
    19/09/24 03:20:27 WARN XGBoostSpark: Unable to read total number of alive cores from REST API.Health Check will be ignored.
    java.net.ConnectException: Connection refused (Connection refused)
            at java.net.PlainSocketImpl.socketConnect(Native Method)
            at java.net.AbstractPlainSocketImpl.doConnect(AbstractPlainSocketImpl.java:350)
            at java.net.AbstractPlainSocketImpl.connectToAddress(AbstractPlainSocketImpl.java:206)
            at java.net.AbstractPlainSocketImpl.connect(AbstractPlainSocketImpl.java:188)
            at java.net.SocksSocketImpl.connect(SocksSocketImpl.java:392)
            at java.net.Socket.connect(Socket.java:589)
            at java.net.Socket.connect(Socket.java:538)
            at sun.net.NetworkClient.doConnect(NetworkClient.java:180)
            at sun.net.www.http.HttpClient.openServer(HttpClient.java:463)
            at sun.net.www.http.HttpClient.openServer(HttpClient.java:558)
            at sun.net.www.http.HttpClient.<init>(HttpClient.java:242)
            at sun.net.www.http.HttpClient.New(HttpClient.java:339)
            at sun.net.www.http.HttpClient.New(HttpClient.java:357)
            at sun.net.www.protocol.http.HttpURLConnection.getNewHttpClient(HttpURLConnection.java:1226)
            at sun.net.www.protocol.http.HttpURLConnection.plainConnect0(HttpURLConnection.java:1162)
            at sun.net.www.protocol.http.HttpURLConnection.plainConnect(HttpURLConnection.java:1056)
            at sun.net.www.protocol.http.HttpURLConnection.connect(HttpURLConnection.java:990)
            at sun.net.www.protocol.http.HttpURLConnection.getInputStream0(HttpURLConnection.java:1570)
            at sun.net.www.protocol.http.HttpURLConnection.getInputStream(HttpURLConnection.java:1498)
            at java.net.URL.openStream(URL.java:1057)
            at org.codehaus.jackson.JsonFactory._optimizedStreamFromURL(JsonFactory.java:935)
            at org.codehaus.jackson.JsonFactory.createJsonParser(JsonFactory.java:530)
            at org.codehaus.jackson.map.ObjectMapper.readTree(ObjectMapper.java:1590)
            at org.apache.spark.SparkParallelismTracker.org$apache$spark$SparkParallelismTracker$$numAliveCores(SparkParallelismTracker.scala:54)
            at org.apache.spark.SparkParallelismTracker$$anonfun$execute$1.apply$mcZ$sp(SparkParallelismTracker.scala:103)
            at org.apache.spark.SparkParallelismTracker$$anonfun$1.apply$mcV$sp(SparkParallelismTracker.scala:72)
            at org.apache.spark.SparkParallelismTracker$$anonfun$1.apply(SparkParallelismTracker.scala:72)
            at org.apache.spark.SparkParallelismTracker$$anonfun$1.apply(SparkParallelismTracker.scala:72)
            at scala.concurrent.impl.Future$PromiseCompletingRunnable.liftedTree1$1(Future.scala:24)
            at scala.concurrent.impl.Future$PromiseCompletingRunnable.run(Future.scala:24)
            at scala.concurrent.impl.ExecutionContextImpl$AdaptedForkJoinTask.exec(ExecutionContextImpl.scala:121)
            at scala.concurrent.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:260)
            at scala.concurrent.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:1339)
            at scala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979)
            at scala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107)
    scala> m.summary.validationObjectiveHistory
    res58: Seq[(String, Array[Float])] = ArrayBuffer((testset,Array(0.861355, 0.874199, 0.877361, 0.879096, 0.87939, 0.878349, 0.879181, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...
    scala> m.summary.trainObjectiveHistory
    res59: Array[Float] = Array(0.873627, 0.899988, 0.911056, 0.919834, 0.927977, 0.936445, 0.939293, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...

    Note that the training as stopped after the 7th iteration as expected.

    However, when using a CrossValidator:

    val evaluator = new BinaryClassificationEvaluator().
      setMetricName("areaUnderROC").
      setRawPredictionCol("probabilities").
      setLabelCol("label")
    val paramGrid = new ParamGridBuilder().
        addGrid(classifier.maxDepth, Array(10)).
        addGrid(classifier.eta, Array(0.01,0.03)).
        addGrid(classifier.lambda, Array(1.0)).
        build()
    val pipeline = new Pipeline().setStages(Array(classifier))
    val cv = new CrossValidator().
        setEstimator(pipeline).
        setEvaluator(evaluator).
        setEstimatorParamMaps(paramGrid).
        setNumFolds(5).
        setParallelism(2)
    // Fitting the model
    val cvModel = cv.fit(trainset)
    cvModel.avgMetrics

    After fitting and extracting the best model, the test error is missing in the summary, and the model did not stop early like in the simpler case.

    scala> val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel].stages(0).asInstanceOf[XGBoostClassificationModel]
    scala> bestModel.summary.validationObjectiveHistory
    res63: Seq[(String, Array[Float])] = ArrayBuffer()
    scala> bestModel.summary.trainObjectiveHistory
    res64: Array[Float] = Array(0.860007, 0.867117, 0.868701, 0.87272, 0.872711, 0.873824, 0.875132, 0.876343, 0.877166, 0.877516, 0.877863, 0.878515, 0.878807, 0.879681, 0.880481, 0.881074, 0.881616, 0.881843, 0.88226, 0.882432, 0.882846, 0.883575, 0.883963, 0.884313, 0.884672, 0.885035, 0.885229, 0.885641, 0.886056, 0.886481, 0.886975, 0.887251, 0.887573, 0.887877, 0.888422, 0.88864, 0.888901, 0.889309, 0.889529, 0.890081, 0.890419, 0.890664, 0.890913, 0.891132, 0.891503, 0.89168, 0.891954, 0.892234, 0.892456, 0.892693, 0.893012, 0.893268, 0.893707, 0.894003, 0.894187, 0.89441, 0.89481, 0.895115, 0.895422, 0.895757, 0.896122, 0.896326, 0.896599, 0.896974, 0.897242, 0.897566, 0.898031, 0.898379, 0.898652, 0.899014, 0.899408, 0.899801, 0.900234, 0.900636, 0.901094, 0.901349, 0.901828, 0.902...
    scala> bestModel.summary.trainObjectiveHistory.filter(_>0.0).length
    res65: Int = 200

    Guess

    My 2 cents guess: eval_sets is not a native spark parameter and cannot be serialized nor copied when integrated in a CrossValidator() . As a conclusion, it seems to be impossible to use Spark ml API and early stopping / eval_sets simultaneously.

    My 1 cent suggestion for work around: specify the path of the testset rather the dataframe in order to serialize all the parameters in a standard way.

    If there is already a good way to circumvent this issue, please let me know.

    Thanks
    [JVM-PACKAGE] eval_sets not taken into account for [jvm-packages] eval_sets not taken into account for Oct 24, 2019

    @borisclemencon you can use trainTestRatio when create XGBoostClassifier, and then in gridsearch, it will print the test metics
    val xgbClassifier = new XGBoostClassifier(paramMap) .setFeaturesCol("features_vector") .setLabelCol("label") .setPredictionCol("prediction") .setEvalSets(wctchVals) .setTrainTestRatio(0.7)

    for (test <- bestModel.summary.validationObjectiveHistory){ println("val name: " + test._1) println(test._2.mkString(",")) for(test_val <- test._2){ println(test_val) } }

    Hello Linlin,

    thanks for your answer. Your solution allows me to retrieve my validation error and actually enforce early stopping while training a PipelineModel, which was the goal of this issue. To go a bit further, it would be nice still to:

    be able to specify a particular validation test that is not sampled from the training set, like it is possible when training a XGBoostClassifier alone. In some cases like mine, I cannot guaranty a strict independence between the instance of my training set (because of the expensive sampling method). However, I do have another fold which is independent by construction. it would be nice to be able to use it.

    suppress the annoying warning: WARN XGBoostSpark: train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly pass a training and multiple evaluation datasets by passing 'eval_sets' and 'eval_set_names' until we find a better solution

    Send warning or error when using broken APIs. The API is error prone when using MLLib Pipelines. For instance, the setEvalSets() API is open whereas it does not work in practice without sending any error or at least warning.

    Thanks again for the trick