package dev.langchain4j.community.model.xinference.client;

import dev.langchain4j.community.model.xinference.client.chat.ChatCompletionRequest;
import dev.langchain4j.community.model.xinference.client.chat.ChatCompletionResponse;
import dev.langchain4j.community.model.xinference.client.completion.CompletionRequest;
import dev.langchain4j.community.model.xinference.client.completion.CompletionResponse;
import dev.langchain4j.community.model.xinference.client.embedding.EmbeddingRequest;
import dev.langchain4j.community.model.xinference.client.embedding.EmbeddingResponse;
import dev.langchain4j.community.model.xinference.client.image.ImageRequest;
import dev.langchain4j.community.model.xinference.client.image.ImageResponse;
import dev.langchain4j.community.model.xinference.client.image.OcrRequest;
import dev.langchain4j.community.model.xinference.client.rerank.RerankRequest;
import dev.langchain4j.community.model.xinference.client.rerank.RerankResponse;
import dev.langchain4j.community.model.xinference.client.shared.StreamOptions;
import dev.langchain4j.community.model.xinference.client.utils.JsonUtil;
import dev.langchain4j.internal.Utils;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import okhttp3.Cache;
import okhttp3.MediaType;
import okhttp3.MultipartBody;
import okhttp3.OkHttpClient;
import okhttp3.RequestBody;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import retrofit2.Retrofit;
import retrofit2.converter.jackson.JacksonConverterFactory;

/* loaded from: input_file:dev/langchain4j/community/model/xinference/client/XinferenceClient.class */
public class XinferenceClient {
    private static final Logger log = LoggerFactory.getLogger(XinferenceClient.class);
    private final String baseUrl;
    private final OkHttpClient okHttpClient;
    private final XinferenceApi xinferenceApi;
    private final boolean logStreamingResponses;

    /* loaded from: input_file:dev/langchain4j/community/model/xinference/client/XinferenceClient$Builder.class */
    public static class Builder {
        private String baseUrl;
        private String apiKey;
        private Duration callTimeout = Duration.ofSeconds(60);
        private Duration connectTimeout = Duration.ofSeconds(60);
        private Duration readTimeout = Duration.ofSeconds(60);
        private Duration writeTimeout = Duration.ofSeconds(60);
        private Proxy proxy;
        private boolean logRequests;
        private boolean logResponses;
        private boolean logStreamingResponses;
        private Map<String, String> customHeaders;

        public Builder baseUrl(String str) {
            if (str == null || str.trim().isEmpty()) {
                throw new IllegalArgumentException("baseUrl cannot be null or empty");
            }
            this.baseUrl = str.endsWith("/") ? str : str + "/";
            return this;
        }

        public Builder apiKey(String str) {
            this.apiKey = str;
            return this;
        }

        public Builder callTimeout(Duration duration) {
            if (duration == null) {
                throw new IllegalArgumentException("callTimeout cannot be null");
            }
            this.callTimeout = duration;
            return this;
        }

        public Builder connectTimeout(Duration duration) {
            if (duration == null) {
                throw new IllegalArgumentException("connectTimeout cannot be null");
            }
            this.connectTimeout = duration;
            return this;
        }

        public Builder readTimeout(Duration duration) {
            if (duration == null) {
                throw new IllegalArgumentException("readTimeout cannot be null");
            }
            this.readTimeout = duration;
            return this;
        }

        public Builder writeTimeout(Duration duration) {
            if (duration == null) {
                throw new IllegalArgumentException("writeTimeout cannot be null");
            }
            this.writeTimeout = duration;
            return this;
        }

        public Builder proxy(Proxy.Type type, String str, int i) {
            this.proxy = new Proxy(type, new InetSocketAddress(str, i));
            return this;
        }

        public Builder proxy(Proxy proxy) {
            this.proxy = proxy;
            return this;
        }

        public Builder logRequests() {
            return logRequests(true);
        }

