/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.onnxruntime.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.onnxruntime.engine.OrtNDManager;
import ai.djl.onnxruntime.engine.OrtSymbolBlock;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Utils;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OrtModel
extends BaseModel {
    private static final Logger logger = LoggerFactory.getLogger(OrtModel.class);
    private OrtEnvironment env;
    private OrtSession.SessionOptions sessionOptions;

    OrtModel(String name, NDManager manager, OrtEnvironment env) {
        super(name);
        this.manager = manager;
        this.manager.setName("ortModel");
        this.env = env;
        this.dataType = DataType.FLOAT32;
        this.sessionOptions = new OrtSession.SessionOptions();
    }

    public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
        this.setModelDir(modelPath);
        this.wasLoaded = true;
        if (this.block != null) {
            throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
        }
        Path modelFile = prefix != null ? this.findModelFile(prefix) : this.findModelFile(this.modelName, this.modelDir.toFile().getName(), "model.onnx");
        if (modelFile == null) {
            throw new FileNotFoundException(".onnx file not found in: " + modelPath);
        }
        try {
            OrtSession.SessionOptions ortOptions = this.getSessionOptions(options);
            OrtSession session = this.env.createSession(modelFile.toString(), ortOptions);
            this.block = new OrtSymbolBlock(session, (OrtNDManager)this.manager);
        }
        catch (OrtException e) {
            throw new MalformedModelException("ONNX Model cannot be loaded", (Throwable)e);
        }
    }

    public void load(InputStream is, Map<String, ?> options) throws IOException, MalformedModelException {
        if (this.block != null) {
            throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
        }
        this.modelDir = Files.createTempDirectory("ort-model", new FileAttribute[0]);
        this.modelDir.toFile().deleteOnExit();
        try {
            byte[] buf = Utils.toByteArray((InputStream)is);
            OrtSession.SessionOptions ortOptions = this.getSessionOptions(options);
            OrtSession session = this.env.createSession(buf, ortOptions);
            this.block = new OrtSymbolBlock(session, (OrtNDManager)this.manager);
        }
        catch (OrtException e) {
            throw new MalformedModelException("ONNX Model cannot be loaded", (Throwable)e);
        }
    }

    private Path findModelFile(String ... prefixes) {
        if (Files.isRegularFile(this.modelDir, new LinkOption[0])) {
            Path file = this.modelDir;
            this.modelDir = this.modelDir.getParent();
            String fileName = file.toFile().getName();
            this.modelName = fileName.endsWith(".onnx") ? fileName.substring(0, fileName.length() - 5) : fileName;
            return file;
        }
        for (String prefix : prefixes) {
            Path modelFile = this.modelDir.resolve(prefix);
            if (Files.isRegularFile(modelFile, new LinkOption[0])) {
                return modelFile;
            }
            if (prefix.endsWith(".onnx") || !Files.isRegularFile(modelFile = this.modelDir.resolve(prefix + ".onnx"), new LinkOption[0])) continue;
            return modelFile;
        }
        return null;
    }

    public void close() {
        super.close();
        try {
            this.sessionOptions.close();
        }
        catch (IllegalArgumentException illegalArgumentException) {
            // empty catch block
        }
    }

    private OrtSession.SessionOptions getSessionOptions(Map<String, ?> options) throws OrtException {
        String profilerOutput;
        String customOpLibrary;
        String disablePerSessionThreads;
        String cpuArena;
        String memoryOptimization;
        String optLevel;
        String executionMode;
        String interOpNumThreads;
        if (options == null) {
            return this.sessionOptions;
        }
        OrtSession.SessionOptions ortSession = this.sessionOptions;
        if (options.containsKey("sessionOptions")) {
            ortSession = (OrtSession.SessionOptions)options.get("sessionOptions");
        }
        if ((interOpNumThreads = (String)options.get("interOpNumThreads")) != null) {
            ortSession.setInterOpNumThreads(Integer.parseInt(interOpNumThreads));
        }
        String intraOpNumThreads = (String)options.get("intraOpNumThreads");
        if (interOpNumThreads != null) {
            ortSession.setIntraOpNumThreads(Integer.parseInt(intraOpNumThreads));
        }
        if ((executionMode = (String)options.get("executionMode")) != null) {
            ortSession.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.valueOf((String)executionMode));
        }
        if ((optLevel = (String)options.get("optLevel")) != null) {
            ortSession.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.valueOf((String)optLevel));
        }
        if (Boolean.parseBoolean(memoryOptimization = (String)options.get("memoryPatternOptimization"))) {
            ortSession.setMemoryPatternOptimization(true);
        }
        if (Boolean.parseBoolean(cpuArena = (String)options.get("cpuArenaAllocator"))) {
            ortSession.setCPUArenaAllocator(true);
        }
        if (Boolean.parseBoolean(disablePerSessionThreads = (String)options.get("disablePerSessionThreads"))) {
            ortSession.disablePerSessionThreads();
        }
        if ((customOpLibrary = (String)options.get("customOpLibrary")) == null) {
            customOpLibrary = this.getOrtxLibraryPath();
        }
        if (customOpLibrary != null) {
            ortSession.registerCustomOpLibrary(customOpLibrary);
        }
        if ((profilerOutput = (String)options.get("profilerOutput")) != null) {
            ortSession.enableProfiling(profilerOutput);
        }
        Device device = this.manager.getDevice();
        if (options.containsKey("ortDevice")) {
            String ortDevice;
            switch (ortDevice = (String)options.get("ortDevice")) {
                case "TensorRT": {
                    if (!device.isGpu()) {
                        throw new IllegalArgumentException("TensorRT required GPU device.");
                    }
                    ortSession.addTensorrt(device.getDeviceId());
                    break;
                }
                case "ROCM": {
                    ortSession.addROCM();
                    break;
                }
                case "CoreML": {
                    ortSession.addCoreML();
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Invalid ortDevice: " + ortDevice);
                }
            }
        } else if (device.isGpu()) {
            ortSession.addCUDA(device.getDeviceId());
        }
        return ortSession;
    }

    private String getOrtxLibraryPath() {
        ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
        try {
            Class<?> clazz = Class.forName("ai.onnxruntime.extensions.OrtxPackage", true, cl);
            Method method = clazz.getDeclaredMethod("getLibraryPath", new Class[0]);
            return (String)method.invoke(null, new Object[0]);
        }
        catch (Throwable e) {
            logger.info("Onnx extension not found in classpath.");
            logger.trace("Failed to load onnx extension", e);
            return null;
        }
    }
}

