package dev.langchain4j.community.store.memory.chat.neo4j;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageDeserializer;
import dev.langchain4j.data.message.ChatMessageSerializer;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.neo4j.cypherdsl.core.Cypher;
import org.neo4j.cypherdsl.core.Expression;
import org.neo4j.cypherdsl.core.FunctionInvocation;
import org.neo4j.cypherdsl.core.IdentifiableElement;
import org.neo4j.cypherdsl.core.Named;
import org.neo4j.cypherdsl.core.Node;
import org.neo4j.cypherdsl.core.PatternElement;
import org.neo4j.cypherdsl.core.Relationship;
import org.neo4j.cypherdsl.core.Statement;
import org.neo4j.cypherdsl.core.StatementBuilder;
import org.neo4j.cypherdsl.core.renderer.Renderer;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.exceptions.Neo4jException;

/* loaded from: input_file:dev/langchain4j/community/store/memory/chat/neo4j/Neo4jChatMemoryStore.class */
public class Neo4jChatMemoryStore implements ChatMemoryStore {
    public static final String DEFAULT_MEMORY_LABEL = "Memory";
    public static final String DEFAULT_MESSAGE_LABEL = "Message";
    public static final String DEFAULT_LAST_REL_TYPE = "LAST_MESSAGE";
    public static final String DEFAULT_REL_TYPE_NEXT = "NEXT";
    public static final String DEFAULT_ID_PROP = "id";
    public static final String DEFAULT_MESSAGE_PROP = "message";
    public static final String DEFAULT_DATABASE_NAME = "neo4j";
    public static final int DEFAULT_SIZE_VALUE = 10;
    private final Driver driver;
    private final SessionConfig config;
    private final String memoryLabel;
    private final String messageLabel;
    private final String lastMessageRelType;
    private final String nextMessageRelType;
    private final String idProperty;
    private final String messageProperty;
    private final int size;

    /* loaded from: input_file:dev/langchain4j/community/store/memory/chat/neo4j/Neo4jChatMemoryStore$Builder.class */
    public static class Builder {
        private Driver driver;
        private SessionConfig config;
        private String memoryLabel;
        private String messageLabel;
        private String lastMessageRelType;
        private String nextMessageRelType;
        private String idProperty;
        private String messageProperty;
        private String databaseName;
        private Integer size;

        public Builder driver(Driver driver) {
            this.driver = driver;
            return this;
        }

        public Builder config(SessionConfig sessionConfig) {
            this.config = sessionConfig;
            return this;
        }

        public Builder memoryLabel(String str) {
            this.memoryLabel = str;
            return this;
        }

        public Builder messageLabel(String str) {
            this.messageLabel = str;
            return this;
        }

        public Builder idProperty(String str) {
            this.idProperty = str;
            return this;
        }

        public Builder messageProperty(String str) {
            this.messageProperty = str;
            return this;
        }

        public Builder lastMessageRelType(String str) {
            this.lastMessageRelType = str;
            return this;
        }

        public Builder nextMessageRelType(String str) {
            this.nextMessageRelType = str;
            return this;
        }

        public Builder databaseName(String str) {
            this.databaseName = str;
            return this;
        }

        public Builder size(Integer num) {
            this.size = num;
            return this;
        }

        public Builder withBasicAuth(String str, String str2, String str3) {
            this.driver = GraphDatabase.driver(str, AuthTokens.basic(str2, str3));
            return this;
        }

        public Neo4jChatMemoryStore build() {
            return new Neo4jChatMemoryStore(this.driver, this.config, this.memoryLabel, this.messageLabel, this.lastMessageRelType, this.nextMessageRelType, this.idProperty, this.messageProperty, this.databaseName, this.size);
        }
    }

