閱讀930 返回首頁    go 阿裏雲 go 技術社區[雲棲]


初學Scala(1): Scala實現Hash Join

最近看了些Scala相關的內容,寫了個簡單的hash join。

初步實現

jion過程

  1. 從數據源讀取兩個List[List[Any]](),我把所有的操作都放到List容器裏進行
  2. 將兩份數據集,hash到自己寫的簡單的SimpleHashTable裏,每次put進去的時候會返回一個Int值,用於記錄兩份數據占據的bucket number集合
  3. 由於兩份數據都是基於同一個hash方法進行hash的,join具體發生在兩個hashTable對應的bucket之間
  4. 遍曆需要進行join的buckets,每對bucket之間的join回歸到最基本的二層遍曆
幾點說明

  1. 整個過程一共兩個文件,SimpleHashTable.scala和HashJoin.scala
  2. 輸入是兩個二維的List,輸出是一個二維List,支持的是單個鍵的inner join
  3. 測試速度:兩個10000大小的20個字段的寬表進行hash join,大約0.4s
  4. HashTable的M值可以針對數據集大小自己定製,盡量讓數據集在buckets裏打散

可以改進的點有很多,這個hash join還是相當簡單的,我比較依賴於foldLeft和map方法,體會到Scala編程代碼量很少,用起來很舒服,很強大。

class SimpleHashTable {
  
  val M = 991
  
  val container = new Array[List[Any]](M)

  for (i <- 0 to M-1) {
    container(i) = List[Any]()
  }

  def hash(key: String): Int = (key.hashCode() & 0x7fffffff) % M

  def put(key: String, value: List[Any]): Int = { // return the hash number
    val indice = hash(key)
    container(indice) = value :: container(indice)
    indice
  }
  
  def get(indice: Int): List[Any] = container(indice) 

  def get(key: String): List[Any] = get(hash(key))

  def dataset() = container

}


class HashJoin(list1: List[List[Any]], list2: List[List[Any]]) {

  val _list1 = list1
  val _list2 = list2
  
  def innerHashJion(col: Int): List[Any] = {
    val start = System.currentTimeMillis()
    var keys1 = Set[Int]()
    var keys2 = Set[Int]()

    val sht1 = _list1.par.foldLeft(new SimpleHashTable) { (sht, list) =>
      val i = sht.put(list(col).toString, list)
      keys1 = keys1 + i
      sht
    }
    
    val sht2 = _list2.par.foldLeft(new SimpleHashTable) { (sht, list) =>
      val i = sht.put(list(col).toString, list)
      keys2 = keys2 + i
      sht
    }
    val end = System.currentTimeMillis()
    println("Hash took: " + (end-start) + "ms")
    getJointRecords((keys1&keys2).toArray, sht1, sht2, col)
  }

  def getJointRecords(inds: Array[Int], sht1: SimpleHashTable, sht2: SimpleHashTable, col: Int): List[Any] = {
    println("joint-keys: " + inds.size)
    var ret = scala.collection.immutable.List[Any]()
    inds.par.foreach(ind => { 
      println(Thread.currentThread)
      sht1.get(ind).map(record1 => {
        sht2.get(ind).map(record2 => {
          val r1 = record1.asInstanceOf[List[Any]]
          val r2 = record2.asInstanceOf[List[Any]]
          if (r1(col) == r2(col)) ret = (r1 ::: r2) :: ret
        })
      })
    })
    ret
  }

}

測試可以使用下麵單例:

object HashJoinTest {
  def main(args: Array[String]): Unit = {
    test()
  }

  def test(): Unit = {
    val c1 = List(111, "asfd", 23)
    val c11 = List(111, "asf", 231)
    val c2 = List(333, "e",    1)
    val c3 = List(222, "ewr",  80)

    val t1 = List(111, "e",    40)
    val t11 = List(111, "fge", 30)
    val t2 = List(333, "asfd", 80)
    val t3 = List(444, "e",    1)

    val list1 = List(c1, c11, c2, c3)
    val list2 = List(t1, t11, t2, t3)
    
    val hj = new HashJoin(list1, list2)
    val ret = hj.innerHashJion(2)
    for (i <- (0 to 1)) println(ret(i))
  }
}