        public Builder logRequests(Boolean bool) {
            if (bool == null) {
                bool = false;
            }
            this.logRequests = bool.booleanValue();
            return this;
        }

        public Builder logResponses() {
            return logResponses(true);
        }

        public Builder logResponses(Boolean bool) {
            if (bool == null) {
                bool = false;
            }
            this.logResponses = bool.booleanValue();
            return this;
        }

        public Builder logStreamingResponses() {
            return logStreamingResponses(true);
        }

        public Builder logStreamingResponses(Boolean bool) {
            if (bool == null) {
                bool = false;
            }
            this.logStreamingResponses = bool.booleanValue();
            return this;
        }

        public Builder customHeaders(Map<String, String> map) {
            this.customHeaders = map;
            return this;
        }

        public XinferenceClient build() {
            return new XinferenceClient(this);
        }
    }

    private XinferenceClient(Builder builder) {
        this.baseUrl = builder.baseUrl;
        OkHttpClient.Builder writeTimeout = new OkHttpClient.Builder().callTimeout(builder.callTimeout).connectTimeout(builder.connectTimeout).readTimeout(builder.readTimeout).writeTimeout(builder.writeTimeout);
        if (builder.apiKey != null) {
            writeTimeout.addInterceptor(new AuthorizationHeaderInjector(builder.apiKey));
        }
        HashMap hashMap = new HashMap();
        if (builder.customHeaders != null) {
            hashMap.putAll(builder.customHeaders);
        }
        if (!hashMap.isEmpty()) {
            writeTimeout.addInterceptor(new GenericHeaderInjector(hashMap));
        }
        if (builder.proxy != null) {
            writeTimeout.proxy(builder.proxy);
        }
        if (builder.logRequests) {
            writeTimeout.addInterceptor(new RequestLoggingInterceptor());
        }
        if (builder.logResponses) {
            writeTimeout.addInterceptor(new ResponseLoggingInterceptor());
        }
        this.logStreamingResponses = builder.logStreamingResponses;
        this.okHttpClient = writeTimeout.build();
        Retrofit.Builder client = new Retrofit.Builder().baseUrl(this.baseUrl).client(this.okHttpClient);
        client.addConverterFactory(JacksonConverterFactory.create(JsonUtil.getObjectMapper()));
        this.xinferenceApi = (XinferenceApi) client.build().create(XinferenceApi.class);
    }

    public void shutdown() {
        this.okHttpClient.dispatcher().executorService().shutdown();
        this.okHttpClient.connectionPool().evictAll();
        Cache cache = this.okHttpClient.cache();
        if (cache != null) {
            try {
                cache.close();
            } catch (IOException e) {
                log.error("Failed to close cache", e);
            }
        }
    }

    public SyncOrAsyncOrStreaming<CompletionResponse> completions(CompletionRequest completionRequest) {
        return new RequestExecutor(this.xinferenceApi.completions(CompletionRequest.builder().from(completionRequest).stream(null).build()), completionResponse -> {
            return completionResponse;
        }, this.okHttpClient, formatUrl("v1/completions"), () -> {
            return CompletionRequest.builder().from(completionRequest).stream(true).streamOptions(StreamOptions.of(true)).build();
        }, CompletionResponse.class, completionResponse2 -> {
            return completionResponse2;
        }, this.logStreamingResponses);
    }

    public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletions(ChatCompletionRequest chatCompletionRequest) {
        return new RequestExecutor(this.xinferenceApi.chatCompletions(ChatCompletionRequest.builder().from(chatCompletionRequest).stream(null).build()), chatCompletionResponse -> {
            return chatCompletionResponse;
        }, this.okHttpClient, formatUrl("v1/chat/completions"), () -> {
            return ChatCompletionRequest.builder().from(chatCompletionRequest).stream(true).streamOptions(StreamOptions.of(true)).build();
        }, ChatCompletionResponse.class, chatCompletionResponse2 -> {
            return chatCompletionResponse2;
        }, this.logStreamingResponses);
    }

