卷首语
前一篇文章 hive UDAF开发入门和运行过程详解(转) 里面讲过UDAF的开发过程,其中说到如果要深入理解UDAF的执行,可以看看求平均值的UDF的源码
本人在看完源码后,也还是没能十分理解里面的内容,于是动手再自己开发一个新的函数,试图多实践中理解它
函数功能介绍
函数的功能比较蛋疼,我们都知道Hive中有几个常用的聚合函数:sum,max,min,avg
现在要用一个函数来同时实现俩个不同的功能,对于同一个key,要求返回指定value集合中的最大值与最小值
这里面涉及到一个难点,函数接收到的数据只有一个,但是要同时产生出俩个新的数据出来,且具备一定的逻辑关系
语言描述这东西我不大懂,想了好久,还是直接上代码得了。。。。。。。。。。。。。
源码
package
org.juefan.udaf;
import
java.util.ArrayList;
import
org.apache.commons.logging.Log;
import
org.apache.commons.logging.LogFactory;
import
org.apache.hadoop.hive.ql.exec.Description;
import
org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import
org.apache.hadoop.hive.ql.metadata.HiveException;
import
org.apache.hadoop.hive.ql.parse.SemanticException;
import
org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import
org.apache.hadoop.hive.serde2.io.DoubleWritable;
import
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import
org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import
org.apache.hadoop.hive.serde2.objectinspector.StructField;
import
org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import
org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import
org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import
org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import
org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import
org.apache.hadoop.io.LongWritable;
import
org.apache.hadoop.io.Text;
import
org.apache.hadoop.util.StringUtils;
/**
* GenericUDAFMaxMin.
*/
@Description(name
= "maxmin", value = "_FUNC_(x) - Returns the max and min value of a set of numbers"
)
public
class
GenericUDAFMaxMin
extends
AbstractGenericUDAFResolver {
static
final
Log LOG = LogFactory.getLog(GenericUDAFMaxMin.
class
.getName());
@Override
public
GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws
SemanticException {
if
(parameters.length != 1
) {
throw
new
UDFArgumentTypeException(parameters.length - 1
,
"Exactly one argument is expected."
);
}
if
(parameters[0].getCategory() !=
ObjectInspector.Category.PRIMITIVE) {
throw
new
UDFArgumentTypeException(0
,
"Only primitive type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed."
);
}
switch
(((PrimitiveTypeInfo) parameters[0
]).getPrimitiveCategory()) {
case
BYTE:
case
SHORT:
case
INT:
case
LONG:
case
FLOAT:
case
DOUBLE:
case
STRING:
case
TIMESTAMP:
return
new
GenericUDAFMaxMinEvaluator();
case
BOOLEAN:
default
:
throw
new
UDFArgumentTypeException(0
,
"Only numeric or string type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed."
);
}
}
/**
* GenericUDAFMaxMinEvaluator.
*
*/
public
static
class
GenericUDAFMaxMinEvaluator
extends
GenericUDAFEvaluator {
//
For PARTIAL1 and COMPLETE
PrimitiveObjectInspector inputOI;
//
For PARTIAL2 and FINAL
StructObjectInspector soi;
//
封装好的序列化数据接口,存储计算过程中的最大值与最小值
StructField maxField;
StructField minField;
//
存储数据,利用get()可直接返回double类型值
DoubleObjectInspector maxFieldOI;
DoubleObjectInspector minFieldOI;
//
For PARTIAL1 and PARTIAL2
//
存储中间的结果
Object[] partialResult;
//
For FINAL and COMPLETE
//
最终输出的数据
Text result;
@Override
public
ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws
HiveException {
assert
(parameters.length == 1
);
super
.init(m, parameters);
//
初始化数据输入过程
if
(m == Mode.PARTIAL1 || m ==
Mode.COMPLETE) {
inputOI
= (PrimitiveObjectInspector) parameters[0
];
}
else
{
//
如果接收到的数据是中间数据,则转换成相应的结构体
soi = (StructObjectInspector) parameters[0
];
//
获取指定字段的序列化数据
maxField = soi.getStructFieldRef("max"
);
minField
= soi.getStructFieldRef("min"
);
//
获取指定字段的实际数据
maxFieldOI =
(DoubleObjectInspector) maxField.getFieldObjectInspector();
minFieldOI
=
(DoubleObjectInspector) minField.getFieldObjectInspector();
}
//
初始化数据输出过程
if
(m == Mode.PARTIAL1 || m ==
Mode.PARTIAL2) {
//
输出的数据是一个结构体,其中包含了max和min的值
//
存储结构化数据类型
ArrayList<ObjectInspector> foi =
new
ArrayList<ObjectInspector>
();
foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
//
存储结构化数据的字段名称
ArrayList<String> fname =
new
ArrayList<String>
();
fname.add(
"max"
);
fname.add(
"min"
);
partialResult
=
new
Object[2
];
partialResult[
0] =
new
DoubleWritable(0
);
partialResult[
1] =
new
DoubleWritable(0
);
return
ObjectInspectorFactory.getStandardStructObjectInspector(fname,
foi);
}
else
{
//
如果执行到了最后一步,则指定相应的输出数据类型
result =
new
Text(""
);
return
PrimitiveObjectInspectorFactory.writableStringObjectInspector;
}
}
static
class
AverageAgg
implements
AggregationBuffer {
double
max;
double
min;
};
@Override
public
AggregationBuffer getNewAggregationBuffer()
throws
HiveException {
AverageAgg result
=
new
AverageAgg();
reset(result);
return
result;
}
@Override
public
void
reset(AggregationBuffer agg)
throws
HiveException {
AverageAgg myagg
=
(AverageAgg) agg;
myagg.max
=
Double.MIN_VALUE;
myagg.min
=
Double.MAX_VALUE;
}
boolean
warned =
false
;
@Override
public
void
iterate(AggregationBuffer agg, Object[] parameters)
throws
HiveException {
assert
(parameters.length == 1
);
Object p
= parameters[0
];
if
(p !=
null
) {
AverageAgg myagg
=
(AverageAgg) agg;
try
{
//
获取输入数据,并进行相应的大小判断
double
v =
PrimitiveObjectInspectorUtils.getDouble(p, inputOI);
if
(myagg.max <
v){
myagg.max
=
v;
}
if
(myagg.min >
v){
myagg.min
=
v;
}
}
catch
(NumberFormatException e) {
if
(!
warned) {
warned
=
true
;
LOG.warn(getClass().getSimpleName()
+ " "
+
StringUtils.stringifyException(e));
LOG.warn(getClass().getSimpleName()
+ " ignoring similar exceptions."
);
}
}
}
}
@Override
public
Object terminatePartial(AggregationBuffer agg)
throws
HiveException {
//
将中间计算出的结果封装好返回给下一步操作
AverageAgg myagg =
(AverageAgg) agg;
((DoubleWritable) partialResult[
0
]).set(myagg.max);
((DoubleWritable) partialResult[
1
]).set(myagg.min);
return
partialResult;
}
@Override
public
void
merge(AggregationBuffer agg, Object partial)
throws
HiveException {
if
(partial !=
null
) {
//
此处partial接收到的是terminatePartial的输出数据
AverageAgg myagg =
(AverageAgg) agg;
Object partialmax
=
soi.getStructFieldData(partial, maxField);
Object partialmin
=
soi.getStructFieldData(partial, minField);
if
(myagg.max <
maxFieldOI.get(partialmax)){
myagg.max
=
maxFieldOI.get(partialmax);
}
if
(myagg.min >
minFieldOI.get(partialmin)){
myagg.min
=
minFieldOI.get(partialmin);
}
}
}
@Override
public
Object terminate(AggregationBuffer agg)
throws
HiveException {
//
将最终的结果合并成字符串后输出
AverageAgg myagg =
(AverageAgg) agg;
if
(myagg.max == 0
) {
return
null
;
}
else
{
result.set(myagg.max
+ "\t" +
myagg.min);
return
result;
}
}
}
}
写完后还是觉得没有怎么理解透整个过程,所以上面的注释也就将就着看了,不保证一定正确的!
下午加上一些输出跟踪一下执行过程才行,不过代码的逻辑是没有问题的了,本人运行过!

