package dev.langchain4j.model.github;

import com.azure.ai.inference.ChatCompletionsAsyncClient;
import com.azure.ai.inference.ModelServiceVersion;
import com.azure.ai.inference.models.ChatCompletionsOptions;
import com.azure.ai.inference.models.ChatCompletionsResponseFormat;
import com.azure.ai.inference.models.StreamingChatChoiceUpdate;
import com.azure.ai.inference.models.StreamingChatCompletionsUpdate;
import com.azure.ai.inference.models.StreamingChatResponseMessageUpdate;
import com.azure.core.exception.HttpResponseException;
import com.azure.core.http.ProxyOptions;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.exception.UnsupportedFeatureException;
import dev.langchain4j.internal.ChatRequestValidationUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.ModelProvider;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.request.ToolChoice;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.github.spi.GitHubModelsStreamingChatModelBuilderFactory;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.spi.ServiceHelper;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

/* loaded from: input_file:dev/langchain4j/model/github/GitHubModelsStreamingChatModel.class */
public class GitHubModelsStreamingChatModel implements StreamingChatModel {
    private static final Logger logger = LoggerFactory.getLogger(GitHubModelsStreamingChatModel.class);
    private ChatCompletionsAsyncClient client;
    private final String modelName;
    private final Integer maxTokens;
    private final Double temperature;
    private final Double topP;
    private final List<String> stop;
    private final Double presencePenalty;
    private final Double frequencyPenalty;
    private final Long seed;
    private final ChatCompletionsResponseFormat responseFormat;
    private final List<ChatModelListener> listeners;

    /* loaded from: input_file:dev/langchain4j/model/github/GitHubModelsStreamingChatModel$Builder.class */
    public static class Builder {
        private String endpoint;
        private ModelServiceVersion serviceVersion;
        private String gitHubToken;
        private String modelName;
        private Integer maxTokens;
        private Double temperature;
        private Double topP;
        private List<String> stop;
        private Double presencePenalty;
        private Double frequencyPenalty;
        private Duration timeout;
        private Long seed;
        private ChatCompletionsResponseFormat responseFormat;
        private Integer maxRetries;
        private ProxyOptions proxyOptions;
        private boolean logRequestsAndResponses;
        private ChatCompletionsAsyncClient client;
        private String userAgentSuffix;
        private List<ChatModelListener> listeners;
        private Map<String, String> customHeaders;

        public Builder endpoint(String str) {
            this.endpoint = str;
            return this;
        }

        public Builder serviceVersion(ModelServiceVersion modelServiceVersion) {
            this.serviceVersion = modelServiceVersion;
            return this;
        }

        public Builder gitHubToken(String str) {
            this.gitHubToken = str;
            return this;
        }

        public Builder modelName(String str) {
            this.modelName = str;
            return this;
        }

        public Builder modelName(GitHubModelsChatModelName gitHubModelsChatModelName) {
            this.modelName = gitHubModelsChatModelName.toString();
            return this;
        }

        public Builder maxTokens(Integer num) {
            this.maxTokens = num;
            return this;
        }

        public Builder temperature(Double d) {
            this.temperature = d;
            return this;
        }

        public Builder topP(Double d) {
            this.topP = d;
            return this;
        }

        public Builder stop(List<String> list) {
            this.stop = list;
            return this;
        }

        public Builder presencePenalty(Double d) {
            this.presencePenalty = d;
            return this;
        }

        public Builder frequencyPenalty(Double d) {
            this.frequencyPenalty = d;
            return this;
        }

        public Builder seed(Long l) {
            this.seed = l;
            return this;
        }

        public Builder responseFormat(ChatCompletionsResponseFormat chatCompletionsResponseFormat) {
            this.responseFormat = chatCompletionsResponseFormat;
            return this;
        }

        public Builder timeout(Duration duration) {
            this.timeout = duration;
            return this;
        }

        public Builder maxRetries(Integer num) {
            this.maxRetries = num;
            return this;
        }

        public Builder proxyOptions(ProxyOptions proxyOptions) {
            this.proxyOptions = proxyOptions;
            return this;
        }

        public Builder logRequestsAndResponses(boolean z) {
            this.logRequestsAndResponses = z;
            return this;
        }

        public Builder chatCompletionsAsyncClient(ChatCompletionsAsyncClient chatCompletionsAsyncClient) {
            this.client = chatCompletionsAsyncClient;
            return this;
        }

        public Builder userAgentSuffix(String str) {
            this.userAgentSuffix = str;
            return this;
        }

        public Builder listeners(List<ChatModelListener> list) {
            this.listeners = list;
            return this;
        }

        public Builder customHeaders(Map<String, String> map) {
            this.customHeaders = map;
            return this;
        }

