Hive中有UDF与UDAF,Spark中对UDF支持较早,UDAF:User Defined Aggregate Function。用户自定义聚合函数,是直到Spark 1.5.x才引入的最新特性。
UDAF,则可以针对多行输入,进行聚合计算。
编写一个实现平均数的UDAF
1、自定义UDAF,需要extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction,并实现接口中的8个方法
package com.spark.sql
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.StructField
/**
* @author Administrator
*/
class NumsAvg extends UserDefinedAggregateFunction {
def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("nums", DoubleType) :: Nil)
def bufferSchema: StructType = StructType(
StructField("cnt", LongType) ::
StructField("avg", DoubleType) :: Nil)
def dataType: DataType = DoubleType
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0l
buffer(1) = 0.0
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Long](0) + 1
buffer(1) = buffer.getAs[Double](1) + input.getAs[Double](0)
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}
def evaluate(buffer: Row): Any = {
val t = buffer.getDouble(1) / buffer.getLong(0)
f"$t%1.5f".toDouble
}
}
分别使用原生的avg()函数及自定义的numsAvg
package com.spark.sql
import org.apache.spark.sql.SQLContext
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DoubleType
/**
* @author Administrator
*/
object NumsAvgTest {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("UDAF").setMaster("local")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import org.apache.spark.sql.functions._
val nums = List(4.5, 2.4, 1.8)
val numsRDD = sc.parallelize(nums, 1);
val numsRowRDD = numsRDD.map { x => Row(x) }
val structType = StructType(Array(StructField("num", DoubleType, true)))
val numsDF = sqlContext.createDataFrame(numsRowRDD, structType)
numsDF.registerTempTable("numtest")
sqlContext.sql("select avg(num) from numtest ").collect().foreach { x => println(x) }
sqlContext.udf.register("numsAvg", new NumsAvg)
sqlContext.sql("select numsAvg(num) from numtest ").collect().foreach { x => println(x) }
}
}
原生的avg()
[2.9000000000000004]
自定义的numsAvg
[2.9]
自定义的函数可以自己控制好数据的精度
3、UDAF的编写实现UserDefinedAggregateFunction接口,使用时注册一下即可。
本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。