diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala index 128094a04c..87f2fdad0f 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala @@ -63,6 +63,40 @@ abstract class FlussBatch( } } + protected def createUpsertPartitions( + partitionName: String, + kvSnapshots: KvSnapshots, + bucketOffsetsRetriever: BucketOffsetsRetrieverImpl): Array[InputPartition] = { + val tableId = kvSnapshots.getTableId + val partitionId = kvSnapshots.getPartitionId + val bucketIds = kvSnapshots.getBucketIds + val bucketIdToLogOffset = + stoppingOffsetsInitializer.getBucketOffsets(partitionName, bucketIds, bucketOffsetsRetriever) + bucketIds.asScala + .map { + bucketId => + val tableBucket = new TableBucket(tableId, partitionId, bucketId) + val snapshotIdOpt = kvSnapshots.getSnapshotId(bucketId) + val logStartingOffsetOpt = kvSnapshots.getLogOffset(bucketId) + val logEndingOffset = bucketIdToLogOffset.get(bucketId) + + if (snapshotIdOpt.isPresent) { + assert( + logStartingOffsetOpt.isPresent, + "Log offset must be present when snapshot id is present") + FlussUpsertInputPartition( + tableBucket, + snapshotIdOpt.getAsLong, + logStartingOffsetOpt.getAsLong, + logEndingOffset) + } else { + FlussUpsertInputPartition(tableBucket, -1L, LogScanner.EARLIEST_OFFSET, logEndingOffset) + } + } + .map(_.asInstanceOf[InputPartition]) + .toArray + } + override def close(): Unit = { if (admin != null) { admin.close() @@ -196,48 +230,6 @@ class FlussUpsertBatch( private val bucketOffsetsRetriever = new BucketOffsetsRetrieverImpl(admin, tablePath) override def planInputPartitions(): Array[InputPartition] = { - def createPartitions(partitionName: String, kvSnapshots: KvSnapshots): Array[InputPartition] = { - val tableId = kvSnapshots.getTableId - val partitionId = kvSnapshots.getPartitionId - val bucketIds = kvSnapshots.getBucketIds - val bucketIdToLogOffset = - stoppingOffsetsInitializer.getBucketOffsets( - partitionName, - bucketIds, - bucketOffsetsRetriever) - bucketIds.asScala - .map { - bucketId => - val tableBucket = new TableBucket(tableId, partitionId, bucketId) - val snapshotIdOpt = kvSnapshots.getSnapshotId(bucketId) - val logStartingOffsetOpt = kvSnapshots.getLogOffset(bucketId) - val logEndingOffset = bucketIdToLogOffset.get(bucketId) - - if (snapshotIdOpt.isPresent) { - assert( - logStartingOffsetOpt.isPresent, - "Log offset must be present when snapshot id is present") - - // Create hybrid partition - FlussUpsertInputPartition( - tableBucket, - snapshotIdOpt.getAsLong, - logStartingOffsetOpt.getAsLong, - logEndingOffset - ) - } else { - // No snapshot yet, only read log from beginning - FlussUpsertInputPartition( - tableBucket, - -1L, - LogScanner.EARLIEST_OFFSET, - logEndingOffset) - } - } - .map(_.asInstanceOf[InputPartition]) - .toArray - } - if (tableInfo.isPartitioned) { val matching = SparkPartitionPredicate.filterPartitions(partitionInfos.asScala.toSeq, partitionPredicate) @@ -246,11 +238,11 @@ class FlussUpsertBatch( val partitionName = partitionInfo.getPartitionName val kvSnapshots = admin.getLatestKvSnapshots(tablePath, partitionName).get() - createPartitions(partitionName, kvSnapshots) + createUpsertPartitions(partitionName, kvSnapshots, bucketOffsetsRetriever) }.toArray } else { val kvSnapshots = admin.getLatestKvSnapshots(tablePath).get() - createPartitions(null, kvSnapshots) + createUpsertPartitions(null, kvSnapshots, bucketOffsetsRetriever) } } diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeUpsertBatch.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeUpsertBatch.scala index 2f99e39a57..2ef6ab7701 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeUpsertBatch.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeUpsertBatch.scala @@ -267,38 +267,18 @@ class FlussLakeUpsertBatch( } private def planFallbackPartitions(): Array[InputPartition] = { - // Fallback to pure Fluss kv reading when no lake snapshot exists val bucketOffsetsRetriever = new BucketOffsetsRetrieverImpl(admin, tablePath) - val buckets = (0 until tableInfo.getNumBuckets).toSeq - - def createPartitions( - partitionId: Option[Long], - partitionName: String): Array[InputPartition] = { - val stoppingOffsets = - getBucketOffsets(stoppingOffsetsInitializer, partitionName, buckets, bucketOffsetsRetriever) - - buckets.map { - bucketId => - val tableBucket = partitionId match { - case Some(pid) => new TableBucket(tableInfo.getTableId, pid, bucketId) - case None => new TableBucket(tableInfo.getTableId, bucketId) - } - // Use FlussUpsertInputPartition for fallback (reads from Fluss kv snapshot) - FlussUpsertInputPartition( - tableBucket, - -1L, // no snapshot - LogScanner.EARLIEST_OFFSET, - stoppingOffsets(bucketId) - ): InputPartition - }.toArray - } if (tableInfo.isPartitioned) { partitionInfos.asScala.flatMap { - pi => createPartitions(Some(pi.getPartitionId), pi.getPartitionName) + pi => + val partitionName = pi.getPartitionName + val kvSnapshots = admin.getLatestKvSnapshots(tablePath, partitionName).get() + createUpsertPartitions(partitionName, kvSnapshots, bucketOffsetsRetriever) }.toArray } else { - createPartitions(None, null) + val kvSnapshots = admin.getLatestKvSnapshots(tablePath).get() + createUpsertPartitions(null, kvSnapshots, bucketOffsetsRetriever) } } } diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakePrimaryKeyTableReadTestBase.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakePrimaryKeyTableReadTestBase.scala index 4913ee29e6..0de25f100e 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakePrimaryKeyTableReadTestBase.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakePrimaryKeyTableReadTestBase.scala @@ -91,6 +91,86 @@ abstract class SparkLakePrimaryKeyTableReadTestBase extends SparkLakeTableReadTe } } + test("Spark Lake Read: pk table fallback uses Fluss kv snapshot for log tail merge") { + // Non-partitioned table + withTable("t_fb_hybrid") { + val tablePath = createTablePath("t_fb_hybrid") + sql(s""" + |CREATE TABLE $DEFAULT_DATABASE.t_fb_hybrid (id INT, name STRING, score INT) + | TBLPROPERTIES ( + | '${ConfigOptions.TABLE_DATALAKE_ENABLED.key()}' = true, + | '${ConfigOptions.TABLE_DATALAKE_FRESHNESS.key()}' = '1s', + | '${PRIMARY_KEY.key()}' = 'id', + | '${BUCKET_NUMBER.key()}' = 1) + |""".stripMargin) + + sql(s""" + |INSERT INTO $DEFAULT_DATABASE.t_fb_hybrid VALUES + |(1, "alice", 90), (2, "bob", 85), (3, "charlie", 95) + |""".stripMargin) + + flussServer.triggerAndWaitSnapshot(tablePath) + + sql(s""" + |INSERT INTO $DEFAULT_DATABASE.t_fb_hybrid VALUES + |(2, "bob_updated", 100), (4, "david", 88) + |""".stripMargin) + + val df = sql(s"SELECT * FROM $DEFAULT_DATABASE.t_fb_hybrid ORDER BY id") + val partitions = lakeUpsertInputPartitions(df) + assert( + partitions.exists(_.snapshotId >= 0), + s"Expected at least one hybrid partition with snapshotId >= 0, got: ${partitions.mkString(", ")}") + checkAnswer( + df, + Row(1, "alice", 90) :: Row(2, "bob_updated", 100) :: + Row(3, "charlie", 95) :: Row(4, "david", 88) :: Nil + ) + } + + // Partitioned table + withTable("t_fb_hybrid_partitioned") { + val tablePath = createTablePath("t_fb_hybrid_partitioned") + sql(s""" + |CREATE TABLE $DEFAULT_DATABASE.t_fb_hybrid_partitioned (id INT, name STRING, score INT, dt STRING) + | PARTITIONED BY (dt) + | TBLPROPERTIES ( + | '${ConfigOptions.TABLE_DATALAKE_ENABLED.key()}' = true, + | '${ConfigOptions.TABLE_DATALAKE_FRESHNESS.key()}' = '1s', + | '${PRIMARY_KEY.key()}' = 'id,dt', + | '${BUCKET_NUMBER.key()}' = 1) + |""".stripMargin) + + sql(s""" + |INSERT INTO $DEFAULT_DATABASE.t_fb_hybrid_partitioned VALUES + |(1, "alice", 90, "2026-01-01"), + |(2, "bob", 85, "2026-01-01"), + |(3, "charlie", 95, "2026-01-02") + |""".stripMargin) + + flussServer.triggerAndWaitSnapshot(tablePath) + + sql(s""" + |INSERT INTO $DEFAULT_DATABASE.t_fb_hybrid_partitioned VALUES + |(2, "bob_updated", 100, "2026-01-01"), + |(4, "david", 88, "2026-01-02") + |""".stripMargin) + + val df = sql(s"SELECT * FROM $DEFAULT_DATABASE.t_fb_hybrid_partitioned ORDER BY id") + val partitions = lakeUpsertInputPartitions(df) + assert( + partitions.exists(_.snapshotId >= 0), + s"Expected at least one hybrid partition with snapshotId >= 0, got: ${partitions.mkString(", ")}") + checkAnswer( + df, + Row(1, "alice", 90, "2026-01-01") :: + Row(2, "bob_updated", 100, "2026-01-01") :: + Row(3, "charlie", 95, "2026-01-02") :: + Row(4, "david", 88, "2026-01-02") :: Nil + ) + } + } + test("Spark Lake Read: pk table lake-only (all data in lake, no kv tail)") { // Test non-partitioned table withTable("t_lake_only") { diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeTableReadTestBase.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeTableReadTestBase.scala index 9ca8d81a0d..526e898899 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeTableReadTestBase.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeTableReadTestBase.scala @@ -22,7 +22,7 @@ import org.apache.fluss.flink.tiering.LakeTieringJobBuilder import org.apache.fluss.flink.tiering.source.TieringSourceOptions import org.apache.fluss.metadata.{DataLakeFormat, TableBucket} import org.apache.fluss.spark.FlussSparkTestBase -import org.apache.fluss.spark.read.FlussScan +import org.apache.fluss.spark.read.{FlussLakeUpsertScan, FlussScan, FlussUpsertInputPartition} import org.apache.flink.api.common.RuntimeExecutionMode import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment @@ -151,4 +151,17 @@ abstract class SparkLakeTableReadTestBase extends FlussSparkTestBase { expected.exists(pushed.contains), s"Expected any of $expected in pushed predicates, got $pushed") } + + protected def lakeUpsertInputPartitions(df: DataFrame): Array[FlussUpsertInputPartition] = { + val scans = + df.queryExecution.executedPlan.collect { + case b: BatchScanExec => b.scan + } ++ df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _, _, _) => scan + } + scans + .collect { case s: FlussLakeUpsertScan => s } + .flatMap(_.toBatch.planInputPartitions().collect { case p: FlussUpsertInputPartition => p }) + .toArray + } }