/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import smile.data.Attribute;
import smile.data.AttributeDataset;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.regression.Regression;
import smile.regression.RegressionTrainer;
import smile.regression.RegressionTree;
import smile.util.MulticoreExecutor;
import smile.util.SmileUtils;
import smile.validation.RMSE;
import smile.validation.RegressionMeasure;

public class RandomForest
implements Regression<double[]> {
    private static final long serialVersionUID = 1L;
    private List<RegressionTree> trees;
    private double error;
    private double[] importance;
    private double[] monotonicRegression;

    public RandomForest(double[][] dArray, double[] dArray2, int n) {
        this(null, dArray, dArray2, n);
    }

    public RandomForest(double[][] dArray, double[] dArray2, int n, int n2, int n3, int n4) {
        this(null, dArray, dArray2, n, n2, n3, n4);
    }

    public RandomForest(Attribute[] attributeArray, double[][] dArray, double[] dArray2, int n) {
        this(attributeArray, dArray, dArray2, n, 100);
    }

    public RandomForest(AttributeDataset attributeDataset, int n) {
        this(attributeDataset.attributes(), attributeDataset.x(), attributeDataset.y(), n);
    }

    public RandomForest(Attribute[] attributeArray, double[][] dArray, double[] dArray2, int n, int n2) {
        this(attributeArray, dArray, dArray2, n, n2, 5);
    }

    public RandomForest(AttributeDataset attributeDataset, int n, int n2) {
        this(attributeDataset.attributes(), attributeDataset.x(), attributeDataset.y(), n, n2);
    }

    public RandomForest(Attribute[] attributeArray, double[][] dArray, double[] dArray2, int n, int n2, int n3) {
        this(attributeArray, dArray, dArray2, n, n2, n3, dArray[0].length / 3);
    }

    public RandomForest(AttributeDataset attributeDataset, int n, int n2, int n3) {
        this(attributeDataset.attributes(), attributeDataset.x(), attributeDataset.y(), n, n2, n3);
    }

    public RandomForest(Attribute[] attributeArray, double[][] dArray, double[] dArray2, int n, int n2, int n3, int n4) {
        this(attributeArray, dArray, dArray2, n, n2, n3, n4, 1.0);
    }

    public RandomForest(Attribute[] attributeArray, double[][] dArray, double[] dArray2, int n, int n2, int n3, int n4, double d) {
        this(attributeArray, dArray, dArray2, n, n2, n3, n4, d, null);
    }

    public RandomForest(AttributeDataset attributeDataset, int n, int n2, int n3, int n4, double d, double[] dArray) {
        this(attributeDataset.attributes(), attributeDataset.x(), attributeDataset.y(), n, n2, n3, n4, d, dArray);
    }

    public RandomForest(Attribute[] attributeArray, double[][] dArray, double[] dArray2, int n, int n2, int n3, int n4, double d, double[] dArray3) {
        int n5;
        int n6;
        if (dArray.length != dArray2.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", dArray.length, dArray2.length));
        }
        if (n < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + n);
        }
        if (n4 < 1 || n4 > dArray[0].length) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + n4);
        }
        if (n3 < 2) {
            throw new IllegalArgumentException("Invalid minimum size of leaves: " + n3);
        }
        if (n2 < 2) {
            throw new IllegalArgumentException("Invalid maximum number of leaves: " + n2);
        }
        if (d <= 0.0 || d > 1.0) {
            throw new IllegalArgumentException("Invalid sampling rate: " + d);
        }
        if (attributeArray == null) {
            n6 = dArray[0].length;
            attributeArray = new Attribute[n6];
            for (int i = 0; i < n6; ++i) {
                attributeArray[i] = new NumericAttribute("V" + (i + 1));
            }
        }
        n6 = dArray.length;
        double[] dArray4 = new double[n6];
        int[] nArray = new int[n6];
        int[][] nArray2 = SmileUtils.sort(attributeArray, dArray);
        ArrayList<TrainingTask> arrayList = new ArrayList<TrainingTask>();
        for (int i = 0; i < n; ++i) {
            arrayList.add(new TrainingTask(attributeArray, dArray, dArray2, n2, n3, n4, d, nArray2, dArray4, nArray, dArray3));
        }
        try {
            this.trees = MulticoreExecutor.run(arrayList);
        }
        catch (Exception exception) {
            exception.printStackTrace();
            this.trees = new ArrayList<RegressionTree>(n);
            for (n5 = 0; n5 < n; ++n5) {
                this.trees.add(((TrainingTask)arrayList.get(n5)).call());
            }
        }
        int n7 = 0;
        for (n5 = 0; n5 < n6; ++n5) {
            if (nArray[n5] <= 0) continue;
            ++n7;
            double d2 = dArray4[n5] / (double)nArray[n5];
            this.error += Math.sqr(d2 - dArray2[n5]);
        }
        if (n7 > 0) {
            this.error = Math.sqrt(this.error / (double)n7);
        }
        this.importance = RandomForest.calculateImportance(this.trees, attributeArray.length);
    }

    public RandomForest merge(RandomForest randomForest) {
        if (this.importance.length != randomForest.importance.length) {
            throw new IllegalArgumentException("RandomForest have different sizes of feature vectors");
        }
        ArrayList<RegressionTree> arrayList = new ArrayList<RegressionTree>();
        arrayList.addAll(this.trees);
        arrayList.addAll(randomForest.trees);
        double d = (this.error * (double)this.trees.size() + randomForest.error * (double)randomForest.trees.size()) / (double)(this.trees.size() + randomForest.trees.size());
        double[] dArray = RandomForest.calculateImportance(arrayList, this.importance.length);
        return new RandomForest(arrayList, d, dArray);
    }

    private RandomForest(List<RegressionTree> list, double d, double[] dArray) {
        this.trees = list;
        this.error = d;
        this.importance = dArray;
    }

    private static double[] calculateImportance(List<RegressionTree> list, int n) {
        double[] dArray = new double[n];
        for (RegressionTree regressionTree : list) {
            double[] dArray2 = regressionTree.importance();
            for (int i = 0; i < dArray2.length; ++i) {
                int n2 = i;
                dArray[n2] = dArray[n2] + dArray2[i];
            }
        }
        return dArray;
    }

    public double error() {
        return this.error;
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.size();
    }

    public void trim(int n) {
        if (n > this.trees.size()) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (n <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + n);
        }
        ArrayList<RegressionTree> arrayList = new ArrayList<RegressionTree>(n);
        for (int i = 0; i < n; ++i) {
            arrayList.add(this.trees.get(i));
        }
        this.trees = arrayList;
    }

    @Override
    public double predict(double[] dArray) {
        double d = 0.0;
        for (RegressionTree regressionTree : this.trees) {
            d += regressionTree.predict(dArray);
        }
        return d / (double)this.trees.size();
    }

    public double[] test(double[][] dArray, double[] dArray2) {
        int n = this.trees.size();
        double[] dArray3 = new double[n];
        int n2 = dArray.length;
        double[] dArray4 = new double[n2];
        double[] dArray5 = new double[n2];
        RMSE rMSE = new RMSE();
        int n3 = 0;
        int n4 = 1;
        while (n3 < n) {
            for (int i = 0; i < n2; ++i) {
                int n5 = i;
                dArray4[n5] = dArray4[n5] + this.trees.get(n3).predict(dArray[i]);
                dArray5[i] = dArray4[i] / (double)n4;
            }
            dArray3[n3] = rMSE.measure(dArray2, dArray5);
            ++n3;
            ++n4;
        }
        return dArray3;
    }

    public double[][] test(double[][] dArray, double[] dArray2, RegressionMeasure[] regressionMeasureArray) {
        int n = this.trees.size();
        int n2 = regressionMeasureArray.length;
        double[][] dArray3 = new double[n][n2];
        int n3 = dArray.length;
        double[] dArray4 = new double[n3];
        double[] dArray5 = new double[n3];
        int n4 = 0;
        int n5 = 1;
        while (n4 < n) {
            int n6;
            for (n6 = 0; n6 < n3; ++n6) {
                int n7 = n6;
                dArray4[n7] = dArray4[n7] + this.trees.get(n4).predict(dArray[n6]);
                dArray5[n6] = dArray4[n6] / (double)n5;
            }
            for (n6 = 0; n6 < n2; ++n6) {
                dArray3[n4][n6] = regressionMeasureArray[n6].measure(dArray2, dArray5);
            }
            ++n4;
            ++n5;
        }
        return dArray3;
    }

    public RegressionTree[] getTrees() {
        return this.trees.toArray(new RegressionTree[this.trees.size()]);
    }

    static class TrainingTask
    implements Callable<RegressionTree> {
        Attribute[] attributes;
        double[][] x;
        double[] y;
        int[][] order;
        int mtry;
        int nodeSize = 5;
        int maxNodes = 100;
        double subsample = 1.0;
        final double[] monotonicRegression;
        double[] prediction;
        int[] oob;

        TrainingTask(Attribute[] attributeArray, double[][] dArray, double[] dArray2, int n, int n2, int n3, double d, int[][] nArray, double[] dArray3, int[] nArray2, double[] dArray4) {
            this.attributes = attributeArray;
            this.monotonicRegression = dArray4;
            this.x = dArray;
            this.y = dArray2;
            this.order = nArray;
            this.mtry = n3;
            this.nodeSize = n2;
            this.maxNodes = n;
            this.subsample = d;
            this.prediction = dArray3;
            this.oob = nArray2;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public RegressionTree call() {
            int n;
            int n2 = this.x.length;
            int[] nArray = new int[n2];
            if (this.subsample == 1.0) {
                for (int i = 0; i < n2; ++i) {
                    int n3 = n = Math.randomInt(n2);
                    nArray[n3] = nArray[n3] + 1;
                }
            } else {
                int[] nArray2 = new int[n2];
                for (n = 0; n < n2; ++n) {
                    nArray2[n] = n;
                }
                Math.permutate(nArray2);
                n = (int)Math.round((double)n2 * this.subsample);
                for (int i = 0; i < n; ++i) {
                    int n4 = nArray2[i];
                    nArray[n4] = nArray[n4] + 1;
                }
            }
            RegressionTree regressionTree = new RegressionTree(this.attributes, this.x, this.y, this.maxNodes, this.nodeSize, this.mtry, this.order, nArray, null, this.monotonicRegression);
            for (n = 0; n < n2; ++n) {
                if (nArray[n] != 0) continue;
                double d = regressionTree.predict(this.x[n]);
                double[] dArray = this.x[n];
                synchronized (dArray) {
                    int n5 = n;
                    this.prediction[n5] = this.prediction[n5] + d;
                    int n6 = n;
                    this.oob[n6] = this.oob[n6] + 1;
                    continue;
                }
            }
            return regressionTree;
        }
    }

    public static class Trainer
    extends RegressionTrainer<double[]> {
        private int ntrees = 500;
        private int mtry = -1;
        private int nodeSize = 5;
        private int maxNodes = 100;
        private double subsample = 1.0;

        public Trainer(int n) {
            if (n < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + n);
            }
            this.ntrees = n;
        }

        public Trainer(Attribute[] attributeArray, int n) {
            super(attributeArray);
            if (n < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + n);
            }
            this.ntrees = n;
        }

        public Trainer setNumTrees(int n) {
            if (n < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + n);
            }
            this.ntrees = n;
            return this;
        }

        public Trainer setNumRandomFeatures(int n) {
            if (n < 1) {
                throw new IllegalArgumentException("Invalid number of random selected features for splitting: " + n);
            }
            this.mtry = n;
            return this;
        }

        public Trainer setMaxNodes(int n) {
            if (n < 2) {
                throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + n);
            }
            this.maxNodes = n;
            return this;
        }

        public Trainer setNodeSize(int n) {
            if (n < 1) {
                throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + n);
            }
            this.nodeSize = n;
            return this;
        }

        public Trainer setSamplingRates(double d) {
            if (d <= 0.0 || d > 1.0) {
                throw new IllegalArgumentException("Invalid sampling rate: " + d);
            }
            this.subsample = d;
            return this;
        }

        public RandomForest train(double[][] dArray, double[] dArray2) {
            return new RandomForest(this.attributes, dArray, dArray2, this.ntrees, this.maxNodes, this.nodeSize, this.mtry, this.subsample);
        }
    }
}