    public SyncOrAsync<EmbeddingResponse> embeddings(EmbeddingRequest embeddingRequest) {
        return new RequestExecutor(this.xinferenceApi.embeddings(embeddingRequest), embeddingResponse -> {
            return embeddingResponse;
        });
    }

    public SyncOrAsync<RerankResponse> rerank(RerankRequest rerankRequest) {
        return new RequestExecutor(this.xinferenceApi.rerank(rerankRequest), rerankResponse -> {
            return rerankResponse;
        });
    }

    public SyncOrAsync<ImageResponse> generations(ImageRequest imageRequest) {
        return new RequestExecutor(this.xinferenceApi.generations(imageRequest), imageResponse -> {
            return imageResponse;
        });
    }

    public SyncOrAsync<ImageResponse> variations(ImageRequest imageRequest, byte[] bArr) {
        MultipartBody.Builder multipartBuilder = toMultipartBuilder(imageRequest);
        multipartBuilder.addFormDataPart("image", "image", RequestBody.create(bArr, MediaType.parse("image")));
        return new RequestExecutor(this.xinferenceApi.variations(multipartBuilder.build()), imageResponse -> {
            return imageResponse;
        });
    }

    public SyncOrAsync<ImageResponse> inpainting(ImageRequest imageRequest, byte[] bArr, byte[] bArr2) {
        MultipartBody.Builder multipartBuilder = toMultipartBuilder(imageRequest);
        multipartBuilder.addFormDataPart("image", "image", RequestBody.create(bArr, MediaType.parse("image")));
        multipartBuilder.addFormDataPart("mask_image", "mask_image", RequestBody.create(bArr2, MediaType.parse("image")));
        return new RequestExecutor(this.xinferenceApi.inpainting(multipartBuilder.build()), imageResponse -> {
            return imageResponse;
        });
    }

    public SyncOrAsync<String> ocr(OcrRequest ocrRequest) {
        MultipartBody.Builder addFormDataPart = new MultipartBody.Builder().setType(MediaType.get("multipart/form-data")).addFormDataPart("model", ocrRequest.getModel()).addFormDataPart("image", "image", RequestBody.create(ocrRequest.getImage(), MediaType.parse("image")));
        if (Utils.isNotNullOrBlank(ocrRequest.getKwargs())) {
            addFormDataPart.addFormDataPart("kwargs", ocrRequest.getKwargs());
        }
        return new RequestExecutor(this.xinferenceApi.ocr(addFormDataPart.build()), str -> {
            return str;
        });
    }

    private String formatUrl(String str) {
        return this.baseUrl + str;
    }

    private static MultipartBody.Builder toMultipartBuilder(ImageRequest imageRequest) {
        MultipartBody.Builder addFormDataPart = new MultipartBody.Builder().setType(MediaType.get("multipart/form-data")).addFormDataPart("model", imageRequest.getModel()).addFormDataPart("prompt", imageRequest.getPrompt()).addFormDataPart("response_format", imageRequest.getResponseFormat().getValue());
        if (Utils.isNotNullOrBlank(imageRequest.getNegativePrompt())) {
            addFormDataPart.addFormDataPart("negative_prompt", imageRequest.getNegativePrompt());
        }
        if (Objects.nonNull(imageRequest.getN())) {
            addFormDataPart.addFormDataPart("n", String.valueOf(imageRequest.getN()));
        }
        if (Utils.isNotNullOrBlank(imageRequest.getSize())) {
            addFormDataPart.addFormDataPart("size", imageRequest.getSize());
        }
        if (Utils.isNotNullOrBlank(imageRequest.getKwargs())) {
            addFormDataPart.addFormDataPart("kwargs", imageRequest.getKwargs());
        }
        return addFormDataPart;
    }

    public static Builder builder() {
        return new Builder();
    }
}
