Spark MLlib特征处理 之 StringIndexer、IndexToString使用说明以及源码剖析 (2)

运行后发现异常:

18/07/05 20:20:32 INFO StateStoreCoordinatorRef: Registered StateStoreCoordinator endpoint Exception in thread "main" java.lang.IllegalArgumentException: Field "categoryIndex" does not exist. at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266) at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266) at scala.collection.MapLike$class.getOrElse(MapLike.scala:128) at scala.collection.AbstractMap.getOrElse(Map.scala:59) at org.apache.spark.sql.types.StructType.apply(StructType.scala:265) at org.apache.spark.ml.feature.IndexToString.transformSchema(StringIndexer.scala:338) at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74) at org.apache.spark.ml.feature.IndexToString.transform(StringIndexer.scala:352) at xingoo.ml.features.tranformer.IndexToString3$.main(IndexToString3.scala:37) at xingoo.ml.features.tranformer.IndexToString3.main(IndexToString3.scala)

这是为什么呢?跟随源码来看吧!

源码剖析

首先我们创建一个DataFrame,获得原始数据:

val df = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c") )).toDF("id", "category")

然后创建对应的StringIndexer:

val indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex") .setHandleInvalid("skip") .fit(df)

这里面的fit就是在训练转换器了,进入fit():

override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) // 这里针对需要转换的列先强制转换成字符串,然后遍历统计每个字符串出现的次数 val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) .countByValue() // counts是一个map,里面的内容为{a->3, b->1, c->2} val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray // 按照个数大小排序,返回数组,[a, c, b] // 把这个label保存起来,并返回对应的model(mllib里边的模型都是这个套路,跟sklearn学的) copyValues(new StringIndexerModel(uid, labels).setParent(this)) }

这样就得到了一个列表,列表里面的内容是[a, c, b],然后执行transform来进行转换:

val indexed = indexer.transform(df)

这个transform可想而知就是用这个数组对每一行的该列进行转换,但是它其实还做了其他的事情:

override def transform(dataset: Dataset[_]): DataFrame = { ... // -------- // 通过label生成一个Metadata,这个很关键!!! // metadata其实是一个map,内容为: // {"ml_attr":{"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}} // -------- val metadata = NominalAttribute.defaultAttr .withName($(outputCol)).withValues(filteredLabels).toMetadata() // 如果是skip则过滤一些数据 ... // 下面是针对不同的情况处理转换的列,逻辑很简单 val indexer = udf { label: String => ... if (labelToIndex.contains(label)) { labelToIndex(label) //如果正常,就进行转换 } else if (keepInvalid) { labels.length // 如果是keep,就返回索引的最大值(即数组的长度) } else { ... // 如果是error,就抛出异常 } } // 保留之前所有的列,新增一个字段,并设置字段的StructField中的Metadata!!!! // 并设置字段的StructField中的Metadata!!!! // 并设置字段的StructField中的Metadata!!!! // 并设置字段的StructField中的Metadata!!!! filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) }

看到了吗!关键的地方在这里,给新增加的字段的类型StructField设置了一个Metadata。这个Metadata正常都是空的{},但是这里设置了metadata之后,里面包含了label数组的信息。

接下来看看IndexToString是怎么用的,由于IndexToString是一个Transformer,因此只有一个trasform方法:

override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata // 关键是这里: // 如果IndexToString设置了labels数组,就直接返回; // 否则,就读取了传入的DataFrame的StructField中的Metadata val values = if (!isDefined(labels) || $(labels).isEmpty) { Attribute.fromStructField(inputColSchema) .asInstanceOf[NominalAttribute].values.get } else { $(labels) } // 基于这个values把index转成对应的值 val indexer = udf { index: Double => val idx = index.toInt if (0 <= idx && idx < values.length) { values(idx) } else { throw new SparkException(s"Unseen index: $index ??") } } val outputColName = $(outputCol) dataset.select(col("*"), indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) }

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/wpysgj.html