/*
 * Decompiled with CFR 0.152.
 */
package marytts.modules;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.Locale;
import marytts.cart.CART;
import marytts.cart.DirectedGraph;
import marytts.cart.StringPredictionTree;
import marytts.cart.io.DirectedGraphReader;
import marytts.datatypes.MaryData;
import marytts.datatypes.MaryDataType;
import marytts.features.FeatureDefinition;
import marytts.features.FeatureProcessorManager;
import marytts.features.FeatureRegistry;
import marytts.features.TargetFeatureComputer;
import marytts.modules.InternalModule;
import marytts.modules.synthesis.Voice;
import marytts.server.MaryProperties;
import marytts.unitselection.select.Target;
import marytts.unitselection.select.UnitSelector;
import marytts.util.MaryRuntimeUtils;
import marytts.util.MaryUtils;
import marytts.util.dom.MaryDomUtils;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.traversal.NodeIterator;
import org.w3c.dom.traversal.TreeWalker;

public class CARTDurationModeller
extends InternalModule {
    protected DirectedGraph cart = new CART();
    protected StringPredictionTree pausetree;
    protected TargetFeatureComputer featureComputer;
    protected TargetFeatureComputer pauseFeatureComputer;
    private String propertyPrefix;
    private FeatureProcessorManager featureProcessorManager;

    public CARTDurationModeller(String locale, String propertyPrefix) throws Exception {
        this(MaryUtils.string2locale(locale), propertyPrefix, FeatureRegistry.getFeatureProcessorManager(MaryUtils.string2locale(locale)));
    }

    public CARTDurationModeller(String locale, String propertyPrefix, String featprocClassInfo) throws Exception {
        this(MaryUtils.string2locale(locale), propertyPrefix, (FeatureProcessorManager)MaryRuntimeUtils.instantiateObject(featprocClassInfo));
    }

    protected CARTDurationModeller(Locale locale, String propertyPrefix, FeatureProcessorManager featureProcessorManager) {
        super("CARTDurationModeller", MaryDataType.ALLOPHONES, MaryDataType.DURATIONS, locale);
        this.propertyPrefix = propertyPrefix.endsWith(".") ? propertyPrefix : String.valueOf(propertyPrefix) + ".";
        this.featureProcessorManager = featureProcessorManager;
    }

    @Override
    public void startup() throws Exception {
        super.startup();
        String cartFilename = MaryProperties.getFilename(String.valueOf(this.propertyPrefix) + "cart");
        if (cartFilename != null) {
            File cartFile = new File(cartFilename);
            this.cart = new DirectedGraphReader().load(cartFile.getAbsolutePath());
            this.featureComputer = FeatureRegistry.getTargetFeatureComputer(this.featureProcessorManager, this.cart.getFeatureDefinition().getFeatureNames());
        } else {
            this.cart = null;
        }
        String pauseFilename = MaryProperties.getFilename(String.valueOf(this.propertyPrefix) + "pausetree");
        if (pauseFilename != null) {
            File pauseFile = new File(pauseFilename);
            File pauseFdFile = new File(MaryProperties.needFilename(String.valueOf(this.propertyPrefix) + "pausefeatures"));
            FeatureDefinition pauseFeatureDefinition = new FeatureDefinition(new BufferedReader(new FileReader(pauseFdFile)), false);
            this.pauseFeatureComputer = FeatureRegistry.getTargetFeatureComputer(this.featureProcessorManager, pauseFeatureDefinition.getFeatureNames());
            this.pausetree = new StringPredictionTree(new BufferedReader(new FileReader(pauseFile)), pauseFeatureDefinition);
        } else {
            this.pausetree = null;
        }
    }

    @Override
    public MaryData process(MaryData d) throws Exception {
        Document doc = d.getDocument();
        NodeIterator sentenceIt = MaryDomUtils.createNodeIterator((Node)doc, "s");
        Element sentence = null;
        while ((sentence = (Element)sentenceIt.nextNode()) != null) {
            Element segmentOrBoundary;
            DirectedGraph voiceCart;
            Element voice = (Element)MaryDomUtils.getAncestor((Node)sentence, "voice");
            Voice maryVoice = Voice.getVoice(voice);
            if (maryVoice == null) {
                maryVoice = d.getDefaultVoice();
            }
            if (maryVoice == null) {
                Locale locale = MaryUtils.string2locale(doc.getDocumentElement().getAttribute("xml:lang"));
                maryVoice = Voice.getDefaultVoice(locale);
            }
            DirectedGraph currentCart = this.cart;
            TargetFeatureComputer currentFeatureComputer = this.featureComputer;
            if (maryVoice != null && (voiceCart = maryVoice.getDurationGraph()) != null) {
                currentCart = voiceCart;
                this.logger.debug((Object)"Using voice duration graph");
                FeatureDefinition voiceFeatDef = voiceCart.getFeatureDefinition();
                currentFeatureComputer = FeatureRegistry.getTargetFeatureComputer(this.featureProcessorManager, voiceFeatDef.getFeatureNames());
            }
            if (currentCart == null) {
                throw new NullPointerException("No cart for predicting duration");
            }
            float end = 0.0f;
            TreeWalker tw = MaryDomUtils.createTreeWalker(sentence, "ph", "boundary");
            Element previous = null;
            while ((segmentOrBoundary = (Element)tw.nextNode()) != null) {
                float durInSeconds;
                String phone = UnitSelector.getPhoneSymbol(segmentOrBoundary);
                Target t = new Target(phone, segmentOrBoundary);
                t.setFeatureVector(currentFeatureComputer.computeFeatureVector(t));
                if (segmentOrBoundary.getTagName().equals("boundary")) {
                    durInSeconds = this.enterPauseDuration(segmentOrBoundary, previous, this.pausetree, this.pauseFeatureComputer);
                } else {
                    float[] dur = (float[])currentCart.interpret(t);
                    assert (dur != null) : "Null duration";
                    assert (dur.length == 2) : "Unexpected duration length: " + dur.length;
                    durInSeconds = dur[1];
                    float cfr_ignored_0 = dur[0];
                }
                end += durInSeconds;
                int durInMillis = (int)(1000.0f * durInSeconds);
                if (segmentOrBoundary.getTagName().equals("boundary")) {
                    segmentOrBoundary.setAttribute("duration", String.valueOf(durInMillis));
                } else {
                    segmentOrBoundary.setAttribute("d", String.valueOf(durInMillis));
                    segmentOrBoundary.setAttribute("end", String.valueOf(end));
                }
                previous = segmentOrBoundary;
            }
        }
        MaryData output = new MaryData(this.outputType(), d.getLocale());
        output.setDocument(doc);
        return output;
    }

    private float enterPauseDuration(Element boundary, Element previous, StringPredictionTree currentPauseTree, TargetFeatureComputer currentPauseFeatureComputer) {
        if (!boundary.getTagName().equals("boundary")) {
            throw new IllegalArgumentException("cannot call enterPauseDuration for non-pause element");
        }
        if (boundary.hasAttribute("duration")) {
            try {
                return Float.parseFloat(boundary.getAttribute("duration")) * 0.001f;
            }
            catch (NumberFormatException numberFormatException) {}
        }
        float duration = 0.4f;
        if (previous == null || !previous.getTagName().equals("ph")) {
            return duration;
        }
        if (currentPauseTree == null) {
            return duration;
        }
        assert (currentPauseFeatureComputer != null);
        String phone = previous.getAttribute("p");
        Target t = new Target(phone, previous);
        t.setFeatureVector(currentPauseFeatureComputer.computeFeatureVector(t));
        String durationString = currentPauseTree.getMostProbableString(t);
        durationString = durationString.substring(0, durationString.length() - 2);
        try {
            duration = Float.parseFloat(durationString);
        }
        catch (NumberFormatException numberFormatException) {}
        if (duration > 2.0f) {
            this.logger.debug((Object)("Cutting long duration to 2 s -- was " + duration));
            duration = 2.0f;
        }
        return duration;
    }
}

