欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

SparkSQL-用户自定义函数

时间:2023-04-18
1.准备工作

spark版本3.0.0

org.apache.spark spark-sql_2.12 3.0.0

读取文件数据如下:

2.基本用法 2.1直接注册udf

object SparkSQL_UDF { def main(args: Array[String]): Unit = { val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("udf") val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate() val inputDF: Dataframe = spark.read.json("datas\student.json") inputDF.createOrReplaceTempView("student") //注册udf spark.udf.register("prefix", (name:String) => { "hnu:" + name }) //这样查询出来的每个name都有hnu:的前缀 spark.sql("select prefix(name) from student").show() spark.close() }}

其实我们一般什么情况用自定义函数用得比较多呢?就是涉及到聚合操作的时候,虽然sql中提供了avg(),count()等等函数,但我们这里可以模仿着实现以下求平均数这个聚合函数功能。

2.2继承UserDefinedAggregateFunction(3.x后不推荐使用)

object SparkSQL_UDAF_1 { def main(args: Array[String]): Unit = { val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("udf") val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate() val inputDF: Dataframe = spark.read.json("datas\student.json") inputDF.createOrReplaceTempView("student") spark.udf.register("myAvg", new MyAvg) spark.sql("select myAvg(age) from student").show() }}class MyAvg extends UserDefinedAggregateFunction { //输入的数据格式 override def inputSchema: StructType = { StructType{ Array( StructField("age", DoubleType) ) } } //缓冲区用于做计算的数据结构(保存总和和个数用于求平均值) override def bufferSchema: StructType = { StructType{ Array( StructField("sum", DoubleType), StructField("count", DoubleType) ) } } //输出数据类型,平均值 override def dataType: DataType = DoubleType //函数稳定性,相同输入总是返回相同结果 override def deterministic: Boolean = true //缓冲区初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { //update(i,value)表示更新缓冲区中索引为i的值为value buffer.update(0, 0.0) buffer.update(1, 0.0) } //根据输入的值更新缓冲区 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //总和求和,个数加1 buffer.update(0, buffer.getDouble(0) + input.getDouble(0)) buffer.update(1, buffer.getDouble(1) + 1) } //合并缓冲区 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1.update(0, buffer1.getDouble(0) + buffer2.getDouble(0)) buffer1.update(1, buffer1.getDouble(1) + buffer2.getDouble(1)) } //计算平均值 override def evaluate(buffer: Row): Any = { buffer.getDouble(0) / buffer.getDouble(1) }}

2.3继承Aggregator(3.X版本后推荐)

object SparkSQL_UDAF_2 { def main(args: Array[String]): Unit = { val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("udf") val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate() val inputDF: Dataframe = spark.read.json("datas\student.json") inputDF.createOrReplaceTempView("student") spark.udf.register("myAvg", functions.udaf(new MyAvg2)) spark.sql("select myAvg(age) from student").show() }}case class Buffer ( var sum: Double, var count: Double)//三个类型分别为输入值类型、缓冲区类型、输出结果类型class MyAvg2 extends Aggregator[Double, Buffer, Double] { //初始化"0"值 override def zero: Buffer = { Buffer(0.0, 0.0) } //根据输入数据进行聚合 override def reduce(b: Buffer, a: Double): Buffer = { b.sum += a b.count += 1 b } //合并缓冲区 override def merge(b1: Buffer, b2: Buffer): Buffer = { b1.sum += b2.sum b1.count += b2.count b1 } //计算结果 override def finish(reduction: Buffer): Double = reduction.sum / reduction.count //编码,自定义的类是product,scala存在的类就是scala+类 override def bufferEncoder: Encoder[Buffer] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble}

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。