package org.noear.solon.ai.mcp.client;

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.ServerParameters;
import io.modelcontextprotocol.client.transport.StdioClientTransport;
import io.modelcontextprotocol.client.transport.WebRxSseClientTransport;
import io.modelcontextprotocol.client.transport.WebRxStreamableClientTransport;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;
import org.noear.snack.ONode;
import org.noear.solon.Utils;
import org.noear.solon.ai.chat.message.ChatMessage;
import org.noear.solon.ai.chat.tool.FunctionTool;
import org.noear.solon.ai.chat.tool.FunctionToolDesc;
import org.noear.solon.ai.chat.tool.ToolProvider;
import org.noear.solon.ai.mcp.McpChannel;
import org.noear.solon.ai.mcp.exception.McpException;
import org.noear.solon.ai.mcp.server.prompt.FunctionPrompt;
import org.noear.solon.ai.mcp.server.prompt.FunctionPromptDesc;
import org.noear.solon.ai.mcp.server.prompt.PromptProvider;
import org.noear.solon.ai.mcp.server.resource.FunctionResource;
import org.noear.solon.ai.mcp.server.resource.FunctionResourceDesc;
import org.noear.solon.ai.mcp.server.resource.ResourceProvider;
import org.noear.solon.ai.media.Image;
import org.noear.solon.ai.media.Text;
import org.noear.solon.core.Props;
import org.noear.solon.core.util.RunUtil;
import org.noear.solon.data.cache.LocalCacheService;
import org.noear.solon.net.http.HttpTimeout;
import org.noear.solon.net.http.HttpUtilsBuilder;

/* loaded from: input_file:org/noear/solon/ai/mcp/client/McpClientProvider.class */
public class McpClientProvider implements ToolProvider, ResourceProvider, PromptProvider, Closeable {
    private final ReentrantLock LOCKER;
    private final AtomicBoolean isClosed;
    private final AtomicBoolean isStarted;
    private final McpClientProperties clientProps;
    private ScheduledExecutorService heartbeatExecutor;
    private McpSyncClient client;
    private McpSchema.LoggingLevel loggingLevel;
    private LocalCacheService localCache;

    /* loaded from: input_file:org/noear/solon/ai/mcp/client/McpClientProvider$Builder.class */
    public static class Builder {
        private McpClientProperties props = new McpClientProperties();

        public Builder name(String str) {
            this.props.setName(str);
            return this;
        }

        public Builder version(String str) {
            this.props.setVersion(str);
            return this;
        }

        public Builder channel(String str) {
            this.props.setChannel(str);
            return this;
        }

        public Builder apiUrl(String str) {
            this.props.setApiUrl(str);
            return this;
        }

        public Builder apiKey(String str) {
            this.props.setApiKey(str);
            return this;
        }

        public Builder headerSet(String str, String str2) {
            this.props.getHeaders().put(str, str2);
            return this;
        }

        public Builder headerSet(Map<String, String> map) {
            if (Utils.isNotEmpty(map)) {
                this.props.getHeaders().putAll(map);
            }
            return this;
        }

        public Builder httpTimeout(HttpTimeout httpTimeout) {
            this.props.setHttpTimeout(httpTimeout);
            return this;
        }

        public Builder httpProxy(Proxy proxy) {
            this.props.setHttpProxy(proxy);
            return this;
        }

