diff --git a/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java b/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java index e868a84cc1..2abd998828 100644 --- a/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java +++ b/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java @@ -37,10 +37,17 @@ public class BucketOffsetsRetrieverImpl implements OffsetsInitializer.BucketOffsetsRetriever { private final Admin flussAdmin; private final TablePath tablePath; + private final boolean fetchEarliestOffset; public BucketOffsetsRetrieverImpl(Admin flussAdmin, TablePath tablePath) { + this(flussAdmin, tablePath, false); + } + + public BucketOffsetsRetrieverImpl( + Admin flussAdmin, TablePath tablePath, boolean fetchEarliestOffset) { this.flussAdmin = flussAdmin; this.tablePath = tablePath; + this.fetchEarliestOffset = fetchEarliestOffset; } @Override @@ -52,11 +59,15 @@ public Map latestOffsets( @Override public Map earliestOffsets( @Nullable String partitionName, Collection buckets) { - Map bucketWithOffset = new HashMap<>(buckets.size()); - for (Integer bucket : buckets) { - bucketWithOffset.put(bucket, EARLIEST_OFFSET); + if (!fetchEarliestOffset) { + Map bucketWithOffset = new HashMap<>(buckets.size()); + for (Integer bucket : buckets) { + bucketWithOffset.put(bucket, EARLIEST_OFFSET); + } + return bucketWithOffset; + } else { + return listOffsets(partitionName, buckets, new OffsetSpec.EarliestSpec()); } - return bucketWithOffset; } @Override diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala index 28fb633b52..00d6400f64 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala @@ -50,4 +50,14 @@ object SparkFlussConf { .durationType() .defaultValue(Duration.ofMillis(10000L)) .withDescription("The timeout for log scanner to poll records.") + + val SCAN_MAX_RECORDS_PER_PARTITION: ConfigOption[java.lang.Long] = + ConfigBuilder + .key("scan.max.records.per.partition") + .longType() + .noDefaultValue() + .withDescription( + "The maximum number of records per Spark input partition when reading a log table. " + + "When set, each Fluss bucket whose offset range exceeds this value will be split " + + "into multiple partitions. Disabled by default (one partition per bucket).") } diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala index 128094a04c..c5a9a3e49c 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala @@ -25,6 +25,7 @@ import org.apache.fluss.client.table.scanner.log.LogScanner import org.apache.fluss.config.Configuration import org.apache.fluss.metadata.{PartitionInfo, TableBucket, TableInfo, TablePath} import org.apache.fluss.predicate.Predicate +import org.apache.fluss.spark.SparkFlussConf import org.apache.fluss.spark.utils.SparkPartitionPredicate import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory} @@ -93,26 +94,58 @@ class FlussAppendBatch( } override def planInputPartitions(): Array[InputPartition] = { - val bucketOffsetsRetrieverImpl = new BucketOffsetsRetrieverImpl(admin, tablePath) + val maxRecordsPerPartition: Option[Long] = { + val opt = flussConfig.getOptional(SparkFlussConf.SCAN_MAX_RECORDS_PER_PARTITION) + if (opt.isPresent) Some(opt.get().longValue()) else None + } + + val bucketOffsetsRetrieverImpl = maxRecordsPerPartition match { + case Some(_) => new BucketOffsetsRetrieverImpl(admin, tablePath, true) + case None => new BucketOffsetsRetrieverImpl(admin, tablePath) + } val buckets = (0 until tableInfo.getNumBuckets).toSeq + def splitOffsetRange( + tableBucket: TableBucket, + startOffset: Long, + stopOffset: Long, + maxRecords: Long): Seq[InputPartition] = { + if ( + startOffset < 0 || stopOffset <= startOffset || stopOffset <= (startOffset + maxRecords) + ) { + return Seq(FlussAppendInputPartition(tableBucket, startOffset, stopOffset)) + } + val rangeSize = stopOffset - startOffset + val numSplits = ((rangeSize + maxRecords - 1) / maxRecords).toInt + val step = (rangeSize + numSplits - 1) / numSplits + + Iterator + .from(0) + .take(numSplits) + .map(i => startOffset + i * step) + .map { + from => FlussAppendInputPartition(tableBucket, from, math.min(from + step, stopOffset)) + } + .toSeq + } + def createPartitions( partitionId: Option[Long], startBucketOffsets: Map[Integer, Long], stoppingBucketOffsets: Map[Integer, Long]): Array[InputPartition] = { - buckets.map { + buckets.flatMap { bucketId => - val (startBucketOffset, stoppingBucketOffset) = + val (startOffset, stopOffset) = (startBucketOffsets(bucketId), stoppingBucketOffsets(bucketId)) - partitionId match { - case Some(partitionId) => - val tableBucket = new TableBucket(tableInfo.getTableId, partitionId, bucketId) - FlussAppendInputPartition(tableBucket, startBucketOffset, stoppingBucketOffset) - .asInstanceOf[InputPartition] - case None => - val tableBucket = new TableBucket(tableInfo.getTableId, bucketId) - FlussAppendInputPartition(tableBucket, startBucketOffset, stoppingBucketOffset) - .asInstanceOf[InputPartition] + val tableBucket = partitionId match { + case Some(pid) => new TableBucket(tableInfo.getTableId, pid, bucketId) + case None => new TableBucket(tableInfo.getTableId, bucketId) + } + maxRecordsPerPartition match { + case Some(maxRecs) if maxRecs > 0 => + splitOffsetRange(tableBucket, startOffset, stopOffset, maxRecs) + case _ => + Seq(FlussAppendInputPartition(tableBucket, startOffset, stopOffset)) } }.toArray } diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala index 42b0aa62d0..4ccc58f928 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala @@ -21,8 +21,8 @@ import org.apache.fluss.spark.read.{FlussMetrics, FlussScan} import org.apache.fluss.spark.read.FlussAppendScan import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.Row import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import org.assertj.core.api.Assertions.assertThat @@ -603,4 +603,66 @@ class SparkLogTableReadTest extends FlussSparkTestBase { assert(numRowsRead == 5L, s"Expected 5 rows read, got $numRowsRead") } } + + test("Spark Read: split partition by config") { + withSampleTable { + withSQLConf("spark.sql.fluss.scan.max.records.per.partition" -> "2") { + val df = sql(s"SELECT amount FROM $DEFAULT_DATABASE.t ORDER BY orderId") + checkAnswer(df, Row(601) :: Row(602) :: Row(603) :: Row(604) :: Row(605) :: Nil) + + val partitions = getInputPartitions(df) + assertThat(partitions.length).isEqualTo(3) + } + } + + withTable("t_partition") { + sql( + s""" + |CREATE TABLE $DEFAULT_DATABASE.t_partition (orderId BIGINT, itemId BIGINT, amount INT, address STRING, dt STRING) + |PARTITIONED BY (dt) + |""".stripMargin + ) + + sql(s""" + |INSERT INTO $DEFAULT_DATABASE.t_partition VALUES + |(600L, 21L, 601, "addr1", "2026-01-01"), (700L, 22L, 602, "addr2", "2026-01-01"), + |(800L, 23L, 603, "addr3", "2026-01-02"), (900L, 24L, 604, "addr4", "2026-01-02"), + |(1000L, 25L, 605, "addr5", "2026-01-03") + |""".stripMargin) + Seq((0, 3), (1, 5), (2, 3)).foreach { + case (maxRecords, expectedPartitions) => + withClue(s"maxRecords = $maxRecords, expectedPartitions = $expectedPartitions") { + withSQLConf("spark.sql.fluss.scan.max.records.per.partition" -> maxRecords.toString) { + val df = sql(s"SELECT * FROM $DEFAULT_DATABASE.t_partition ORDER BY orderId") + checkAnswer( + df, + Row(600L, 21L, 601, "addr1", "2026-01-01") :: + Row(700L, 22L, 602, "addr2", "2026-01-01") :: + Row(800L, 23L, 603, "addr3", "2026-01-02") :: + Row(900L, 24L, 604, "addr4", "2026-01-02") :: + Row(1000L, 25L, 605, "addr5", "2026-01-03") :: Nil + ) + + val partitions = getInputPartitions(df) + assertThat(partitions.length).isEqualTo(expectedPartitions) + } + } + } + } + } + + private def getInputPartitions(df: DataFrame): Array[InputPartition] = { + // Try executedPlan first (after AQE), then optimizedPlan + val fromExecutedPlan = df.queryExecution.executedPlan.collect { + case b: BatchScanExec => b.inputPartitions.toArray + } + if (fromExecutedPlan.nonEmpty) { + fromExecutedPlan.head + } else { + val scans = df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan: FlussAppendScan, _, _, _) => scan + } + scans.headOption.map(_.toBatch.planInputPartitions()).getOrElse(Array.empty[InputPartition]) + } + } }