/*
 * Decompiled with CFR 0.152.
 */
package com.st.stellar.components.ai.commands;

import com.st.stellar.ai.aiComponent.AIComponent;
import com.st.stellar.ai.aiComponent.CompressionType;
import com.st.stellar.ai.aiComponent.ModeType;
import com.st.stellar.ai.aiComponent.Network;
import com.st.stellar.ai.aiComponent.NetworkType;
import com.st.stellar.ai.aiComponent.OptimizationType;
import com.st.stellar.ai.aiComponent.impl.AIComponentImpl;
import com.st.stellar.ai.utils.Utils;
import com.st.stellar.components.ai.commands.AICommand;
import com.st.stellar.components.ai.commands.AINetParam;
import com.st.stellar.components.ai.generator.AICodeGenerator;
import com.st.stellar.components.ai.generator.AIGenerator;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.eclipse.core.resources.IProject;
import org.eclipse.emf.common.util.BasicMonitor;
import org.eclipse.emf.common.util.EList;
import org.eclipse.emf.common.util.Monitor;
import org.eclipse.emf.ecore.EObject;

public class AIParam {
    private String _aiTool;
    private String _workingDir;
    private AIComponentImpl _appModel;
    private IProject _project;
    private static AIParam _instance = null;
    private static List<Integer> inputSizes = new ArrayList<Integer>();
    private static List<Integer> outputSizes = new ArrayList<Integer>();

    private AIParam() {
        String aiStaiComponentPath = Utils.getPath("com.st.stellar.ai.stai.win32", "");
        aiStaiComponentPath = Utils.normalizeSlashes(aiStaiComponentPath);
        this._aiTool = aiStaiComponentPath + "windows/stedgeai.exe";
        this._appModel = null;
        this._project = null;
    }

    public static AIParam getInstance() {
        if (_instance == null) {
            _instance = new AIParam();
        }
        return _instance;
    }

    public String getAITool() {
        return this._aiTool;
    }

    public boolean launchCommand(AICommand aiCommand) {
        boolean res = true;
        boolean status = true;
        if (aiCommand.toString().contains("VersionCommand")) {
            res = aiCommand.execute(new AINetParam(null, null, null, null, null, false, false, false, false, false, false, null, false, null));
            status = res;
        } else {
            inputSizes.clear();
            outputSizes.clear();
            EList<Network> aiNetworkList = this._appModel.getNetworks();
            for (Network network : aiNetworkList) {
                boolean enabled = network.isEnabled();
                if (!enabled) continue;
                String aiNetworkNameParam = network.getName();
                NetworkType aiTypeParam = network.getType();
                CompressionType aiCompressionParam = network.getCompression();
                OptimizationType aiOptimizationParam = network.getOptimization();
                String aiModelFilePathParam = network.getFilePath();
                if (aiModelFilePathParam == null || aiModelFilePathParam.isEmpty()) {
                    this._appModel.getErr().println(aiNetworkNameParam + ": The network is enabled but the input model path has been not selected!!! Network processing stopped.");
                    status = false;
                }
                if (!status) continue;
                boolean aiIsAllocInEnableParam = network.getAdvancedSettings().isAllocateInputs();
                boolean aiIsAllocOutEnableParam = network.getAdvancedSettings().isAllocateOutputs();
                boolean aiIsSplitWeightsEnableParam = network.getAdvancedSettings().isSplitWeights();
                boolean aiIsAllocActivationsEnableParam = network.getAdvancedSettings().isAllocateActivations();
                boolean aiIsAllocStatesEnableParam = network.getAdvancedSettings().isAllocateStates();
                boolean aiIsClassifierEnableParam = network.getAdvancedSettings().isClassifier();
                String aiExtraCommandLineOptions = network.getAdvancedSettings().getExtraCommandLineOptions();
                boolean aiIsCustomLayerEnableParam = network.getCustomLayerSettings().isEnable();
                String aiCustomLayerFilePathParam = network.getCustomLayerSettings().getCustomLayerJsonFile();
                if (aiTypeParam != NetworkType.KERAS && aiIsCustomLayerEnableParam) {
                    this._appModel.getErr().println(aiNetworkNameParam + ": The Custom Layer Support is enabled but the network type is not a Keras model!!! Network processing stopped.");
                    status = false;
                }
                if (status && aiIsCustomLayerEnableParam && (aiCustomLayerFilePathParam == null || aiCustomLayerFilePathParam.isEmpty())) {
                    this._appModel.getErr().println(aiNetworkNameParam + ": The Custom Layer Support is enabled but the custom layer json file has been not selected!!! Network processing stopped.");
                    status = false;
                }
                if (!status) continue;
                res = aiCommand.execute(new AINetParam(aiNetworkNameParam, aiTypeParam, aiCompressionParam, aiOptimizationParam, aiModelFilePathParam, aiIsAllocInEnableParam, aiIsAllocOutEnableParam, aiIsSplitWeightsEnableParam, aiIsAllocActivationsEnableParam, aiIsAllocStatesEnableParam, aiIsClassifierEnableParam, aiExtraCommandLineOptions, aiIsCustomLayerEnableParam, aiCustomLayerFilePathParam));
                if (!res) {
                    status = false;
                    continue;
                }
                if (aiCommand.toString().contains("GenerateCommand")) {
                    if (!this._appModel.getValidate().isEnabled() || !this._appModel.getValidate().getNetworkToValidate().matches(aiNetworkNameParam)) continue;
                    res = this.aiMNetworkGet(aiNetworkNameParam, aiTypeParam);
                    if (!res) {
                        status = false;
                        continue;
                    }
                    res = this.aiMNetworkSet(aiNetworkNameParam, aiIsAllocInEnableParam, aiIsAllocOutEnableParam);
                    if (res) continue;
                    status = false;
                    continue;
                }
                if (!aiCommand.toString().contains("AnalyzeCommand")) continue;
                network.setMacc(this.aiMNetworkGetMacc(aiNetworkNameParam));
            }
        }
        return status;
    }