        public Builder httpProxy(String str, int i) {
            return httpProxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress(str, i)));
        }

        public Builder requestTimeout(Duration duration) {
            this.props.setRequestTimeout(duration);
            return this;
        }

        public Builder initializationTimeout(Duration duration) {
            this.props.setInitializationTimeout(duration);
            return this;
        }

        public Builder heartbeatInterval(Duration duration) {
            this.props.setHeartbeatInterval(duration);
            return this;
        }

        public Builder cacheSeconds(int i) {
            this.props.setCacheSeconds(i);
            return this;
        }

        public Builder serverParameters(McpServerParameters mcpServerParameters) {
            this.props.setServerParameters(mcpServerParameters);
            return this;
        }

        public McpClientProvider build() {
            return new McpClientProvider(this.props);
        }
    }

    public McpClientProvider(Properties properties) {
        this((McpClientProperties) Props.from(properties).bindTo(new McpClientProperties()));
    }

    public McpClientProvider(String str) {
        this(new McpClientProperties(str));
    }

    public McpClientProvider(McpClientProperties mcpClientProperties) {
        this.LOCKER = new ReentrantLock();
        this.isClosed = new AtomicBoolean(false);
        this.isStarted = new AtomicBoolean(false);
        this.loggingLevel = McpSchema.LoggingLevel.INFO;
        this.localCache = new LocalCacheService();
        if (mcpClientProperties.getHeartbeatInterval() != null && mcpClientProperties.getHeartbeatInterval().getSeconds() < 10) {
            throw new IllegalArgumentException("HeartbeatInterval cannot be less than 10s!");
        }
        if (McpChannel.STDIO.equals(mcpClientProperties.getChannel())) {
            if (mcpClientProperties.getServerParameters() == null) {
                throw new IllegalArgumentException("ServerParameters is null!");
            }
        } else if (Utils.isEmpty(mcpClientProperties.getApiUrl())) {
            throw new IllegalArgumentException("ApiUrl is empty!");
        }
        this.clientProps = mcpClientProperties;
        heartbeatHandle();
    }

    public void clearCache() {
        this.localCache.clear();
    }

    private McpSyncClient buildClient() {
        McpClientTransport build;
        if (McpChannel.STDIO.equals(this.clientProps.getChannel())) {
            build = new StdioClientTransport(ServerParameters.builder(this.clientProps.getServerParameters().getCommand()).args(this.clientProps.getServerParameters().getArgs()).env(this.clientProps.getServerParameters().getEnv()).build());
        } else {
            URI create = URI.create(this.clientProps.getApiUrl());
            String str = create.getScheme() + "://" + create.getAuthority();
            String rawPath = Utils.isEmpty(create.getRawQuery()) ? create.getRawPath() : create.getRawPath() + "?" + create.getRawQuery();
            if (Utils.isEmpty(rawPath)) {
                throw new IllegalArgumentException("SseEndpoint is empty!");
            }
            HttpUtilsBuilder httpUtilsBuilder = new HttpUtilsBuilder();
            httpUtilsBuilder.baseUri(str);
            if (Utils.isNotEmpty(this.clientProps.getApiKey())) {
                httpUtilsBuilder.headerSet("Authorization", "Bearer " + this.clientProps.getApiKey());
            }
            this.clientProps.getHeaders().forEach((str2, str3) -> {
                httpUtilsBuilder.headerSet(str2, str3);
            });
            if (this.clientProps.getHttpTimeout() != null) {
                httpUtilsBuilder.timeout(this.clientProps.getHttpTimeout());
            }
            if (this.clientProps.getHttpProxy() != null) {
                httpUtilsBuilder.proxy(this.clientProps.getHttpProxy());
            }
            build = McpChannel.STREAMABLE.equals(this.clientProps.getChannel()) ? WebRxStreamableClientTransport.builder(httpUtilsBuilder).endpoint(rawPath).build() : WebRxSseClientTransport.builder(httpUtilsBuilder).sseEndpoint(rawPath).build();
        }
        return McpClient.sync(build).clientInfo(new McpSchema.Implementation(this.clientProps.getName(), this.clientProps.getVersion())).requestTimeout(this.clientProps.getRequestTimeout()).initializationTimeout(this.clientProps.getInitializationTimeout()).loggingConsumer(loggingMessageNotification -> {
            loggingMessageNotification.setLevel(this.loggingLevel);
        }).build();
    }

    public McpSyncClient getClient() {
        this.LOCKER.lock();
        try {
            if (this.isClosed.get()) {
                throw new IllegalStateException("The current status has been closed.");
            }
            this.isStarted.set(true);
            if (this.client == null) {
                this.client = buildClient();
            }
            if (!this.client.isInitialized()) {
                this.client.initialize();
            }
            return this.client;
        } finally {
            this.LOCKER.unlock();
        }
    }

    public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
        if (loggingLevel != null) {
            this.loggingLevel = loggingLevel;
        }
    }

    private void heartbeatHandle() {
        if (this.heartbeatExecutor == null) {
            this.heartbeatExecutor = Executors.newSingleThreadScheduledExecutor();
        }
        heartbeatHandleDo();
    }

    private void heartbeatHandleDo() {
        if (this.heartbeatExecutor == null || this.clientProps.getHeartbeatInterval() == null) {
            return;
        }
        this.heartbeatExecutor.schedule(() -> {
            if (Thread.currentThread().isInterrupted() || this.isClosed.get()) {
                return;
            }
            if (this.isStarted.get()) {
                RunUtil.runAndTry(() -> {
                    try {
                        getClient().ping();
                    } catch (Throwable th) {
                        reset();
                    }
                });
            }
            heartbeatHandleDo();
        }, this.clientProps.getHeartbeatInterval().toMillis(), TimeUnit.MILLISECONDS);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        this.LOCKER.lock();
        try {
            if (!this.isClosed.get()) {
                this.isClosed.set(true);
                this.isStarted.set(false);
                if (this.heartbeatExecutor != null) {
                    this.heartbeatExecutor.shutdownNow();
                    this.heartbeatExecutor = null;
                }
                reset();
            }
        } finally {
            this.LOCKER.unlock();
        }
    }

    public void reopen() {
        this.LOCKER.lock();
        try {
            if (this.isClosed.get()) {
                this.isClosed.set(false);
                getClient();
                heartbeatHandle();
            }
        } finally {
            this.LOCKER.unlock();
        }
    }

    private void reset() {
        this.LOCKER.lock();
        try {
            if (this.client != null) {
                this.client.close();
                this.client = null;
            }
        } finally {
            this.LOCKER.unlock();
        }
    }

    public Text callToolAsText(String str, Map<String, Object> map) {
        McpSchema.CallToolResult callTool = callTool(str, map);
        if (Utils.isEmpty(callTool.getContent())) {
            return null;
        }
        McpSchema.Content content = callTool.getContent().get(0);
        if (content instanceof McpSchema.TextContent) {
            return Text.of(false, ((McpSchema.TextContent) content).getText());
        }
        throw new IllegalArgumentException("The tool result content is not a text content.");
    }

    public Image callToolAsImage(String str, Map<String, Object> map) {
        McpSchema.CallToolResult callTool = callTool(str, map);
        if (Utils.isEmpty(callTool.getContent())) {
            return null;
        }
        McpSchema.Content content = callTool.getContent().get(0);
        if (!(content instanceof McpSchema.ImageContent)) {
            throw new IllegalArgumentException("The tool result content is not a image content.");
        }
        McpSchema.ImageContent imageContent = (McpSchema.ImageContent) content;
        return Image.ofBase64(imageContent.getData(), imageContent.getMimeType());
    }

    public McpSchema.CallToolResult callTool(String str, Map<String, Object> map) {
        try {
            McpSchema.CallToolResult callTool = getClient().callTool(new McpSchema.CallToolRequest(str, map));
            if (callTool.getIsError() == null || !callTool.getIsError().booleanValue()) {
                return callTool;
            }
            if (Utils.isEmpty(callTool.getContent())) {
                throw new McpException("Call Toll Failed");
            }
            throw new McpException(callTool.getContent().get(0).toString());
        } catch (RuntimeException e) {
            reset();
            throw e;
        }
    }

    public Text readResourceAsText(String str) {
        McpSchema.ReadResourceResult readResource = readResource(str);
        if (Utils.isEmpty(readResource.getContents())) {
            return null;
        }
        McpSchema.ResourceContents resourceContents = readResource.getContents().get(0);
        if (resourceContents instanceof McpSchema.TextResourceContents) {
            McpSchema.TextResourceContents textResourceContents = (McpSchema.TextResourceContents) resourceContents;
            return Text.of(false, textResourceContents.getText(), textResourceContents.getMimeType());
        }
        McpSchema.BlobResourceContents blobResourceContents = (McpSchema.BlobResourceContents) resourceContents;
        return Text.of(true, blobResourceContents.getBlob(), blobResourceContents.getMimeType());
    }

    public McpSchema.ReadResourceResult readResource(String str) {
        try {
            McpSchema.ReadResourceResult readResource = getClient().readResource(new McpSchema.ReadResourceRequest(str));
            if (Utils.isEmpty(readResource.getContents())) {
                throw new McpException("Read resource Failed");
            }
            return readResource;
        } catch (RuntimeException e) {
            reset();
            throw e;
        }
    }

    public List<ChatMessage> getPromptAsMessages(String str, Map<String, Object> map) {
        ArrayList arrayList = new ArrayList();
        for (McpSchema.PromptMessage promptMessage : getPrompt(str, map).getMessages()) {
            McpSchema.Content content = promptMessage.getContent();
            if (promptMessage.getRole() == McpSchema.Role.ASSISTANT) {
                if (content instanceof McpSchema.TextContent) {
                    arrayList.add(ChatMessage.ofAssistant(((McpSchema.TextContent) content).getText()));
                }
            } else if (content instanceof McpSchema.TextContent) {
                arrayList.add(ChatMessage.ofUser(((McpSchema.TextContent) content).getText()));
            } else if (content instanceof McpSchema.ImageContent) {
                McpSchema.ImageContent imageContent = (McpSchema.ImageContent) content;
                String data = imageContent.getData();
                if (data.contains("://")) {
                    arrayList.add(ChatMessage.ofUser(Image.ofUrl(data)));
                } else {
                    arrayList.add(ChatMessage.ofUser(Image.ofBase64(data, imageContent.getMimeType())));
                }
            }
        }
        return arrayList;
    }

    public McpSchema.GetPromptResult getPrompt(String str, Map<String, Object> map) {
        try {
            McpSchema.GetPromptResult prompt = getClient().getPrompt(new McpSchema.GetPromptRequest(str, map));
            if (Utils.isEmpty(prompt.getMessages())) {
                throw new McpException("Read resource Failed");
            }
            return prompt;
        } catch (RuntimeException e) {
            reset();
            throw e;
        }
    }

    public Collection<FunctionTool> getTools() {
        return getTools(null);
    }

    public Collection<FunctionTool> getTools(String str) {
        return (Collection) this.localCache.getOrStore("getTools:" + str, Collection.class, this.clientProps.getCacheSeconds(), () -> {
            return getToolsDo(str);
        });
    }

    private Collection<FunctionTool> getToolsDo(String str) {
        ArrayList arrayList = new ArrayList();
        for (McpSchema.Tool tool : (str == null ? getClient().listTools() : getClient().listTools(str)).getTools()) {
            String name = tool.getName();
            arrayList.add(new FunctionToolDesc(name, tool.getDescription(), tool.getReturnDirect(), ONode.load(tool.getInputSchema()).toJson(), tool.getOutputSchema() == null ? null : ONode.load(tool.getOutputSchema()).toJson(), map -> {
                return callToolAsText(name, map).getContent();
            }));
        }
        return arrayList;
    }

    @Override // org.noear.solon.ai.mcp.server.resource.ResourceProvider
    public Collection<FunctionResource> getResources() {
        return getResources(null);
    }

    public Collection<FunctionResource> getResources(String str) {
        return (Collection) this.localCache.getOrStore("getResources:" + str, Collection.class, this.clientProps.getCacheSeconds(), () -> {
            return getResourcesDo(str);
        });
    }

    private Collection<FunctionResource> getResourcesDo(String str) {
        ArrayList arrayList = new ArrayList();
        for (McpSchema.Resource resource : (str == null ? getClient().listResources() : getClient().listResources(str)).getResources()) {
            String name = resource.getName();
            String uri = resource.getUri();
            String description = resource.getDescription();
            FunctionResourceDesc functionResourceDesc = new FunctionResourceDesc(name);
            functionResourceDesc.description(description);
            functionResourceDesc.uri(uri);
            functionResourceDesc.mimeType(resource.getMimeType());
            functionResourceDesc.doHandle(str2 -> {
                return readResourceAsText(str2);
            });
            arrayList.add(functionResourceDesc);
        }
        return arrayList;
    }

    public Collection<FunctionResource> getResourceTemplates() {
        return getResourceTemplates(null);
    }

    public Collection<FunctionResource> getResourceTemplates(String str) {
        return (Collection) this.localCache.getOrStore("getResourceTemplates:" + str, Collection.class, this.clientProps.getCacheSeconds(), () -> {
            return getResourceTemplatesDo(str);
        });
    }

    private Collection<FunctionResource> getResourceTemplatesDo(String str) {
        ArrayList arrayList = new ArrayList();
        for (McpSchema.ResourceTemplate resourceTemplate : (str == null ? getClient().listResourceTemplates() : getClient().listResourceTemplates(str)).getResourceTemplates()) {
            String name = resourceTemplate.getName();
            String uriTemplate = resourceTemplate.getUriTemplate();
            String description = resourceTemplate.getDescription();
            FunctionResourceDesc functionResourceDesc = new FunctionResourceDesc(name);
            functionResourceDesc.description(description);
            functionResourceDesc.uri(uriTemplate);
            functionResourceDesc.mimeType(resourceTemplate.getMimeType());
            functionResourceDesc.doHandle(str2 -> {
                return readResourceAsText(str2);
            });
            arrayList.add(functionResourceDesc);
        }
        return arrayList;
    }

    @Override // org.noear.solon.ai.mcp.server.prompt.PromptProvider
    public Collection<FunctionPrompt> getPrompts() {
        return getPrompts(null);
    }

    public Collection<FunctionPrompt> getPrompts(String str) {
        return (Collection) this.localCache.getOrStore("getPrompts:" + str, Collection.class, this.clientProps.getCacheSeconds(), () -> {
            return getPromptsDo(str);
        });
    }

    private Collection<FunctionPrompt> getPromptsDo(String str) {
        ArrayList arrayList = new ArrayList();
        for (McpSchema.Prompt prompt : (str == null ? getClient().listPrompts() : getClient().listPrompts(str)).getPrompts()) {
            String name = prompt.getName();
            String description = prompt.getDescription();
            FunctionPromptDesc functionPromptDesc = new FunctionPromptDesc(name);
            functionPromptDesc.description(description);
            for (McpSchema.PromptArgument promptArgument : prompt.getArguments()) {
                functionPromptDesc.paramAdd(promptArgument.getName(), promptArgument.getRequired().booleanValue(), promptArgument.getDescription());
            }
            functionPromptDesc.doHandle(map -> {
                return getPromptAsMessages(name, map);
            });
            arrayList.add(functionPromptDesc);
        }
        return arrayList;
    }

    @Deprecated
    public static Map<String, McpClientProvider> fromMcpServers(String str) throws IOException {
        return McpProviders.fromMcpServers(str).getProviders();
    }

    public static Builder builder() {
        return new Builder();
    }
}
