Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
123 views
in Technique[技术] by (71.8m points)

algorithm - scala median implementation

What's a fast implementation of median in scala?

This is what I found on rosetta code:

  def median(s: Seq[Double])  =
  {
    val (lower, upper) = s.sortWith(_<_).splitAt(s.size / 2)
    if (s.size % 2 == 0) (lower.last + upper.head) / 2.0 else upper.head
  }

I don't like it because it does a sort. I know there are ways to compute the median in linear time.

EDIT:

I would like to have a set of median functions that I can use in various scenarios:

  1. fast, in place median computation that can be done in linear time
  2. median that works on a stream that you can traverse multiple times, but you can only keep O(log n) values in memory like this
  3. median that works on a stream, where you can hold at most O(log n) values in memory, and you can traverse the stream at most once (is this even possible?)

Please only post code that compiles and correctly computes the median. For simplicity, you may assume that all inputs contain an odd number of values.

question from:https://stackoverflow.com/questions/4662292/scala-median-implementation

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Immutable Algorithm

The first algorithm indicated by Taylor Leese is quadratic, but has linear average. That, however, depends on the pivot selection. So I'm providing here a version which has a pluggable pivot selection, and both the random pivot and the median of medians pivot (which guarantees linear time).

import scala.annotation.tailrec

@tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double = {
    val a = choosePivot(arr)
    val (s, b) = arr partition (a >)
    if (s.size == k) a
    // The following test is used to avoid infinite repetition
    else if (s.isEmpty) {
        val (s, b) = arr partition (a ==)
        if (s.size > k) a
        else findKMedian(b, k - s.size)
    } else if (s.size < k) findKMedian(b, k - s.size)
    else findKMedian(s, k)
}

def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2)

Random Pivot (quadratic, linear average), Immutable

This is the random pivot selection. Analysis of algorithms with random factors is trickier than normal, because it deals largely with probability and statistics.

def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size))

Median of Medians (linear), Immutable

The median of medians method, which guarantees linear time when used with the algorithm above. First, and algorithm to compute the median of up to 5 numbers, which is the basis of the median of medians algorithm. This one was provided by Rex Kerr in this answer -- the algorithm depends a lot on the speed of it.

def medianUpTo5(five: Array[Double]): Double = {
  def order2(a: Array[Double], i: Int, j: Int) = {
    if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t }
  }

  def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) = {
    if (a(i)<a(k)) { order2(a,j,k); a(j) }
    else { order2(a,i,l); a(i) }
  }

  if (five.length < 2) return five(0)
  order2(five,0,1)
  if (five.length < 4) return (
    if (five.length==2 || five(2) < five(0)) five(0)
    else if (five(2) > five(1)) five(1)
    else five(2)
  )
  order2(five,2,3)
  if (five.length < 5) pairs(five,0,1,2,3)
  else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) }
  else { order2(five,3,4); pairs(five,0,1,3,4) }
}

And, then, the median of medians algorithm itself. Basically, it guarantees that the choosen pivot will be greater than at least 30% and smaller than other 30% of the list, which is enough to guarantee the linearity of the previous algorithm. Look up the wikipedia link provided in another answer for details.

def medianOfMedians(arr: Array[Double]): Double = {
    val medians = arr grouped 5 map medianUpTo5 toArray;
    if (medians.size <= 5) medianUpTo5 (medians)
    else medianOfMedians(medians)
}

In-place Algorithm

So, here's an in-place version of the algorithm. I'm using a class that implements a partition in-place, with a backing array, so that the changes to the algorithms are minimal.

case class ArrayView(arr: Array[Double], from: Int, until: Int) {
    def apply(n: Int) = 
        if (from + n < until) arr(from + n)
        else throw new ArrayIndexOutOfBoundsException(n)

    def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) = {
      var upper = until - 1
      var lower = from
      while (lower < upper) {
        while (lower < until && p(arr(lower))) lower += 1
        while (upper >= from && !p(arr(upper))) upper -= 1
        if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp }
      }
      (copy(until = lower), copy(from = lower))
    }

    def size = until - from
    def isEmpty = size <= 0

    override def toString = arr mkString ("ArraySize(", ", ", ")")
}; object ArrayView {
    def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size)
}

@tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double = {
    val a = choosePivot(arr)
    val (s, b) = arr partitionInPlace (a >)
    if (s.size == k) a
    // The following test is used to avoid infinite repetition
    else if (s.isEmpty) {
        val (s, b) = arr partitionInPlace (a ==)
        if (s.size > k) a
        else findKMedianInPlace(b, k - s.size)
    } else if (s.size < k) findKMedianInPlace(b, k - s.size)
    else findKMedianInPlace(s, k)
}

