/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.common.libsvm;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.common.libsvm.KernelType;
import org.tribuo.common.libsvm.LibSVMModel;
import org.tribuo.common.libsvm.SVMParameters;
import org.tribuo.common.libsvm.SVMType;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.util.Util;

public abstract class LibSVMTrainer<T extends Output<T>>
implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(LibSVMTrainer.class.getName());
    protected svm_parameter parameters;
    @Config(mandatory=true, description="Type of SVM algorithm.")
    protected SVMType<T> svmType;
    @Config(description="Type of Kernel.")
    private KernelType kernelType = KernelType.LINEAR;
    @Config(description="Polynomial degree.")
    private int degree = 3;
    @Config(description="Width of the RBF kernel, or scalar on sigmoid kernel.")
    private double gamma = 0.0;
    @Config(description="Polynomial coefficient or shift in sigmoid kernel.")
    private double coef0 = 0.0;
    @Config(description="nu value in NU SVM.")
    private double nu = 0.5;
    @Config(description="Internal cache size, most of the time should be left at default.")
    private double cache_size = 500.0;
    @Config(description="Cost parameter for incorrect predictions.")
    private double cost = 1.0;
    @Config(description="Tolerance of the termination criterion.")
    private double eps = 0.001;
    @Config(description="Epsilon in EPSILON_SVR.")
    private double p = 0.001;
    @Config(description="Regularise the weight parameters.")
    private boolean shrinking = true;
    @Config(description="Generate probability estimates.")
    private boolean probability = false;
    @Config(description="RNG seed.")
    private long seed = 12345L;
    private SplittableRandom rng;
    private int trainInvocationCounter = 0;

    protected LibSVMTrainer() {
    }

    protected LibSVMTrainer(SVMParameters<T> parameters, long seed) {
        this.parameters = parameters.getParameters();
        this.svmType = parameters.getSvmType();
        this.kernelType = parameters.getKernelType();
        this.degree = this.parameters.degree;
        this.gamma = parameters.getGamma();
        this.coef0 = this.parameters.coef0;
        this.nu = this.parameters.nu;
        this.cache_size = this.parameters.cache_size;
        this.cost = this.parameters.C;
        this.eps = this.parameters.eps;
        this.p = this.parameters.p;
        this.shrinking = this.parameters.shrinking == 1;
        this.probability = this.parameters.probability == 1;
        this.seed = seed;
        this.rng = new SplittableRandom(seed);
    }

    public void postConfig() {
        this.parameters = new svm_parameter();
        this.parameters.svm_type = this.svmType.getNativeType();
        this.parameters.kernel_type = this.kernelType.getNativeType();
        this.parameters.degree = this.degree;
        this.parameters.gamma = this.gamma;
        this.parameters.coef0 = this.coef0;
        this.parameters.nu = this.nu;
        this.parameters.cache_size = this.cache_size;
        this.parameters.C = this.cost;
        this.parameters.eps = this.eps;
        this.parameters.p = this.p;
        this.parameters.shrinking = this.shrinking ? 1 : 0;
        this.parameters.probability = this.probability ? 1 : 0;
        this.rng = new SplittableRandom(this.seed);
    }

    public String toString() {
        StringBuilder buffer = new StringBuilder();
        buffer.append("LibSVMTrainer(");
        buffer.append("svm_params=");
        buffer.append(SVMParameters.svmParamsToString(this.parameters));
        buffer.append(",seed=");
        buffer.append(this.seed);
        buffer.append(")");
        return buffer.toString();
    }

    public LibSVMModel<T> train(Dataset<T> examples) {
        return this.train((Dataset)examples, Collections.emptyMap());
    }

    public LibSVMModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
        return this.train((Dataset)examples, (Map)runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public LibSVMModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = examples.getOutputIDInfo();
        LibSVMTrainer libSVMTrainer = this;
        synchronized (libSVMTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            localRNG = this.rng.split();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        svm_parameter curParams = this.setupParameters(outputIDInfo);
        Pair<svm_node[][], double[][]> data = this.extractData(examples, outputIDInfo, featureIDMap);
        Class<LibSVMTrainer> clazz = LibSVMTrainer.class;
        synchronized (LibSVMTrainer.class) {
            List<svm_model> models = this.trainModels(curParams, featureIDMap.size() + 1, (svm_node[][])data.getA(), (double[][])data.getB(), localRNG);
            // ** MonitorExit[var11_11] (shouldn't be in output)
            ModelProvenance provenance = new ModelProvenance(LibSVMModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
            return this.createModel(provenance, featureIDMap, outputIDInfo, models);
        }
    }

    protected abstract LibSVMModel<T> createModel(ModelProvenance var1, ImmutableFeatureMap var2, ImmutableOutputInfo<T> var3, List<svm_model> var4);

    protected abstract List<svm_model> trainModels(svm_parameter var1, int var2, svm_node[][] var3, double[][] var4, SplittableRandom var5);

    protected abstract Pair<svm_node[][], double[][]> extractData(Dataset<T> var1, ImmutableOutputInfo<T> var2, ImmutableFeatureMap var3);

    protected svm_parameter setupParameters(ImmutableOutputInfo<T> info) {
        return SVMParameters.copyParameters(this.parameters);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < invocationCount) {
            SplittableRandom splittableRandom = this.rng.split();
            ++this.trainInvocationCounter;
        }
    }

    public static <T extends Output<T>> svm_node[] exampleToNodes(Example<T> example, ImmutableFeatureMap featureIDMap, List<svm_node> features) {
        if (features == null) {
            features = new ArrayList<svm_node>();
        }
        features.clear();
        int prevIdx = -1;
        for (Feature f : example) {
            svm_node n2;
            int id = featureIDMap.getID(f.getName());
            double value = f.getValue();
            if (id > prevIdx) {
                prevIdx = id;
                svm_node n3 = new svm_node();
                n3.index = id;
                n3.value = value;
                features.add(n3);
                continue;
            }
            if (id <= -1) continue;
            int collisionIdx = Util.binarySearch(features, (int)id, n -> n.index);
            if (collisionIdx < 0) {
                collisionIdx = -(collisionIdx + 1);
                n2 = new svm_node();
                n2.index = id;
                n2.value = value;
                features.add(collisionIdx, n2);
                continue;
            }
            n2 = features.get(collisionIdx);
            n2.value += value;
        }
        return features.toArray(new svm_node[0]);
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }
}

