/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.nnet.learning;

import java.util.List;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;
import org.neuroph.nnet.learning.BackPropagation;

public class ResilientPropagation
extends BackPropagation {
    private double decreaseFactor = 0.5;
    private double increaseFactor = 1.2;
    private double initialDelta = 0.1;
    private double maxDelta = 1.0;
    private double minDelta = 1.0E-6;
    private static final double ZERO_TOLERANCE = 1.0E-27;

    public ResilientPropagation() {
        super.setBatchMode(true);
    }

    private int sign(double value) {
        if (Math.abs(value) < 1.0E-27) {
            return 0;
        }
        if (value > 0.0) {
            return 1;
        }
        return -1;
    }

    @Override
    protected void onStart() {
        super.onStart();
        for (Layer layer : this.neuralNetwork.getLayers()) {
            for (Neuron neuron : layer.getNeurons()) {
                for (Connection connection : neuron.getInputConnections()) {
                    connection.getWeight().setTrainingData(new ResilientWeightTrainingtData());
                }
            }
        }
    }

    @Override
    public void calculateWeightChanges(Neuron neuron) {
        for (Connection connection : neuron.getInputConnections()) {
            double input = connection.getInput();
            if (input == 0.0) continue;
            double neuronError = neuron.getDelta();
            Weight weight = connection.getWeight();
            ResilientWeightTrainingtData weightData = (ResilientWeightTrainingtData)weight.getTrainingData();
            weightData.gradient += -neuronError * input;
        }
    }

    @Override
    protected void doBatchWeightsUpdate() {
        List<Layer> layers = this.neuralNetwork.getLayers();
        for (int i = this.neuralNetwork.getLayersCount() - 1; i > 0; --i) {
            for (Neuron neuron : layers.get(i).getNeurons()) {
                for (Connection connection : neuron.getInputConnections()) {
                    Weight weight = connection.getWeight();
                    this.resillientWeightUpdate(weight);
                }
            }
        }
    }

    protected void resillientWeightUpdate(Weight weight) {
        ResilientWeightTrainingtData weightData = (ResilientWeightTrainingtData)weight.getTrainingData();
        int gradientSignChange = this.sign(weightData.previousGradient * weightData.gradient);
        double weightChange = 0.0;
        if (gradientSignChange > 0) {
            double delta = Math.min(weightData.previousDelta * this.increaseFactor, this.maxDelta);
            weightChange = (double)this.sign(weightData.gradient) * delta;
            weightData.previousDelta = delta;
        } else if (gradientSignChange < 0) {
            double delta = Math.max(weightData.previousDelta * this.decreaseFactor, this.minDelta);
            weightChange = -weightData.previousWeightChange;
            weightData.gradient = 0.0;
            weightData.previousGradient = 0.0;
            weightData.previousDelta = delta;
        } else if (gradientSignChange == 0) {
            double delta = weightData.previousDelta;
            weightChange = (double)this.sign(weightData.gradient) * delta;
        }
        weightData.previousWeightChange = weightChange;
        weightData.previousGradient = weightData.gradient;
        weightData.gradient = 0.0;
    }

    public double getDecreaseFactor() {
        return this.decreaseFactor;
    }

    public void setDecreaseFactor(double decreaseFactor) {
        this.decreaseFactor = decreaseFactor;
    }

    public double getIncreaseFactor() {
        return this.increaseFactor;
    }

    public void setIncreaseFactor(double increaseFactor) {
        this.increaseFactor = increaseFactor;
    }

    public double getInitialDelta() {
        return this.initialDelta;
    }

    public void setInitialDelta(double initialDelta) {
        this.initialDelta = initialDelta;
    }

    public double getMaxDelta() {
        return this.maxDelta;
    }

    public void setMaxDelta(double maxDelta) {
        this.maxDelta = maxDelta;
    }

    public double getMinDelta() {
        return this.minDelta;
    }

    public void setMinDelta(double minDelta) {
        this.minDelta = minDelta;
    }

    @Override
    public void setBatchMode(boolean batchMode) {
        if (!batchMode) {
            throw new IllegalStateException("Resilient propagation runs only in batch mode!");
        }
    }

    public class ResilientWeightTrainingtData {
        public double gradient;
        public double previousGradient;
        public double previousWeightChange;
        public double previousDelta;

        public ResilientWeightTrainingtData() {
            this.previousDelta = ResilientPropagation.this.initialDelta;
        }
    }
}

