package io.modelcontextprotocol.client.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.ClientMcpTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.util.function.BiConsumer;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.core.publisher.SynchronousSink;
import reactor.core.scheduler.Schedulers;
import reactor.util.retry.Retry;

/* loaded from: input_file:io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.class */
public class WebFluxSseClientTransport implements ClientMcpTransport {
    private static final String MESSAGE_EVENT_TYPE = "message";
    private static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private static final String SSE_ENDPOINT = "/sse";
    private final WebClient webClient;
    protected ObjectMapper objectMapper;
    private Disposable inboundSubscription;
    private volatile boolean isClosing;
    protected final Sinks.One<String> messageEndpointSink;
    private BiConsumer<Retry.RetrySignal, SynchronousSink<Object>> inboundRetryHandler;
    private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class);
    private static final ParameterizedTypeReference<ServerSentEvent<String>> SSE_TYPE = new ParameterizedTypeReference<ServerSentEvent<String>>() { // from class: io.modelcontextprotocol.client.transport.WebFluxSseClientTransport.1
    };

    public WebFluxSseClientTransport(WebClient.Builder builder) {
        this(builder, new ObjectMapper());
    }

    public WebFluxSseClientTransport(WebClient.Builder builder, ObjectMapper objectMapper) {
        this.isClosing = false;
        this.messageEndpointSink = Sinks.one();
        this.inboundRetryHandler = (retrySignal, synchronousSink) -> {
            if (this.isClosing) {
                logger.debug("SSE connection closed during shutdown");
                synchronousSink.error(retrySignal.failure());
            } else if (retrySignal.failure() instanceof IOException) {
                logger.debug("Retrying SSE connection after IO error");
                synchronousSink.next(retrySignal);
            } else {
                logger.error("Fatal SSE error, not retrying: {}", retrySignal.failure().getMessage());
                synchronousSink.error(retrySignal.failure());
            }
        };
        Assert.notNull(objectMapper, "ObjectMapper must not be null");
        Assert.notNull(builder, "WebClient.Builder must not be null");
        this.objectMapper = objectMapper;
        this.webClient = builder.build();
    }

    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        this.inboundSubscription = eventStream().concatMap(serverSentEvent -> {
            return Mono.just(serverSentEvent).handle((serverSentEvent, synchronousSink) -> {
                if ("endpoint".equals(serverSentEvent.event())) {
                    if (this.messageEndpointSink.tryEmitValue((String) serverSentEvent.data()).isSuccess()) {
                        synchronousSink.complete();
                        return;
                    } else {
                        synchronousSink.error(new McpError("Failed to handle SSE endpoint event"));
                        return;
                    }
                }
                if (!"message".equals(serverSentEvent.event())) {
                    synchronousSink.error(new McpError("Received unrecognized SSE event type: " + serverSentEvent.event()));
                    return;
                }
                try {
                    synchronousSink.next(McpSchema.deserializeJsonRpcMessage(this.objectMapper, (String) serverSentEvent.data()));
                } catch (IOException e) {
                    synchronousSink.error(e);
                }
            }).transform(function);
        }).subscribe();
        return this.messageEndpointSink.asMono().then();
    }

    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage jSONRPCMessage) {
        return this.messageEndpointSink.asMono().flatMap(str -> {
            if (this.isClosing) {
                return Mono.empty();
            }
            try {
                return this.webClient.post().uri(str, new Object[0]).contentType(MediaType.APPLICATION_JSON).bodyValue(this.objectMapper.writeValueAsString(jSONRPCMessage)).retrieve().toBodilessEntity().doOnSuccess(responseEntity -> {
                    logger.debug("Message sent successfully");
                }).doOnError(th -> {
                    if (this.isClosing) {
                        return;
                    }
                    logger.error("Error sending message: {}", th.getMessage());
                });
            } catch (IOException e) {
                return !this.isClosing ? Mono.error(new RuntimeException("Failed to serialize message", e)) : Mono.empty();
            }
        }).then();
    }

    protected Flux<ServerSentEvent<String>> eventStream() {
        return this.webClient.get().uri("/sse", new Object[0]).accept(new MediaType[]{MediaType.TEXT_EVENT_STREAM}).retrieve().bodyToFlux(SSE_TYPE).retryWhen(Retry.from(flux -> {
            return flux.handle(this.inboundRetryHandler);
        }));
    }

    public Mono<Void> closeGracefully() {
        return Mono.fromRunnable(() -> {
            this.isClosing = true;
            if (this.inboundSubscription != null) {
                this.inboundSubscription.dispose();
            }
        }).then().subscribeOn(Schedulers.boundedElastic());
    }

    public <T> T unmarshalFrom(Object obj, TypeReference<T> typeReference) {
        return (T) this.objectMapper.convertValue(obj, typeReference);
    }
}
