/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.core;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.NeuralNetworkEvent;
import org.neuroph.core.events.NeuralNetworkEventListener;
import org.neuroph.core.exceptions.NeurophException;
import org.neuroph.core.exceptions.VectorSizeMismatchException;
import org.neuroph.core.learning.IterativeLearning;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.util.NeuralNetworkType;
import org.neuroph.util.plugins.PluginBase;
import org.neuroph.util.random.RangeRandomizer;
import org.neuroph.util.random.WeightsRandomizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NeuralNetwork<L extends LearningRule>
implements Serializable {
    private static final long serialVersionUID = 7L;
    private NeuralNetworkType type;
    private List<Layer> layers;
    private L learningRule;
    protected double[] outputBuffer;
    private List<Neuron> inputNeurons;
    private List<Neuron> outputNeurons;
    private Map<Class, PluginBase> plugins;
    private String label = "";
    private transient List<NeuralNetworkEventListener> listeners = new ArrayList<NeuralNetworkEventListener>();
    private final Logger LOGGER = LoggerFactory.getLogger(NeuralNetwork.class);

    public NeuralNetwork() {
        this.layers = new ArrayList<Layer>();
        this.inputNeurons = new ArrayList<Neuron>();
        this.outputNeurons = new ArrayList<Neuron>();
        this.plugins = new HashMap<Class, PluginBase>();
    }

    public void addLayer(Layer layer) {
        if (layer == null) {
            throw new IllegalArgumentException("Layer cant be null!");
        }
        layer.setParentNetwork(this);
        this.layers.add(layer);
        this.fireNetworkEvent(new NeuralNetworkEvent(layer, NeuralNetworkEvent.Type.LAYER_ADDED));
    }

    public void addLayer(int index, Layer layer) {
        if (layer == null) {
            throw new IllegalArgumentException("Layer cant be null!");
        }
        if (index < 0) {
            throw new IllegalArgumentException("Layer index cannot be negative: " + index);
        }
        layer.setParentNetwork(this);
        this.layers.add(index, layer);
        this.fireNetworkEvent(new NeuralNetworkEvent(layer, NeuralNetworkEvent.Type.LAYER_ADDED));
    }

    public void removeLayer(Layer layer) {
        if (!this.layers.remove(layer)) {
            throw new RuntimeException("Layer not in Neural n/w");
        }
        this.fireNetworkEvent(new NeuralNetworkEvent(layer, NeuralNetworkEvent.Type.LAYER_REMOVED));
    }

    public void removeLayerAt(int index) {
        Layer layer = this.layers.get(index);
        this.layers.remove(index);
        this.fireNetworkEvent(new NeuralNetworkEvent(layer, NeuralNetworkEvent.Type.LAYER_REMOVED));
    }

    public List<Layer> getLayers() {
        return Collections.unmodifiableList(this.layers);
    }

    public Layer getLayerAt(int index) {
        return this.layers.get(index);
    }

    public int indexOf(Layer layer) {
        return this.layers.indexOf(layer);
    }

    public int getLayersCount() {
        return this.layers.size();
    }

    public void setInput(double ... inputVector) throws VectorSizeMismatchException {
        if (inputVector.length != this.inputNeurons.size()) {
            throw new VectorSizeMismatchException("Input vector size does not match network input dimension!");
        }
        int idx = 0;
        for (Neuron neuron : this.inputNeurons) {
            neuron.setInput(inputVector[idx]);
            ++idx;
        }
    }

    public double[] getOutput() {
        int i = 0;
        for (Neuron c : this.outputNeurons) {
            this.outputBuffer[i] = c.getOutput();
            ++i;
        }
        return this.outputBuffer;
    }

    public void calculate() {
        this.layers.forEach(layer -> layer.calculate());
        this.fireNetworkEvent(new NeuralNetworkEvent(this, NeuralNetworkEvent.CALCULATED));
    }

    public void reset() {
        this.layers.forEach(layer -> layer.reset());
    }

    public void learn(DataSet trainingSet) {
        if (trainingSet == null) {
            throw new IllegalArgumentException("Training set is null!");
        }
        ((LearningRule)this.learningRule).learn(trainingSet);
    }

    public void stopLearning() {
        ((LearningRule)this.learningRule).stopLearning();
    }

    public void pauseLearning() {
        if (this.learningRule instanceof IterativeLearning) {
            ((IterativeLearning)this.learningRule).pause();
        }
    }

    public void resumeLearning() {
        if (this.learningRule instanceof IterativeLearning) {
            ((IterativeLearning)this.learningRule).resume();
        }
    }

    public void randomizeWeights() {
        this.randomizeWeights(new WeightsRandomizer());
    }

    public void randomizeWeights(double minWeight, double maxWeight) {
        this.randomizeWeights(new RangeRandomizer(minWeight, maxWeight));
    }

    public void randomizeWeights(Random random) {
        this.randomizeWeights(new WeightsRandomizer(random));
    }

    public void randomizeWeights(WeightsRandomizer randomizer) {
        randomizer.randomize(this);
    }

    public NeuralNetworkType getNetworkType() {
        return this.type;
    }

    public void setNetworkType(NeuralNetworkType type) {
        this.type = type;
    }

    public List<Neuron> getInputNeurons() {
        return this.inputNeurons;
    }

    public int getInputsCount() {
        return this.inputNeurons.size();
    }

    public void setInputNeurons(List<Neuron> inputNeurons) {
        for (Neuron neuron : inputNeurons) {
            this.inputNeurons.add(neuron);
        }
    }

    public List<Neuron> getOutputNeurons() {
        return this.outputNeurons;
    }

    public int getOutputsCount() {
        return this.outputNeurons.size();
    }

    public void setOutputNeurons(List<Neuron> outputNeurons) {
        for (Neuron neuron : outputNeurons) {
            this.outputNeurons.add(neuron);
        }
        this.outputBuffer = new double[outputNeurons.size()];
    }

    public void setOutputLabels(String[] labels) {
        for (int i = 0; i < this.outputNeurons.size(); ++i) {
            this.outputNeurons.get(i).setLabel(labels[i]);
        }
    }

    public String[] getOutputLabels() {
        String[] labels = new String[this.outputNeurons.size()];
        for (int i = 0; i < this.outputNeurons.size(); ++i) {
            labels[i] = this.outputNeurons.get(i).getLabel();
        }
        return labels;
    }

    public L getLearningRule() {
        return this.learningRule;
    }

    public void setLearningRule(L learningRule) {
        if (learningRule == null) {
            throw new IllegalArgumentException("Learning rule can't be null!");
        }
        ((LearningRule)learningRule).setNeuralNetwork(this);
        this.learningRule = learningRule;
    }

    public Double[] getWeights() {
        ArrayList<Double> weights = new ArrayList<Double>();
        for (Layer layer : this.layers) {
            for (Neuron neuron : layer.getNeurons()) {
                for (Connection conn : neuron.getInputConnections()) {
                    weights.add(conn.getWeight().getValue());
                }
            }
        }
        return weights.toArray(new Double[weights.size()]);
    }

    public void setWeights(double[] weights) {
        int i = 0;
        for (Layer layer : this.layers) {
            for (Neuron neuron : layer.getNeurons()) {
                for (Connection conn : neuron.getInputConnections()) {
                    conn.getWeight().setValue(weights[i]);
                    ++i;
                }
            }
        }
    }

    public boolean isEmpty() {
        return this.layers.isEmpty();
    }

    public void createConnection(Neuron fromNeuron, Neuron toNeuron, double weightVal) {
        toNeuron.addInputConnection(fromNeuron, weightVal);
    }

    public String toString() {
        if (this.label != null) {
            return this.label;
        }
        return super.toString();
    }

    public void save(String filePath) {
        ObjectOutputStream out = null;
        try {
            File file = new File(filePath);
            out = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
            out.writeObject(this);
            out.flush();
        }
        catch (IOException ioe) {
            throw new NeurophException("Could not write neural network to file!", ioe);
        }
        finally {
            if (out != null) {
                try {
                    out.close();
                }
                catch (IOException iOException) {}
            }
        }
    }

    public static NeuralNetwork load(String filePath) {
        ObjectInputStream oistream = null;
        try {
            NeuralNetwork nnet;
            File file = new File(filePath);
            if (!file.exists()) {
                throw new FileNotFoundException("Cannot find file: " + filePath);
            }
            oistream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(filePath)));
            NeuralNetwork neuralNetwork = nnet = (NeuralNetwork)oistream.readObject();
            return neuralNetwork;
        }
        catch (IOException ioe) {
            throw new NeurophException("Could not read neural network file!", ioe);
        }
        catch (ClassNotFoundException cnfe) {
            throw new NeurophException("Class not found while trying to read neural network from file!", cnfe);
        }
        finally {
            if (oistream != null) {
                try {
                    oistream.close();
                }
                catch (IOException iOException) {}
            }
        }
    }

    public static NeuralNetwork load(InputStream inputStream) {
        ObjectInputStream oistream = null;
        try {
            NeuralNetwork nnet;
            oistream = new ObjectInputStream(new BufferedInputStream(inputStream));
            NeuralNetwork neuralNetwork = nnet = (NeuralNetwork)oistream.readObject();
            return neuralNetwork;
        }
        catch (IOException ioe) {
            throw new NeurophException("Could not read neural network file!", ioe);
        }
        catch (ClassNotFoundException cnfe) {
            throw new NeurophException("Class not found while trying to read neural network from file!", cnfe);
        }
        finally {
            if (oistream != null) {
                try {
                    oistream.close();
                }
                catch (IOException iOException) {}
            }
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.listeners = new ArrayList<NeuralNetworkEventListener>();
    }

    public static NeuralNetwork createFromFile(File file) {
        ObjectInputStream oistream = null;
        try {
            NeuralNetwork nnet;
            if (!file.exists()) {
                throw new FileNotFoundException("Cannot find file: " + file);
            }
            oistream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(file)));
            NeuralNetwork neuralNetwork = nnet = (NeuralNetwork)oistream.readObject();
            return neuralNetwork;
        }
        catch (IOException ioe) {
            throw new NeurophException("Could not read neural network file!", ioe);
        }
        catch (ClassNotFoundException cnfe) {
            throw new NeurophException("Class not found while trying to read neural network from file!", cnfe);
        }
        finally {
            if (oistream != null) {
                try {
                    oistream.close();
                }
                catch (IOException iOException) {}
            }
        }
    }

    public static NeuralNetwork createFromFile(String filePath) {
        File file = new File(filePath);
        return NeuralNetwork.createFromFile(file);
    }

    public void addPlugin(PluginBase plugin) {
        plugin.setParentNetwork(this);
        this.plugins.put(plugin.getClass(), plugin);
    }

    public <T extends PluginBase> T getPlugin(Class<T> pluginClass) {
        return (T)((PluginBase)pluginClass.cast(this.plugins.get(pluginClass)));
    }

    public void removePlugin(Class pluginClass) {
        this.plugins.remove(pluginClass);
    }

    public String getLabel() {
        return this.label;
    }

    public void setLabel(String label) {
        this.label = label;
    }

    public synchronized void addListener(NeuralNetworkEventListener listener) {
        if (listener == null) {
            throw new IllegalArgumentException("listener is null!");
        }
        this.listeners.add(listener);
    }

    public synchronized void removeListener(NeuralNetworkEventListener listener) {
        if (listener == null) {
            throw new IllegalArgumentException("listener is null!");
        }
        this.listeners.remove(listener);
    }

    public synchronized void fireNetworkEvent(NeuralNetworkEvent evt) {
        for (NeuralNetworkEventListener listener : this.listeners) {
            listener.handleNeuralNetworkEvent(evt);
        }
    }
}

