package io.modelcontextprotocol.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.transport.WebRxSseServerTransportProvider;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.McpSession;
import io.modelcontextprotocol.spec.StatelessMcpSession;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.noear.solon.SolonApp;
import org.noear.solon.core.handle.Context;
import org.noear.solon.core.handle.Entity;
import org.noear.solon.web.sse.SseEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/* loaded from: input_file:io/modelcontextprotocol/server/transport/WebRxStreamableServerTransportProvider.class */
public class WebRxStreamableServerTransportProvider implements McpServerTransportProvider {
    private static final Logger logger = LoggerFactory.getLogger(WebRxStreamableServerTransportProvider.class);
    private static final String MCP_SESSION_ID = "Mcp-Session-Id";
    private static final String APPLICATION_JSON = "application/json";
    private static final String TEXT_EVENT_STREAM = "text/event-stream";
    private static final String DEFAULT_MCP_ENDPOINT = "/mcp";
    private final ObjectMapper objectMapper;
    private final String endpoint;
    private final Map<String, McpServerSession> sessions = new ConcurrentHashMap();
    private McpServerSession.Factory sessionFactory;

    /* loaded from: input_file:io/modelcontextprotocol/server/transport/WebRxStreamableServerTransportProvider$Builder.class */
    public static class Builder {
        private ObjectMapper objectMapper;
        private String endpoint = WebRxStreamableServerTransportProvider.DEFAULT_MCP_ENDPOINT;

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "ObjectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder endpoint(String str) {
            Assert.notNull(str, "Endpoint must not be null");
            this.endpoint = str;
            return this;
        }

