Searching for RH Counterexamples in Golang - Improving Performance via Parallel processing

Performance

Although I had promised in the previous post to follow the original series closely, a conversation with a friend prompted me to spend some time to see if I can improve upon the current performance of both Python and Golang by leveraging parallel processing. So I am going to take a small detour to figure out if the function performance can be improved using parallel processing before coming back to the database implementation in the next post.

In Python, numba allows parallel processing by calling the appropriate APIs and using the parallel option. Golang has native support for goroutines which allows concurrency and parallel processing trivially. In this post I will be trying out both approaches to see how they fare.

I will not be rigorously following benchmarking techniques to get the most precise results. The focus will be mostly on the approaches rather than the exact, reproducible results.

Status Quo

Before we start, let us revisit the where we left off, in both Python and Golang. I have also added a longer 20M test to make sure that the tests run slightly longer to reduce the impact of startup/teardown times.

Implementation/Time(s) 10M 20M
Python ~25 ~60
Golang ~15 ~40

We can also take a look at the code that gave the above results:

We start with DivisorSum function. In Python, we have added JIT compilation using the numba njit decorator which gives us the performance benefits to keeping it competitive with the compiled Go code.

@njit
def divisor_sum(n: int) -> int:
    '''Compute the sum of divisors of a positive integer.'''
    if n <= 0:
        return 0

    i = 1
    limit = math.sqrt(n)
    the_sum = 0
    while i <= limit:
        if n % i == 0:
            the_sum += i
            if i != limit:
                the_sum += n // i
        i += 1

    return the_sum

This is the equivalent Golang code.

func DivisorSum(n int64) (int64, error) {
	if n <= 0 {
		return 0, errors.New("value cannot be less than or equal to zero")
	}
	limit := math.Sqrt(float64(n))
	sum := int64(0)

	for i := int64(1); float64(i) <= limit; i++ {
		if n%i == 0 {
			sum += i
			if float64(i) != limit {
				sum += n / i
			}
		}
	}
	return sum, nil
}

On to the wrapper function in Python, which runs the above function across all the numbers serially, to find the best Witness.

def witness_value(n: int) -> float:
    denominator = n * math.log(math.log(n))
    return divisor_sum(n) / denominator

def best_witness(max_range: int, search_start: int = SEARCH_START) -> int:
    return max(range(search_start, max_range), key=witness_value)

The Golang implementation is trivial as well

func BestWitness(maxRange, searchStart int64) (int64, float64) {
	maxVal := 0.0
	bestWitness := searchStart
	for i := searchStart; i < maxRange; i++ {
		currentWitness := WitnessValue(i, -1)
		if currentWitness > maxVal {
			bestWitness = i
			maxVal = currentWitness
		}
	}
	return bestWitness, maxVal
}

Finally, the tests in each language.

Python:

def test_best_witness_million():
    start_time = datetime.datetime.now()
    output = best_witness(20_000_000, search_start = 11000)
    end_time = datetime.datetime.now()
    print("Total time: ", end_time - start_time)
    assert 55440 == output

Golang:

	It("should find best witness successfully", func() {
		count_till := int64(20_000_000)

		output, witnessVal := riemann.BestWitness(count_till, 11000)
		fmt.Println("\nCurrent Best till", humanize.Comma(int64(count_till)), "is", output, "at value", witnessVal)

		Expect(output).To(Equal(int64(55440)))
	})

Parallelization

Python

For python, numba makes parallelizing a piece of cake by adding the parallel=True argument to the njit decorator. However, while doing so, the parallel functions do not necessarily run in any particular order and we cannot map it to an object like we did previously. Instead, we need to make sure that the array is preallocated and that each parallel run can independently write to this preallocated array. At the end of the calculation, we use np.argmax to find the index of the object with the maximum value.

given below:

@njit(parallel = True)
def best_witness(max_range: int, search_start: int = SEARCH_START) -> int:
    witness_array = np.zeros(max_range)
    for i in prange(search_start, max_range):
        witness_array[i] = witness_value(i)
    return np.argmax(witness_array)

With the above parallel code, in my Macbook M1 Air, the process ran in ~(20) seconds. This is 3x the speed of the non-parallelized version!

Golang

In golang, there is a little bit more effort to be done in parallelizing the code. First we need to define a Witness structure to hold the results.

type witness struct {
	idx   int64
	value float64
}

We make a buffered channel of the witness type with a buffer size equal to the number of iterations:

    allValues := make(chan witness, maxRange-searchStart)

Then we make a sync group and add the counter equivalent to the number of iterations to make sure that we proceed only when all of the iterations are done.

	var wg sync.WaitGroup
	wg.Add(int(maxRange - searchStart))

