fix: simplify parallel map using channels (#582)

This commit is contained in:
Craig Andrews 2025-07-03 11:43:10 +01:00 committed by GitHub
parent 2ace57404b
commit 73c012c76c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 45 additions and 32 deletions

View file

@ -2,49 +2,39 @@ package util
import (
"strings"
"sync"
)
// MapReducePar performs a parallel map-reduce operation on a slice of items.
// It applies a function to each item in the slice concurrently,
// and combines the results serially using a reducer returned from
// each one of the functions, allowing the use of closures.
func MapReducePar[a, b any](items []a, init b, fn func(a) func(b) b) b {
itemCount := len(items)
locks := make([]*sync.Mutex, itemCount)
mapped := make([]func(b) b, itemCount)
func mapParallel[in, out any](items []in, fn func(in) out) chan out {
mapChans := make([]chan out, 0, len(items))
for i, value := range items {
lock := &sync.Mutex{}
lock.Lock()
locks[i] = lock
for _, v := range items {
ch := make(chan out)
mapChans = append(mapChans, ch)
go func() {
defer lock.Unlock()
mapped[i] = fn(value)
defer close(ch)
ch <- fn(v)
}()
}
result := init
for i := range itemCount {
locks[i].Lock()
defer locks[i].Unlock()
f := mapped[i]
if f != nil {
result = f(result)
}
}
resultChan := make(chan out)
return result
go func() {
defer close(resultChan)
for _, ch := range mapChans {
v := <-ch
resultChan <- v
}
}()
return resultChan
}
// WriteStringsPar allows to iterate over a list and compute strings in parallel,
// yet write them in order.
func WriteStringsPar[a any](sb *strings.Builder, items []a, fn func(a) string) {
MapReducePar(items, sb, func(item a) func(*strings.Builder) *strings.Builder {
str := fn(item)
return func(sbdr *strings.Builder) *strings.Builder {
sbdr.WriteString(str)
return sbdr
}
})
ch := mapParallel(items, fn)
for v := range ch {
sb.WriteString(v)
}
}

View file

@ -0,0 +1,23 @@
package util_test
import (
"strconv"
"strings"
"testing"
"time"
"github.com/sst/opencode/internal/util"
)
func TestWriteStringsPar(t *testing.T) {
items := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
sb := strings.Builder{}
util.WriteStringsPar(&sb, items, func(i int) string {
// sleep for the inverse duration so that later items finish first
time.Sleep(time.Duration(10-i) * time.Millisecond)
return strconv.Itoa(i)
})
if sb.String() != "0123456789" {
t.Fatalf("expected 0123456789, got %s", sb.String())
}
}