package org.nd4j.onnxruntime.runner;

import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import lombok.Generated;
import onnx.Onnx;
import org.apache.commons.io.FileUtils;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.CharPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.onnxruntime.Env;
import org.bytedeco.onnxruntime.LongVector;
import org.bytedeco.onnxruntime.MemoryInfo;
import org.bytedeco.onnxruntime.OrtAllocator;
import org.bytedeco.onnxruntime.RunOptions;
import org.bytedeco.onnxruntime.Session;
import org.bytedeco.onnxruntime.SessionOptions;
import org.bytedeco.onnxruntime.Value;
import org.bytedeco.onnxruntime.ValueVector;
import org.nd4j.autodiff.samediff.config.SDValue;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.onnxruntime.runner.enums.ONNXType;
import org.nd4j.onnxruntime.util.ONNXUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/onnxruntime/runner/OnnxRuntimeRunner.class */
public class OnnxRuntimeRunner implements Closeable {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(OnnxRuntimeRunner.class);
    private Session session;
    private RunOptions runOptions;
    private MemoryInfo memoryInfo;
    private OrtAllocator allocator;
    private SessionOptions sessionOptions;
    private static Env env;
    private Pointer bp;
    private Onnx.ModelProto modelProto;
    private List<Onnx.TensorProto> initializers = new ArrayList();
    private List<Onnx.ValueInfoProto> inputs = new ArrayList();

    @Generated
    /* loaded from: input_file:org/nd4j/onnxruntime/runner/OnnxRuntimeRunner$OnnxRuntimeRunnerBuilder.class */
    public static class OnnxRuntimeRunnerBuilder {

        @Generated
        private String modelUri;

        @Generated
        OnnxRuntimeRunnerBuilder() {
        }

        @Generated
        public OnnxRuntimeRunnerBuilder modelUri(String str) {
            this.modelUri = str;
            return this;
        }

        @Generated
        public OnnxRuntimeRunner build() {
            return new OnnxRuntimeRunner(this.modelUri);
        }

        @Generated
        public String toString() {
            return "OnnxRuntimeRunner.OnnxRuntimeRunnerBuilder(modelUri=" + this.modelUri + ")";
        }
    }

    public OnnxRuntimeRunner(String str) {
        if (env == null) {
            env = new Env(ONNXUtils.getOnnxLogLevelFromLogger(log), new BytePointer("nd4j-serving-onnx-session-" + UUID.randomUUID()));
            env.retainReference();
        }
        this.sessionOptions = new SessionOptions();
        this.sessionOptions.SetGraphOptimizationLevel(2);
        this.sessionOptions.SetIntraOpNumThreads(1);
        this.sessionOptions.SetLogSeverityLevel(0);
        this.sessionOptions.retainReference();
        this.allocator = new OrtAllocator();
        this.allocator.retainReference();
        if (str != null) {
            this.bp = Loader.getPlatform().toLowerCase().startsWith("windows") ? new CharPointer(str) : new BytePointer(str);
            this.session = new Session(env, this.bp, this.sessionOptions);
            this.session.retainReference();
            try {
                this.modelProto = Onnx.ModelProto.parseFrom(FileUtils.readFileToByteArray(new File(str)));
            } catch (IOException e) {
                e.printStackTrace();
            }
            for (int i = 0; i < this.modelProto.getGraph().getInitializerCount(); i++) {
                this.initializers.add(this.modelProto.getGraph().getInitializer(i));
            }
            for (int i2 = 0; i2 < this.modelProto.getGraph().getInputCount(); i2++) {
                this.inputs.add(this.modelProto.getGraph().getInput(i2));
            }
        }
        this.runOptions = new RunOptions();
        this.memoryInfo = MemoryInfo.CreateCpu(1, 0);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (this.session != null) {
            this.session.close();
        }
        this.sessionOptions.releaseReference();
        this.allocator.releaseReference();
        this.runOptions.releaseReference();
    }

