package bastion_init import ( "fmt" "sort" ) type TrieNode struct { children [26]*TrieNode isEnd bool word string } type Trie struct { root *TrieNode } func NewTrie() *Trie { return &Trie{root: &TrieNode{}} } func (tn *TrieNode) Insert(word string) { node := tn for _, c := range word { c -= 'a' // Convert char to int index if node.children[c] == nil { node.children[c] = &TrieNode{} } node = node.children[c] } node.isEnd = true node.word = word } func (t *Trie) InsertAll(words []string) { for _, word := range words { t.root.Insert(word) } } func (tn *TrieNode) Find(prefix string) []string { node := tn result := make([]string, 0) for _, c := range prefix { c -= 'a' // Convert char to int index if node.children[c] != nil { node = node.children[c] } else { return result // No more matching nodes } } trieWalk(node, &result) return result } func trieWalk(node *TrieNode, result *[]string) { if node.isEnd { *result = append(*result, node.word) } for _, child := range node.children { if child != nil { trieWalk(child, result) } } } type Word struct { Word string Rank int } func SortWords(words []string) []Word { sortedWords := make([]Word, len(words)) for i, word := range words { sortedWords[i] = Word{Word: word, Rank: i} } sort.Slice(sortedWords, func(i, j int) bool { return sortedWords[i].Word < sortedWords[j].Word }) return sortedWords } func FindClosestWord(trie *Trie, prefix string) (string, error) { words := trie.root.Find(prefix) if len(words) == 0 { return "", fmt.Errorf("no words found for prefix: %s", prefix) } sortedWords := SortWords(words) minDistance := len(prefix) // Initialize with the maximum possible distance closestWord := sortedWords[0].Word for _, word := range sortedWords { distance := levDist(prefix, word.Word) if distance < minDistance { minDistance = distance closestWord = word.Word } } return closestWord, nil } func levDist(s1, s2 string) int { m, n := len(s1), len(s2) if m == 0 { return n } if n == 0 { return m } dp := make([][]int, m+1) for i := range dp { dp[i] = make([]int, n+1) } for i := 1; i <= m; i++ { for j := 1; j <= n; j++ { cost := 0 if s1[i-1] != s2[j-1] { cost = 1 } dp[i][j] = min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) + cost } } return dp[m][n] } func min(a, b, c int) int { if a < b && a < c { return a } else if b < a && b < c { return b } else { return c } }