From 93523e7e6c1fea3a9c11f058a620ba151c1d56e5 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sat, 10 May 2025 17:15:10 +1000 Subject: [PATCH 1/5] user specified schema --- .../elasticsearch/hadoop/rest/RestClient.java | 12 +++++++-- .../hadoop/rest/RestRepository.java | 4 +++ .../dto/mapping/FieldParser.java | 19 ++++++++++---- .../serialization/dto/mapping/MappingSet.java | 26 +++++++++---------- .../spark/sql/DefaultSource.scala | 9 +++++-- .../elasticsearch/spark/sql/SchemaUtils.scala | 20 +++++--------- 6 files changed, 54 insertions(+), 36 deletions(-) diff --git a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java index 1e05bfad4..444c39468 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java @@ -311,14 +311,22 @@ public List>> targetShards(String index, String routing } public MappingSet getMappings(Resource indexResource) { + return getMappings(indexResource, Collections.emptyList()); + } + + public MappingSet getMappings(Resource indexResource, Collection includeFields) { if (indexResource.isTyped()) { - return getMappings(indexResource.index() + "/_mapping/" + indexResource.type(), true); + return getMappings(indexResource.index() + "/_mapping/" + indexResource.type(), true, includeFields); } else { - return getMappings(indexResource.index() + "/_mapping" + (indexReadMissingAsEmpty ? "?ignore_unavailable=true" : ""), false); + return getMappings(indexResource.index() + "/_mapping" + (indexReadMissingAsEmpty ? "?ignore_unavailable=true" : ""), false, includeFields); } } public MappingSet getMappings(String query, boolean includeTypeName) { + return getMappings(query, includeTypeName, Collections.emptyList()); + } + + public MappingSet getMappings(String query, boolean includeTypeName, Collection includeFields) { // If the version is not at least 7, then the property isn't guaranteed to exist. If it is, then defer to the flag. boolean requestTypeNameInResponse = clusterInfo.getMajorVersion().onOrAfter(EsMajorVersion.V_7_X) && includeTypeName; // Response will always have the type name in it if node version is before 7, and if it is not, defer to the flag. diff --git a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java index 23609a4cf..518addd65 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java @@ -300,6 +300,10 @@ public MappingSet getMappings() { return client.getMappings(resources.getResourceRead()); } + public MappingSet getMappings(List includeFields) { + return client.getMappings(resources.getResourceRead(), includeFields); + } + public Map sampleGeoFields(Mapping mapping) { Map fields = MappingUtils.geoFields(mapping); Map geoMapping = client.sampleForFields(resources.getResourceRead(), fields.keySet()); diff --git a/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/FieldParser.java b/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/FieldParser.java index 8a1cd7763..99f0114a0 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/FieldParser.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/FieldParser.java @@ -19,10 +19,7 @@ package org.elasticsearch.hadoop.serialization.dto.mapping; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.*; import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException; import org.elasticsearch.hadoop.serialization.FieldType; @@ -52,13 +49,25 @@ public static MappingSet parseTypelessMappings(Map content) { * @return MappingSet for that response. */ public static MappingSet parseMappings(Map content, boolean includeTypeName) { + return parseMappings(content, includeTypeName, Collections.emptyList()); + } + + /** + * Convert the deserialized mapping request body into an object + * @param content entire mapping request body for all indices and types + * @param includeTypeName true if the given content to be parsed includes type names within the structure, + * or false if it is in the typeless format + * @param includeFields list of field that should have mapping checked + * @return MappingSet for that response. + */ + public static MappingSet parseMappings(Map content, boolean includeTypeName, Collection includeFields) { Iterator> indices = content.entrySet().iterator(); List indexMappings = new ArrayList(); while(indices.hasNext()) { // These mappings are ordered by index, then optionally type. parseIndexMappings(indices.next(), indexMappings, includeTypeName); } - return new MappingSet(indexMappings); + return new MappingSet(indexMappings, includeFields); } private static void parseIndexMappings(Map.Entry indexToMappings, List collector, boolean includeTypeName) { diff --git a/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingSet.java b/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingSet.java index 8e1be4f09..438165948 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingSet.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingSet.java @@ -20,12 +20,7 @@ package org.elasticsearch.hadoop.serialization.dto.mapping; import java.io.Serializable; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; +import java.util.*; import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException; import org.elasticsearch.hadoop.serialization.FieldType; @@ -46,7 +41,7 @@ public class MappingSet implements Serializable { private final Map> indexTypeMap = new HashMap>(); private final Mapping resolvedSchema; - public MappingSet(List mappings) { + public MappingSet(List mappings, Collection includeFields) { if (mappings.isEmpty()) { this.empty = true; this.resolvedSchema = new Mapping(RESOLVED_INDEX_NAME, RESOLVED_MAPPING_NAME, Field.NO_FIELDS); @@ -78,15 +73,15 @@ public MappingSet(List mappings) { mappingsToSchema.put(typeName, mapping); } - this.resolvedSchema = mergeMappings(mappings); + this.resolvedSchema = mergeMappings(mappings, includeFields); } } - private static Mapping mergeMappings(List mappings) { + private static Mapping mergeMappings(List mappings, Collection includeFields) { Map fieldMap = new LinkedHashMap(); for (Mapping mapping: mappings) { for (Field field : mapping.getFields()) { - addToFieldTable(field, "", fieldMap); + addToFieldTable(field, "", fieldMap, includeFields); } } Field[] collapsed = collapseFields(fieldMap); @@ -94,10 +89,13 @@ private static Mapping mergeMappings(List mappings) { } @SuppressWarnings("unchecked") - private static void addToFieldTable(Field field, String parent, Map fieldTable) { + private static void addToFieldTable(Field field, String parent, Map fieldTable, Collection includeFields) { String fullName = parent + field.name(); Object[] entry = fieldTable.get(fullName); - if (entry == null) { + if (!includeFields.isEmpty() && !includeFields.contains(fullName)) { + return; + } + else if (entry == null) { // Haven't seen field yet. if (FieldType.isCompound(field.type())) { // visit its children @@ -105,7 +103,7 @@ private static void addToFieldTable(Field field, String parent, Map subTable = (Map)entry[1]; String prefix = fullName + "."; for (Field subField : field.properties()) { - addToFieldTable(subField, prefix, subTable); + addToFieldTable(subField, prefix, subTable, includeFields); } } } diff --git a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala index 86ffbfa17..8b754f381 100644 --- a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala +++ b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala @@ -80,6 +80,7 @@ import org.elasticsearch.hadoop.util.StringUtils import org.elasticsearch.hadoop.util.Version import org.elasticsearch.spark.cfg.SparkSettingsManager import org.elasticsearch.spark.serialization.ScalaValueWriter +import org.elasticsearch.spark.sql.SchemaUtils.{Schema, discoverMapping} import org.elasticsearch.spark.sql.streaming.EsSparkSqlStreamingSink import org.elasticsearch.spark.sql.streaming.SparkSqlStreamingConfigs import org.elasticsearch.spark.sql.streaming.StructuredStreamingVersionLock @@ -235,11 +236,15 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @ conf } - @transient lazy val lazySchema = { SchemaUtils.discoverMapping(cfg) } + @transient lazy val lazySchema = userSchema match { + case None => SchemaUtils.discoverMapping(cfg) + //TODO: properly flatten the schema so we can selectively check mapping of nested field as well + case Some(s) => SchemaUtils.discoverMapping(cfg, s.names) // Or we just take the user specified schema as it is: Schema(s) + } @transient lazy val valueWriter = { new ScalaValueWriter } - override def schema = userSchema.getOrElse(lazySchema.struct) + override def schema: StructType = lazySchema.struct // TableScan def buildScan(): RDD[Row] = buildScan(Array.empty) diff --git a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala index 849ff5a78..d01dcddd9 100644 --- a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala +++ b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala @@ -23,7 +23,6 @@ import java.util.{LinkedHashSet => JHashSet} import java.util.{List => JList} import java.util.{Map => JMap} import java.util.Properties - import scala.collection.JavaConverters.asScalaBufferConverter import scala.collection.JavaConverters.propertiesAsScalaMapConverter import scala.collection.mutable.ArrayBuffer @@ -70,12 +69,7 @@ import org.elasticsearch.hadoop.serialization.FieldType.SHORT import org.elasticsearch.hadoop.serialization.FieldType.STRING import org.elasticsearch.hadoop.serialization.FieldType.TEXT import org.elasticsearch.hadoop.serialization.FieldType.WILDCARD -import org.elasticsearch.hadoop.serialization.dto.mapping.Field -import org.elasticsearch.hadoop.serialization.dto.mapping.GeoField -import org.elasticsearch.hadoop.serialization.dto.mapping.GeoPointType -import org.elasticsearch.hadoop.serialization.dto.mapping.GeoShapeType -import org.elasticsearch.hadoop.serialization.dto.mapping.Mapping -import org.elasticsearch.hadoop.serialization.dto.mapping.MappingUtils +import org.elasticsearch.hadoop.serialization.dto.mapping.{Field, GeoField, GeoPointType, GeoShapeType, Mapping, MappingSet, MappingUtils} import org.elasticsearch.hadoop.serialization.field.FieldFilter import org.elasticsearch.hadoop.serialization.field.FieldFilter.NumberedInclude import org.elasticsearch.hadoop.util.Assert @@ -87,22 +81,22 @@ import org.elasticsearch.spark.sql.Utils.ROW_INFO_ARRAY_PROPERTY import org.elasticsearch.spark.sql.Utils.ROW_INFO_ORDER_PROPERTY private[sql] object SchemaUtils { - case class Schema(mapping: Mapping, struct: StructType) + case class Schema(struct: StructType) - def discoverMapping(cfg: Settings): Schema = { - val (mapping, geoInfo) = discoverMappingAndGeoFields(cfg) + def discoverMapping(cfg: Settings, includeFields: Seq[String] = Seq.empty[String]): Schema = { + val (mapping, geoInfo) = discoverMappingAndGeoFields(cfg, includeFields) val struct = convertToStruct(mapping, geoInfo, cfg) - Schema(mapping, struct) + Schema(struct) } - def discoverMappingAndGeoFields(cfg: Settings): (Mapping, JMap[String, GeoField]) = { + def discoverMappingAndGeoFields(cfg: Settings, includeFields: Seq[String]): (Mapping, JMap[String, GeoField]) = { InitializationUtils.validateSettings(cfg) InitializationUtils.discoverClusterInfo(cfg, Utils.LOGGER) val repo = new RestRepository(cfg) try { if (repo.resourceExists(true)) { - var mappingSet = repo.getMappings + val mappingSet = repo.getMappings if (mappingSet == null || mappingSet.isEmpty) { throw new EsHadoopIllegalArgumentException(s"Cannot find mapping for ${cfg.getResourceRead} - one is required before using Spark SQL") } From 77f5a61698512c2f6d83f2df30b19c585dff4821 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sat, 10 May 2025 17:26:54 +1000 Subject: [PATCH 2/5] selective mapping top level fields --- .../org/elasticsearch/hadoop/rest/RestRepository.java | 8 ++------ .../scala/org/elasticsearch/spark/sql/SchemaUtils.scala | 4 +++- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java index 518addd65..47248aaa3 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java @@ -54,11 +54,7 @@ import java.io.Closeable; import java.io.IOException; import java.io.InputStream; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.Map.Entry; import static org.elasticsearch.hadoop.rest.Request.Method.POST; @@ -300,7 +296,7 @@ public MappingSet getMappings() { return client.getMappings(resources.getResourceRead()); } - public MappingSet getMappings(List includeFields) { + public MappingSet getMappings(Collection includeFields) { return client.getMappings(resources.getResourceRead(), includeFields); } diff --git a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala index d01dcddd9..78df5c99d 100644 --- a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala +++ b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala @@ -80,6 +80,8 @@ import org.elasticsearch.spark.sql.Utils.ROOT_LEVEL_NAME import org.elasticsearch.spark.sql.Utils.ROW_INFO_ARRAY_PROPERTY import org.elasticsearch.spark.sql.Utils.ROW_INFO_ORDER_PROPERTY +import scala.jdk.CollectionConverters.SeqHasAsJava + private[sql] object SchemaUtils { case class Schema(struct: StructType) @@ -96,7 +98,7 @@ private[sql] object SchemaUtils { val repo = new RestRepository(cfg) try { if (repo.resourceExists(true)) { - val mappingSet = repo.getMappings + val mappingSet = repo.getMappings(includeFields.asJava) if (mappingSet == null || mappingSet.isEmpty) { throw new EsHadoopIllegalArgumentException(s"Cannot find mapping for ${cfg.getResourceRead} - one is required before using Spark SQL") } From f3d46373f271028eb231c1a45353a7a798256290 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sun, 13 Jul 2025 19:35:58 +1000 Subject: [PATCH 3/5] - only include fields exist in user specified schema - AbstractEsRDD to use mapping from first discoverMapping --- .../hadoop/mr/EsInputFormat.java | 2 +- .../elasticsearch/hadoop/rest/RestClient.java | 2 +- .../hadoop/rest/RestService.java | 18 +++--- .../spark/rdd/AbstractEsRDD.scala | 9 ++- .../spark/sql/DefaultSource.scala | 6 +- .../spark/sql/ScalaEsRowRDD.scala | 2 +- .../elasticsearch/spark/sql/SchemaUtils.scala | 56 +++++++++++-------- 7 files changed, 49 insertions(+), 46 deletions(-) diff --git a/mr/src/main/java/org/elasticsearch/hadoop/mr/EsInputFormat.java b/mr/src/main/java/org/elasticsearch/hadoop/mr/EsInputFormat.java index cde07ec01..555859415 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/mr/EsInputFormat.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/mr/EsInputFormat.java @@ -412,7 +412,7 @@ public EsInputRecordReader createRecordReader(InputSplit split, TaskAttemp public org.apache.hadoop.mapred.InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { Settings settings = HadoopSettingsManager.loadFrom(job); - Collection partitions = RestService.findPartitions(settings, log); + Collection partitions = RestService.findPartitions(settings, log, null); EsInputSplit[] splits = new EsInputSplit[partitions.size()]; int index = 0; diff --git a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java index 444c39468..d9d7a13d4 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java @@ -336,7 +336,7 @@ public MappingSet getMappings(String query, boolean includeTypeName, Collection< } Map result = get(query, null); if (result != null && !result.isEmpty()) { - return FieldParser.parseMappings(result, typeNameInResponse); + return FieldParser.parseMappings(result, typeNameInResponse, includeFields); } return null; } diff --git a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestService.java b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestService.java index d0b5ad58b..2fb5f8078 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestService.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestService.java @@ -212,7 +212,7 @@ public void remove() { } @SuppressWarnings("unchecked") - public static List findPartitions(Settings settings, Log log) { + public static List findPartitions(Settings settings, Log log, Mapping resolvedMapping) { Version.logVersion(); InitializationUtils.validateSettings(settings); @@ -244,16 +244,18 @@ public static List findPartitions(Settings settings, Log lo log.info(String.format("Reading from [%s]", settings.getResourceRead())); - MappingSet mapping = null; + Mapping mapping = resolvedMapping; if (!shards.isEmpty()) { - mapping = client.getMappings(); + if (mapping == null) { + mapping = client.getMappings().getResolvedView(); + } if (log.isDebugEnabled()) { - log.debug(String.format("Discovered resolved mapping {%s} for [%s]", mapping.getResolvedView(), settings.getResourceRead())); + log.debug(String.format("Discovered resolved mapping {%s} for [%s]", mapping, settings.getResourceRead())); } // validate if possible FieldPresenceValidation validation = settings.getReadFieldExistanceValidation(); if (validation.isRequired()) { - MappingUtils.validateMapping(SettingsUtils.determineSourceFields(settings), mapping.getResolvedView(), validation, log); + MappingUtils.validateMapping(SettingsUtils.determineSourceFields(settings), mapping, validation, log); } } final Map nodesMap = new HashMap(); @@ -278,9 +280,8 @@ public static List findPartitions(Settings settings, Log lo /** * Create one {@link PartitionDefinition} per shard for each requested index. */ - static List findShardPartitions(Settings settings, MappingSet mappingSet, Map nodes, + static List findShardPartitions(Settings settings, Mapping resolvedMapping, Map nodes, List>> shards, Log log) { - Mapping resolvedMapping = mappingSet == null ? null : mappingSet.getResolvedView(); List partitions = new ArrayList(shards.size()); PartitionDefinition.PartitionDefinitionBuilder partitionBuilder = PartitionDefinition.builder(settings, resolvedMapping); for (List> group : shards) { @@ -316,13 +317,12 @@ static List findShardPartitions(Settings settings, MappingS /** * Partitions the query based on the max number of documents allowed per partition {@link Settings#getMaxDocsPerPartition()}. */ - static List findSlicePartitions(RestClient client, Settings settings, MappingSet mappingSet, + static List findSlicePartitions(RestClient client, Settings settings, Mapping resolvedMapping, Map nodes, List>> shards, Log log) { QueryBuilder query = QueryUtils.parseQueryAndFilters(settings); Integer maxDocsPerPartition = settings.getMaxDocsPerPartition(); Assert.notNull(maxDocsPerPartition, "Attempting to find slice partitions but maximum documents per partition is not set."); Resource readResource = new Resource(settings, true); - Mapping resolvedMapping = mappingSet == null ? null : mappingSet.getResolvedView(); PartitionDefinition.PartitionDefinitionBuilder partitionBuilder = PartitionDefinition.builder(settings, resolvedMapping); List partitions = new ArrayList(shards.size()); diff --git a/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala b/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala index 559664144..f558d8013 100644 --- a/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala +++ b/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala @@ -19,7 +19,6 @@ package org.elasticsearch.spark.rdd; import JDKCollectionConvertersCompat.Converters._ -import scala.reflect.ClassTag import org.apache.commons.logging.LogFactory import org.apache.spark.Partition import org.apache.spark.SparkContext @@ -31,12 +30,12 @@ import org.elasticsearch.hadoop.rest.PartitionDefinition import org.elasticsearch.hadoop.util.ObjectUtils import org.elasticsearch.spark.cfg.SparkSettingsManager import org.elasticsearch.hadoop.rest.RestRepository - -import scala.annotation.meta.param +import org.elasticsearch.hadoop.serialization.dto.mapping.{Mapping, MappingSet} private[spark] abstract class AbstractEsRDD[T: ClassTag]( @(transient @param) sc: SparkContext, - val params: scala.collection.Map[String, String] = Map.empty) + val params: scala.collection.Map[String, String] = Map.empty, + @(transient @param) mapping: Mapping) extends RDD[T](sc, Nil) { private val init = { ObjectUtils.loadClass("org.elasticsearch.spark.rdd.CompatUtils", classOf[ObjectUtils].getClassLoader) } @@ -75,7 +74,7 @@ private[spark] abstract class AbstractEsRDD[T: ClassTag]( } @transient private[spark] lazy val esPartitions = { - RestService.findPartitions(esCfg, logger) + RestService.findPartitions(esCfg, logger, mapping) } } diff --git a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala index 8b754f381..fb57e422d 100644 --- a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala +++ b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala @@ -236,11 +236,7 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @ conf } - @transient lazy val lazySchema = userSchema match { - case None => SchemaUtils.discoverMapping(cfg) - //TODO: properly flatten the schema so we can selectively check mapping of nested field as well - case Some(s) => SchemaUtils.discoverMapping(cfg, s.names) // Or we just take the user specified schema as it is: Schema(s) - } + @transient lazy val lazySchema = SchemaUtils.discoverMapping(cfg, userSchema) @transient lazy val valueWriter = { new ScalaValueWriter } diff --git a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/ScalaEsRowRDD.scala b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/ScalaEsRowRDD.scala index 7b545f15c..e9791f06b 100644 --- a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/ScalaEsRowRDD.scala +++ b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/ScalaEsRowRDD.scala @@ -41,7 +41,7 @@ private[spark] class ScalaEsRowRDD( @(transient @param) sc: SparkContext, params: Map[String, String] = Map.empty, schema: SchemaUtils.Schema) - extends AbstractEsRDD[Row](sc, params) { + extends AbstractEsRDD[Row](sc, params, schema.mapping) { override def compute(split: Partition, context: TaskContext): ScalaEsRowRDDIterator = { new ScalaEsRowRDDIterator(context, split.asInstanceOf[EsPartition].esPartition, schema) diff --git a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala index 78df5c99d..c00609732 100644 --- a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala +++ b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala @@ -23,25 +23,7 @@ import java.util.{LinkedHashSet => JHashSet} import java.util.{List => JList} import java.util.{Map => JMap} import java.util.Properties -import scala.collection.JavaConverters.asScalaBufferConverter -import scala.collection.JavaConverters.propertiesAsScalaMapConverter -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.types.ArrayType -import org.apache.spark.sql.types.BinaryType -import org.apache.spark.sql.types.BooleanType -import org.apache.spark.sql.types.ByteType -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.types.DataTypes -import org.apache.spark.sql.types.DoubleType -import org.apache.spark.sql.types.FloatType -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.sql.types.LongType -import org.apache.spark.sql.types.NullType -import org.apache.spark.sql.types.ShortType -import org.apache.spark.sql.types.StringType -import org.apache.spark.sql.types.StructField -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.types.TimestampType +import org.apache.spark.sql.types._ import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException import org.elasticsearch.hadoop.cfg.InternalConfigurationOptions import org.elasticsearch.hadoop.cfg.Settings @@ -79,16 +61,42 @@ import org.elasticsearch.hadoop.util.StringUtils import org.elasticsearch.spark.sql.Utils.ROOT_LEVEL_NAME import org.elasticsearch.spark.sql.Utils.ROW_INFO_ARRAY_PROPERTY import org.elasticsearch.spark.sql.Utils.ROW_INFO_ORDER_PROPERTY - -import scala.jdk.CollectionConverters.SeqHasAsJava +import scala.annotation.tailrec private[sql] object SchemaUtils { - case class Schema(struct: StructType) + case class Schema(struct: StructType, mapping: Mapping) - def discoverMapping(cfg: Settings, includeFields: Seq[String] = Seq.empty[String]): Schema = { + def discoverMapping(cfg: Settings, userSchema: Option[StructType] = None): Schema = { + val includeFields = structToColumnsNames(userSchema) val (mapping, geoInfo) = discoverMappingAndGeoFields(cfg, includeFields) val struct = convertToStruct(mapping, geoInfo, cfg) - Schema(struct) + Schema(struct, mapping) + } + + def structToColumnsNames(struct: Option[StructType]): Seq[String] = { + @tailrec + def getInnerMostType(dType: DataType): DataType = dType match { + case at: ArrayType => getInnerMostType(at.elementType) + case t => t + } + + @tailrec + def flattenFields(remaining: Seq[(String, DataType)], acc: Seq[String]): Seq[String] = remaining match { + case Nil => acc + case (name, dataType) :: tail => + getInnerMostType(dataType) match { + case s: StructType => + val nestedFields = s.fields.map(f => (s"$name.${f.name}", f.dataType)) + flattenFields(nestedFields ++ tail, acc :+ name) + case _ => + flattenFields(tail, name +: acc) + } + } + + struct match { + case None => Seq.empty + case Some(s) => flattenFields(s.fields.map(f => (f.name, f.dataType)), Seq.empty) + } } def discoverMappingAndGeoFields(cfg: Settings, includeFields: Seq[String]): (Mapping, JMap[String, GeoField]) = { From a7222f47554eab0239d6e687fa573051af9318ad Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sun, 13 Jul 2025 22:34:09 +1000 Subject: [PATCH 4/5] default null for mapping --- .../main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala b/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala index f558d8013..2f00c2eb4 100644 --- a/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala +++ b/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala @@ -35,7 +35,7 @@ import org.elasticsearch.hadoop.serialization.dto.mapping.{Mapping, MappingSet} private[spark] abstract class AbstractEsRDD[T: ClassTag]( @(transient @param) sc: SparkContext, val params: scala.collection.Map[String, String] = Map.empty, - @(transient @param) mapping: Mapping) + @(transient @param) mapping: Mapping = null) extends RDD[T](sc, Nil) { private val init = { ObjectUtils.loadClass("org.elasticsearch.spark.rdd.CompatUtils", classOf[ObjectUtils].getClassLoader) } From af1e47b858b82a6409e489d8cc5855856f915582 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Mon, 14 Jul 2025 21:39:48 +1000 Subject: [PATCH 5/5] Revert Schema case class changes --- .../main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala index c00609732..fc1f5e366 100644 --- a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala +++ b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala @@ -64,13 +64,13 @@ import org.elasticsearch.spark.sql.Utils.ROW_INFO_ORDER_PROPERTY import scala.annotation.tailrec private[sql] object SchemaUtils { - case class Schema(struct: StructType, mapping: Mapping) + case class Schema(mapping: Mapping, struct: StructType) def discoverMapping(cfg: Settings, userSchema: Option[StructType] = None): Schema = { val includeFields = structToColumnsNames(userSchema) val (mapping, geoInfo) = discoverMappingAndGeoFields(cfg, includeFields) val struct = convertToStruct(mapping, geoInfo, cfg) - Schema(struct, mapping) + Schema(mapping, struct) } def structToColumnsNames(struct: Option[StructType]): Seq[String] = {