Now we iterate across the elements and start one goroutine for each iteration of finding the Witness value. Within each iteration, push the output to the channel once the process has calculated the results. We also close the Waitgroup counter using the defer keyword once the function has finished running.

	for i := searchStart; i < maxRange; i++ {
		go func(j int64) {
			defer wg.Done()
			currentWitness := WitnessValue(j, -1)
			allValues <- witness{j, currentWitness}
		}(i)
	}

We wait for all the iterations to be done using the wg.Wait() command and then close the channel. Then we can iterate over the channel to find the index of the elements with the maximum Witness value.

	wg.Wait()
	close(allValues)
	bestWitness := int64(0)
	maxVal := -1.0
	for val := range allValues {
		if val.value > maxVal {
			maxVal = val.value
			bestWitness = val.idx
		}
	}

Here is the final function in all its glory -

func BestWitness(maxRange, searchStart int64) (int64, float64) {
	allValues := make(chan witness, maxRange-searchStart)
	var wg sync.WaitGroup
	wg.Add(int(maxRange - searchStart))
	for i := searchStart; i < maxRange; i++ {
		go func(j int64) {
			defer wg.Done()
			currentWitness := WitnessValue(j, -1)
			allValues <- witness{j, currentWitness}
		}(i)
	}
	wg.Wait()
	close(allValues)
	bestWitness := int64(0)
	maxVal := -1.0
	for val := range allValues {
		if val.value > maxVal {
			maxVal = val.value
			bestWitness = val.idx
		}
	}
	return bestWitness, maxVal
}

The above function took almost 40 seconds to run, which is roughly the same as the non-parallel version! However, I could see that the parallel cores were getting used but I realized that the overhead of creating 20million goroutines was too much, causing Amdahl’s law to come in to play, negating any advantages of parallelization. We will have to do better!

Parallel Optimization

In the next iteration of the code, I try a more optimized approach, wherein I restrict the parallelism to a lower number, say 100. The work splitting function is smarter, and split the work across the limited number of goroutines evenly.

func BestWitness(maxRange, searchStart int64) (int64, float64) {
	parallelism := int64(100)
	allValues := make(chan witness, parallelism)
	var wg sync.WaitGroup
	wg.Add(int(parallelism))
	for i := int64(0); i < parallelism; i++ {
		go func(j int64) {
			defer wg.Done()
			maxVal := 0.0
			bestWitness := searchStart
			for k := searchStart + j; k < maxRange; k += parallelism {
				currentWitness := WitnessValue(k, -1)
				if currentWitness > maxVal {
					maxVal = currentWitness
					bestWitness = k
				}
			}
			allValues <- witness{bestWitness, maxVal}
		}(i)
	}
	wg.Wait()
	close(allValues)
	bestWitness := int64(0)
	maxVal := -1.0
	for val := range allValues {
		if val.value > maxVal {
			maxVal = val.value
			bestWitness = val.idx
		}
	}
	return bestWitness, maxVal
}

Now the test runs in 10s, which is ~4x faster than the non-parallel version!

Revisiting Python

We can now refactor the python implementation using the the same optimization logic.

@njit(parallel=True)
def best_witness(max_range: int, search_start: int = SEARCH_START) -> int:
    parallelism = 100
    witness_array = np.zeros(max_range)
    for i in prange(parallelism):
        current_best_witness = -1
        for k in range(search_start + i, max_range, parallelism):
            witness = witness_value(k)
            if witness > current_best_witness:
                current_best_witness = witness
                witness_array[k] = witness
    return np.argmax(witness_array)
}

This optimized implementation takes ~15 seconds to run, which is 4x faster than the non-parallel version, and 1.3x faster than the naive implementation!

big.LITTLE

I was initially concerned that the maximum performance improvement that we could attain with multi core processing was only ~4x, and not closer to 7x with 8 cores, in my Macbook. Turns out the M1 processor has a big.LITTLE architecture with 4 heavy duty cores and 4 smaller efficiency cores which are much slower. This means, we’re constrained heavily by the big 4 cores and it is unlikely that we can get a faster performance. Hence, the 4x speed up that we got in both Python and Golang makes perfect sense.

Conclusion

Here are the final results of the different approaches we tried out:

Implementation/Time(s) serial parallel parallel(optimized)
Python ~60 ~20 ~15
Golang ~40 ~40 ~10

Looks like we were able to improve both the implementations by 4x, which I suppose is the ideal maximum speed up we can get. With the best implementation, Golang continues to retain the 30% speedup over Python. However,the Python implementation was way simpler than the Golang one which is always an aspect (probably the most important one) we need to consider before running behind benchmarks.