package io.modelcontextprotocol.client.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.transport.FlowSseClient;
import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

/* loaded from: input_file:io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.class */
public class HttpClientSseClientTransport implements McpClientTransport {
    private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class);
    private static final String MESSAGE_EVENT_TYPE = "message";
    private static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private static final String DEFAULT_SSE_ENDPOINT = "/sse";
    private final String baseUri;
    private final String sseEndpoint;
    private final FlowSseClient sseClient;
    private final HttpClient httpClient;
    private final HttpRequest.Builder requestBuilder;
    protected ObjectMapper objectMapper;
    private volatile boolean isClosing;
    private final CountDownLatch closeLatch;
    private final AtomicReference<String> messageEndpoint;
    private final AtomicReference<CompletableFuture<Void>> connectionFuture;

    /* loaded from: input_file:io/modelcontextprotocol/client/transport/HttpClientSseClientTransport$Builder.class */
    public static class Builder {
        private String baseUri;
        private String sseEndpoint = "/sse";
        private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).connectTimeout(Duration.ofSeconds(10));
        private ObjectMapper objectMapper = new ObjectMapper();
        private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder().header("Content-Type", HttpServletSseServerTransportProvider.APPLICATION_JSON);

        Builder() {
        }

        @Deprecated(forRemoval = true)
        public Builder(String str) {
            Assert.hasText(str, "baseUri must not be empty");
            this.baseUri = str;
        }

        Builder baseUri(String str) {
            Assert.hasText(str, "baseUri must not be empty");
            this.baseUri = str;
            return this;
        }

        public Builder sseEndpoint(String str) {
            Assert.hasText(str, "sseEndpoint must not be empty");
            this.sseEndpoint = str;
            return this;
        }

        public Builder clientBuilder(HttpClient.Builder builder) {
            Assert.notNull(builder, "clientBuilder must not be null");
            this.clientBuilder = builder;
            return this;
        }

        public Builder customizeClient(Consumer<HttpClient.Builder> consumer) {
            Assert.notNull(consumer, "clientCustomizer must not be null");
            consumer.accept(this.clientBuilder);
            return this;
        }

        public Builder requestBuilder(HttpRequest.Builder builder) {
            Assert.notNull(builder, "requestBuilder must not be null");
            this.requestBuilder = builder;
            return this;
        }

        public Builder customizeRequest(Consumer<HttpRequest.Builder> consumer) {
            Assert.notNull(consumer, "requestCustomizer must not be null");
            consumer.accept(this.requestBuilder);
            return this;
        }

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "objectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public HttpClientSseClientTransport build() {
            return new HttpClientSseClientTransport(this.clientBuilder.build(), this.requestBuilder, this.baseUri, this.sseEndpoint, this.objectMapper);
        }
    }

    @Deprecated(forRemoval = true)
    public HttpClientSseClientTransport(String str) {
        this(HttpClient.newBuilder(), str, new ObjectMapper());
    }

    @Deprecated(forRemoval = true)
    public HttpClientSseClientTransport(HttpClient.Builder builder, String str, ObjectMapper objectMapper) {
        this(builder, str, "/sse", objectMapper);
    }

    @Deprecated(forRemoval = true)
    public HttpClientSseClientTransport(HttpClient.Builder builder, String str, String str2, ObjectMapper objectMapper) {
        this(builder, HttpRequest.newBuilder(), str, str2, objectMapper);
    }

    @Deprecated(forRemoval = true)
    public HttpClientSseClientTransport(HttpClient.Builder builder, HttpRequest.Builder builder2, String str, String str2, ObjectMapper objectMapper) {
        this(builder.connectTimeout(Duration.ofSeconds(10L)).build(), builder2, str, str2, objectMapper);
    }

    HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder builder, String str, String str2, ObjectMapper objectMapper) {
        this.isClosing = false;
        this.closeLatch = new CountDownLatch(1);
        this.messageEndpoint = new AtomicReference<>();
        this.connectionFuture = new AtomicReference<>();
        Assert.notNull(objectMapper, "ObjectMapper must not be null");
        Assert.hasText(str, "baseUri must not be empty");
        Assert.hasText(str2, "sseEndpoint must not be empty");
        Assert.notNull(httpClient, "httpClient must not be null");
        Assert.notNull(builder, "requestBuilder must not be null");
        this.baseUri = str;
        this.sseEndpoint = str2;
        this.objectMapper = objectMapper;
        this.httpClient = httpClient;
        this.requestBuilder = builder;
        this.sseClient = new FlowSseClient(this.httpClient, builder);
    }

    public static Builder builder(String str) {
        return new Builder().baseUri(str);
    }

    @Override // io.modelcontextprotocol.spec.McpClientTransport
    public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        final CompletableFuture<Void> completableFuture = new CompletableFuture<>();
        this.connectionFuture.set(completableFuture);
        this.sseClient.subscribe(this.baseUri + this.sseEndpoint, new FlowSseClient.SseEventHandler() { // from class: io.modelcontextprotocol.client.transport.HttpClientSseClientTransport.1
            @Override // io.modelcontextprotocol.client.transport.FlowSseClient.SseEventHandler
            public void onEvent(FlowSseClient.SseEvent sseEvent) {
                if (HttpClientSseClientTransport.this.isClosing) {
                    return;
                }
                try {
                    if ("endpoint".equals(sseEvent.type())) {
                        HttpClientSseClientTransport.this.messageEndpoint.set(sseEvent.data());
                        HttpClientSseClientTransport.this.closeLatch.countDown();
                        completableFuture.complete(null);
                    } else if ("message".equals(sseEvent.type())) {
                        ((Mono) function.apply(Mono.just(McpSchema.deserializeJsonRpcMessage(HttpClientSseClientTransport.this.objectMapper, sseEvent.data())))).subscribe();
                    } else {
                        HttpClientSseClientTransport.logger.error("Received unrecognized SSE event type: {}", sseEvent.type());
                    }
                } catch (IOException e) {
                    HttpClientSseClientTransport.logger.error("Error processing SSE event", e);
                    completableFuture.completeExceptionally(e);
                }
            }

            @Override // io.modelcontextprotocol.client.transport.FlowSseClient.SseEventHandler
            public void onError(Throwable th) {
                if (HttpClientSseClientTransport.this.isClosing) {
                    return;
                }
                HttpClientSseClientTransport.logger.error("SSE connection error", th);
                completableFuture.completeExceptionally(th);
            }
        });
        return Mono.fromFuture(completableFuture);
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage jSONRPCMessage) {
        if (this.isClosing) {
            return Mono.empty();
        }
        try {
            if (!this.closeLatch.await(10L, TimeUnit.SECONDS)) {
                return Mono.error(new McpError("Failed to wait for the message endpoint"));
            }
            String str = this.messageEndpoint.get();
            if (str == null) {
                return Mono.error(new McpError("No message endpoint available"));
            }
            try {
                return Mono.fromFuture(this.httpClient.sendAsync(this.requestBuilder.uri(URI.create(this.baseUri + str)).POST(HttpRequest.BodyPublishers.ofString(this.objectMapper.writeValueAsString(jSONRPCMessage))).build(), HttpResponse.BodyHandlers.discarding()).thenAccept(httpResponse -> {
                    if (httpResponse.statusCode() == 200 || httpResponse.statusCode() == 201 || httpResponse.statusCode() == 202 || httpResponse.statusCode() == 206) {
                        return;
                    }
                    logger.error("Error sending message: {}", Integer.valueOf(httpResponse.statusCode()));
                }));
            } catch (IOException e) {
                return !this.isClosing ? Mono.error(new RuntimeException("Failed to serialize message", e)) : Mono.empty();
            }
        } catch (InterruptedException e2) {
            return Mono.error(new McpError("Failed to wait for the message endpoint"));
        }
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
            CompletableFuture<Void> completableFuture = this.connectionFuture.get();
            if (completableFuture == null || completableFuture.isDone()) {
                return;
            }
            completableFuture.cancel(true);
        });
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public <T> T unmarshalFrom(Object obj, TypeReference<T> typeReference) {
        return (T) this.objectMapper.convertValue(obj, typeReference);
    }
}
