/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.io.Serializable;
import smile.math.Math;
import smile.regression.OnlineRegression;
import smile.regression.RegressionTrainer;

public class NeuralNetwork
implements OnlineRegression<double[]> {
    private static final long serialVersionUID = 1L;
    private ActivationFunction activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
    private int p;
    private Layer[] net;
    private Layer inputLayer;
    private Layer outputLayer;
    private double eta = 0.1;
    private double alpha = 0.0;
    private double lambda = 0.0;

    public NeuralNetwork(int ... nArray) {
        this(ActivationFunction.LOGISTIC_SIGMOID, nArray);
    }

    public NeuralNetwork(ActivationFunction activationFunction, int ... nArray) {
        this(activationFunction, 1.0E-4, 0.9, nArray);
    }

    public NeuralNetwork(ActivationFunction activationFunction, double d, double d2, int ... nArray) {
        int n;
        int n2 = nArray.length;
        if (n2 < 2) {
            throw new IllegalArgumentException("Invalid number of layers: " + n2);
        }
        for (n = 0; n < n2; ++n) {
            if (nArray[n] >= 1) continue;
            throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", n + 1, nArray[n]));
        }
        if (nArray[n2 - 1] != 1) {
            throw new IllegalArgumentException(String.format("Invalid number of units in output layer %d", nArray[n2 - 1]));
        }
        this.activationFunction = activationFunction;
        this.alpha = d;
        this.lambda = d2;
        this.p = nArray[0];
        this.net = new Layer[n2];
        for (n = 0; n < n2; ++n) {
            this.net[n] = new Layer();
            this.net[n].units = nArray[n];
            this.net[n].output = new double[nArray[n] + 1];
            this.net[n].error = new double[nArray[n] + 1];
            this.net[n].output[nArray[n]] = 1.0;
        }
        this.inputLayer = this.net[0];
        this.outputLayer = this.net[n2 - 1];
        for (n = 1; n < n2; ++n) {
            this.net[n].weight = new double[nArray[n]][nArray[n - 1] + 1];
            this.net[n].delta = new double[nArray[n]][nArray[n - 1] + 1];
            double d3 = 1.0 / Math.sqrt(this.net[n - 1].units);
            for (int i = 0; i < this.net[n].units; ++i) {
                for (int j = 0; j <= this.net[n - 1].units; ++j) {
                    this.net[n].weight[i][j] = Math.random(-d3, d3);
                }
            }
        }
    }

    private NeuralNetwork() {
    }

    public NeuralNetwork clone() {
        NeuralNetwork neuralNetwork = new NeuralNetwork();
        neuralNetwork.activationFunction = this.activationFunction;
        neuralNetwork.p = this.p;
        neuralNetwork.eta = this.eta;
        neuralNetwork.alpha = this.alpha;
        neuralNetwork.lambda = this.lambda;
        int n = this.net.length;
        neuralNetwork.net = new Layer[n];
        for (int i = 0; i < n; ++i) {
            neuralNetwork.net[i] = new Layer();
            neuralNetwork.net[i].units = this.net[i].units;
            neuralNetwork.net[i].output = (double[])this.net[i].output.clone();
            neuralNetwork.net[i].error = (double[])this.net[i].error.clone();
            if (i <= 0) continue;
            neuralNetwork.net[i].weight = Math.clone(this.net[i].weight);
            neuralNetwork.net[i].delta = Math.clone(this.net[i].delta);
        }
        neuralNetwork.inputLayer = neuralNetwork.net[0];
        neuralNetwork.outputLayer = neuralNetwork.net[n - 1];
        return neuralNetwork;
    }

    public void setLearningRate(double d) {
        if (d <= 0.0) {
            throw new IllegalArgumentException("Invalid learning rate: " + d);
        }
        this.eta = d;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public void setMomentum(double d) {
        if (d < 0.0 || d >= 1.0) {
            throw new IllegalArgumentException("Invalid momentum factor: " + d);
        }
        this.alpha = d;
    }

    public double getMomentum() {
        return this.alpha;
    }

    public void setWeightDecay(double d) {
        if (d < 0.0 || d > 0.1) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + d);
        }
        this.lambda = d;
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    public double[][] getWeight(int n) {
        return this.net[n].weight;
    }

    private void setInput(double[] dArray) {
        if (dArray.length != this.inputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", dArray.length, this.inputLayer.units));
        }
        System.arraycopy(dArray, 0, this.inputLayer.output, 0, this.inputLayer.units);
    }

    private void propagate(Layer layer, Layer layer2) {
        for (int i = 0; i < layer2.units; ++i) {
            double d = 0.0;
            for (int j = 0; j <= layer.units; ++j) {
                d += layer2.weight[i][j] * layer.output[j];
            }
            if (layer2 == this.outputLayer) {
                layer2.output[i] = d;
                continue;
            }
            if (this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                layer2.output[i] = Math.logistic(d);
                continue;
            }
            if (this.activationFunction != ActivationFunction.TANH) continue;
            layer2.output[i] = 2.0 * Math.logistic(2.0 * d) - 1.0;
        }
    }

    private void propagate() {
        for (int i = 0; i < this.net.length - 1; ++i) {
            this.propagate(this.net[i], this.net[i + 1]);
        }
    }

    private double computeOutputError(double d) {
        return this.computeOutputError(d, this.outputLayer.error);
    }

    private double computeOutputError(double d, double[] dArray) {
        double d2 = 0.0;
        double d3 = this.outputLayer.output[0];
        double d4 = d - d3;
        dArray[0] = d4;
        return d2 += 0.5 * d4 * d4;
    }

    private void backpropagate(Layer layer, Layer layer2) {
        for (int i = 0; i <= layer2.units; ++i) {
            double d = layer2.output[i];
            double d2 = 0.0;
            for (int j = 0; j < layer.units; ++j) {
                d2 += layer.weight[j][i] * layer.error[j];
            }
            if (this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                layer2.error[i] = d * (1.0 - d) * d2;
                continue;
            }
            if (this.activationFunction != ActivationFunction.TANH) continue;
            layer2.error[i] = (1.0 - d * d) * d2;
        }
    }

    private void backpropagate() {
        int n = this.net.length;
        while (--n > 0) {
            this.backpropagate(this.net[n], this.net[n - 1]);
        }
    }

    private void adjustWeights() {
        for (int i = 1; i < this.net.length; ++i) {
            for (int j = 0; j < this.net[i].units; ++j) {
                for (int k = 0; k <= this.net[i - 1].units; ++k) {
                    double d;
                    double d2 = this.net[i - 1].output[k];
                    double d3 = this.net[i].error[j];
                    this.net[i].delta[j][k] = d = (1.0 - this.alpha) * this.eta * d3 * d2 + this.alpha * this.net[i].delta[j][k];
                    double[] dArray = this.net[i].weight[j];
                    int n = k;
                    dArray[n] = dArray[n] + d;
                    if (this.lambda == 0.0 || k >= this.net[i - 1].units) continue;
                    double[] dArray2 = this.net[i].weight[j];
                    int n2 = k;
                    dArray2[n2] = dArray2[n2] * (1.0 - this.eta * this.lambda);
                }
            }
        }
    }

    @Override
    public double predict(double[] dArray) {
        this.setInput(dArray);
        this.propagate();
        return this.outputLayer.output[0];
    }

    public double learn(double[] dArray, double d, double d2) {
        this.setInput(dArray);
        this.propagate();
        double d3 = d2 * this.computeOutputError(d);
        if (d2 != 1.0) {
            this.outputLayer.error[0] = this.outputLayer.error[0] * d2;
        }
        this.backpropagate();
        this.adjustWeights();
        return d3;
    }

    @Override
    public void learn(double[] dArray, double d) {
        this.learn(dArray, d, 1.0);
    }

    public void learn(double[][] dArray, double[] dArray2) {
        int n = dArray.length;
        int[] nArray = Math.permutate(n);
        for (int i = 0; i < n; ++i) {
            this.learn(dArray[nArray[i]], dArray2[nArray[i]]);
        }
    }

    public static class Trainer
    extends RegressionTrainer<double[]> {
        private ActivationFunction activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
        private int[] numUnits;
        private double eta = 0.1;
        private double alpha = 0.0;
        private double lambda = 0.0;
        private int epochs = 25;

        public Trainer(int ... nArray) {
            this(ActivationFunction.LOGISTIC_SIGMOID, nArray);
        }

        public Trainer(ActivationFunction activationFunction, int ... nArray) {
            int n = nArray.length;
            if (n < 2) {
                throw new IllegalArgumentException("Invalid number of layers: " + n);
            }
            for (int i = 0; i < n; ++i) {
                if (nArray[i] >= 1) continue;
                throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", i + 1, nArray[i]));
            }
            if (nArray[n - 1] != 1) {
                throw new IllegalArgumentException(String.format("Invalid number of units in output layer %d", nArray[n - 1]));
            }
            this.activationFunction = activationFunction;
            this.numUnits = nArray;
        }

        public Trainer setLearningRate(double d) {
            if (d <= 0.0) {
                throw new IllegalArgumentException("Invalid learning rate: " + d);
            }
            this.eta = d;
            return this;
        }

        public Trainer setMomentum(double d) {
            if (d < 0.0 || d >= 1.0) {
                throw new IllegalArgumentException("Invalid momentum factor: " + d);
            }
            this.alpha = d;
            return this;
        }

        public Trainer setWeightDecay(double d) {
            if (d < 0.0 || d > 0.1) {
                throw new IllegalArgumentException("Invalid weight decay factor: " + d);
            }
            this.lambda = d;
            return this;
        }

        public Trainer setNumEpochs(int n) {
            if (n < 1) {
                throw new IllegalArgumentException("Invalid numer of epochs of stochastic learning:" + n);
            }
            this.epochs = n;
            return this;
        }

        public NeuralNetwork train(double[][] dArray, double[] dArray2) {
            NeuralNetwork neuralNetwork = new NeuralNetwork(this.activationFunction, this.numUnits);
            neuralNetwork.setLearningRate(this.eta);
            neuralNetwork.setMomentum(this.alpha);
            neuralNetwork.setWeightDecay(this.lambda);
            for (int i = 1; i <= this.epochs; ++i) {
                neuralNetwork.learn(dArray, dArray2);
                System.out.println("Neural network learns epoch " + i);
            }
            return neuralNetwork;
        }
    }

    private class Layer
    implements Serializable {
        private static final long serialVersionUID = 1L;
        int units;
        double[] output;
        double[] error;
        double[][] weight;
        double[][] delta;

        private Layer() {
        }
    }

    public static enum ActivationFunction {
        LOGISTIC_SIGMOID,
        TANH;

    }
}