    public Neo4jChatMemoryStore(Driver driver, SessionConfig sessionConfig, String str, String str2, String str3, String str4, String str5, String str6, String str7, Integer num) {
        this.driver = (Driver) ValidationUtils.ensureNotNull(driver, "driver");
        this.config = (SessionConfig) Utils.getOrDefault(sessionConfig, SessionConfig.forDatabase((String) Utils.getOrDefault(str7, DEFAULT_DATABASE_NAME)));
        this.memoryLabel = (String) Utils.getOrDefault(str, DEFAULT_MEMORY_LABEL);
        this.messageLabel = (String) Utils.getOrDefault(str2, DEFAULT_MESSAGE_LABEL);
        this.lastMessageRelType = (String) Utils.getOrDefault(str3, DEFAULT_LAST_REL_TYPE);
        this.nextMessageRelType = (String) Utils.getOrDefault(str4, DEFAULT_REL_TYPE_NEXT);
        this.idProperty = (String) Utils.getOrDefault(str5, DEFAULT_ID_PROP);
        this.messageProperty = (String) Utils.getOrDefault(str6, DEFAULT_MESSAGE_PROP);
        this.size = ((Integer) Utils.getOrDefault(num, 10)).intValue();
    }

    public static Builder builder() {
        return new Builder();
    }

    private void createSessionNode(Object obj) {
        try {
            Session session = session();
            try {
                session.run(Cypher.merge(new PatternElement[]{(PatternElement) Cypher.node(this.memoryLabel, new String[0]).withProperties(new Object[]{this.idProperty, Cypher.parameter("memoryId")})}).build().getCypher(), Map.of("label", this.memoryLabel, "window", Integer.valueOf(this.size), "memoryId", obj));
                if (session != null) {
                    session.close();
                }
            } finally {
            }
        } catch (Neo4jException e) {
            getDescriptiveProcedureNotFoundError(e);
            throw new RuntimeException((Throwable) e);
        }
    }

    public List<ChatMessage> getMessages(Object obj) {
        String memoryIdString = toMemoryIdString(obj);
        try {
            Session session = session();
            try {
                List<ChatMessage> list = session.run(Renderer.getDefaultRenderer().render(buildHistoryQuery()), Map.of("label", this.memoryLabel, "window", this.size < 1 ? "" : Long.toString(this.size), "memoryId", memoryIdString)).stream().map(record -> {
                    return record.get("msg").asString((String) null);
                }).filter((v0) -> {
                    return Objects.nonNull(v0);
                }).map(ChatMessageDeserializer::messageFromJson).toList();
                if (session != null) {
                    session.close();
                }
                return list;
            } finally {
            }
        } catch (Neo4jException e) {
            getDescriptiveProcedureNotFoundError(e);
            throw new RuntimeException((Throwable) e);
        }
    }

    public Statement buildHistoryQuery() {
        Node named = Cypher.node(this.memoryLabel, new String[0]).named("s");
        Node named2 = Cypher.anyNode().named("lastNode");
        Relationship relationshipFrom = named2.relationshipFrom(Cypher.anyNode(), new String[]{this.nextMessageRelType});
        IdentifiableElement definedBy = Cypher.path("p").definedBy(this.size < 1 ? relationshipFrom.min(0) : relationshipFrom.length(0, Integer.valueOf(this.size)));
        return ((StatementBuilder.OngoingReadingWithWhere) Cypher.match(new PatternElement[]{named.relationshipTo(named2, new String[]{this.lastMessageRelType})}).where(named.property(this.idProperty).isEqualTo(Cypher.parameter("memoryId")))).match(new PatternElement[]{definedBy}).with(new IdentifiableElement[]{definedBy, FunctionInvocation.create(Neo4jUtils.functionDef("length"), new Expression[]{Cypher.name("p")}).as("length")}).orderBy(Cypher.name("length")).descending().limit(1).unwind(FunctionInvocation.create(Neo4jUtils.functionDef("reverse"), new Expression[]{FunctionInvocation.create(Neo4jUtils.functionDef("nodes"), new Expression[]{Cypher.name("p")})})).as("node").returning(new Expression[]{Cypher.name("node").property(this.messageProperty).as("msg")}).build();
    }

