package org.axonframework.messaging.unitofwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.axonframework.messaging.GenericMessage;
import org.axonframework.messaging.GenericResultMessage;
import org.axonframework.messaging.Message;
import org.axonframework.messaging.unitofwork.UnitOfWork;
import org.axonframework.utils.MockException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/axonframework/messaging/unitofwork/BatchingUnitOfWorkTest.class */
class BatchingUnitOfWorkTest {
    private List<PhaseTransition> transitions;
    private BatchingUnitOfWork<?> subject;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/axonframework/messaging/unitofwork/BatchingUnitOfWorkTest$PhaseTransition.class */
    public static class PhaseTransition {
        private final UnitOfWork.Phase phase;
        private final Message<?> message;

        public PhaseTransition(Message<?> message, UnitOfWork.Phase phase) {
            this.message = message;
            this.phase = phase;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            PhaseTransition phaseTransition = (PhaseTransition) obj;
            return this.phase == phaseTransition.phase && Objects.equals(this.message, phaseTransition.message);
        }

        public int hashCode() {
            return Objects.hash(this.phase, this.message);
        }

        public String toString() {
            return this.phase + " -> " + this.message.getPayload();
        }
    }

    BatchingUnitOfWorkTest() {
    }

    @BeforeEach
    void setUp() {
        this.transitions = new ArrayList();
    }

    @Test
    void executeTask() {
        List<Message<?>> asList = Arrays.asList(toMessage(0), toMessage(1), toMessage(2));
        this.subject = new BatchingUnitOfWork<>(asList);
        this.subject.executeWithResult(() -> {
            registerListeners(this.subject);
            return resultFor(this.subject.getMessage());
        });
        validatePhaseTransitions(Arrays.asList(UnitOfWork.Phase.PREPARE_COMMIT, UnitOfWork.Phase.COMMIT, UnitOfWork.Phase.AFTER_COMMIT, UnitOfWork.Phase.CLEANUP), asList);
        HashMap hashMap = new HashMap();
        asList.forEach(message -> {
            hashMap.put(message, new ExecutionResult(GenericResultMessage.asResultMessage(resultFor(message))));
        });
        assertExecutionResults(hashMap, this.subject.getExecutionResults());
    }

    @Test
    void rollback() {
        List asList = Arrays.asList(toMessage(0), toMessage(1), toMessage(2));
        this.subject = new BatchingUnitOfWork<>(asList);
        MockException mockException = new MockException();
        try {
            this.subject.executeWithResult(() -> {
                registerListeners(this.subject);
                if (this.subject.getMessage().getPayload().equals(1)) {
                    throw mockException;
                }
                return resultFor(this.subject.getMessage());
            });
        } catch (Exception e) {
        }
        validatePhaseTransitions(Arrays.asList(UnitOfWork.Phase.ROLLBACK, UnitOfWork.Phase.CLEANUP), asList.subList(0, 2));
        HashMap hashMap = new HashMap();
        asList.forEach(message -> {
            hashMap.put(message, new ExecutionResult(GenericResultMessage.asResultMessage(mockException)));
        });
        assertExecutionResults(hashMap, this.subject.getExecutionResults());
    }

    @Test
    void suppressedExceptionOnRollback() {
        List<Message<?>> asList = Arrays.asList(toMessage(0), toMessage(1), toMessage(2));
        AtomicInteger atomicInteger = new AtomicInteger();
        this.subject = new BatchingUnitOfWork<>(asList);
        MockException mockException = new MockException("task exception");
        MockException mockException2 = new MockException("commit exception");
        MockException mockException3 = new MockException("cleanup exception");
        this.subject.onCleanup(unitOfWork -> {
            atomicInteger.incrementAndGet();
        });
        this.subject.onCleanup(unitOfWork2 -> {
            throw mockException3;
        });
        this.subject.onCleanup(unitOfWork3 -> {
            atomicInteger.incrementAndGet();
        });
        try {
            this.subject.executeWithResult(() -> {
                registerListeners(this.subject);
                if (!this.subject.getMessage().getPayload().equals(2)) {
                    return resultFor(this.subject.getMessage());
                }
                this.subject.addHandler(UnitOfWork.Phase.PREPARE_COMMIT, unitOfWork4 -> {
                    throw mockException2;
                });
                throw mockException;
            }, th -> {
                return false;
            });
        } catch (Exception e) {
        }
        validatePhaseTransitions(Arrays.asList(UnitOfWork.Phase.PREPARE_COMMIT, UnitOfWork.Phase.ROLLBACK, UnitOfWork.Phase.CLEANUP), asList);
        HashMap hashMap = new HashMap();
        hashMap.put(asList.get(0), new ExecutionResult(GenericResultMessage.asResultMessage(mockException2)));
        hashMap.put(asList.get(1), new ExecutionResult(GenericResultMessage.asResultMessage(mockException2)));
        hashMap.put(asList.get(2), new ExecutionResult(GenericResultMessage.asResultMessage(mockException)));
        assertExecutionResults(hashMap, this.subject.getExecutionResults());
        Assertions.assertSame(mockException2, mockException.getSuppressed()[0]);
        Assertions.assertEquals(2, atomicInteger.get());
    }

