diff --git a/src/trie.c b/src/trie.c index d7d2c1d0..ca32c05d 100644 --- a/src/trie.c +++ b/src/trie.c @@ -679,14 +679,11 @@ bool trie_add_suffix(trie_t *self, char *key, uint32_t data) { return success; } -bool trie_compare_tail(trie_t *self, char *str, uint32_t tail_index) { +bool trie_compare_tail(trie_t *self, char *str, size_t len, size_t tail_index) { + if (tail_index >= self->tail->n) return false; + unsigned char *current_tail = self->tail->a + tail_index; - - size_t tail_len = strlen((char *)current_tail); - char *query_tail = *str ? str + 1 : str; - size_t query_tail_len = strlen(query_tail); - - return strncmp((char *)current_tail, query_tail, query_tail_len) == 0; + return strncmp((char *)current_tail, str, len) == 0; } inline trie_data_node_t trie_get_data_node(trie_t *self, trie_node_t node) { @@ -698,42 +695,63 @@ inline trie_data_node_t trie_get_data_node(trie_t *self, trie_node_t node) { return data_node; } -uint32_t trie_get_prefix_from_index(trie_t *self, char *key, size_t len, uint32_t i) { - if (key == NULL) return NULL_NODE_ID; +uint32_t trie_get_prefix_from_index(trie_t *self, char *key, size_t len, uint32_t i, size_t *tail_pos) { + if (key == NULL) { + *tail_pos = 0; + return NULL_NODE_ID; + } unsigned char *ptr = (unsigned char *)key; uint32_t node_id = i; trie_node_t node = trie_get_node(self, i); - if (node.base == NULL_NODE_ID) return NULL_NODE_ID; + if (node.base == NULL_NODE_ID) { + *tail_pos = 0; + return NULL_NODE_ID; + } uint32_t next_id = NULL_NODE_ID; - // Include NUL-byte. It may be stored if this phrase is a prefix of a longer one + if (node.base >= 0) { + // Include NUL-byte. It may be stored if this phrase is a prefix of a longer one + for (int i = 0; i < len; i++, ptr++, node_id = next_id) { + next_id = trie_get_transition_index(self, node, *ptr); + node = trie_get_node(self, next_id); - for (int i = 0; i < len; i++, ptr++, node_id = next_id) { - next_id = trie_get_transition_index(self, node, *ptr); - node = trie_get_node(self, next_id); + if (node.check != node_id) { + return NULL_NODE_ID; + } - if (node.check != node_id) { + if (node.base < 0) break; + } + } + + if (node.base < 0) { + trie_data_node_t data_node = trie_get_data_node(self, node); + + char *query_tail = *ptr ? (char *)ptr + 1 : (char *)ptr; + size_t query_len = strlen(query_tail); + + if (data_node.tail != 0 && trie_compare_tail(self, query_tail, query_len, data_node.tail)) { + *tail_pos = query_len; + return next_id; + } else { + *tail_pos = 0; return NULL_NODE_ID; } - if (node.base < 0) { - return next_id; - } } return next_id; } -uint32_t trie_get_prefix_len(trie_t *self, char *key, size_t len) { - return trie_get_prefix_from_index(self, key, len, ROOT_NODE_ID); +uint32_t trie_get_prefix_len(trie_t *self, char *key, size_t len, size_t *tail_pos) { + return trie_get_prefix_from_index(self, key, len, ROOT_NODE_ID, tail_pos); } -uint32_t trie_get_prefix(trie_t *self, char *key) { - return trie_get_prefix_from_index(self, key, strlen(key), ROOT_NODE_ID); +uint32_t trie_get_prefix(trie_t *self, char *key, size_t *tail_pos) { + return trie_get_prefix_from_index(self, key, strlen(key), ROOT_NODE_ID, tail_pos); } uint32_t trie_get_from_index(trie_t *self, char *word, size_t len, uint32_t i) { @@ -760,7 +778,9 @@ uint32_t trie_get_from_index(trie_t *self, char *word, size_t len, uint32_t i) { if (node.check == node_id && node.base < 0) { trie_data_node_t data_node = trie_get_data_node(self, node); - if (data_node.tail != 0 && trie_compare_tail(self, (char *)ptr, data_node.tail)) { + char *query_tail = *ptr ? (char *) ptr + 1 : (char *) ptr; + + if (data_node.tail != 0 && trie_compare_tail(self, query_tail, strlen(query_tail) + 1, data_node.tail)) { return next_id; } else { return NULL_NODE_ID; diff --git a/src/trie.h b/src/trie.h index 892d7950..fea3f837 100644 --- a/src/trie.h +++ b/src/trie.h @@ -105,9 +105,9 @@ uint32_t trie_get_from_index(trie_t *self, char *word, size_t len, uint32_t i); uint32_t trie_get_len(trie_t *self, char *word, size_t len); uint32_t trie_get(trie_t *self, char *word); -uint32_t trie_get_prefix(trie_t *self, char *key); -uint32_t trie_get_prefix_len(trie_t *self, char *key, size_t len); -uint32_t trie_get_prefix_from_index(trie_t *self, char *key, size_t len, uint32_t i); +uint32_t trie_get_prefix(trie_t *self, char *key, size_t *tail_pos); +uint32_t trie_get_prefix_len(trie_t *self, char *key, size_t len, size_t *tail_pos); +uint32_t trie_get_prefix_from_index(trie_t *self, char *key, size_t len, uint32_t i, size_t *tail_pos); void trie_print(trie_t *self);