package dev.langchain4j.community.chain;

import dev.langchain4j.Experimental;
import dev.langchain4j.chain.Chain;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.injector.DefaultContentInjector;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.rag.query.Query;
import java.util.List;

@Experimental
/* loaded from: input_file:dev/langchain4j/community/chain/RetrievalQAChain.class */
public class RetrievalQAChain implements Chain<Query, String> {
    private final ChatModel chatModel;
    private final RetrievalAugmentor retrievalAugmentor;

    /* loaded from: input_file:dev/langchain4j/community/chain/RetrievalQAChain$Builder.class */
    public static class Builder {
        private ChatModel chatModel;
        private final DefaultRetrievalAugmentor.DefaultRetrievalAugmentorBuilder augmentorBuilder = DefaultRetrievalAugmentor.builder();
        private RetrievalAugmentor retrievalAugmentor;

        public Builder chatModel(ChatModel chatModel) {
            this.chatModel = chatModel;
            return this;
        }

        public Builder contentRetriever(ContentRetriever contentRetriever) {
            if (contentRetriever != null) {
                this.augmentorBuilder.contentRetriever(contentRetriever);
            }
            return this;
        }

        public Builder prompt(PromptTemplate promptTemplate) {
            this.augmentorBuilder.contentInjector(DefaultContentInjector.builder().promptTemplate(promptTemplate).build());
            return this;
        }

        public Builder retrievalAugmentor(RetrievalAugmentor retrievalAugmentor) {
            this.retrievalAugmentor = retrievalAugmentor;
            return this;
        }

        public RetrievalQAChain build() {
            return this.retrievalAugmentor == null ? new RetrievalQAChain(this.chatModel, this.augmentorBuilder.build()) : new RetrievalQAChain(this.chatModel, this.retrievalAugmentor);
        }
    }

    public RetrievalQAChain(ChatModel chatModel, RetrievalAugmentor retrievalAugmentor) {
        this.chatModel = chatModel;
        this.retrievalAugmentor = retrievalAugmentor;
    }

    public String execute(Query query) {
        return this.chatModel.chat(augment(query).singleText());
    }

    private UserMessage augment(Query query) {
        UserMessage from = UserMessage.from(query.text());
        return this.retrievalAugmentor.augment(new AugmentationRequest(from, query.metadata() == null ? Metadata.from(from, (Object) null, (List) null) : query.metadata())).chatMessage();
    }

    public static Builder builder() {
        return new Builder();
    }
}