    private void registerListeners(UnitOfWork<?> unitOfWork) {
        unitOfWork.onPrepareCommit(unitOfWork2 -> {
            this.transitions.add(new PhaseTransition(unitOfWork2.getMessage(), UnitOfWork.Phase.PREPARE_COMMIT));
        });
        unitOfWork.onCommit(unitOfWork3 -> {
            this.transitions.add(new PhaseTransition(unitOfWork3.getMessage(), UnitOfWork.Phase.COMMIT));
        });
        unitOfWork.afterCommit(unitOfWork4 -> {
            this.transitions.add(new PhaseTransition(unitOfWork4.getMessage(), UnitOfWork.Phase.AFTER_COMMIT));
        });
        unitOfWork.onRollback(unitOfWork5 -> {
            this.transitions.add(new PhaseTransition(unitOfWork5.getMessage(), UnitOfWork.Phase.ROLLBACK));
        });
        unitOfWork.onCleanup(unitOfWork6 -> {
            this.transitions.add(new PhaseTransition(unitOfWork6.getMessage(), UnitOfWork.Phase.CLEANUP));
        });
    }

    private static Message<?> toMessage(Object obj) {
        return new GenericMessage(obj);
    }

    public static Object resultFor(Message<?> message) {
        return "Result for: " + message.getPayload();
    }

    private void validatePhaseTransitions(List<UnitOfWork.Phase> list, List<Message<?>> list2) {
        Iterator<PhaseTransition> it = this.transitions.iterator();
        for (UnitOfWork.Phase phase : list) {
            (phase.isReverseCallbackOrder() ? new LinkedList(list2).descendingIterator() : list2.iterator()).forEachRemaining(message -> {
                PhaseTransition phaseTransition = new PhaseTransition(message, phase);
                Assertions.assertTrue(it.hasNext());
                Assertions.assertEquals(phaseTransition, (PhaseTransition) it.next());
            });
        }
    }

    private void assertExecutionResults(Map<Message<?>, ExecutionResult> map, Map<Message<?>, ExecutionResult> map2) {
        Assertions.assertEquals(map.keySet(), map2.keySet());
        List list = (List) map.values().stream().map((v0) -> {
            return v0.getResult();
        }).collect(Collectors.toList());
        List list2 = (List) map2.values().stream().map((v0) -> {
            return v0.getResult();
        }).collect(Collectors.toList());
        List list3 = (List) list.stream().filter(resultMessage -> {
            return !resultMessage.isExceptional();
        }).map((v0) -> {
            return v0.getPayload();
        }).collect(Collectors.toList());
        List list4 = (List) list2.stream().filter(resultMessage2 -> {
            return !resultMessage2.isExceptional();
        }).map((v0) -> {
            return v0.getPayload();
        }).collect(Collectors.toList());
        List list5 = (List) list.stream().filter((v0) -> {
            return v0.isExceptional();
        }).map((v0) -> {
            return v0.exceptionResult();
        }).collect(Collectors.toList());
        List list6 = (List) list2.stream().filter((v0) -> {
            return v0.isExceptional();
        }).map((v0) -> {
            return v0.exceptionResult();
        }).collect(Collectors.toList());
        List list7 = (List) list.stream().map((v0) -> {
            return v0.getMetaData();
        }).collect(Collectors.toList());
        List list8 = (List) list2.stream().map((v0) -> {
            return v0.getMetaData();
        }).collect(Collectors.toList());
        Assertions.assertEquals(list3.size(), list4.size());
        Assertions.assertTrue(list3.containsAll(list4));
        Assertions.assertEquals(list5.size(), list6.size());
        Assertions.assertTrue(list5.containsAll(list6));
        Assertions.assertTrue(list7.containsAll(list8));
    }
}
