/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.SingleClassifierEnhancer;
import weka.classifiers.trees.M5P;
import weka.core.BatchPredictor;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.MakeIndicator;

public class ClassificationViaRegression
extends SingleClassifierEnhancer
implements TechnicalInformationHandler,
WeightedInstancesHandler {
    static final long serialVersionUID = 4500023123618669859L;
    private Classifier[] m_Classifiers;
    private MakeIndicator[] m_ClassFilters;

    public ClassificationViaRegression() {
        this.m_Classifier = new M5P();
    }

    public String globalInfo() {
        return "Class for doing classification using regression methods. Class is binarized and one regression model is built for each class value. For more information, see, for example\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "E. Frank and Y. Wang and S. Inglis and G. Holmes and I.H. Witten");
        result.setValue(TechnicalInformation.Field.YEAR, "1998");
        result.setValue(TechnicalInformation.Field.TITLE, "Using model trees for classification");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        result.setValue(TechnicalInformation.Field.VOLUME, "32");
        result.setValue(TechnicalInformation.Field.NUMBER, "1");
        result.setValue(TechnicalInformation.Field.PAGES, "63-76");
        return result;
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.M5P";
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.disableAllClassDependencies();
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        return result;
    }

    @Override
    public void buildClassifier(Instances insts) throws Exception {
        this.getCapabilities().testWithFail(insts);
        insts = new Instances(insts);
        insts.deleteWithMissingClass();
        if (!insts.allInstanceWeightsIdentical() && !(this.m_Classifier instanceof WeightedInstancesHandler)) {
            throw new IllegalArgumentException("ClassificationViaRegression: training data has non-uniform instance weights and base classifier cannot handle instance weights");
        }
        this.m_Classifiers = AbstractClassifier.makeCopies(this.m_Classifier, insts.numClasses());
        this.m_ClassFilters = new MakeIndicator[insts.numClasses()];
        for (int i = 0; i < insts.numClasses(); ++i) {
            this.m_ClassFilters[i] = new MakeIndicator();
            this.m_ClassFilters[i].setAttributeIndex("" + (insts.classIndex() + 1));
            this.m_ClassFilters[i].setValueIndex(i);
            this.m_ClassFilters[i].setNumeric(true);
            this.m_ClassFilters[i].setInputFormat(insts);
            Instances newInsts = Filter.useFilter(insts, this.m_ClassFilters[i]);
            this.m_Classifiers[i].buildClassifier(newInsts);
        }
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        double[] probs = new double[inst.numClasses()];
        double sum = 0.0;
        for (int i = 0; i < inst.numClasses(); ++i) {
            this.m_ClassFilters[i].input(inst);
            this.m_ClassFilters[i].batchFinished();
            Instance newInst = this.m_ClassFilters[i].output();
            probs[i] = this.m_Classifiers[i].classifyInstance(newInst);
            if (Utils.isMissingValue(probs[i])) {
                return new double[inst.numClasses()];
            }
            if (probs[i] > 1.0) {
                probs[i] = 1.0;
            }
            if (probs[i] < 0.0) {
                probs[i] = 0.0;
            }
            sum += probs[i];
        }
        if (sum != 0.0) {
            Utils.normalize(probs, sum);
        }
        return probs;
    }

    @Override
    public boolean implementsMoreEfficientBatchPrediction() {
        if (!(this.m_Classifier instanceof BatchPredictor)) {
            return false;
        }
        return ((BatchPredictor)((Object)this.m_Classifier)).implementsMoreEfficientBatchPrediction();
    }

    @Override
    public double[][] distributionsForInstances(Instances insts) throws Exception {
        if (this.m_Classifier instanceof BatchPredictor) {
            int i;
            double[][] probs = new double[insts.numInstances()][insts.numClasses()];
            for (i = 0; i < insts.numClasses(); ++i) {
                double[][] p = ((BatchPredictor)((Object)this.m_Classifiers[i])).distributionsForInstances(Filter.useFilter(insts, this.m_ClassFilters[i]));
                for (int j = 0; j < p.length; ++j) {
                    if (p[j][0] > 1.0) {
                        p[j][0] = 1.0;
                    }
                    if (p[j][0] < 0.0) {
                        p[j][0] = 0.0;
                    }
                    probs[j][i] = p[j][0];
                }
            }
            for (i = 0; i < probs.length; ++i) {
                Utils.normalize(probs[i]);
            }
            return probs;
        }
        return super.distributionsForInstances(insts);
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "Classification via Regression: No model built yet.";
        }
        StringBuffer text = new StringBuffer();
        text.append("Classification via Regression\n\n");
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            text.append("Classifier for class with index " + i + ":\n\n");
            text.append(this.m_Classifiers[i].toString() + "\n\n");
        }
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 15482 $");
    }

    public static void main(String[] argv) {
        ClassificationViaRegression.runClassifier(new ClassificationViaRegression(), argv);
    }
}

