ChiMerge算法 (java)

系统 2277 0

韩家炜 数据挖掘概念与技术 第三版 习题3.12

鸢尾花数据集iris.data 作为待离散化的数据集合,使用ChiMerge算法,对四个数值属性进 行离散化,对四个属性进行区间合并,最终合并区间个数剩下为6个即停:即max_interval=6。

一、样本数据

iris.data 数据形式为:前面4列是属性,最后一列是数据类名,

      
        5.1,3.5,1.4,0.2,Iris-setosa

4.9,3.0,1.4,0.2,Iris-setosa

4.7,3.2,1.3,0.2,Iris-setosa

6.6,2.9,4.6,1.3,Iris-versicolor

5.2,2.7,3.9,1.4,Iris-versicolor

6.0,3.0,4.8,1.8,Iris-virginica

6.9,3.1,5.4,2.1,Iris-virginica
      
      
........

 此数据集一共3个类名:String[] classifies = {"Iris-setosa","Iris-versicolor","Iris-virginica"};


 

二、算法理论:

算法理论步骤参考: http://blog.csdn.net/zhaoyl03/article/details/8689440

第一步:初始化 

初始化时,一个数据认为是一个区间,每个属性对该属性下的各个区间进行升序排序

第二步:合并区间:(直到剩下区间数目为6)

(1) 计算每一对相邻区间的卡方值

   卡方公式是: 

  其中observed是expected是一个二行n列矩阵,二行是两个区间,n列是指数据一共有n个类。

       这里iris.data数据中一共有三个类,所以是2行3列矩阵:e.g


 observedmatrix:(下面表只有红色数字部分才为observedmatrix[2][3]的值。)
区间: 类别Iris-setosa 类别 Iris-versicolor 类别 Iris-virginica i行计算1的总个数
{3.0} 1 0 0 1
{3.1,3.2,3.3} 0 1 2 3
j列计算1的总个数 1 1 2 4    (矩阵里 1的总个数 )
 
  expectedmatrix[i][j]是由上面 observedmatrix所得, expectedmatrix[i][j]= (obser矩阵的i行1的总个数* j列1的总个数)/(observ矩阵里1的总个数) 
 expectedmatrix:
区间: 类别 Iris-setosa 类别 Iris-versicolor 类别 Iris-virginica
{3.0} 1*1/4=0.25 1/4=0.25 2*1/4=0.5
{3.2,3.3} 1*3/4=0.75 1/4=0.25 2*3/4 = 1.5
 
因为是2行3列矩阵,所以一共卡方迭加了2*3=6次; ( observedmatrix[i][j]- expectedmatrix[i][j])/  expectedmatrix[i][j]
所以  chisquare = (1-0.25)^2 /0.25 + (0-0.25)^2/0.25 +(0-0.5)^2/0.5 + (0-0.75)^2/0.75 + (1-0.25)^2/0.25 + (2-1.5)^2/1.5

(2) 将上面卡方值最小的一对区间合并

      

第三步:输出结果:6个区间的最大最小值

 


 

三、算法理论数据结构化

将上面算法理论数据结构化:

iris.data  中

1.属性:每个属性都有多个区间,所以定义属性是一个list,list的元素是什么类型呢? 是一个区间类型(所以写一个区间类:包括 区间最大最小值,区间包含的元素)。

2.区间:每个区间会包含很多元素,所以也需要一个list来存,list元素什么类型好?  再写一个数据 Data类,包括(数据,数据对应的类别(在卡方运算里会用到类别))

 所有数据都具备了结构了,整体结构这是最重要的。

      List<Interval>[] attributelists = 
      
        new
      
      
         ArrayList[attributenum]; 




      
      
        for
      
      (
      
        int
      
       i=0;i<attributenum;i++
      
        ) {

            attributelists[i] 
      
      = 
      
        new
      
       ArrayList<Interval>
      
        ();

        }






      
      
        class
      
      
         Interval {

    
      
      
        //
      
      
        每个区间都是有最小值最大值,以及该区间所包含的所有数据
      
      
        public
      
      
        double
      
       maxvalue = 0.0
      
        ;

    
      
      
        public
      
      
        double
      
       minvalue = 0.0
      
        ;

    
      
      
        public
      
       List<Data> intervallist = 
      
        new
      
       ArrayList<Data>();  
      
        //
      
      
        区间里的list每个元素都是Data类型
      
      
        }




      
      
        class
      
      
         Data {       

    
      
      
        //
      
      
        每个数据都包含它的值和类别
      
      
        public
      
      
        double
      
        value = 0.0
      
        ;

    
      
      
        public
      
       String  classify = ""
      
        ;

}
      
    

 


 

