/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.kmeans;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.StreamUtil;
import java.security.AccessController;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Trainer;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ImmutableClusteringInfo;
import org.tribuo.clustering.kmeans.KMeansModel;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.util.Util;

public class KMeansTrainer
implements Trainer<ClusterID> {
    private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName());
    private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory();
    @Config(mandatory=true, description="Number of centroids (i.e., the \"k\" in k-means).")
    private int centroids;
    @Config(mandatory=true, description="The number of iterations to run.")
    private int iterations;
    @Config(mandatory=true, description="The distance function to use.")
    private Distance distanceType;
    @Config(description="The centroid initialisation method to use.")
    private Initialisation initialisationType = Initialisation.RANDOM;
    @Config(description="The number of threads to use for training.")
    private int numThreads = 1;
    @Config(mandatory=true, description="The seed to use for the RNG.")
    private long seed;
    private SplittableRandom rng;
    private int trainInvocationCounter;

    private KMeansTrainer() {
    }

    public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, long seed) {
        this(centroids, iterations, distanceType, Initialisation.RANDOM, numThreads, seed);
    }

    public KMeansTrainer(int centroids, int iterations, Distance distanceType, Initialisation initialisationType, int numThreads, long seed) {
        this.centroids = centroids;
        this.iterations = iterations;
        this.distanceType = distanceType;
        this.initialisationType = initialisationType;
        this.numThreads = numThreads;
        this.seed = seed;
        this.postConfig();
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
        return this.train((Dataset)examples, (Map)runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        DenseVector[] centroidVectors;
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        KMeansTrainer kMeansTrainer = this;
        synchronized (kMeansTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            localRNG = this.rng.split();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
        int[] oldCentre = new int[examples.size()];
        SGDVector[] data = new SGDVector[examples.size()];
        double[] weights = new double[examples.size()];
        int n = 0;
        for (Example example : examples) {
            weights[n] = example.getWeight();
            data[n] = example.size() == featureMap.size() ? DenseVector.createDenseVector((Example)example, (ImmutableFeatureMap)featureMap, (boolean)false) : SparseVector.createSparseVector((Example)example, (ImmutableFeatureMap)featureMap, (boolean)false);
            oldCentre[n] = -1;
            ++n;
        }
        switch (this.initialisationType) {
            case RANDOM: {
                centroidVectors = KMeansTrainer.initialiseRandomCentroids(this.centroids, featureMap, localRNG);
                break;
            }
            case PLUSPLUS: {
                centroidVectors = KMeansTrainer.initialisePlusPlusCentroids(this.centroids, data, localRNG, this.distanceType);
                break;
            }
            default: {
                throw new IllegalStateException("Unknown initialisation" + (Object)((Object)this.initialisationType));
            }
        }
        HashMap<Integer, List<Integer>> clusterAssignments = new HashMap<Integer, List<Integer>>();
        boolean parallel = this.numThreads > 1;
        for (int i = 0; i < this.centroids; ++i) {
            clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList()) : new ArrayList());
        }
        AtomicInteger changeCounter = new AtomicInteger(0);
        Consumer<IntAndVector> eStepFunc = e -> {
            double minDist = Double.POSITIVE_INFINITY;
            int clusterID = -1;
            int id = e.idx;
            SGDVector vector = e.vector;
            for (int j = 0; j < this.centroids; ++j) {
                DenseVector cluster = centroidVectors[j];
                double distance = KMeansTrainer.getDistance(cluster, vector, this.distanceType);
                if (!(distance < minDist)) continue;
                minDist = distance;
                clusterID = j;
            }
            ((List)clusterAssignments.get(clusterID)).add(id);
            if (oldCentre[id] != clusterID) {
                oldCentre[id] = clusterID;
                changeCounter.incrementAndGet();
            }
        };
        boolean converged = false;
        ForkJoinPool fjp = null;
        try {
            if (parallel) {
                fjp = System.getSecurityManager() == null ? new ForkJoinPool(this.numThreads) : new ForkJoinPool(this.numThreads, THREAD_FACTORY, null, false);
            }
            for (int i = 0; i < this.iterations && !converged; ++i) {
                logger.log(Level.FINE, "Beginning iteration " + i);
                changeCounter.set(0);
                for (Map.Entry e2 : clusterAssignments.entrySet()) {
                    ((List)e2.getValue()).clear();
                }
                Stream<SGDVector> vecStream = Arrays.stream(data);
                Stream<Integer> intStream = IntStream.range(0, data.length).boxed();
                Stream zipStream = StreamUtil.zip(intStream, (Stream)vecStream, IntAndVector::new);
                if (parallel) {
                    Stream parallelZipStream = StreamUtil.boundParallelism((Stream)((Stream)zipStream.parallel()));
                    try {
                        ((ForkJoinTask)fjp.submit(() -> parallelZipStream.forEach(eStepFunc))).get();
                    }
                    catch (InterruptedException | ExecutionException e3) {
                        throw new RuntimeException("Parallel execution failed", e3);
                    }
                } else {
                    zipStream.forEach(eStepFunc);
                }
                logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated.");
                this.mStep(fjp, centroidVectors, clusterAssignments, data, weights);
                logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated.");
                if (changeCounter.get() != 0) continue;
                converged = true;
                logger.log(Level.INFO, "K-Means converged at iteration " + i);
            }
        }
        finally {
            if (fjp != null) {
                fjp.shutdown();
            }
        }
        HashMap counts = new HashMap();
        for (Map.Entry e2 : clusterAssignments.entrySet()) {
            counts.put(e2.getKey(), new MutableLong((long)((List)e2.getValue()).size()));
        }
        ImmutableClusteringInfo outputMap = new ImmutableClusteringInfo(counts);
        ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        return new KMeansModel("k-means-model", provenance, featureMap, (ImmutableOutputInfo<ClusterID>)outputMap, centroidVectors, this.distanceType);
    }

    public KMeansModel train(Dataset<ClusterID> dataset) {
        return this.train((Dataset)dataset, Collections.emptyMap());
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < invocationCount) {
            SplittableRandom splittableRandom = this.rng.split();
            ++this.trainInvocationCounter;
        }
    }

    private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableFeatureMap featureMap, SplittableRandom rng) {
        DenseVector[] centroidVectors = new DenseVector[centroids];
        int numFeatures = featureMap.size();
        for (int i = 0; i < centroids; ++i) {
            double[] newCentroid = new double[numFeatures];
            for (int j = 0; j < numFeatures; ++j) {
                newCentroid[j] = featureMap.get(j).uniformSample(rng);
            }
            centroidVectors[i] = DenseVector.createDenseVector((double[])newCentroid);
        }
        return centroidVectors;
    }

    private static DenseVector[] initialisePlusPlusCentroids(int centroids, SGDVector[] data, SplittableRandom rng, Distance distanceType) {
        if (centroids > data.length) {
            throw new IllegalArgumentException("The number of centroids may not exceed the number of samples.");
        }
        double[] minDistancePerVector = new double[data.length];
        Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY);
        double[] squaredMinDistance = new double[data.length];
        double[] probabilities = new double[data.length];
        DenseVector[] centroidVectors = new DenseVector[centroids];
        centroidVectors[0] = KMeansTrainer.getRandomCentroidFromData(data, rng);
        for (int i = 1; i < centroids; ++i) {
            int j;
            DenseVector prevCentroid = centroidVectors[i - 1];
            for (int j2 = 0; j2 < data.length; ++j2) {
                double tempDistance = KMeansTrainer.getDistance(prevCentroid, data[j2], distanceType);
                minDistancePerVector[j2] = Math.min(minDistancePerVector[j2], tempDistance);
            }
            double total = 0.0;
            for (j = 0; j < data.length; ++j) {
                squaredMinDistance[j] = minDistancePerVector[j] * minDistancePerVector[j];
                total += squaredMinDistance[j];
            }
            for (j = 0; j < probabilities.length; ++j) {
                probabilities[j] = squaredMinDistance[j] / total;
            }
            double[] cdf = Util.generateCDF((double[])probabilities);
            int idx = Util.sampleFromCDF((double[])cdf, (SplittableRandom)rng);
            centroidVectors[i] = DenseVector.createDenseVector((double[])data[idx].toArray());
        }
        return centroidVectors;
    }

    private static DenseVector getRandomCentroidFromData(SGDVector[] data, SplittableRandom rng) {
        int randIdx = rng.nextInt(data.length);
        return DenseVector.createDenseVector((double[])data[randIdx].toArray());
    }

    private static double getDistance(DenseVector cluster, SGDVector vector, Distance distanceType) {
        double distance;
        switch (distanceType) {
            case EUCLIDEAN: {
                distance = cluster.euclideanDistance(vector);
                break;
            }
            case COSINE: {
                distance = cluster.cosineDistance(vector);
                break;
            }
            case L1: {
                distance = cluster.l1Distance(vector);
                break;
            }
            default: {
                throw new IllegalStateException("Unknown distance " + (Object)((Object)distanceType));
            }
        }
        return distance;
    }

    protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer, List<Integer>> clusterAssignments, SGDVector[] data, double[] weights) {
        Consumer<Map.Entry> mStepFunc = e -> {
            DenseVector newCentroid = centroidVectors[(Integer)e.getKey()];
            newCentroid.fill(0.0);
            double weightSum = 0.0;
            for (Integer idx : (List)e.getValue()) {
                newCentroid.intersectAndAddInPlace((Tensor)data[idx], f -> f * weights[idx]);
                weightSum += weights[idx];
            }
            if (weightSum != 0.0) {
                newCentroid.scaleInPlace(1.0 / weightSum);
            }
        };
        Stream mStream = clusterAssignments.entrySet().stream();
        if (fjp != null) {
            Stream parallelMStream = StreamUtil.boundParallelism((Stream)((Stream)mStream.parallel()));
            try {
                ((ForkJoinTask)fjp.submit(() -> parallelMStream.forEach(mStepFunc))).get();
            }
            catch (InterruptedException | ExecutionException e2) {
                throw new RuntimeException("Parallel execution failed", e2);
            }
        } else {
            mStream.forEach(mStepFunc);
        }
    }

    public String toString() {
        return "KMeansTrainer(centroids=" + this.centroids + ",distanceType=" + (Object)((Object)this.distanceType) + ",seed=" + this.seed + ",numThreads=" + this.numThreads + ", initialisationType=" + (Object)((Object)this.initialisationType) + ")";
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }

    private static final class CustomForkJoinWorkerThreadFactory
    implements ForkJoinPool.ForkJoinWorkerThreadFactory {
        private CustomForkJoinWorkerThreadFactory() {
        }

        @Override
        public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
            return AccessController.doPrivileged(() -> new ForkJoinWorkerThread(pool){});
        }
    }

    static class IntAndVector {
        final int idx;
        final SGDVector vector;

        public IntAndVector(int idx, SGDVector vector) {
            this.idx = idx;
            this.vector = vector;
        }
    }

    public static enum Initialisation {
        RANDOM,
        PLUSPLUS;

    }

    public static enum Distance {
        EUCLIDEAN,
        COSINE,
        L1;

    }
}