        public GitHubModelsStreamingChatModel build() {
            return this.client != null ? new GitHubModelsStreamingChatModel(this.client, this.modelName, this.maxTokens, this.temperature, this.topP, this.stop, this.presencePenalty, this.frequencyPenalty, this.seed, this.responseFormat, this.listeners) : new GitHubModelsStreamingChatModel(this.endpoint, this.serviceVersion, this.gitHubToken, this.modelName, this.maxTokens, this.temperature, this.topP, this.stop, this.presencePenalty, this.frequencyPenalty, this.seed, this.responseFormat, this.timeout, this.maxRetries, this.proxyOptions, this.logRequestsAndResponses, this.listeners, this.userAgentSuffix, this.customHeaders);
        }
    }

    private GitHubModelsStreamingChatModel(ChatCompletionsAsyncClient chatCompletionsAsyncClient, String str, Integer num, Double d, Double d2, List<String> list, Double d3, Double d4, Long l, ChatCompletionsResponseFormat chatCompletionsResponseFormat, List<ChatModelListener> list2) {
        this(str, num, d, d2, list, d3, d4, l, chatCompletionsResponseFormat, list2);
        this.client = chatCompletionsAsyncClient;
    }

    private GitHubModelsStreamingChatModel(String str, ModelServiceVersion modelServiceVersion, String str2, String str3, Integer num, Double d, Double d2, List<String> list, Double d3, Double d4, Long l, ChatCompletionsResponseFormat chatCompletionsResponseFormat, Duration duration, Integer num2, ProxyOptions proxyOptions, boolean z, List<ChatModelListener> list2, String str4, Map<String, String> map) {
        this(str3, num, d, d2, list, d3, d4, l, chatCompletionsResponseFormat, list2);
        this.client = InternalGitHubModelHelper.setupChatCompletionsBuilder(str, modelServiceVersion, str2, duration, num2, proxyOptions, z, str4, map).buildAsyncClient();
    }

    private GitHubModelsStreamingChatModel(String str, Integer num, Double d, Double d2, List<String> list, Double d3, Double d4, Long l, ChatCompletionsResponseFormat chatCompletionsResponseFormat, List<ChatModelListener> list2) {
        this.modelName = ValidationUtils.ensureNotBlank(str, "modelName");
        this.maxTokens = num;
        this.temperature = d;
        this.topP = d2;
        this.stop = Utils.copyIfNotNull(list);
        this.presencePenalty = d3;
        this.frequencyPenalty = d4;
        this.seed = l;
        this.responseFormat = chatCompletionsResponseFormat;
        this.listeners = list2 == null ? Collections.emptyList() : new ArrayList<>(list2);
    }

    public void chat(ChatRequest chatRequest, final StreamingChatResponseHandler streamingChatResponseHandler) {
        ChatRequestParameters parameters = chatRequest.parameters();
        ChatRequestValidationUtils.validateParameters(parameters);
        ChatRequestValidationUtils.validate(parameters.responseFormat());
        StreamingResponseHandler<AiMessage> streamingResponseHandler = new StreamingResponseHandler<AiMessage>() { // from class: dev.langchain4j.model.github.GitHubModelsStreamingChatModel.1
            public void onNext(String str) {
                streamingChatResponseHandler.onPartialResponse(str);
            }

            public void onComplete(Response<AiMessage> response) {
                streamingChatResponseHandler.onCompleteResponse(ChatResponse.builder().aiMessage((AiMessage) response.content()).metadata(ChatResponseMetadata.builder().tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build()).build());
            }

            public void onError(Throwable th) {
                streamingChatResponseHandler.onError(th);
            }
        };
        List<ToolSpecification> list = parameters.toolSpecifications();
        if (Utils.isNullOrEmpty(list)) {
            generate(chatRequest.messages(), streamingResponseHandler);
        } else if (parameters.toolChoice() != ToolChoice.REQUIRED) {
            generate(chatRequest.messages(), list, streamingResponseHandler);
        } else {
            if (list.size() != 1) {
                throw new UnsupportedFeatureException("%s.%s is currently supported only when there is a single tool".formatted(ToolChoice.class.getSimpleName(), ToolChoice.REQUIRED.name()));
            }
            generate(chatRequest.messages(), list.get(0), streamingResponseHandler);
        }
    }

    private void generate(List<ChatMessage> list, StreamingResponseHandler<AiMessage> streamingResponseHandler) {
        generate(list, null, null, streamingResponseHandler);
    }

    private void generate(List<ChatMessage> list, List<ToolSpecification> list2, StreamingResponseHandler<AiMessage> streamingResponseHandler) {
        generate(list, list2, null, streamingResponseHandler);
    }

    private void generate(List<ChatMessage> list, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> streamingResponseHandler) {
        generate(list, null, toolSpecification, streamingResponseHandler);
    }