四、Java 实现

      
        public class
      
       ChiMergeTest {
      
public static int classificationnum = 3; // 类个数 public static int attributenum = 4 ; public static List<Interval>[] attributelists = new ArrayList[attributenum]; // 右边不能Arraylist<interval>!! public static String[] classifies = {"Iris-setosa","Iris-versicolor","Iris-virginica" }; public static void main(String[] args) throws Exception { String inputpath = "iris.data" ; readFile(inputpath); // 将输入数据的 结构化 chiMerge(); printresult(); }

 

对应上面算法步骤:

第一步:初始化 

初始化时,一个数据认为是一个区间,每个属性对该属性下的各个区间进行升序排序

      
        public
      
      
        static
      
      
        void
      
       readFile(String inputpath) 
      
        throws
      
      
         Exception {

        BufferedReader br 
      
      = 
      
        new
      
       BufferedReader(
      
        new
      
      
         FileReader(inputpath));

        String line 
      
      =
      
         br.readLine();

        

        
      
      
        for
      
      (
      
        int
      
       i=0;i<attributenum;i++
      
        ) {

            attributelists[i] 
      
      = 
      
        new
      
       ArrayList<Interval>
      
        ();

        }

        

        
      
      
        while
      
      (line!= 
      
        null
      
      && line.length()>0
      
        ) {

            String[] temp 
      
      = line.split(",");  
      
        //
      
      
        将数据分隔,
      
      
        for
      
      (
      
        int
      
       i=0; i<attributelists.length; i++) {  
      
        //
      
      
        遍历属性名
      
      

                Interval interval = 
      
        new
      
      
         Interval();

                Data onedata 
      
      = 
      
        new
      
      
         Data();

                

                onedata.value 
      
      =
      
         Double.parseDouble(temp[i]);

                onedata.classify 
      
      = temp[4
      
        ];

                

                interval.minvalue 
      
      = interval.maxvalue =
      
         onedata.value;

                interval.intervallist.add(onedata);  
      
      
        //
      
      
        区间加入了一个数据
      
      

                attributelists[i].add(interval);     
      
        //
      
      
        第i个属性加入了一个区间
      
      
                    }

        line 
      
      =
      
         br.readLine();

        }

        br.close();

        sort();
      
      
        

    }     





    
      
      
        public
      
      
        static
      
      
        void
      
       sort() {   
      
        //
      
      
        初步建立属性list时,对区间进行排序
      
      
        for
      
      (
      
        int
      
       i = 0; i<attributenum; i++
      
        ){

            List
      
      <Interval> attrlist =
      
         attributelists[i]; 

            Collections.sort(attrlist,
      
      
        new
      
       IntervalComparator());  
      
        //
      
      
        排序
      
      
                    combineRepeatedData(attrlist);  


      
      
        //
      
      
                    CombineRepeatedDatawithHash(attrlist); 
      
      
        //
      
      
        等同于上面方法,不同顺序会再被打算。麻烦。
      
      

        }
      
        //
      
      
        for
      
      
        

    }

        

    
      
      
        public
      
      
        static
      
      
        void
      
       combineRepeatedData(List<Interval>
      
         attrlist) {

        
      
      
        for
      
      (
      
        int
      
       j=0; j<attrlist.size()-1 ;j++
      
        ) {

            Interval inteFront 
      
      =
      
         attrlist
        
          .get(j)
        
        ;

            Interval intevbehind 
      
      = attrlist.get(j+1
      
        );

            List
      
      <Data> listFront =
      
         inteFront.intervallist;

            List
      
      <Data> listbehind =
      
         intevbehind.intervallist;

            Data dataFront 
      
      =  listFront.get(0
      
        );

            Data  databehind 
      
      = listbehind.get(0
      
        );

    

            
      
      
        while
      
      (databehind.value == dataFront.value &&(j<attrlist.size()-1)   ) { 
      
        //
      
      
        属性list已经排序,如果后面一个data值跟前面data相同,则合并到前面的。
      
      
                        attrlist.get(j).intervallist
        
          .addAll
        
        (listbehind);  
        
          //用得不熟!!
        
        

                attrlist
        
          .remove
        
        (j
      
      +1
      
        );
        
if ((j<attrlist.size()-1 )) { inteFront = attrlist.get(j); intevbehind = attrlist.get(j+1 ); listFront = inteFront.intervallist; listbehind = intevbehind.intervallist; dataFront = listFront.get(0 ); databehind = listbehind.get(0 ); } } } }
        
          class
        
        
          IntervalComparator
        
        
          implements
        
         Comparator {  
        
          //
        
        
          升序了。
          
            对list引用类型写compartor排序方法很重要!!
          
        
        
          public
        
        
          int
        
        
           compare(Object arg0, Object arg1) {

        Interval i1 
        
        =
        
           (Interval)arg0;

        Interval i2 
        
        =
        
           (Interval)arg1;

        

        Data x1 
        
        = i1.intervallist.get(0); 
        
          //
        
        
          一开始所有区间就一个元素而已
        
        

        Data x2 = i2.intervallist.get(0
        
          );

        
        
        
          int
        
         result = 0
        
          ;

        
        
        
          if
        
        (x2.value<
        
          x1.value)

        {result 
        
        = 1
        
          ; }

        
        
        
          if
        
        (x2.value>
        
          x1.value)

        {result 
        
        = -1
        
          ; }

        
        
        
          return
        
        
           result;        

    }

}
        
      
  

 

第二步:合并区间:(直到剩下区间数目为6)

      
        public
      
      
        static
      
      
        void
      
      
         chiMerge() {

        
      
      
        for
      
      (
      
        int
      
       i=0; i<attributelists.length; i++
      
        ){

            List
      
      <Interval> attrlist =
      
        attributelists[i]; 

            
      
      
        while
      
      (attrlist.size()>6){    
      
        //
      
      
        最终的终止条件是形成6个区间
      
      
        double
      
       minchisquare = 10000000;  
      
        //
      
      
        定义一个属性里最小的卡方值 。。  变量放在的位置很重要,是放在循环里面还是外面很重要,就因为这个找错误还挑不出来,白花了两个小时
      
      
        int
      
       minchisquareindex =0;  
      
        //
      
      
        记住两个区间最小卡方值的第一个区间在属性list的下标              

                  
      
      
        //
      
      
        遍历一个属性的相邻的两个区间
      
      
        for
      
      (
      
        int
      
       j=0; j<attrlist.size()-1;j++){  
      
        //
      
      
        遍历一个属性的每个两个区间比较  
      
      

                      Interval interval1 = attrlist.get(j);   
      
        //
      
      
        要比较两个区间
      
      

                      Interval interval2 = attrlist.get(j+1
      
        ); 

      

                      Matrixs matrixs 
      
      = 
      
        buildObseredandExpectedMatrixs
      
      (attrlist,interval1, interval2); 
      
        //
      
      
        返回了两个observed,expected矩阵
      
      
        double
      
       chisquarevalue = 
      
        calchi
      
      (matrixs);          
      
        //
      
      
        计算两个区间的卡方值
      
      
        if
      
      (chisquarevalue < minchisquare ){  
      
        //
      
      
        找最小的卡方值
      
      

                          minchisquare =
      
         chisquarevalue;

                          minchisquareindex 
      
      = j; 
      
        //
      
      
        表示当前最小的卡方值的两个区间中第一个区间在该属性list的下标位置
      
      
                             }

                  }
      
      
        //
      
      
        for
      
      
         mergetwoIntervals(
      
      attrlist,minchisquareindex);  
      
        //
      
      
        合并第i个属性list里的最小两个区间。最终的合并!
      
      

              }  
      
        //
      
      
        while
      
      
                } 

    }
      
    

 

(1) 计算每一对相邻区间的卡方值

      
        public
      
      
        static
      
      
        double
      
      
         calchi(Matrixs matrixs) {

         
      
      
        double
      
      [][] observedMatrix = 
      
        new
      
      
        double
      
      [2][3
      
        ]; 

          
      
      
        double
      
      [][] expectedMatrix = 
      
        new
      
      
        double
      
      [2][3
      
        ];

          observedMatrix 
      
      =
      
         matrixs.observedMatrix;

          expectedMatrix 
      
      =
      
         matrixs.expectedMatrix;            

          

          
      
      
        //
      
      
        求卡方
      
      
        int
      
       chisquarevalue =0
      
        ;

          
      
      
        for
      
      (
      
        int
      
       r=0; r<2; r++
      
        ) {

              
      
      
        for
      
      (
      
        int
      
       c=0;c<3;c++
      
         ) {

                chisquarevalue 
      
      += (observedMatrix[r][c]- expectedMatrix[r][c]) *(observedMatrix[r][c]- expectedMatrix[r][c]) /
      
        expectedMatrix[r][c] ; 

              }

          }


      
      
        //
      
      
                  System.out.println("卡方值:"+chisquarevalue);
      
      
        return
      
      
         chisquarevalue;

    }



    
      
      
        public
      
      
        static
      
       Matrixs buildObseredandExpectedMatrixs(List<Interval> attrlist,Interval interval1,Interval interval2) {  
      
        //
      
      
        返回两个矩阵:obeserved和expected矩阵
      
      
        //
      
      
        建立observedMatrix 
      
      
        double
      
      [][] observedMatrix = 
      
        new
      
      
        double
      
      [2][3
      
        ]; 

          
      
      
        double
      
      [][] expectedMatrix = 
      
        new
      
      
        double
      
      [2][3
      
        ];

          

          
      
      
        int
      
      [] linesum = 
      
        new
      
      
        int
      
      [2] ;  
      
        //
      
      
        矩阵两行的计算
      
      
        int
      
      [] columnsum = 
      
        new
      
      
        int
      
      [3]; 
      
        //
      
      
        矩阵三列都计算
      
      
                   

          linesum[
      
      0] =
      
         interval1.intervallist.size();

          linesum[
      
      1] =
      
         interval2.intervallist.size();

          
      
      
        int
      
       allsum = linesum[0] + linesum[1
      
        ];

          columnsum[
      
      0]= columnsum[1] = columnsum[2] = 0; 
      
        //
      
      
        初始化列

          

          
      
      
        //
      
      
        取第一个区间
      
      
        for
      
      (
      
        int
      
       k=0; k< interval1.intervallist.size() ; k++) { 
      
        //
      
      
        遍历一个区间里所有元素
      
      

                Data data =
      
         interval1.intervallist.get(k);

                
      
      
        if
      
      (data.classify.equals(classifies[0])) {  
      
        //
      
      
        是类别1:Iris-setosa
      
      

                    columnsum[0]++
      
        ;

                    observedMatrix[
      
      0][0]++
      
        ;

                }

                
      
      
        else
      
      
        if
      
      (data.classify.equals(classifies[1])) {  
      
        //
      
      
        是类别2:Iris-versicolor
      
      

                    columnsum[1]++
      
        ;

                    observedMatrix[
      
      0][1]++
      
        ;

                }

                
      
      
        else
      
      
        if
      
      (data.classify.equals(classifies[2])) {  
      
        //
      
      
        是类3
      
      

                    columnsum[2]++
      
        ;

                    observedMatrix[
      
      0][2]++
      
        ;

                }

           }
      
      
        //
      
      
        for

          

          
      
      
        //
      
      
        取第2个区间
      
      
        for
      
      (
      
        int
      
       k=0; k< interval2.intervallist.size() ; k++) { 
      
        //
      
      
        遍历一个区间里所有元素
      
      

                Data data =
      
         interval2.intervallist.get(k);

                
      
      
        if
      
      (data.classify.equals(classifies[0])) {  
      
        //
      
      
        是类别1:Iris-setosa
      
      

                    columnsum[0]++
      
        ;

                    observedMatrix[
      
      1][0]++
      
        ;

                }

                
      
      
        else
      
      
        if
      
      (data.classify.equals(classifies[1])) {  
      
        //
      
      
        是类别2:Iris-versicolor
      
      

                    columnsum[1]++
      
        ;

                    observedMatrix[
      
      1][1]++
      
        ;

                }

                
      
      
        else
      
      
        if
      
      (data.classify.equals(classifies[2])) {  
      
        //
      
      
        是类3
      
      

                    columnsum[2]++
      
        ;

                    observedMatrix[
      
      1][2]++
      
        ;

                }

           }
      
      
        //
      
      
        for       

      

          
      
      
        //
      
      
        建立expectedMatrix
      
      
        for
      
      (
      
        int
      
       r=0; r<2; r++
      
        ) {

            
      
      
        for
      
      (
      
        int
      
       c=0;c<3;c++
      
         ) {

                expectedMatrix[r][c]
      
      = linesum[r] * columnsum[c] /
      
        allsum;

                
      
      
        if
      
      (expectedMatrix[r][c]==0.0
      
        )

                    expectedMatrix[r][c]
      
      =0.0001; 
      
        //
      
      
        因为求卡方的时候,这个值会作分母,所以分母不能作0.分母变小,则卡方值就大,卡方值越大,越不相似,越不会被合并了
      
      
                    } 

          }

          

         Matrixs matrixs 
      
      = 
      
        new
      
      
         Matrixs();

         matrixs.expectedMatrix 
      
      =
      
         expectedMatrix;

         matrixs.observedMatrix 
      
      =
      
         observedMatrix;

         

         
      
      
        return
      
      
         matrixs;

    }
      
      
        

}




      
      
        class
      
      
         Matrixs {

    
      
      
        public
      
      
        double
      
      [][] observedMatrix = 
      
        new
      
      
        double
      
      [2][3
      
        ];

    
      
      
        public
      
      
        double
      
      [][] expectedMatrix  = 
      
        new
      
      
        double
      
      [2][3
      
        ];

}
      
    

