package org.nd4j.compression.impl;

import lombok.Generated;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.DataTypeEx;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.NDArrayCompressor;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/compression/impl/AbstractCompressor.class */
public abstract class AbstractCompressor implements NDArrayCompressor {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(AbstractCompressor.class);

    public INDArray compress(INDArray iNDArray) {
        INDArray dup = iNDArray.dup(iNDArray.ordering());
        Nd4j.getExecutioner().commit();
        dup.setData(compress(dup.data()));
        dup.markAsCompressed(true);
        return dup;
    }

    public void configure(Object... objArr) {
    }

    public void compressi(INDArray iNDArray) {
        if (iNDArray.isView()) {
            throw new UnsupportedOperationException("Impossible to apply inplace compression on View");
        }
        iNDArray.setData(compress(iNDArray.data()));
        iNDArray.markAsCompressed(true);
    }

    public void decompressi(INDArray iNDArray) {
        if (iNDArray.isCompressed()) {
            iNDArray.markAsCompressed(false);
            iNDArray.setData(decompress(iNDArray.data(), iNDArray.data().getCompressionDescriptor().getOriginalDataType()));
        }
    }

    public INDArray decompress(INDArray iNDArray) {
        if (!iNDArray.isCompressed()) {
            return iNDArray;
        }
        return Nd4j.createArrayFromShapeBuffer(decompress(iNDArray.data(), iNDArray.data().getCompressionDescriptor().getOriginalDataType()), iNDArray.shapeInfoDataBuffer());
    }

    public abstract DataBuffer decompress(DataBuffer dataBuffer, DataType dataType);

    public abstract DataBuffer compress(DataBuffer dataBuffer);

    protected static DataTypeEx convertType(DataType dataType) {
        if (dataType == DataType.HALF) {
            return DataTypeEx.FLOAT16;
        }
        if (dataType == DataType.FLOAT) {
            return DataTypeEx.FLOAT;
        }
        if (dataType == DataType.DOUBLE) {
            return DataTypeEx.DOUBLE;
        }
        throw new IllegalStateException("Unknown dataType: [" + dataType + "]");
    }

    protected DataTypeEx getGlobalTypeEx() {
        return convertType(Nd4j.dataType());
    }

    public static DataTypeEx getBufferTypeEx(DataBuffer dataBuffer) {
        return convertType(dataBuffer.dataType());
    }

    public INDArray compress(float[] fArr) {
        return compress(fArr, new int[]{1, fArr.length}, Nd4j.order().charValue());
    }

    public INDArray compress(double[] dArr) {
        return compress(dArr, new int[]{1, dArr.length}, Nd4j.order().charValue());
    }

    public INDArray compress(float[] fArr, int[] iArr, char c) {
        FloatPointer floatPointer = new FloatPointer(fArr);
        return Nd4j.createArrayFromShapeBuffer(compressPointer(DataTypeEx.FLOAT, floatPointer, fArr.length, 4), (DataBuffer) Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(iArr), c, DataType.FLOAT).getFirst());
    }

    public INDArray compress(double[] dArr, int[] iArr, char c) {
        DoublePointer doublePointer = new DoublePointer(dArr);
        return Nd4j.createArrayFromShapeBuffer(compressPointer(DataTypeEx.DOUBLE, doublePointer, dArr.length, 8), (DataBuffer) Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(iArr), c, DataType.DOUBLE).getFirst());
    }

    protected abstract CompressedDataBuffer compressPointer(DataTypeEx dataTypeEx, Pointer pointer, int i, int i2);
}
