OopsOutOfMemory 盛利's Blog
专注大数据领域,分布式计算,Spark Contributor
<原创文章>

Hive UDFA 中位数

盛利的博客
, in 19 November 2014

第一次写UDAF,拿中位数来练手。

看下中位数定义: MEDIAN 中位数(一组数据按从小到大的顺序依次排列,处在中间位置的一个数或最中间两个数据的平均数) 写成genericUDAF的形式

1 2 3 4 中位数 2+3/2=2.5 1 2 3 中位数 2

package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.util.StringUtils;


@Description(name="median",value=""
        + "_FUNC_(x) return the median number of a number array. eg: median(x)")
public class GenericUDAFMedian extends AbstractGenericUDAFResolver {

    static final Log LOG = LogFactory.getLog(GenericUDAFMedian.class.getName());

    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
            throws SemanticException {
        if(parameters.length != 1) {
            throw new UDFArgumentTypeException(parameters.length-1, "Only 1 parameter is accepted!");
        }

        ObjectInspector objectInspector = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
        if(!ObjectInspectorUtils.compareSupported(objectInspector)) {
            throw new UDFArgumentTypeException(parameters.length - 1, "Cannot support comparison of map<> type or complex type containing map<>.");
        }

        switch (((PrimitiveTypeInfo)parameters[0]).getPrimitiveCategory()) {
        case BYTE:
        case SHORT:
        case INT:
            return new GenericUDAFMedianEvaluatorInt();
        case LONG:
            return new GenericUDAFMedianEvaluatorLong();
        case FLOAT:
        case DOUBLE:
            return new GenericUDAFMedianEvaluatorDouble();
        case STRING:
        case BOOLEAN:
        default:
          throw new UDFArgumentTypeException(0,
              "Only numeric type(int long double) arguments are accepted but "
              + parameters[0].getTypeName() + " was passed as parameter of index->1.");
        }
    }

    public static class GenericUDAFMedianEvaluatorInt extends GenericUDAFEvaluator {

        private DoubleWritable result = new DoubleWritable() ;
        PrimitiveObjectInspector inputOI;
        StructObjectInspector structOI;
        StandardListObjectInspector listOI;
        StructField listField;
        Object[] partialResult;  
        ListObjectInspector listFieldOI;

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {
             assert (parameters.length == 1);
             super.init(m, parameters);

             listOI = ObjectInspectorFactory.getStandardListObjectInspector(
                      PrimitiveObjectInspectorFactory.writableIntObjectInspector);
             //init input
             if(m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
                 inputOI = (PrimitiveObjectInspector) parameters[0];
             }
             else {
                 structOI = (StructObjectInspector) parameters[0];
                 listField = structOI.getStructFieldRef("list");
                 listFieldOI = (ListObjectInspector) listField.getFieldObjectInspector();
            }

            //init output
             if(m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
                 ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
                 foi.add(listOI);
                 ArrayList<String> fname = new ArrayList<String>();
                 fname.add("list");
                 partialResult = new Object[1];
                 partialResult[0] = new ArrayList<IntWritable>();
                 return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
             }else {
                return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            }

        }

        static class MedianNumberAgg implements AggregationBuffer {
            List<IntWritable> aggIntegerList;
        }

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            MedianNumberAgg resultAgg = new MedianNumberAgg();
            reset(resultAgg);
            return resultAgg;
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            MedianNumberAgg medianNumberAgg = (MedianNumberAgg)agg;
            medianNumberAgg.aggIntegerList = null;
            medianNumberAgg.aggIntegerList = new ArrayList<IntWritable>();
        }

         boolean warned = false;

        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters)
                throws HiveException {
            assert(parameters.length == 1);
            if(parameters[0] != null) {
                MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
                int val = 0;
                try {
                     val = PrimitiveObjectInspectorUtils.getInt(parameters[0], (PrimitiveObjectInspector)inputOI);
                } catch (NullPointerException e) {
                    LOG.warn("got a null value, skip it");
                }catch (NumberFormatException e) {
                    if(!warned) {
                        warned = true;
                        LOG.warn(getClass().getSimpleName() + " " + StringUtils.stringifyException(e));
                        LOG.warn("ignore similar exceptions.");
                    }

                }
                medianNumberAgg.aggIntegerList.add(new IntWritable(val));
            }
        }

        @SuppressWarnings("unchecked")
        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
            Collections.sort(medianNumberAgg.aggIntegerList);
            int size = medianNumberAgg.aggIntegerList.size();
            if(size == 1) {
                result.set((double)medianNumberAgg.aggIntegerList.get(0).get());
                return result;
            }
            double rs = 0.0;