優化

上麵的這種實現,在join結果集並發往同一個List()容器裏寫的時候會出現性能瓶頸,寫的速度會達到10W-100W行/s,而且需要在寫的時候加上synchronized實現同步。雖然scala.collection.immutable.List類是不可變的,也是線程安全的,但是在1W join 1W的測試中,0.4s內寫入10W行出現了數據丟失,加上synchronized字段可以簡單避免這個問題,但同時帶來了額外開銷。

下麵新的HashJoin.scala類,為每個需要join的bucket申請了一個數組空間,讓每個線程返回的單個bucket join結果集保存在統一的數組中,最後對結果集進行merge,同時保留了並發求join的特性。

優化HashJoin.scala類之後,測試速度 1W join 1W 隻要 0.1s,2W join 2W 時間是 0.2s-0.4s,(M=991的情況下,M可以調整)

class HashJoin(list1: List[List[Any]], list2: List[List[Any]]) {

  val _list1 = list1
  val _list2 = list2
  val M = 991
  val retContainer = new Array[List[Any]](M)
  for (i <- 0 to M-1) retContainer(i) = List[Any]()
  
  var ret = List[Any]()

  def innerHashJion(col: Int): Unit = {
    val start = System.currentTimeMillis()

    var keys1 = Set[Int]()
    var keys2 = Set[Int]()

    val sht1 = _list1.par.foldLeft(new SimpleHashTable) { (sht, list) =>
      val i = sht.put(list(col).toString, list)
      keys1 = keys1 + i
      sht
    }
    
    val sht2 = _list2.par.foldLeft(new SimpleHashTable) { (sht, list) =>
      val i = sht.put(list(col).toString, list)
      keys2 = keys2 + i
      sht
    }

    val end = System.currentTimeMillis()
    println("Hash took: " + (end-start) + "ms")
    
    val jointKeys = (keys1&keys2).toArray
    println("JointKeys Size: " + jointKeys.size)
    jointKeys.par.foreach(ind => retContainer(ind) = getBucketRecords(ind, sht1, sht2, col))

    def getBucketRecords(ind: Int, sht1: SimpleHashTable, sht2: SimpleHashTable, col: Int): List[Any] = {
      var bucketRet = List[Any]()
      sht1.get(ind).map(record1 => {
        sht2.get(ind).map(record2 => {
          val r1 = record1.asInstanceOf[List[Any]]
          val r2 = record2.asInstanceOf[List[Any]]
          if (r1(col) == r2(col)) bucketRet = (r1 ::: r2) :: bucketRet
        })
      })
      bucketRet
    }
  }

  def getRet: List[Any] = {
    mergeRets
    ret
  }

  def mergeRets = {
    val t1 = System.currentTimeMillis()
    retContainer.foreach({r =>
      ret = r ::: ret
    })
    val t2 = System.currentTimeMillis()
    println("Merge Rets took: " + (t2-t1) + " ms")
  }
}
我的測試單例如下,數據來自mongodb,進行了一次BSON to List的轉換,可以替換掉傳入的list1和list2,傳入自己想要的測試數據:

object HashJoinTest {
  def main(args: Array[String]): Unit = {
    mongo()
  }
   
  def mongo(): Unit = {
    val loadS = System.currentTimeMillis()
    val list1 = BsonToList.getMongoList(0, 10000)
    val list2 = BsonToList.getMongoList(100000, 10000)
    val loadE = System.currentTimeMillis()
    println("Load Data took: " + (loadE-loadS) + "ms")

    val hj = new HashJoin(list1, list2)
    hj.innerHashJion(8)
    val ret = hj.getRet 
    val joinE = System.currentTimeMillis()
    println("HashJoin Totally took: " + (joinE-loadE) + "ms")
    
    println("Result size: " + ret.size)
    for (i <- (0 to 1)) println(ret(i))
  }
}


後續如果有優化結果,還會更新在這裏。

(全文完)

最後更新:2017-04-03 14:54:38

  上一篇:go 時間子係統12_clockevent設備注冊
  下一篇:go 動態設置android:drawableLeft|Right|Top|Bottom