package io.modelcontextprotocol.client;

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.client.McpClientFeatures;
import io.modelcontextprotocol.spec.ClientMcpTransport;
import io.modelcontextprotocol.spec.DefaultMcpSession;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpTransport;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/* loaded from: input_file:io/modelcontextprotocol/client/McpAsyncClient.class */
public class McpAsyncClient {
    private final DefaultMcpSession mcpSession;
    private final McpSchema.ClientCapabilities clientCapabilities;
    private final McpSchema.Implementation clientInfo;
    private McpSchema.ServerCapabilities serverCapabilities;
    private McpSchema.Implementation serverInfo;
    private final ConcurrentHashMap<String, McpSchema.Root> roots;
    private Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler;
    private final McpTransport transport;
    private List<String> protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION);
    private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class);
    private static TypeReference<Void> VOID_TYPE_REFERENCE = new TypeReference<Void>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.1
    };
    private static final TypeReference<McpSchema.CallToolResult> CALL_TOOL_RESULT_TYPE_REF = new TypeReference<McpSchema.CallToolResult>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.6
    };
    private static final TypeReference<McpSchema.ListToolsResult> LIST_TOOLS_RESULT_TYPE_REF = new TypeReference<McpSchema.ListToolsResult>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.7
    };
    private static final TypeReference<McpSchema.ListResourcesResult> LIST_RESOURCES_RESULT_TYPE_REF = new TypeReference<McpSchema.ListResourcesResult>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.8
    };
    private static final TypeReference<McpSchema.ReadResourceResult> READ_RESOURCE_RESULT_TYPE_REF = new TypeReference<McpSchema.ReadResourceResult>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.9
    };
    private static final TypeReference<McpSchema.ListResourceTemplatesResult> LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeReference<McpSchema.ListResourceTemplatesResult>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.10
    };
    private static final TypeReference<McpSchema.ListPromptsResult> LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<McpSchema.ListPromptsResult>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.11
    };
    private static final TypeReference<McpSchema.GetPromptResult> GET_PROMPT_RESULT_TYPE_REF = new TypeReference<McpSchema.GetPromptResult>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.12
    };

    /* JADX INFO: Access modifiers changed from: package-private */
    public McpAsyncClient(ClientMcpTransport clientMcpTransport, Duration duration, McpClientFeatures.Async async) {
        Assert.notNull(clientMcpTransport, "Transport must not be null");
        Assert.notNull(duration, "Request timeout must not be null");
        this.clientInfo = async.clientInfo();
        this.clientCapabilities = async.clientCapabilities();
        this.transport = clientMcpTransport;
        this.roots = new ConcurrentHashMap<>(async.roots());
        HashMap hashMap = new HashMap();
        if (this.clientCapabilities.roots() != null) {
            hashMap.put(McpSchema.METHOD_ROOTS_LIST, rootsListRequestHandler());
        }
        if (this.clientCapabilities.sampling() != null) {
            if (async.samplingHandler() == null) {
                throw new McpError("Sampling handler must not be null when client capabilities include sampling");
            }
            this.samplingHandler = async.samplingHandler();
            hashMap.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler());
        }
        HashMap hashMap2 = new HashMap();
        ArrayList arrayList = new ArrayList();
        arrayList.add(list -> {
            return Mono.fromRunnable(() -> {
                logger.info("Tools changed: {}", list);
            });
        });
        if (!Utils.isEmpty(async.toolsChangeConsumers())) {
            arrayList.addAll(async.toolsChangeConsumers());
        }
        hashMap2.put(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, asyncToolsChangeNotificationHandler(arrayList));
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(list2 -> {
            return Mono.fromRunnable(() -> {
                logger.info("Resources changed: {}", list2);
            });
        });
        if (!Utils.isEmpty(async.resourcesChangeConsumers())) {
            arrayList2.addAll(async.resourcesChangeConsumers());
        }
        hashMap2.put(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, asyncResourcesChangeNotificationHandler(arrayList2));
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(list3 -> {
            return Mono.fromRunnable(() -> {
                logger.info("Prompts changed: {}", list3);
            });
        });
        if (!Utils.isEmpty(async.promptsChangeConsumers())) {
            arrayList3.addAll(async.promptsChangeConsumers());
        }
        hashMap2.put(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, asyncPromptsChangeNotificationHandler(arrayList3));
        ArrayList arrayList4 = new ArrayList();
        arrayList4.add(loggingMessageNotification -> {
            return Mono.fromRunnable(() -> {
                logger.info("Logging: {}", loggingMessageNotification);
            });
        });
        if (!Utils.isEmpty(async.loggingConsumers())) {
            arrayList4.addAll(async.loggingConsumers());
        }
        hashMap2.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(arrayList4));
        this.mcpSession = new DefaultMcpSession(duration, clientMcpTransport, hashMap, hashMap2);
    }

    public Mono<McpSchema.InitializeResult> initialize() {
        return this.mcpSession.sendRequest(McpSchema.METHOD_INITIALIZE, new McpSchema.InitializeRequest(this.protocolVersions.get(this.protocolVersions.size() - 1), this.clientCapabilities, this.clientInfo), new TypeReference<McpSchema.InitializeResult>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.2
        }).flatMap(initializeResult -> {
            this.serverCapabilities = initializeResult.capabilities();
            this.serverInfo = initializeResult.serverInfo();
            logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", new Object[]{initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), initializeResult.instructions()});
            return !this.protocolVersions.contains(initializeResult.protocolVersion()) ? Mono.error(new McpError("Unsupported protocol version from the server: " + initializeResult.protocolVersion())) : this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null).thenReturn(initializeResult);
        });
    }

    public McpSchema.ServerCapabilities getServerCapabilities() {
        return this.serverCapabilities;
    }

    public McpSchema.Implementation getServerInfo() {
        return this.serverInfo;
    }

    public McpSchema.ClientCapabilities getClientCapabilities() {
        return this.clientCapabilities;
    }

    public McpSchema.Implementation getClientInfo() {
        return this.clientInfo;
    }

    public void close() {
        this.mcpSession.close();
    }

    public Mono<Void> closeGracefully() {
        return this.mcpSession.closeGracefully();
    }

    public Mono<Object> ping() {
        return this.mcpSession.sendRequest(McpSchema.METHOD_PING, null, new TypeReference<Object>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.3
        });
    }

    public Mono<Void> addRoot(McpSchema.Root root) {
        if (root == null) {
            return Mono.error(new McpError("Root must not be null"));
        }
        if (this.clientCapabilities.roots() == null) {
            return Mono.error(new McpError("Client must be configured with roots capabilities"));
        }
        if (this.roots.containsKey(root.uri())) {
            return Mono.error(new McpError("Root with uri '" + root.uri() + "' already exists"));
        }
        this.roots.put(root.uri(), root);
        logger.info("Added root: {}", root);
        return this.clientCapabilities.roots().listChanged().booleanValue() ? rootsListChangedNotification() : Mono.empty();
    }

    public Mono<Void> removeRoot(String str) {
        if (str == null) {
            return Mono.error(new McpError("Root uri must not be null"));
        }
        if (this.clientCapabilities.roots() == null) {
            return Mono.error(new McpError("Client must be configured with roots capabilities"));
        }
        if (this.roots.remove(str) == null) {
            return Mono.error(new McpError("Root with uri '" + str + "' not found"));
        }
        logger.info("Removed Root: {}", str);
        return this.clientCapabilities.roots().listChanged().booleanValue() ? rootsListChangedNotification() : Mono.empty();
    }

    public Mono<Void> rootsListChangedNotification() {
        return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED);
    }

    private DefaultMcpSession.RequestHandler<McpSchema.ListRootsResult> rootsListRequestHandler() {
        return obj -> {
            return Mono.just(new McpSchema.ListRootsResult(this.roots.values().stream().toList()));
        };
    }

    private DefaultMcpSession.RequestHandler<McpSchema.CreateMessageResult> samplingCreateMessageHandler() {
        return obj -> {
            return this.samplingHandler.apply((McpSchema.CreateMessageRequest) this.transport.unmarshalFrom(obj, new TypeReference<McpSchema.CreateMessageRequest>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.5
            }));
        };
    }

    public Mono<McpSchema.CallToolResult> callTool(McpSchema.CallToolRequest callToolRequest) {
        return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF);
    }

    public Mono<McpSchema.ListToolsResult> listTools() {
        return listTools(null);
    }

    public Mono<McpSchema.ListToolsResult> listTools(String str) {
        return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(str), LIST_TOOLS_RESULT_TYPE_REF);
    }

    private DefaultMcpSession.NotificationHandler asyncToolsChangeNotificationHandler(List<Function<List<McpSchema.Tool>, Mono<Void>>> list) {
        return obj -> {
            return listTools().flatMap(listToolsResult -> {
                return Flux.fromIterable(list).flatMap(function -> {
                    return (Publisher) function.apply(listToolsResult.tools());
                }).onErrorResume(th -> {
                    logger.error("Error handling tools list change notification", th);
                    return Mono.empty();
                }).then();
            });
        };
    }

    public Mono<McpSchema.ListResourcesResult> listResources() {
        return listResources(null);
    }

    public Mono<McpSchema.ListResourcesResult> listResources(String str) {
        return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(str), LIST_RESOURCES_RESULT_TYPE_REF);
    }

    public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.Resource resource) {
        return readResource(new McpSchema.ReadResourceRequest(resource.uri()));
    }

    public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.ReadResourceRequest readResourceRequest) {
        return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, READ_RESOURCE_RESULT_TYPE_REF);
    }

    public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates() {
        return listResourceTemplates(null);
    }

    public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates(String str) {
        return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new McpSchema.PaginatedRequest(str), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF);
    }

    public Mono<Void> sendResourcesListChanged() {
        return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED);
    }

    public Mono<Void> subscribeResource(McpSchema.SubscribeRequest subscribeRequest) {
        return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE);
    }

    public Mono<Void> unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) {
        return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE);
    }

    private DefaultMcpSession.NotificationHandler asyncResourcesChangeNotificationHandler(List<Function<List<McpSchema.Resource>, Mono<Void>>> list) {
        return obj -> {
            return listResources().flatMap(listResourcesResult -> {
                return Flux.fromIterable(list).flatMap(function -> {
                    return (Publisher) function.apply(listResourcesResult.resources());
                }).onErrorResume(th -> {
                    logger.error("Error handling resources list change notification", th);
                    return Mono.empty();
                }).then();
            });
        };
    }

    public Mono<McpSchema.ListPromptsResult> listPrompts() {
        return listPrompts(null);
    }

    public Mono<McpSchema.ListPromptsResult> listPrompts(String str) {
        return this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_LIST, new McpSchema.PaginatedRequest(str), LIST_PROMPTS_RESULT_TYPE_REF);
    }

    public Mono<McpSchema.GetPromptResult> getPrompt(McpSchema.GetPromptRequest getPromptRequest) {
        return this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF);
    }

    public Mono<Void> promptListChangedNotification() {
        return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED);
    }

    private DefaultMcpSession.NotificationHandler asyncPromptsChangeNotificationHandler(List<Function<List<McpSchema.Prompt>, Mono<Void>>> list) {
        return obj -> {
            return listPrompts().flatMap(listPromptsResult -> {
                return Flux.fromIterable(list).flatMap(function -> {
                    return (Publisher) function.apply(listPromptsResult.prompts());
                }).onErrorResume(th -> {
                    logger.error("Error handling prompts list change notification", th);
                    return Mono.empty();
                }).then();
            });
        };
    }

    private DefaultMcpSession.NotificationHandler asyncLoggingNotificationHandler(List<Function<McpSchema.LoggingMessageNotification, Mono<Void>>> list) {
        return obj -> {
            McpSchema.LoggingMessageNotification loggingMessageNotification = (McpSchema.LoggingMessageNotification) this.transport.unmarshalFrom(obj, new TypeReference<McpSchema.LoggingMessageNotification>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.13
            });
            return Flux.fromIterable(list).flatMap(function -> {
                return (Publisher) function.apply(loggingMessageNotification);
            }).then();
        };
    }

    public Mono<Void> setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
        Assert.notNull(loggingLevel, "Logging level must not be null");
        return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, Map.of("level", (String) this.transport.unmarshalFrom(loggingLevel, new TypeReference<String>() { // from class: io.modelcontextprotocol.client.McpAsyncClient.14
        })));
    }

    void setProtocolVersions(List<String> list) {
        this.protocolVersions = list;
    }
}
