package io.modelcontextprotocol.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.noear.solon.SolonApp;
import org.noear.solon.Utils;
import org.noear.solon.core.handle.Context;
import org.noear.solon.core.handle.Entity;
import org.noear.solon.web.sse.SseEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;

/* loaded from: input_file:io/modelcontextprotocol/server/transport/WebRxSseServerTransportProvider.class */
public class WebRxSseServerTransportProvider implements McpServerTransportProvider {
    private static final Logger logger = LoggerFactory.getLogger(WebRxSseServerTransportProvider.class);
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";
    private final ObjectMapper objectMapper;
    private final String messageEndpoint;
    private final String sseEndpoint;
    private McpServerSession.Factory sessionFactory;
    private final ConcurrentHashMap<String, McpServerSession> sessions;
    private volatile boolean isClosing;

    /* loaded from: input_file:io/modelcontextprotocol/server/transport/WebRxSseServerTransportProvider$Builder.class */
    public static class Builder {
        private ObjectMapper objectMapper;
        private String messageEndpoint;
        private String sseEndpoint = WebRxSseServerTransportProvider.DEFAULT_SSE_ENDPOINT;

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "ObjectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder messageEndpoint(String str) {
            Assert.notNull(str, "Message endpoint must not be null");
            this.messageEndpoint = str;
            return this;
        }

        public Builder sseEndpoint(String str) {
            Assert.notNull(str, "SSE endpoint must not be null");
            this.sseEndpoint = str;
            return this;
        }

