package io.modelcontextprotocol.client.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.transport.WebRxSseServerTransportProvider;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import org.noear.solon.net.http.HttpResponse;
import org.noear.solon.net.http.HttpUtils;
import org.noear.solon.net.http.HttpUtilsBuilder;
import org.noear.solon.net.http.textstream.TextStreamUtil;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.retry.Retry;

/* loaded from: input_file:io/modelcontextprotocol/client/transport/WebRxStreamableClientTransport.class */
public class WebRxStreamableClientTransport implements McpClientTransport {
    private static final String DEFAULT_MCP_ENDPOINT = "/mcp";
    private static final String MCP_SESSION_ID = "Mcp-Session-Id";
    private static final String LAST_EVENT_ID = "Last-Event-ID";
    private static final String ACCEPT = "Accept";
    private static final String CONTENT_TYPE = "Content-Type";
    private static final String APPLICATION_JSON_SEQ = "application/json-seq";
    private final WebRxSseClientTransport sseClientTransport;
    private final HttpUtilsBuilder webBuilder;
    private final String endpoint;
    private final ObjectMapper objectMapper;
    private final AtomicReference<String> lastEventId = new AtomicReference<>();
    private final AtomicReference<String> mcpSessionId = new AtomicReference<>();
    private final AtomicBoolean fallbackToSse = new AtomicBoolean(false);
    private static final Logger LOGGER = LoggerFactory.getLogger(WebRxStreamableClientTransport.class);
    private static final String APPLICATION_JSON = "application/json";
    private static final String TEXT_EVENT_STREAM = "text/event-stream";
    private static final String DEFAULT_ACCEPT_VALUES = String.format("%s, %s", APPLICATION_JSON, TEXT_EVENT_STREAM);

    /* loaded from: input_file:io/modelcontextprotocol/client/transport/WebRxStreamableClientTransport$Builder.class */
    public static class Builder {
        private final HttpUtilsBuilder webBuilder;
        private ObjectMapper objectMapper = new ObjectMapper();
        private String endpoint = WebRxStreamableClientTransport.DEFAULT_MCP_ENDPOINT;

        public Builder(HttpUtilsBuilder httpUtilsBuilder) {
            Assert.notNull(httpUtilsBuilder, "webBuilder must not be empty");
            this.webBuilder = httpUtilsBuilder;
        }

