/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.nnet.learning;

import java.util.List;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;
import org.neuroph.nnet.learning.BackPropagation;

public class MomentumBackpropagation
extends BackPropagation {
    private static final long serialVersionUID = 1L;
    protected double momentum = 0.25;

    @Override
    protected void calculateErrorAndUpdateHiddenNeurons() {
        List<Layer> layers = this.neuralNetwork.getLayers();
        for (int layerIdx = layers.size() - 2; layerIdx > 0; --layerIdx) {
            List<Neuron> layerNeurons = layers.get(layerIdx).getNeurons();
            if (layerNeurons.size() >= 100) {
                layerNeurons.parallelStream().forEach(neuron -> {
                    double delta = this.calculateHiddenNeuronError((Neuron)neuron);
                    neuron.setDelta(delta);
                    this.calculateWeightChanges((Neuron)neuron);
                });
                continue;
            }
            for (Neuron neuron2 : layerNeurons) {
                double delta = this.calculateHiddenNeuronError(neuron2);
                neuron2.setDelta(delta);
                this.calculateWeightChanges(neuron2);
            }
        }
    }

    @Override
    public void calculateWeightChanges(Neuron neuron) {
        for (Connection connection : neuron.getInputConnections()) {
            double input = connection.getInput();
            if (input == 0.0) continue;
            double neuronDelta = neuron.getDelta();
            Weight weight = connection.getWeight();
            MomentumTrainingData weightTrainingData = (MomentumTrainingData)weight.getTrainingData();
            double weightChange = -this.learningRate * neuronDelta * input + this.momentum * weightTrainingData.previousWeightChange;
            weightTrainingData.previousWeightChange = weight.weightChange;
            if (!this.isBatchMode()) {
                weight.weightChange = weightChange;
                continue;
            }
            weight.weightChange += weightChange;
        }
    }

    public double getMomentum() {
        return this.momentum;
    }

    public void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    @Override
    protected void onStart() {
        super.onStart();
        for (Layer layer : this.neuralNetwork.getLayers()) {
            for (Neuron neuron : layer.getNeurons()) {
                for (Connection connection : neuron.getInputConnections()) {
                    connection.getWeight().setTrainingData(new MomentumTrainingData());
                }
            }
        }
    }

    public static class MomentumTrainingData {
        public double previousWeightChange;
    }
}