    public Map<String, SDValue> execValues(Map<String, SDValue> map) {
        long GetInputCount = this.session.GetInputCount();
        long GetOutputCount = this.session.GetOutputCount();
        PointerPointer pointerPointer = new PointerPointer(GetInputCount);
        PointerPointer pointerPointer2 = new PointerPointer(GetOutputCount);
        Value value = new Value(GetInputCount);
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= GetInputCount) {
                value.position(0L);
                for (int i = 0; i < GetOutputCount; i++) {
                    pointerPointer2.put(i, this.session.GetOutputNameAllocated(i, this.allocator));
                }
                ValueVector Run = this.session.Run(this.runOptions, pointerPointer, value, GetInputCount, pointerPointer2, GetOutputCount);
                Run.retainReference();
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                for (int i2 = 0; i2 < GetOutputCount; i2++) {
                    Value value2 = Run.get(i2);
                    value2.retainReference();
                    if (value2.IsTensor()) {
                        linkedHashMap.put(pointerPointer2.get(BytePointer.class, i2).getString(), SDValue.create(ONNXUtils.getArray(value2)));
                    } else {
                        linkedHashMap.put(pointerPointer2.get(BytePointer.class, i2).getString(), SDValue.create(Arrays.asList(ONNXUtils.ndarraysFromSequence(value2, this.allocator))));
                    }
                }
                return linkedHashMap;
            }
            BytePointer GetInputNameAllocated = this.session.GetInputNameAllocated(j2, this.allocator);
            pointerPointer.put(j2, GetInputNameAllocated);
            ONNXType typeForInput = ONNXUtils.getTypeForInput(this.session, j2);
            List listValue = map.get(GetInputNameAllocated.getString()).getListValue();
            if (listValue.size() == 1 && typeForInput == ONNXType.ONNX_TYPE_TENSOR) {
                Value tensor = ONNXUtils.getTensor((INDArray) listValue.get(0), this.memoryInfo);
                Preconditions.checkState(tensor.IsTensor(), "Input must be a tensor.");
                value.position(j2).put(tensor);
            } else {
                if (listValue.size() == 0) {
                    throw new IllegalArgumentException("Onnx Runtime does not support empty sequences! Found at input name " + GetInputNameAllocated.getString());
                }
                if (listValue.size() > 1 || typeForInput == ONNXType.ONNX_TYPE_SEQUENCE) {
                    value.position(j2).put(Value.CreateSequence(ONNXUtils.getSequence(listValue, this.memoryInfo)));
                }
            }
            j = j2 + 1;
        }
    }

    public Map<String, INDArray> exec(Map<String, INDArray> map) {
        long GetInputCount = this.session.GetInputCount();
        long GetOutputCount = this.session.GetOutputCount();
        PointerPointer pointerPointer = new PointerPointer(GetInputCount);
        PointerPointer pointerPointer2 = new PointerPointer(GetOutputCount);
        Value value = new Value(GetInputCount);
        for (int i = 0; i < GetInputCount; i++) {
            BytePointer GetInputNameAllocated = this.session.GetInputNameAllocated(i, this.allocator);
            pointerPointer.put(i, GetInputNameAllocated);
            Value tensor = ONNXUtils.getTensor(map.get(GetInputNameAllocated.getString()), this.memoryInfo);
            Preconditions.checkState(tensor.IsTensor(), "Input must be a tensor.");
            value.position(i).put(tensor);
        }
        value.position(0L);
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= GetOutputCount) {
                break;
            }
            pointerPointer2.put(j2, this.session.GetOutputNameAllocated(j2, this.allocator));
            j = j2 + 1;
        }
        ValueVector Run = this.session.Run(this.runOptions, pointerPointer, value, GetInputCount, pointerPointer2, GetOutputCount);
        Run.retainReference();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i2 = 0; i2 < GetOutputCount; i2++) {
            Value value2 = Run.get(i2);
            value2.retainReference();
            ONNXType typeForOutput = ONNXUtils.getTypeForOutput(this.session, i2);
            switch (typeForOutput) {
                case ONNX_TYPE_SEQUENCE:
                    value2.GetCount();
                    break;
                case ONNX_TYPE_TENSOR:
                    DataBuffer dataBuffer = ONNXUtils.getDataBuffer(value2);
                    LongVector GetShape = value2.GetTensorTypeAndShapeInfo().GetShape();
                    if (GetShape != null) {
                        long[] jArr = new long[(int) GetShape.capacity()];
                        for (int i3 = 0; i3 < jArr.length; i3++) {
                            jArr[i3] = GetShape.get(i3);
                        }
                        linkedHashMap.put(pointerPointer2.get(BytePointer.class, i2).getString(), Nd4j.create(dataBuffer).reshape(jArr));
                        break;
                    } else {
                        linkedHashMap.put(pointerPointer2.get(BytePointer.class, i2).getString(), Nd4j.create(dataBuffer));
                        break;
                    }
                case ONNX_TYPE_MAP:
                case ONNX_TYPE_OPAQUE:
                case ONNX_TYPE_UNKNOWN:
                case ONNX_TYPE_OPTIONAL:
                case ONNX_TYPE_SPARSE_TENSOR:
                default:
                    throw new IllegalStateException("Unable to get type " + typeForOutput + " only accepts tensors and sequences.");
            }
        }
        return linkedHashMap;
    }

    @Generated
    public static OnnxRuntimeRunnerBuilder builder() {
        return new OnnxRuntimeRunnerBuilder();
    }

    @Generated
    public Session getSession() {
        return this.session;
    }

    @Generated
    public RunOptions getRunOptions() {
        return this.runOptions;
    }

    @Generated
    public MemoryInfo getMemoryInfo() {
        return this.memoryInfo;
    }

    @Generated
    public OrtAllocator getAllocator() {
        return this.allocator;
    }

    @Generated
    public SessionOptions getSessionOptions() {
        return this.sessionOptions;
    }

    @Generated
    public Pointer getBp() {
        return this.bp;
    }

    @Generated
    public Onnx.ModelProto getModelProto() {
        return this.modelProto;
    }

    @Generated
    public List<Onnx.TensorProto> getInitializers() {
        return this.initializers;
    }

    @Generated
    public List<Onnx.ValueInfoProto> getInputs() {
        return this.inputs;
    }
}