//          int midIndex = (int) Math.floor(((double) size / 2));
            int midIndex = size / 2;
            if(size%2 == 1) {
                rs = (double) medianNumberAgg.aggIntegerList.get(midIndex).get();
            }
            else if(size%2 == 0) {
                rs =( medianNumberAgg.aggIntegerList.get(midIndex - 1).get() + medianNumberAgg.aggIntegerList.get(midIndex).get() ) / 2.0 ;
            }
            result.set(rs);
            return result;
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg)
                throws HiveException {
            MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
            partialResult[0] = new ArrayList<IntWritable>(medianNumberAgg.aggIntegerList.size());
            ((ArrayList<IntWritable>) partialResult[0]).addAll( medianNumberAgg.aggIntegerList);
            return partialResult;
        }

        @Override
        public void merge(AggregationBuffer agg, Object partial)
                throws HiveException {
            MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
            Object partialObject = structOI.getStructFieldData(partial, listField);
            ArrayList<IntWritable> resultList = (ArrayList<IntWritable>) listFieldOI.getList(partialObject);
            for( IntWritable  i : resultList) {
                medianNumberAgg.aggIntegerList.add(i);
            }
        }

    }


public static class GenericUDAFMedianEvaluatorDouble extends GenericUDAFEvaluator {

        private DoubleWritable result = new DoubleWritable() ;
        PrimitiveObjectInspector inputOI;
        StructObjectInspector structOI;
        StandardListObjectInspector listOI;
        StructField listField;
        Object[] partialResult;  
        ListObjectInspector listFieldOI;

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {
             assert (parameters.length == 1);
             super.init(m, parameters);

             listOI = ObjectInspectorFactory.getStandardListObjectInspector(
                      PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
             //init input
             if(m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
                 inputOI = (PrimitiveObjectInspector) parameters[0];
             }
             else {
                 structOI = (StructObjectInspector) parameters[0];
                 listField = structOI.getStructFieldRef("list");
                 listFieldOI = (ListObjectInspector) listField.getFieldObjectInspector();
            }

            //init output
             if(m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
                 ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
                 foi.add(listOI);
                 ArrayList<String> fname = new ArrayList<String>();
                 fname.add("list");
                 partialResult = new Object[1];
                 return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
             }else {
                return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            }

        }

        static class MedianNumberAgg implements AggregationBuffer {
            List<DoubleWritable> aggIntegerList;
        }

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            MedianNumberAgg resultAgg = new MedianNumberAgg();
            reset(resultAgg);
            return resultAgg;
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            MedianNumberAgg medianNumberAgg = (MedianNumberAgg)agg;
            medianNumberAgg.aggIntegerList = null;
            medianNumberAgg.aggIntegerList = new ArrayList<DoubleWritable>();
        }

         boolean warned = false;

        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters)
                throws HiveException {
            assert(parameters.length == 1);
            if(parameters[0] != null) {
                MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
                double val = 0.0;
                try {
                     val = PrimitiveObjectInspectorUtils.getDouble(parameters[0], (PrimitiveObjectInspector)inputOI);
                } catch (NullPointerException e) {
                    LOG.warn("got a null value, skip it");
                }catch (NumberFormatException e) {
                    if(!warned) {
                        warned = true;
                        LOG.warn(getClass().getSimpleName() + " " + StringUtils.stringifyException(e));
                        LOG.warn("ignore similar exceptions.");
                    }

                }
                medianNumberAgg.aggIntegerList.add(new DoubleWritable(val));
            }
        }

        @SuppressWarnings("unchecked")
        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
            Collections.sort(medianNumberAgg.aggIntegerList);
            int size = medianNumberAgg.aggIntegerList.size();
            if(size == 1) {
                result.set((double)medianNumberAgg.aggIntegerList.get(0).get());
                return result;
            }
            double rs = 0.0;
//          int midIndex = (int) Math.floor(((double) size / 2));
            int midIndex = size / 2;
            if(size%2 == 1) {
                rs = (double) medianNumberAgg.aggIntegerList.get(midIndex).get();
            }
            else if(size%2 == 0) {
                rs =( medianNumberAgg.aggIntegerList.get(midIndex - 1).get() + medianNumberAgg.aggIntegerList.get(midIndex).get() ) / 2.0 ;
            }
            result.set(rs);
            return result;
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg)
                throws HiveException {
            MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
            partialResult[0] = new ArrayList<DoubleWritable>(medianNumberAgg.aggIntegerList.size());
            ((ArrayList<DoubleWritable>) partialResult[0]).addAll(medianNumberAgg.aggIntegerList);
            return partialResult;
        }

        @Override
        public void merge(AggregationBuffer agg, Object partial)
                throws HiveException {
            MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
            Object partialObject = structOI.getStructFieldData(partial, listField);
            ArrayList<DoubleWritable> resultList = (ArrayList<DoubleWritable>) listFieldOI.getList(partialObject);
            for( DoubleWritable  i : resultList) {
                medianNumberAgg.aggIntegerList.add(i);
            }
        }

    }


public static class GenericUDAFMedianEvaluatorLong extends GenericUDAFEvaluator {

    private DoubleWritable result = new DoubleWritable() ;
    PrimitiveObjectInspector inputOI;
    StructObjectInspector structOI;
    StandardListObjectInspector listOI;
    StructField listField;
    Object[] partialResult;  
    ListObjectInspector listFieldOI;

