diff --git a/go.mod b/go.mod index 0b1185a..0aa7e2b 100644 --- a/go.mod +++ b/go.mod @@ -2,10 +2,14 @@ module wfa go 1.23.2 +require ( + github.com/schollz/progressbar/v3 v3.17.0 + golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c +) + require ( github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/schollz/progressbar/v3 v3.16.1 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/term v0.25.0 // indirect ) diff --git a/pkg/custom_slice.go b/pkg/custom_slice.go index ea65e12..958ef3e 100644 --- a/pkg/custom_slice.go +++ b/pkg/custom_slice.go @@ -1,63 +1,59 @@ package wfa -type IntegerSlice[T any] struct { - data []T - valid []bool - defaultValue T +import ( + "golang.org/x/exp/constraints" +) + +type Wavefront[T constraints.Integer] struct { // since wavefronts store diag distance, they should never be negative, and traceback data can be stored as uint8 + data []T + valid []bool + lo int } -func (a *IntegerSlice[T]) TranslateIndex(idx int) int { - if idx >= 0 { // 0 -> 0, 1 -> 2, 2 -> 4, 3 -> 6, ... - return 2 * idx - } else { // -1 -> 1, -2 -> 3, -3 -> 5, ... - return (-2 * idx) - 1 - } -} +func NewWavefront[T constraints.Integer](lo int, hi int) *Wavefront[T] { + a := &Wavefront[T]{} -func (a *IntegerSlice[T]) Valid(idx int) bool { - actualIdx := a.TranslateIndex(idx) - return 0 <= actualIdx && actualIdx < len(a.valid) && a.valid[actualIdx] -} + a.lo = lo + size := a.TranslateIndex(hi) -func (a *IntegerSlice[T]) Get(idx int) T { - actualIdx := a.TranslateIndex(idx) - if 0 <= actualIdx && actualIdx < len(a.valid) && a.valid[actualIdx] { // idx is in the slice - return a.data[actualIdx] - } else { // idx is out of the slice - return a.defaultValue - } -} - -func (a *IntegerSlice[T]) Set(idx int, value T) { - actualIdx := a.TranslateIndex(idx) - if actualIdx >= len(a.valid) { // idx is outside the slice - // expand data array to actualIdx - newData := make([]T, 2*actualIdx+1) - copy(newData, a.data) - a.data = newData - - // expand valid array to actualIdx - newValid := make([]bool, 2*actualIdx+1) - copy(newValid, a.valid) - a.valid = newValid - } - - a.data[actualIdx] = value - a.valid[actualIdx] = true -} - -func (a *IntegerSlice[T]) Preallocate(lo int, hi int) { - actualLo := a.TranslateIndex(lo) - actualHi := a.TranslateIndex(hi) - size := max(actualHi, actualLo) - - // expand data array to actualIdx newData := make([]T, size+1) a.data = newData - // expand valid array to actualIdx newValid := make([]bool, size+1) a.valid = newValid + + return a +} + +func (a *Wavefront[T]) TranslateIndex(idx int) int { + return idx - a.lo +} + +func (a *Wavefront[T]) Valid(idx int) bool { + actualIdx := a.TranslateIndex(idx) + return 0 <= actualIdx && actualIdx < len(a.data) && a.valid[actualIdx] +} + +func (a *Wavefront[T]) Get(idx int) T { + actualIdx := a.TranslateIndex(idx) + if 0 <= actualIdx && actualIdx < len(a.data) { // idx is in the slice + return a.data[actualIdx] + } else { // idx is out of the slice + return 0 + } +} + +func (a *Wavefront[T]) Set(idx int, value T) { + actualIdx := a.TranslateIndex(idx) + + /* in theory idx is always in bounds because the wavefront is preallocated + if actualIdx < 0 || actualIdx >= len(a.data) { + return + } + */ + + a.data[actualIdx] = value + a.valid[actualIdx] = true } type PositiveSlice[T any] struct { diff --git a/pkg/debug.go b/pkg/debug.go new file mode 100644 index 0000000..9b2f982 --- /dev/null +++ b/pkg/debug.go @@ -0,0 +1,79 @@ +//go:build debug + +package wfa + +import ( + "fmt" + "math" +) + +func (w *WavefrontComponent) String(score int) string { + traceback_str := []string{"OI", "EI", "OD", "ED", "SB", "IN", "DL", "EN"} + s := "<" + min_lo := math.MaxInt + max_hi := math.MinInt + + for i := 0; i <= score; i++ { + if w.lo.Valid(i) && w.lo.Get(i) < min_lo { + min_lo = w.lo.Get(i) + } + if w.hi.Valid(i) && w.hi.Get(i) > max_hi { + max_hi = w.hi.Get(i) + } + } + + for k := min_lo; k <= max_hi; k++ { + s = s + fmt.Sprintf("%02d", k) + if k < max_hi { + s = s + "|" + } + } + + s = s + ">\t<" + + for k := min_lo; k <= max_hi; k++ { + s = s + fmt.Sprintf("%02d", k) + if k < max_hi { + s = s + "|" + } + } + + s = s + ">\n" + + for i := 0; i <= score; i++ { + s = s + "[" + lo := w.lo.Get(i) + hi := w.hi.Get(i) + // print out wavefront matrix + for k := min_lo; k <= max_hi; k++ { + if w.W.Valid(i) && w.W.Get(i).Valid(k) { + s = s + fmt.Sprintf("%02d", w.W.Get(i).Get(k)) + } else if k < lo || k > hi { + s = s + "--" + } else { + s = s + " " + } + + if k < max_hi { + s = s + "|" + } + } + s = s + "]\t[" + // print out traceback matrix + for k := min_lo; k <= max_hi; k++ { + if w.A.Valid(i) && w.A.Get(i).Valid(k) { + s = s + traceback_str[w.A.Get(i).Get(k)] + } else if k < lo || k > hi { + s = s + "--" + } else { + s = s + " " + } + + if k < max_hi { + s = s + "|" + } + } + s = s + "]\n" + } + return s +} diff --git a/pkg/types.go b/pkg/types.go index ac605d1..9a74991 100644 --- a/pkg/types.go +++ b/pkg/types.go @@ -1,10 +1,5 @@ package wfa -import ( - "fmt" - "math" -) - type Result struct { Score int CIGAR string @@ -31,10 +26,10 @@ const ( ) type WavefrontComponent struct { - lo *PositiveSlice[int] // lo for each wavefront - hi *PositiveSlice[int] // hi for each wavefront - W *PositiveSlice[*IntegerSlice[int]] // wavefront diag distance for each wavefront - A *PositiveSlice[*IntegerSlice[traceback]] // compact CIGAR for backtrace for each wavefront + lo *PositiveSlice[int] // lo for each wavefront + hi *PositiveSlice[int] // hi for each wavefront + W *PositiveSlice[*Wavefront[int]] // wavefront diag distance for each wavefront + A *PositiveSlice[*Wavefront[traceback]] // compact CIGAR for backtrace for each wavefront } func NewWavefrontComponent(preallocateSize int) WavefrontComponent { @@ -53,16 +48,16 @@ func NewWavefrontComponent(preallocateSize int) WavefrontComponent { data: []int{0}, valid: []bool{true}, }, - W: &PositiveSlice[*IntegerSlice[int]]{ - defaultValue: &IntegerSlice[int]{ - data: []int{}, - valid: []bool{}, + W: &PositiveSlice[*Wavefront[int]]{ + defaultValue: &Wavefront[int]{ + data: []int{0}, + valid: []bool{false}, }, }, - A: &PositiveSlice[*IntegerSlice[traceback]]{ - defaultValue: &IntegerSlice[traceback]{ - data: []traceback{}, - valid: []bool{}, + A: &PositiveSlice[*Wavefront[traceback]]{ + defaultValue: &Wavefront[traceback]{ + data: []traceback{0}, + valid: []bool{false}, }, }, } @@ -114,81 +109,10 @@ func (w *WavefrontComponent) SetLoHi(score int, lo int, hi int) { w.hi.Set(score, hi) // preemptively setup w.A - w.A.Set(score, &IntegerSlice[traceback]{}) - w.A.Get(score).Preallocate(lo, hi) + a := NewWavefront[traceback](lo, hi) + w.A.Set(score, a) // preemptively setup w.W - w.W.Set(score, &IntegerSlice[int]{}) - w.W.Get(score).Preallocate(lo, hi) -} - -func (w *WavefrontComponent) String(score int) string { - traceback_str := []string{"OI", "EI", "OD", "ED", "SB", "IN", "DL", "EN"} - s := "<" - min_lo := math.MaxInt - max_hi := math.MinInt - - for i := 0; i <= score; i++ { - if w.lo.Valid(i) && w.lo.Get(i) < min_lo { - min_lo = w.lo.Get(i) - } - if w.hi.Valid(i) && w.hi.Get(i) > max_hi { - max_hi = w.hi.Get(i) - } - } - - for k := min_lo; k <= max_hi; k++ { - s = s + fmt.Sprintf("%02d", k) - if k < max_hi { - s = s + "|" - } - } - - s = s + ">\t<" - - for k := min_lo; k <= max_hi; k++ { - s = s + fmt.Sprintf("%02d", k) - if k < max_hi { - s = s + "|" - } - } - - s = s + ">\n" - - for i := 0; i <= score; i++ { - s = s + "[" - lo := w.lo.Get(i) - hi := w.hi.Get(i) - // print out wavefront matrix - for k := min_lo; k <= max_hi; k++ { - if w.W.Valid(i) && w.W.Get(i).Valid(k) { - s = s + fmt.Sprintf("%02d", w.W.Get(i).Get(k)) - } else if k < lo || k > hi { - s = s + "--" - } else { - s = s + " " - } - - if k < max_hi { - s = s + "|" - } - } - s = s + "]\t[" - // print out traceback matrix - for k := min_lo; k <= max_hi; k++ { - if w.A.Valid(i) && w.A.Get(i).Valid(k) { - s = s + traceback_str[w.A.Get(i).Get(k)] - } else if k < lo || k > hi { - s = s + "--" - } else { - s = s + " " - } - - if k < max_hi { - s = s + "|" - } - } - s = s + "]\n" - } - return s + b := NewWavefront[int](lo, hi) + w.W.Set(score, b) } diff --git a/test/wfa_test.go b/test/wfa_test.go index e8d8387..5d56445 100644 --- a/test/wfa_test.go +++ b/test/wfa_test.go @@ -64,7 +64,7 @@ func TestWFA(t *testing.T) { s2 := sequences.Text() s2 = s2[1:] - x := wfa.WFAlign(s1, s2, testPenalties, false) + x := wfa.WFAlign(s1, s2, testPenalties, true) gotScore := x.Score if gotScore != -1*expectedScore {