Spark编写UDAF自定义函数


Hive中有UDF与UDAF,Spark中对UDF支持较早,UDAF:User Defined Aggregate Function。用户自定义聚合函数,是直到Spark 1.5.x才引入的最新特性。

UDAF,则可以针对多行输入,进行聚合计算。

编写一个实现平均数的UDAF


1、自定义UDAF,需要extends  org.apache.spark.sql.expressions.UserDefinedAggregateFunction,并实现接口中的8个方法

package com.spark.sql

import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.StructField

/**
* @author Administrator
*/
class NumsAvg extends UserDefinedAggregateFunction {
def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("nums", DoubleType) :: Nil)

def bufferSchema: StructType = StructType(
StructField("cnt", LongType) ::
StructField("avg", DoubleType) :: Nil)

def dataType: DataType = DoubleType

def deterministic: Boolean = true

def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0l
buffer(1) = 0.0
}

def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Long](0) + 1
buffer(1) = buffer.getAs[Double](1) + input.getAs[Double](0)
}

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}

def evaluate(buffer: Row): Any = {
val t = buffer.getDouble(1) / buffer.getLong(0)
f"$t%1.5f".toDouble
}
}

2、使用自定义的UDAF测试

分别使用原生的avg()函数及自定义的numsAvg

package com.spark.sql

import org.apache.spark.sql.SQLContext
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DoubleType

/**
* @author Administrator
*/
object NumsAvgTest {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("UDAF").setMaster("local")

val sc = new SparkContext(conf)

val sqlContext = new SQLContext(sc)

import org.apache.spark.sql.functions._

val nums = List(4.5, 2.4, 1.8)
val numsRDD = sc.parallelize(nums, 1);

val numsRowRDD = numsRDD.map { x => Row(x) }

val structType = StructType(Array(StructField("num", DoubleType, true)))

val numsDF = sqlContext.createDataFrame(numsRowRDD, structType)

numsDF.registerTempTable("numtest")
sqlContext.sql("select avg(num) from numtest ").collect().foreach { x => println(x) }

sqlContext.udf.register("numsAvg", new NumsAvg)
sqlContext.sql("select numsAvg(num) from numtest ").collect().foreach { x => println(x) }
}
}

测试结果对比

原生的avg()

[2.9000000000000004]

自定义的numsAvg

[2.9]

自定义的函数可以自己控制好数据的精度


3、UDAF的编写实现UserDefinedAggregateFunction接口,使用时注册一下即可。


智能推荐

注意!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。



 
© 2014-2019 ITdaan.com 粤ICP备14056181号  

赞助商广告