mirror of
https://github.com/sst/opencode.git
synced 2025-08-04 13:30:52 +00:00
fix: simplify parallel map using channels (#582)
This commit is contained in:
parent
2ace57404b
commit
73c012c76c
2 changed files with 45 additions and 32 deletions
|
@ -2,49 +2,39 @@ package util
|
|||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MapReducePar performs a parallel map-reduce operation on a slice of items.
|
||||
// It applies a function to each item in the slice concurrently,
|
||||
// and combines the results serially using a reducer returned from
|
||||
// each one of the functions, allowing the use of closures.
|
||||
func MapReducePar[a, b any](items []a, init b, fn func(a) func(b) b) b {
|
||||
itemCount := len(items)
|
||||
locks := make([]*sync.Mutex, itemCount)
|
||||
mapped := make([]func(b) b, itemCount)
|
||||
func mapParallel[in, out any](items []in, fn func(in) out) chan out {
|
||||
mapChans := make([]chan out, 0, len(items))
|
||||
|
||||
for i, value := range items {
|
||||
lock := &sync.Mutex{}
|
||||
lock.Lock()
|
||||
locks[i] = lock
|
||||
for _, v := range items {
|
||||
ch := make(chan out)
|
||||
mapChans = append(mapChans, ch)
|
||||
go func() {
|
||||
defer lock.Unlock()
|
||||
mapped[i] = fn(value)
|
||||
defer close(ch)
|
||||
ch <- fn(v)
|
||||
}()
|
||||
}
|
||||
|
||||
result := init
|
||||
for i := range itemCount {
|
||||
locks[i].Lock()
|
||||
defer locks[i].Unlock()
|
||||
f := mapped[i]
|
||||
if f != nil {
|
||||
result = f(result)
|
||||
}
|
||||
}
|
||||
resultChan := make(chan out)
|
||||
|
||||
return result
|
||||
go func() {
|
||||
defer close(resultChan)
|
||||
for _, ch := range mapChans {
|
||||
v := <-ch
|
||||
resultChan <- v
|
||||
}
|
||||
}()
|
||||
|
||||
return resultChan
|
||||
}
|
||||
|
||||
// WriteStringsPar allows to iterate over a list and compute strings in parallel,
|
||||
// yet write them in order.
|
||||
func WriteStringsPar[a any](sb *strings.Builder, items []a, fn func(a) string) {
|
||||
MapReducePar(items, sb, func(item a) func(*strings.Builder) *strings.Builder {
|
||||
str := fn(item)
|
||||
return func(sbdr *strings.Builder) *strings.Builder {
|
||||
sbdr.WriteString(str)
|
||||
return sbdr
|
||||
}
|
||||
})
|
||||
ch := mapParallel(items, fn)
|
||||
|
||||
for v := range ch {
|
||||
sb.WriteString(v)
|
||||
}
|
||||
}
|
||||
|
|
23
packages/tui/internal/util/concurrency_test.go
Normal file
23
packages/tui/internal/util/concurrency_test.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package util_test
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sst/opencode/internal/util"
|
||||
)
|
||||
|
||||
func TestWriteStringsPar(t *testing.T) {
|
||||
items := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
sb := strings.Builder{}
|
||||
util.WriteStringsPar(&sb, items, func(i int) string {
|
||||
// sleep for the inverse duration so that later items finish first
|
||||
time.Sleep(time.Duration(10-i) * time.Millisecond)
|
||||
return strconv.Itoa(i)
|
||||
})
|
||||
if sb.String() != "0123456789" {
|
||||
t.Fatalf("expected 0123456789, got %s", sb.String())
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue