From 547dffd8ee7c8d190f3924356b9c9a1f1a0623f9 Mon Sep 17 00:00:00 2001 From: Arthur Lu Date: Thu, 24 Oct 2024 18:07:10 +0000 Subject: [PATCH] rewrite in go and compile to wasm --- .eslintrc.json | 43 ---- .gitignore | 5 +- Makefile | 17 ++ go.mod | 11 + main.go | 72 ++++++ package.json | 20 -- pkg/custom_slice.go | 97 ++++++++ pkg/types.go | 214 +++++++++++++++++ pkg/utils.go | 168 ++++++++++++++ pkg/wfa.go | 146 ++++++++++++ src/wfa.js | 356 ----------------------------- {tests => test}/sequences | 0 {tests => test}/test_affine_p0_sol | 0 {tests => test}/test_affine_p1_sol | 2 +- {tests => test}/test_affine_p2_sol | 2 +- {tests => test}/tests.json | 6 +- test/wfa_test.go | 78 +++++++ tests/test.js | 41 ---- 18 files changed, 810 insertions(+), 468 deletions(-) delete mode 100644 .eslintrc.json create mode 100644 Makefile create mode 100644 go.mod create mode 100644 main.go delete mode 100644 package.json create mode 100644 pkg/custom_slice.go create mode 100644 pkg/types.go create mode 100644 pkg/utils.go create mode 100644 pkg/wfa.go delete mode 100644 src/wfa.js rename {tests => test}/sequences (100%) rename {tests => test}/test_affine_p0_sol (100%) rename {tests => test}/test_affine_p1_sol (99%) rename {tests => test}/test_affine_p2_sol (99%) rename {tests => test}/tests.json (72%) create mode 100644 test/wfa_test.go delete mode 100644 tests/test.js diff --git a/.eslintrc.json b/.eslintrc.json deleted file mode 100644 index 32fdd82..0000000 --- a/.eslintrc.json +++ /dev/null @@ -1,43 +0,0 @@ -{ - "env": { - "es2021": true, - "node": true - }, - "extends": "standard", - "parserOptions": { - "ecmaVersion": "latest", - "sourceType": "module" - }, - "rules": { - "no-tabs": [ - "error", - { - "allowIndentationTabs": true - } - ], - "indent": [ - "error", - "tab" - ], - "linebreak-style": [ - "error", - "unix" - ], - "quotes": [ - "error", - "double" - ], - "semi": [ - "error", - "always" - ], - "brace-style": [ - "error", - "stroustrup", - { - "allowSingleLine": false - } - ], - "camelcase": 0 - } -} \ No newline at end of file diff --git a/.gitignore b/.gitignore index e4291ab..80271e7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,2 @@ -**/package-lock.json -**/node_modules -dist/* +go.sum +dist/* \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b4c5a27 --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +.PHONY: build clean test + +build: clean + @echo "======================== Building Binary =======================" + GOOS=js GOARCH=wasm CGO_ENABLED=0 tinygo build -no-debug -opt=2 -target=wasm -o dist/wfa.wasm . + +clean: + @echo "======================== Cleaning Project ======================" + go clean + rm -f dist/wfa.wasm + +test: + @echo "======================== Running Tests =========================" + go test -v -cover -coverpkg=./pkg/ -coverprofile coverage ./test/ + @echo "======================= Coverage Report ========================" + go tool cover -func=coverage + @rm -f coverage \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0b1185a --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module wfa + +go 1.23.2 + +require ( + github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/schollz/progressbar/v3 v3.16.1 // indirect + golang.org/x/sys v0.26.0 // indirect + golang.org/x/term v0.25.0 // indirect +) diff --git a/main.go b/main.go new file mode 100644 index 0000000..35cd236 --- /dev/null +++ b/main.go @@ -0,0 +1,72 @@ +package main + +import ( + "fmt" + "syscall/js" + wfa "wfa/pkg" +) + +func main() { + c := make(chan bool) + js.Global().Set("wfAlign", js.FuncOf(wfAlign)) + <-c +} + +func wfAlign(this js.Value, args []js.Value) interface{} { + if len(args) != 4 { + fmt.Println("invalid number of args, requires 4: s1, s2, penalties, doCIGAR") + return nil + } + + if args[0].Type() != js.TypeString { + fmt.Println("s1 should be a string") + return nil + } + + s1 := args[0].String() + + if args[1].Type() != js.TypeString { + fmt.Println("s2 should be a string") + return nil + } + + s2 := args[1].String() + + if args[2].Type() != js.TypeObject { + fmt.Println("penalties should be a map with key values m, x, o, e") + return nil + } + + if args[2].Get("m").IsUndefined() || args[2].Get("x").IsUndefined() || args[2].Get("o").IsUndefined() || args[2].Get("e").IsUndefined() { + fmt.Println("penalties should be a map with key values m, x, o, e") + return nil + } + + m := args[2].Get("m").Int() + x := args[2].Get("x").Int() + o := args[2].Get("o").Int() + e := args[2].Get("e").Int() + + penalties := wfa.Penalty{ + M: m, + X: x, + O: o, + E: e, + } + + if args[3].Type() != js.TypeBoolean { + fmt.Println("doCIGAR should be a boolean") + return nil + } + + doCIGAR := args[3].Bool() + + // Call the actual func. + result := wfa.WFAlign(s1, s2, penalties, doCIGAR) + resultMap := map[string]interface{}{ + "score": result.Score, + "CIGAR": result.CIGAR, + } + + return js.ValueOf(resultMap) +} diff --git a/package.json b/package.json deleted file mode 100644 index f3627b9..0000000 --- a/package.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "name": "wfa-js", - "version": "1.0.0", - "description": "Wavefront alignment algorithm in JS", - "main": "tests/test.js", - "type": "module", - "devDependencies": { - "eslint": "^8.43.0", - "eslint-config-standard": "^17.1.0", - "eslint-plugin-import": "^2.27.5", - "eslint-plugin-n": "^16.0.1", - "eslint-plugin-promise": "^6.1.1", - "progress": "^2.0.3" - }, - "scripts": { - "test": "node tests/test.js", - "lint": "DEBUG=eslint:cli-engine eslint --fix src/*.js tests/*.js", - "minify": "sed -ze 's/\\t//g; s/\\/\\/[[:print:]]*//g;s/\\n//g;' src/wfa.js > dist/wfa.js" - } -} diff --git a/pkg/custom_slice.go b/pkg/custom_slice.go new file mode 100644 index 0000000..f213fe1 --- /dev/null +++ b/pkg/custom_slice.go @@ -0,0 +1,97 @@ +package wfa + +type IntegerSlice[T any] struct { + data []T + valid []bool + defaultValue T +} + +func (a *IntegerSlice[T]) TranslateIndex(idx int) int { + if idx >= 0 { // 0 -> 0, 1 -> 2, 2 -> 4, 3 -> 6, ... + return 2 * idx + } else { // -1 -> 1, -2 -> 3, -3 -> 5, ... + return (-2 * idx) - 1 + } +} + +func (a *IntegerSlice[T]) Valid(idx int) bool { + actualIdx := a.TranslateIndex(idx) + if actualIdx < len(a.valid) { // idx is in the slice + return a.valid[actualIdx] + } else { // idx is out of the slice + return false + } +} + +func (a *IntegerSlice[T]) Get(idx int) T { + actualIdx := a.TranslateIndex(idx) + if actualIdx < len(a.valid) { // idx is in the slice + return a.data[actualIdx] + } else { // idx is out of the slice + return a.defaultValue + } +} + +func (a *IntegerSlice[T]) Set(idx int, value T) { + actualIdx := a.TranslateIndex(idx) + if actualIdx >= len(a.valid) { // idx is outside the slice + // expand data array to actualIdx + newData := make([]T, actualIdx+1) + copy(newData, a.data) + a.data = newData + + // expand valid array to actualIdx + newValid := make([]bool, actualIdx+1) + copy(newValid, a.valid) + a.valid = newValid + } + + a.data[actualIdx] = value + a.valid[actualIdx] = true +} + +type PositiveSlice[T any] struct { + data []T + valid []bool + defaultValue T +} + +func (a *PositiveSlice[T]) TranslateIndex(idx int) int { + return idx +} + +func (a *PositiveSlice[T]) Valid(idx int) bool { + actualIdx := a.TranslateIndex(idx) + if actualIdx >= 0 && actualIdx < len(a.valid) { // idx is in the slice + return a.valid[actualIdx] + } else { // idx is out of the slice + return false + } +} + +func (a *PositiveSlice[T]) Get(idx int) T { + actualIdx := a.TranslateIndex(idx) + if actualIdx >= 0 && actualIdx < len(a.valid) { // idx is in the slice + return a.data[actualIdx] + } else { // idx is out of the slice + return a.defaultValue + } +} + +func (a *PositiveSlice[T]) Set(idx int, value T) { + actualIdx := a.TranslateIndex(idx) + if actualIdx < 0 || actualIdx >= len(a.valid) { // idx is outside the slice + // expand data array to actualIdx + newData := make([]T, actualIdx+1) + copy(newData, a.data) + a.data = newData + + // expand valid array to actualIdx + newValid := make([]bool, actualIdx+1) + copy(newValid, a.valid) + a.valid = newValid + } + + a.data[actualIdx] = value + a.valid[actualIdx] = true +} diff --git a/pkg/types.go b/pkg/types.go new file mode 100644 index 0000000..46ffbcb --- /dev/null +++ b/pkg/types.go @@ -0,0 +1,214 @@ +package wfa + +import ( + "fmt" + "math" +) + +type Penalty struct { + M int + X int + O int + E int +} + +type traceback byte + +const ( + OpenIns traceback = iota + ExtdIns + OpenDel + ExtdDel + Sub + Ins + Del + End +) + +type WavefrontComponent struct { + lo *PositiveSlice[int] // lo for each wavefront + hi *PositiveSlice[int] // hi for each wavefront + W *PositiveSlice[*IntegerSlice[int]] // wavefront diag distance for each wavefront + A *PositiveSlice[*IntegerSlice[traceback]] // compact CIGAR for backtrace for each wavefront +} + +func NewWavefrontComponent() WavefrontComponent { + // new wavefront component = { + // lo = [0] + // hi = [0] + // W = [] + // A = [] + // } + return WavefrontComponent{ + lo: &PositiveSlice[int]{ + data: []int{0}, + valid: []bool{true}, + }, + hi: &PositiveSlice[int]{ + data: []int{0}, + valid: []bool{true}, + }, + W: &PositiveSlice[*IntegerSlice[int]]{}, + A: &PositiveSlice[*IntegerSlice[traceback]]{}, + } +} + +// get value for wavefront=score, diag=k => returns ok, value +func (w *WavefrontComponent) GetVal(score int, k int) (bool, int) { + // if W[score][k] is valid + if w.W.Valid(score) && w.W.Get(score).Valid(k) { + // return W[score][k] + return true, w.W.Get(score).Get(k) + } else { + return false, 0 + } +} + +// set value for wavefront=score, diag=k +func (w *WavefrontComponent) SetVal(score int, k int, val int) { + // if W[score] is valid + if w.W.Valid(score) { + // W[score][k] = val + w.W.Get(score).Set(k, val) + } else { + // W[score] = [] + w.W.Set(score, &IntegerSlice[int]{}) + // W[score][k] = val + w.W.Get(score).Set(k, val) + } +} + +// get alignment traceback for wavefront=score, diag=k => returns ok, value +func (w *WavefrontComponent) GetTraceback(score int, k int) (bool, traceback) { + // if W[score][k] is valid + if w.A.Valid(score) && w.A.Get(score).Valid(k) { + // return W[score][k] + return true, w.A.Get(score).Get(k) + } else { + return false, 0 + } +} + +// set alignment traceback for wavefront=score, diag=k +func (w *WavefrontComponent) SetTraceback(score int, k int, val traceback) { + // if A[score] is valid + if w.A.Valid(score) { + // A[score][k] = val + w.A.Get(score).Set(k, val) + } else { + // W[score] = [] + w.A.Set(score, &IntegerSlice[traceback]{}) + // W[score][k] = val + w.A.Get(score).Set(k, val) + } +} + +// get hi for wavefront=score +func (w *WavefrontComponent) GetHi(score int) (bool, int) { + // if hi[score] is valid + if w.hi.Valid(score) { + // return hi[score] + return true, w.hi.Get(score) + } else { + return false, 0 + } +} + +// set hi for wavefront=score +func (w *WavefrontComponent) SetHi(score int, hi int) { + // hi[score] = hi + w.hi.Set(score, hi) +} + +// get lo for wavefront=score +func (w *WavefrontComponent) GetLo(score int) (bool, int) { + // if lo[score] is valid + if w.lo.Valid(score) { + // return lo[score] + return true, w.lo.Get(score) + } else { + return false, 0 + } +} + +// set hi for wavefront=score +func (w *WavefrontComponent) SetLo(score int, lo int) { + // lo[score] = lo + w.lo.Set(score, lo) +} + +type Result struct { + Score int + CIGAR string +} + +func (w *WavefrontComponent) String(score int) string { + traceback_str := []string{"OI", "EI", "OD", "ED", "SB", "IN", "DL", "EN"} + s := "<" + min_lo := math.MaxInt + max_hi := math.MinInt + + for i := 0; i <= score; i++ { + if w.lo.Valid(i) && w.lo.Get(i) < min_lo { + min_lo = w.lo.Get(i) + } + if w.hi.Valid(i) && w.hi.Get(i) > max_hi { + max_hi = w.hi.Get(i) + } + } + + for k := min_lo; k <= max_hi; k++ { + s = s + fmt.Sprintf("%02d", k) + if k < max_hi { + s = s + "|" + } + } + + s = s + ">\t<" + + for k := min_lo; k <= max_hi; k++ { + s = s + fmt.Sprintf("%02d", k) + if k < max_hi { + s = s + "|" + } + } + + s = s + ">\n" + + for i := 0; i <= score; i++ { + s = s + "[" + lo := w.lo.Get(i) + hi := w.hi.Get(i) + // print out wavefront matrix + for k := min_lo; k <= max_hi; k++ { + if w.W.Valid(i) && w.W.Get(i).Valid(k) { + s = s + fmt.Sprintf("%02d", w.W.Get(i).Get(k)) + } else if k < lo || k > hi { + s = s + "--" + } else { + s = s + " " + } + + if k < max_hi { + s = s + "|" + } + } + s = s + "]\t[" + // print out traceback matrix + for k := min_lo; k <= max_hi; k++ { + if w.A.Valid(i) && w.A.Get(i).Valid(k) { + s = s + traceback_str[w.A.Get(i).Get(k)] + } else if k < lo || k > hi { + s = s + "--" + } else { + s = s + " " + } + + if k < max_hi { + s = s + "|" + } + } + s = s + "]\n" + } + return s +} diff --git a/pkg/utils.go b/pkg/utils.go new file mode 100644 index 0000000..c95c63d --- /dev/null +++ b/pkg/utils.go @@ -0,0 +1,168 @@ +package wfa + +import ( + "math" + "unicode/utf8" +) + +func SafeMin(valids []bool, values []int) (bool, int) { + ok, idx := SafeArgMin(valids, values) + return ok, values[idx] +} + +func SafeMax(valids []bool, values []int) (bool, int) { + ok, idx := SafeArgMax(valids, values) + return ok, values[idx] +} + +func SafeArgMax(valids []bool, values []int) (bool, int) { + hasValid := false + maxIndex := 0 + maxValue := math.MinInt + for i := 0; i < len(valids); i++ { + if valids[i] && values[i] > maxValue { + hasValid = true + maxIndex = i + maxValue = values[i] + } + } + if hasValid { + return true, maxIndex + } else { + return false, 0 + } +} + +func SafeArgMin(valids []bool, values []int) (bool, int) { + hasValid := false + minIndex := 0 + minValue := math.MaxInt + for i := 0; i < len(valids); i++ { + if valids[i] && values[i] < minValue { + hasValid = true + minIndex = i + minValue = values[i] + } + } + if hasValid { + return true, minIndex + } else { + return false, 0 + } +} + +func Reverse(s string) string { + size := len(s) + buf := make([]byte, size) + for start := 0; start < size; { + r, n := utf8.DecodeRuneInString(s[start:]) + start += n + utf8.EncodeRune(buf[size-start:], r) + } + return string(buf) +} + +func Splice(s string, c rune, idx int) string { + return s[:idx] + string(c) + s[idx:] +} + +func NextLo(M WavefrontComponent, I WavefrontComponent, D WavefrontComponent, score int, penalties Penalty) int { + x := penalties.X + o := penalties.O + e := penalties.E + + a_ok, a := M.GetLo(score - x) + b_ok, b := M.GetLo(score - o - e) + c_ok, c := I.GetLo(score - e) + d_ok, d := D.GetLo(score - e) + + ok, lo := SafeMin( + []bool{a_ok, b_ok, c_ok, d_ok}, + []int{a, b, c, d}, + ) + lo-- + if ok { + M.SetLo(score, lo) + I.SetLo(score, lo) + D.SetLo(score, lo) + } + return lo +} + +func NextHi(M WavefrontComponent, I WavefrontComponent, D WavefrontComponent, score int, penalties Penalty) int { + x := penalties.X + o := penalties.O + e := penalties.E + + a_ok, a := M.GetHi(score - x) + b_ok, b := M.GetHi(score - o - e) + c_ok, c := I.GetHi(score - e) + d_ok, d := D.GetHi(score - e) + + ok, hi := SafeMax( + []bool{a_ok, b_ok, c_ok, d_ok}, + []int{a, b, c, d}, + ) + hi++ + if ok { + M.SetHi(score, hi) + I.SetHi(score, hi) + D.SetHi(score, hi) + } + return hi +} + +func NextI(M WavefrontComponent, I WavefrontComponent, score int, k int, penalties Penalty) { + o := penalties.O + e := penalties.E + + a_ok, a := M.GetVal(score-o-e, k-1) + b_ok, b := I.GetVal(score-e, k-1) + + ok, nextIVal := SafeMax([]bool{a_ok, b_ok}, []int{a, b}) + if ok { + I.SetVal(score, k, nextIVal+1) // important that the +1 is here + } + + ok, nextITraceback := SafeArgMax([]bool{a_ok, b_ok}, []int{a, b}) + if ok { + I.SetTraceback(score, k, []traceback{OpenIns, ExtdIns}[nextITraceback]) + } +} + +func NextD(M WavefrontComponent, D WavefrontComponent, score int, k int, penalties Penalty) { + o := penalties.O + e := penalties.E + + a_ok, a := M.GetVal(score-o-e, k+1) + b_ok, b := D.GetVal(score-e, k+1) + + ok, nextDVal := SafeMax([]bool{a_ok, b_ok}, []int{a, b}) + if ok { + D.SetVal(score, k, nextDVal) // nothing special + } + + ok, nextDTraceback := SafeArgMax([]bool{a_ok, b_ok}, []int{a, b}) + if ok { + D.SetTraceback(score, k, []traceback{OpenDel, ExtdDel}[nextDTraceback]) + } +} + +func NextM(M WavefrontComponent, I WavefrontComponent, D WavefrontComponent, score int, k int, penalties Penalty) { + x := penalties.X + + a_ok, a := M.GetVal(score-x, k) + a++ // important to have +1 here + b_ok, b := I.GetVal(score, k) + c_ok, c := D.GetVal(score, k) + + ok, nextMVal := SafeMax([]bool{a_ok, b_ok, c_ok}, []int{a, b, c}) + if ok { + M.SetVal(score, k, nextMVal) + } + + ok, nextMTraceback := SafeArgMax([]bool{a_ok, b_ok, c_ok}, []int{a, b, c}) + if ok { + M.SetTraceback(score, k, []traceback{Sub, Ins, Del}[nextMTraceback]) + } +} diff --git a/pkg/wfa.go b/pkg/wfa.go new file mode 100644 index 0000000..1e442d6 --- /dev/null +++ b/pkg/wfa.go @@ -0,0 +1,146 @@ +package wfa + +func WFAlign(s1 string, s2 string, penalties Penalty, doCIGAR bool) Result { + n := len(s1) + m := len(s2) + A_k := m - n + A_offset := m + score := 0 + M := NewWavefrontComponent() + M.SetVal(0, 0, 0) + M.SetHi(0, 0) + M.SetLo(0, 0) + M.SetTraceback(0, 0, End) + I := NewWavefrontComponent() + D := NewWavefrontComponent() + + for { + WFExtend(M, s1, n, s2, m, score) + ok, val := M.GetVal(score, A_k) + if ok && val >= A_offset { + break + } + score = score + 1 + WFNext(M, I, D, score, penalties) + } + + CIGAR := "" + if doCIGAR { + CIGAR = WFBacktrace(M, I, D, score, penalties, A_k, s1, s2) + } + + return Result{ + Score: score, + CIGAR: CIGAR, + } +} + +func WFExtend(M WavefrontComponent, s1 string, n int, s2 string, m int, score int) { + _, lo := M.GetLo(score) + _, hi := M.GetHi(score) + for k := lo; k <= hi; k++ { + // v = M[score][k] - k + // h = M[score][k] + ok, h := M.GetVal(score, k) + v := h - k + + // exit early if v or h are invalid + if !ok { + continue + } + for v < n && h < m && s1[v] == s2[h] { + _, val := M.GetVal(score, k) + M.SetVal(score, k, val+1) + v++ + h++ + } + } +} + +func WFNext(M WavefrontComponent, I WavefrontComponent, D WavefrontComponent, score int, penalties Penalty) { + // get this score's lo + lo := NextLo(M, I, D, score, penalties) + + // get this score's hi + hi := NextHi(M, I, D, score, penalties) + + for k := lo; k <= hi; k++ { + NextI(M, I, score, k, penalties) + NextD(M, D, score, k, penalties) + NextM(M, I, D, score, k, penalties) + } +} + +func WFBacktrace(M WavefrontComponent, I WavefrontComponent, D WavefrontComponent, score int, penalties Penalty, A_k int, s1 string, s2 string) string { + traceback_CIGAR := []string{"I", "I", "D", "D", "X", "", "", ""} + x := penalties.X + o := penalties.O + e := penalties.E + CIGAR_rev := "" + tb_s := score + tb_k := A_k + _, current_traceback := M.GetTraceback(tb_s, tb_k) + done := false + + for !done { + CIGAR_rev = CIGAR_rev + traceback_CIGAR[current_traceback] + switch current_traceback { + case OpenIns: + tb_s = tb_s - o - e + tb_k = tb_k - 1 + _, current_traceback = M.GetTraceback(tb_s, tb_k) + case ExtdIns: + tb_s = tb_s - e + tb_k = tb_k - 1 + _, current_traceback = I.GetTraceback(tb_s, tb_k) + case OpenDel: + tb_s = tb_s - o - e + tb_k = tb_k + 1 + _, current_traceback = M.GetTraceback(tb_s, tb_k) + case ExtdDel: + tb_s = tb_s - e + tb_k = tb_k + 1 + _, current_traceback = D.GetTraceback(tb_s, tb_k) + case Sub: + tb_s = tb_s - x + // tb_k = tb_k; + _, current_traceback = M.GetTraceback(tb_s, tb_k) + case Ins: + // tb_s = tb_s; + // tb_k = tb_k; + _, current_traceback = I.GetTraceback(tb_s, tb_k) + case Del: + // tb_s = tb_s; + // tb_k = tb_k; + _, current_traceback = D.GetTraceback(tb_s, tb_k) + case End: + done = true + } + } + + CIGAR_part := Reverse(CIGAR_rev) + c := 0 + i := 0 + j := 0 + for i < len(s1) && j < len(s2) { + if s1[i] == s2[j] { + //CIGAR_part.splice(c, 0, "M") + CIGAR_part = Splice(CIGAR_part, 'M', c) + c++ + i++ + j++ + } else if CIGAR_part[c] == 'X' { + c++ + i++ + j++ + } else if CIGAR_part[c] == 'I' { + c++ + j++ + } else if CIGAR_part[c] == 'D' { + c++ + i++ + } + } + + return CIGAR_part +} diff --git a/src/wfa.js b/src/wfa.js deleted file mode 100644 index 9483ff4..0000000 --- a/src/wfa.js +++ /dev/null @@ -1,356 +0,0 @@ -class WavefrontComponent { - constructor () { - this.lo = [0]; // lo for each wavefront - this.hi = [0]; // hi for each wavefront - this.W = []; // wavefront diag distance for each wavefront - this.A = []; // compact CIGAR for backtrace - } - - // get value for wavefront=score, diag=k - getVal (score, k) { - if (this.W[score] !== undefined && this.W[score][k] !== undefined) { - return this.W[score][k]; - } - else { - return NaN; - } - } - - // set value for wavefront=score, diag=k - setVal (score, k, val) { - if (this.W[score]) { - this.W[score][k] = val; - } - else { - this.W[score] = []; - this.W[score][k] = val; - } - } - - // get alignment traceback - getTraceback (score, k) { - if (this.A[score] !== undefined && this.A[score][k] !== undefined) { - return this.A[score][k]; - } - else { - return undefined; - } - } - - // set alignment traceback - setTraceback (score, k, traceback) { - if (this.A[score]) { - this.A[score][k] = traceback; - } - else { - this.A[score] = []; - this.A[score][k] = traceback; - } - } - - // get hi for wavefront=score - getHi (score) { - const hi = this.hi[score]; - return isNaN(hi) ? 0 : hi; - } - - // set hi for wavefront=score - setHi (score, hi) { - this.hi[score] = hi; - } - - // get lo for wavefront=score - getLo (score) { - const lo = this.lo[score]; - return isNaN(lo) ? 0 : lo; - } - - // set lo for wavefront=score - setLo (score, lo) { - this.lo[score] = lo; - } - - // string representation of all wavefronts - toString () { - const traceback_str = ["OI", "EI", "OD", "ED", "SB", "IN", "DL", "EN"]; - let s = "<"; - let min_lo = Infinity; - let max_hi = -Infinity; - // get the min lo and max hi values across all wavefronts - for (let i = 0; i < this.W.length; i++) { - const lo = this.lo[i]; - const hi = this.hi[i]; - if (lo < min_lo) { - min_lo = lo; - } - if (hi > max_hi) { - max_hi = hi; - } - } - // print out two headers, one for wavefront and one for traceback - for (let k = min_lo; k <= max_hi; k++) { - s += FormatNumberLength(k, 2); - if (k < max_hi) { - s += "|"; - } - } - s += ">\t<"; - for (let k = min_lo; k <= max_hi; k++) { - s += FormatNumberLength(k, 2); - if (k < max_hi) { - s += "|"; - } - } - s += ">\n"; - // for each wavefront - for (let i = 0; i < this.W.length; i++) { - s += "["; - const lo = this.lo[i]; - const hi = this.hi[i]; - // print out the wavefront matrix - for (let k = min_lo; k <= max_hi; k++) { - if (this.W[i] !== undefined && this.W[i][k] !== undefined && !isNaN(this.W[i][k])) { - s += FormatNumberLength(this.W[i][k], 2); - } - else if (k < lo || k > hi) { - s += "--"; - } - else { - s += " "; - } - if (k < max_hi) { - s += "|"; - } - } - s += "]\t["; - // print out the traceback matrix - for (let k = min_lo; k <= max_hi; k++) { - if (this.A[i] !== undefined && this.A[i][k] !== undefined) { - s += traceback_str[this.A[i][k].toString()]; - } - else if (k < lo || k > hi) { - s += "--"; - } - else { - s += " "; - } - if (k < max_hi) { - s += "|"; - } - } - s += "]\n"; - } - return s; - } -} - -const traceback = { - OpenIns: 0, - ExtdIns: 1, - OpenDel: 2, - ExtdDel: 3, - Sub: 4, - Ins: 5, - Del: 6, - End: 7 -}; - -function FormatNumberLength (num, length) { - let r = "" + num; - while (r.length < length) { - r = " " + r; - } - return r; -} - -function min (args) { - args.forEach((el, idx, arr) => { - arr[idx] = isNaN(el) ? Infinity : el; - }); - const min = Math.min.apply(Math, args); - return min === Infinity ? NaN : min; -} - -function max (args) { - args.forEach((el, idx, arr) => { - arr[idx] = isNaN(el) ? -Infinity : el; - }); - const max = Math.max.apply(Math, args); - return max === -Infinity ? NaN : max; -} - -function argmax (args) { - const val = max(args); - return args.indexOf(val); -} - -export default function wfAlign (s1, s2, penalties, doCIGAR = false) { - const n = s1.length; - const m = s2.length; - const A_k = m - n; - const A_offset = m; - let score = 0; - const M = new WavefrontComponent(); - M.setVal(0, 0, 0); - M.setHi(0, 0); - M.setLo(0, 0); - M.setTraceback(0, 0, traceback.End); - const I = new WavefrontComponent(); - const D = new WavefrontComponent(); - while (true) { - wfExtend(M, s1, n, s2, m, score); - if (M.getVal(score, A_k) >= A_offset) { - break; - } - score++; - wfNext(M, I, D, score, penalties); - } - let CIGAR = null; - if (doCIGAR) { - CIGAR = wfBacktrace(M, I, D, score, penalties, A_k, s1, s2); - } - return { score, CIGAR }; -} - -function wfExtend (M, s1, n, s2, m, score) { - const lo = M.getLo(score); - const hi = M.getHi(score); - for (let k = lo; k <= hi; k++) { - let v = M.getVal(score, k) - k; - let h = M.getVal(score, k); - if (isNaN(v) || isNaN(h)) { - continue; - } - while (s1[v] === s2[h]) { - M.setVal(score, k, M.getVal(score, k) + 1); - v++; - h++; - if (v > n || h > m) { - break; - } - } - } -} - -function wfNext (M, I, D, score, penalties, do_traceback) { - const x = penalties.x; - const o = penalties.o; - const e = penalties.e; - const lo = min([M.getLo(score - x), M.getLo(score - o - e), I.getLo(score - e), D.getLo(score - e)]) - 1; - const hi = max([M.getHi(score - x), M.getHi(score - o - e), I.getHi(score - e), D.getHi(score - e)]) + 1; - M.setHi(score, hi); - I.setHi(score, hi); - D.setHi(score, hi); - M.setLo(score, lo); - I.setLo(score, lo); - D.setLo(score, lo); - for (let k = lo; k <= hi; k++) { - I.setVal(score, k, max([ - M.getVal(score - o - e, k - 1), - I.getVal(score - e, k - 1) - ]) + 1); - I.setTraceback(score, k, [traceback.OpenIns, traceback.ExtdIns][argmax([ - M.getVal(score - o - e, k - 1), - I.getVal(score - e, k - 1) - ])]); - D.setVal(score, k, max([ - M.getVal(score - o - e, k + 1), - D.getVal(score - e, k + 1) - ])); - D.setTraceback(score, k, [traceback.OpenDel, traceback.ExtdDel][argmax([ - M.getVal(score - o - e, k + 1), - D.getVal(score - e, k + 1) - ])]); - M.setVal(score, k, max([ - M.getVal(score - x, k) + 1, - I.getVal(score, k), - D.getVal(score, k) - ])); - M.setTraceback(score, k, [traceback.Sub, traceback.Ins, traceback.Del][argmax([ - M.getVal(score - x, k) + 1, - I.getVal(score, k), - D.getVal(score, k) - ])]); - } -} - -function wfBacktrace (M, I, D, score, penalties, A_k, s1, s2) { - const traceback_CIGAR = ["I", "I", "D", "D", "X", "", "", ""]; - const x = penalties.x; - const o = penalties.o; - const e = penalties.e; - let CIGAR_rev = ""; // reversed CIGAR - let tb_s = score; // traceback score - let tb_k = A_k; // traceback diag k - let current_traceback = M.getTraceback(tb_s, tb_k); - let done = false; - while (!done) { - CIGAR_rev += traceback_CIGAR[current_traceback]; - switch (current_traceback) { - case traceback.OpenIns: - tb_s = tb_s - o - e; - tb_k = tb_k - 1; - current_traceback = M.getTraceback(tb_s, tb_k); - break; - case traceback.ExtdIns: - tb_s = tb_s - e; - tb_k = tb_k - 1; - current_traceback = I.getTraceback(tb_s, tb_k); - break; - case traceback.OpenDel: - tb_s = tb_s - o - e; - tb_k = tb_k + 1; - current_traceback = M.getTraceback(tb_s, tb_k); - break; - case traceback.ExtdDel: - tb_s = tb_s - e; - tb_k = tb_k + 1; - current_traceback = D.getTraceback(tb_s, tb_k); - break; - case traceback.Sub: - tb_s = tb_s - x; - // tb_k = tb_k; - current_traceback = M.getTraceback(tb_s, tb_k); - break; - case traceback.Ins: - // tb_s = tb_s; - // tb_k = tb_k; - current_traceback = I.getTraceback(tb_s, tb_k); - break; - case traceback.Del: - // tb_s = tb_s; - // tb_k = tb_k; - current_traceback = D.getTraceback(tb_s, tb_k); - break; - case traceback.End: - done = true; - break; - } - } - const CIGAR_part = Array.from(CIGAR_rev).reverse(); // still missing Match positions - let c = 0; - let i = 0; - let j = 0; - while (i < s1.length && j < s2.length) { // iterate through the strings to back-solve match positions - if (s1[i] === s2[j]) { // match, insert M and then increment c, i, j - CIGAR_part.splice(c, 0, "M"); - c++; - i++; - j++; - } - else if (CIGAR_part[c] === "X") { // mismatch, increment c, i, j - c++; - i++; - j++; - } - else if (CIGAR_part[c] === "I") { // insertion of character to s1 to reach s2, increment c,j - c++; - j++; - } - else if (CIGAR_part[c] === "D") { // deletion of character from s1 to reach s2, increment c,i - c++; - i++; - } - } - return CIGAR_part.join(""); -} diff --git a/tests/sequences b/test/sequences similarity index 100% rename from tests/sequences rename to test/sequences diff --git a/tests/test_affine_p0_sol b/test/test_affine_p0_sol similarity index 100% rename from tests/test_affine_p0_sol rename to test/test_affine_p0_sol diff --git a/tests/test_affine_p1_sol b/test/test_affine_p1_sol similarity index 99% rename from tests/test_affine_p1_sol rename to test/test_affine_p1_sol index 592251f..3a65d7c 100644 --- a/tests/test_affine_p1_sol +++ b/test/test_affine_p1_sol @@ -302,4 +302,4 @@ -154 2M1D1M2D2M1X3M3X1M2X4M1X1M2X2M1X1M2X1M1X5M1D2X4M1X1M2X3M1I3M1I8M1I2X3M1X1M1I1X3M2X1M1X1M3X1M1X1M1X1M2X2M1D2M1D1X1M2X2M -164 1I1M2X3M1I1M2X1M2X1M1X2M1X1M1I1X1M2X3M2I1M1X3M1I3M2X1M1I3M1X2M2I1M1X1M1X4M4I1M1X3M2I1M1X3M2X1M3X3M1I2M1X3M2D4M1X1M1X2M -160 1M1I1X1M2I1M1X1M2X2M1X2M1D6M2X1M1I1M1X2M1X1M1X4M1X2M1I4M2I2M1I1X1M1X5M2X4M1X1M2X1M3X2M1X1M1X1M2X1M1X3M2D1X3M1D3X8M3D --178 1I1M1X5M1I1M2X2M2I4M1X1M1X1M1I1X3M1I3X1M1I2M1X1M2X1M1X1M1X1M2I2X1M1I2M2X3M2X5M2I1X2M1X3M2I1M1I1M1I2M1I1X3M3I3M1X1M2I1X1M +-178 1I1M1X5M1I1M2X2M2I4M1X1M1X1M1I1X3M1I3X1M1I2M1X1M2X1M1X1M1X1M2I2X1M1I2M2X3M2X5M2I1X2M1X3M2I1M1I1M1I2M1I1X3M3I3M1X1M2I1X1M \ No newline at end of file diff --git a/tests/test_affine_p2_sol b/test/test_affine_p2_sol similarity index 99% rename from tests/test_affine_p2_sol rename to test/test_affine_p2_sol index 88cdaf0..62d8d47 100644 --- a/tests/test_affine_p2_sol +++ b/test/test_affine_p2_sol @@ -302,4 +302,4 @@ -184 2M1I4M3I4M5D1M2D4M1X1M2X2M1X1M1I2M1D1X5M2D1M1I4M2I4M1D5M1I8M2I1M1D3M1X1M2I4M2D2M2D1X1M1X2M1X1M1D1M1D2M1D4M5I1M1D2M -186 2D4M4D2M1D2M1I1X2M4D6M1I4M4I3M4I1M8I3M1X2M2I1M1X1M1X4M4I1M1X3M2I1M1X3M2D1X3M2I3M1I2M1X3M2D4M1X1M1X2M -171 1M2D2M2I1M4I4M1D2M1D6M3I1M1D2M1X3M1D1X4M1X2M1I4M2I2M1D2M2I5M2X4M2I2M1I3M1I1M2D7M1I3M2D2M1I2M1X1M5D8M3D --192 1I1M1X5M1I1M2X2M2I4M1X1M2I2M1D3M2D2M4I2M1X1M5I1M1I1X5M1D1X1M1X3M1I2M1X2M4I4M1D2X2M1X3M1X2M1I1X3M3I3M1X1M2I1X1M +-192 1I1M1X5M1I1M2X2M2I4M1X1M2I2M1D3M2D2M4I2M1X1M5I1M1I1X5M1D1X1M1X3M1I2M1X2M4I4M1D2X2M1X3M1X2M1I1X3M3I3M1X1M2I1X1M \ No newline at end of file diff --git a/tests/tests.json b/test/tests.json similarity index 72% rename from tests/tests.json rename to test/tests.json index fc5b943..95936a1 100644 --- a/tests/tests.json +++ b/test/tests.json @@ -6,7 +6,7 @@ "o": 2, "e": 1 }, - "solutions": "./tests/test_affine_p0_sol" + "solutions": "test_affine_p0_sol" }, "p1": { "penalties": { @@ -15,7 +15,7 @@ "o": 1, "e": 4 }, - "solutions": "./tests/test_affine_p1_sol" + "solutions": "test_affine_p1_sol" }, "p2": { "penalties": { @@ -24,6 +24,6 @@ "o": 3, "e": 2 }, - "solutions": "./tests/test_affine_p2_sol" + "solutions": "test_affine_p2_sol" } } \ No newline at end of file diff --git a/test/wfa_test.go b/test/wfa_test.go new file mode 100644 index 0000000..e8d8387 --- /dev/null +++ b/test/wfa_test.go @@ -0,0 +1,78 @@ +package tests + +import ( + "bufio" + "encoding/json" + "os" + "strconv" + "strings" + "testing" + wfa "wfa/pkg" + + "github.com/schollz/progressbar/v3" +) + +const testJsonPath = "tests.json" +const testSequences = "sequences" + +type TestPenalty struct { + M int `json:"m"` + X int `json:"x"` + O int `json:"o"` + E int `json:"e"` +} + +type TestCase struct { + Penalties TestPenalty `json:"penalties"` + Solutions string `json:"solutions"` +} + +func TestWFA(t *testing.T) { + content, _ := os.ReadFile(testJsonPath) + + var testMap map[string]TestCase + json.Unmarshal(content, &testMap) + + for k, v := range testMap { + testName := k + + testPenalties := wfa.Penalty{ + M: v.Penalties.M, + X: v.Penalties.X, + O: v.Penalties.O, + E: v.Penalties.E, + } + + sequencesFile, _ := os.Open(testSequences) + sequences := bufio.NewScanner(sequencesFile) + solutionsFile, _ := os.Open(v.Solutions) + solutions := bufio.NewScanner(solutionsFile) + + bar := progressbar.Default(305, k) + + idx := 0 + + for solutions.Scan() { + solution := solutions.Text() + expectedScore, _ := strconv.Atoi(strings.Split(solution, "\t")[0]) + + sequences.Scan() + s1 := sequences.Text() + s1 = s1[1:] + + sequences.Scan() + s2 := sequences.Text() + s2 = s2[1:] + + x := wfa.WFAlign(s1, s2, testPenalties, false) + gotScore := x.Score + + if gotScore != -1*expectedScore { + t.Errorf(`test: %s#%d, s1: %s, s2: %s, got: %d, expected: %d\n`, testName, idx, s1, s2, gotScore, expectedScore) + } + + idx++ + bar.Add(1) + } + } +} diff --git a/tests/test.js b/tests/test.js deleted file mode 100644 index 327de3e..0000000 --- a/tests/test.js +++ /dev/null @@ -1,41 +0,0 @@ -import wfAlign from "../src/wfa.js"; -import fs from "fs"; -import ProgressBar from "progress"; - -let data = fs.readFileSync("./tests/tests.json"); -data = JSON.parse(data); -const sequences = fs.readFileSync("./tests/sequences").toString().split("\n"); -// const total = sequences.length; -const total = 500; // skip the later tests because of memory usage -const timePerChar = []; - -for (const test_name of Object.keys(data)) { - const test = data[test_name]; - const penalties = test.penalties; - const solutions = fs.readFileSync(test.solutions).toString().split("\n"); - const bar = new ProgressBar(":bar :current/:total", { total: total / 2 }); - console.log(`test: ${test_name}`); - let correct = 0; - let j = 0; - for (let i = 0; i < total; i += 2) { - const s1 = sequences[i].replace(">"); - const s2 = sequences[i + 1].replace("<"); - const start = process.hrtime()[1]; - const { score } = wfAlign(s1, s2, penalties, false); - const elapsed = process.hrtime()[1] - start; - timePerChar.push((elapsed / 1e9) / (s1.length + s2.length)); - const solution_score = Number(solutions[j].split("\t")[0]); - if (solution_score === -score) { - correct += 1; - } - j += 1; - bar.tick(); - } - console.log(`correct: ${correct}\ntotal: ${total / 2}\n`); - console.log(`average time per character (ms): ${average(timePerChar) * 1000}`); -} - -function average (arr) { - const sum = arr.reduce((a, b) => a + b, 0); - return sum / arr.length; -}