    private void generate(List<ChatMessage> list, List<ToolSpecification> list2, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> streamingResponseHandler) {
        ChatCompletionsOptions responseFormat = new ChatCompletionsOptions(InternalGitHubModelHelper.toAzureAiMessages(list)).setModel(this.modelName).setMaxTokens(this.maxTokens).setTemperature(this.temperature).setTopP(this.topP).setStop(this.stop).setPresencePenalty(this.presencePenalty).setFrequencyPenalty(this.frequencyPenalty).setSeed(this.seed).setResponseFormat(this.responseFormat);
        if (toolSpecification != null) {
            responseFormat.setTools(InternalGitHubModelHelper.toToolDefinitions(Collections.singletonList(toolSpecification)));
            responseFormat.setToolChoice(InternalGitHubModelHelper.toToolChoice(toolSpecification));
        }
        if (!Utils.isNullOrEmpty(list2)) {
            responseFormat.setTools(InternalGitHubModelHelper.toToolDefinitions(list2));
        }
        GitHubModelsStreamingResponseBuilder gitHubModelsStreamingResponseBuilder = new GitHubModelsStreamingResponseBuilder();
        ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(InternalGitHubModelHelper.createListenerRequest(responseFormat, list, list2), provider(), new ConcurrentHashMap());
        this.listeners.forEach(chatModelListener -> {
            try {
                chatModelListener.onRequest(chatModelRequestContext);
            } catch (Exception e) {
                logger.warn("Exception while calling model listener", e);
            }
        });
        asyncCall(streamingResponseHandler, responseFormat, gitHubModelsStreamingResponseBuilder, chatModelRequestContext);
    }

    private void handleResponseException(Throwable th, StreamingResponseHandler<AiMessage> streamingResponseHandler) {
        if (!(th instanceof HttpResponseException)) {
            streamingResponseHandler.onError(th);
            return;
        }
        HttpResponseException httpResponseException = (HttpResponseException) th;
        logger.info("Error generating response, {}", httpResponseException.getValue());
        FinishReason contentFilterManagement = InternalGitHubModelHelper.contentFilterManagement(httpResponseException, "content_filter");
        if (contentFilterManagement == FinishReason.CONTENT_FILTER) {
            streamingResponseHandler.onComplete(Response.from(AiMessage.aiMessage(httpResponseException.getMessage()), (TokenUsage) null, contentFilterManagement));
        } else {
            streamingResponseHandler.onError(th);
        }
    }

    private void asyncCall(StreamingResponseHandler<AiMessage> streamingResponseHandler, ChatCompletionsOptions chatCompletionsOptions, GitHubModelsStreamingResponseBuilder gitHubModelsStreamingResponseBuilder, ChatModelRequestContext chatModelRequestContext) {
        Flux completeStream = this.client.completeStream(chatCompletionsOptions);
        AtomicReference atomicReference = new AtomicReference();
        AtomicReference atomicReference2 = new AtomicReference();
        completeStream.subscribe(streamingChatCompletionsUpdate -> {
            gitHubModelsStreamingResponseBuilder.append(streamingChatCompletionsUpdate);
            handle(streamingChatCompletionsUpdate, streamingResponseHandler);
            if (Utils.isNotNullOrBlank(streamingChatCompletionsUpdate.getId())) {
                atomicReference.set(streamingChatCompletionsUpdate.getId());
            }
            if (Utils.isNullOrBlank(streamingChatCompletionsUpdate.getModel())) {
                return;
            }
            atomicReference2.set(streamingChatCompletionsUpdate.getModel());
        }, th -> {
            ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext(th, chatModelRequestContext.chatRequest(), provider(), chatModelRequestContext.attributes());
            this.listeners.forEach(chatModelListener -> {
                try {
                    chatModelListener.onError(chatModelErrorContext);
                } catch (Exception e) {
                    logger.warn("Exception while calling model listener", e);
                }
            });
            handleResponseException(th, streamingResponseHandler);
        }, () -> {
            Response<AiMessage> build = gitHubModelsStreamingResponseBuilder.build();
            ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext(InternalGitHubModelHelper.createListenerResponse((String) atomicReference.get(), chatCompletionsOptions.getModel(), build), chatModelRequestContext.chatRequest(), provider(), chatModelRequestContext.attributes());
            this.listeners.forEach(chatModelListener -> {
                try {
                    chatModelListener.onResponse(chatModelResponseContext);
                } catch (Exception e) {
                    logger.warn("Exception while calling model listener", e);
                }
            });
            streamingResponseHandler.onComplete(build);
        });
    }

    private static void handle(StreamingChatCompletionsUpdate streamingChatCompletionsUpdate, StreamingResponseHandler<AiMessage> streamingResponseHandler) {
        StreamingChatResponseMessageUpdate delta;
        List choices = streamingChatCompletionsUpdate.getChoices();
        if (choices == null || choices.isEmpty() || (delta = ((StreamingChatChoiceUpdate) choices.get(0)).getDelta()) == null || delta.getContent() == null) {
            return;
        }
        streamingResponseHandler.onNext(delta.getContent());
    }

    public List<ChatModelListener> listeners() {
        return this.listeners;
    }

    public ModelProvider provider() {
        return ModelProvider.GITHUB_MODELS;
    }

    public static Builder builder() {
        Iterator it = ServiceHelper.loadFactories(GitHubModelsStreamingChatModelBuilderFactory.class).iterator();
        return it.hasNext() ? ((GitHubModelsStreamingChatModelBuilderFactory) it.next()).get() : new Builder();
    }
}
