/*
 * Decompiled with CFR 0.152.
 */
package marytts.unitselection.select;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureProcessorManager;
import marytts.features.FeatureVector;
import marytts.features.TargetFeatureComputer;
import marytts.server.MaryProperties;
import marytts.signalproc.display.Histogram;
import marytts.unitselection.data.FeatureFileReader;
import marytts.unitselection.data.Unit;
import marytts.unitselection.select.Target;
import marytts.unitselection.select.TargetCostFunction;
import marytts.unitselection.weightingfunctions.WeightFunc;
import marytts.unitselection.weightingfunctions.WeightFunctionManager;
import marytts.util.MaryUtils;

public class FFRTargetCostFunction
implements TargetCostFunction {
    protected WeightFunc[] weightFunction;
    protected TargetFeatureComputer targetFeatureComputer;
    protected FeatureVector[] featureVectors;
    protected FeatureDefinition featureDefinition;
    protected boolean[] weightsNonZero;
    protected boolean debugShowCostGraph = false;
    protected double[] cumulWeightedCosts = null;
    protected int nCostComputations = 0;

    @Override
    public double cost(Target target, Unit unit) {
        return this.cost(target, unit, this.featureDefinition, this.weightFunction);
    }

    protected double cost(Target target, Unit unit, FeatureDefinition weights, WeightFunc[] weightFunctions) {
        int i;
        ++this.nCostComputations;
        FeatureVector targetFeatures = target.getFeatureVector();
        assert (targetFeatures != null) : "Target " + target + " does not have pre-computed feature vector";
        FeatureVector unitFeatures = this.featureVectors[unit.index];
        int nBytes = targetFeatures.byteValuedDiscreteFeatures.length;
        int nShorts = targetFeatures.shortValuedDiscreteFeatures.length;
        int nFloats = targetFeatures.continuousFeatures.length;
        assert (nBytes == unitFeatures.byteValuedDiscreteFeatures.length);
        assert (nShorts == unitFeatures.shortValuedDiscreteFeatures.length);
        assert (nFloats == unitFeatures.continuousFeatures.length);
        float[] weightVector = weights.getFeatureWeights();
        double cost = 0.0;
        if (nBytes > 0) {
            i = 0;
            while (i < nBytes) {
                if (this.weightsNonZero[i]) {
                    float weight = weightVector[i];
                    if (this.featureDefinition.hasSimilarityMatrix(i)) {
                        byte targetFeatValueIndex = targetFeatures.byteValuedDiscreteFeatures[i];
                        byte unitFeatValueIndex = unitFeatures.byteValuedDiscreteFeatures[i];
                        float similarity = this.featureDefinition.getSimilarity(i, unitFeatValueIndex, targetFeatValueIndex);
                        cost += (double)(similarity * weight);
                        if (this.debugShowCostGraph) {
                            int n = i;
                            this.cumulWeightedCosts[n] = this.cumulWeightedCosts[n] + (double)(similarity * weight);
                        }
                    } else if (targetFeatures.byteValuedDiscreteFeatures[i] != unitFeatures.byteValuedDiscreteFeatures[i]) {
                        cost += (double)weight;
                        if (this.debugShowCostGraph) {
                            int n = i;
                            this.cumulWeightedCosts[n] = this.cumulWeightedCosts[n] + (double)weight;
                        }
                    }
                }
                ++i;
            }
        }
        if (nShorts > 0) {
            i = nBytes;
            int n = nBytes + nShorts;
            while (i < n) {
                if (this.weightsNonZero[i]) {
                    float weight = weightVector[i];
                    if (targetFeatures.shortValuedDiscreteFeatures[i - nBytes] != unitFeatures.shortValuedDiscreteFeatures[i - nBytes]) {
                        cost += (double)weight;
                        if (this.debugShowCostGraph) {
                            int n2 = i;
                            this.cumulWeightedCosts[n2] = this.cumulWeightedCosts[n2] + (double)weight;
                        }
                    }
                }
                ++i;
            }
        }
        if (nFloats > 0) {
            int nDiscrete;
            int i2 = nDiscrete = nBytes + nShorts;
            int n = nDiscrete + nFloats;
            while (i2 < n) {
                if (this.weightsNonZero[i2]) {
                    float weight = weightVector[i2];
                    float a = targetFeatures.continuousFeatures[i2 - nDiscrete];
                    float b = unitFeatures.continuousFeatures[i2 - nDiscrete];
                    if (a == a && b == b) {
                        double myCost = weightFunctions[i2 - nDiscrete].cost(a, b);
                        cost += (double)weight * myCost;
                        if (this.debugShowCostGraph) {
                            int n3 = i2;
                            this.cumulWeightedCosts[n3] = this.cumulWeightedCosts[n3] + (double)weight * myCost;
                        }
                    }
                }
                ++i2;
            }
        }
        return cost;
    }

    public double featureCost(Target target, Unit unit, String featureName) {
        return this.featureCost(target, unit, featureName, this.featureDefinition, this.weightFunction);
    }

    protected double featureCost(Target target, Unit unit, String featureName, FeatureDefinition weights, WeightFunc[] weightFunctions) {
        if (!this.featureDefinition.hasFeature(featureName)) {
            throw new IllegalArgumentException("this feature does not exists in feature definition");
        }
        FeatureVector targetFeatures = target.getFeatureVector();
        assert (targetFeatures != null) : "Target " + target + " does not have pre-computed feature vector";
        FeatureVector unitFeatures = this.featureVectors[unit.index];
        int nBytes = targetFeatures.byteValuedDiscreteFeatures.length;
        int nShorts = targetFeatures.shortValuedDiscreteFeatures.length;
        int nFloats = targetFeatures.continuousFeatures.length;
        assert (nBytes == unitFeatures.byteValuedDiscreteFeatures.length);
        assert (nShorts == unitFeatures.shortValuedDiscreteFeatures.length);
        assert (nFloats == unitFeatures.continuousFeatures.length);
        int featureIndex = this.featureDefinition.getFeatureIndex(featureName);
        float[] weightVector = weights.getFeatureWeights();
        double cost = 0.0;
        if (featureIndex < nBytes) {
            if (this.weightsNonZero[featureIndex]) {
                float weight = weightVector[featureIndex];
                if (this.featureDefinition.hasSimilarityMatrix(featureIndex)) {
                    byte targetFeatValueIndex = targetFeatures.byteValuedDiscreteFeatures[featureIndex];
                    byte unitFeatValueIndex = unitFeatures.byteValuedDiscreteFeatures[featureIndex];
                    float similarity = this.featureDefinition.getSimilarity(featureIndex, unitFeatValueIndex, targetFeatValueIndex);
                    cost = similarity * weight;
                    if (this.debugShowCostGraph) {
                        int n = featureIndex;
                        this.cumulWeightedCosts[n] = this.cumulWeightedCosts[n] + (double)(similarity * weight);
                    }
                } else if (targetFeatures.byteValuedDiscreteFeatures[featureIndex] != unitFeatures.byteValuedDiscreteFeatures[featureIndex]) {
                    cost = weight;
                    if (this.debugShowCostGraph) {
                        int n = featureIndex;
                        this.cumulWeightedCosts[n] = this.cumulWeightedCosts[n] + (double)weight;
                    }
                }
            }
        } else if (featureIndex < nShorts + nBytes) {
            if (this.weightsNonZero[featureIndex]) {
                float weight = weightVector[featureIndex];
                if (targetFeatures.shortValuedDiscreteFeatures[featureIndex - nBytes] != unitFeatures.shortValuedDiscreteFeatures[featureIndex - nBytes]) {
                    cost = weight;
                    if (this.debugShowCostGraph) {
                        int n = featureIndex;
                        this.cumulWeightedCosts[n] = this.cumulWeightedCosts[n] + (double)weight;
                    }
                }
            }
        } else {
            int nDiscrete = nBytes + nShorts;
            if (this.weightsNonZero[featureIndex]) {
                float weight = weightVector[featureIndex];
                float a = targetFeatures.continuousFeatures[featureIndex - nDiscrete];
                float b = unitFeatures.continuousFeatures[featureIndex - nDiscrete];
                if (a == a && b == b) {
                    double myCost = weightFunctions[featureIndex - nDiscrete].cost(a, b);
                    cost = (double)weight * myCost;
                    if (this.debugShowCostGraph) {
                        int n = featureIndex;
                        this.cumulWeightedCosts[n] = this.cumulWeightedCosts[n] + (double)weight * myCost;
                    }
                }
            }
        }
        return cost;
    }

    @Override
    public void load(String featureFileName, InputStream weightsStream, FeatureProcessorManager featProc) throws IOException, MaryConfigurationException {
        FeatureFileReader ffr = FeatureFileReader.getFeatureFileReader(featureFileName);
        this.load(ffr, weightsStream, featProc);
    }

    @Override
    public void load(FeatureFileReader ffr, InputStream weightsStream, FeatureProcessorManager featProc) throws IOException {
        this.featureDefinition = ffr.getFeatureDefinition();
        this.featureVectors = ffr.getFeatureVectors();
        if (weightsStream != null) {
            MaryUtils.getLogger("TargetCostFeatures").debug((Object)"Overwriting target cost weights from file");
            FeatureDefinition newWeights = new FeatureDefinition(new BufferedReader(new InputStreamReader(weightsStream, "UTF-8")), true);
            if (!newWeights.featureEquals(this.featureDefinition)) {
                throw new IOException("Weights file: feature definition incompatible with feature file");
            }
            this.featureDefinition = newWeights;
        }
        this.weightFunction = new WeightFunc[this.featureDefinition.getNumberOfContinuousFeatures()];
        WeightFunctionManager wfm = new WeightFunctionManager();
        int nDiscreteFeatures = this.featureDefinition.getNumberOfByteFeatures() + this.featureDefinition.getNumberOfShortFeatures();
        int i = 0;
        while (i < this.weightFunction.length) {
            String weightFunctionName = this.featureDefinition.getWeightFunctionName(nDiscreteFeatures + i);
            this.weightFunction[i] = "".equals(weightFunctionName) ? wfm.getWeightFunction("linear") : wfm.getWeightFunction(weightFunctionName);
            ++i;
        }
        this.targetFeatureComputer = new TargetFeatureComputer(featProc, this.featureDefinition.getFeatureNames());
        this.rememberWhichWeightsAreNonZero();
        if (MaryProperties.getBoolean("debug.show.cost.graph")) {
            this.debugShowCostGraph = true;
            this.cumulWeightedCosts = new double[this.featureDefinition.getNumberOfFeatures()];
            TargetCostReporter tcr2 = new TargetCostReporter(this.cumulWeightedCosts);
            tcr2.showInJFrame("Average weighted target costs", false, false);
            tcr2.start();
        }
    }

    protected void rememberWhichWeightsAreNonZero() {
        this.weightsNonZero = new boolean[this.featureDefinition.getNumberOfFeatures()];
        int i = 0;
        int n = this.featureDefinition.getNumberOfFeatures();
        while (i < n) {
            this.weightsNonZero[i] = this.featureDefinition.getWeight(i) > 0.0f;
            ++i;
        }
    }

    @Override
    public void computeTargetFeatures(Target target) {
        FeatureVector fv = this.targetFeatureComputer.computeFeatureVector(target);
        target.setFeatureVector(fv);
    }

    @Override
    public FeatureVector getFeatureVector(Unit unit) {
        return this.featureVectors[unit.index];
    }

    @Override
    public String getFeature(Unit unit, String featureName) {
        int featureIndex = this.featureDefinition.getFeatureIndex(featureName);
        if (this.featureDefinition.isByteFeature(featureIndex)) {
            byte value = this.featureVectors[unit.index].getByteFeature(featureIndex);
            return this.featureDefinition.getFeatureValueAsString(featureIndex, value);
        }
        if (this.featureDefinition.isShortFeature(featureIndex)) {
            short value = this.featureVectors[unit.index].getShortFeature(featureIndex);
            return this.featureDefinition.getFeatureValueAsString(featureIndex, value);
        }
        float value = this.featureVectors[unit.index].getContinuousFeature(featureIndex);
        return String.valueOf(value);
    }

    @Override
    public FeatureDefinition getFeatureDefinition() {
        return this.featureDefinition;
    }

    @Override
    public FeatureVector[] getFeatureVectors() {
        return this.featureVectors;
    }

    public class TargetCostReporter
    extends Histogram {
        private double[] data;
        private int lastN;

        public TargetCostReporter(double[] data) {
            super(0.0, 1.0, data);
            this.lastN = 0;
            this.data = data;
        }

        public void start() {
            new Thread(){

                @Override
                public void run() {
                    while (TargetCostReporter.this.isVisible()) {
                        try {
                            Thread.sleep(500L);
                        }
                        catch (InterruptedException interruptedException) {}
                        TargetCostReporter.this.updateGraph();
                    }
                }
            }.start();
        }

        protected void updateGraph() {
            if (FFRTargetCostFunction.this.nCostComputations == this.lastN) {
                return;
            }
            this.lastN = FFRTargetCostFunction.this.nCostComputations;
            double[] newCosts = new double[this.data.length];
            int i = 0;
            while (i < newCosts.length) {
                newCosts[i] = this.data[i] / (double)FFRTargetCostFunction.this.nCostComputations;
                ++i;
            }
            this.updateData(0.0, 1.0, newCosts);
            this.repaint();
        }
    }
}

