package dev.langchain4j.community.chain;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.injector.DefaultContentInjector;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.rag.query.Query;
import java.util.Arrays;
import java.util.List;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith({MockitoExtension.class})
/* loaded from: input_file:dev/langchain4j/community/chain/RetrievalQAChainTest.class */
public class RetrievalQAChainTest {
    private static final String ANSWER = "answer";

    @Mock
    ChatModel chatModel;

    @Mock
    ContentRetriever contentRetriever;

    @Captor
    ArgumentCaptor<String> messagesCaptor;
    private static final Query QUERY = Query.from("query");
    public static final PromptTemplate promptTemplate = PromptTemplate.from("Answer the question based only on the context provided.\n\nContext:\n{{contents}}\n\nQuestion:\n{{userMessage}}\n\nAnswer:\n");

    @BeforeEach
    void beforeEach() {
        Mockito.lenient().when(this.chatModel.chat(ArgumentMatchers.anyString())).thenReturn(ANSWER);
    }

    @Test
    void should_inject_retrieved_segments() {
        Mockito.when(this.contentRetriever.retrieve((Query) ArgumentMatchers.any())).thenReturn(Arrays.asList(Content.from("Segment 1"), Content.from("Segment 2")));
        Assertions.assertThat(RetrievalQAChain.builder().chatModel(this.chatModel).contentRetriever(this.contentRetriever).build().execute(QUERY)).isEqualTo(ANSWER);
        ((ChatModel) Mockito.verify(this.chatModel)).chat((String) this.messagesCaptor.capture());
        Assertions.assertThat((String) this.messagesCaptor.getValue()).isEqualTo("query\n\nAnswer using the following information:\nSegment 1\n\nSegment 2");
    }

    @Test
    void should_inject_retrieved_segments_using_custom_prompt_template() {
        Mockito.when(this.contentRetriever.retrieve((Query) ArgumentMatchers.any())).thenReturn(Arrays.asList(Content.from("Segment 1"), Content.from("Segment 2")));
        PromptTemplate from = PromptTemplate.from("Answer the question based only on the context provided.\n\nContext: {{contents}}\n\nQuestion: {{userMessage}}\n\nAnswer:\n");
        Assertions.assertThat(RetrievalQAChain.builder().chatModel(this.chatModel).retrievalAugmentor(DefaultRetrievalAugmentor.builder().contentRetriever(this.contentRetriever).contentInjector(DefaultContentInjector.builder().promptTemplate(from).build()).build()).build().execute(QUERY)).isEqualTo(ANSWER);
        ((ChatModel) Mockito.verify(this.chatModel)).chat((String) this.messagesCaptor.capture());
        Assertions.assertThat((String) this.messagesCaptor.getValue()).isEqualToIgnoringWhitespace("Answer the question based only on the context provided.\nContext:\nSegment 1\nSegment 2\n\nQuestion:\nquery\nAnswer:\n");
        Assertions.assertThat(RetrievalQAChain.builder().chatModel(this.chatModel).contentRetriever(this.contentRetriever).prompt(from).build().execute(QUERY)).isEqualTo(ANSWER);
        Assertions.assertThat((String) this.messagesCaptor.getValue()).isEqualToIgnoringWhitespace("Answer the question based only on the context provided.\nContext:\nSegment 1\nSegment 2\n\nQuestion:\nquery\nAnswer:\n");
    }

    @Test
    void should_inject_retrieved_segments_using_custom_prompt_template_and_metadata() {
        Mockito.when(this.contentRetriever.retrieve((Query) ArgumentMatchers.argThat(query -> {
            return (query == null || query.metadata() == null) ? false : true;
        }))).thenReturn(List.of(Content.from(TextSegment.from("Segment 1 with meta")), Content.from(TextSegment.from("Segment 2  with meta"))));
        Assertions.assertThat(RetrievalQAChain.builder().chatModel(this.chatModel).retrievalAugmentor(DefaultRetrievalAugmentor.builder().contentRetriever(this.contentRetriever).contentInjector(DefaultContentInjector.builder().promptTemplate(promptTemplate).build()).build()).build().execute(Query.from("query", Metadata.from(UserMessage.from("user message"), 42, List.of(UserMessage.from("Hello"), AiMessage.from("Hi, how can I help you today?")))))).isEqualTo(ANSWER);
        ((ChatModel) Mockito.verify(this.chatModel)).chat((String) this.messagesCaptor.capture());
        Assertions.assertThat((String) this.messagesCaptor.getValue()).isEqualToIgnoringWhitespace("Answer the question based only on the context provided.\nContext:\nSegment 1 with meta\nSegment 2  with meta\n\nQuestion:\nquery\nAnswer:\n");
    }

    @Test
    void should_throws_exception_if_neither_retriever_nor_retrieval_augmentor_is_defined() {
        try {
            RetrievalQAChain.builder().chatModel(this.chatModel).build();
            Assertions.fail("Should fail due to missing builder configurations");
        } catch (Exception e) {
            Assertions.assertThat(e.getMessage()).contains(new CharSequence[]{"queryRouter cannot be null"});
        }
    }

    @Test
    void should_throws_exception_if_retriever_is_null() {
        try {
            RetrievalQAChain.builder().chatModel(this.chatModel).contentRetriever((ContentRetriever) null).build();
            Assertions.fail("Should fail due to missing builder configurations");
        } catch (Exception e) {
            Assertions.assertThat(e.getMessage()).contains(new CharSequence[]{"queryRouter cannot be null"});
        }
    }
}
