package com.odianyun.horse.spark.salesprediction;

import com.odianyun.horse.api.common.DateUtil;
import com.odianyun.horse.spark.common.EnvUtil$;
import com.odianyun.horse.spark.common.SparkSessionBuilder$;
import com.odianyun.horse.spark.ds.DataSetRequest;
import com.odianyun.horse.spark.ml.algorithm.xgboost.XGBoostSpark$;
import com.odianyun.horse.store.hdfs.HDFSUtil;
import java.util.Date;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ListBuffer;
import scala.collection.mutable.ListBuffer$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;

/* compiled from: XGBoostSalesPrediction.scala */
/* loaded from: input_file:com/odianyun/horse/spark/salesprediction/XGBoostSalesPrediction$.class */
public final class XGBoostSalesPrediction$ {
    public static final XGBoostSalesPrediction$ MODULE$ = null;
    private final String predictDataSql;
    private final String trainDataSql;
    private final String modelHadoopPath;

    static {
        new XGBoostSalesPrediction$();
    }

    public String predictDataSql() {
        return this.predictDataSql;
    }

    public String trainDataSql() {
        return this.trainDataSql;
    }

    public String modelHadoopPath() {
        return this.modelHadoopPath;
    }

    public RDD<Row> predict(DataSetRequest dataSetRequest) {
        SparkSession build = SparkSessionBuilder$.MODULE$.build(getClass().getSimpleName());
        RDD<Tuple2<String, double[]>> map = build.sql(predictDataSql().replaceAll("#dt#", dataSetRequest.getStartDate()).replaceAll("#env#", dataSetRequest.env())).rdd().map(new XGBoostSalesPrediction$$anonfun$1(), ClassTag$.MODULE$.apply(Tuple2.class));
        String str = build.sparkContext().hadoopConfiguration().get("fs.defaultFS");
        String replaceAll = modelHadoopPath().replaceAll("#env#", dataSetRequest.env());
        if (!Predef$.MODULE$.Boolean2boolean(HDFSUtil.isExists(replaceAll))) {
            return build.sparkContext().makeRDD(Nil$.MODULE$, ClassTag$.MODULE$.apply(Row.class));
        }
        return XGBoostSpark$.MODULE$.regressionPredict(map, 0.0d, new StringBuilder().append(replaceAll).append(new HDFSUtil(str).getLastestChildrenPathName(replaceAll)).toString()).map(new XGBoostSalesPrediction$$anonfun$2(), ClassTag$.MODULE$.apply(Row.class));
    }

    public void train(DataSetRequest dataSetRequest) {
        SparkSession build = SparkSessionBuilder$.MODULE$.build(getClass().getSimpleName());
        String startDate = dataSetRequest.getStartDate();
        String endDate = dataSetRequest.getEndDate();
        Dataset sql = build.sql(trainDataSql().replaceAll("#env#", dataSetRequest.env()).replaceAll("#start_dt#", startDate).replaceAll("#end_dt#", endDate));
        if (sql.rdd().count() > 2) {
            RDD<LabeledPoint> map = sql.rdd().map(new XGBoostSalesPrediction$$anonfun$3(), ClassTag$.MODULE$.apply(LabeledPoint.class));
            String str = build.sparkContext().hadoopConfiguration().get("fs.defaultFS");
            String replaceAll = modelHadoopPath().replaceAll("#env#", dataSetRequest.env());
            XGBoostSpark$.MODULE$.regressionTrainAndSaveHadoopPath(map, new StringBuilder().append(replaceAll).append(DateUtil.format(new Date(), DateUtil.YMD_HMS_FORMAT_underline)).append("###").append(startDate).append("_").append(endDate).toString());
            new HDFSUtil(str).deleteChildrenPath(replaceAll, 20);
        }
    }

    public double[] com$odianyun$horse$spark$salesprediction$XGBoostSalesPrediction$$toArrayDouble(Seq<Object> seq) {
        ListBuffer apply = ListBuffer$.MODULE$.apply(Nil$.MODULE$);
        seq.map(new XGBoostSalesPrediction$$anonfun$com$odianyun$horse$spark$salesprediction$XGBoostSalesPrediction$$toArrayDouble$1(apply), Seq$.MODULE$.canBuildFrom());
        return (double[]) apply.toArray(ClassTag$.MODULE$.Double());
    }

    public void main(String[] strArr) {
        train(EnvUtil$.MODULE$.convert(strArr));
    }

    private XGBoostSalesPrediction$() {
        MODULE$ = this;
        this.predictDataSql = new StringOps(Predef$.MODULE$.augmentString("\n        |select company_id,\n        |mp_id,\n        |predict_dt,\n        |price,is_promotion,is_regular_product,positive_count,comments,color,size,price_elastic,is_new_product,\n        |promotion_level,festival,weekday,month,season,weather\n        |from dwd.dwd_mp_merchant_product_prediction_plan_inc\n        |where env='#env#' and dt='#dt#' and mode_code='xgboost'\n      ")).stripMargin();
        this.trainDataSql = new StringOps(Predef$.MODULE$.augmentString("\n      |select\n      |sales_num,\n      |price,is_promotion,is_regular_product,positive_count,comments,color,size,price_elastic,is_new_product,\n      |promotion_level,festival,weekday,month,season,weather\n      |from ads.ads_merchant_product_training_daily\n      |where env='#env#' and dt>='#start_dt#' and dt<= '#end_dt#'\n    ")).stripMargin();
        this.modelHadoopPath = "/user/mapred/env/#env#/ml/saleprediction/model/xgboost/";
    }
}