    public String getWorkingDir() {
        return this._workingDir;
    }

    public void setAppModel(AIComponentImpl model) {
        this._appModel = model;
        this._project = Utils.getProject((EObject)model);
        this._workingDir = this._project.getLocation().toString();
    }

    public IProject getProject() {
        return this._project;
    }

    public AIComponent getAppModel() {
        return this._appModel;
    }

    private String aiMNetworkGetMacc(String networkName) {
        try {
            String line = "-";
            int inputSize = 0;
            String folderName = this.getAppModel().getName();
            String jsonFile = this.getWorkingDir() + File.separator + folderName + File.separator + "cfg" + File.separator + networkName + "_c_info.json";
            File networkJsonFile = new File(jsonFile);
            if (!networkJsonFile.exists()) {
                return line;
            }
            BufferedReader bufferedReader = new BufferedReader(new FileReader(jsonFile));
            while ((line = bufferedReader.readLine()) != null) {
                if (!line.contains("macc\":")) continue;
                line = line.replaceAll("[^0-9]+", "");
                inputSize += Integer.parseInt(line, 10);
            }
            bufferedReader.close();
            line = Integer.toString(inputSize);
            line = line.concat(" MACC");
            return line;
        }
        catch (IOException e) {
            e.printStackTrace();
            return "-";
        }
    }

