spark编写UDF和UDAF


UDF:

一、编写udf类,在其中定义udf函数

package spark._sql.UDF

import org.apache.spark.sql.functions._

/**
  * AUTHOR Guozy
  * DATE   2019/7/18-9:41
  **/
object udfs {
  def len(str: String): Int = str.length

  def ageThan(age: Int, small: Int): Boolean = age > small

  val ageThaner = udf((age: Int, bigger: Int) => age < bigger)
} 

二、在主方法中进行调用  

package spark._sql

import org.apache.log4j.Logger
import org.apache.spark.sql
import spark._sql.UDF.udfs._
import org.apache.spark.sql.functions._

/**
  * AUTHOR Guozy
  * DATE   2019/7/18-9:42
  **/
object UDFMain {
  val log = Logger.getLogger("UDFMain")

  def main(args: Array[String]): Unit = {
    val ssc = new sql.SparkSession.Builder()
      .master("local[2]")
      .appName(this.getClass.getSimpleName)
      .enableHiveSupport()
      .getOrCreate()

    ssc.sparkContext.setLogLevel("warn")

    val df = ssc.createDataFrame(Seq((22, 1), (24, 1), (11, 2), (15, 2))).toDF("age", "class_id")
    df.createOrReplaceTempView("table")

    ssc.udf.register("len", len _)
    ssc.sql("select age,len(age) as len from table").show(20, false)
    println("=====================================")
    ssc.udf.register("ageThan", ageThan _)
    ssc.sql("select age from table where ageThan(age,15)").show()
    println("=====================================")
    import ssc.implicits._
    val r = ssc.sql("select * from table")
    r.filter(ageThaner($"age", lit(20))).show()
    println("=====================================")

    ssc.stop()
  }
}

  运行结果:

  

  可以看到,以上代码中一共定义了三个不同的udf函数,分别对三个函数进行说明:

  • len(str: String):该函数使用用来获取传入字段的长度,str 即为所需要传入的字段
    •   在使用的时候,需要现将其进行注册并赋予其函数名:ssc.udf.register("len", len _),调用的时候直接在sql语句中通过函数名来进行调用
  • ageThan(age: Int, small: Int):该函数式用来比较传入的age与已有的small大小,返回一个boolean值,该函数需要是用在where条件语句中用来进行过滤使用
    •     在使用的时候,需要现将其进行注册并赋予其函数名:ssc.udf.register("ageThan", ageThan _),调用的时候直接在sql语句中通过函数名来进行调用
  • ageThaner:该函数跟上面两个不同,所谓的不同指的是:
    •   定义方式不同:通过使用org.apache.spark.sql.functions._ 中的udf函数在定义的时候就将其注册好
    •        使用场景不同:使用在dataframe中,用来进行select,filter操作中
    •        对于该函数的第二列来说,如果是常量的话,需要使用org.apache.spark.sql.function._ 中的lit进行包装,不能将常量直接传入,否则,程序不认识该常量会报错,如果是列名的话,则没问题,使用($"colName")方式即可。

UDAF:

  UDAF相对于udf来说稍微麻烦一下,且需要完全理解当中每个函数的含义才可以轻而易举的写出符合自己预期的UDAF函数,      

     UDAF需要继承 UserDefinedAggregateFunction ,并且复写当中的方法

方法含义说明:

def inputSchema: StructType =

    StructType(Array(StructField("value", IntegerType)))

  inputSchema用来定义,输入的字段的类型,字段名可以随便定义,这里定义为value,也可以是其他的,不重要,关键是字段类型一定要与所要传入计算的字段进行对应,且必须使用org.apche.spark.sql.type. _ 中的类型

def bufferSchema: StructType = StructType(Array(

    StructField("count", IntegerType), StructField("ages", DoubleType)))

  bufferSchema用来定义生成中间数据的结果类型,例如在求和的时候,要求a+b+c,相加顺序为a+b=ab,ab+c=abc ,ab即为中间结果。

def dataType: DataType = DoubleType

  dataType为函数返回值的类型,例子中,该UDAF最终返回的结果为double类型,这里的类型不能写成double,要写成org.apache.spark.sql.type._支持的类型DoubleType.

 def deterministic: Boolean = true

  daterministic 为代表结果是否为确定性的,也就是说,相同的输入是否有相同的输出。

def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
    buffer(1) = 0.0
  }

  initalize 初始化中间结果,即count和ages的初始值。

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0) + 1 //更新计数器
    buffer(1) = buffer.getDouble(1) + input.getInt(0) //更新值
  }

  update用来更新中间结果,input为dataframe中的一行,将要合并到buffer中的数据,buffer则为已经进行合并后的中间结果。

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
    buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
  }

  merge 合并所有分片的结果,buffer2是一个分片的中间结果,buffer1是整个合并过程中的结果。

def evaluate(buffer: Row): Any = {
    buffer.getDouble(1) / buffer.getInt(0)
  }

  evaluate 函数式真正进行计算的函数,计算返回函数的结果,buffer是merge合并后的结果

 

案例需求:求分组中age的平均数

  先上代码:

一、定义UDAF函数

package spark._sql.UDAF

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * AUTHOR Guozy
  * DATE   2019/7/18-14:47
  **/
class udafs() extends UserDefinedAggregateFunction {

  def inputSchema: StructType =

    StructType(Array(StructField("value", IntegerType)))

  def bufferSchema: StructType = StructType(Array(

    StructField("count", IntegerType), StructField("ages", DoubleType)))

  def dataType: DataType = DoubleType

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
    buffer(1) = 0.0
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0) + 1 //更新计数器
    buffer(1) = buffer.getDouble(1) + input.getInt(0) //更新值
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
    buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
  }

  def evaluate(buffer: Row): Any = {
    buffer.getDouble(1) / buffer.getInt(0)
  }
}

二、主函数引用:

package spark._sql.UDF

import org.apache.spark.sql
import org.apache.spark.sql.functions._
import spark._sql.UDAF.udafs

/**
  * AUTHOR Guozy
  * DATE   2019/7/19-16:04
  **/
object UDAFMain {
  def main(args: Array[String]): Unit = {
    val ssc = new sql.SparkSession.Builder()
      .master("local[2]")
      .appName(this.getClass.getSimpleName)
      .enableHiveSupport()
      .getOrCreate()

    ssc.sparkContext.setLogLevel("warn")

    val ageDF = ssc.createDataFrame(Seq((22, 1), (24, 1), (11, 2), (15, 2))).toDF("age", "class_id")
    ssc.udf.register("avgage", new udafs)
    ageDF.createOrReplaceTempView("table")
    ssc.sql("select avgage(age) from table group by class_id").show()

    ssc.stop()
  }
}

 运行结果:

  

智能推荐

注意!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。



 
© 2014-2019 ITdaan.com 粤ICP备14056181号  

赞助商广告