Hive中有UDF与UDAF,Spark中对UDF支持较早,UDAF:User Defined Aggregate Function。用户自定义聚合函数,是直到Spark 1.5.x才引入的最新特性。
UDAF,则可以针对多行输入,进行聚合计算。
编写一个实现平均数的UDAF
1、自定义UDAF,需要extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction,并实现接口中的8个方法
package com.spark.sql import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.expressions.MutableAggregationBuffer import org.apache.spark.sql.Row import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.StructField /** * @author Administrator */ class NumsAvg extends UserDefinedAggregateFunction { def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("nums", DoubleType) :: Nil) def bufferSchema: StructType = StructType( StructField("cnt", LongType) :: StructField("avg", DoubleType) :: Nil) def dataType: DataType = DoubleType def deterministic: Boolean = true def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0l buffer(1) = 0.0 } def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getAs[Long](0) + 1 buffer(1) = buffer.getAs[Double](1) + input.getAs[Double](0) } def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0) buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1) } def evaluate(buffer: Row): Any = { val t = buffer.getDouble(1) / buffer.getLong(0) f"$t%1.5f".toDouble } }
分别使用原生的avg()函数及自定义的numsAvg
package com.spark.sql import org.apache.spark.sql.SQLContext import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.sql.types.StructType import org.apache.spark.sql.Row import org.apache.spark.sql.types.StringType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.DoubleType /** * @author Administrator */ object NumsAvgTest { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("UDAF").setMaster("local") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) import org.apache.spark.sql.functions._ val nums = List(4.5, 2.4, 1.8) val numsRDD = sc.parallelize(nums, 1); val numsRowRDD = numsRDD.map { x => Row(x) } val structType = StructType(Array(StructField("num", DoubleType, true))) val numsDF = sqlContext.createDataFrame(numsRowRDD, structType) numsDF.registerTempTable("numtest") sqlContext.sql("select avg(num) from numtest ").collect().foreach { x => println(x) } sqlContext.udf.register("numsAvg", new NumsAvg) sqlContext.sql("select numsAvg(num) from numtest ").collect().foreach { x => println(x) } } }
原生的avg()
[2.9000000000000004]
自定义的numsAvg
[2.9]
自定义的函数可以自己控制好数据的精度
3、UDAF的编写实现UserDefinedAggregateFunction接口,使用时注册一下即可。
本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。