    public void updateMessages(Object obj, List<ChatMessage> list) {
        String memoryIdString = toMemoryIdString(obj);
        ValidationUtils.ensureNotEmpty(list, "messages");
        List list2 = list.stream().map(ChatMessageSerializer::messageToJson).map(str -> {
            return Map.of(this.messageProperty, str);
        }).toList();
        createSessionNode(memoryIdString);
        Session session = session();
        try {
            IdentifiableElement named = Cypher.node(this.memoryLabel, new String[0]).named("s");
            IdentifiableElement named2 = Cypher.anyNode().named("lastNode");
            Named named3 = named.relationshipTo(named2, new String[]{this.lastMessageRelType}).named("lastRel");
            IdentifiableElement named4 = Cypher.anyNode().named("new");
            session.run(((StatementBuilder.OngoingInQueryCallWithArguments) ((StatementBuilder.OngoingInQueryCallWithoutArguments) ((StatementBuilder.OngoingInQueryCallWithReturnFields) ((StatementBuilder.OngoingInQueryCallWithArguments) ((StatementBuilder.OngoingInQueryCallWithoutArguments) ((StatementBuilder.OngoingReadingWithWhere) Cypher.match(new PatternElement[]{named}).where(named.property(this.idProperty).isEqualTo(Cypher.parameter("memoryId")))).optionalMatch(new PatternElement[]{named3}).call(new String[]{"apoc.create.nodes"})).withArgs(new Expression[]{Cypher.raw("[$label], $messages", new Object[0])})).yield(new String[]{"node"})).with(new IdentifiableElement[]{Cypher.raw("collect(node)", new Object[0]).as("nodes"), named, named2, named3}).call(new String[]{"apoc.nodes.link"})).withArgs(new Expression[]{Cypher.raw("nodes, $relType, {avoidDuplicates: true}", new Object[0])})).withoutResults().with(new IdentifiableElement[]{Cypher.raw("nodes[-1]", new Object[0]).as("new"), named, named2, named3}).create(new PatternElement[]{named.relationshipTo(named4, new String[]{this.lastMessageRelType})}).with(new IdentifiableElement[]{named4, named3, named2}).where(named2.isNotNull()).create(new PatternElement[]{named2.relationshipTo(named4, new String[]{this.nextMessageRelType})}).delete(new Named[]{named3}).build().getCypher(), Map.of("memoryId", memoryIdString, "relType", this.nextMessageRelType, "label", this.messageLabel, "messages", list2));
            if (session != null) {
                session.close();
            }
        } catch (Throwable th) {
            if (session != null) {
                try {
                    session.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void deleteMessages(Object obj) {
        String memoryIdString = toMemoryIdString(obj);
        try {
            Session session = session();
            try {
                Named named = Cypher.node(this.memoryLabel, new String[0]).named("s");
                Named definedBy = Cypher.path("p").definedBy(named.relationshipTo(Cypher.anyNode().named("lastNode"), new String[]{this.lastMessageRelType}).named("lastRel").relationshipFrom(Cypher.anyNode(), new String[]{this.nextMessageRelType}).min(0));
                session.run(((StatementBuilder.OngoingReadingWithWhere) Cypher.match(new PatternElement[]{named}).where(named.property(this.idProperty).isEqualTo(Cypher.parameter("memoryId")))).optionalMatch(new PatternElement[]{definedBy}).with(new IdentifiableElement[]{named, definedBy, Cypher.raw("length(p)", new Object[0]).as("length")}).orderBy(Cypher.raw("length", new Object[0])).descending().limit(1).detachDelete(new Named[]{named, definedBy}).build().getCypher(), Map.of("memoryId", memoryIdString, "relType", this.lastMessageRelType, "label", this.memoryLabel));
                if (session != null) {
                    session.close();
                }
            } finally {
            }
        } catch (Neo4jException e) {
            getDescriptiveProcedureNotFoundError(e);
            throw new RuntimeException((Throwable) e);
        }
    }

    private static void getDescriptiveProcedureNotFoundError(Neo4jException neo4jException) {
        if ("Neo.ClientError.Procedure.ProcedureNotFound".equals(neo4jException.code())) {
            throw new Neo4jException("Please ensure the APOC plugin is installed in Neo4j", neo4jException);
        }
    }

    private static String toMemoryIdString(Object obj) {
        if (obj == null || obj.toString().trim().isEmpty()) {
            throw new IllegalArgumentException("memoryId cannot be null or empty");
        }
        return obj.toString();
    }

    private Session session() {
        return this.driver.session(this.config);
    }
}