(2) 将上面卡方值最小的一对区间合并

      
        public
      
      
        static
      
      
        void
      
       mergetwoIntervals(List<Interval> attrlist,
      
        int
      
      
         minchisquareindex) {
        
// List<Interval> attrlist =attributelists[0]; // 将当前最小的卡方值对应的两个区间进行合并;删去已被合并的区间 List<Data> mergedlist = attrlist.get(minchisquareindex+1).intervallist; // 被合并的区间里的数据list attrlist.get(minchisquareindex).intervallist.addAll(mergedlist); attrlist.get(minchisquareindex).maxvalue = attrlist.get(minchisquareindex+1).maxvalue; // 该区间的最大值是第二个区间的最大值,因为区间已经排过序了 attrlist.remove(minchisquareindex+1); // 该属性删去已被合并的区间 }

 

 

 第三步:输出结果:6个区间的最大最小值

      
        public
      
      
        static
      
      
        void
      
      
         printresult() {

        
      
      
        for
      
      (
      
        int
      
       i=0; i<attributenum; i++
      
        ){

            System.out.println(
      
      "第"+(i+1)+"个属性:"
      
        );

             
      
      
        for
      
      (
      
        int
      
       j=0; j<attributelists[i].size(); j++) {  
      
        //
      
      
        每个属性是list,遍历属性list每一个元素
      
      

                 Interval in =
      
         attributelists[i].get(j); 

                 System.out.println(
      
      "( "+in.minvalue +" , " + in.maxvalue + " )" );  
      
        //
      
      
        每个interval类里的list每个元素都是一个Data类型    
      
      
                     }

         }

    }
      
    

 

 最终结果如下:

      
        第1个属性:

( 
      
      4.3 , 4.8
      
         ) ( 
      
      4.9 , 5.2
      
         ) ( 
      
      5.3 , 5.3
      
         ) ( 
      
      5.4 , 6.9
      
         ) ( 
      
      7.0 , 7.0
      
         ) ( 
      
      7.1 , 7.9
      
         )

第2个属性:

( 
      
      2.0 , 2.0
      
         )( 
      
      2.2 , 2.2
      
         ) ( 
      
      2.3 , 2.3
      
         ) ( 
      
      2.4 , 3.5
      
         ) ( 
      
      3.6 , 3.6
      
         ) ( 
      
      3.7 , 4.4
      
         )

第3个属性:

( 
      
      1.0 , 1.9
      
         ) ( 
      
      3.0 , 4.4
      
         ) ( 
      
      4.5 , 4.5
      
         ) ( 
      
      4.6 , 4.7
      
         ) ( 
      
      4.8 , 5.1
      
         ) ( 
      
      5.2 , 6.9
      
         )

第4个属性:

( 
      
      0.1 , 0.6
      
         ) ( 
      
      1.0 , 1.5
      
         ) ( 
      
      1.6 , 1.6
      
         ) ( 
      
      1.7 , 1.7
      
         ) ( 
      
      1.8 , 1.8
      
         )  ( 
      
      1.9 , 2.5 )
    

 

 

 

ChiMerge算法 (java)


更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描下面二维码支持博主2元、5元、10元、20元等您想捐的金额吧,狠狠点击下面给点支持吧,站长非常感激您!手机微信长按不能支付解决办法:请将微信支付二维码保存到相册,切换到微信,然后点击微信右上角扫一扫功能,选择支付二维码完成支付。

【本文对您有帮助就好】

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描上面二维码支持博主2元、5元、10元、自定义金额等您想捐的金额吧,站长会非常 感谢您的哦!!!

发表我的评论
最新评论 总共0条评论