From 3da3ddf10c841d9abac8a51b118ab8260519f5dd Mon Sep 17 00:00:00 2001 From: Arthur Lu Date: Tue, 5 Nov 2024 18:28:28 +0000 Subject: [PATCH] major optimization by packing wavefront values --- pkg/custom_slice.go | 56 ------------------- pkg/debug.go | 12 ++-- pkg/types.go | 130 ++++++++++++++++++++++++++++++-------------- pkg/utils.go | 53 +++++++++--------- pkg/wfa.go | 30 +++++----- test/wfa_test.go | 19 +++++++ 6 files changed, 154 insertions(+), 146 deletions(-) diff --git a/pkg/custom_slice.go b/pkg/custom_slice.go index 958ef3e..5720372 100644 --- a/pkg/custom_slice.go +++ b/pkg/custom_slice.go @@ -1,61 +1,5 @@ package wfa -import ( - "golang.org/x/exp/constraints" -) - -type Wavefront[T constraints.Integer] struct { // since wavefronts store diag distance, they should never be negative, and traceback data can be stored as uint8 - data []T - valid []bool - lo int -} - -func NewWavefront[T constraints.Integer](lo int, hi int) *Wavefront[T] { - a := &Wavefront[T]{} - - a.lo = lo - size := a.TranslateIndex(hi) - - newData := make([]T, size+1) - a.data = newData - - newValid := make([]bool, size+1) - a.valid = newValid - - return a -} - -func (a *Wavefront[T]) TranslateIndex(idx int) int { - return idx - a.lo -} - -func (a *Wavefront[T]) Valid(idx int) bool { - actualIdx := a.TranslateIndex(idx) - return 0 <= actualIdx && actualIdx < len(a.data) && a.valid[actualIdx] -} - -func (a *Wavefront[T]) Get(idx int) T { - actualIdx := a.TranslateIndex(idx) - if 0 <= actualIdx && actualIdx < len(a.data) { // idx is in the slice - return a.data[actualIdx] - } else { // idx is out of the slice - return 0 - } -} - -func (a *Wavefront[T]) Set(idx int, value T) { - actualIdx := a.TranslateIndex(idx) - - /* in theory idx is always in bounds because the wavefront is preallocated - if actualIdx < 0 || actualIdx >= len(a.data) { - return - } - */ - - a.data[actualIdx] = value - a.valid[actualIdx] = true -} - type PositiveSlice[T any] struct { data []T valid []bool diff --git a/pkg/debug.go b/pkg/debug.go index 9b2f982..a3dd795 100644 --- a/pkg/debug.go +++ b/pkg/debug.go @@ -1,5 +1,3 @@ -//go:build debug - package wfa import ( @@ -46,8 +44,9 @@ func (w *WavefrontComponent) String(score int) string { hi := w.hi.Get(i) // print out wavefront matrix for k := min_lo; k <= max_hi; k++ { - if w.W.Valid(i) && w.W.Get(i).Valid(k) { - s = s + fmt.Sprintf("%02d", w.W.Get(i).Get(k)) + valid, val, _ := UnpackWavefrontValue(w.W.Get(i).Get(k)) + if valid { + s = s + fmt.Sprintf("%02d", val) } else if k < lo || k > hi { s = s + "--" } else { @@ -61,8 +60,9 @@ func (w *WavefrontComponent) String(score int) string { s = s + "]\t[" // print out traceback matrix for k := min_lo; k <= max_hi; k++ { - if w.A.Valid(i) && w.A.Get(i).Valid(k) { - s = s + traceback_str[w.A.Get(i).Get(k)] + valid, _, tb := UnpackWavefrontValue(w.W.Get(i).Get(k)) + if valid { + s = s + traceback_str[tb] } else if k < lo || k > hi { s = s + "--" } else { diff --git a/pkg/types.go b/pkg/types.go index 9a74991..e7733b5 100644 --- a/pkg/types.go +++ b/pkg/types.go @@ -12,10 +12,10 @@ type Penalty struct { E int } -type traceback byte +type Traceback byte const ( - OpenIns traceback = iota + OpenIns Traceback = iota ExtdIns OpenDel ExtdDel @@ -25,19 +25,87 @@ const ( End ) -type WavefrontComponent struct { - lo *PositiveSlice[int] // lo for each wavefront - hi *PositiveSlice[int] // hi for each wavefront - W *PositiveSlice[*Wavefront[int]] // wavefront diag distance for each wavefront - A *PositiveSlice[*Wavefront[traceback]] // compact CIGAR for backtrace for each wavefront +// bitpacked wavefront values with 1 valid bit, 3 traceback bits, and 28 bits for the diag distance +// technically this restricts to solutions within 268 million score but that should be sufficient for most cases +type WavefrontValue uint32 + +// TODO: add 64 bit packed value in case more than 268 million score is needed + +// PackWavefrontValue: packs a diag value and traceback into a WavefrontValue +func PackWavefrontValue(value uint32, traceback Traceback) WavefrontValue { + valueBM := value & 0x0FFF_FFFF + tracebackBM := uint32(traceback&0x0000_0007) << 28 + return WavefrontValue(0x8000_0000 | valueBM | tracebackBM) } +// UnpackWavefrontValue: opens a WavefrontValue into a valid bool, diag value and traceback +func UnpackWavefrontValue(wf WavefrontValue) (bool, uint32, Traceback) { + valueBM := uint32(wf & 0x0FFF_FFFF) + tracebackBM := uint8(wf & 0x7000_0000 >> 28) + validBM := wf&0x8000_0000 != 0 + return validBM, valueBM, Traceback(tracebackBM) +} + +// Wavefront: stores a single wavefront, stores wavefront's lo value and hi is naturally lo + len(data) +type Wavefront struct { // since wavefronts store diag distance, they should never be negative, and traceback data can be stored as uint8 + data []WavefrontValue + lo int +} + +// NewWavefront: returns a new wavefront with size accomodating lo and hi (inclusive) +func NewWavefront(lo int, hi int) *Wavefront { + a := &Wavefront{} + + a.lo = lo + size := a.TranslateIndex(hi) + + newData := make([]WavefrontValue, size+1) + a.data = newData + + return a +} + +// TranslateIndex: utility function for getting the data index given a diagonal +func (a *Wavefront) TranslateIndex(diagonal int) int { + return diagonal - a.lo +} + +// Get: returns WavefrontValue for given diagonal +func (a *Wavefront) Get(diagonal int) WavefrontValue { + actualIdx := a.TranslateIndex(diagonal) + if 0 <= actualIdx && actualIdx < len(a.data) { // idx is in the slice + return a.data[actualIdx] + } else { // idx is out of the slice + return 0 + } +} + +// Set: the diagonal to a WavefrontValue +func (a *Wavefront) Set(diagonal int, value WavefrontValue) { + actualIdx := a.TranslateIndex(diagonal) + + /* in theory idx is always in bounds because the wavefront is preallocated + if actualIdx < 0 || actualIdx >= len(a.data) { + return + } + */ + + a.data[actualIdx] = value +} + +// WavefrontComponent: each M/I/D wavefront matrix including the wavefront data, lo and hi +type WavefrontComponent struct { + lo *PositiveSlice[int] // lo for each wavefront + hi *PositiveSlice[int] // hi for each wavefront + W *PositiveSlice[*Wavefront] // wavefront diag distance and traceback for each wavefront +} + +// NewWavefrontComponent: returns initialized WavefrontComponent func NewWavefrontComponent(preallocateSize int) WavefrontComponent { // new wavefront component = { // lo = [0] // hi = [0] // W = [] - // A = [] // } w := WavefrontComponent{ lo: &PositiveSlice[int]{ @@ -48,16 +116,9 @@ func NewWavefrontComponent(preallocateSize int) WavefrontComponent { data: []int{0}, valid: []bool{true}, }, - W: &PositiveSlice[*Wavefront[int]]{ - defaultValue: &Wavefront[int]{ - data: []int{0}, - valid: []bool{false}, - }, - }, - A: &PositiveSlice[*Wavefront[traceback]]{ - defaultValue: &Wavefront[traceback]{ - data: []traceback{0}, - valid: []bool{false}, + W: &PositiveSlice[*Wavefront]{ + defaultValue: &Wavefront{ + data: []WavefrontValue{0}, }, }, } @@ -65,32 +126,21 @@ func NewWavefrontComponent(preallocateSize int) WavefrontComponent { w.lo.Preallocate(preallocateSize) w.hi.Preallocate(preallocateSize) w.W.Preallocate(preallocateSize) - w.A.Preallocate(preallocateSize) return w } -// get value for wavefront=score, diag=k => returns ok, value -func (w *WavefrontComponent) GetVal(score int, k int) (bool, int) { - return w.W.Valid(score) && w.W.Get(score).Valid(k), w.W.Get(score).Get(k) +// GetVal: get value for wavefront=score, diag=k => returns ok, value, traceback +func (w *WavefrontComponent) GetVal(score int, k int) (bool, uint32, Traceback) { + return UnpackWavefrontValue(w.W.Get(score).Get(k)) } -// set value for wavefront=score, diag=k -func (w *WavefrontComponent) SetVal(score int, k int, val int) { - w.W.Get(score).Set(k, val) +// SetVal: set value, traceback for wavefront=score, diag=k +func (w *WavefrontComponent) SetVal(score int, k int, val uint32, tb Traceback) { + w.W.Get(score).Set(k, PackWavefrontValue(val, tb)) } -// get alignment traceback for wavefront=score, diag=k => returns ok, value -func (w *WavefrontComponent) GetTraceback(score int, k int) (bool, traceback) { - return w.A.Valid(score) && w.A.Get(score).Valid(k), w.A.Get(score).Get(k) -} - -// set alignment traceback for wavefront=score, diag=k -func (w *WavefrontComponent) SetTraceback(score int, k int, val traceback) { - w.A.Get(score).Set(k, val) -} - -// get hi for wavefront=score +// GetLoHi: get lo and hi for wavefront=score func (w *WavefrontComponent) GetLoHi(score int) (bool, int, int) { // if lo[score] and hi[score] are valid if w.lo.Valid(score) && w.hi.Valid(score) { @@ -101,18 +151,14 @@ func (w *WavefrontComponent) GetLoHi(score int) (bool, int, int) { } } -// set hi for wavefront=score +// SetLoHi: set lo and hi for wavefront=score func (w *WavefrontComponent) SetLoHi(score int, lo int, hi int) { // lo[score] = lo w.lo.Set(score, lo) // hi[score] = hi w.hi.Set(score, hi) - // preemptively setup w.A - a := NewWavefront[traceback](lo, hi) - w.A.Set(score, a) - // preemptively setup w.W - b := NewWavefront[int](lo, hi) + b := NewWavefront(lo, hi) w.W.Set(score, b) } diff --git a/pkg/utils.go b/pkg/utils.go index 49b4c5e..a806159 100644 --- a/pkg/utils.go +++ b/pkg/utils.go @@ -3,25 +3,27 @@ package wfa import ( "math" "unicode/utf8" + + "golang.org/x/exp/constraints" ) -func SafeMin(values []int, idx int) int { +func SafeMin[T constraints.Integer](values []T, idx int) T { return values[idx] } -func SafeMax(values []int, idx int) int { +func SafeMax[T constraints.Integer](values []T, idx int) T { return values[idx] } -func SafeArgMax(valids []bool, values []int) (bool, int) { +func SafeArgMax[T constraints.Integer](valids []bool, values []T) (bool, int) { hasValid := false maxIndex := 0 maxValue := math.MinInt for i := 0; i < len(valids); i++ { - if valids[i] && values[i] > maxValue { + if valids[i] && int(values[i]) > maxValue { hasValid = true maxIndex = i - maxValue = values[i] + maxValue = int(values[i]) } } if hasValid { @@ -31,15 +33,15 @@ func SafeArgMax(valids []bool, values []int) (bool, int) { } } -func SafeArgMin(valids []bool, values []int) (bool, int) { +func SafeArgMin[T constraints.Integer](valids []bool, values []T) (bool, int) { hasValid := false minIndex := 0 minValue := math.MaxInt for i := 0; i < len(valids); i++ { - if valids[i] && values[i] < minValue { + if valids[i] && int(values[i]) < minValue { hasValid = true minIndex = i - minValue = values[i] + minValue = int(values[i]) } } if hasValid { @@ -98,14 +100,13 @@ func NextI(M WavefrontComponent, I WavefrontComponent, score int, k int, penalti o := penalties.O e := penalties.E - a_ok, a := M.GetVal(score-o-e, k-1) - b_ok, b := I.GetVal(score-e, k-1) + a_ok, a, _ := M.GetVal(score-o-e, k-1) + b_ok, b, _ := I.GetVal(score-e, k-1) - ok, nextITraceback := SafeArgMax([]bool{a_ok, b_ok}, []int{a, b}) - nextIVal := SafeMax([]int{a, b}, nextITraceback) + 1 // important that the +1 is here + ok, nextITraceback := SafeArgMax([]bool{a_ok, b_ok}, []uint32{a, b}) + nextIVal := SafeMax([]uint32{a, b}, nextITraceback) + 1 // important that the +1 is here if ok { - I.SetVal(score, k, nextIVal) - I.SetTraceback(score, k, []traceback{OpenIns, ExtdIns}[nextITraceback]) + I.SetVal(score, k, nextIVal, []Traceback{OpenIns, ExtdIns}[nextITraceback]) } } @@ -113,33 +114,31 @@ func NextD(M WavefrontComponent, D WavefrontComponent, score int, k int, penalti o := penalties.O e := penalties.E - a_ok, a := M.GetVal(score-o-e, k+1) - b_ok, b := D.GetVal(score-e, k+1) + a_ok, a, _ := M.GetVal(score-o-e, k+1) + b_ok, b, _ := D.GetVal(score-e, k+1) ok, nextDTraceback := SafeArgMax( []bool{a_ok, b_ok}, - []int{a, b}, + []uint32{a, b}, ) - nextDVal := SafeMax([]int{a, b}, nextDTraceback) // nothing special + nextDVal := SafeMax([]uint32{a, b}, nextDTraceback) // nothing special if ok { - D.SetVal(score, k, nextDVal) - D.SetTraceback(score, k, []traceback{OpenDel, ExtdDel}[nextDTraceback]) + D.SetVal(score, k, nextDVal, []Traceback{OpenDel, ExtdDel}[nextDTraceback]) } } func NextM(M WavefrontComponent, I WavefrontComponent, D WavefrontComponent, score int, k int, penalties Penalty) { x := penalties.X - a_ok, a := M.GetVal(score-x, k) + a_ok, a, _ := M.GetVal(score-x, k) a++ // important to have +1 here - b_ok, b := I.GetVal(score, k) - c_ok, c := D.GetVal(score, k) + b_ok, b, _ := I.GetVal(score, k) + c_ok, c, _ := D.GetVal(score, k) - ok, nextMTraceback := SafeArgMax([]bool{a_ok, b_ok, c_ok}, []int{a, b, c}) - nextMVal := SafeMax([]int{a, b, c}, nextMTraceback) + ok, nextMTraceback := SafeArgMax([]bool{a_ok, b_ok, c_ok}, []uint32{a, b, c}) + nextMVal := SafeMax([]uint32{a, b, c}, nextMTraceback) if ok { - M.SetVal(score, k, nextMVal) - M.SetTraceback(score, k, []traceback{Sub, Ins, Del}[nextMTraceback]) + M.SetVal(score, k, nextMVal, []Traceback{Sub, Ins, Del}[nextMTraceback]) } } diff --git a/pkg/wfa.go b/pkg/wfa.go index 876cf10..1bc52d5 100644 --- a/pkg/wfa.go +++ b/pkg/wfa.go @@ -4,19 +4,18 @@ func WFAlign(s1 string, s2 string, penalties Penalty, doCIGAR bool) Result { n := len(s1) m := len(s2) A_k := m - n - A_offset := m + A_offset := uint32(m) score := 0 estimatedScore := (max(n, m) * max(penalties.M, penalties.X, penalties.O, penalties.E)) / 4 M := NewWavefrontComponent(estimatedScore) M.SetLoHi(0, 0, 0) - M.SetVal(0, 0, 0) - M.SetTraceback(0, 0, End) + M.SetVal(0, 0, 0, End) I := NewWavefrontComponent(estimatedScore) D := NewWavefrontComponent(estimatedScore) for { WFExtend(M, s1, n, s2, m, score) - ok, val := M.GetVal(score, A_k) + ok, val, _ := M.GetVal(score, A_k) if ok && val >= A_offset { break } @@ -40,7 +39,8 @@ func WFExtend(M WavefrontComponent, s1 string, n int, s2 string, m int, score in for k := lo; k <= hi; k++ { // v = M[score][k] - k // h = M[score][k] - ok, h := M.GetVal(score, k) + ok, hu, _ := M.GetVal(score, k) + h := int(hu) v := h - k // exit early if v or h are invalid @@ -48,8 +48,8 @@ func WFExtend(M WavefrontComponent, s1 string, n int, s2 string, m int, score in continue } for v < n && h < m && s1[v] == s2[h] { - _, val := M.GetVal(score, k) - M.SetVal(score, k, val+1) + _, val, tb := M.GetVal(score, k) + M.SetVal(score, k, val+1, tb) v++ h++ } @@ -75,7 +75,7 @@ func WFBacktrace(M WavefrontComponent, I WavefrontComponent, D WavefrontComponen CIGAR_rev := "" tb_s := score tb_k := A_k - _, current_traceback := M.GetTraceback(tb_s, tb_k) + _, _, current_traceback := M.GetVal(tb_s, tb_k) done := false for !done { @@ -84,31 +84,31 @@ func WFBacktrace(M WavefrontComponent, I WavefrontComponent, D WavefrontComponen case OpenIns: tb_s = tb_s - o - e tb_k = tb_k - 1 - _, current_traceback = M.GetTraceback(tb_s, tb_k) + _, _, current_traceback = M.GetVal(tb_s, tb_k) case ExtdIns: tb_s = tb_s - e tb_k = tb_k - 1 - _, current_traceback = I.GetTraceback(tb_s, tb_k) + _, _, current_traceback = I.GetVal(tb_s, tb_k) case OpenDel: tb_s = tb_s - o - e tb_k = tb_k + 1 - _, current_traceback = M.GetTraceback(tb_s, tb_k) + _, _, current_traceback = M.GetVal(tb_s, tb_k) case ExtdDel: tb_s = tb_s - e tb_k = tb_k + 1 - _, current_traceback = D.GetTraceback(tb_s, tb_k) + _, _, current_traceback = D.GetVal(tb_s, tb_k) case Sub: tb_s = tb_s - x // tb_k = tb_k; - _, current_traceback = M.GetTraceback(tb_s, tb_k) + _, _, current_traceback = M.GetVal(tb_s, tb_k) case Ins: // tb_s = tb_s; // tb_k = tb_k; - _, current_traceback = I.GetTraceback(tb_s, tb_k) + _, _, current_traceback = I.GetVal(tb_s, tb_k) case Del: // tb_s = tb_s; // tb_k = tb_k; - _, current_traceback = D.GetTraceback(tb_s, tb_k) + _, _, current_traceback = D.GetVal(tb_s, tb_k) case End: done = true } diff --git a/test/wfa_test.go b/test/wfa_test.go index 5d56445..5107b66 100644 --- a/test/wfa_test.go +++ b/test/wfa_test.go @@ -3,6 +3,7 @@ package tests import ( "bufio" "encoding/json" + "math/rand/v2" "os" "strconv" "strings" @@ -27,6 +28,24 @@ type TestCase struct { Solutions string `json:"solutions"` } +func randRange(min, max int) uint32 { + return uint32(rand.IntN(max-min) + min) +} + +func TestWavefrontPacking(t *testing.T) { + for range 1000 { + val := randRange(0, 1000) + tb := wfa.Traceback(randRange(0, 7)) + v := wfa.PackWavefrontValue(val, tb) + + valid, gotVal, gotTB := wfa.UnpackWavefrontValue(v) + + if !valid || gotVal != val || gotTB != tb { + t.Errorf(`test WavefrontPack/Unpack, val: %d, tb: %d, packedval: %x, gotok: %t, gotval: %d, gottb: %d\n`, val, tb, v, valid, gotVal, gotTB) + } + } +} + func TestWFA(t *testing.T) { content, _ := os.ReadFile(testJsonPath)