def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2)

Random Pivot, In-place

I'm only implementing the radom pivot for the in-place algorithms, as the median of medians would require more support than what is presently provided by the ArrayView class I defined.

def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size))

Histogram Algorithm (O(log(n)) memory), Immutable

So, about streams. It is impossible to do anything less than O(n) memory for a stream that can only be traversed once, unless you happen to know what the string length is (in which case it ceases to be a stream in my book).

Using buckets is also a bit problematic, but if we can traverse it multiple times, then we can know its size, maximum and minimum, and work from there. For example:

def findMedianHistogram(s: Traversable[Double]) = {
    def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double = {
        // The buckets
        def numberOfBuckets = (math.log(s.size).toInt + 1) max 2
        val buckets = new Array[Int](numberOfBuckets)

        // The upper limit of each bucket
        val max = s.max
        val min = s.min
        val increment = (max - min) / numberOfBuckets
        val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _)

        // Return the bucket a number is supposed to be in
        def bucketIndex(d: Double) = indices indexWhere (d <=)

        // Compute how many in each bucket
        s foreach { d => buckets(bucketIndex(d)) += 1 }

        // Now make the buckets cumulative
        val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1)

        // The bucket where our target is at
        val medianBucket = partialTotals indexWhere (medianIndex <)

        // Keep track of how many numbers there are that are less 
        // than the median bucket
        val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1)

        // Test whether a number is in the median bucket
        def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket

        // Get a view of the target bucket
        val view = s.view filter insideMedianBucket

        // If all numbers in the bucket are equal, return that
        if (view forall (view.head ==)) view.head
        // Otherwise, recurse on that bucket
        else medianHistogram(view, newDiscarded, medianIndex)
    }

    medianHistogram(s, 0, (s.size - 1) / 2)
}

Test and Benchmark

To test the algorithms, I'm using Scalacheck, and comparing the output of each algorithm to the output of a trivial implementation with sorting. That assumes the sorting version is correct, of course.

I'm benchmarking each of the above algorithms with all provided pivot selections, plus a fixed pivot selection (halfway the array, round down). Each algorithm is tested with three different input array sizes, and for three times against each one.

Here's the testing code:

import org.scalacheck.{Prop, Pretty, Test}
import Prop._
import Pretty._

def test(algorithm: Array[Double] => Double, 
         reference: Array[Double] => Double): String = {
    def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")
    val resultEqualsReference = forAll { (arr: Array[Double]) => 
        arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)
    }
    Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0))
}

import java.lang.System.currentTimeMillis

def bench[A](n: Int)(body: => A): Long = {
  val start = currentTimeMillis()
  1 to n foreach { _ => body }
  currentTimeMillis() - start
}

import scala.util.Random.nextDouble

def benchmark(algorithm: Array[Double] => Double,
              arraySizes: List[Int]): List[Iterable[Long]] = 
    for (size <- arraySizes)
    yield for (iteration <- 1 to 3)
        yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))

def testAndBenchmark: String = {
    val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(
        "Random Pivot"      -> chooseRandomPivot,
        "Median of Medians" -> medianOfMedians,
        "Midpoint"          -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))
    )
    val inPlacePivotSelection: List[(String, ArrayView => Double)] = List(
        "Random Pivot (in-place)" -> chooseRandomPivotInPlace,
        "Midpoint (in-place)"     -> ((arr: ArrayView) => arr((arr.size - 1) / 2))
    )
    val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)
        yield name -> (findMedian(_: Array[Double])(pivotSelection))
    val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection)
        yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection))
    val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr))
    val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2))
    val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms

    val formattingString = "%%-%ds  %%s" format (algorithms map (_._1.length) max)

    // Tests
    val testResults = for ((name, algorithm) <- algorithms)
        yield formattingString format (name, test(algorithm, sortingAlgorithm._2))

    // Benchmarks
    val arraySizes = List(100, 500, 1000)
    def formatResults(results: List[Long]) = results map ("%8d" format _) mkString

    val benchmarkResults: List[String] = for {
        (name, algorithm) <- algorithms
        results <- benchmark(algorithm, arraySizes).transpose
    } yield formattingString format (name, formatResults(results))

    val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong)))

    "Tests" :: "*****" :: testResults ::: 
    ("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "
", "
")
}

Results

Tests:

Tests
*****
Sorting                OK, passed 100 tests.
Histogram              OK, passed 100 tests.
Random Pivot           OK, passed 100 tests.
Median of Medians      OK, passed 100 tests.
Midpoint               OK, passed 100 tests.
Random Pivot (

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...