Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,17 @@
public class BucketOffsetsRetrieverImpl implements OffsetsInitializer.BucketOffsetsRetriever {
private final Admin flussAdmin;
private final TablePath tablePath;
private final boolean fetchEarliestOffset;

public BucketOffsetsRetrieverImpl(Admin flussAdmin, TablePath tablePath) {
this(flussAdmin, tablePath, false);
}

public BucketOffsetsRetrieverImpl(
Admin flussAdmin, TablePath tablePath, boolean fetchEarliestOffset) {
this.flussAdmin = flussAdmin;
this.tablePath = tablePath;
this.fetchEarliestOffset = fetchEarliestOffset;
}

@Override
Expand All @@ -52,11 +59,15 @@ public Map<Integer, Long> latestOffsets(
@Override
public Map<Integer, Long> earliestOffsets(
@Nullable String partitionName, Collection<Integer> buckets) {
Map<Integer, Long> bucketWithOffset = new HashMap<>(buckets.size());
for (Integer bucket : buckets) {
bucketWithOffset.put(bucket, EARLIEST_OFFSET);
if (!fetchEarliestOffset) {
Map<Integer, Long> bucketWithOffset = new HashMap<>(buckets.size());
for (Integer bucket : buckets) {
bucketWithOffset.put(bucket, EARLIEST_OFFSET);
}
return bucketWithOffset;
} else {
return listOffsets(partitionName, buckets, new OffsetSpec.EarliestSpec());
}
return bucketWithOffset;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,14 @@ object SparkFlussConf {
.durationType()
.defaultValue(Duration.ofMillis(10000L))
.withDescription("The timeout for log scanner to poll records.")

val SCAN_MAX_RECORDS_PER_PARTITION: ConfigOption[java.lang.Long] =
ConfigBuilder
.key("scan.max.records.per.partition")
.longType()
.noDefaultValue()
.withDescription(
"The maximum number of records per Spark input partition when reading a log table. " +
"When set, each Fluss bucket whose offset range exceeds this value will be split " +
"into multiple partitions. Disabled by default (one partition per bucket).")
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.fluss.client.table.scanner.log.LogScanner
import org.apache.fluss.config.Configuration
import org.apache.fluss.metadata.{PartitionInfo, TableBucket, TableInfo, TablePath}
import org.apache.fluss.predicate.Predicate
import org.apache.fluss.spark.SparkFlussConf
import org.apache.fluss.spark.utils.SparkPartitionPredicate

import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory}
Expand Down Expand Up @@ -93,26 +94,58 @@ class FlussAppendBatch(
}

override def planInputPartitions(): Array[InputPartition] = {
val bucketOffsetsRetrieverImpl = new BucketOffsetsRetrieverImpl(admin, tablePath)
val maxRecordsPerPartition: Option[Long] = {
val opt = flussConfig.getOptional(SparkFlussConf.SCAN_MAX_RECORDS_PER_PARTITION)
if (opt.isPresent) Some(opt.get().longValue()) else None
}

val bucketOffsetsRetrieverImpl = maxRecordsPerPartition match {
case Some(_) => new BucketOffsetsRetrieverImpl(admin, tablePath, true)
case None => new BucketOffsetsRetrieverImpl(admin, tablePath)
}
val buckets = (0 until tableInfo.getNumBuckets).toSeq

def splitOffsetRange(
tableBucket: TableBucket,
startOffset: Long,
stopOffset: Long,
maxRecords: Long): Seq[InputPartition] = {
if (
startOffset < 0 || stopOffset <= startOffset || stopOffset <= (startOffset + maxRecords)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the earliest mode we have sentinel -2L, I think it would result in a bug here, since we clamp to 1 split

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The startup mode is designed for streaming scenarios. It will not be used in batch. I will refactor it in a future PR.

) {
return Seq(FlussAppendInputPartition(tableBucket, startOffset, stopOffset))
}
val rangeSize = stopOffset - startOffset
val numSplits = ((rangeSize + maxRecords - 1) / maxRecords).toInt
val step = (rangeSize + numSplits - 1) / numSplits

Iterator
.from(0)
.take(numSplits)
.map(i => startOffset + i * step)
.map {
from => FlussAppendInputPartition(tableBucket, from, math.min(from + step, stopOffset))
}
.toSeq
}

def createPartitions(
partitionId: Option[Long],
startBucketOffsets: Map[Integer, Long],
stoppingBucketOffsets: Map[Integer, Long]): Array[InputPartition] = {
buckets.map {
buckets.flatMap {
bucketId =>
val (startBucketOffset, stoppingBucketOffset) =
val (startOffset, stopOffset) =
(startBucketOffsets(bucketId), stoppingBucketOffsets(bucketId))
partitionId match {
case Some(partitionId) =>
val tableBucket = new TableBucket(tableInfo.getTableId, partitionId, bucketId)
FlussAppendInputPartition(tableBucket, startBucketOffset, stoppingBucketOffset)
.asInstanceOf[InputPartition]
case None =>
val tableBucket = new TableBucket(tableInfo.getTableId, bucketId)
FlussAppendInputPartition(tableBucket, startBucketOffset, stoppingBucketOffset)
.asInstanceOf[InputPartition]
val tableBucket = partitionId match {
case Some(pid) => new TableBucket(tableInfo.getTableId, pid, bucketId)
case None => new TableBucket(tableInfo.getTableId, bucketId)
}
maxRecordsPerPartition match {
case Some(maxRecs) if maxRecs > 0 =>
splitOffsetRange(tableBucket, startOffset, stopOffset, maxRecs)
case _ =>
Seq(FlussAppendInputPartition(tableBucket, startOffset, stopOffset))
}
}.toArray
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import org.apache.fluss.spark.read.{FlussMetrics, FlussScan}
import org.apache.fluss.spark.read.FlussAppendScan

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.Row
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation}
import org.assertj.core.api.Assertions.assertThat

Expand Down Expand Up @@ -603,4 +603,66 @@ class SparkLogTableReadTest extends FlussSparkTestBase {
assert(numRowsRead == 5L, s"Expected 5 rows read, got $numRowsRead")
}
}

test("Spark Read: split partition by config") {
Comment thread
Yohahaha marked this conversation as resolved.
withSampleTable {
withSQLConf("spark.sql.fluss.scan.max.records.per.partition" -> "2") {
val df = sql(s"SELECT amount FROM $DEFAULT_DATABASE.t ORDER BY orderId")
checkAnswer(df, Row(601) :: Row(602) :: Row(603) :: Row(604) :: Row(605) :: Nil)

val partitions = getInputPartitions(df)
assertThat(partitions.length).isEqualTo(3)
}
}

withTable("t_partition") {
sql(
s"""
|CREATE TABLE $DEFAULT_DATABASE.t_partition (orderId BIGINT, itemId BIGINT, amount INT, address STRING, dt STRING)
|PARTITIONED BY (dt)
|""".stripMargin
)

sql(s"""
|INSERT INTO $DEFAULT_DATABASE.t_partition VALUES
|(600L, 21L, 601, "addr1", "2026-01-01"), (700L, 22L, 602, "addr2", "2026-01-01"),
|(800L, 23L, 603, "addr3", "2026-01-02"), (900L, 24L, 604, "addr4", "2026-01-02"),
|(1000L, 25L, 605, "addr5", "2026-01-03")
|""".stripMargin)
Seq((0, 3), (1, 5), (2, 3)).foreach {
case (maxRecords, expectedPartitions) =>
withClue(s"maxRecords = $maxRecords, expectedPartitions = $expectedPartitions") {
withSQLConf("spark.sql.fluss.scan.max.records.per.partition" -> maxRecords.toString) {
val df = sql(s"SELECT * FROM $DEFAULT_DATABASE.t_partition ORDER BY orderId")
checkAnswer(
df,
Row(600L, 21L, 601, "addr1", "2026-01-01") ::
Row(700L, 22L, 602, "addr2", "2026-01-01") ::
Row(800L, 23L, 603, "addr3", "2026-01-02") ::
Row(900L, 24L, 604, "addr4", "2026-01-02") ::
Row(1000L, 25L, 605, "addr5", "2026-01-03") :: Nil
)

val partitions = getInputPartitions(df)
assertThat(partitions.length).isEqualTo(expectedPartitions)
}
}
}
}
}

private def getInputPartitions(df: DataFrame): Array[InputPartition] = {
// Try executedPlan first (after AQE), then optimizedPlan
val fromExecutedPlan = df.queryExecution.executedPlan.collect {
case b: BatchScanExec => b.inputPartitions.toArray
}
if (fromExecutedPlan.nonEmpty) {
fromExecutedPlan.head
} else {
val scans = df.queryExecution.optimizedPlan.collect {
case DataSourceV2ScanRelation(_, scan: FlussAppendScan, _, _, _) => scan
}
scans.headOption.map(_.toBatch.planInputPartitions()).getOrElse(Array.empty[InputPartition])
}
}
}
Loading