#include "crf.h" #include "log/log.h" #define CRF_SIGNATURE 0xCFCFCFCF static inline bool crf_get_feature_id(crf_t *self, char *feature, uint32_t *feature_id) { return trie_get_data(self->state_features, feature, feature_id); } static inline bool crf_get_state_trans_feature_id(crf_t *self, char *feature, uint32_t *feature_id) { return trie_get_data(self->state_trans_features, feature, feature_id); } bool crf_tagger_score(crf_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features) { if (self == NULL || feature_function == NULL || tokenized == NULL ) { return false; } size_t num_tokens = tokenized->tokens->n; crf_context_t *crf_context = self->context; crf_context_set_num_items(crf_context, num_tokens); crf_context_reset(crf_context, CRF_CONTEXT_RESET_ALL); if (!double_matrix_copy(self->trans_weights, crf_context->trans)) { return false; } for (uint32_t t = 0; t < num_tokens; t++) { cstring_array_clear(features); cstring_array_clear(prev_tag_features); if (!feature_function(tagger, tagger_context, tokenized, t)) { log_error("Could not add address parser features\n"); return false; } uint32_t fidx; char *feature; if (print_features) { printf("{ "); size_t num_features = cstring_array_num_strings(features); cstring_array_foreach(features, fidx, feature, { printf("%s", feature); if (fidx < num_features - 1) printf(", "); }) size_t num_prev_tag_features = cstring_array_num_strings(prev_tag_features); if (num_prev_tag_features > 0) { printf(", "); } cstring_array_foreach(prev_tag_features, fidx, feature, { printf("prev tag+%s", feature); if (fidx < num_prev_tag_features - 1) printf(", "); }) printf(" }\n"); } uint32_t feature_id; double *state_scores = state_score(crf_context, t); uint32_t *indptr = self->weights->indptr->a; uint32_t *indices = self->weights->indices->a; double *data = self->weights->data->a; cstring_array_foreach(features, fidx, feature, { if (!crf_get_feature_id(self, feature, &feature_id)) { continue; } for (int col = indptr[feature_id]; col < indptr[feature_id + 1]; col++) { uint32_t class_id = indices[col]; state_scores[class_id] += data[col]; } }) double *state_trans_scores = state_trans_score_all(crf_context, t); indptr = self->state_trans_weights->indptr->a; indices = self->state_trans_weights->indices->a; data = self->state_trans_weights->data->a; cstring_array_foreach(prev_tag_features, fidx, feature, { if (!crf_get_state_trans_feature_id(self, feature, &feature_id)) { continue; } for (int col = indptr[feature_id]; col < indptr[feature_id + 1]; col++) { // Note: here there are L * L classes uint32_t class_id = indices[col]; state_trans_scores[class_id] += data[col]; } }) } return true; } bool crf_tagger_score_viterbi(crf_t *self, void *tagger, void *tagger_context, cstring_array *features, cstring_array *prev_tag_features, tagger_feature_function feature_function, tokenized_string_t *tokenized, double *score, bool print_features) { if (!crf_tagger_score(self, tagger, tagger_context, features, prev_tag_features, feature_function, tokenized, print_features)) { return false; } size_t num_tokens = tokenized->tokens->n; uint32_array_resize_fixed(self->viterbi, num_tokens); double viterbi_score = crf_context_viterbi(self->context, self->viterbi->a); *score = viterbi_score; return true; } bool crf_tagger_predict(crf_t *self, void *tagger, void *context, cstring_array *features, cstring_array *prev_tag_features, cstring_array *labels, tagger_feature_function feature_function, tokenized_string_t *tokenized, bool print_features) { double score; if (labels == NULL) return false; if (!crf_tagger_score_viterbi(self, tagger, context, features, prev_tag_features, feature_function, tokenized, &score, print_features)) { return false; } uint32_t *viterbi = self->viterbi->a; if (self->viterbi->n == 0) { log_error("self->viterbi->n == 0\n"); } for (size_t i = 0; i < self->viterbi->n; i++) { char *predicted = cstring_array_get_string(self->classes, viterbi[i]); cstring_array_add_string(labels, predicted); } return true; } bool crf_write(crf_t *self, FILE *f) { if (self == NULL || f == NULL || self->weights == NULL || self->classes == NULL || self->state_features == NULL || self->state_trans_features == NULL) { log_info("something was NULL\n"); return false; } if (!file_write_uint32(f, CRF_SIGNATURE) || !file_write_uint32(f, self->num_classes)) { log_info("error writing header\n"); return false; } uint64_t classes_str_len = (uint64_t) cstring_array_used(self->classes); if (!file_write_uint64(f, classes_str_len)) { log_info("error writing classes_str_len\n"); return false; } if (!file_write_chars(f, self->classes->str->a, classes_str_len)) { log_info("error writing chars\n"); return false; } if (!trie_write(self->state_features, f)) { log_info("error state_features\n"); return false; } if (!sparse_matrix_write(self->weights, f)) { log_info("error weights\n"); return false; } if (!trie_write(self->state_trans_features, f)) { log_info("error state_trans_features\n"); return false; } if (!sparse_matrix_write(self->state_trans_weights, f)) { log_info("error state_trans_weights\n"); return false; } if (!double_matrix_write(self->trans_weights, f)) { log_info("error trans_weights\n"); return false; } return true; } bool crf_save(crf_t *self, char *filename) { if (self == NULL || filename == NULL) { log_info("crf or filename was NULL\n"); return false; } FILE *f = fopen(filename, "wb"); if (f == NULL) return false; bool ret_val = crf_write(self, f); fclose(f); return ret_val; } crf_t *crf_read(FILE *f) { if (f == NULL) return NULL; uint32_t signature; if (!file_read_uint32(f, &signature) || signature != CRF_SIGNATURE) { return NULL; } crf_t *crf = calloc(1, sizeof(crf_t)); if (crf == NULL) return NULL; if (!file_read_uint32(f, &crf->num_classes) || crf->num_classes == 0) { free(crf); return NULL; } uint64_t classes_str_len; if (!file_read_uint64(f, &classes_str_len)) { goto exit_crf_created; } char_array *array = char_array_new_size(classes_str_len); if (array == NULL) { goto exit_crf_created; } if (!file_read_chars(f, array->a, classes_str_len)) { char_array_destroy(array); goto exit_crf_created; } array->n = classes_str_len; crf->classes = cstring_array_from_char_array(array); if (crf->classes == NULL) { goto exit_crf_created; } crf->state_features = trie_read(f); if (crf->state_features == NULL) { goto exit_crf_created; } crf->weights = sparse_matrix_read(f); if (crf->weights == NULL) { goto exit_crf_created; } crf->state_trans_features = trie_read(f); if (crf->state_trans_features == NULL) { goto exit_crf_created; } crf->state_trans_weights = sparse_matrix_read(f); if (crf->state_trans_weights == NULL) { goto exit_crf_created; } crf->trans_weights = double_matrix_read(f); if (crf->trans_weights == NULL) { goto exit_crf_created; } crf->viterbi = uint32_array_new(); if (crf->viterbi == NULL) { goto exit_crf_created; } crf->context = crf_context_new(CRF_CONTEXT_VITERBI | CRF_CONTEXT_MARGINALS, crf->num_classes, CRF_CONTEXT_DEFAULT_NUM_ITEMS); if (crf->context == NULL) { goto exit_crf_created; } return crf; exit_crf_created: crf_destroy(crf); return NULL; } crf_t *crf_load(char *filename) { if (filename == NULL) return NULL; FILE *f = fopen(filename, "rb"); if (f == NULL) return NULL; crf_t *crf = crf_read(f); fclose(f); return crf; } void crf_destroy(crf_t *self) { if (self == NULL) return; if (self->classes != NULL) { cstring_array_destroy(self->classes); } if (self->state_features != NULL) { trie_destroy(self->state_features); } if (self->weights != NULL) { sparse_matrix_destroy(self->weights); } if (self->state_trans_features != NULL) { trie_destroy(self->state_trans_features); } if (self->state_trans_weights != NULL) { sparse_matrix_destroy(self->state_trans_weights); } if (self->trans_weights != NULL) { double_matrix_destroy(self->trans_weights); } if (self->viterbi != NULL) { uint32_array_destroy(self->viterbi); } if (self->context != NULL) { crf_context_destroy(self->context); } free(self); }