package com.actelion.research.calc.regression.randomforest;

import com.actelion.research.calc.Matrix;
import com.actelion.research.calc.regression.ARegressionMethod;
import com.actelion.research.calc.regression.ParameterRegressionMethod;
import com.actelion.research.util.datamodel.ModelXYIndex;
import smile.regression.RandomForest;

/* loaded from: input_file:com/actelion/research/calc/regression/randomforest/RandomForestRegression.class */
public class RandomForestRegression extends ARegressionMethod<ParameterRandomForest> implements Comparable<RandomForestRegression> {
    public static final int MIN_NUM_VAR_SPLIT = 3;
    private RandomForest forest;

    public RandomForestRegression() {
        setParameterRegressionMethod(new ParameterRandomForest());
        try {
            System.setProperty("smile.threads", "1");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public RandomForestRegression(ParameterRandomForest parameterRandomForest) {
        setParameterRegressionMethod(parameterRandomForest);
    }

    @Override // com.actelion.research.calc.regression.ICalculateModel
    public Matrix createModel(ModelXYIndex modelXYIndex) {
        Matrix matrix = null;
        try {
            ParameterRandomForest parameter = getParameter();
            int cols = (int) ((modelXYIndex.X.cols() * parameter.getFractionMTry()) + 0.5d);
            if (cols < 3) {
                cols = 3;
            }
            this.forest = new RandomForest(modelXYIndex.X.getArray(), modelXYIndex.Y.getColAsDouble(0), parameter.getNumberOfTrees(), parameter.getMaxNodes(), parameter.getNodeSize(), cols);
            matrix = calculateYHat(modelXYIndex.X);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return matrix;
    }

    @Override // com.actelion.research.calc.regression.ICalculateYHat
    public Matrix calculateYHat(Matrix matrix) {
        double[] dArr = new double[matrix.rows()];
        for (int i = 0; i < matrix.rows(); i++) {
            dArr[i] = this.forest.predict(matrix.getRow(i));
        }
        return new Matrix(false, dArr);
    }

    @Override // com.actelion.research.calc.regression.ICalculateYHat
    public double calculateYHat(double[] dArr) {
        double predict;
        synchronized (this) {
            predict = this.forest.predict(dArr);
        }
        return predict;
    }

    @Override // java.lang.Comparable
    public int compareTo(RandomForestRegression randomForestRegression) {
        return getParameter().compareTo((ParameterRegressionMethod) randomForestRegression.getParameter());
    }
}
