diff --git a/src/lib.rs b/src/lib.rs index f59af63..18cdc35 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -319,6 +319,35 @@ where TrieHard::U256(trie) => TrieIter::U256(trie.prefix_search(prefix)), } } + + /// Find the closest ancestor to the given key, where an ancestor is defined as the longest + /// string present in the trie that appears as a prefix of the given key. + /// + /// ``` + /// # use trie_hard::TrieHard; + /// let trie = ["dad", "ant", "and", "dot", "do"] + /// .into_iter() + /// .collect::>(); + /// + /// assert_eq!( + /// trie.ancestor("dada").map(|(_, v)| v), + /// Some("dad") + /// ); + /// assert_eq!( + /// trie.ancestor("an").map(|(_, v)| v), + /// None + /// ); + /// ``` + pub fn ancestor>(&self, key: K) -> Option<(&[u8], T)> { + match self { + TrieHard::U8(trie) => trie.ancestor(key), + TrieHard::U16(trie) => trie.ancestor(key), + TrieHard::U32(trie) => trie.ancestor(key), + TrieHard::U64(trie) => trie.ancestor(key), + TrieHard::U128(trie) => trie.ancestor(key), + TrieHard::U256(trie) => trie.ancestor(key), + } + } } /// Structure used for iterative over the contents of trie @@ -583,6 +612,71 @@ macro_rules! trie_impls { TrieIterSized::new(self, node_index) } + + /// Find the closest ancestor to the given key, where an ancestor is defined as the + /// longest string present in the trie that appears as a prefix of the given key. + /// + /// ``` + /// # use trie_hard::TrieHard; + /// let trie = ["dad", "ant", "and", "dot", "do"] + /// .into_iter() + /// .collect::>(); + /// + /// let TrieHard::U8(sized_trie) = trie else { + /// unreachable!() + /// }; + /// + /// assert_eq!( + /// sized_trie.ancestor("dada").map(|(_, v)| v), + /// Some("dad") + /// ); + /// assert_eq!( + /// sized_trie.ancestor("an").map(|(_, v)| v), + /// None + /// ); + /// ``` + pub fn ancestor>( + &self, + key: K, + ) -> Option<(&[u8], T)> { + self.ancestor_recurse(0, key.as_ref(), self.nodes.get(0)?) + } + + fn ancestor_recurse( + &self, + i: usize, + key: &[u8], + state: &TrieState<'a, T, $int_type>, + ) -> Option<(&[u8], T)> { + match state { + TrieState::Leaf(k, value) => { + ( + k.len() <= key.len() + && k[i..] == key[i..k.len()] + ).then_some((k, *value)) + } + TrieState::Search(search) => { + let c = key.get(i)?; + let next_state_index = search.evaluate(*c, self)?; + self.ancestor_recurse(i + 1, key, &self.nodes[next_state_index]) + } + TrieState::SearchOrLeaf(k, value, search) => { + // lambda to enable using `?` operator + let search = || { + let c = key.get(i)?; + let next_state_index = search.evaluate(*c, self)?; + self.ancestor_recurse(i + 1, key, &self.nodes[next_state_index]) + }; + + search().or_else(|| { + ( + k.len() <= key.len() + && k[i..] == key[i..k.len()] + ).then_some((k, *value)) + }) + } + } + } } impl<'a, T> TrieHardSized<'a, T, $int_type> where T: 'a + Copy { @@ -924,4 +1018,30 @@ mod tests { .collect::>(); assert_eq!(emitted, output); } + + #[rstest] + #[case(&[], "", None)] + #[case(&[""], "", Some(""))] + #[case(&["aaa", "a", ""], "", Some(""))] + #[case(&["aaa", "a", ""], "a", Some("a"))] + #[case(&["aaa", "a", ""], "aa", Some("a"))] + #[case(&["aaa", "a", ""], "aab", Some("a"))] + #[case(&["aaa", "a", ""], "aaa", Some("aaa"))] + #[case(&["aaa", "a", ""], "b", Some(""))] + #[case(&["dad", "ant", "and", "dot", "do"], "d", None)] + #[case(&["dad", "ant", "and", "dot", "do"], "dad", Some("dad"))] + #[case(&["dad", "ant", "and", "dot", "do"], "dada", Some("dad"))] + #[case(&["dad", "ant", "and", "dot", "do"], "do", Some("do"))] + #[case(&["dad", "ant", "and", "dot", "do"], "dot", Some("dot"))] + #[case(&["dad", "ant", "and", "dot", "do"], "dob", Some("do"))] + #[case(&["dad", "ant", "and", "dot", "do"], "doto", Some("dot"))] + fn test_ancestor( + #[case] input: &[&str], + #[case] key: &str, + #[case] output: Option<&str>, + ) { + let trie = input.iter().copied().collect::>(); + let emitted = trie.ancestor(key).map(|(_, v)| v); + assert_eq!(emitted, output); + } }