Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,30 @@ public GetResponse consume() {
return responseQueue.poll();
}

public void acknowledge(final GetResponse response) {
public void acknowledge(final long deliveryTag) {
if (autoAcknowledge) {
return;
}

try {
getChannel().basicAck(response.getEnvelope().getDeliveryTag(), true);
getChannel().basicAck(deliveryTag, true);
} catch (Exception e) {
throw new AMQPException("Failed to acknowledge message", e);
}
}

public void negativeAcknowledge(final long deliveryTag) {
if (autoAcknowledge) {
return;
}

try {
getChannel().basicNack(deliveryTag, true, true);
} catch (Exception e) {
throw new AMQPException("Failed to negatively acknowledge message", e);
}
}

@Override
public void close() throws TimeoutException, IOException {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,19 @@ protected void processResource(final Connection connection, final AMQPConsumer c
}

if (lastReceived != null) {
final GetResponse finalGetResponse = lastReceived;
session.commitAsync(() -> consumer.acknowledge(finalGetResponse), null);
final long lastDeliveryTag = lastReceived.getEnvelope().getDeliveryTag();

session.commitAsync(
() -> consumer.acknowledge(lastDeliveryTag),
failure -> {
getLogger().warn(
"ProcessSession commit failed after consuming AMQP messages up to delivery tag {}; negatively acknowledging with requeue",
lastDeliveryTag,
failure
);
consumer.negativeAcknowledge(lastDeliveryTag);
}
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
import com.rabbitmq.client.Connection;
import com.rabbitmq.client.MessageProperties;
import org.apache.nifi.amqp.processors.ConsumeAMQP.OutputHeaderFormat;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.io.OutputStreamCallback;
import org.apache.nifi.provenance.ProvenanceReporter;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.PropertyMigrationResult;
import org.apache.nifi.util.TestRunner;
Expand All @@ -40,12 +44,20 @@
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class ConsumeAMQPTest {

Expand All @@ -63,6 +75,7 @@ public void testMessageAcked() throws TimeoutException, IOException {
ConsumeAMQP proc = new LocalConsumeAMQP(connection);
TestRunner runner = initTestRunner(proc);
runner.setProperty(ConsumeAMQP.AUTO_ACKNOWLEDGE, "false");
runner.setProperty(ConsumeAMQP.BATCH_SIZE, "10");

runner.run();

Expand All @@ -75,8 +88,74 @@ public void testMessageAcked() throws TimeoutException, IOException {
worldFF.assertContentEquals("world");

// A single cumulative ack should be used
assertFalse(((TestChannel) connection.createChannel()).isAck(0));
assertTrue(((TestChannel) connection.createChannel()).isAck(1));
final TestChannel channel = (TestChannel) connection.createChannel();
assertFalse(channel.isAck(0));
assertTrue(channel.isAck(1));
assertEquals(1, channel.getBasicAckCount());
assertEquals(1L, channel.getLastBasicAckDeliveryTag());
assertTrue(channel.isLastBasicAckMultiple());
assertEquals(0, channel.getBasicNackCount());
}
}

@Test
public void testMessageNackedOnSessionCommitFailure() throws TimeoutException, IOException {
final Map<String, List<String>> routingMap = Collections.singletonMap("key1", Collections.singletonList("queue1"));
final Map<String, String> exchangeToRoutingKeymap = Collections.singletonMap("myExchange", "key1");

final Connection connection = new TestConnection(exchangeToRoutingKeymap, routingMap);

try (AMQPPublisher sender = new AMQPPublisher(connection, mock(ComponentLog.class));
AMQPConsumer consumer = new AMQPConsumer(connection, "queue1", false, 0, mock(ComponentLog.class))) {
sender.publish("hello".getBytes(), MessageProperties.PERSISTENT_TEXT_PLAIN, "key1", "myExchange");
sender.publish("world".getBytes(), MessageProperties.PERSISTENT_TEXT_PLAIN, "key1", "myExchange");

final ConsumeAMQP proc = new LocalConsumeAMQP(connection);
final TestRunner runner = initTestRunner(proc);
runner.setProperty(ConsumeAMQP.AUTO_ACKNOWLEDGE, "false");
runner.setProperty(ConsumeAMQP.BATCH_SIZE, "10");

final RuntimeException commitFailure = new RuntimeException("commit failed");
final ProcessSession session = failingCommitSession(commitFailure);

final RuntimeException thrown = assertThrows(RuntimeException.class,
() -> proc.processResource(connection, consumer, runner.getProcessContext(), session));
assertSame(commitFailure, thrown);

final TestChannel channel = (TestChannel) connection.createChannel();
assertEquals(0, channel.getBasicAckCount());
assertFalse(channel.isNack(0));
assertTrue(channel.isNack(1));
assertEquals(1, channel.getBasicNackCount());
assertEquals(1L, channel.getLastBasicNackDeliveryTag());
assertTrue(channel.isLastBasicNackMultiple());
assertTrue(channel.isLastBasicNackRequeue());
}
}

@Test
public void testAutoAcknowledgeDoesNotIssueManualAcknowledgements() throws TimeoutException, IOException {
final Map<String, List<String>> routingMap = Collections.singletonMap("key1", Collections.singletonList("queue1"));
final Map<String, String> exchangeToRoutingKeymap = Collections.singletonMap("myExchange", "key1");

final Connection connection = new TestConnection(exchangeToRoutingKeymap, routingMap);

try (AMQPPublisher sender = new AMQPPublisher(connection, mock(ComponentLog.class))) {
sender.publish("hello".getBytes(), MessageProperties.PERSISTENT_TEXT_PLAIN, "key1", "myExchange");
sender.publish("world".getBytes(), MessageProperties.PERSISTENT_TEXT_PLAIN, "key1", "myExchange");

ConsumeAMQP proc = new LocalConsumeAMQP(connection);
TestRunner runner = initTestRunner(proc);
runner.setProperty(ConsumeAMQP.AUTO_ACKNOWLEDGE, "true");
runner.setProperty(ConsumeAMQP.BATCH_SIZE, "10");

runner.run();

runner.assertTransferCount(ConsumeAMQP.REL_SUCCESS, 2);

final TestChannel channel = (TestChannel) connection.createChannel();
assertEquals(0, channel.getBasicAckCount());
assertEquals(0, channel.getBasicNackCount());
}
}

Expand Down Expand Up @@ -401,6 +480,23 @@ private TestRunner initTestRunner(ConsumeAMQP proc) {
return runner;
}

private ProcessSession failingCommitSession(final RuntimeException commitFailure) {
final ProcessSession session = mock(ProcessSession.class);
final FlowFile flowFile = new MockFlowFile(1L);

when(session.create()).thenReturn(flowFile);
when(session.write(eq(flowFile), any(OutputStreamCallback.class))).thenReturn(flowFile);
when(session.putAllAttributes(eq(flowFile), anyMap())).thenReturn(flowFile);
when(session.getProvenanceReporter()).thenReturn(mock(ProvenanceReporter.class));
doAnswer(invocation -> {
final Consumer<Throwable> onFailure = invocation.getArgument(1);
onFailure.accept(commitFailure);
throw commitFailure;
}).when(session).commitAsync(any(Runnable.class), any());

return session;
}

public static class LocalConsumeAMQP extends ConsumeAMQP {
private final Connection connection;
private AMQPConsumer consumer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ class TestChannel implements Channel {
private long deliveryTag = 0L;
private final BitSet acknowledgments = new BitSet();
private final BitSet nacks = new BitSet();
private int basicAckCount;
private long lastBasicAckDeliveryTag;
private boolean lastBasicAckMultiple;
private int basicNackCount;
private long lastBasicNackDeliveryTag;
private boolean lastBasicNackMultiple;
private boolean lastBasicNackRequeue;
private int prefetchCount = 0;

public TestChannel(Map<String, String> exchangeToRoutingKeyMappings,
Expand Down Expand Up @@ -485,21 +492,56 @@ public GetResponse basicGet(String queue, boolean autoAck) throws IOException {
@Override
public void basicAck(long deliveryTag, boolean multiple) throws IOException {
acknowledgments.set((int) deliveryTag);
basicAckCount++;
lastBasicAckDeliveryTag = deliveryTag;
lastBasicAckMultiple = multiple;
}

public boolean isAck(final int deliveryTag) {
return acknowledgments.get(deliveryTag);
}

public int getBasicAckCount() {
return basicAckCount;
}

public long getLastBasicAckDeliveryTag() {
return lastBasicAckDeliveryTag;
}

public boolean isLastBasicAckMultiple() {
return lastBasicAckMultiple;
}

@Override
public void basicNack(long deliveryTag, boolean multiple, boolean requeue) throws IOException {
nacks.set((int) deliveryTag);
basicNackCount++;
lastBasicNackDeliveryTag = deliveryTag;
lastBasicNackMultiple = multiple;
lastBasicNackRequeue = requeue;
}

public boolean isNack(final int deliveryTag) {
return nacks.get(deliveryTag);
}

public int getBasicNackCount() {
return basicNackCount;
}

public long getLastBasicNackDeliveryTag() {
return lastBasicNackDeliveryTag;
}

public boolean isLastBasicNackMultiple() {
return lastBasicNackMultiple;
}

public boolean isLastBasicNackRequeue() {
return lastBasicNackRequeue;
}

@Override
public void basicReject(long deliveryTag, boolean requeue) throws IOException {
nacks.set((int) deliveryTag);
Expand Down
Loading