        public WebRxStreamableServerTransportProvider build() {
            Assert.notNull(this.objectMapper, "ObjectMapper must be set");
            Assert.notNull(this.endpoint, "Endpoint must be set");
            return new WebRxStreamableServerTransportProvider(this.objectMapper, this.endpoint);
        }
    }

    /* loaded from: input_file:io/modelcontextprotocol/server/transport/WebRxStreamableServerTransportProvider$StreamableHttpServerTransport.class */
    public static class StreamableHttpServerTransport implements McpServerTransport {
        private final ObjectMapper objectMapper;
        private final Context ctx;

        public StreamableHttpServerTransport(Context context, ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
            this.ctx = context;
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage jSONRPCMessage) {
            return Mono.fromRunnable(() -> {
                try {
                    String writeValueAsString = this.objectMapper.writeValueAsString(jSONRPCMessage);
                    if (WebRxStreamableServerTransportProvider.APPLICATION_JSON.equals(this.ctx.contentTypeNew())) {
                        this.ctx.output(writeValueAsString);
                    } else {
                        this.ctx.render(new SseEvent().id(UUID.randomUUID().toString()).data(writeValueAsString));
                    }
                } catch (Throwable th) {
                    throw new RuntimeException("Failed to send message", th);
                }
            });
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public <T> T unmarshalFrom(Object obj, TypeReference<T> typeReference) {
            return (T) this.objectMapper.convertValue(obj, typeReference);
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> {
                try {
                    this.ctx.flush();
                    this.ctx.close();
                } catch (IOException e) {
                }
            });
        }
    }

    public WebRxStreamableServerTransportProvider(ObjectMapper objectMapper, String str) {
        this.objectMapper = objectMapper;
        this.endpoint = str;
    }

    public void sendHeartbeat() {
        Iterator<McpServerSession> it = this.sessions.values().iterator();
        while (it.hasNext()) {
            ((WebRxSseServerTransportProvider.WebRxMcpSessionTransport) it.next().getTransport()).sendHeartbeat();
        }
    }

    public void toHttpHandler(SolonApp solonApp) {
        if (solonApp != null) {
            solonApp.post(this.endpoint, this::doPost);
            solonApp.delete(this.endpoint, this::doDelete);
        }
    }

    @Override // io.modelcontextprotocol.spec.McpServerTransportProvider
    public void setSessionFactory(McpServerSession.Factory factory) {
        this.sessionFactory = factory;
    }

    @Override // io.modelcontextprotocol.spec.McpServerTransportProvider
    public Mono<Void> notifyClients(String str, Map<String, Object> map) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", Integer.valueOf(this.sessions.size()));
        return Flux.fromIterable(this.sessions.values()).flatMap(mcpServerSession -> {
            return mcpServerSession.sendNotification(str, map).doOnError(th -> {
                logger.error("Failed to send message to session {}: {}", mcpServerSession.getId(), th.getMessage());
            }).onErrorComplete();
        }).then();
    }

    @Override // io.modelcontextprotocol.spec.McpServerTransportProvider
    public Mono<Void> closeGracefully() {
        logger.debug("Initiating graceful shutdown with {} active sessions", Integer.valueOf(this.sessions.size()));
        return Flux.fromIterable(this.sessions.values()).flatMap((v0) -> {
            return v0.closeGracefully();
        }).then();
    }

    public void doPost(Context context) throws Throwable {
        String headerOrDefault = context.headerOrDefault("Accept", "");
        List list = (List) Arrays.stream(headerOrDefault.split(",")).map((v0) -> {
            return v0.trim();
        }).collect(Collectors.toList());
        if (!list.contains(APPLICATION_JSON) && !list.contains(TEXT_EVENT_STREAM)) {
            context.status(406, "Legacy transport not available");
            return;
        }
        McpSession orCreateSession = getOrCreateSession(context.header(MCP_SESSION_ID), new StreamableHttpServerTransport(context, this.objectMapper));
        if (!"stateless".equals(orCreateSession.getId())) {
            context.headerSet(MCP_SESSION_ID, orCreateSession.getId());
        }
        Flux<McpSchema.JSONRPCMessage> parseRequestBodyAsStream = parseRequestBodyAsStream(context);
        if (headerOrDefault.contains(TEXT_EVENT_STREAM)) {
            context.contentType(TEXT_EVENT_STREAM);
            orCreateSession.getClass();
            context.returnValue(parseRequestBodyAsStream.flatMap(orCreateSession::handle).collectList().flatMap(list2 -> {
                return Mono.just(new Entity());
            }).onErrorResume(th -> {
                logger.error("Error processing  message: {}", th.getMessage());
                return Mono.just(new Entity().status(500).body(new McpError(th.getMessage())));
            }).doOnTerminate(() -> {
                closeGracefully();
            }));
        } else {
            if (!headerOrDefault.contains(APPLICATION_JSON)) {
                context.status(406, "Unsupported Accept header");
                return;
            }
            context.contentType(APPLICATION_JSON);
            orCreateSession.getClass();
            context.returnValue(parseRequestBodyAsStream.flatMap(orCreateSession::handle).collectList().flatMap(list3 -> {
                return Mono.just(new Entity());
            }).onErrorResume(th2 -> {
                logger.error("Error processing  message: {}", th2.getMessage());
                return Mono.just(new Entity().status(500).body(new McpError(th2.getMessage())));
            }).doOnTerminate(() -> {
                closeGracefully();
            }));
        }
    }

    public void doDelete(Context context) throws IOException {
        String header = context.header("mcp-session-id");
        if (header == null || !this.sessions.containsKey(header)) {
            context.status(404, "Session not found");
        } else {
            this.sessions.remove(header).closeGracefully().subscribe();
            context.status(204);
        }
    }

    private Flux<McpSchema.JSONRPCMessage> parseRequestBodyAsStream(Context context) {
        return Mono.fromCallable(() -> {
            InputStream bodyAsStream = context.bodyAsStream();
            Throwable th = null;
            try {
                JsonNode readTree = this.objectMapper.readTree(bodyAsStream);
                if (readTree.isArray()) {
                    ArrayList arrayList = new ArrayList();
                    Iterator it = readTree.iterator();
                    while (it.hasNext()) {
                        arrayList.add(McpSchema.deserializeJsonRpcMessage(this.objectMapper, (JsonNode) it.next()));
                    }
                    return arrayList;
                }
                if (readTree.isObject()) {
                    List singletonList = Collections.singletonList(McpSchema.deserializeJsonRpcMessage(this.objectMapper, readTree));
                    if (bodyAsStream != null) {
                        if (0 != 0) {
                            try {
                                bodyAsStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            bodyAsStream.close();
                        }
                    }
                    return singletonList;
                }
                List emptyList = Collections.emptyList();
                if (bodyAsStream != null) {
                    if (0 != 0) {
                        try {
                            bodyAsStream.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        bodyAsStream.close();
                    }
                }
                return emptyList;
            } finally {
                if (bodyAsStream != null) {
                    if (0 != 0) {
                        try {
                            bodyAsStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        bodyAsStream.close();
                    }
                }
            }
        }).flatMapMany((v0) -> {
            return Flux.fromIterable(v0);
        });
    }

    private McpSession getOrCreateSession(String str, McpServerTransport mcpServerTransport) {
        if (str != null && this.sessionFactory != null) {
            return this.sessions.get(str);
        }
        if (this.sessionFactory == null) {
            return new StatelessMcpSession(mcpServerTransport);
        }
        McpServerSession create = this.sessionFactory.create(mcpServerTransport);
        this.sessions.put(create.getId(), create);
        return create;
    }

    public static Builder builder() {
        return new Builder();
    }
}