        public WebRxSseServerTransportProvider build() {
            Assert.notNull(this.objectMapper, "ObjectMapper must be set");
            Assert.notNull(this.messageEndpoint, "Message endpoint must be set");
            return new WebRxSseServerTransportProvider(this.objectMapper, this.messageEndpoint, this.sseEndpoint);
        }
    }

    /* loaded from: input_file:io/modelcontextprotocol/server/transport/WebRxSseServerTransportProvider$WebRxMcpSessionTransport.class */
    public class WebRxMcpSessionTransport implements McpServerTransport {
        private final Context context;
        private final FluxSink<SseEvent> sink;

        public WebRxMcpSessionTransport(Context context, FluxSink<SseEvent> fluxSink) {
            this.context = context;
            this.sink = fluxSink;
        }

        public Context getContext() {
            return this.context;
        }

        public void sendHeartbeat() {
            this.sink.next(new SseEvent().comment("heartbeat"));
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage jSONRPCMessage) {
            return Mono.fromSupplier(() -> {
                try {
                    return WebRxSseServerTransportProvider.this.objectMapper.writeValueAsString(jSONRPCMessage);
                } catch (IOException e) {
                    throw Exceptions.propagate(e);
                }
            }).doOnNext(str -> {
                this.sink.next(new SseEvent().name(WebRxSseServerTransportProvider.MESSAGE_EVENT_TYPE).data(str));
            }).doOnError(th -> {
                this.sink.error(Exceptions.unwrap(th));
            }).then();
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public <T> T unmarshalFrom(Object obj, TypeReference<T> typeReference) {
            return (T) WebRxSseServerTransportProvider.this.objectMapper.convertValue(obj, typeReference);
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public Mono<Void> closeGracefully() {
            FluxSink<SseEvent> fluxSink = this.sink;
            fluxSink.getClass();
            return Mono.fromRunnable(fluxSink::complete);
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public void close() {
            this.sink.complete();
        }
    }

    public WebRxSseServerTransportProvider(ObjectMapper objectMapper, String str, String str2) {
        this.sessions = new ConcurrentHashMap<>();
        this.isClosing = false;
        Assert.notNull(objectMapper, "ObjectMapper must not be null");
        Assert.notNull(str, "Message endpoint must not be null");
        Assert.notNull(str2, "SSE endpoint must not be null");
        this.objectMapper = objectMapper;
        this.messageEndpoint = str;
        this.sseEndpoint = str2;
    }

    public void sendHeartbeat() {
        Iterator<McpServerSession> it = this.sessions.values().iterator();
        while (it.hasNext()) {
            ((WebRxMcpSessionTransport) it.next().getTransport()).sendHeartbeat();
        }
    }

    public void toHttpHandler(SolonApp solonApp) {
        if (solonApp != null) {
            solonApp.get(this.sseEndpoint, this::handleSseConnection);
            solonApp.post(this.messageEndpoint, this::handleMessage);
        }
    }

    public String getSseEndpoint() {
        return this.sseEndpoint;
    }

    public String getMessageEndpoint() {
        return this.messageEndpoint;
    }

    public WebRxSseServerTransportProvider(ObjectMapper objectMapper, String str) {
        this(objectMapper, str, DEFAULT_SSE_ENDPOINT);
    }

    @Override // io.modelcontextprotocol.spec.McpServerTransportProvider
    public void setSessionFactory(McpServerSession.Factory factory) {
        this.sessionFactory = factory;
    }

    @Override // io.modelcontextprotocol.spec.McpServerTransportProvider
    public Mono<Void> notifyClients(String str, Map<String, Object> map) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", Integer.valueOf(this.sessions.size()));
        return Flux.fromStream(this.sessions.values().stream()).flatMap(mcpServerSession -> {
            return mcpServerSession.sendNotification(str, map).doOnError(th -> {
                logger.error("Failed to send message to session {}: {}", mcpServerSession.getId(), th.getMessage());
            }).onErrorComplete();
        }).then();
    }

    @Override // io.modelcontextprotocol.spec.McpServerTransportProvider
    public Mono<Void> closeGracefully() {
        return Flux.fromIterable(this.sessions.values()).doFirst(() -> {
            logger.debug("Initiating graceful shutdown with {} active sessions", Integer.valueOf(this.sessions.size()));
        }).flatMap((v0) -> {
            return v0.closeGracefully();
        }).then();
    }

    public void handleSseConnection(Context context) throws Throwable {
        if (this.isClosing) {
            context.status(503);
            context.output("Server is shutting down");
        } else {
            Flux create = Flux.create(fluxSink -> {
                McpServerSession create2 = this.sessionFactory.create(new WebRxMcpSessionTransport(context, fluxSink));
                String id = create2.getId();
                logger.debug("Created new SSE connection for session: {}", id);
                this.sessions.put(id, create2);
                logger.debug("Sending initial endpoint event to session: {}", id);
                fluxSink.next(new SseEvent().name(ENDPOINT_EVENT_TYPE).data(this.messageEndpoint + "?sessionId=" + id));
                fluxSink.onCancel(() -> {
                    logger.debug("Session {} cancelled", id);
                    this.sessions.remove(id);
                });
            });
            context.contentType("text/event-stream");
            context.returnValue(create);
        }
    }

    public void handleMessage(Context context) throws Throwable {
        if (this.isClosing) {
            context.status(503);
            context.output("Server is shutting down");
            return;
        }
        String param = context.param("sessionId");
        if (Utils.isEmpty(param)) {
            context.status(404);
            context.render(new McpError("Session ID missing in message endpoint"));
            return;
        }
        try {
            context.returnValue(this.sessions.get(param).handle(McpSchema.deserializeJsonRpcMessage(this.objectMapper, context.body())).flatMap(r3 -> {
                return Mono.just(new Entity());
            }).onErrorResume(th -> {
                logger.error("Error processing  message: {}", th.getMessage());
                return Mono.just(new Entity().status(500).body(new McpError(th.getMessage())));
            }));
        } catch (IOException | IllegalArgumentException e) {
            logger.error("Failed to deserialize message: {}", e.getMessage());
            context.status(400);
            context.render(new McpError("Invalid message format"));
        }
    }

    public static Builder builder() {
        return new Builder();
    }
}
