package io.modelcontextprotocol.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.MediaType;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@WebServlet(asyncSupported = true)
/* loaded from: input_file:io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.class */
public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) HttpServletSseServerTransportProvider.class);
    public static final String UTF_8 = "UTF-8";
    public static final String APPLICATION_JSON = "application/json";
    public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}";
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    public static final String DEFAULT_BASE_URL = "";
    private final ObjectMapper objectMapper;
    private final String baseUrl;
    private final String messageEndpoint;
    private final String sseEndpoint;
    private final Map<String, McpServerSession> sessions;
    private final AtomicBoolean isClosing;
    private McpServerSession.Factory sessionFactory;

    /* loaded from: input_file:io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider$Builder.class */
    public static class Builder {
        private String messageEndpoint;
        private ObjectMapper objectMapper = new ObjectMapper();
        private String baseUrl = "";
        private String sseEndpoint = "/sse";

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "ObjectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder baseUrl(String str) {
            Assert.notNull(str, "Base URL must not be null");
            this.baseUrl = str;
            return this;
        }

        public Builder messageEndpoint(String str) {
            Assert.hasText(str, "Message endpoint must not be empty");
            this.messageEndpoint = str;
            return this;
        }

        public Builder sseEndpoint(String str) {
            Assert.hasText(str, "SSE endpoint must not be empty");
            this.sseEndpoint = str;
            return this;
        }

        public HttpServletSseServerTransportProvider build() {
            if (this.objectMapper == null) {
                throw new IllegalStateException("ObjectMapper must be set");
            }
            if (this.messageEndpoint == null) {
                throw new IllegalStateException("MessageEndpoint must be set");
            }
            return new HttpServletSseServerTransportProvider(this.objectMapper, this.baseUrl, this.messageEndpoint, this.sseEndpoint);
        }
    }

    /* loaded from: input_file:io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider$HttpServletMcpSessionTransport.class */
    private class HttpServletMcpSessionTransport implements McpServerTransport {
        private final String sessionId;
        private final AsyncContext asyncContext;
        private final PrintWriter writer;

        HttpServletMcpSessionTransport(String str, AsyncContext asyncContext, PrintWriter printWriter) {
            this.sessionId = str;
            this.asyncContext = asyncContext;
            this.writer = printWriter;
            HttpServletSseServerTransportProvider.logger.debug("Session transport {} initialized with SSE writer", str);
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage jSONRPCMessage) {
            return Mono.fromRunnable(() -> {
                try {
                    HttpServletSseServerTransportProvider.this.sendEvent(this.writer, "message", HttpServletSseServerTransportProvider.this.objectMapper.writeValueAsString(jSONRPCMessage));
                    HttpServletSseServerTransportProvider.logger.debug("Message sent to session {}", this.sessionId);
                } catch (Exception e) {
                    HttpServletSseServerTransportProvider.logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage());
                    HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
                    this.asyncContext.complete();
                }
            });
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public <T> T unmarshalFrom(Object obj, TypeReference<T> typeReference) {
            return (T) HttpServletSseServerTransportProvider.this.objectMapper.convertValue(obj, typeReference);
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> {
                HttpServletSseServerTransportProvider.logger.debug("Closing session transport: {}", this.sessionId);
                try {
                    HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
                    this.asyncContext.complete();
                    HttpServletSseServerTransportProvider.logger.debug("Successfully completed async context for session {}", this.sessionId);
                } catch (Exception e) {
                    HttpServletSseServerTransportProvider.logger.warn("Failed to complete async context for session {}: {}", this.sessionId, e.getMessage());
                }
            });
        }

        @Override // io.modelcontextprotocol.spec.McpTransport
        public void close() {
            try {
                HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
                this.asyncContext.complete();
                HttpServletSseServerTransportProvider.logger.debug("Successfully completed async context for session {}", this.sessionId);
            } catch (Exception e) {
                HttpServletSseServerTransportProvider.logger.warn("Failed to complete async context for session {}: {}", this.sessionId, e.getMessage());
            }
        }
    }

    public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String str, String str2) {
        this(objectMapper, "", str, str2);
    }

    public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String str, String str2, String str3) {
        this.sessions = new ConcurrentHashMap();
        this.isClosing = new AtomicBoolean(false);
        this.objectMapper = objectMapper;
        this.baseUrl = str;
        this.messageEndpoint = str2;
        this.sseEndpoint = str3;
    }

    public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String str) {
        this(objectMapper, str, "/sse");
    }

    @Override // io.modelcontextprotocol.spec.McpServerTransportProvider
    public void setSessionFactory(McpServerSession.Factory factory) {
        this.sessionFactory = factory;
    }

    @Override // io.modelcontextprotocol.spec.McpServerTransportProvider
    public Mono<Void> notifyClients(String str, Object obj) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", Integer.valueOf(this.sessions.size()));
        return Flux.fromIterable(this.sessions.values()).flatMap(mcpServerSession -> {
            return mcpServerSession.sendNotification(str, obj).doOnError(th -> {
                logger.error("Failed to send message to session {}: {}", mcpServerSession.getId(), th.getMessage());
            }).onErrorComplete();
        }).then();
    }

    @Override // jakarta.servlet.http.HttpServlet
    protected void doGet(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws ServletException, IOException {
        if (!httpServletRequest.getRequestURI().endsWith(this.sseEndpoint)) {
            httpServletResponse.sendError(404);
            return;
        }
        if (this.isClosing.get()) {
            httpServletResponse.sendError(503, "Server is shutting down");
            return;
        }
        httpServletResponse.setContentType(MediaType.TEXT_EVENT_STREAM_VALUE);
        httpServletResponse.setCharacterEncoding(UTF_8);
        httpServletResponse.setHeader("Cache-Control", "no-cache");
        httpServletResponse.setHeader("Connection", "keep-alive");
        httpServletResponse.setHeader("Access-Control-Allow-Origin", "*");
        String uuid = UUID.randomUUID().toString();
        AsyncContext startAsync = httpServletRequest.startAsync();
        startAsync.setTimeout(0L);
        PrintWriter writer = httpServletResponse.getWriter();
        this.sessions.put(uuid, this.sessionFactory.create(new HttpServletMcpSessionTransport(uuid, startAsync, writer)));
        sendEvent(writer, "endpoint", this.baseUrl + this.messageEndpoint + "?sessionId=" + uuid);
    }

    @Override // jakarta.servlet.http.HttpServlet
    protected void doPost(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws ServletException, IOException {
        if (this.isClosing.get()) {
            httpServletResponse.sendError(503, "Server is shutting down");
            return;
        }
        if (!httpServletRequest.getRequestURI().endsWith(this.messageEndpoint)) {
            httpServletResponse.sendError(404);
            return;
        }
        String parameter = httpServletRequest.getParameter("sessionId");
        if (parameter == null) {
            httpServletResponse.setContentType("application/json");
            httpServletResponse.setCharacterEncoding(UTF_8);
            httpServletResponse.setStatus(400);
            String writeValueAsString = this.objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint"));
            PrintWriter writer = httpServletResponse.getWriter();
            writer.write(writeValueAsString);
            writer.flush();
            return;
        }
        McpServerSession mcpServerSession = this.sessions.get(parameter);
        if (mcpServerSession == null) {
            httpServletResponse.setContentType("application/json");
            httpServletResponse.setCharacterEncoding(UTF_8);
            httpServletResponse.setStatus(404);
            String writeValueAsString2 = this.objectMapper.writeValueAsString(new McpError("Session not found: " + parameter));
            PrintWriter writer2 = httpServletResponse.getWriter();
            writer2.write(writeValueAsString2);
            writer2.flush();
            return;
        }
        try {
            BufferedReader reader = httpServletRequest.getReader();
            StringBuilder sb = new StringBuilder();
            while (true) {
                String readLine = reader.readLine();
                if (readLine == null) {
                    mcpServerSession.handle(McpSchema.deserializeJsonRpcMessage(this.objectMapper, sb.toString())).block();
                    httpServletResponse.setStatus(200);
                    return;
                }
                sb.append(readLine);
            }
        } catch (Exception e) {
            logger.error("Error processing message: {}", e.getMessage());
            try {
                McpError mcpError = new McpError(e.getMessage());
                httpServletResponse.setContentType("application/json");
                httpServletResponse.setCharacterEncoding(UTF_8);
                httpServletResponse.setStatus(500);
                String writeValueAsString3 = this.objectMapper.writeValueAsString(mcpError);
                PrintWriter writer3 = httpServletResponse.getWriter();
                writer3.write(writeValueAsString3);
                writer3.flush();
            } catch (IOException e2) {
                logger.error(FAILED_TO_SEND_ERROR_RESPONSE, e2.getMessage());
                httpServletResponse.sendError(500, "Error processing message");
            }
        }
    }

    @Override // io.modelcontextprotocol.spec.McpServerTransportProvider
    public Mono<Void> closeGracefully() {
        this.isClosing.set(true);
        logger.debug("Initiating graceful shutdown with {} active sessions", Integer.valueOf(this.sessions.size()));
        return Flux.fromIterable(this.sessions.values()).flatMap((v0) -> {
            return v0.closeGracefully();
        }).then();
    }

    private void sendEvent(PrintWriter printWriter, String str, String str2) throws IOException {
        printWriter.write("event: " + str + "\n");
        printWriter.write("data: " + str2 + "\n\n");
        printWriter.flush();
        if (printWriter.checkError()) {
            throw new IOException("Client disconnected");
        }
    }

    @Override // jakarta.servlet.GenericServlet, jakarta.servlet.Servlet
    public void destroy() {
        closeGracefully().block();
        super.destroy();
    }

    public static Builder builder() {
        return new Builder();
    }
}