        public Builder endpoint(String str) {
            Assert.hasText(str, "endpoint must not be null");
            this.endpoint = str;
            return this;
        }

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "objectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public WebRxStreamableClientTransport build() {
            return new WebRxStreamableClientTransport(this.webBuilder, this.objectMapper, this.endpoint, new WebRxSseClientTransport(this.webBuilder, this.endpoint, this.objectMapper));
        }
    }

    public WebRxStreamableClientTransport(HttpUtilsBuilder httpUtilsBuilder, ObjectMapper objectMapper, String str, WebRxSseClientTransport webRxSseClientTransport) {
        this.webBuilder = httpUtilsBuilder;
        this.objectMapper = objectMapper;
        this.endpoint = str;
        this.sseClientTransport = webRxSseClientTransport;
    }

    public static Builder builder(HttpUtilsBuilder httpUtilsBuilder) {
        return new Builder(httpUtilsBuilder);
    }

    @Override // io.modelcontextprotocol.spec.McpClientTransport
    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        return this.fallbackToSse.get() ? this.sseClientTransport.connect(function) : Mono.defer(() -> {
            return Mono.fromFuture(() -> {
                HttpUtils build = this.webBuilder.build(this.endpoint);
                build.header(ACCEPT, TEXT_EVENT_STREAM);
                String str = this.lastEventId.get();
                if (str != null) {
                    build.header(LAST_EVENT_ID, str);
                }
                if (this.mcpSessionId.get() != null) {
                    build.header(MCP_SESSION_ID, this.mcpSessionId.get());
                }
                return build.execAsync("POST");
            }).flatMap(httpResponse -> {
                if (this.mcpSessionId.get() != null && httpResponse.code() == 404) {
                    this.mcpSessionId.set(null);
                }
                if (httpResponse.code() != 405 && httpResponse.code() != 404) {
                    return handleStreamingResponse(httpResponse, function);
                }
                LOGGER.warn("Operation not allowed, falling back to SSE");
                this.fallbackToSse.set(true);
                return this.sseClientTransport.connect(function);
            }).retryWhen(Retry.backoff(3L, Duration.ofSeconds(3L)).filter(th -> {
                return th instanceof IllegalStateException;
            })).onErrorResume(th2 -> {
                LOGGER.error("Streamable transport connection error", th2);
                return Mono.error(th2);
            });
        }).doOnTerminate(this::closeGracefully);
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage jSONRPCMessage) {
        return sendMessage(jSONRPCMessage, mono -> {
            return mono;
        });
    }

    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage jSONRPCMessage, Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        return this.fallbackToSse.get() ? fallbackToSse(jSONRPCMessage) : serializeJson(jSONRPCMessage).flatMap(str -> {
            HttpUtils header = this.webBuilder.build(this.endpoint).bodyOfJson(str).header(ACCEPT, DEFAULT_ACCEPT_VALUES).header(CONTENT_TYPE, APPLICATION_JSON);
            if (this.mcpSessionId.get() != null) {
                header.header(MCP_SESSION_ID, this.mcpSessionId.get());
            }
            return Mono.fromFuture(header.execAsync("POST")).flatMap(httpResponse -> {
                String header2;
                if ((jSONRPCMessage instanceof McpSchema.JSONRPCRequest) && ((McpSchema.JSONRPCRequest) jSONRPCMessage).getMethod().equals(McpSchema.METHOD_INITIALIZE) && (header2 = httpResponse.header(MCP_SESSION_ID)) != null) {
                    this.mcpSessionId.set(header2);
                }
                if (httpResponse.code() == 202) {
                    return Mono.empty();
                }
                if (this.mcpSessionId.get() != null && httpResponse.code() == 404) {
                    this.mcpSessionId.set(null);
                }
                if (httpResponse.code() != 405 && httpResponse.code() != 404) {
                    return httpResponse.code() >= 400 ? Mono.error(new IllegalArgumentException("Unexpected status code: " + httpResponse.code())) : handleStreamingResponse(httpResponse, function);
                }
                LOGGER.warn("Operation not allowed, falling back to SSE");
                this.fallbackToSse.set(true);
                return fallbackToSse(jSONRPCMessage);
            });
        }).onErrorResume(th -> {
            LOGGER.error("Streamable transport sendMessages error", th);
            return Mono.error(th);
        });
    }

    private Mono<Void> fallbackToSse(McpSchema.JSONRPCMessage jSONRPCMessage) {
        if (jSONRPCMessage instanceof McpSchema.JSONRPCBatchRequest) {
            Flux fromIterable = Flux.fromIterable(((McpSchema.JSONRPCBatchRequest) jSONRPCMessage).getItems());
            WebRxSseClientTransport webRxSseClientTransport = this.sseClientTransport;
            webRxSseClientTransport.getClass();
            return fromIterable.flatMap(webRxSseClientTransport::sendMessage).then();
        }
        if (!(jSONRPCMessage instanceof McpSchema.JSONRPCBatchResponse)) {
            return this.sseClientTransport.sendMessage(jSONRPCMessage);
        }
        Flux fromIterable2 = Flux.fromIterable(((McpSchema.JSONRPCBatchResponse) jSONRPCMessage).getItems());
        WebRxSseClientTransport webRxSseClientTransport2 = this.sseClientTransport;
        webRxSseClientTransport2.getClass();
        return fromIterable2.flatMap(webRxSseClientTransport2::sendMessage).then();
    }

    private Mono<String> serializeJson(McpSchema.JSONRPCMessage jSONRPCMessage) {
        try {
            return Mono.just(this.objectMapper.writeValueAsString(jSONRPCMessage));
        } catch (IOException e) {
            LOGGER.error("Error serializing JSON-RPC message", e);
            return Mono.error(e);
        }
    }

    private Mono<Void> handleStreamingResponse(HttpResponse httpResponse, Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        String header = httpResponse.header(CONTENT_TYPE);
        return header.contains(APPLICATION_JSON_SEQ) ? handleJsonStream(httpResponse, function) : header.contains(TEXT_EVENT_STREAM) ? handleSseStream(httpResponse, function) : header.contains(APPLICATION_JSON) ? handleSingleJson(httpResponse, function) : Mono.error(new UnsupportedOperationException("Unsupported Content-Type: " + header));
    }

    private Mono<Void> handleSingleJson(HttpResponse httpResponse, Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        return Mono.fromCallable(() -> {
            try {
                return (Mono) function.apply(Mono.just(McpSchema.deserializeJsonRpcMessage(this.objectMapper, new String(httpResponse.bodyAsBytes(), StandardCharsets.UTF_8))));
            } catch (IOException e) {
                LOGGER.error("Error processing JSON response", e);
                return Mono.error(e);
            }
        }).flatMap(Function.identity()).then();
    }

    private Mono<Void> handleJsonStream(HttpResponse httpResponse, Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        return Flux.from(TextStreamUtil.parseLineStream(httpResponse.body())).flatMap(str -> {
            try {
                return (Publisher) function.apply(Mono.just(McpSchema.deserializeJsonRpcMessage(this.objectMapper, str)));
            } catch (IOException e) {
                LOGGER.error("Error processing JSON line", e);
                return Mono.error(e);
            }
        }).then();
    }

    private Mono<Void> handleSseStream(HttpResponse httpResponse, Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> function) {
        return Flux.from(TextStreamUtil.parseSseStream(httpResponse.body())).filter(serverSentEvent -> {
            return WebRxSseServerTransportProvider.MESSAGE_EVENT_TYPE.equals(serverSentEvent.getEvent());
        }).concatMap(serverSentEvent2 -> {
            String trim = serverSentEvent2.getData().trim();
            try {
                JsonNode readTree = this.objectMapper.readTree(trim);
                ArrayList arrayList = new ArrayList();
                if (readTree.isArray()) {
                    Iterator it = readTree.iterator();
                    while (it.hasNext()) {
                        arrayList.add(McpSchema.deserializeJsonRpcMessage(this.objectMapper, ((JsonNode) it.next()).toString()));
                    }
                } else {
                    if (!readTree.isObject()) {
                        String str = "Unexpected JSON in SSE data: " + trim;
                        LOGGER.warn(str);
                        return Mono.error(new IllegalArgumentException(str));
                    }
                    arrayList.add(McpSchema.deserializeJsonRpcMessage(this.objectMapper, readTree.toString()));
                }
                return Flux.fromIterable(arrayList).concatMap(jSONRPCMessage -> {
                    return (Mono) function.apply(Mono.just(jSONRPCMessage));
                }).then(Mono.fromRunnable(() -> {
                    if (serverSentEvent2.getId().isEmpty()) {
                        return;
                    }
                    this.lastEventId.set(serverSentEvent2.getId());
                }));
            } catch (IOException e) {
                LOGGER.error("Error parsing SSE JSON: {}", trim, e);
                return Mono.error(e);
            }
        }).then();
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public Mono<Void> closeGracefully() {
        this.mcpSessionId.set(null);
        this.lastEventId.set(null);
        return this.fallbackToSse.get() ? this.sseClientTransport.closeGracefully() : Mono.empty();
    }

    @Override // io.modelcontextprotocol.spec.McpTransport
    public <T> T unmarshalFrom(Object obj, TypeReference<T> typeReference) {
        return (T) this.objectMapper.convertValue(obj, typeReference);
    }
}
