/*
 * Decompiled with CFR 0.152.
 */
package sparklyr;

import java.util.concurrent.ConcurrentHashMap;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Row;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Range;
import scala.math.Ordering$;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.ScalaRunTime$;
import scala.util.Random;
import sparklyr.BoundedPriorityQueue;
import sparklyr.SamplingUtils;
import sparklyr.SamplingUtils$;
import sparklyr.SamplingUtils$$anonfun$3$;
import sparklyr.Utils$;

public final class SamplingUtils$ {
    public static final SamplingUtils$ MODULE$;

    static {
        new SamplingUtils$();
    }

    public RDD<Row> sampleWithoutReplacement(RDD<Row> rdd, String weightColumn, int k, long seed) {
        RDD rDD;
        SparkContext sc = rdd.context();
        if (0 == k) {
            rDD = sc.emptyRDD(ClassTag$.MODULE$.apply(Row.class));
        } else {
            ConcurrentHashMap prngState = new ConcurrentHashMap();
            BoundedPriorityQueue samples = (BoundedPriorityQueue)rdd.aggregate(new BoundedPriorityQueue(k, Ordering$.MODULE$.ordered((Function1)Predef$.MODULE$.$conforms())), (Function2)new Serializable(weightColumn, seed, prngState){
                public static final long serialVersionUID = 0L;
                private final String weightColumn$1;
                private final long seed$1;
                private final ConcurrentHashMap prngState$1;

                public final BoundedPriorityQueue<SamplingUtils.Sample> apply(BoundedPriorityQueue<SamplingUtils.Sample> pq, Row row) {
                    BoxedUnit boxedUnit;
                    double weight = SamplingUtils$.MODULE$.extractWeightValue(row, this.weightColumn$1);
                    if (weight > 0.0) {
                        long sampleSeed = this.seed$1 + (long)TaskContext$.MODULE$.getPartitionId();
                        Random random = this.prngState$1.computeIfAbsent(BoxesRunTime.boxToLong((long)sampleSeed), new SamplingUtils.PRNG());
                        SamplingUtils.Sample sample = new SamplingUtils.Sample(SamplingUtils$.MODULE$.genSamplePriority(weight, random), row);
                        boxedUnit = pq.$plus$eq((Object)sample);
                    } else {
                        boxedUnit = BoxedUnit.UNIT;
                    }
                    return pq;
                }
                {
                    this.weightColumn$1 = weightColumn$1;
                    this.seed$1 = seed$1;
                    this.prngState$1 = prngState$1;
                }
            }, (Function2)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final BoundedPriorityQueue<SamplingUtils.Sample> apply(BoundedPriorityQueue<SamplingUtils.Sample> pq1, BoundedPriorityQueue<SamplingUtils.Sample> pq2) {
                    pq1.$plus$plus$eq((TraversableOnce)pq2);
                    return pq1;
                }
            }, ClassTag$.MODULE$.apply(BoundedPriorityQueue.class));
            rDD = sc.parallelize((Seq)samples.toSeq().map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final Row apply(SamplingUtils.Sample x) {
                    return x.row();
                }
            }, Seq$.MODULE$.canBuildFrom()), sc.parallelize$default$2(), ClassTag$.MODULE$.apply(Row.class));
        }
        return rDD;
    }

    public RDD<Row> sampleWithReplacement(RDD<Row> rdd, String weightColumn, int k, long seed) {
        RDD mapRDDs;
        SparkContext sc = rdd.context();
        return 0 == k ? sc.emptyRDD(ClassTag$.MODULE$.apply(Row.class)) : (0 == (mapRDDs = rdd.mapPartitionsWithIndex((Function2)new Serializable(weightColumn, k, seed){
            public static final long serialVersionUID = 0L;
            public final String weightColumn$2;
            public final int k$1;
            private final long seed$2;

            public final Iterator<SamplingUtils.Sample[]> apply(int index, Iterator<Row> iter) {
                Random random = new Random(this.seed$2 + (long)index);
                SamplingUtils.Sample[] samples = (SamplingUtils.Sample[])Array$.MODULE$.fill(this.k$1, (Function0)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final SamplingUtils.Sample apply() {
                        return new SamplingUtils.Sample(Double.NEGATIVE_INFINITY, null);
                    }
                }, ClassTag$.MODULE$.apply(SamplingUtils.Sample.class));
                iter.foreach((Function1)new Serializable(this, random, samples){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ anonfun.3 $outer;
                    public final Random random$1;
                    public final SamplingUtils.Sample[] samples$1;

                    public final void apply(Row row) {
                        DoubleRef weight = DoubleRef.create((double)SamplingUtils$.MODULE$.extractWeightValue(row, this.$outer.weightColumn$2));
                        if (weight.elem > 0.0) {
                            Serializable serializable = new Serializable(this, weight, row){
                                public static final long serialVersionUID = 0L;
                                private final /* synthetic */ anonfun$3$$anonfun$apply$3 $outer;
                                private final DoubleRef weight$1;
                                private final Row row$1;

                                public final void apply(int idx) {
                                    this.apply$mcVI$sp(idx);
                                }

                                public void apply$mcVI$sp(int idx) {
                                    SamplingUtils.Sample replacement = new SamplingUtils.Sample(SamplingUtils$.MODULE$.genSamplePriority(this.weight$1.elem, this.$outer.random$1), this.row$1);
                                    if (this.$outer.samples$1[idx].$less(replacement)) {
                                        this.$outer.samples$1[idx] = replacement;
                                    }
                                }
                                {
                                    if ($outer == null) {
                                        throw null;
                                    }
                                    this.$outer = $outer;
                                    this.weight$1 = weight$1;
                                    this.row$1 = row$1;
                                }
                            };
                            Range range = scala.package$.MODULE$.Range().apply(0, this.$outer.k$1);
                            if (!range.isEmpty()) {
                                int n = range.start();
                                while (true) {
                                    serializable.apply$mcVI$sp(n);
                                    if (n == range.lastElement()) break;
                                    n += range.step();
                                }
                            }
                        }
                    }
                    {
                        if ($outer == null) {
                            throw null;
                        }
                        this.$outer = $outer;
                        this.random$1 = random$1;
                        this.samples$1 = samples$1;
                    }
                });
                return scala.package$.MODULE$.Iterator().single((Object)samples);
            }
            {
                this.weightColumn$2 = weightColumn$2;
                this.k$1 = k$1;
                this.seed$2 = seed$2;
            }
        }, rdd.mapPartitionsWithIndex$default$2(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(SamplingUtils.Sample.class)))).partitions().length ? sc.emptyRDD(ClassTag$.MODULE$.apply(Row.class)) : sc.parallelize((Seq)Predef$.MODULE$.refArrayOps((Object[])mapRDDs.reduce((Function2)new Serializable(k){
            public static final long serialVersionUID = 0L;
            private final int k$1;

            public final SamplingUtils.Sample[] apply(SamplingUtils.Sample[] s1, SamplingUtils.Sample[] s2) {
                Serializable serializable = new Serializable(this, s1, s2){
                    public static final long serialVersionUID = 0L;
                    private final SamplingUtils.Sample[] s1$1;
                    private final SamplingUtils.Sample[] s2$1;

                    public final void apply(int idx) {
                        this.apply$mcVI$sp(idx);
                    }

                    public void apply$mcVI$sp(int idx) {
                        if (this.s1$1[idx].$less(this.s2$1[idx])) {
                            this.s1$1[idx] = this.s2$1[idx];
                        }
                    }
                    {
                        this.s1$1 = s1$1;
                        this.s2$1 = s2$1;
                    }
                };
                Range range = scala.package$.MODULE$.Range().apply(0, this.k$1);
                if (!range.isEmpty()) {
                    int n = range.start();
                    while (true) {
                        serializable.apply$mcVI$sp(n);
                        if (n == range.lastElement()) break;
                        n += range.step();
                    }
                }
                return s1;
            }
            {
                this.k$1 = k$1;
            }
        })).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Row apply(SamplingUtils.Sample x) {
                return x.row();
            }
        }, Array$.MODULE$.fallbackCanBuildFrom(Predef.DummyImplicit$.MODULE$.dummyImplicit())), sc.parallelize$default$2(), ClassTag$.MODULE$.apply(Row.class)));
    }

    public double genSamplePriority(double weight, Random random) {
        return package$.MODULE$.log(random.nextDouble()) / weight;
    }

    public double extractWeightValue(Row row, String weightColumn) {
        return weightColumn == null || weightColumn.isEmpty() ? 1.0 : Utils$.MODULE$.asDouble(row.get(row.fieldIndex(weightColumn)));
    }

    private SamplingUtils$() {
        MODULE$ = this;
    }
}