    private boolean aiMNetworkGet(String networkName, NetworkType networkType) {
        try {
            String line;
            boolean foundLine = false;
            int index = 0;
            String folderName = this.getAppModel().getName();
            String jsonFile = this.getWorkingDir() + File.separator + folderName + File.separator + "cfg" + File.separator + networkName + "_c_info.json";
            File networkJsonFile = new File(jsonFile);
            if (!networkJsonFile.exists()) {
                return false;
            }
            BufferedReader bufferedReader_in = new BufferedReader(new FileReader(jsonFile));
            while ((line = bufferedReader_in.readLine()) != null) {
                if (!line.contains("STAI_FLAG_INPUTS")) continue;
                foundLine = true;
                while ((line = bufferedReader_in.readLine()) != null && !line.matches("    },")) {
                    if (!line.contains("\"size_bytes\":")) continue;
                    String size = line.replaceAll("[^0-9]+", "");
                    int inputSize = Integer.parseInt(size, 10);
                    if (inputSizes.isEmpty()) {
                        inputSizes.add(inputSize);
                    } else if (index < inputSizes.size()) {
                        if (inputSize > inputSizes.get(index)) {
                            inputSizes.set(index, inputSize);
                        }
                    } else {
                        inputSizes.add(inputSize);
                    }
                    ++index;
                }
            }
            bufferedReader_in.close();
            if (!foundLine) {
                return false;
            }
            foundLine = false;
            index = 0;
            BufferedReader bufferedReader_out = new BufferedReader(new FileReader(jsonFile));
            while ((line = bufferedReader_out.readLine()) != null) {
                if (!line.contains("STAI_FLAG_OUTPUTS")) continue;
                foundLine = true;
                while ((line = bufferedReader_out.readLine()) != null && !line.matches("    },")) {
                    if (!line.contains("\"size_bytes\":")) continue;
                    String size = line.replaceAll("[^0-9]+", "");
                    int outputSize = Integer.parseInt(size, 10);
                    if (outputSizes.isEmpty()) {
                        outputSizes.add(outputSize);
                    } else if (index < outputSizes.size()) {
                        if (outputSize > outputSizes.get(index)) {
                            outputSizes.set(index, outputSize);
                        }
                    } else {
                        outputSizes.add(outputSize);
                    }
                    ++index;
                }
            }
            bufferedReader_out.close();
            return foundLine;
        }
        catch (IOException e) {
            e.printStackTrace();
            return false;
        }
    }

