/*
 * Decompiled with CFR 0.152.
 */
package com.actelion.research.calc.regression.neuralnetwork;

import com.actelion.research.calc.Matrix;
import com.actelion.research.calc.regression.ARegressionMethod;
import com.actelion.research.calc.regression.ParameterRegressionMethod;
import com.actelion.research.calc.regression.neuralnetwork.ParameterNeuralNetwork;
import com.actelion.research.util.datamodel.ModelXYIndex;
import smile.regression.NeuralNetwork;

public class NeuralNetworkRegression
extends ARegressionMethod<ParameterNeuralNetwork>
implements Comparable<NeuralNetworkRegression> {
    private NeuralNetwork neuralNetwork;

    public NeuralNetworkRegression() {
        this.setParameterRegressionMethod(new ParameterNeuralNetwork());
        System.setProperty("smile.threads", "1");
    }

    public NeuralNetworkRegression(ParameterNeuralNetwork parameterNeuralNetwork) {
        this.setParameterRegressionMethod(parameterNeuralNetwork);
    }

    @Override
    public Matrix createModel(ModelXYIndex modelXYIndex) {
        Matrix matrix = null;
        try {
            ParameterNeuralNetwork parameterNeuralNetwork = (ParameterNeuralNetwork)this.getParameter();
            if (modelXYIndex.Y.cols() != 1) {
                throw new RuntimeException("Only one column for y is allowed!");
            }
            double[][] dArray = modelXYIndex.X.getArray();
            double[] dArray2 = modelXYIndex.Y.getColAsDouble(0);
            int[] nArray = new int[parameterNeuralNetwork.getArrInnerLayerArchitecture().length + 2];
            System.arraycopy(parameterNeuralNetwork.getArrInnerLayerArchitecture(), 0, nArray, 1, parameterNeuralNetwork.getArrInnerLayerArchitecture().length);
            nArray[0] = modelXYIndex.X.cols();
            nArray[nArray.length - 1] = 1;
            this.neuralNetwork = new NeuralNetwork(parameterNeuralNetwork.getActivationFunction(), nArray);
            this.neuralNetwork.learn(dArray, dArray2);
            matrix = this.calculateYHat(modelXYIndex.X);
        }
        catch (Exception exception) {
            exception.printStackTrace();
        }
        return matrix;
    }

    @Override
    public Matrix calculateYHat(Matrix matrix) {
        double[] dArray = new double[matrix.rows()];
        for (int i = 0; i < matrix.rows(); ++i) {
            double d;
            double[] dArray2 = matrix.getRow(i);
            dArray[i] = d = this.neuralNetwork.predict(dArray2);
        }
        return new Matrix(false, dArray);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public double calculateYHat(double[] dArray) {
        double d;
        NeuralNetworkRegression neuralNetworkRegression = this;
        synchronized (neuralNetworkRegression) {
            d = this.neuralNetwork.predict(dArray);
        }
        return d;
    }

    @Override
    public int compareTo(NeuralNetworkRegression neuralNetworkRegression) {
        return ((ParameterNeuralNetwork)this.getParameter()).compareTo((ParameterRegressionMethod)neuralNetworkRegression.getParameter());
    }
}

