赞
踩
以下基于hive 3.1.2版本
Hive中自定义UDF函数,有两种实现方式,一是通过继承org.apache.hadoop.hive.ql.exec.UDF
类实现,二是通过继承org.apache.hadoop.hive.ql.udf.generic.GenericUDF
类实现。
无论是哪种方式,实现步骤都是:
首先引入pom依赖:
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-exec</artifactId>
<version>3.1.2</version>
</dependency>
继承UDF类实现时只需要实现evaluate方法就可以了,写之前,找了replace函数的源码用来参考,源码贴在下面:
package org.apache.hadoop.hive.ql.udf; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.io.Text; /** * UDFReplace replaces all substrings that are matched with a replacement substring. * */ @Description(name = "replace", value = "_FUNC_(str, search, rep) - replace all substrings of 'str' that " + "match 'search' with 'rep'", extended = "Example:\n" + " > SELECT _FUNC_('Hack and Hue', 'H', 'BL') FROM src LIMIT 1;\n" + " 'BLack and BLue'") public class UDFReplace extends UDF { private Text result = new Text(); public UDFReplace() { } public Text evaluate(Text s, Text search, Text replacement) { if (s == null || search == null || replacement == null) { return null; } String r = s.toString().replace(search.toString(), replacement.toString()); result.set(r); return result; } }
模仿上面,自己定义了个函数,功能和hive中的repeat函数一样:
package com.demo.hive; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; @Description(name = "my_repeat", // 用于描述该类在hive中对应的函数名,一般与hive中的映射函数名保持一致 value = "_FUNC_(str, n): repeat str n times", // "desc function xxx"时显示的内容 extended = "Example SQL: select _FUNC_('a',3);\nResult: 'aaa'") // "desc function extended xxx"时显示的内容 public class MyUDFRepeat extends UDF { // 涉及到hive中的字符或字符串类型,建议使用Text类处理 private Text res = new Text(); public Text evaluate(Text str, IntWritable n) { if (str == null || n == null) { return null; } if (n.get() > 0) { byte[] arr = str.getBytes(); byte[] newArr = new byte[str.getLength() * n.get()]; for (int i = 0; i < n.get(); i++) { System.arraycopy(arr, 0, newArr, i * str.getLength(), str.getLength()); } res.set(newArr); } return res; } }
在写上面这个函数时,最开始出现了一些问题,逻辑上怎么检查都没看出来,捯饬了将近一天才发现原来是Text类中的getByte()和String中的getByte()略有区别(返回的字节数组长度并不相等),后来将所有的
str.getbytes().length
换成str.getLength()
就好了,这里以后再深入研究一下。关于Text类的API:https://hadoop.apache.org/docs/r3.1.2/api/index.html
将上面源码打成jar包之后上传到hive服务所在主机或者hadoop上,然后在本地idea中执行:
add jar /root/HiveLib/hive_udf-1.0-SNAPSHOT.jar; // jar包加入到hive环境
create temporary function my_repeat as 'com.demo.hive.MyUDFRepeat'; // 创建临时函数,只对当前session生效
创建完函数可以查看一下函数详细信息:
desc function extended my_repeat;
跑下测试数据验证效果:
select *,my_repeat(name,2),repeat(name,2) from db_prac.employee;
同样先贴一下length函数的源码,通过GenericUDF类实现需要实现父类中的三个抽象方法:initialize()、evaluate()、getDisplayString()
package org.apache.hadoop.hive.ql.udf.generic; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions; import org.apache.hadoop.hive.ql.exec.vector.expressions.StringLength; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.lazy.LazyBinary; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.IntWritable; /** * GenericUDFLength. * */ @Description(name = "length", value = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data", extended = "Example:\n" + " > SELECT _FUNC_('Facebook') FROM src LIMIT 1;\n" + " 8") @VectorizedExpressions({StringLength.class}) public class GenericUDFLength extends GenericUDF { private final IntWritable result = new IntWritable(); private transient PrimitiveObjectInspector argumentOI; private transient PrimitiveObjectInspectorConverter.StringConverter stringConverter; private transient PrimitiveObjectInspectorConverter.BinaryConverter binaryConverter; private transient boolean isInputString; @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { if (arguments.length != 1) { throw new UDFArgumentLengthException( "LENGTH requires 1 argument, got " + arguments.length); } if (arguments[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { throw new UDFArgumentException( "LENGTH only takes primitive types, got " + argumentOI.getTypeName()); } argumentOI = (PrimitiveObjectInspector) arguments[0]; PrimitiveObjectInspector.PrimitiveCategory inputType = argumentOI.getPrimitiveCategory(); ObjectInspector outputOI = null; switch (inputType) { case CHAR: case VARCHAR: case STRING: isInputString = true; stringConverter = new PrimitiveObjectInspectorConverter.StringConverter(argumentOI); break; case BINARY: isInputString = false; binaryConverter = new PrimitiveObjectInspectorConverter.BinaryConverter(argumentOI, PrimitiveObjectInspectorFactory.writableBinaryObjectInspector); break; default: throw new UDFArgumentException( " LENGTH() only takes STRING/CHAR/VARCHAR/BINARY types as first argument, got " + inputType); } outputOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector; return outputOI; } @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { byte[] data = null; if (isInputString) { String val = null; if (arguments[0] != null) { val = (String) stringConverter.convert(arguments[0].get()); } if (val == null) { return null; } data = val.getBytes(); int len = 0; for (int i = 0; i < data.length; i++) { if (GenericUDFUtils.isUtfStartByte(data[i])) { len++; } } result.set(len); return result; } else { BytesWritable val = null; if (arguments[0] != null) { val = (BytesWritable) binaryConverter.convert(arguments[0].get()); } if (val == null) { return null; } result.set(val.getLength()); return result; } } @Override public String getDisplayString(String[] children) { return getStandardDisplayString("length", children); } }
模仿上面,下面写了个判断是否是子字符串的函数:
package com.demo.hive; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; @Description(name = "str_contains", value = "_FUNC_(str1, str2): return true if str1 contains str2, else return false") public class MyGenericUDFContains extends GenericUDF { private StringObjectInspector pos1; private StringObjectInspector pos2; @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { // 检查参数个数 if (arguments.length != 2) { throw new UDFArgumentLengthException("参数个数必须为2"); } // 检查参数类型 if (!(arguments[0] instanceof StringObjectInspector) || !(arguments[1] instanceof StringObjectInspector)) { throw new UDFArgumentException("参数必须都为String类型"); } this.pos1 = (StringObjectInspector) arguments[0]; this.pos2 = (StringObjectInspector) arguments[1]; // 函数结果返回类型为布尔类型 return PrimitiveObjectInspectorFactory.javaBooleanObjectInspector; } @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { String str1 = this.pos1.getPrimitiveJavaObject(arguments[0].get()); String str2 = this.pos2.getPrimitiveJavaObject(arguments[1].get()); return str1.contains(str2) ? Boolean.TRUE : Boolean.FALSE; } @Override public String getDisplayString(String[] children) { return getStandardDisplayString("str_contains", children); } }
打jar包上传之后,创建映射函数:
create temporary function str_contains as 'com.demo.hive.MyGenericUDFContains';
查看一下函数信息:
desc function extended str_contains;
跑一下测试数据:
select name, str_contains(name,"i") from db_prac.employee;
end
create function my_repeat as 'com.demo.hive.MyUDFRepeat' using jar "hdfs:/user/hive/lib/hive_udf-1.0-SNAPSHOT.jar";
永久注册。drop [temporary] function xxx;
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。