    private boolean aiMNetworkSet(String network, boolean allocateinput, boolean allocateoutput) {
        try {
            Object line;
            if (inputSizes.isEmpty() || outputSizes.isEmpty()) {
                return false;
            }
            String folderName = this.getAppModel().getName();
            String validateCfgFile = this.getWorkingDir() + File.separator + folderName + File.separator + "cfg" + File.separator + "app_stellar-studio-ai.h";
            String validateCfgFileTmp = this.getWorkingDir() + File.separator + folderName + File.separator + "cfg" + File.separator + "app_stellar-studio-ai_temp.h";
            File aiValidateCfgFile = new File(validateCfgFile);
            File aiValidateCfgFileTmp = new File(validateCfgFileTmp);
            if (!aiValidateCfgFile.exists()) {
                return false;
            }
            aiValidateCfgFile.delete();
            AICodeGenerator generator = new AICodeGenerator();
            AIComponent comp = this.getAppModel();
            String conf = generator.generateValidateAIHeader(comp);
            String fileName = "app_stellar-studio-ai.h";
            AIGenerator.generateText(comp.getName() + File.separator + "cfg", fileName, conf, (Monitor)new BasicMonitor());
            BufferedReader bufferedReader = new BufferedReader(new FileReader(aiValidateCfgFile));
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(aiValidateCfgFileTmp));
            boolean foundInputLine = false;
            String mnetwork_number_value = "";
            mnetwork_number_value = this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY ? "AI_MNETWORK_NUMBER" : "STAI_MNETWORK_NUMBER";
            while ((line = bufferedReader.readLine()) != null) {
                int i;
                if (!((String)line).contains(mnetwork_number_value)) {
                    bufferedWriter.write((String)line);
                    bufferedWriter.newLine();
                    continue;
                }
                foundInputLine = true;
                bufferedWriter.write((String)line);
                bufferedWriter.newLine();
                for (i = 0; i < inputSizes.size(); ++i) {
                    line = !allocateinput ? (this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY ? "#define AI_MNETWORK_IN_" + (i + 1) + "_SIZE_BYTES AI_" + network.toUpperCase() + "_IN_" + (i + 1) + "_SIZE_BYTES" : "#define STAI_MNETWORK_IN_" + (i + 1) + "_SIZE_BYTES STAI_" + network.toUpperCase() + "_IN_" + (i + 1) + "_SIZE_BYTES") : (this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY ? "#define AI_MNETWORK_IN_" + (i + 1) + "_SIZE_BYTES 1" : "#define STAI_MNETWORK_IN_" + (i + 1) + "_SIZE_BYTES 1");
                    bufferedWriter.write((String)line);
                    bufferedWriter.newLine();
                }
                if (this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY) {
                    line = "AI_" + network.toUpperCase() + "_IN_NUM";
                    line = "#define AI_MNETWORK_IN_NUM " + (String)line;
                } else {
                    line = "STAI_" + network.toUpperCase() + "_IN_NUM";
                    line = "#define STAI_MNETWORK_IN_NUM " + (String)line;
                }
                bufferedWriter.write((String)line);
                bufferedWriter.newLine();
                line = "#define DEF_DATA_IN \\";
                bufferedWriter.write((String)line);
                bufferedWriter.newLine();
                for (i = 0; i < inputSizes.size(); ++i) {
                    line = this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY ? "  AI_ALIGNED(32) static ai_i8 data_in_" + (i + 1) + "[AI_MNETWORK_IN_" + (i + 1) + "_SIZE_BYTES]; \\" : "  STAI_ALIGNED(32) static uint8_t data_in_" + (i + 1) + "[STAI_MNETWORK_IN_" + (i + 1) + "_SIZE_BYTES]; \\";
                    bufferedWriter.write((String)line);
                    bufferedWriter.newLine();
                }
                line = this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY ? "  ai_i8* data_ins[] = { \\" : "  stai_ptr data_ins[] = { \\";
                bufferedWriter.write((String)line);
                bufferedWriter.newLine();
                for (i = 0; i < inputSizes.size(); ++i) {
                    line = "    data_in_" + (i + 1) + ", \\";
                    bufferedWriter.write((String)line);
                    bufferedWriter.newLine();
                }
                line = "  }; \\";
                bufferedWriter.write((String)line);
                bufferedWriter.newLine();
                bufferedWriter.newLine();
                for (i = 0; i < outputSizes.size(); ++i) {
                    line = !allocateoutput ? (this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY ? "#define AI_MNETWORK_OUT_" + (i + 1) + "_SIZE_BYTES AI_" + network.toUpperCase() + "_OUT_" + (i + 1) + "_SIZE_BYTES" : "#define STAI_MNETWORK_OUT_" + (i + 1) + "_SIZE_BYTES STAI_" + network.toUpperCase() + "_OUT_" + (i + 1) + "_SIZE_BYTES") : (this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY ? "#define AI_MNETWORK_OUT_" + (i + 1) + "_SIZE_BYTES 1" : "#define STAI_MNETWORK_OUT_" + (i + 1) + "_SIZE_BYTES 1");
                    bufferedWriter.write((String)line);
                    bufferedWriter.newLine();
                }
                if (this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY) {
                    line = "AI_" + network.toUpperCase() + "_OUT_NUM";
                    line = "#define AI_MNETWORK_OUT_NUM " + (String)line;
                } else {
                    line = "STAI_" + network.toUpperCase() + "_OUT_NUM";
                    line = "#define STAI_MNETWORK_OUT_NUM " + (String)line;
                }
                bufferedWriter.write((String)line);
                bufferedWriter.newLine();
                line = "#define DEF_DATA_OUT \\";
                bufferedWriter.write((String)line);
                bufferedWriter.newLine();
                for (i = 0; i < outputSizes.size(); ++i) {
                    line = this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY ? "  AI_ALIGNED(32) static ai_i8 data_out_" + (i + 1) + "[AI_MNETWORK_OUT_" + (i + 1) + "_SIZE_BYTES]; \\" : "  STAI_ALIGNED(32) static uint8_t data_out_" + (i + 1) + "[STAI_MNETWORK_OUT_" + (i + 1) + "_SIZE_BYTES]; \\";
                    bufferedWriter.write((String)line);
                    bufferedWriter.newLine();
                }
                line = this.getAppModel().getApi().getConfiguration() == ModeType.LEGACY ? "  ai_i8* data_outs[] = { \\" : "  stai_ptr data_outs[] = { \\";
                bufferedWriter.write((String)line);
                bufferedWriter.newLine();
                for (i = 0; i < outputSizes.size(); ++i) {
                    line = "    data_out_" + (i + 1) + ", \\";
                    bufferedWriter.write((String)line);
                    bufferedWriter.newLine();
                }
                line = "  }; \\";
                bufferedWriter.write((String)line);
                bufferedWriter.newLine();
            }
            bufferedReader.close();
            bufferedWriter.close();
            if (!foundInputLine) {
                aiValidateCfgFile.delete();
                return false;
            }
            aiValidateCfgFile.delete();
            aiValidateCfgFileTmp.renameTo(aiValidateCfgFile);
            return true;
        }
        catch (IOException e) {
            e.printStackTrace();
            return false;
        }
    }
}

