From 0558475a50d61c5dc4aad3d549de3b272dbb9316 Mon Sep 17 00:00:00 2001 From: Al Date: Sun, 10 Jan 2016 01:20:17 -0500 Subject: [PATCH] [language_classifier] Language classifier structs, I/O and API --- src/language_classifier.c | 279 ++++++++++++++++++++++++++++++++++++++ src/language_classifier.h | 55 ++++++++ 2 files changed, 334 insertions(+) create mode 100644 src/language_classifier.c create mode 100644 src/language_classifier.h diff --git a/src/language_classifier.c b/src/language_classifier.c new file mode 100644 index 00000000..af1284e8 --- /dev/null +++ b/src/language_classifier.c @@ -0,0 +1,279 @@ +#include "language_classifier.h" + +#include + +#include "language_features.h" +#include "minibatch.h" +#include "normalize.h" +#include "token_types.h" +#include "unicode_scripts.h" + +#define LANGUAGE_CLASSIFIER_SIGNATURE 0xCCCCCCCC + +#define MIN_PROB (0.05 - DBL_EPSILON) + +static language_classifier_t *language_classifier = NULL; +static language_classifier_t *language_classifier_country = NULL; + +void language_classifier_destroy(language_classifier_t *self) { + if (self == NULL) return; + + if (self->features != NULL) { + trie_destroy(self->features); + } + + if (self->labels != NULL) { + cstring_array_destroy(self->labels); + } + + if (self->weights != NULL) { + matrix_destroy(self->weights); + } + + free(self); +} + +language_classifier_t *language_classifier_new(void) { + language_classifier_t *language_classifier = malloc(sizeof(language_classifier_t)); + return language_classifier; +} + +language_classifier_t *get_language_classifier(void) { + return language_classifier; +} + +language_classifier_t *get_language_classifier_country(void) { + return language_classifier_country; +} + +language_classifier_response_t *classify_languages(char *address, char *country) { + language_classifier_t *classifier = NULL; + + if (country == NULL) { + classifier = get_language_classifier(); + } else { + classifier = get_language_classifier_country(); + } + + if (classifier == NULL) { + log_error("classifier NULL\n"); + return NULL; + } + + char *normalized = language_classifier_normalize_string(address); + + token_array *tokens = token_array_new(); + char_array *feature_array = char_array_new(); + + khash_t(str_double) *feature_counts = extract_language_features(normalized, country, tokens, feature_array); + + sparse_matrix_t *x = feature_vector(classifier->features, feature_counts); + + size_t n = classifier->num_labels; + matrix_t *p_y = matrix_new_zeros(1, n); + + language_classifier_response_t *response = NULL; + if (logistic_regression_model_expectation(classifier->weights, x, p_y)) { + double *predictions = matrix_get_row(p_y, 0); + size_t *indices = double_array_argsort(predictions, n); + size_t num_languages = 0; + size_t i; + double prob; + + double min_prob = 1.0 / n; + if (min_prob < MIN_PROB) min_prob = MIN_PROB; + + for (i = 0; i < n; i++) { + size_t idx = indices[n - i - 1]; + prob = predictions[idx]; + + if (i == 0 || prob > min_prob) { + num_languages++; + } else { + break; + } + } + char **languages = malloc(sizeof(char *) * num_languages); + double *probs = malloc(sizeof(double) * num_languages); + + for (i = 0; i < num_languages; i++) { + size_t idx = indices[n - i - 1]; + char *lang = cstring_array_get_string(classifier->labels, (uint32_t)idx); + prob = predictions[idx]; + languages[i] = lang; + probs[i] = prob; + } + + free(indices); + + response = malloc(sizeof(language_classifier_response_t)); + response->num_languages = num_languages; + response->languages = languages; + response->probs = probs; + } + +exit_tokens_created: + token_array_destroy(tokens); + char_array_destroy(feature_array); + const char *key; + kh_foreach_key(feature_counts, key, { + free((char *)key); + }) + kh_destroy(str_double, feature_counts); + return response; + +} + +language_classifier_t *language_classifier_read(FILE *f) { + if (f == NULL) return NULL; + long save_pos = ftell(f); + + uint32_t signature; + + if (!file_read_uint32(f, &signature) || signature != LANGUAGE_CLASSIFIER_SIGNATURE) { + goto exit_file_read; + } + + language_classifier_t *classifier = language_classifier_new(); + if (classifier == NULL) { + goto exit_file_read; + } + + trie_t *features = trie_read(f); + if (features == NULL) { + goto exit_classifier_created; + } + classifier->features = features; + uint64_t num_features; + if (!file_read_uint64(f, &num_features)) { + goto exit_classifier_created; + } + classifier->num_features = (size_t)num_features; + + uint64_t labels_str_len; + + if (!file_read_uint64(f, &labels_str_len)) { + goto exit_classifier_created; + } + + char_array *array = char_array_new_size(labels_str_len); + + if (array == NULL) { + goto exit_classifier_created; + } + + if (!file_read_chars(f, array->a, labels_str_len)) { + char_array_destroy(array); + goto exit_classifier_created; + } + + array->n = labels_str_len; + + classifier->labels = cstring_array_from_char_array(array); + if (classifier->labels == NULL) { + goto exit_classifier_created; + } + classifier->num_labels = cstring_array_num_strings(classifier->labels); + + matrix_t *weights = matrix_read(f); + + if (weights == NULL) { + goto exit_classifier_created; + } + + classifier->weights = weights; + + return classifier; + +exit_classifier_created: + language_classifier_destroy(classifier); +exit_file_read: + fseek(f, save_pos, SEEK_SET); + return NULL; +} + + +language_classifier_t *language_classifier_load(char *path) { + FILE *f; + + f = fopen(path, "rb"); + if (!f) return NULL; + + language_classifier_t *classifier = language_classifier_read(f); + + fclose(f); + return classifier; +} + +bool language_classifier_write(language_classifier_t *self, FILE *f) { + if (f == NULL || self == NULL) return false; + + if (!file_write_uint32(f, LANGUAGE_CLASSIFIER_SIGNATURE) || + !trie_write(self->features, f) || + !file_write_uint64(f, self->num_features) || + !file_write_uint64(f, self->labels->str->n) || + !file_write_chars(f, (const char *)self->labels->str->a, self->labels->str->n) || + !matrix_write(self->weights, f)) { + return false; + } + + return true; +} + +bool language_classifier_save(language_classifier_t *self, char *path) { + if (self == NULL || path == NULL) return false; + + FILE *f = fopen(path, "wb"); + if (!f) return false; + + bool result = language_classifier_write(self, f); + fclose(f); + + return result; +} + +// Module setup/teardown + +bool language_classifier_module_setup(char *dir) { + if (language_classifier != NULL && language_classifier_country != NULL) { + return true; + } + + if (dir == NULL) { + dir = LIBPOSTAL_LANGUAGE_CLASSIFIER_DIR; + } + + char *classifier_path; + + char_array *path = char_array_new_size(strlen(dir) + PATH_SEPARATOR_LEN + strlen(LANGUAGE_CLASSIFIER_COUNTRY_FILENAME)); + if (language_classifier == NULL) { + char_array_cat_joined(path, PATH_SEPARATOR, true, 2, dir, LANGUAGE_CLASSIFIER_FILENAME); + classifier_path = char_array_get_string(path); + + language_classifier = language_classifier_load(classifier_path); + + } + + if (language_classifier_country == NULL) { + char_array_clear(path); + char_array_cat_joined(path, PATH_SEPARATOR, true, 2, dir, LANGUAGE_CLASSIFIER_COUNTRY_FILENAME); + classifier_path = char_array_get_string(path); + + language_classifier_country = language_classifier_load(classifier_path); + + } + + char_array_destroy(path); + return true; +} + +void language_classifier_module_teardown(void) { + if (language_classifier != NULL) { + language_classifier_destroy(language_classifier); + } + + if (language_classifier_country != NULL) { + language_classifier_destroy(language_classifier_country); + } +} + diff --git a/src/language_classifier.h b/src/language_classifier.h new file mode 100644 index 00000000..a3d51c89 --- /dev/null +++ b/src/language_classifier.h @@ -0,0 +1,55 @@ +#ifndef LANGUAGE_CLASSIFIER_H +#define LANGUAGE_CLASSIFIER_H + +#include +#include +#include +#include + +#include "collections.h" +#include "language_features.h" +#include "logistic_regression.h" +#include "tokens.h" +#include "string_utils.h" +#include "trie.h" + +#define LANGUAGE_CLASSIFIER_FILENAME "language_classifier.dat" +#define LANGUAGE_CLASSIFIER_COUNTRY_FILENAME "language_classifier_country.dat" + +typedef struct language_classifier { + size_t num_labels; + size_t num_features; + trie_t *features; + cstring_array *labels; + matrix_t *weights; +} language_classifier_t; + + +typedef struct language_classifier_response { + size_t num_languages; + char **languages; + double *probs; +} language_classifier_response_t; + +// General usage + +language_classifier_t *language_classifier_new(void); +language_classifier_t *get_language_classifier(void); +language_classifier_t *get_language_classifier_country(void); + +language_classifier_response_t *classify_languages(char *address, char *country); + +void language_classifier_destroy(language_classifier_t *self); + +// I/O methods + +language_classifier_t *language_classifier_load(char *path); +bool language_classifier_save(language_classifier_t *self, char *output_dir); + +// Module setup/teardown + +bool language_classifier_module_setup(char *dir); +void language_classifier_module_teardown(void); + + +#endif \ No newline at end of file