diff --git a/src/trie.c b/src/trie.c index 291df578..d7d2c1d0 100644 --- a/src/trie.c +++ b/src/trie.c @@ -679,25 +679,23 @@ bool trie_add_suffix(trie_t *self, char *key, uint32_t data) { return success; } -trie_data_node_t trie_get_data_node(trie_t *self, trie_node_t node, char *str) { - if (node.base >= 0) { - return NULL_DATA_NODE; - } - int32_t data_index = -1*node.base; - trie_data_node_t data_node = self->data->a[data_index]; - unsigned char *current_tail = self->tail->a + data_node.tail; +bool trie_compare_tail(trie_t *self, char *str, uint32_t tail_index) { + unsigned char *current_tail = self->tail->a + tail_index; size_t tail_len = strlen((char *)current_tail); char *query_tail = *str ? str + 1 : str; size_t query_tail_len = strlen(query_tail); - int tail_match = strncmp((char *)current_tail, query_tail, query_tail_len); + return strncmp((char *)current_tail, query_tail, query_tail_len) == 0; +} - if (tail_match == 0) { - return data_node; - } else { +inline trie_data_node_t trie_get_data_node(trie_t *self, trie_node_t node) { + if (node.base >= 0) { return NULL_DATA_NODE; } + int32_t data_index = -1*node.base; + trie_data_node_t data_node = self->data->a[data_index]; + return data_node; } uint32_t trie_get_prefix_from_index(trie_t *self, char *key, size_t len, uint32_t i) { @@ -720,6 +718,10 @@ uint32_t trie_get_prefix_from_index(trie_t *self, char *key, size_t len, uint32_ if (node.check != node_id) { return NULL_NODE_ID; } + + if (node.base < 0) { + return next_id; + } } return next_id; @@ -756,9 +758,9 @@ uint32_t trie_get_from_index(trie_t *self, char *word, size_t len, uint32_t i) { } if (node.check == node_id && node.base < 0) { - trie_data_node_t data_node = trie_get_data_node(self, node, (char *)ptr); + trie_data_node_t data_node = trie_get_data_node(self, node); - if (data_node.tail != 0) { + if (data_node.tail != 0 && trie_compare_tail(self, (char *)ptr, data_node.tail)) { return next_id; } else { return NULL_NODE_ID; diff --git a/src/trie.h b/src/trie.h index 0191b243..892d7950 100644 --- a/src/trie.h +++ b/src/trie.h @@ -85,7 +85,8 @@ void trie_set_check(trie_t *self, uint32_t index, int32_t check); trie_node_t trie_get_root(trie_t *self); trie_node_t trie_get_free_list(trie_t *self); -trie_data_node_t trie_get_data_node(trie_t *self, trie_node_t node, char *str); +trie_data_node_t trie_get_data_node(trie_t *self, trie_node_t node); +bool trie_tail_match(trie_t *self, char *str, uint32_t tail_index); uint32_t trie_add_transition(trie_t *self, uint32_t node_id, unsigned char c);