package com.actelion.research.calc.regression.neuralnetwork;

import com.actelion.research.calc.Matrix;
import com.actelion.research.calc.regression.ARegressionMethod;
import com.actelion.research.calc.regression.ParameterRegressionMethod;
import com.actelion.research.util.datamodel.ModelXYIndex;
import smile.regression.NeuralNetwork;

/* loaded from: input_file:com/actelion/research/calc/regression/neuralnetwork/NeuralNetworkRegression.class */
public class NeuralNetworkRegression extends ARegressionMethod<ParameterNeuralNetwork> implements Comparable<NeuralNetworkRegression> {
    private NeuralNetwork neuralNetwork;

    public NeuralNetworkRegression() {
        setParameterRegressionMethod(new ParameterNeuralNetwork());
        System.setProperty("smile.threads", "1");
    }

    public NeuralNetworkRegression(ParameterNeuralNetwork parameterNeuralNetwork) {
        setParameterRegressionMethod(parameterNeuralNetwork);
    }

    @Override // com.actelion.research.calc.regression.ICalculateModel
    public Matrix createModel(ModelXYIndex modelXYIndex) {
        ParameterNeuralNetwork parameter;
        Matrix matrix = null;
        try {
            parameter = getParameter();
        } catch (Exception e) {
            e.printStackTrace();
        }
        if (modelXYIndex.Y.cols() != 1) {
            throw new RuntimeException("Only one column for y is allowed!");
        }
        double[][] array = modelXYIndex.X.getArray();
        double[] colAsDouble = modelXYIndex.Y.getColAsDouble(0);
        int[] iArr = new int[parameter.getArrInnerLayerArchitecture().length + 2];
        System.arraycopy(parameter.getArrInnerLayerArchitecture(), 0, iArr, 1, parameter.getArrInnerLayerArchitecture().length);
        iArr[0] = modelXYIndex.X.cols();
        iArr[iArr.length - 1] = 1;
        this.neuralNetwork = new NeuralNetwork(parameter.getActivationFunction(), iArr);
        this.neuralNetwork.learn(array, colAsDouble);
        matrix = calculateYHat(modelXYIndex.X);
        return matrix;
    }

    @Override // com.actelion.research.calc.regression.ICalculateYHat
    public Matrix calculateYHat(Matrix matrix) {
        double[] dArr = new double[matrix.rows()];
        for (int i = 0; i < matrix.rows(); i++) {
            dArr[i] = this.neuralNetwork.predict(matrix.getRow(i));
        }
        return new Matrix(false, dArr);
    }

    @Override // com.actelion.research.calc.regression.ICalculateYHat
    public double calculateYHat(double[] dArr) {
        double predict;
        synchronized (this) {
            predict = this.neuralNetwork.predict(dArr);
        }
        return predict;
    }

    @Override // java.lang.Comparable
    public int compareTo(NeuralNetworkRegression neuralNetworkRegression) {
        return getParameter().compareTo((ParameterRegressionMethod) neuralNetworkRegression.getParameter());
    }
}
