/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class SimpleLinearRegression
extends Classifier
implements WeightedInstancesHandler {
    static final long serialVersionUID = 1679336022895414137L;
    private Attribute m_attribute;
    private int m_attributeIndex;
    private double m_slope;
    private double m_intercept;
    private boolean m_suppressErrorMessage = false;

    public String globalInfo() {
        return "Learns a simple linear regression model. Picks the attribute that results in the lowest squared error. Missing values are not allowed. Can only deal with numeric attributes.";
    }

    @Override
    public double classifyInstance(Instance inst) throws Exception {
        if (this.m_attribute == null) {
            return this.m_intercept;
        }
        if (inst.isMissing(this.m_attribute.index())) {
            throw new Exception("SimpleLinearRegression: No missing values!");
        }
        return this.m_intercept + this.m_slope * inst.value(this.m_attribute.index());
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.DATE_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    @Override
    public void buildClassifier(Instances insts) throws Exception {
        this.getCapabilities().testWithFail(insts);
        insts = new Instances(insts);
        insts.deleteWithMissingClass();
        double yMean = insts.meanOrMode(insts.classIndex());
        double minMsq = Double.MAX_VALUE;
        this.m_attribute = null;
        int chosen = -1;
        double chosenSlope = Double.NaN;
        double chosenIntercept = Double.NaN;
        int i = 0;
        while (i < insts.numAttributes()) {
            if (i != insts.classIndex()) {
                this.m_attribute = insts.attribute(i);
                double xMean = insts.meanOrMode(i);
                double sumWeightedXDiffSquared = 0.0;
                double sumWeightedYDiffSquared = 0.0;
                this.m_slope = 0.0;
                int j = 0;
                while (j < insts.numInstances()) {
                    Instance inst = insts.instance(j);
                    if (!inst.isMissing(i) && !inst.classIsMissing()) {
                        double xDiff = inst.value(i) - xMean;
                        double yDiff = inst.classValue() - yMean;
                        double weightedXDiff = inst.weight() * xDiff;
                        double weightedYDiff = inst.weight() * yDiff;
                        this.m_slope += weightedXDiff * yDiff;
                        sumWeightedXDiffSquared += weightedXDiff * xDiff;
                        sumWeightedYDiffSquared += weightedYDiff * yDiff;
                    }
                    ++j;
                }
                if (sumWeightedXDiffSquared != 0.0) {
                    double numerator = this.m_slope;
                    this.m_slope /= sumWeightedXDiffSquared;
                    this.m_intercept = yMean - this.m_slope * xMean;
                    double msq = sumWeightedYDiffSquared - this.m_slope * numerator;
                    if (msq < minMsq) {
                        minMsq = msq;
                        chosen = i;
                        chosenSlope = this.m_slope;
                        chosenIntercept = this.m_intercept;
                    }
                }
            }
            ++i;
        }
        if (chosen == -1) {
            if (!this.m_suppressErrorMessage) {
                System.err.println("----- no useful attribute found");
            }
            this.m_attribute = null;
            this.m_attributeIndex = 0;
            this.m_slope = 0.0;
            this.m_intercept = yMean;
        } else {
            this.m_attribute = insts.attribute(chosen);
            this.m_attributeIndex = chosen;
            this.m_slope = chosenSlope;
            this.m_intercept = chosenIntercept;
        }
    }

    public boolean foundUsefulAttribute() {
        return this.m_attribute != null;
    }

    public int getAttributeIndex() {
        return this.m_attributeIndex;
    }

    public double getSlope() {
        return this.m_slope;
    }

    public double getIntercept() {
        return this.m_intercept;
    }

    public void setSuppressErrorMessage(boolean s) {
        this.m_suppressErrorMessage = s;
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        if (this.m_attribute == null) {
            text.append("Predicting constant " + this.m_intercept);
        } else {
            text.append("Linear regression on " + this.m_attribute.name() + "\n\n");
            text.append(String.valueOf(Utils.doubleToString(this.m_slope, 2)) + " * " + this.m_attribute.name());
            if (this.m_intercept > 0.0) {
                text.append(" + " + Utils.doubleToString(this.m_intercept, 2));
            } else {
                text.append(" - " + Utils.doubleToString(-this.m_intercept, 2));
            }
        }
        text.append("\n");
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5523 $");
    }

    public static void main(String[] argv) {
        SimpleLinearRegression.runClassifier(new SimpleLinearRegression(), argv);
    }
}