    @Override
    public ObjectInspector init(Mode m, ObjectInspector[] parameters)
            throws HiveException {
         assert (parameters.length == 1);
         super.init(m, parameters);

         listOI = ObjectInspectorFactory.getStandardListObjectInspector(
                  PrimitiveObjectInspectorFactory.writableLongObjectInspector);
         //init input
         if(m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
             inputOI = (PrimitiveObjectInspector) parameters[0];
         }
         else {
             structOI = (StructObjectInspector) parameters[0];
             listField = structOI.getStructFieldRef("list");
             listFieldOI = (ListObjectInspector) listField.getFieldObjectInspector();
        }

        //init output
         if(m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
             ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
             foi.add(listOI);
             ArrayList<String> fname = new ArrayList<String>();
             fname.add("list");
             partialResult = new Object[1];
             partialResult[0] = new ArrayList<LongWritable>();
             return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
         }else {
            return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

    }

    static class MedianNumberAgg implements AggregationBuffer {
        List<LongWritable> aggIntegerList;
    }

    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
        MedianNumberAgg resultAgg = new MedianNumberAgg();
        reset(resultAgg);
        return resultAgg;
    }

    @Override
    public void reset(AggregationBuffer agg) throws HiveException {
        MedianNumberAgg medianNumberAgg = (MedianNumberAgg)agg;
        medianNumberAgg.aggIntegerList = null;
        medianNumberAgg.aggIntegerList = new ArrayList<LongWritable>();
    }

     boolean warned = false;

    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters)
            throws HiveException {
        assert(parameters.length == 1);
        if(parameters[0] != null) {
            MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
            long val = 0L;
            try {
                 val = PrimitiveObjectInspectorUtils.getLong(parameters[0], (PrimitiveObjectInspector)inputOI);
            } catch (NullPointerException e) {
                LOG.warn("got a null value, skip it");
            }catch (NumberFormatException e) {
                if(!warned) {
                    warned = true;
                    LOG.warn(getClass().getSimpleName() + " " + StringUtils.stringifyException(e));
                    LOG.warn("ignore similar exceptions.");
                }

            }
            medianNumberAgg.aggIntegerList.add(new LongWritable(val));
        }
    }

    @SuppressWarnings("unchecked")
    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
        MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
        Collections.sort(medianNumberAgg.aggIntegerList);
        int size = medianNumberAgg.aggIntegerList.size();
        if(size == 1) {
            result.set((double)medianNumberAgg.aggIntegerList.get(0).get());
            return result;
        }
        double rs = 0.0;
//      int midIndex = (int) Math.floor(((double) size / 2));
        int midIndex = size / 2;
        if(size%2 == 1) {
            rs = (double) medianNumberAgg.aggIntegerList.get(midIndex).get();
        }
        else if(size%2 == 0) {
            rs =( medianNumberAgg.aggIntegerList.get(midIndex - 1).get() + medianNumberAgg.aggIntegerList.get(midIndex).get() ) / 2.0 ;
        }
        result.set(rs);
        return result;
    }

    @Override
    public Object terminatePartial(AggregationBuffer agg)
            throws HiveException {
        MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
        partialResult[0] = new ArrayList<LongWritable>(medianNumberAgg.aggIntegerList.size());
        ((ArrayList<LongWritable>) partialResult[0]).addAll(medianNumberAgg.aggIntegerList);
        return partialResult;
    }

    @Override
    public void merge(AggregationBuffer agg, Object partial)
            throws HiveException {
        MedianNumberAgg medianNumberAgg = (MedianNumberAgg) agg;
        Object partialObject = structOI.getStructFieldData(partial, listField);
        ArrayList<LongWritable> resultList = (ArrayList<LongWritable>) listFieldOI.getList(partialObject);
        for( LongWritable  i : resultList) {
            medianNumberAgg.aggIntegerList.add(i);
        }
    }

}

}

写好之后,用eclipse打好jar包。

测试:

use datawarehouse;
add jar /home/hadoop/shengli/median.jar;
create temporary function median as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMedian';
select median(id) from
(
select 7 id from dual
union all
select 8 id from dual
union all
select 1 id from dual

) a;


select median(id) from
(
select cast(1 as bigint) id from dual
union all 
select cast(2 as bigint) id from dual
) a


select median(id) from
(
select 1.0 id from dual
union all 
select 2.3 id from dual
) a

select median(id) from
(
select 1 id from dual
union all
select 2 id from dual
union all
select 3 id from dual
) a


select median(id) from
(

select null id from dual
) a
---------------------------------
select type,median(id) from
(
select 'a' type,3 id from dual
union all
select 'a' type,-2 id from dual
union all
select 'a' type,1 id from dual
union all
select 'a' type,4 id from dual
union all
select 'b' type,6 id from dual
union all
select 'b' type,5 id from dual
union all
select 'b' type,4 id from dual
) a
group by type

其实写好UDAF的关键是理解各个阶段的含义,后续我会写一篇UDAF的介绍文章。

(The End)
<原创文章> From OopsOutOfMemory 盛利's Blog
转载请注明出自: http://oopsoutofmemory.github.io/hive/2014/11/19/hive-udfa--zhong-wei-shu