diff --git a/src/trie_search.c b/src/trie_search.c index ff6bbdeb..1808b7ec 100644 --- a/src/trie_search.c +++ b/src/trie_search.c @@ -163,19 +163,24 @@ phrase_array *trie_search(trie_t *self, char *text) { return phrases; } -int trie_node_search_tail_tokens(trie_t *self, trie_node_t node, tokenized_string_t *response, int tail_index, int token_index) { +int trie_node_search_tail_tokens(trie_t *self, trie_node_t node, char *str, token_array *tokens, int tail_index, int token_index) { int32_t data_index = -1*node.base; trie_data_node_t old_data_node = self->data->a[data_index]; uint32_t current_tail_pos = old_data_node.tail; - token_array *tokens = response->tokens; - unsigned char *tail_ptr = self->tail->a + current_tail_pos + tail_index; + + if (!(*tail_ptr)) { + log_debug("tail matches!\n", NULL); + return token_index-1; + } + log_debug("Searching tail: %s\n", tail_ptr); for (int i = token_index; i < tokens->n; i++) { token_t token = tokens->a[i]; - char *ptr = tokenized_string_get_token(response, i); - int token_len = token.len; + + char *ptr = str + token.offset; + int token_length = token.len; if (!(*tail_ptr)) { log_debug("tail matches!\n", NULL); @@ -186,10 +191,10 @@ int trie_node_search_tail_tokens(trie_t *self, trie_node_t node, tokenized_strin tail_ptr++; } - log_debug("Tail string compare: %s with %s\n", tail_ptr, ptr); + log_debug("Tail string compare: %s with %.*s\n", tail_ptr, (int)token_length, ptr); - if (strncmp((char *)tail_ptr, ptr, token_len) == 0) { - tail_ptr += token_len; + if (strncmp((char *)tail_ptr, ptr, token_length) == 0) { + tail_ptr += token_length; } else { return -1; } @@ -198,10 +203,8 @@ int trie_node_search_tail_tokens(trie_t *self, trie_node_t node, tokenized_strin } -phrase_array *trie_search_tokens(trie_t *self, tokenized_string_t *response) { - if (response == NULL || response->tokens->n == 0) return NULL; - - token_array *tokens = response->tokens; +phrase_array *trie_search_tokens(trie_t *self, char *str, token_array *tokens) { + if (str == NULL || tokens == NULL || tokens->n == 0) return NULL; phrase_array *phrases = phrase_array_new(); @@ -214,20 +217,31 @@ phrase_array *trie_search_tokens(trie_t *self, tokenized_string_t *response) { trie_search_state_t state = SEARCH_STATE_BEGIN, last_state = SEARCH_STATE_BEGIN; + token_t token; + size_t token_length, token_consumed; + log_debug("num_tokens: %zu\n", tokens->n); for (int i = 0; i < tokens->n; i++, last_state = state) { - char *ptr = tokenized_string_get_token(response, i); - log_debug("On %d, token=%s\n", i, ptr); + token = tokens->a[i]; + token_length = token.len; + + char *ptr = str + token.offset; + log_debug("On %d, token=%.*s\n", i, (int)token_length, ptr); - for (; *ptr; ptr++, last_node = node, last_node_id = node_id) { + for (int j = 0; j < token_length; j++, ptr++, last_node = node, last_node_id = node_id) { log_debug("Getting transition index for %d, (%d, %d)\n", node_id, node.base, node.check); - node_id = trie_get_transition_index(self, node, *ptr); - node = trie_get_node(self, node_id); - log_debug("Doing %c, got node_id=%d\n", *ptr, node_id); + if (j > 0 || last_node.base >= 0) { + node_id = trie_get_transition_index(self, node, *ptr); + node = trie_get_node(self, node_id); + log_debug("Doing %c, got node_id=%d\n", *ptr, node_id); + } else { + log_debug("Tail stored on space node, rolling back one character\n"); + ptr--; + } //if (last_node.check && last_node->tail) { node = last_node; node_id = last_node_id; } - if (node.check != last_node_id) { + if (node.check != last_node_id && last_node.base >= 0) { log_debug("Fell off trie. last_node_id=%d and node.check=%d\n", last_node_id, node.check); node = trie_get_root(self); node_id = ROOT_ID; @@ -242,16 +256,16 @@ phrase_array *trie_search_tokens(trie_t *self, tokenized_string_t *response) { unsigned char *current_tail = self->tail->a + current_tail_pos; - log_debug("next node tail: %s vs %s\n", current_tail, ptr + 1); - size_t ptr_len = strlen(ptr+1); + log_debug("next node tail: %s vs %.*s\n", current_tail, (int)token_length - j, ptr + 1); + size_t ptr_len = token_length - j; if (last_state == SEARCH_STATE_NO_MATCH || last_state == SEARCH_STATE_BEGIN) { log_debug("phrase start at %d\n", i); phrase_start = i; } if (strncmp((char *)current_tail, ptr + 1, ptr_len) == 0) { - log_debug("node tail matches first token\n", NULL); - int tail_search_result = trie_node_search_tail_tokens(self, node, response, ptr_len, i+1); + log_debug("node tail matches first token\n"); + int tail_search_result = trie_node_search_tail_tokens(self, node, str, tokens, ptr_len, i+1); if (tail_search_result == -1) { node = trie_get_root(self); node_id = ROOT_ID; @@ -282,13 +296,13 @@ phrase_array *trie_search_tokens(trie_t *self, tokenized_string_t *response) { phrase_start = 0; continue; } else if (last_state == SEARCH_STATE_PARTIAL_MATCH) { - log_debug("last_state == SEARCH_STATE_PARTIAL_MATCH\n", NULL); + log_debug("last_state == SEARCH_STATE_PARTIAL_MATCH\n"); i = phrase_start; continue; } else { phrase_start = phrase_len = 0; // this token was not a phrase - log_debug("Plain token=%s\n", tokenized_string_get_token(response, i)); + log_debug("Plain token=%.*s\n", token.len, str + token.offset); } last_node = trie_get_root(self); last_node_id = ROOT_ID; diff --git a/src/trie_search.h b/src/trie_search.h index dc581b9e..72d8f534 100644 --- a/src/trie_search.h +++ b/src/trie_search.h @@ -24,7 +24,7 @@ typedef struct phrase { VECTOR_INIT(phrase_array, phrase_t) phrase_array *trie_search(trie_t *self, char *text); -phrase_array *trie_search_tokens(trie_t *self, tokenized_string_t *response); +phrase_array *trie_search_tokens(trie_t *self, char *str, token_array *tokens); phrase_t trie_search_suffixes(trie_t *self, char *word); phrase_t trie_search_prefixes(trie_t *self, char *word);